Skip to content

Commit

Permalink
[2D] Enable 2D DTensor state_dict for FSDP + TP (#110846)
Browse files Browse the repository at this point in the history
This PR adds a `chunk_dtensor()` method to fsdp/_fsdp_extensions.py and the actual implementation of `chunk_dtensor()` in tensor/parallel/fsdp.py. This enables FSDP to return 2D DTensor state_dict when composing FSDP with TP.

cc. @fegin
Pull Request resolved: #110846
Approved by: https://github.com/fegin, https://github.com/wanchaol
ghstack dependencies: #110831
  • Loading branch information
wz337 authored and pytorchmergebot committed Oct 11, 2023
1 parent 0bd4ce7 commit 6c136c3
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 15 deletions.
53 changes: 52 additions & 1 deletion test/distributed/tensor/parallel/test_fsdp_2d_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch.nn.functional as F
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._tensor import DTensor as DT, init_device_mesh, Replicate
from torch.distributed._tensor import DTensor as DT, init_device_mesh, Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
Expand Down Expand Up @@ -401,6 +401,57 @@ def test_2d_e2e_training_use_orig_params(self):
def test_2d_e2e_training_not_use_orig_params(self):
self._test_2d_e2e_training(recompute_activation=True)

@with_comms
@skip_if_lt_x_gpu(4)
def test_2d_state_dict(self):
# Create a model without wrapper
torch.manual_seed(0)
simple_model = SimpleModel().cuda(self.rank)
no_wrap_state_dict = simple_model.state_dict()

# Create a model and sharded it with 2D FSDP + TP
torch.manual_seed(0)
mesh_2d = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
)
tp_mesh = mesh_2d["tp"]
dp_mesh = mesh_2d["dp"]
model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, PairwiseParallel())
model_2d = FSDP(
model_2d,
device_mesh=dp_mesh,
use_orig_params=True,
)

FSDP.set_state_dict_type(
model_2d,
StateDictType.SHARDED_STATE_DICT,
)
state_dict_2d = model_2d.state_dict()

for no_wrap_items, two_d_items in zip(
no_wrap_state_dict.items(), state_dict_2d.items()
):
no_wrap_k, no_wrap_v = no_wrap_items
two_d_k, two_d_v = two_d_items

self.assertEqual(no_wrap_k, two_d_k)

# check if all value in 2D state_dict are DTensor
self.assertTrue(isinstance(two_d_v, DT))
self.assertEqual(len(two_d_v.placements), 2)
# the outer dimension is the FSDP dimension and the placement is always Shard(0)
self.assertEqual(two_d_v.placements[0], Shard(0))
self.assertEqual(two_d_v.device_mesh, mesh_2d)

# check if the parameter value is the same between 2D model and the model without wrapper
all_gather_two_d_v = two_d_v.redistribute(
mesh_2d, (Replicate(), Replicate())
)
self.assertEqual(
torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True
)


if __name__ == "__main__":
run_tests()
19 changes: 14 additions & 5 deletions torch/distributed/fsdp/_fsdp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def chunk_tensor(
"""Shards a tensor to chunks and returns the local chunk."""
...

@abstractmethod
def chunk_dtensor(
self,
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> torch.Tensor:
"""Shards a tensor/DTensor to DTensor and returns the local DTensor."""
...

@abstractmethod
def pre_load_state_dict_transform(
self,
Expand Down Expand Up @@ -114,11 +124,10 @@ def _ext_chunk_dtensor(
rank: int,
device_mesh: DeviceMesh,
) -> torch.Tensor:
# TODO: Address composability issue and remove the assertion.
assert (
_extensions is None
), "Currently does not support composability when _use_dtensor = True"
return _create_chunk_dtensor(
chunk_dtensor_fn = (
_extensions.chunk_dtensor if _extensions is not None else _create_chunk_dtensor
)
return chunk_dtensor_fn(
tensor,
rank,
device_mesh,
Expand Down
1 change: 0 additions & 1 deletion torch/distributed/fsdp/_shard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def _create_chunk_dtensor(
replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements[-1] = DShard(0) # type: ignore[call-overload]
shard_placements = tuple(shard_placements)

return DTensor.from_local(tensor, device_mesh, replicate_placements).redistribute(
device_mesh=device_mesh,
Expand Down
9 changes: 9 additions & 0 deletions torch/distributed/fsdp/_state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ShardedTensor,
)
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed._tensor.device_mesh import mesh_resources

from torch.distributed.distributed_c10d import _get_pg_default_device
from torch.distributed.fsdp._common_utils import (
Expand Down Expand Up @@ -289,6 +290,14 @@ def _full_pre_state_dict_hook(
is supported in ``nn.Module``, this hook will be registered as a hook in
``nn.Module``.
"""
if getattr(fsdp_state, "_device_mesh", False):
parent_mesh = mesh_resources.get_parent_mesh(fsdp_state._device_mesh)
if parent_mesh:
raise RuntimeError(
f"Found FSDP's device_mesh {fsdp_state._device_mesh} has a parent device_mesh {parent_mesh}.",
"We do not support FULL_STATE_DICT for 2D FSDP + TP. Please use FSDP SHARDED_STATE_DICT instead.",
)

_common_pre_state_dict_hook(module, fsdp_state)
_common_unshard_pre_state_dict_hook(
module,
Expand Down
85 changes: 77 additions & 8 deletions torch/distributed/tensor/parallel/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
from torch.distributed._tensor import DTensor as DistributedTensor, Shard as DShard
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
from torch.distributed._tensor.device_mesh import mesh_resources

from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
from torch.distributed.fsdp._fsdp_extensions import _set_fsdp_extensions, FSDPExtensions
Expand All @@ -30,7 +31,7 @@
__all__ = ["enable_2d_with_fsdp", "DTensorExtensions"]


def _get_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]:
def _get_box(tensor: DTensor) -> Tuple[torch.Size, torch.Size]:
device_mesh = tensor.device_mesh
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"

Expand All @@ -46,19 +47,19 @@ def _get_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]:
return (torch.Size(offsets), tensor._local_tensor.size())


def _get_box_for(tensor: DistributedTensor, idx: int) -> Tuple[torch.Size, torch.Size]:
def _get_box_for(tensor: DTensor, idx: int) -> Tuple[torch.Size, torch.Size]:
offsets, size = _get_box(tensor)
return (torch.Size([val * idx for val in offsets]), size)


def _get_local_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]:
def _get_local_box(tensor: DTensor) -> Tuple[torch.Size, torch.Size]:
device_mesh = tensor.device_mesh
coord = device_mesh.get_coordinate()
assert coord is not None
return _get_box_for(tensor, coord[0])


def _create_shard_md_from_dt(dt: DistributedTensor, current_rank: int) -> ShardMetadata:
def _create_shard_md_from_dt(dt: DTensor, current_rank: int) -> ShardMetadata:
mesh = dt.device_mesh
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"

Expand All @@ -71,7 +72,7 @@ def _create_shard_md_from_dt(dt: DistributedTensor, current_rank: int) -> ShardM


def _create_sharded_tensor_md_from_dt(
dt: DistributedTensor, dt_pg: c10d.ProcessGroup
dt: DTensor, dt_pg: c10d.ProcessGroup
) -> ShardedTensorMetadata:
# This is where it gets tricky, we have to produce a ShardedTensor that has full coverage
# and yet has only one valid shard for the current rank.
Expand Down Expand Up @@ -109,7 +110,7 @@ def _create_sharded_tensor_md_from_dt(
)


def _get_dt_pg(dt: DistributedTensor) -> c10d.ProcessGroup:
def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup:
mesh = dt.device_mesh
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
dim_groups = mesh.get_dim_groups()
Expand Down Expand Up @@ -179,7 +180,7 @@ def _chunk_tensor(
init_rrefs=False,
)
return st_outer
elif type(tensor) is DistributedTensor:
elif type(tensor) is DTensor:
device_mesh = tensor.device_mesh
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"

Expand Down Expand Up @@ -220,6 +221,66 @@ def _chunk_tensor(
)


def _chunk_dtensor(
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> DTensor:
"""
Shard a tensor to chunks along the first dimension. The local rank will gets its
corresponding chunk as the local tensor to create a DTensor.
"""
parent_mesh = mesh_resources.get_parent_mesh(device_mesh)
if parent_mesh is None:
raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.")
if parent_mesh.ndim != 2:
raise RuntimeError(
f"Found parent device_mesh of ndim={parent_mesh.ndim},",
"but only 2D meshes are currently supported.",
)

# We need to explicitly call .detach() to return a new tensor detached from the current graph.
tensor = tensor.clone().detach()

# if a tensor has not yet sharded by TP
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor):

# For tensors, it is replicated across tp dimension and sharded across FSDP dimension.
# TP is the inner dimension and FSDP is the outer dimension.
# Therefore, shard placements for tensor is (Shard(0), Replicate()).
replicate_placements = [Replicate() for _ in range(parent_mesh.ndim)]
shard_placements = [Replicate() for _ in range(parent_mesh.ndim)]
shard_placements[0] = DShard(0) # type: ignore[call-overload]

return DTensor.from_local(
tensor, parent_mesh, replicate_placements
).redistribute(
device_mesh=parent_mesh,
placements=shard_placements,
)

else:
tp_placements = tensor.placements
tp_placement = tp_placements[0]

tensor = tensor.to_local()

# For DTensors, it is sharded across tp dimension first and then shardeed across FSDP dimension.
# TP is the inner dimension and FSDP is the outer dimension.
# Therefore, shard placements for tensor is (Shard(0), tp_placement).
replicate_placements = [Replicate() for _ in range(parent_mesh.ndim)]
replicate_placements[-1] = tp_placement # type: ignore[call-overload]
shard_placements = [DShard(0) for _ in range(parent_mesh.ndim)] # type: ignore[misc]
shard_placements[-1] = tp_placement # type: ignore[call-overload]

return DTensor.from_local(
tensor, parent_mesh, replicate_placements
).redistribute(
device_mesh=parent_mesh,
placements=shard_placements,
)


def _pre_load_state_dict(
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, List[Shard]]:
Expand Down Expand Up @@ -263,6 +324,14 @@ def chunk_tensor(
) -> torch.Tensor:
return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)

def chunk_dtensor(
self,
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> torch.Tensor:
return _chunk_dtensor(tensor, rank, device_mesh)

def pre_load_state_dict_transform(
self,
tensor: torch.Tensor,
Expand Down

0 comments on commit 6c136c3

Please sign in to comment.