Skip to content

Commit

Permalink
[DCP] Support partial load (#122829)
Browse files Browse the repository at this point in the history
Adds ability to load a subset of keys directly from a checkpoint, avoiding the need to initialize state dict first

Differential Revision: [D55441391](https://our.internmc.facebook.com/intern/diff/D55441391/)

Pull Request resolved: #122829
Approved by: https://github.com/fegin
  • Loading branch information
LucasLLC authored and pytorchmergebot committed Apr 2, 2024
1 parent feabb64 commit bcb6e5a
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 43 deletions.
32 changes: 29 additions & 3 deletions test/distributed/checkpoint/e2e/test_e2e_save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict,
_patch_optimizer_state_dict,
get_model_state_dict,
get_state_dict,
)
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
Expand Down Expand Up @@ -242,9 +244,6 @@ def _run_e2e_test(self, compile, model_type, async_op=False):
loaded_train_state = TestTrainState()
dist_model, dist_optim = self._create_model(compile, model_type)

loaded_stateful_obj = TestStatefulObj()
dist_model, dist_optim = self._create_model(compile, model_type)

DCP.load(
state_dict={
"model": dist_model,
Expand Down Expand Up @@ -328,6 +327,33 @@ def test_no_dist(self):
DCP.save({}, checkpoint_id=self.temp_dir)
DCP.load({}, checkpoint_id=self.temp_dir)

@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_partial_load(self):
model, optim = self._create_model(compile=False, model_type=ModelType.NONE)
_train(model, optim, train_steps=2)

dist_model, dist_optim = self._create_model(
compile=False, model_type=ModelType.FSDP
)
_train(dist_model, dist_optim, train_steps=2)

DCP.save(
{"model": dist_model, "optimizer": dist_optim}, checkpoint_id=self.temp_dir
)

dist_model, _ = self._create_model(compile=False, model_type=ModelType.FSDP)
DCP.load({"model": dist_model}, checkpoint_id=self.temp_dir)

dist_msd = get_model_state_dict(dist_model)
model_sd = get_model_state_dict(model)
self._verify_msd(model_sd, dist_msd)

# another way
loaded_model_sd = _load_state_dict_from_keys("model", model_sd)
self._verify_msd(model_sd, loaded_model_sd)


class TestNoCPU(DTensorTestBase):
@property
Expand Down
41 changes: 41 additions & 0 deletions torch/distributed/checkpoint/default_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,47 @@ def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)


class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
"""
Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
Useful for loading in state_dict without first initializing a model, such as
when converting a DCP checkpoint into a Torch save file.
. N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
.. warning::
Because the entire state dict is initialized, It's recommended to only utilize
this LoadPlanner on a single rank or process to avoid OOM.
"""

def __init__(self, keys=None, *args, **kwargs):
self.keys = keys
super().__init__(*args, **kwargs)

def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Metadata,
is_coordinator: bool,
) -> None:
assert not state_dict

# rebuild the state dict from the metadata
for k, v in metadata.state_dict_metadata.items():
if self.keys and k not in self.keys:
continue

if isinstance(v, TensorStorageMetadata):
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
if k in metadata.planner_data:
set_element(state_dict, metadata.planner_data[k], v)
else:
state_dict[k] = v

super().set_up_planner(state_dict, metadata, is_coordinator)


def create_default_local_load_plan(
state_dict: Dict[str, Any],
metadata: Metadata,
Expand Down
43 changes: 4 additions & 39 deletions torch/distributed/checkpoint/format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torch.distributed.checkpoint._traverse import set_element
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.default_planner import (
_EmptyStateDictLoadPlanner,
DefaultLoadPlanner,
)
from torch.distributed.checkpoint.metadata import (
Metadata,
STATE_DICT_TYPE,
Expand All @@ -33,43 +35,6 @@
]


class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
"""
Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
Useful for loading in state_dict without first initializing a model, such as
when converting a DCP checkpoint into a Torch save file.
. N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
.. warning::
Because the entire state dict is initialized, It's recommended to only utilize
this LoadPlanner on a single rank or process to avoid OOM.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Metadata,
is_coordinator: bool,
) -> None:
assert not state_dict

# rebuild the state dict from the metadata
for k, v in metadata.state_dict_metadata.items():
if isinstance(v, TensorStorageMetadata):
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
if k in metadata.planner_data:
set_element(state_dict, metadata.planner_data[k], v)
else:
state_dict[k] = v

super().set_up_planner(state_dict, metadata, is_coordinator)


class BroadcastingTorchSaveReader(StorageReader):
"""
StorageReader for reading a Torch Save file. This reader will read the entire checkpoint
Expand Down
70 changes: 69 additions & 1 deletion torch/distributed/checkpoint/state_dict_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.distributed as dist
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.stateful import Stateful

from ._storage_utils import _storage_setup
Expand Down Expand Up @@ -73,7 +74,7 @@ def load(
pos-processing and non-tensor data properly propagates.
.. note:
If no process group is initialized, this function can assumesbe the intent
If no process group is initialized, this function will assume the intent
is to load a checkpoint into the local process. This can be useful in the
case of local inference, and when using regular Tensors (as opposed to DTensor
or ShardedTensor)
Expand Down Expand Up @@ -215,3 +216,70 @@ def read_data():
return None

_ = distW.all_gather("read", read_data)


def _load_state_dict_from_keys(
keys: Optional[set] = None,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_reader: Optional[StorageReader] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
Load only the specified keys from the checkpoint, if no keys are specified, the entire
checkpoint will be loaded. Note, this method completely loads the checkpoint into the
current process and is not distributed.
.. warning::
.. warning::
All non-tensor data is loaded using `torch.load()`
.. note:
As opposed to the usual pattern, this function does not take a state dict as input
and does not load inplace. Instead, a new state dict is directly initialized and read
from file.
.. note:
If no process group is initialized, this function will assume the intent
is to load a checkpoint into the local process. This can be useful in the
case of local inference, and when using regular Tensors (as opposed to DTensor
or ShardedTensor)
.. note:
Rank 0 is assumed to be the coordinator rank.
Args:
state_dict (Dict[str, Any]): The state_dict to save.
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is a key-value store.
(Default: ``None``)
storage_reader (Optional[StorageReader]):
Instance of StorageWriter used to perform reads. If this is not
specified, DCP will automatically infer the reader based on the
checkpoint_id. If checkpoint_id is also None, an exception will
be raised. (Default: ``None``)
process_group (Optional[ProcessGroup]):
ProcessGroup to be used for cross-rank synchronization.
(Default: ``None``)
Returns:
State dict from specified keys
"""
torch._C._log_api_usage_once(
"torch.distributed.checkpoint._load_state_dict_from_keys"
)

sd: Dict[str, Any] = {}
load(
sd,
storage_reader=storage_reader,
planner=_EmptyStateDictLoadPlanner(keys=keys or set()),
process_group=process_group,
)

return sd

0 comments on commit bcb6e5a

Please sign in to comment.