Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[state_dict] Add cpu_only and ranks_only support for _gather_state_dict #112836

Closed
wants to merge 10 commits into from
Closed
25 changes: 24 additions & 1 deletion test/distributed/checkpoint/test_state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class TestStateDictUtils(DTensorTestBase):
@property
def world_size(self):
return 2
return min(4, torch.cuda.device_count())

@with_comms
@skip_if_lt_x_gpu(2)
Expand All @@ -35,6 +35,29 @@ def test_gather_state_dict_dtensor(self):
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
self.assertEqual(gathered_state_dict["dtensor"].is_cuda, True)

@with_comms
@skip_if_lt_x_gpu(4)
def test_cpu_and_ranks_only(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
torch.random.manual_seed(dist.get_rank())
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
state_dict = {"dtensor": dist_tensor}

gathered_state_dict = _gather_state_dict(
state_dict, cpu_offload=True, ranks_only=(0, 2)
)
expected_gathered_dtensor = funcol.all_gather_tensor(
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
if dist.get_rank() in (0, 2):
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
self.assertEqual(gathered_state_dict["dtensor"].is_cuda, False)
else:
self.assertEqual(gathered_state_dict, {})


if __name__ == "__main__":
Expand Down
35 changes: 30 additions & 5 deletions torch/distributed/checkpoint/_state_dict_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -49,13 +49,34 @@ def _all_gather_sharded_tensor(

def _gather_state_dict(
state_dict: Dict[str, Any],
*,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
cpu_offload: bool = False,
ranks_only: Tuple[int, ...] = tuple(),
) -> Dict[str, Any]:
"""
Given a state_dict, this API gathers all the ShardedTensors or DTensors in the state_dict.
Given a state_dict, this API gathers all the ShardedTensors or DTensors in
the state_dict.


Args:
state_dict (Dict[str, Any]): the target sharded state_dict.
pg (Optional[dist.ProcessGroup]): the process group that is used to
gather ShardedTensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: gather DTensor or ShardedTensor.

device: (Optional[torch.device]): the device that is used to
perform allgather for ShardedTensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: gather DTensor or ShardedTensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShardedTensor only. DTensor uses DeviceMesh to gather the tensors.

cpu_offload (bool): whether to offload the tensors to CPU memory. The
default value is False.
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
have the same state_dicts. Other ranks will get empty state_dicts.

Returns:
The gathered state dictionary.
"""
new_state_dict = {}
cpu_device = torch.device("cpu")
for key, value in state_dict.items():
if isinstance(value, ShardedTensor):
# ShardedTensor does not seem to record the original device type.
Expand All @@ -65,7 +86,7 @@ def _gather_state_dict(
local_shard_device = (
value.local_shards()[0].tensor.device
if value.local_shards()
else torch.device("cpu")
else cpu_device
)
if output_tensor.device != local_shard_device:
value = output_tensor.to(local_shard_device)
Expand All @@ -86,7 +107,11 @@ def _gather_state_dict(
)
value = value.to_local()
elif isinstance(value, dict):
value = _gather_state_dict(value, pg, device)
value = _gather_state_dict(value, pg=pg, device=device)

if isinstance(value, torch.Tensor) and cpu_offload:
value = value.to(cpu_device)

new_state_dict[key] = value
if not cpu_offload or len(ranks_only) == 0 or dist.get_rank(pg) in ranks_only:
Copy link
Contributor

@wz337 wz337 Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the condition is a bit complicated and I can't think thru lol, but is dist.get_rank(pg) in ranks_only alone not sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ranks_only can be empty (the default value). But it seems that we should not restrict this function to be cpu_offload only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Makes sense!

new_state_dict[key] = value
return new_state_dict