diff --git a/tests/framework/callbacks/test_checkpoint_utils.py b/tests/framework/callbacks/test_checkpoint_utils.py index 25efe7831a..ac1a019a98 100644 --- a/tests/framework/callbacks/test_checkpoint_utils.py +++ b/tests/framework/callbacks/test_checkpoint_utils.py @@ -421,23 +421,3 @@ def test_get_app_state(self) -> None: app_state.keys(), ["module", "optimizer", "loss_fn", "train_progress"], ) - - @skip_if_not_distributed - def test_rank_zero_read_and_broadcast(self) -> None: - spawn_multi_process(2, "gloo", self._test_rank_zero_read_and_broadcast) - - @staticmethod - def _test_rank_zero_read_and_broadcast() -> None: - """ - Tests that rank_zero_read_and_broadcast decorator works as expected - """ - - @rank_zero_read_and_broadcast - def _test_method_for_rank_zero() -> str: - assert get_global_rank() == 0 - return "foo" - - init_from_env() - val_from_test_method = _test_method_for_rank_zero() - tc = unittest.TestCase() - tc.assertEqual(val_from_test_method, "foo") diff --git a/tests/utils/test_distributed.py b/tests/utils/test_distributed.py index cb4cbba5ed..437f016144 100644 --- a/tests/utils/test_distributed.py +++ b/tests/utils/test_distributed.py @@ -30,6 +30,7 @@ get_world_size, PGWrapper, rank_zero_fn, + rank_zero_read_and_broadcast, revert_sync_batchnorm, spawn_multi_process, sync_bool, @@ -443,3 +444,22 @@ def _test_method(offset_arg: int, offset_kwarg: int) -> int: def test_spawn_multi_process(self) -> None: mp_list = spawn_multi_process(2, "gloo", self._test_method, 3, offset_kwarg=2) self.assertEqual(mp_list, [1, 2]) + + @skip_if_not_distributed + def test_rank_zero_read_and_broadcast(self) -> None: + spawn_multi_process(2, "gloo", self._test_rank_zero_read_and_broadcast) + + @staticmethod + def _test_rank_zero_read_and_broadcast() -> None: + """ + Tests that rank_zero_read_and_broadcast decorator works as expected + """ + + @rank_zero_read_and_broadcast + def _test_method_for_rank_zero() -> str: + assert get_global_rank() == 0 + return "foo" + + val_from_test_method = _test_method_for_rank_zero() + tc = unittest.TestCase() + tc.assertEqual(val_from_test_method, "foo") diff --git a/torchtnt/framework/callbacks/_checkpoint_utils.py b/torchtnt/framework/callbacks/_checkpoint_utils.py index e67c956510..087c15c15b 100644 --- a/torchtnt/framework/callbacks/_checkpoint_utils.py +++ b/torchtnt/framework/callbacks/_checkpoint_utils.py @@ -10,18 +10,7 @@ import os import re -from typing import ( - Any, - Callable, - cast, - Dict, - List, - Literal, - Optional, - Pattern, - Tuple, - TypeVar, -) +from typing import Any, Dict, List, Literal, Optional, Pattern, Tuple, TypeVar import fsspec @@ -30,7 +19,7 @@ from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions from torchtnt.framework.state import State from torchtnt.framework.unit import AppStateMixin -from torchtnt.utils.distributed import get_global_rank, PGWrapper +from torchtnt.utils.distributed import rank_zero_read_and_broadcast from torchtnt.utils.fsspec import get_filesystem from torchtnt.utils.stateful import Stateful @@ -40,44 +29,6 @@ T = TypeVar("T") -def rank_zero_read_and_broadcast( - func: Callable[..., T], -) -> Callable[..., T]: - """ - Decorator that ensures a function is only executed by rank 0 and returns the result to all ranks. - - Note: - By default will use the global process group. To use a custom process group, `process_group` must be an arg to the function and passed as a keyword argument. - """ - - def wrapper(*args: Any, **kwargs: Any) -> T: - ret = None - rank = get_global_rank() - process_group = kwargs.pop("process_group", None) - - # Do all filesystem reads from rank 0 only - if rank == 0: - ret = func(*args, **kwargs) - - # If not running in a distributed setting, return as is - if not (dist.is_available() and dist.is_initialized()): - # we cast here to avoid type errors, since it is - # guaranteed the return value is of type T - return cast(T, ret) - - # Otherwise, broadcast result from rank 0 to all ranks - pg = PGWrapper(process_group) - path_container = [ret] - pg.broadcast_object_list(path_container, 0) - val = path_container[0] - - # we cast here to avoid type errors, since it is - # guaranteed the return value is of type T - return cast(T, val) - - return wrapper - - @rank_zero_read_and_broadcast def get_latest_checkpoint_path( dirpath: str, diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 9993fbd532..f5f56fdf25 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -24,7 +24,6 @@ get_best_checkpoint_path, get_checkpoint_dirpaths, get_latest_checkpoint_path, - rank_zero_read_and_broadcast, ) from torchtnt.framework.callbacks.checkpointer_types import ( BestCheckpointConfig, @@ -33,7 +32,7 @@ from torchtnt.framework.state import EntryPoint, State from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit from torchtnt.framework.utils import get_timing_context -from torchtnt.utils.distributed import PGWrapper +from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast from torchtnt.utils.fsspec import get_filesystem from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn diff --git a/torchtnt/utils/distributed.py b/torchtnt/utils/distributed.py index b65dfbfb96..c9496722b4 100644 --- a/torchtnt/utils/distributed.py +++ b/torchtnt/utils/distributed.py @@ -590,3 +590,41 @@ def _init_pg_and_rank_and_launch_method( finally: destroy_process_group() + + +def rank_zero_read_and_broadcast( + func: Callable[..., T], +) -> Callable[..., T]: + """ + Decorator that ensures a function is only executed by rank 0 and returns the result to all ranks. + + Note: + By default will use the global process group. To use a custom process group, `process_group` must be an arg to the function and passed as a keyword argument. + """ + + def wrapper(*args: Any, **kwargs: Any) -> T: + ret = None + rank = get_global_rank() + process_group = kwargs.pop("process_group", None) + + # Do all filesystem reads from rank 0 only + if rank == 0: + ret = func(*args, **kwargs) + + # If not running in a distributed setting, return as is + if not (dist.is_available() and dist.is_initialized()): + # we cast here to avoid type errors, since it is + # guaranteed the return value is of type T + return cast(T, ret) + + # Otherwise, broadcast result from rank 0 to all ranks + pg = PGWrapper(process_group) + path_container = [ret] + pg.broadcast_object_list(path_container, 0) + val = path_container[0] + + # we cast here to avoid type errors, since it is + # guaranteed the return value is of type T + return cast(T, val) + + return wrapper