Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in process_group_name when there is duplicate pgs #100518

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,43 @@ def _test_new_group_local_sync_sanity_check(self, backend):
]
self.assertEqual(output_tensor_list, expected)

def _test_new_group_local_sync_duplicate_pg(self, backend):
"""
We should support users create multiple PGs with the same set of
members, and no conflict in group name
"""
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)
rank = dist.get_rank()

# split the world in 2 PGs
rank = dist.get_rank()
pg_idx = rank // 2
ranks_in = [pg_idx * 2, pg_idx * 2 + 1]
new_pgs = []
for _ in range(2):
new_pgs.append(
dist.new_group(ranks=ranks_in, use_local_synchronization=True)
)

input_tensor = torch.tensor([pg_idx, rank], device=self.device)
for new_pg in new_pgs:
output_tensor_list = [
torch.tensor([-1, -1], device=self.device,) for _ in range(new_pg.size())
]
dist.all_gather(output_tensor_list, input_tensor, group=new_pg)

expected = [
torch.tensor([pg_idx, ranks_in[0]], device=self.device),
torch.tensor([pg_idx, ranks_in[1]], device=self.device)
]
self.assertEqual(output_tensor_list, expected)


class CommTest(AbstractCommTest, MultiProcessTestCase):
def setUp(self):
Expand Down
4 changes: 4 additions & 0 deletions test/distributed/test_c10d_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2511,6 +2511,10 @@ def test_new_group_local_sync(self):
def test_new_group_local_sync_sanity_check(self):
self._test_new_group_local_sync_sanity_check(backend="gloo")

@requires_gloo()
def test_new_group_local_sync_duplicate_pg(self):
self._test_new_group_local_sync_duplicate_pg(backend="gloo")

if __name__ == "__main__":
assert (
not torch.cuda._initialized
Expand Down
5 changes: 5 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3043,6 +3043,11 @@ def test_new_group_local_sync(self):
def test_new_group_local_sync_sanity_check(self):
self._test_new_group_local_sync_sanity_check(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(4)
def test_new_group_local_sync_duplicated_pg(self):
self._test_new_group_local_sync_duplicate_pg(backend="nccl")



if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3635,7 +3635,7 @@ def _process_group_name(ranks, use_hashed_name):
global _world
if use_hashed_name:
pg_name = hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
while pg_name in _world.pg_names:
while pg_name in _world.pg_names.values():
pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest()
else:
pg_name = str(_world.group_count)
Expand Down