Skip to content

Commit

Permalink
[fairscale] Avoid creating unnecessary pgs
Browse files Browse the repository at this point in the history
Summary: with the new c10d API, we don't need all ranks to call new_group. Integrate with the new API, so that every rank just call new_group 3 times, with a local barrier with the members within the group.

Test Plan: https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx_hpc-xlformers_chinch70B_4096_xdwang_0426190441?env=PRODUCTION&job_name=torchx_hpc-xlformers_chinch70B_4096_xdwang_0426190441

Reviewed By: xunnanxu, eeggl

Differential Revision: D45315615

fbshipit-source-id: 73f57188da69ecd6466dba8ab1739b9e862d614b
  • Loading branch information
xw285cornell authored and facebook-github-bot committed May 3, 2023
1 parent 4b9ba3f commit 01eccfd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
37 changes: 37 additions & 0 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,43 @@ def _test_new_group_local_sync_sanity_check(self, backend):
]
self.assertEqual(output_tensor_list, expected)

def _test_new_group_local_sync_duplidate_pg(self, backend):
"""
We should support users create multiople PGs with the same set of
members, and no conflict in group name
"""
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)
rank = dist.get_rank()

# split the world in 2 PGs
rank = dist.get_rank()
pg_idx = rank // 2
ranks_in = [pg_idx * 2, pg_idx * 2 + 1]
new_pgs = []
for _ in range(2):
new_pgs.append(
dist.new_group(ranks=ranks_in, use_local_synchronization=True)
)

input_tensor = torch.tensor([pg_idx, rank], device=self.device)
for new_pg in new_pgs:
output_tensor_list = [
torch.tensor([-1, -1], device=self.device,) for _ in range(new_pg.size())
]
dist.all_gather(output_tensor_list, input_tensor, group=new_pg)

expected = [
torch.tensor([pg_idx, ranks_in[0]], device=self.device),
torch.tensor([pg_idx, ranks_in[1]], device=self.device)
]
self.assertEqual(output_tensor_list, expected)


class CommTest(AbstractCommTest, MultiProcessTestCase):
def setUp(self):
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3635,7 +3635,7 @@ def _process_group_name(ranks, use_hashed_name):
global _world
if use_hashed_name:
pg_name = hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
while pg_name in _world.pg_names:
while pg_name in _world.pg_names.values():
pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest()
else:
pg_name = str(_world.group_count)
Expand Down Expand Up @@ -3785,6 +3785,7 @@ def _new_group_with_tag(
group_rank = global_rank

group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization)
print(f"XDW: rank-{group_rank}, group_name: {group_name}")

with record_function(f"## process_group:init with ranks: {ranks}"):
pg, pg_store = _new_process_group_helper(
Expand Down

0 comments on commit 01eccfd

Please sign in to comment.