-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[state_dict] Add cpu_only and ranks_only support for _gather_state_dict #112836
Changes from 8 commits
e2ed1d2
51ed292
0c8651d
e7bbe52
28f37a3
f060ecc
0e20921
3ac02f0
44a361b
9752f60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import math | ||
from typing import Any, Dict, Optional | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
@@ -49,13 +49,34 @@ def _all_gather_sharded_tensor( | |
|
||
def _gather_state_dict( | ||
state_dict: Dict[str, Any], | ||
*, | ||
pg: Optional[dist.ProcessGroup] = None, | ||
device: Optional[torch.device] = None, | ||
cpu_offload: bool = False, | ||
ranks_only: Tuple[int, ...] = tuple(), | ||
) -> Dict[str, Any]: | ||
""" | ||
Given a state_dict, this API gathers all the ShardedTensors or DTensors in the state_dict. | ||
Given a state_dict, this API gathers all the ShardedTensors or DTensors in | ||
the state_dict. | ||
|
||
|
||
Args: | ||
state_dict (Dict[str, Any]): the target sharded state_dict. | ||
pg (Optional[dist.ProcessGroup]): the process group that is used to | ||
gather ShardedTensor. | ||
device: (Optional[torch.device]): the device that is used to | ||
perform allgather for ShardedTensor. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: gather DTensor or ShardedTensor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ShardedTensor only. DTensor uses DeviceMesh to gather the tensors. |
||
cpu_offload (bool): whether to offload the tensors to CPU memory. The | ||
default value is False. | ||
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will | ||
have the same state_dicts. Otherwise only ranks that in ``ranks_only`` | ||
have the same state_dicts. Other ranks will get empty state_dicts. | ||
|
||
Returns: | ||
The gathered state dictionary. | ||
""" | ||
new_state_dict = {} | ||
cpu_device = torch.device("cpu") | ||
for key, value in state_dict.items(): | ||
if isinstance(value, ShardedTensor): | ||
# ShardedTensor does not seem to record the original device type. | ||
|
@@ -65,7 +86,7 @@ def _gather_state_dict( | |
local_shard_device = ( | ||
value.local_shards()[0].tensor.device | ||
if value.local_shards() | ||
else torch.device("cpu") | ||
else cpu_device | ||
) | ||
if output_tensor.device != local_shard_device: | ||
value = output_tensor.to(local_shard_device) | ||
|
@@ -86,7 +107,11 @@ def _gather_state_dict( | |
) | ||
value = value.to_local() | ||
elif isinstance(value, dict): | ||
value = _gather_state_dict(value, pg, device) | ||
value = _gather_state_dict(value, pg=pg, device=device) | ||
|
||
if isinstance(value, torch.Tensor) and cpu_offload: | ||
value = value.to(cpu_device) | ||
|
||
new_state_dict[key] = value | ||
if not cpu_offload or len(ranks_only) == 0 or dist.get_rank(pg) in ranks_only: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe the condition is a bit complicated and I can't think thru lol, but is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ranks_only can be empty (the default value). But it seems that we should not restrict this function to be cpu_offload only. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Makes sense! |
||
new_state_dict[key] = value | ||
return new_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: gather DTensor or ShardedTensor.