Skip to content

Commit

Permalink
init device handle before any other initializations after determined …
Browse files Browse the repository at this point in the history
…ignored params
  • Loading branch information
medivh-xp committed May 16, 2023
2 parents e50a264 + 16fc0f3 commit 20bfb47
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
2 changes: 2 additions & 0 deletions torch/distributed/_composable/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.distributed.fsdp._init_utils import (
_init_buffer_state,
_init_core_state,
_init_device_handle,
_init_ignored_module_states,
_init_param_handles_from_module,
_init_prefetching_state,
Expand Down Expand Up @@ -56,6 +57,7 @@ def fully_shard(
raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
state = fully_shard.state(module)
state = _init_ignored_module_states(state, module, ignored_modules)
state = _init_device_handle(state, module, state._ignored_params, device_id)
state = _init_process_group_state(
state, process_group, ShardingStrategy.FULL_SHARD, policy
)
Expand Down
47 changes: 45 additions & 2 deletions torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,51 @@ def _init_ignored_module_states(
return state


@no_type_check
def _init_device_handle(
state: _FSDPState,
module: nn.Module,
ignored_params: Set[nn.Parameter],
device_id: Optional[Union[int, torch.device]],
) -> _FSDPState:
"""
Determines device handle used for initializing FSDP. If a device is specified by ``device_id``,
then returns device handle corresponds to that device type. Otherwise, If the
module is already on a non-CPU device, then the device type is that non-CPU device type.
If the module is on CPU or meta, then the device type is the current cuda device.
This method will be called once ignored paramters was determined, as the device handle maybe needed
for other initialization.
"""
determined_device = None
if device_id is not None:
determined_device = (
device_id
if isinstance(device_id, torch.device)
else torch.device(device_id)
)
if determined_device is None:
for param in _get_orig_params(module, ignored_params):
if param.device.type in {"cpu", "meta"}:
continue
if determined_device is None:
determined_device = param.device
else:
if param.device.type != determined_device.type:
raise RuntimeError(
f"FSDP not supports modules on different device type "
f"but got params on {determined_device.type} and {param.device.type}"
)
determined_device = determined_device or torch.device(
"cuda", torch.cuda.current_device()
)

state._device_handle = _FSDPDeviceHandle.from_device(
torch.device("cuda", torch.cuda.current_device())
)
return state


@no_type_check
def _init_buffer_state(
state: _FSDPState,
Expand Down Expand Up @@ -430,7 +475,6 @@ def _init_param_handle_from_module(
device_from_device_id,
state.rank,
)
state._device_handle = _FSDPDeviceHandle.from_device(state.compute_device)

managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
if sync_module_states:
Expand Down Expand Up @@ -515,7 +559,6 @@ def _init_param_handles_from_module(
device_from_device_id,
state.rank,
)
state._device_handle = _FSDPDeviceHandle.from_device(state.compute_device)
if sync_module_states:
_sync_module_states(params, buffers, state.process_group)
_init_param_handle_from_params(state, params, fully_sharded_module)
Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_get_default_comm_hook,
_init_buffer_state,
_init_core_state,
_init_device_handle,
_init_ignored_module_states,
_init_param_handle_from_module,
_init_prefetching_state,
Expand Down Expand Up @@ -389,6 +390,7 @@ def __init__(
torch._C._log_api_usage_once("torch.distributed.fsdp")
super().__init__()
_init_ignored_module_states(self, module, ignored_modules, ignored_parameters)
_init_device_handle(self, module, self._ignored_params, device_id)

# Add module annotations for Dynamo support (see function for details)
_annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params)
Expand Down

0 comments on commit 20bfb47

Please sign in to comment.