From 89efc197f9c1bcbcdbc6507f1ed9922aef036ec6 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Mon, 3 Nov 2025 17:28:53 -0800 Subject: [PATCH] fix set_determinism on single gpu --- tests/unit_tests/test_set_determinism.py | 21 +++++++++++++++++++++ torchtitan/distributed/utils.py | 4 +++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/test_set_determinism.py b/tests/unit_tests/test_set_determinism.py index f6798dd01e..c8087731c5 100644 --- a/tests/unit_tests/test_set_determinism.py +++ b/tests/unit_tests/test_set_determinism.py @@ -208,6 +208,27 @@ def test_seed_uniqueness_3d_mesh(self, mock_get_rank, mock_get_world_size): f"Expected {mesh_sizes[0] * mesh_sizes[1]} unique seeds for (dp_shard, dp_replicate) combinations", ) + @patch("torch.distributed.distributed_c10d.get_world_size") + @patch("torch.distributed.distributed_c10d.get_rank") + def test_set_determinism_single_gpu(self, mock_get_rank, mock_get_world_size): + """Test set_determinism for single GPU (empty mesh)""" + mock_get_world_size.return_value = 1 + mock_get_rank.return_value = 0 + + base_seed = 42 + + fake_mesh = MagicMock() + fake_mesh.mesh_dim_names = None + fake_mesh.get_coordinate.return_value = None + + debug_config = DebugConfig(seed=base_seed, deterministic=False) + set_determinism( + world_mesh=fake_mesh, + device=self.device, + debug_config=debug_config, + distinct_seed_mesh_dims=["pp"], + ) + if __name__ == "__main__": unittest.main() diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 90ad886e22..f424276a3c 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -145,7 +145,9 @@ def set_determinism( # and choose a unique seed for each rank on the PP mesh. # We support multiple distinct dimensions by adding each distinct dimension's local rank to the seed. distinct_dims_in_mesh = [ - dim for dim in distinct_seed_mesh_dims if dim in world_mesh.mesh_dim_names + dim + for dim in distinct_seed_mesh_dims + if world_mesh.mesh_dim_names and dim in world_mesh.mesh_dim_names ] if c10d.get_world_size() > 1 and distinct_dims_in_mesh: