Skip to content

Commit

Permalink
add progress reporter callback (#785)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #785

Reviewed By: JKSenthil

Differential Revision: D56175728

fbshipit-source-id: be61bf67dd0b0ac18d3633574ac7f91259e08432
  • Loading branch information
galrotem authored and facebook-github-bot committed Apr 16, 2024
1 parent 5beb537 commit 6de95a5
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/framework/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
LearningRateMonitor
MemorySnapshot
ModuleSummary
ProgressReporter
PyTorchProfiler
SlowRankDetector
SystemResourcesMonitor
Expand Down
49 changes: 49 additions & 0 deletions tests/framework/callbacks/test_progress_reporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch
from torchtnt.framework._test_utils import DummyAutoUnit
from torchtnt.framework.callbacks.progress_reporter import ProgressReporter
from torchtnt.framework.state import EntryPoint, State
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.progress import Progress


class ProgressReporterTest(unittest.TestCase):
def test_log_with_rank(self) -> None:
spawn_multi_process(2, "gloo", self._test_log_with_rank)

@staticmethod
def _test_log_with_rank() -> None:
progress_reporter = ProgressReporter()
unit = DummyAutoUnit(module=torch.nn.Linear(2, 2))
unit.train_progress = Progress(
num_epochs_completed=1,
num_steps_completed=5,
num_steps_completed_in_epoch=3,
)
unit.eval_progress = Progress(
num_epochs_completed=2,
num_steps_completed=15,
num_steps_completed_in_epoch=7,
)
state = State(entry_point=EntryPoint.FIT)
tc = unittest.TestCase()
with tc.assertLogs(level="INFO") as log:
progress_reporter.on_train_end(state, unit)
tc.assertEqual(
log.output,
[
f"INFO:torchtnt.framework.callbacks.progress_reporter:Progress Reporter: rank {get_global_rank()} at on_train_end. "
"Train progress: completed epochs: 1, completed steps: 5, completed steps in current epoch: 3. "
"Eval progress: completed epochs: 2, completed steps: 15, completed steps in current epoch: 7."
],
)
2 changes: 2 additions & 0 deletions torchtnt/framework/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .learning_rate_monitor import LearningRateMonitor
from .memory_snapshot import MemorySnapshot
from .module_summary import ModuleSummary
from .progress_reporter import ProgressReporter
from .pytorch_profiler import PyTorchProfiler
from .slow_rank_detector import SlowRankDetector
from .system_resources_monitor import SystemResourcesMonitor
Expand All @@ -36,6 +37,7 @@
"LearningRateMonitor",
"MemorySnapshot",
"ModuleSummary",
"ProgressReporter",
"PyTorchProfiler",
"SlowRankDetector",
"SystemResourcesMonitor",
Expand Down
102 changes: 102 additions & 0 deletions torchtnt/framework/callbacks/progress_reporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import logging
from typing import cast

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.utils.distributed import get_global_rank

logger: logging.Logger = logging.getLogger(__name__)


class ProgressReporter(Callback):
"""
A simple callback which logs the progress at each loop start/end, epoch start/end and step start/end.
This is useful to debug certain issues, for which the root cause might be unequal progress across ranks, for instance NCCL timeouts.
If used, it's recommended to pass this callback as the first item in the callbacks list.
"""

def on_train_start(self, state: State, unit: TTrainUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_train_start")

def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_train_epoch_start")

def on_train_step_start(self, state: State, unit: TTrainUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_train_step_start")

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_train_step_end")

def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_train_epoch_end")

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_train_end")

def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_eval_start")

def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_eval_epoch_start")

def on_eval_step_start(self, state: State, unit: TEvalUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_eval_step_start")

def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_eval_step_end")

def on_eval_epoch_end(self, state: State, unit: TEvalUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_eval_epoch_end")

def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_eval_end")

def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_predict_start")

def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_predict_epoch_start")

def on_predict_step_start(self, state: State, unit: TPredictUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_predict_step_start")

def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_predict_step_end")

def on_predict_epoch_end(self, state: State, unit: TPredictUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_predict_epoch_end")

def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
self._log_with_rank_and_unit(state, unit, "on_predict_end")

@classmethod
def _log_with_rank_and_unit(
cls, state: State, unit: AppStateMixin, hook: str
) -> None:
output_str = f"Progress Reporter: rank {get_global_rank()} at {hook}."
if state.entry_point == EntryPoint.TRAIN:
output_str = f"{output_str} Train progress: {cast(TTrainUnit, unit).train_progress.get_progress_string()}"

elif state.entry_point == EntryPoint.EVALUATE:
output_str = f"{output_str} Eval progress: {cast(TEvalUnit, unit).eval_progress.get_progress_string()}"

elif state.entry_point == EntryPoint.PREDICT:
output_str = f"{output_str} Predict progress: {cast(TPredictUnit, unit).predict_progress.get_progress_string()}"

elif state.entry_point == EntryPoint.FIT:
output_str = f"{output_str} Train progress: {cast(TTrainUnit, unit).train_progress.get_progress_string()} Eval progress: {cast(TEvalUnit, unit).eval_progress.get_progress_string()}"

else:
raise ValueError(
f"State entry point {state.entry_point} is not supported in ProgressReporter"
)

logger.info(output_str)

0 comments on commit 6de95a5

Please sign in to comment.