Skip to content

Commit

Permalink
Fix test_checkpoint_utils multiprocess test (#791)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #791

Reviewed By: JKSenthil

Differential Revision: D56260189

fbshipit-source-id: 32d5035a7c654308695e8e6f5126e4e0da75f610
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Apr 19, 2024
1 parent 5534617 commit 35f9d92
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tests/framework/callbacks/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed import launcher
from torchsnapshot import Snapshot
from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME
from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state
Expand All @@ -34,7 +33,7 @@
from torchtnt.utils.distributed import get_global_rank, PGWrapper, spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.fsspec import get_filesystem
from torchtnt.utils.test_utils import get_pet_launch_config, skip_if_not_distributed
from torchtnt.utils.test_utils import skip_if_not_distributed

METADATA_FNAME: str = ".metadata"

Expand Down Expand Up @@ -88,10 +87,11 @@ def test_latest_checkpoint_path(self) -> None:

@skip_if_not_distributed
def test_latest_checkpoint_path_distributed(self) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(
config, entrypoint=self._latest_checkpoint_path_distributed
)()
spawn_multi_process(
2,
"gloo",
self._latest_checkpoint_path_distributed,
)

@staticmethod
def _latest_checkpoint_path_distributed() -> None:
Expand Down Expand Up @@ -130,6 +130,7 @@ def _latest_checkpoint_path_distributed() -> None:
path_container = [path_2] if is_rank0 else [None]
pg.broadcast_object_list(path_container, 0)
expected_path = path_container[0]
tc.assertIsNotNone(expected_path)
tc.assertEqual(
get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path
)
Expand Down

0 comments on commit 35f9d92

Please sign in to comment.