Skip to content

Commit

Permalink
Fix bug in process_group_name when there is duplicate pgs (#100518)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #100518

with the new c10d API, we don't need all ranks to call new_group. Integrate with the new API, so that every rank just call new_group 3 times, with a local barrier with the members within the group.

Test Plan: https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx_hpc-xlformers_chinch70B_4096_xdwang_0426190441?env=PRODUCTION&job_name=torchx_hpc-xlformers_chinch70B_4096_xdwang_0426190441

Reviewed By: xunnanxu, eeggl

Differential Revision: D45315615

fbshipit-source-id: aaefe67f52bddd6e184496c1231dadb0dcca089b
  • Loading branch information
xw285cornell authored and facebook-github-bot committed May 3, 2023
1 parent 1a6f613 commit 44aa965
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 1 deletion.
37 changes: 37 additions & 0 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,43 @@ def _test_new_group_local_sync_sanity_check(self, backend):
]
self.assertEqual(output_tensor_list, expected)

def _test_new_group_local_sync_duplicate_pg(self, backend):
"""
We should support users create multiple PGs with the same set of
members, and no conflict in group name
"""
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)
rank = dist.get_rank()

# split the world in 2 PGs
rank = dist.get_rank()
pg_idx = rank // 2
ranks_in = [pg_idx * 2, pg_idx * 2 + 1]
new_pgs = []
for _ in range(2):
new_pgs.append(
dist.new_group(ranks=ranks_in, use_local_synchronization=True)
)

input_tensor = torch.tensor([pg_idx, rank], device=self.device)
for new_pg in new_pgs:
output_tensor_list = [
torch.tensor([-1, -1], device=self.device,) for _ in range(new_pg.size())
]
dist.all_gather(output_tensor_list, input_tensor, group=new_pg)

expected = [
torch.tensor([pg_idx, ranks_in[0]], device=self.device),
torch.tensor([pg_idx, ranks_in[1]], device=self.device)
]
self.assertEqual(output_tensor_list, expected)


class CommTest(AbstractCommTest, MultiProcessTestCase):
def setUp(self):
Expand Down
4 changes: 4 additions & 0 deletions test/distributed/test_c10d_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2511,6 +2511,10 @@ def test_new_group_local_sync(self):
def test_new_group_local_sync_sanity_check(self):
self._test_new_group_local_sync_sanity_check(backend="gloo")

@requires_gloo()
def test_new_group_local_sync_duplicate_pg(self):
self._test_new_group_local_sync_duplicate_pg(backend="gloo")

if __name__ == "__main__":
assert (
not torch.cuda._initialized
Expand Down
5 changes: 5 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3043,6 +3043,11 @@ def test_new_group_local_sync(self):
def test_new_group_local_sync_sanity_check(self):
self._test_new_group_local_sync_sanity_check(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(4)
def test_new_group_local_sync_duplicated_pg(self):
self._test_new_group_local_sync_duplicate_pg(backend="nccl")



if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3635,7 +3635,7 @@ def _process_group_name(ranks, use_hashed_name):
global _world
if use_hashed_name:
pg_name = hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
while pg_name in _world.pg_names:
while pg_name in _world.pg_names.values():
pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest()
else:
pg_name = str(_world.group_count)
Expand Down

0 comments on commit 44aa965

Please sign in to comment.