Skip to content

Commit

Permalink
[state_dict][11/N] Implement cpu_offload and full_state_dict for get_…
Browse files Browse the repository at this point in the history
…state_dict (#112837)

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
Pull Request resolved: #112837
Approved by: https://github.com/LucasLLC, https://github.com/wz337
ghstack dependencies: #112836, #112885
  • Loading branch information
fegin authored and pytorchmergebot committed Nov 13, 2023
1 parent b910d9e commit 2bcff4d
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 18 deletions.
71 changes: 71 additions & 0 deletions test/distributed/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DTensorTestBase,
with_comms,
)
from torch.utils._pytree import tree_all, tree_all_only


if not dist.is_available():
Expand Down Expand Up @@ -458,6 +459,76 @@ def test_partial(self) -> None:
self.assertEqual(model.l.weight, model_state_dict1["l.weight"])
self.assertEqual(model.l.bias, model_state_dict1["l.bias"])

@with_comms
@skip_if_lt_x_gpu(2)
def test_cpu_offload_full_state_dict(self) -> None:
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = torch.optim.Adam(orig_model.parameters(), lr=1e-3)
copy_optim = torch.optim.Adam(orig_model.parameters(), lr=1e-3)
device_mesh = init_device_mesh("cuda", (self.world_size,))
dist_model = FSDP(
copy.deepcopy(orig_model),
auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
use_orig_params=True,
device_mesh=device_mesh,
)

dist_optim = torch.optim.Adam(dist_model.parameters(), lr=1e-3)

mst, ost = get_state_dict(
dist_model,
dist_optim,
options=StateDictOptions(cpu_offload=True),
)

cpu_device = torch.device("cpu")

def is_cpu(v):
if isinstance(v, DTensor):
return v.device == cpu_device
elif isinstance(v, ShardedTensor):
shards = v.local_shards()
if not shards:
return True
return shards[0].tensor.device == cpu_device
else:
return v.device == cpu_device

self.assertTrue(
tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, mst)
)
self.assertTrue(
tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, ost)
)

mst, ost = get_state_dict(
dist_model, dist_optim, options=StateDictOptions(full_state_dict=True)
)

self.assertTrue(
tree_all(lambda v: not isinstance(v, (DTensor, ShardedTensor)), mst)
)
self.assertTrue(
tree_all(lambda v: not isinstance(v, (DTensor, ShardedTensor)), ost)
)

mst, ost = get_state_dict(
dist_model,
dist_optim,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)

if self.rank == 0:
self.assertTrue(
tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, mst)
)
self.assertTrue(
tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, ost)
)
else:
self.assertEqual(mst, {})
self.assertEqual(ost, {})


if __name__ == "__main__":
run_tests()
67 changes: 50 additions & 17 deletions torch/distributed/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@
)

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint._state_dict_utils import (
_gather_state_dict,
_offload_state_dict_to_cpu,
)
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
Expand Down Expand Up @@ -77,9 +82,13 @@ class StateDictOptions:
"""
This dataclass specifies how get_state_dict/set_state_dict will work.
- ``fsdp_state_dict_type``: if the model contains FSDP sharded submodules,
what FSDP state_dict type should be used.
The default value is SHARDED_STATE_DICT.
- ``full_state_dict``: if this is set to True, all the tensors in the
returned state_dict will be gathered. No ShardedTensor and DTensor
will be in the returned state_dict.
- ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if
``full_state_dict`` is also true, then only the rank0 will get the
state_dict and all other ranks will get empty state_dict.
- ``ignore_frozen_params``: if the value is True, the returned state_dict
won't contain any frozen parameters -- the ``requires_grad`` is False.
Expand All @@ -100,7 +109,8 @@ class StateDictOptions:
The default value is False.
"""

fsdp_state_dict_type: StateDictType = StateDictType.SHARDED_STATE_DICT
full_state_dict: bool = False
cpu_offload: bool = False
ignore_frozen_params: bool = False
keep_submodule_prefixes: bool = True
strict: bool = True
Expand Down Expand Up @@ -210,25 +220,25 @@ def _verify_options(
fsdp_context: Callable
if fsdp_modules:
# FSDP API only work if at least one FSDP instance exists.
if options.fsdp_state_dict_type == StateDictType.FULL_STATE_DICT:
if options.full_state_dict:
state_dict_config = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload
)
optim_state_dict_config = FullOptimStateDictConfig(
offload_to_cpu=True, rank0_only=True
offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload
)
elif options.fsdp_state_dict_type == StateDictType.SHARDED_STATE_DICT:
state_dict_config = ShardedStateDictConfig()
optim_state_dict_config = ShardedOptimStateDictConfig()
state_dict_type = StateDictType.FULL_STATE_DICT
else:
raise RuntimeError(
"state_dict currently support only FSDP "
"FULL_STATE_DICT and SHARDED_STATE_DICT"
state_dict_config = ShardedStateDictConfig()
optim_state_dict_config = ShardedOptimStateDictConfig(
offload_to_cpu=options.cpu_offload,
)
state_dict_type = StateDictType.SHARDED_STATE_DICT

fsdp_context = functools.partial(
FSDP.state_dict_type,
module=model,
state_dict_type=options.fsdp_state_dict_type,
state_dict_type=state_dict_type,
state_dict_config=state_dict_config,
optim_state_dict_config=optim_state_dict_config,
)
Expand Down Expand Up @@ -270,15 +280,19 @@ def _verify_state_dict(
and not model_state_dict
and not info.submodule_prefixes
and not info.ignore_frozen_params
and not (info.cpu_offload and info.full_state_dict)
and info.strict
):
raise RuntimeError(
"The option indicates that model state_dict is required to save "
"or load, but model state_dict is empty."
f"rank = {dist.get_rank()=}."
)

if info.handle_optim:
if not (optim_state_dict and optim_state_dict[STATE]):
if not (optim_state_dict and optim_state_dict[STATE]) and not (
info.cpu_offload and info.full_state_dict
):
raise RuntimeError(
"The option indicates that model state_dict is required to save, "
f"or load but optim state_dict is empty. {optim_state_dict}"
Expand Down Expand Up @@ -362,7 +376,15 @@ def verify(key, fqn) -> bool:
if p.is_meta:
state_dict.pop(key)

return state_dict
if info.full_state_dict:
ranks_only = tuple() if not info.cpu_offload else (0,)
return _gather_state_dict(
state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only
)
elif info.cpu_offload:
return _offload_state_dict_to_cpu(state_dict)
else:
return state_dict


def _load_model_state_dict(
Expand Down Expand Up @@ -448,10 +470,21 @@ def _get_optim_state_dict(
for group in osd[PG]:
group[PARAMS] = [fqn_pid_mapping[pid] for pid in group[PARAMS]]

if not osd:
continue

cast(DictValueType, optim_state_dict[STATE]).update(osd[STATE])
cast(ListDictValueType, optim_state_dict[PG]).extend(osd[PG])

return optim_state_dict
if info.full_state_dict:
ranks_only = tuple() if not info.cpu_offload else (0,)
return _gather_state_dict(
optim_state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only
)
elif info.cpu_offload:
return _offload_state_dict_to_cpu(optim_state_dict)
else:
return optim_state_dict


def _split_optim_state_dict(
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,7 +1476,7 @@ def _unflatten_orig_param_states(
)
elif not cpu_offload:
with SimpleProfiler.profile("clone"):
value = value.detach.clone()
value = value.detach().clone()

if cpu_offload:
with SimpleProfiler.profile(SimpleProfiler.Type.D2H):
Expand Down

0 comments on commit 2bcff4d

Please sign in to comment.