Skip to content

Commit

Permalink
[DCP] Removes no_dist and coordinator_rank from public DCP API's
Browse files Browse the repository at this point in the history
[DCP] Removes `no_dist` and `coordinator_rank` from public DCP API's

Differential Revision: [D54591181](https://our.internmc.facebook.com/intern/diff/D54591181/)

ghstack-source-id: 217653051
Pull Request resolved: #121317
  • Loading branch information
LucasLLC committed Mar 6, 2024
1 parent 9deaa2e commit 31ff6c6
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 58 deletions.
6 changes: 4 additions & 2 deletions test/distributed/checkpoint/e2e/test_e2e_save_and_load.py
Expand Up @@ -278,8 +278,10 @@ def load_state_dict(self, state_dict):

@with_temp_dir
def test_no_dist(self):
DCP.save({}, checkpoint_id=self.temp_dir, no_dist=True)
DCP.load({}, checkpoint_id=self.temp_dir, no_dist=True)
# since comm's are not initialized in this method, `no_dist`
# is assumed False
DCP.save({}, checkpoint_id=self.temp_dir)
DCP.load({}, checkpoint_id=self.temp_dir)


class TestNoCPU(DTensorTestBase):
Expand Down
2 changes: 0 additions & 2 deletions test/distributed/checkpoint/test_compatibility.py
Expand Up @@ -50,13 +50,11 @@ def test_sharded_tensor_dependency(self) -> None:
dcp.save(
{"a": torch.zeros(4, 4)},
dcp.FileSystemWriter("/tmp/dcp_testing"),
no_dist=True,
)

dcp.load(
{"a": torch.zeros(4, 4)},
dcp.FileSystemReader("/tmp/dcp_testing"),
no_dist=True,
)


Expand Down
5 changes: 4 additions & 1 deletion torch/distributed/checkpoint/format_utils.py
Expand Up @@ -19,6 +19,7 @@
from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner
from torch.distributed.checkpoint.planner_helpers import _create_chunk_list
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
from torch.distributed.checkpoint.state_dict_saver import _save_state_dict
from torch.distributed.checkpoint.storage import StorageReader
from torch.futures import Future

Expand Down Expand Up @@ -265,4 +266,6 @@ def torch_save_to_dcp(
"""

state_dict = torch.load(torch_save_path)
dcp.save(state_dict, checkpoint_id=dcp_checkpoint_dir, no_dist=True)
# we don't need stateful behavior here because the expectation is anything loaded by
# torch.load would not contain stateful objects.
_save_state_dict(state_dict, checkpoint_id=dcp_checkpoint_dir, no_dist=True)
39 changes: 17 additions & 22 deletions torch/distributed/checkpoint/state_dict_loader.py
Expand Up @@ -49,8 +49,6 @@ def load(
storage_reader: Optional[StorageReader] = None,
planner: Optional[LoadPlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: Optional[bool] = None,
) -> None:
"""
Load a distributed ``state_dict`` in SPMD style.
Expand All @@ -75,9 +73,13 @@ def load(
pos-processing and non-tensor data properly propagates.
.. note:
This function can be used for local inference and load a checkpoint
produced by ``save_state_dict`` without having a process group initialized
by passing ``no_dist=True`` and by using Tensors instead of ShardedTensors.
If no process group is initialized, this function can assumesbe the intent
is to load a checkpoint into the local process. This can be useful in the
case of local inference, and when using regular Tensors (as opposed to DTensor
or ShardedTensor)
.. note:
Rank 0 is assumed to be the coordinator rank.
Args:
state_dict (Dict[str, Any]): The state_dict to save.
Expand All @@ -97,10 +99,6 @@ def load(
process_group (Optional[ProcessGroup]):
ProcessGroup to be used for cross-rank synchronization.
(Default: ``None``)
coordinator_rank (int): Rank to use to coordinate the checkpoint.
rank0 is used by default. (Default: ``0``)
no_dist (Optional[bool]): If ``True``, distributed checkpoint will not save
in SPMD style. (Default: ``False`` when torch.distributed is available and initialized)
Returns:
None.
Expand Down Expand Up @@ -131,13 +129,11 @@ def load(
rank has an individual GPU, via ``torch.cuda.set_device()``.
"""

if no_dist is None:
no_dist = not (dist.is_available() and dist.is_initialized())
if no_dist:
warnings.warn(
"Loading with `no_dist` set to True because torch.distributed"
" is unavailable or uninitialized."
)
no_dist = not (dist.is_available() and dist.is_initialized())
if no_dist:
warnings.warn(
"torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process."
)

with _profile():
storage_reader = cast(
Expand All @@ -164,12 +160,11 @@ def load(
)

_load_state_dict(
statetful_sd,
storage_reader,
process_group,
coordinator_rank,
no_dist,
planner,
state_dict=statetful_sd,
storage_reader=storage_reader,
process_group=process_group,
no_dist=no_dist,
planner=planner
)
for key in keys:
if key not in state_dict:
Expand Down
48 changes: 17 additions & 31 deletions torch/distributed/checkpoint/state_dict_saver.py
Expand Up @@ -56,8 +56,6 @@ def save(
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: Optional[bool] = None,
) -> Metadata:
"""
Save a distributed model in SPMD style.
Expand All @@ -82,9 +80,11 @@ def save(
group needs to be passed in.
.. note::
This function can be used to save a state_dict without having a process group
initialized by passing ``no_dist=True``.
(Default: ``False`` when torch.distributed is available and initialized)
If no process group is available, this function assumes the intention is to save the
state_dict in the local process.
.. note:
Rank 0 is assumed to be the coordinator rank.
Args:
Expand All @@ -105,10 +105,6 @@ def save(
process_group (Optional[ProcessGroup]):
ProcessGroup to be used for cross-rank synchronization.
(Default: ``None``)
coordinator_rank (int): Rank to use to coordinate the checkpoint.
rank0 is used by default. (Default: ``0``)
no_dist (bool): If ``True``, distributed checkpoint will not save
in SPMD style. (Default: ``False``)
Returns:
Metadata: Metadata object for the saved checkpoint.
Expand All @@ -135,25 +131,23 @@ def save(
"""
torch._C._log_api_usage_once("torch.distributed.checkpoint.save")

if no_dist is None:
no_dist = not (dist.is_available() and dist.is_initialized())
if no_dist:
warnings.warn(
"Saving with `no_dist` set to True because torch.distributed"
" is unavailable or uninitialized."
)
no_dist = not (dist.is_available() and dist.is_initialized())
if no_dist:
warnings.warn(
"torch.distributed is unavailable or uninitialized, assuming the intent is to save in a single process."
)

with _profile():
storage_writer = cast(
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
)

return _save_state_dict(
_stateful_to_state_dict(state_dict),
storage_writer,
process_group,
coordinator_rank,
no_dist,
planner,
state_dict=_stateful_to_state_dict(state_dict),
storage_writer=storage_writer,
process_group=process_group,
no_dist=no_dist,
planner=planner,
)


Expand All @@ -164,8 +158,6 @@ def _async_save(
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
) -> Future:
"""Asynchronous version of ``save_state_dict``. This code first de-stages the state_dict on CPU, and then calls
`save` in a separate thread.
Expand All @@ -191,10 +183,6 @@ def _async_save(
process_group (Optional[ProcessGroup]):
ProcessGroup to be used for cross-rank synchronization.
(Default: ``None``)
coordinator_rank (int): Rank to use to coordinate the checkpoint.
rank0 is used by default. (Default: ``0``)
no_dist (bool): If ``True``, distributed checkpoint will not save
in SPMD style. (Default: ``False``)
Returns:
Future: A future holding the resultant Metadata object from `save`.
Expand All @@ -216,9 +204,7 @@ def _async_save(
checkpoint_id=checkpoint_id,
storage_writer=storage_writer,
planner=planner,
process_group=process_group,
coordinator_rank=coordinator_rank,
no_dist=no_dist,
process_group=process_group
)
f.add_done_callback(lambda f: executor.shutdown(wait=False))

Expand Down

0 comments on commit 31ff6c6

Please sign in to comment.