Skip to content

Commit

Permalink
[FSDP][3/N] Unify fully_shard auto wrap (#104408)
Browse files Browse the repository at this point in the history
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: #104408
Approved by: https://github.com/rohan-varma
  • Loading branch information
awgu authored and pytorchmergebot committed Jul 8, 2023
1 parent 6d71b4f commit d9be036
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 184 deletions.
15 changes: 12 additions & 3 deletions test/distributed/_composable/fully_shard/test_fully_shard_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,20 @@ def test_nested_fully_shard_shared_state(self):
Tests that nested applications of ``fully_shard`` share the expected
data structure state.
"""
self.run_subtests(
{"use_policy": [False, True]},
self._test_nested_fully_shard_shared_state,
)

def _test_nested_fully_shard_shared_state(self, use_policy: bool):
device = torch.device("cuda")
composable_module = CompositeParamModel(device=device)
fully_shard(composable_module.u1)
fully_shard(composable_module.u2)
fully_shard(composable_module)
if use_policy:
fully_shard(composable_module, policy=ModuleWrapPolicy({UnitModule}))
else:
fully_shard(composable_module.u1)
fully_shard(composable_module.u2)
fully_shard(composable_module)

# Run a forward pass to trigger lazy initialization
inp = torch.randn((2, 100), device=device)
Expand Down
61 changes: 40 additions & 21 deletions torch/distributed/_composable/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@
_init_core_state,
_init_device_handle,
_init_ignored_module_states,
_init_param_handles_from_module,
_init_param_handle_from_module,
_init_prefetching_state,
_init_process_group_state,
_init_runtime_state,
_init_state_dict_state,
HYBRID_SHARDING_STRATEGIES,
)
from torch.distributed.fsdp._runtime_utils import (
_register_post_forward_hooks,
_register_pre_forward_hooks,
_register_post_forward_hook,
_register_pre_forward_hook,
_register_root_pre_forward_hook,
)
from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
from torch.distributed.fsdp._wrap_utils import _auto_wrap
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
Expand Down Expand Up @@ -66,39 +68,56 @@ def fully_shard(
state = _init_process_group_state(
state, process_group, ShardingStrategy.FULL_SHARD, policy
)
limit_all_gathers = True
use_orig_params = True
backward_prefetch_limit = 1
forward_prefetch_limit = 1
if policy is not None:
fsdp_kwargs = {
"process_group": process_group,
"strategy": strategy,
"mixed_precision": mixed_precision,
"cpu_offload": cpu_offload,
"ignored_modules": ignored_modules,
"device_id": device_id,
"param_init_fn": param_init_fn,
"sync_module_states": sync_module_states,
"forward_prefetch": forward_prefetch,
"ignored_states": ignored_states,
}
if strategy in HYBRID_SHARDING_STRATEGIES:
fsdp_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
_auto_wrap(
module,
policy,
state._ignored_modules,
state._ignored_params,
fsdp_kwargs,
fully_shard,
)
state = _init_core_state(
state,
strategy or ShardingStrategy.FULL_SHARD,
mixed_precision,
cpu_offload,
limit_all_gathers,
use_orig_params,
backward_prefetch_limit,
forward_prefetch_limit,
limit_all_gathers=True,
use_orig_params=True,
backward_prefetch_limit=1,
forward_prefetch_limit=1,
)
state = _init_runtime_state(state)
state = _init_prefetching_state(
state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch
)
state = _init_buffer_state(state, module)
state = _init_param_handles_from_module(
state,
module,
policy,
device_id,
param_init_fn,
sync_module_states,
state = _init_param_handle_from_module(
state, module, device_id, param_init_fn, sync_module_states
)
state = _init_state_dict_state(state)
_register_all_state_dict_hooks(state)
modules = list(module.modules())
_register_pre_forward_hooks(state, modules)
_register_post_forward_hooks(state, modules)
_register_pre_forward_hook(state, module)
_register_post_forward_hook(state, module)
_register_root_pre_forward_hook(state, module) # prepend last
# Always insert the state for the passed-in module even if it has no
# managed parameters, in which case it has no handles and does not appear
# in `_fully_sharded_module_to_handles`
_insert_module_state(module, state)
for submodule in module.modules():
if (
submodule in state._fully_sharded_module_to_handles
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/fsdp/_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def _module_handles(state: _FSDPState, module: nn.Module) -> List:
the handles that contain some parameter in ``module``.
"""
if _is_composable(state):
# A valid FSDP state may have no managed parameters and hence no
# handles, meaning no entry in `_fully_sharded_module_to_handles`
if len(state._handles) == 0:
return []
assert (
module in state._fully_sharded_module_to_handles
), f"Expects a fully sharded module but got {module} on rank {state.rank}"
Expand Down
90 changes: 1 addition & 89 deletions torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Optional,
Set,
Tuple,
Type,
Union,
)

Expand All @@ -37,7 +36,6 @@
TrainingState,
)
from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
Expand Down Expand Up @@ -485,11 +483,9 @@ def _init_param_handle_from_module(
device_id: Optional[Union[int, torch.device]],
param_init_fn: Optional[Callable[[nn.Module], None]],
sync_module_states: bool,
module_wrapper_cls: Type,
) -> _FSDPState:
"""
Initializes a ``FlatParamHandle`` from a module ``fully_sharded_module``.
This is the module wrapper code path.
"""
_check_single_device_module(fully_sharded_module, state._ignored_params, device_id)
device_from_device_id = _get_device_from_device_id(device_id, state.rank)
Expand All @@ -504,7 +500,7 @@ def _init_param_handle_from_module(
elif is_torchdistX_deferred_init:
deferred_init.materialize_module(
fully_sharded_module,
check_fn=lambda k: not isinstance(k, module_wrapper_cls),
check_fn=lambda k: _get_module_fsdp_state(k) is None,
)
_move_module_to_device(
fully_sharded_module, state._ignored_params, device_from_device_id
Expand All @@ -525,90 +521,6 @@ def _init_param_handle_from_module(
return state


@no_type_check
def _init_param_handles_from_module(
state: _FSDPState,
root_module: nn.Module,
policy: _FSDPPolicy,
device_id: Optional[Union[int, torch.device]],
param_init_fn: Optional[Callable[[nn.Module], None]],
sync_module_states: bool,
) -> _FSDPState:
"""
Initializes all ``FlatParamHandle`` s from a module ``root_module``. This
is the non-module-wrapper code path. ``root_module`` is guaranteed to be
a fully sharded module, and some of its submodules may be as well,
depending on ``policy``. See [Note: Fully Sharded Module].
"""
fully_sharded_module_to_states = _get_fully_sharded_module_to_states(
root_module,
policy,
state._ignored_modules,
state._ignored_params,
)
_check_single_device_module(root_module, state._ignored_params, device_id)
device_from_device_id = _get_device_from_device_id(device_id, state.rank)
# Initialize and shard `FlatParamHandle`s one by one following reverse
# depth-first order (i.e. reverse `.modules()` order), which represents a
# reverse topological sort order. This avoids increasing peak GPU memory
# usage when the unsharded model exists on CPU or meta device.
# NOTE: This order differs from that followed by the wrapper path when
# using auto wrapping, which also represents a valid reverse topological
# sort order, but the difference does not matter.
materialized_module = False
for fully_sharded_module, (params, buffers) in reversed(
fully_sharded_module_to_states.items()
):
# Materialize the module if needed
is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
fully_sharded_module, state._ignored_params, state._ignored_modules
)
if is_meta_module or is_torchdistX_deferred_init:
materialized_module = True
# Save the parameter and buffer names to reacquire references after
# after materialization since their variables may change
param_names, buffer_names = _get_state_names_for_states(
fully_sharded_module, params, buffers
)
if (
is_meta_module or is_torchdistX_deferred_init
) and param_init_fn is not None:
_materialize_with_param_init_fn(fully_sharded_module, param_init_fn)
elif is_meta_module:
_materialize_meta_module(fully_sharded_module, device_id)
elif is_torchdistX_deferred_init:
deferred_init.materialize_module(
root_module,
check_fn=lambda _: True,
)
if materialized_module:
# Reacquire references using the pre-computed state names
params = [
fully_sharded_module.get_parameter(param_name)
for param_name in param_names
]
buffers = [
fully_sharded_module.get_buffer(buffer_name)
for buffer_name in buffer_names
]
_move_states_to_device(params, buffers, device_from_device_id)
if state.compute_device is None: # only need to set once
state.compute_device = _get_compute_device(
fully_sharded_module,
state._ignored_params,
device_from_device_id,
state.rank,
)
if sync_module_states:
_sync_module_states(params, buffers, state.process_group)
_init_param_handle_from_params(state, params, fully_sharded_module)
# Reverse `_handles` to preserve depth-first `.modules()` order for
# consistency with the wrapper path (namely, so that `_get_fsdp_handles()`
# returns the same ordering for both paths).
state._handles.reverse()
return state


@no_type_check
def _init_param_handle_from_params(
state: _FSDPState,
Expand Down
Loading

0 comments on commit d9be036

Please sign in to comment.