Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fsdp support create hybrid-sharded process group for custom backend #100622

Closed
wants to merge 7 commits into from
Closed
2 changes: 2 additions & 0 deletions torch/distributed/_composable/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.distributed.fsdp._init_utils import (
_init_buffer_state,
_init_core_state,
_init_device_handle,
_init_ignored_module_states,
_init_param_handles_from_module,
_init_prefetching_state,
Expand Down Expand Up @@ -56,6 +57,7 @@ def fully_shard(
raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
state = fully_shard.state(module)
state = _init_ignored_module_states(state, module, ignored_modules)
state = _init_device_handle(state, module, state._ignored_params, device_id)
state = _init_process_group_state(
state, process_group, ShardingStrategy.FULL_SHARD, policy
)
Expand Down
66 changes: 54 additions & 12 deletions torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _init_process_group_state_for_hybrid_shard(
if process_group is None:
default_group = _get_default_group()
intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
default_group
default_group, state._device_handle.device_count()
)
# we shard across intra-node
state.process_group = intra_node_group
Expand Down Expand Up @@ -170,7 +170,7 @@ def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool:


@no_type_check
def _init_intra_node_process_group() -> dist.ProcessGroup:
def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup:
"""
Returns a process group across the current node.
For example, given each row is a distinct node:
Expand All @@ -180,13 +180,14 @@ def _init_intra_node_process_group() -> dist.ProcessGroup:
[0, 7] or [8, 15] depending on the process's rank.
For example, rank 3 would get [0, 7].
"""
intra_node_subgroup, _ = dist.new_subgroups()
intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node)
return intra_node_subgroup


@no_type_check
def _init_inter_node_process_group(
global_process_group: dist.ProcessGroup,
num_devices_per_node: int,
) -> dist.ProcessGroup:
"""
Returns an inter-node process group where each contained rank has
Expand All @@ -202,12 +203,11 @@ def _init_inter_node_process_group(
sharding_backend = dist.get_backend(global_process_group)
world_size = dist.get_world_size(global_process_group)
# Assuming fully homogeneous setup
num_devices = torch.cuda.device_count()
num_nodes = world_size // num_devices
my_local_rank = dist.get_rank(global_process_group) % num_devices
for local_rank in range(num_devices):
num_nodes = world_size // num_devices_per_node
my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node
for local_rank in range(num_devices_per_node):
ranks_for_inter_group = [
local_rank + (i * num_devices) for i in range(num_nodes)
local_rank + (i * num_devices_per_node) for i in range(num_nodes)
]
# every rank always needs to call dist.new_group
grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend)
Expand All @@ -223,6 +223,7 @@ def _init_inter_node_process_group(

def _init_intra_and_inter_node_groups(
global_process_group: dist.ProcessGroup,
num_devices_per_node: int,
) -> Tuple[dist.ProcessGroup, dist.ProcessGroup]:
"""
Initializes intra and inter-node process groups and returns the ones corresponding
Expand All @@ -234,8 +235,8 @@ def _init_intra_and_inter_node_groups(
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group.
"""
return (
_init_intra_node_process_group(),
_init_inter_node_process_group(global_process_group),
_init_intra_node_process_group(num_devices_per_node),
_init_inter_node_process_group(global_process_group, num_devices_per_node),
)


Expand Down Expand Up @@ -264,6 +265,49 @@ def _init_ignored_module_states(
return state


@no_type_check
def _init_device_handle(
state: _FSDPState,
module: nn.Module,
ignored_params: Set[nn.Parameter],
device_id: Optional[Union[int, torch.device]],
) -> _FSDPState:
"""
Determines device handle used for initializing FSDP. If a device is specified by ``device_id``,
then returns device handle corresponds to that device type. Otherwise, If the
module is already on a non-CPU device, then the device type is that non-CPU device type.
If the module is on CPU or meta, then the device type is the current cuda device.

This method will be called once ignored paramters was determined, as the device handle maybe needed
for other initialization.
"""
determined_device = None
if device_id is not None:
determined_device = (
device_id
if isinstance(device_id, torch.device)
else torch.device(device_id)
)
if determined_device is None:
for param in _get_orig_params(module, ignored_params):
if param.device.type in {"cpu", "meta"}:
continue
if determined_device is None:
determined_device = param.device
else:
if param.device.type != determined_device.type:
raise RuntimeError(
f"FSDP not supports modules on different device type "
medivh-xp marked this conversation as resolved.
Show resolved Hide resolved
f"but got params on {determined_device.type} and {param.device.type}"
)
determined_device = determined_device or torch.device(
"cuda", torch.cuda.current_device()
)

state._device_handle = _FSDPDeviceHandle.from_device(determined_device)
return state


@no_type_check
def _init_buffer_state(
state: _FSDPState,
Expand Down Expand Up @@ -429,7 +473,6 @@ def _init_param_handle_from_module(
device_from_device_id,
state.rank,
)
state._device_handle = _FSDPDeviceHandle.from_device(state.compute_device)

managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
if sync_module_states:
Expand Down Expand Up @@ -514,7 +557,6 @@ def _init_param_handles_from_module(
device_from_device_id,
state.rank,
)
state._device_handle = _FSDPDeviceHandle.from_device(state.compute_device)
if sync_module_states:
_sync_module_states(params, buffers, state.process_group)
_init_param_handle_from_params(state, params, fully_sharded_module)
Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_get_default_comm_hook,
_init_buffer_state,
_init_core_state,
_init_device_handle,
_init_ignored_module_states,
_init_param_handle_from_module,
_init_prefetching_state,
Expand Down Expand Up @@ -389,6 +390,7 @@ def __init__(
torch._C._log_api_usage_once("torch.distributed.fsdp")
super().__init__()
_init_ignored_module_states(self, module, ignored_modules, ignored_parameters)
_init_device_handle(self, module, self._ignored_params, device_id)

# Add module annotations for Dynamo support (see function for details)
_annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params)
Expand Down