Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][state_dict] Restore the state_dict_config for NO_SHARD #100855

Closed
wants to merge 6 commits into from
2 changes: 2 additions & 0 deletions test/distributed/_composable/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def test_state_dict_fsdp_submodules(self):
self.assertIsInstance(tensor, ShardedTensor)
elif "u2" in fqn:
self.assertIsInstance(tensor, torch.Tensor)
# Ensure that get_state_dict_type can still correctly get the settings.
_ = FSDP.get_state_dict_type(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this meant to have some assert on the return value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_state_dict_type will assert if the state_dict_config is not restored.



instantiate_parametrized_tests(TestFSDPCheckpoint)
Expand Down
151 changes: 98 additions & 53 deletions torch/distributed/fsdp/_state_dict_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import contextlib
import functools
import math
import warnings
from typing import Any, Callable, cast, Dict, Iterator, List, no_type_check, Tuple
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterator,
List,
no_type_check,
Tuple,
)

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -635,6 +646,20 @@ def _sharded_pre_load_state_dict_hook(
_deregister_orig_params(module, fsdp_state)


@contextlib.contextmanager
def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator:
old_state_dict_config = fsdp_state._state_dict_config
old_state_dict_type = fsdp_state._state_dict_type
try:
fsdp_state._state_dict_config = FullStateDictConfig()
fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
yield
except Exception as e:
raise e
fsdp_state._state_dict_config = old_state_dict_config
fsdp_state._state_dict_type = old_state_dict_type


@no_type_check
@torch.no_grad()
def _post_state_dict_hook(
Expand All @@ -650,17 +675,23 @@ def _post_state_dict_hook(
what postprocessing will be done.
"""
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
fsdp_state._state_dict_config = FullStateDictConfig()
fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT

_post_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
}
processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
module, fsdp_state, state_dict, prefix
)
context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned."
)
else:
context = contextlib.nullcontext()

with context:
_post_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
}
processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
module, fsdp_state, state_dict, prefix
)
return processed_state_dict


Expand All @@ -678,24 +709,26 @@ def _pre_state_dict_hook(
be done.
"""
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
fsdp_state._state_dict_config = FullStateDictConfig()
fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned."
)

_pre_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
}
_pre_state_dict_hook_fn[fsdp_state._state_dict_type](
fsdp_state,
module,
*args,
**kwargs,
)
else:
context = contextlib.nullcontext()

with context:
_pre_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
}
_pre_state_dict_hook_fn[fsdp_state._state_dict_type](
fsdp_state,
module,
*args,
**kwargs,
)


@no_type_check
Expand All @@ -713,21 +746,27 @@ def _pre_load_state_dict_hook(
be done.
"""
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
fsdp_state._state_dict_config = FullStateDictConfig()
fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT

_pre_load_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
}
# Code that is common for all state_dict impls
if fsdp_state._device_handle.is_available():
fsdp_state._device_handle.synchronize()
# Dispatch into state_dict specific implementation of pre-hook.
_pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
module, fsdp_state, state_dict, prefix
)
context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned."
)
else:
context = contextlib.nullcontext()

with context:
_pre_load_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
}
# Code that is common for all state_dict impls
if fsdp_state._device_handle.is_available():
fsdp_state._device_handle.synchronize()
# Dispatch into state_dict specific implementation of pre-hook.
_pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
module, fsdp_state, state_dict, prefix
)


@no_type_check
Expand All @@ -738,18 +777,24 @@ def _post_load_state_dict_hook(
*args: Any,
) -> None:
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
fsdp_state._state_dict_config = FullStateDictConfig()
fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT

_post_load_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
}
# Code that is common for all state_dict impls
# Dispatch into state_dict type specific implementation of post-hook for
# loading state_dict.
_post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)
context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned."
)
else:
context = contextlib.nullcontext()

with context:
_post_load_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
}
# Code that is common for all state_dict impls
# Dispatch into state_dict type specific implementation of post-hook for
# loading state_dict.
_post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)


def _register_all_state_dict_hooks(state: _FSDPState):
Expand Down