diff --git a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py index 9093419a3075c..534c525d0317e 100644 --- a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py +++ b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py @@ -278,8 +278,10 @@ def load_state_dict(self, state_dict): @with_temp_dir def test_no_dist(self): - DCP.save({}, checkpoint_id=self.temp_dir, no_dist=True) - DCP.load({}, checkpoint_id=self.temp_dir, no_dist=True) + # since comm's are not initialized in this method, `no_dist` + # is assumed False + DCP.save({}, checkpoint_id=self.temp_dir) + DCP.load({}, checkpoint_id=self.temp_dir) class TestNoCPU(DTensorTestBase): diff --git a/test/distributed/checkpoint/test_compatibility.py b/test/distributed/checkpoint/test_compatibility.py index c71c97e49cfef..f5950e5476ec9 100644 --- a/test/distributed/checkpoint/test_compatibility.py +++ b/test/distributed/checkpoint/test_compatibility.py @@ -50,13 +50,11 @@ def test_sharded_tensor_dependency(self) -> None: dcp.save( {"a": torch.zeros(4, 4)}, dcp.FileSystemWriter("/tmp/dcp_testing"), - no_dist=True, ) dcp.load( {"a": torch.zeros(4, 4)}, dcp.FileSystemReader("/tmp/dcp_testing"), - no_dist=True, ) diff --git a/torch/distributed/checkpoint/format_utils.py b/torch/distributed/checkpoint/format_utils.py index e9e1f541a8c2a..40d6c9b4332bc 100644 --- a/torch/distributed/checkpoint/format_utils.py +++ b/torch/distributed/checkpoint/format_utils.py @@ -19,6 +19,7 @@ from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner from torch.distributed.checkpoint.planner_helpers import _create_chunk_list from torch.distributed.checkpoint.state_dict_loader import _load_state_dict +from torch.distributed.checkpoint.state_dict_saver import _save_state_dict from torch.distributed.checkpoint.storage import StorageReader from torch.futures import Future @@ -265,4 +266,6 @@ def torch_save_to_dcp( """ state_dict = torch.load(torch_save_path) - dcp.save(state_dict, checkpoint_id=dcp_checkpoint_dir, no_dist=True) + # we don't need stateful behavior here because the expectation is anything loaded by + # torch.load would not contain stateful objects. + _save_state_dict(state_dict, checkpoint_id=dcp_checkpoint_dir, no_dist=True) diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index a89ae4b7dae3c..1b103e8895403 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -49,8 +49,6 @@ def load( storage_reader: Optional[StorageReader] = None, planner: Optional[LoadPlanner] = None, process_group: Optional[dist.ProcessGroup] = None, - coordinator_rank: int = 0, - no_dist: Optional[bool] = None, ) -> None: """ Load a distributed ``state_dict`` in SPMD style. @@ -75,9 +73,13 @@ def load( pos-processing and non-tensor data properly propagates. .. note: - This function can be used for local inference and load a checkpoint - produced by ``save_state_dict`` without having a process group initialized - by passing ``no_dist=True`` and by using Tensors instead of ShardedTensors. + If no process group is initialized, this function can assumesbe the intent + is to load a checkpoint into the local process. This can be useful in the + case of local inference, and when using regular Tensors (as opposed to DTensor + or ShardedTensor) + + .. note: + Rank 0 is assumed to be the coordinator rank. Args: state_dict (Dict[str, Any]): The state_dict to save. @@ -97,10 +99,6 @@ def load( process_group (Optional[ProcessGroup]): ProcessGroup to be used for cross-rank synchronization. (Default: ``None``) - coordinator_rank (int): Rank to use to coordinate the checkpoint. - rank0 is used by default. (Default: ``0``) - no_dist (Optional[bool]): If ``True``, distributed checkpoint will not save - in SPMD style. (Default: ``False`` when torch.distributed is available and initialized) Returns: None. @@ -131,13 +129,11 @@ def load( rank has an individual GPU, via ``torch.cuda.set_device()``. """ - if no_dist is None: - no_dist = not (dist.is_available() and dist.is_initialized()) - if no_dist: - warnings.warn( - "Loading with `no_dist` set to True because torch.distributed" - " is unavailable or uninitialized." - ) + no_dist = not (dist.is_available() and dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process." + ) with _profile(): storage_reader = cast( @@ -164,12 +160,11 @@ def load( ) _load_state_dict( - statetful_sd, - storage_reader, - process_group, - coordinator_rank, - no_dist, - planner, + state_dict=statetful_sd, + storage_reader=storage_reader, + process_group=process_group, + no_dist=no_dist, + planner=planner ) for key in keys: if key not in state_dict: diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index 59f0db33ccd3b..0a4592c178da7 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -56,8 +56,6 @@ def save( storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, - coordinator_rank: int = 0, - no_dist: Optional[bool] = None, ) -> Metadata: """ Save a distributed model in SPMD style. @@ -82,9 +80,11 @@ def save( group needs to be passed in. .. note:: - This function can be used to save a state_dict without having a process group - initialized by passing ``no_dist=True``. - (Default: ``False`` when torch.distributed is available and initialized) + If no process group is available, this function assumes the intention is to save the + state_dict in the local process. + + .. note: + Rank 0 is assumed to be the coordinator rank. Args: @@ -105,10 +105,6 @@ def save( process_group (Optional[ProcessGroup]): ProcessGroup to be used for cross-rank synchronization. (Default: ``None``) - coordinator_rank (int): Rank to use to coordinate the checkpoint. - rank0 is used by default. (Default: ``0``) - no_dist (bool): If ``True``, distributed checkpoint will not save - in SPMD style. (Default: ``False``) Returns: Metadata: Metadata object for the saved checkpoint. @@ -135,25 +131,23 @@ def save( """ torch._C._log_api_usage_once("torch.distributed.checkpoint.save") - if no_dist is None: - no_dist = not (dist.is_available() and dist.is_initialized()) - if no_dist: - warnings.warn( - "Saving with `no_dist` set to True because torch.distributed" - " is unavailable or uninitialized." - ) + no_dist = not (dist.is_available() and dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is unavailable or uninitialized, assuming the intent is to save in a single process." + ) with _profile(): storage_writer = cast( StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) ) + return _save_state_dict( - _stateful_to_state_dict(state_dict), - storage_writer, - process_group, - coordinator_rank, - no_dist, - planner, + state_dict=_stateful_to_state_dict(state_dict), + storage_writer=storage_writer, + process_group=process_group, + no_dist=no_dist, + planner=planner, ) @@ -164,8 +158,6 @@ def _async_save( storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, - coordinator_rank: int = 0, - no_dist: bool = False, ) -> Future: """Asynchronous version of ``save_state_dict``. This code first de-stages the state_dict on CPU, and then calls `save` in a separate thread. @@ -191,10 +183,6 @@ def _async_save( process_group (Optional[ProcessGroup]): ProcessGroup to be used for cross-rank synchronization. (Default: ``None``) - coordinator_rank (int): Rank to use to coordinate the checkpoint. - rank0 is used by default. (Default: ``0``) - no_dist (bool): If ``True``, distributed checkpoint will not save - in SPMD style. (Default: ``False``) Returns: Future: A future holding the resultant Metadata object from `save`. @@ -216,9 +204,7 @@ def _async_save( checkpoint_id=checkpoint_id, storage_writer=storage_writer, planner=planner, - process_group=process_group, - coordinator_rank=coordinator_rank, - no_dist=no_dist, + process_group=process_group ) f.add_done_callback(lambda f: executor.shutdown(wait=False))