Skip to content

Commit

Permalink
[FSDP][optim_state_dict] Enable cpu_offload config for optimzer state…
Browse files Browse the repository at this point in the history
…_dict

Pull Request resolved: #108434

We had the option but never used cpu_offload as optimizer state_dict offloads the tensors to CPU by default. And this is usually most users want as the tensors are required to be moved to CPU eventually. However, we may want to disable offloading to CPU in some cases, epsecially for the debugging purpose. This PR lets optimizer state_dict read the flag.
ghstack-source-id: 199739177
@exported-using-ghexport

Differential Revision: [D48913340](https://our.internmc.facebook.com/intern/diff/D48913340/)
  • Loading branch information
fegin committed Sep 5, 2023
1 parent c8d4159 commit 179bd9e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 15 deletions.
37 changes: 35 additions & 2 deletions test/distributed/fsdp/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_WRAPPED_MODULE,
apply_activation_checkpointing,
Expand Down Expand Up @@ -1864,8 +1865,40 @@ def step():
optim.state_dict(), original_osd, check_same_param_keys=True
)

# TODO: add local/sharded/full state_dict and CPU offloading and rank0
# interface test here, https://github.com/pytorch/pytorch/issues/97163
with FSDP.state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
ShardedStateDictConfig(),
ShardedOptimStateDictConfig(offload_to_cpu=False),
):
osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
for fqn, state in osd["state"].items():
for s in state.values():
if s.dim() == 0:
continue
self.assertTrue(isinstance(s, ShardedTensor))
if s._local_shards[0]:
self.assertTrue(s._local_shards[0].tensor.is_cuda)

with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(),
FullOptimStateDictConfig(
offload_to_cpu=True,
rank0_only=True,
),
):
osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
if dist.get_rank() > 0:
self.assertEqual(osd, {})
else:
for fqn, state in osd["state"].items():
for s in state.values():
if s.dim() == 0:
continue
self.assertFalse(s.is_cuda)
self.assertFalse(isinstance(s, ShardedTensor))

@skip_if_lt_x_gpu(2)
def test_state_dict_with_none_tensor_state(self):
Expand Down
55 changes: 42 additions & 13 deletions torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _unflatten_optim_state(
flat_param_state: Dict[str, Any],
to_save: bool,
shard_state: bool,
cpu_offload: bool,
) -> List[Dict[str, Any]]:
"""
Unflattens the optimizer state, consisting of the "state" part and the
Expand Down Expand Up @@ -162,9 +163,11 @@ def _unflatten_optim_state(
)
for optim_state in unflat_param_state:
# We can't use .items() below cuz we'd run into a concurrent modification error
for key in list(optim_state.keys()):
state = optim_state[key]
if isinstance(state, torch.Tensor):
if cpu_offload:
for key in list(optim_state.keys()):
state = optim_state[key]
if not isinstance(state, torch.Tensor):
continue
optim_state[key] = state.cpu()
return unflat_param_state
else:
Expand Down Expand Up @@ -1388,6 +1391,7 @@ def _unflatten_orig_param_states(
state_name: str,
shard_state: bool,
to_save: bool,
cpu_offload: bool,
) -> None:
"""
Given a output state dict, ``output_states``, which the keys are FQNs to the
Expand Down Expand Up @@ -1426,8 +1430,13 @@ def _unflatten_orig_param_states(
value = _ext_chunk_dtensor(
value, fsdp_state.rank, fsdp_state._device_mesh
)
with SimpleProfiler.profile(SimpleProfiler.Type.D2H):
value = value.cpu()
elif not cpu_offload:
with SimpleProfiler.profile("clone"):
value = value.detach.clone()

if cpu_offload:
with SimpleProfiler.profile(SimpleProfiler.Type.D2H):
value = value.cpu()
gathered_state[state_name] = value

logger.warning(
Expand All @@ -1441,6 +1450,7 @@ def _allgather_orig_param_states(
input_states: Dict[str, Any],
shard_state: bool,
to_save: bool,
cpu_offload: bool,
) -> Dict[str, Dict[str, Any]]:
"""
Given the ``gathered_state_info`` and ``input_states``, the API allgather
Expand Down Expand Up @@ -1548,9 +1558,10 @@ def _allgather_orig_param_states(
state_name,
shard_state,
to_save,
cpu_offload,
)
del gathered_tensor

del gathered_tensor
return output_states


Expand All @@ -1559,6 +1570,7 @@ def _gather_all_orig_param_state(
input_states: Dict[str, Any],
shard_state: bool,
to_save: bool,
cpu_offload: bool,
) -> Dict[str, Any]:
"""
Given a optimizer state dict, ``input_states``, which the keys are FQNs to the
Expand All @@ -1577,7 +1589,12 @@ def _gather_all_orig_param_state(
with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER_OBJ):
gathered_state_info = _allgather_state_info(fsdp_state, input_states)
output_states = _allgather_orig_param_states(
fsdp_param_info, gathered_state_info, input_states, shard_state, to_save
fsdp_param_info,
gathered_state_info,
input_states,
shard_state,
to_save,
cpu_offload,
)
if to_save:
assert set(output_states.keys()) == set(fsdp_param_info.param_indices.keys())
Expand All @@ -1593,6 +1610,7 @@ def _convert_state_with_orig_params(
optim_state_dict: Dict[Union[str, int], Any],
to_save: bool,
shard_state: bool,
cpu_offload: bool = True,
) -> Dict[str, Any]:
fsdp_osd_state: Dict[str, Any] = {}
all_states: Dict[int, Dict[str, Any]] = {}
Expand Down Expand Up @@ -1622,10 +1640,12 @@ def _convert_state_with_orig_params(
fsdp_osd_state[unflat_param_name] = copy.copy(
optim_state_dict[param_key]
)
for state_name, value in sorted_items(
fsdp_osd_state[unflat_param_name]
):
if torch.is_tensor(value):
if cpu_offload:
for state_name, value in sorted_items(
fsdp_osd_state[unflat_param_name]
):
if not torch.is_tensor(value):
continue
fsdp_osd_state[unflat_param_name][state_name] = value.cpu()

# Instead of gathering the state of each parameter individually, we perform
Expand All @@ -1640,6 +1660,7 @@ def _convert_state_with_orig_params(
_all_states,
shard_state,
to_save,
cpu_offload,
)
)

Expand All @@ -1653,6 +1674,7 @@ def _convert_state_with_flat_params(
optim_state_dict: Dict[Union[str, int], Any],
to_save: bool,
shard_state: bool,
cpu_offload: bool,
) -> Dict[str, Any]:
fsdp_osd_state: Dict[str, Any] = {}
# Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
Expand All @@ -1678,6 +1700,7 @@ def _convert_state_with_flat_params(
optim_state_dict[param_key],
to_save,
shard_state,
cpu_offload,
)
if to_save:
assert len(unflat_state) == len(optim_state_key.unflat_param_names)
Expand All @@ -1690,8 +1713,12 @@ def _convert_state_with_flat_params(
assert len(optim_state_key.unflat_param_names) == 1
unflat_param_name = optim_state_key.unflat_param_names[0]
fsdp_osd_state[unflat_param_name] = copy.copy(optim_state_dict[param_key])
for state_name, value in sorted_items(fsdp_osd_state[unflat_param_name]):
if torch.is_tensor(value):
if cpu_offload:
for state_name, value in sorted_items(
fsdp_osd_state[unflat_param_name]
):
if not torch.is_tensor(value):
continue
fsdp_osd_state[unflat_param_name][state_name] = value.cpu()

return fsdp_osd_state
Expand All @@ -1713,6 +1740,7 @@ def _optim_state_dict(
group: Optional[dist.ProcessGroup],
using_optim_input: bool,
use_orig_params: bool = False,
cpu_offload: bool = True,
) -> Dict[str, Any]:
"""
Consolidates the optimizer state and returns it as a :class:`dict`
Expand Down Expand Up @@ -1811,6 +1839,7 @@ def _optim_state_dict(
optim_state_dict["state"],
to_save,
shard_state,
cpu_offload,
)

# At this point, communication is complete and ranks can return early if nothing
Expand Down
5 changes: 5 additions & 0 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ def _optim_state_dict_impl(
rank0_only: bool = True,
full_state_dict: bool = True,
group: Optional[dist.ProcessGroup] = None,
cpu_offload: bool = True,
) -> Dict[str, Any]:
"""
The internal API that is used by all the optim_state_dict implementations.
Expand Down Expand Up @@ -1253,6 +1254,7 @@ def _optim_state_dict_impl(
group=group,
using_optim_input=using_optim_input,
use_orig_params=use_orig_params,
cpu_offload=cpu_offload,
)

@staticmethod
Expand Down Expand Up @@ -1826,6 +1828,9 @@ def optim_state_dict(
full_state_dict=state_dict_settings.state_dict_type
== StateDictType.FULL_STATE_DICT,
group=group,
cpu_offload=getattr(
state_dict_settings.optim_state_dict_config, "offload_to_cpu", True
),
)

@staticmethod
Expand Down

0 comments on commit 179bd9e

Please sign in to comment.