Skip to content

Commit

Permalink
fix iteration time logger logging steps (#786)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #786

Make sure all steps are logged with the right value

Reviewed By: anshulverma

Differential Revision: D56199868

fbshipit-source-id: 69c6088e75c9af79d91547b094e5d8a6f7c3cfaf
  • Loading branch information
galrotem authored and facebook-github-bot committed Apr 17, 2024
1 parent 6de95a5 commit c7095bf
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 4 deletions.
65 changes: 61 additions & 4 deletions tests/framework/callbacks/test_iteration_time_logger.py
Expand Up @@ -10,17 +10,22 @@
import unittest
from unittest.mock import call, MagicMock

import torch
from pyre_extensions import none_throws

from torch.utils.tensorboard import SummaryWriter
from torchtnt.framework._callback_handler import CallbackHandler
from torchtnt.framework._test_utils import (
DummyAutoUnit,
DummyEvalUnit,
DummyPredictUnit,
DummyTrainUnit,
generate_random_dataloader,
)
from torchtnt.framework.callbacks.iteration_time_logger import IterationTimeLogger

from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.framework.state import EntryPoint, PhaseState, State
from torchtnt.framework.train import _train_impl, train
from torchtnt.utils.loggers.logger import MetricLogger


Expand Down Expand Up @@ -68,12 +73,12 @@ def test_iteration_time_logger_test_on_train_step_end(self) -> None:
call(
"Train Iteration Time (seconds)",
6.0, # the average of the last 4 numbers is 6
2, # after incrementing twice, step should be 2
1, # at on_train_step_end we report for step-1, we incremented twice so value should be 1
),
call(
"Prediction Iteration Time (seconds)",
16.0, # the average of the last 4 numbers is 16
2, # after incrementing twice, step should be 2
1, # at on_predict_step_end we report for step-1, we incremented twice so value should be 1
),
]
)
Expand All @@ -93,6 +98,58 @@ def test_with_train_epoch(self) -> None:
# 2 epochs, 6 iterations each, logging every third step
self.assertEqual(logger.log.call_count, 4)

def test_comparing_step_logging_time(self) -> None:
"""
Test IterationTimeLogger callback and compare reported time to collected time
"""

my_auto_unit = DummyAutoUnit(module=torch.nn.Linear(2, 2))
logger = MagicMock(spec=MetricLogger)
iteration_time_logger = IterationTimeLogger(
logger, moving_avg_window=1, log_every_n_steps=1
)
dataloader = generate_random_dataloader(
num_samples=8, input_dim=2, batch_size=2
)
state = State(
entry_point=EntryPoint.FIT,
train_state=PhaseState(
dataloader=dataloader,
max_epochs=2,
max_steps_per_epoch=2,
),
eval_state=PhaseState(
dataloader=dataloader,
max_steps_per_epoch=2,
evaluate_every_n_epochs=1,
),
)

# we want to be able to compare the logging value to the state, so we need to create state manually and
# call _train_impl. This would have been similar to calling fit() and getting the state as a ret value

_train_impl(state, my_auto_unit, CallbackHandler([iteration_time_logger]))
train_iteration_timer = none_throws(
state.train_state
).iteration_timer.recorded_durations["train_iteration_time"]
eval_iteration_timer = none_throws(
state.eval_state
).iteration_timer.recorded_durations["eval_iteration_time"]

expected_training_iteration_time_calls = [
call("Train Iteration Time (seconds)", train_iteration_timer[i], i + 1)
for i in range(4)
]
expected_eval_iteration_time_calls = [
call("Eval Iteration Time (seconds)", eval_iteration_timer[i], i + 1)
for i in range(4)
]

logger.log.assert_has_calls(
expected_training_iteration_time_calls + expected_eval_iteration_time_calls,
any_order=True,
)

def test_with_summary_writer(self) -> None:
"""
Test IterationTimeLogger callback with train entry point and SummaryWriter
Expand Down
28 changes: 28 additions & 0 deletions torchtnt/framework/callbacks/iteration_time_logger.py
Expand Up @@ -87,6 +87,15 @@ def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
self._log_step_metrics(
"train_iteration_time",
timer,
# on_train_step_end happens after the num steps is incremented, but before the timer list is populated,
# so it logs for step-1
unit.train_progress.num_steps_completed - 1,
)

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
self._log_step_metrics(
"train_iteration_time",
none_throws(state.train_state).iteration_timer,
unit.train_progress.num_steps_completed,
)

Expand All @@ -95,10 +104,29 @@ def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
self._log_step_metrics(
"eval_iteration_time",
timer,
# on_eval_step_end happens after the num steps is incremented, but before the timer list is populated,
# so it logs for step-1
unit.eval_progress.num_steps_completed - 1,
)

def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
self._log_step_metrics(
"eval_iteration_time",
none_throws(state.eval_state).iteration_timer,
unit.eval_progress.num_steps_completed,
)

def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
timer = none_throws(state.predict_state).iteration_timer
self._log_step_metrics(
"predict_iteration_time",
timer,
# on_predict_step_end happens after the num steps is incremented, but before the timer list is populated,
# so it logs for step-1
unit.predict_progress.num_steps_completed - 1,
)

def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
timer = none_throws(state.predict_state).iteration_timer
self._log_step_metrics(
"predict_iteration_time",
Expand Down

0 comments on commit c7095bf

Please sign in to comment.