Skip to content

Commit

Permalink
[FSDP][3/N] Unify fully_shard auto wrap
Browse files Browse the repository at this point in the history
ghstack-source-id: 9e0cf806b4bc63ef5bb5361c6a4c1cb33cd80c7c
Pull Request resolved: #104408
  • Loading branch information
awgu committed Jun 29, 2023
1 parent 4fb6a63 commit 022eca5
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 109 deletions.
15 changes: 12 additions & 3 deletions test/distributed/_composable/fully_shard/test_fully_shard_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,20 @@ def test_nested_fully_shard_shared_state(self):
Tests that nested applications of ``fully_shard`` share the expected
data structure state.
"""
self.run_subtests(
{"use_policy": [False, True]},
self._test_nested_fully_shard_shared_state,
)

def _test_nested_fully_shard_shared_state(self, use_policy: bool):
device = torch.device("cuda")
composable_module = CompositeParamModel(device=device)
fully_shard(composable_module.u1)
fully_shard(composable_module.u2)
fully_shard(composable_module)
if use_policy:
fully_shard(composable_module, policy=ModuleWrapPolicy({UnitModule}))
else:
fully_shard(composable_module.u1)
fully_shard(composable_module.u2)
fully_shard(composable_module)

# Run a forward pass to trigger lazy initialization
inp = torch.randn((2, 100), device=device)
Expand Down
48 changes: 32 additions & 16 deletions torch/distributed/_composable/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@
_init_core_state,
_init_device_handle,
_init_ignored_module_states,
_init_param_handles_from_module,
_init_param_handle_from_module,
_init_prefetching_state,
_init_process_group_state,
_init_runtime_state,
_init_state_dict_state,
HYBRID_SHARDING_STRATEGIES,
)
from torch.distributed.fsdp._runtime_utils import (
_register_post_forward_hooks,
_register_pre_forward_hooks,
_register_root_pre_forward_hook,
)
from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
from torch.distributed.fsdp._wrap_utils import _auto_wrap
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
Expand Down Expand Up @@ -66,32 +68,46 @@ def fully_shard(
state = _init_process_group_state(
state, process_group, ShardingStrategy.FULL_SHARD, policy
)
limit_all_gathers = True
use_orig_params = True
backward_prefetch_limit = 1
forward_prefetch_limit = 1
if policy is not None:
fsdp_kwargs = {
"process_group": process_group,
"strategy": strategy,
"mixed_precision": mixed_precision,
"cpu_offload": cpu_offload,
"ignored_modules": ignored_modules,
"device_id": device_id,
"param_init_fn": param_init_fn,
"sync_module_states": sync_module_states,
"forward_prefetch": forward_prefetch,
"ignored_states": ignored_states,
}
if strategy in HYBRID_SHARDING_STRATEGIES:
fsdp_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
_auto_wrap(
module,
policy,
state._ignored_modules,
state._ignored_params,
fsdp_kwargs,
fully_shard,
)
state = _init_core_state(
state,
strategy or ShardingStrategy.FULL_SHARD,
mixed_precision,
cpu_offload,
limit_all_gathers,
use_orig_params,
backward_prefetch_limit,
forward_prefetch_limit,
limit_all_gathers=True,
use_orig_params=True,
backward_prefetch_limit=1,
forward_prefetch_limit=1,
)
state = _init_runtime_state(state)
state = _init_prefetching_state(
state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch
)
state = _init_buffer_state(state, module)
state = _init_param_handles_from_module(
state,
module,
policy,
device_id,
param_init_fn,
sync_module_states,
state = _init_param_handle_from_module(
state, module, device_id, param_init_fn, sync_module_states
)
state = _init_state_dict_state(state)
_register_all_state_dict_hooks(state)
Expand Down
90 changes: 1 addition & 89 deletions torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Optional,
Set,
Tuple,
Type,
Union,
)

Expand All @@ -35,7 +34,6 @@
TrainingState,
)
from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
Expand Down Expand Up @@ -447,11 +445,9 @@ def _init_param_handle_from_module(
device_id: Optional[Union[int, torch.device]],
param_init_fn: Optional[Callable[[nn.Module], None]],
sync_module_states: bool,
module_wrapper_cls: Type,
) -> _FSDPState:
"""
Initializes a ``FlatParamHandle`` from a module ``fully_sharded_module``.
This is the module wrapper code path.
"""
_check_single_device_module(fully_sharded_module, state._ignored_params)
device_from_device_id = _get_device_from_device_id(device_id, state.rank)
Expand All @@ -466,7 +462,7 @@ def _init_param_handle_from_module(
elif is_torchdistX_deferred_init:
deferred_init.materialize_module(
fully_sharded_module,
check_fn=lambda k: not isinstance(k, module_wrapper_cls),
check_fn=lambda k: _get_module_fsdp_state(k) is None,
)
_move_module_to_device(
fully_sharded_module, state._ignored_params, device_from_device_id
Expand All @@ -487,90 +483,6 @@ def _init_param_handle_from_module(
return state


@no_type_check
def _init_param_handles_from_module(
state: _FSDPState,
root_module: nn.Module,
policy: _FSDPPolicy,
device_id: Optional[Union[int, torch.device]],
param_init_fn: Optional[Callable[[nn.Module], None]],
sync_module_states: bool,
) -> _FSDPState:
"""
Initializes all ``FlatParamHandle`` s from a module ``root_module``. This
is the non-module-wrapper code path. ``root_module`` is guaranteed to be
a fully sharded module, and some of its submodules may be as well,
depending on ``policy``. See [Note: Fully Sharded Module].
"""
fully_sharded_module_to_states = _get_fully_sharded_module_to_states(
root_module,
policy,
state._ignored_modules,
state._ignored_params,
)
_check_single_device_module(root_module, state._ignored_params)
device_from_device_id = _get_device_from_device_id(device_id, state.rank)
# Initialize and shard `FlatParamHandle`s one by one following reverse
# depth-first order (i.e. reverse `.modules()` order), which represents a
# reverse topological sort order. This avoids increasing peak GPU memory
# usage when the unsharded model exists on CPU or meta device.
# NOTE: This order differs from that followed by the wrapper path when
# using auto wrapping, which also represents a valid reverse topological
# sort order, but the difference does not matter.
materialized_module = False
for fully_sharded_module, (params, buffers) in reversed(
fully_sharded_module_to_states.items()
):
# Materialize the module if needed
is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
fully_sharded_module, state._ignored_params
)
if is_meta_module or is_torchdistX_deferred_init:
materialized_module = True
# Save the parameter and buffer names to reacquire references after
# after materialization since their variables may change
param_names, buffer_names = _get_state_names_for_states(
fully_sharded_module, params, buffers
)
if (
is_meta_module or is_torchdistX_deferred_init
) and param_init_fn is not None:
_materialize_with_param_init_fn(fully_sharded_module, param_init_fn)
elif is_meta_module:
_materialize_meta_module(fully_sharded_module, device_id)
elif is_torchdistX_deferred_init:
deferred_init.materialize_module(
root_module,
check_fn=lambda _: True,
)
if materialized_module:
# Reacquire references using the pre-computed state names
params = [
fully_sharded_module.get_parameter(param_name)
for param_name in param_names
]
buffers = [
fully_sharded_module.get_buffer(buffer_name)
for buffer_name in buffer_names
]
_move_states_to_device(params, buffers, device_from_device_id)
if state.compute_device is None: # only need to set once
state.compute_device = _get_compute_device(
fully_sharded_module,
state._ignored_params,
device_from_device_id,
state.rank,
)
if sync_module_states:
_sync_module_states(params, buffers, state.process_group)
_init_param_handle_from_params(state, params, fully_sharded_module)
# Reverse `_handles` to preserve depth-first `.modules()` order for
# consistency with the wrapper path (namely, so that `_get_fsdp_handles()`
# returns the same ordering for both paths).
state._handles.reverse()
return state


@no_type_check
def _init_param_handle_from_params(
state: _FSDPState,
Expand Down
1 change: 0 additions & 1 deletion torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,6 @@ def __init__(
device_id,
param_init_fn,
sync_module_states,
FullyShardedDataParallel,
)
self._fsdp_wrapped_module = module
if not use_orig_params:
Expand Down

0 comments on commit 022eca5

Please sign in to comment.