Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][optim_state_dict] Add device to _shard_utils.py to explicitly use the device from fsdp_state #109631

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,9 @@ def _shard_orig_param_state(
flat_param = fsdp_param_info.handle.flat_param
param_idx = fsdp_param_info.param_indices[fqn]
shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined]
optim_state = _gather_state_dict(optim_state, fsdp_state.process_group)
optim_state = _gather_state_dict(
optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device
)
if not shard_param_info.in_shard:
return {}
# Flatten and shard the state.
Expand Down Expand Up @@ -583,7 +585,9 @@ def _flatten_optim_state(
# without
unflat_param_states = [
_gather_state_dict(
unflat_osd_state[unflat_param_name], pg=fsdp_state.process_group
unflat_osd_state[unflat_param_name],
pg=fsdp_state.process_group,
device=fsdp_state.compute_device,
)
if unflat_param_name in unflat_osd_state
else None
Expand Down
27 changes: 20 additions & 7 deletions torch/distributed/fsdp/_shard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


def _all_gather_sharded_tensor(
sharded_tensor: ShardedTensor, pg: Optional[dist.ProcessGroup] = None
sharded_tensor: ShardedTensor,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
if pg is None:
pg = distributed_c10d._get_default_group()
Expand All @@ -28,7 +30,9 @@ def _all_gather_sharded_tensor(
dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
pg_device = distributed_c10d._get_pg_default_device(pg)
pg_device = (
distributed_c10d._get_pg_default_device(pg) if device is None else device
)
if shards:
local_tensor = shards[0].tensor.flatten()
with SimpleProfiler.profile(SimpleProfiler.Type.D2H):
Expand All @@ -47,7 +51,7 @@ def _all_gather_sharded_tensor(
dtype=local_tensor.dtype,
device=pg_device,
)
dist._all_gather_base(tensor, local_tensor, group=pg)
dist.all_gather_into_tensor(tensor, local_tensor, group=pg)

tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
return tensor
Expand All @@ -58,14 +62,18 @@ def _all_gather_sharded_tensor(
def _gather_state_dict(
state_dict: Dict[str, Any],
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> Dict[str, Any]:
"""
Given a state_dict, this API gathers all the ShardedTensors or DTensors in the state_dict.
"""
new_state_dict = {}
for key, tensor in state_dict.items():
if isinstance(tensor, ShardedTensor):
output_tensor = _all_gather_sharded_tensor(tensor, pg)
# ShardedTensor does not seem to record the original device type.
# So if the tensor is moved to CPU, we won't know the original type.
# As a result, we have to rely on the user to tell us the correct one.
output_tensor = _all_gather_sharded_tensor(tensor, pg, device)
local_shard_device = (
tensor.local_shards()[0].tensor.device
if tensor.local_shards()
Expand All @@ -92,7 +100,7 @@ def _gather_state_dict(
return new_state_dict


def _get_remove_device_str(rank, device_type, num_devices_per_node):
def _get_remote_device_str(rank, device_type, num_devices_per_node):
if device_type.lower() == "cpu":
return f"rank:{rank}/{device_type}"
else:
Expand All @@ -105,6 +113,7 @@ def _create_chunk_sharded_tensor(
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
) -> ShardedTensor:
"""
Shard a tensor to chunks along the first dimension. The local rank will gets its
Expand All @@ -126,9 +135,13 @@ def _create_chunk_sharded_tensor(
)[:-1]
offsets = [0] * (len(chunk_sizes[0]) - 1)
chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
device_type = distributed_c10d._get_pg_default_device(pg).type
device_type = (
distributed_c10d._get_pg_default_device(pg).type
if device is None
else device.type
)
placements = [
_get_remove_device_str(r, device_type, num_devices_per_node)
_get_remote_device_str(r, device_type, num_devices_per_node)
for r in range(len(chunk_sizes))
]
assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
Expand Down