Skip to content

Commit

Permalink
twfb logger (#790)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #790

Adding a time-wait-for-batch logger callback.

Reviewed By: JKSenthil

Differential Revision: D56315489

fbshipit-source-id: 5fa9210114231c3c7d97d4872252cb8bf659b2d7
  • Loading branch information
galrotem authored and facebook-github-bot committed Apr 19, 2024
1 parent 35f9d92 commit 2409e14
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/framework/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
SystemResourcesMonitor
TensorBoardParameterMonitor
TimeLimitInterrupter
TimeWaitForBatchLogger
IterationTimeLogger
TorchSnapshotSaver
TQDMProgressBar
Expand Down
128 changes: 128 additions & 0 deletions tests/framework/callbacks/test_time_wait_for_batch_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from unittest.mock import ANY, call, MagicMock

import torch
from pyre_extensions import none_throws

from torch.utils.tensorboard import SummaryWriter
from torchtnt.framework._callback_handler import CallbackHandler
from torchtnt.framework._test_utils import (
DummyAutoUnit,
DummyPredictUnit,
generate_random_dataloader,
)
from torchtnt.framework.callbacks.time_wait_for_batch_logger import (
TimeWaitForBatchLogger,
)
from torchtnt.framework.predict import predict

from torchtnt.framework.state import EntryPoint, PhaseState, State
from torchtnt.framework.train import _train_impl
from torchtnt.utils.loggers.logger import MetricLogger
from torchtnt.utils.timer import TimerProtocol


class TimeWaitForBatchLoggerTest(unittest.TestCase):
def test_log_step_metrics(self) -> None:
for spec in [MetricLogger, SummaryWriter]:
with self.subTest(spec=spec):
logger = MagicMock(spec=spec)
log_method = logger.log if spec is MetricLogger else logger.add_scalar

twfb_logger = TimeWaitForBatchLogger(logger=logger, log_every_n_steps=2)
timer = MagicMock(spec=TimerProtocol)
timer.recorded_durations = {"data_wait_time": [1, 2, 3]}
twfb_logger._log_step_metrics(timer=timer, label="foo", step=1)
log_method.assert_not_called()
twfb_logger._log_step_metrics(timer=timer, label="foo", step=2)
log_method.assert_has_calls(
[
call(
"foo",
3, # last element in the data wait time list
2, # step
)
],
)

def test_comparing_twfb_logging_time(self) -> None:
dataloader = generate_random_dataloader(
num_samples=8, input_dim=2, batch_size=2
)
state = State(
entry_point=EntryPoint.FIT,
train_state=PhaseState(
dataloader=dataloader,
max_epochs=2,
max_steps_per_epoch=2,
),
eval_state=PhaseState(
dataloader=dataloader,
max_steps_per_epoch=1,
evaluate_every_n_epochs=1,
),
)

logger = MagicMock(spec=MetricLogger)
# we want to be able to compare the logging value to the state, so we need to create state manually and
# call _train_impl. This would have been similar to calling fit() and getting the state as a ret value
_train_impl(
state,
DummyAutoUnit(module=torch.nn.Linear(2, 2)),
CallbackHandler(
[TimeWaitForBatchLogger(logger=logger, log_every_n_steps=1)]
),
)
train_twfb_durations = none_throws(
state.train_state
).iteration_timer.recorded_durations["data_wait_time"]
eval_iteration_timer = none_throws(
state.eval_state
).iteration_timer.recorded_durations["data_wait_time"]

expected_training_iteration_time_calls = [
call("Time Wait For Batch (Train)", train_twfb_durations[i], i + 1)
for i in range(4)
]
expected_eval_iteration_time_calls = [
call("Time Wait For Batch (Eval)", eval_iteration_timer[i], i + 1)
for i in range(2)
]

logger.log.assert_has_calls(
expected_training_iteration_time_calls + expected_eval_iteration_time_calls,
any_order=True,
)

def test_with_predict(self) -> None:
logger = MagicMock(spec=MetricLogger)
predict(
DummyPredictUnit(input_dim=2),
generate_random_dataloader(num_samples=8, input_dim=2, batch_size=2),
max_steps_per_epoch=1,
callbacks=[TimeWaitForBatchLogger(logger=logger, log_every_n_steps=1)],
)
logger.log.assert_has_calls(
[
call(
"Time Wait For Batch (Predict)",
ANY,
1,
)
],
)

def test_invalid_log_every_n_steps(self) -> None:
with self.assertRaisesRegex(
ValueError, "log_every_n_steps must be at least 1, got 0"
):
TimeWaitForBatchLogger(
logger=MagicMock(spec=MetricLogger), log_every_n_steps=0
)
2 changes: 2 additions & 0 deletions torchtnt/framework/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .system_resources_monitor import SystemResourcesMonitor
from .tensorboard_parameter_monitor import TensorBoardParameterMonitor
from .time_limit_interrupter import TimeLimitInterrupter
from .time_wait_for_batch_logger import TimeWaitForBatchLogger
from .torch_compile import TorchCompile
from .torchsnapshot_saver import TorchSnapshotSaver
from .tqdm_progress_bar import TQDMProgressBar
Expand All @@ -43,6 +44,7 @@
"SystemResourcesMonitor",
"TensorBoardParameterMonitor",
"TimeLimitInterrupter",
"TimeWaitForBatchLogger",
"TorchCompile",
"TorchSnapshotSaver",
"TQDMProgressBar",
Expand Down
90 changes: 90 additions & 0 deletions torchtnt/framework/callbacks/time_wait_for_batch_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import cast, Union

from pyre_extensions import none_throws
from torch.utils.tensorboard import SummaryWriter

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.utils.distributed import rank_zero_fn
from torchtnt.utils.loggers.logger import MetricLogger
from torchtnt.utils.timer import TimerProtocol


class TimeWaitForBatchLogger(Callback):
"""
A callback which logs time wait for batch as scalars to a MetricLogger.
Args:
logger: Either a subclass of :class:`torchtnt.utils.loggers.logger.MetricLogger`
or a :class:`torch.utils.tensorboard.SummaryWriter` instance.
log_every_n_steps: an optional int to control the log frequency
"""

def __init__(
self,
logger: Union[MetricLogger, SummaryWriter],
log_every_n_steps: int = 1,
) -> None:
self._logger = logger
if log_every_n_steps < 1:
raise ValueError(
f"log_every_n_steps must be at least 1, got {log_every_n_steps}"
)
self._log_every_n_steps = log_every_n_steps

@rank_zero_fn
def _log_step_metrics(
self,
*,
timer: TimerProtocol,
label: str,
step: int,
) -> None:
if step % self._log_every_n_steps != 0:
return

data_wait_time_list = timer.recorded_durations.get("data_wait_time")
if not data_wait_time_list:
return

if isinstance(self._logger, SummaryWriter):
self._logger.add_scalar(
label,
data_wait_time_list[-1],
step,
)
else:
cast(MetricLogger, self._logger).log(
label,
data_wait_time_list[-1],
step,
)

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
self._log_step_metrics(
timer=none_throws(state.train_state).iteration_timer,
label="Time Wait For Batch (Train)",
step=unit.train_progress.num_steps_completed,
)

def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
self._log_step_metrics(
timer=none_throws(state.eval_state).iteration_timer,
label="Time Wait For Batch (Eval)",
step=unit.eval_progress.num_steps_completed,
)

def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
self._log_step_metrics(
timer=none_throws(state.predict_state).iteration_timer,
label="Time Wait For Batch (Predict)",
step=unit.predict_progress.num_steps_completed,
)

0 comments on commit 2409e14

Please sign in to comment.