Skip to content

Commit

Permalink
fsdp support create hybrid-sharded group for custom backend
Browse files Browse the repository at this point in the history
  • Loading branch information
medivh-xp committed May 15, 2023
1 parent 253b9d3 commit 4d42667
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _init_process_group_state_for_hybrid_shard(
if process_group is None:
default_group = _get_default_group()
intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
default_group
default_group, state._device_handle.device_count()
)
# we shard across intra-node
state.process_group = intra_node_group
Expand Down Expand Up @@ -170,7 +170,7 @@ def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool:


@no_type_check
def _init_intra_node_process_group() -> dist.ProcessGroup:
def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup:
"""
Returns a process group across the current node.
For example, given each row is a distinct node:
Expand All @@ -180,13 +180,14 @@ def _init_intra_node_process_group() -> dist.ProcessGroup:
[0, 7] or [8, 15] depending on the process's rank.
For example, rank 3 would get [0, 7].
"""
intra_node_subgroup, _ = dist.new_subgroups()
intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node)
return intra_node_subgroup


@no_type_check
def _init_inter_node_process_group(
global_process_group: dist.ProcessGroup,
num_devices_per_node: int,
) -> dist.ProcessGroup:
"""
Returns an inter-node process group where each contained rank has
Expand All @@ -202,12 +203,11 @@ def _init_inter_node_process_group(
sharding_backend = dist.get_backend(global_process_group)
world_size = dist.get_world_size(global_process_group)
# Assuming fully homogeneous setup
num_devices = torch.cuda.device_count()
num_nodes = world_size // num_devices
my_local_rank = dist.get_rank(global_process_group) % num_devices
for local_rank in range(num_devices):
num_nodes = world_size // num_devices_per_node
my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node
for local_rank in range(num_devices_per_node):
ranks_for_inter_group = [
local_rank + (i * num_devices) for i in range(num_nodes)
local_rank + (i * num_devices_per_node) for i in range(num_nodes)
]
# every rank always needs to call dist.new_group
grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend)
Expand All @@ -223,6 +223,7 @@ def _init_inter_node_process_group(

def _init_intra_and_inter_node_groups(
global_process_group: dist.ProcessGroup,
num_devices_per_node: int,
) -> Tuple[dist.ProcessGroup, dist.ProcessGroup]:
"""
Initializes intra and inter-node process groups and returns the ones corresponding
Expand All @@ -234,8 +235,8 @@ def _init_intra_and_inter_node_groups(
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group.
"""
return (
_init_intra_node_process_group(),
_init_inter_node_process_group(global_process_group),
_init_intra_node_process_group(num_devices_per_node),
_init_inter_node_process_group(global_process_group, num_devices_per_node),
)


Expand Down

0 comments on commit 4d42667

Please sign in to comment.