Skip to content

Commit

Permalink
[FSDP]Skip unshard call during checkpointing for NO_SHARD sharding st…
Browse files Browse the repository at this point in the history
…rategy (#101095)

Pull Request resolved: #101095
Approved by: https://github.com/fegin
  • Loading branch information
zhaojuanmao authored and pytorchmergebot committed May 12, 2023
1 parent aec11b8 commit 5ac48eb
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.distributed._composable import fully_shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.testing._internal.common_dist_composable import (
Expand Down Expand Up @@ -170,6 +171,38 @@ def _test_save_dict_save_load_flow(
load_model.load_state_dict(state_dict)
self._check_model_parity(load_model, save_model)

@skip_if_lt_x_gpu(2)
def test_full_state_dict_save_load_mixed_sharding(self):
"""
Tests that the full state dict saved from a module with ``fully_shard``
and ``no_shard`` applied on the module matches that of an equivalent
local module. Also ensures that this state_dict can be reloaded into
a composable module and is equivalent to the original composable module.
"""
local_model = CompositeParamModel(device=torch.device("cuda"))

def _create_mixed_shard_on_model(mod: nn.Module):
fully_shard(mod.u1)
fully_shard(mod, strategy=ShardingStrategy.NO_SHARD)
return mod

save_composable = copy.deepcopy(local_model)
save_composable = _create_mixed_shard_on_model(save_composable)
local_sd = local_model.state_dict()
composable_sd = save_composable.state_dict()
self._check_state_dict_parity(local_sd, composable_sd)

# Validate load
load_composable = copy.deepcopy(local_model)
load_composable = _create_mixed_shard_on_model(load_composable)
_zero_model(load_composable, summon_full=False)
for p in load_composable.parameters():
self.assertEqual(0, p.sum())

sd = {k: v.clone() for k, v in composable_sd.items()}
load_composable.load_state_dict(sd)
self._check_model_parity(load_composable, save_composable)

def _check_state_dict_parity(self, local_sd: Dict, composable_sd: Dict):
"""Checks that ``local_sd`` and ``composable_sd`` are the same."""
# Check that all keys match
Expand Down
31 changes: 27 additions & 4 deletions torch/distributed/fsdp/_state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def _common_unshard_pre_state_dict_hook(
Performs the pre-state_dict tasks shared by all state_dict types that require
``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook.
"""
# For composable `fully_shard`, it does not need to unshard parameters for `NO_SHARD` cases.
if (
_is_composable(fsdp_state)
and fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
):
return
_enter_unshard_params_ctx(
module,
fsdp_state,
Expand All @@ -172,7 +178,11 @@ def _common_unshard_post_state_dict_hook(
_replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix)
# Return early for trivial cases
if not state_dict or not _has_fsdp_params(fsdp_state, module):
_exit_unshard_params_ctx(module, fsdp_state)
if not (
_is_composable(fsdp_state)
and fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
):
_exit_unshard_params_ctx(module, fsdp_state)
return state_dict

# If a rank does not have unsharded parameters(when `rank0_only=True`
Expand Down Expand Up @@ -215,7 +225,12 @@ def _common_unshard_post_state_dict_hook(
)

param_hook(state_dict, prefix, fqn)
_exit_unshard_params_ctx(module, fsdp_state)

if not (
_is_composable(fsdp_state)
and fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
):
_exit_unshard_params_ctx(module, fsdp_state)

cpu_device = torch.device("cpu")
buffer_clean_fqns = []
Expand Down Expand Up @@ -335,7 +350,11 @@ def _full_pre_load_state_dict_hook(
prefix: str,
) -> None:
_lazy_init(fsdp_state, module)
_enter_unshard_params_ctx(module, fsdp_state, writeback=True)
if not (
_is_composable(fsdp_state)
and fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
):
_enter_unshard_params_ctx(module, fsdp_state, writeback=True)
# Add FSDP_PREFIX only for wrapper-based FSDP.
if not _is_composable(fsdp_state):
_replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
Expand All @@ -344,7 +363,11 @@ def _full_pre_load_state_dict_hook(
def _full_post_load_state_dict_hook(
module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
) -> None:
_exit_unshard_params_ctx(module, fsdp_state)
if not (
_is_composable(fsdp_state)
and fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
):
_exit_unshard_params_ctx(module, fsdp_state)


def _local_pre_state_dict_hook(
Expand Down

0 comments on commit 5ac48eb

Please sign in to comment.