diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 391ae74df..6597c45f9 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -185,6 +185,7 @@ def __init__( mode: str, mp_policy: MixedPrecisionPolicy | None, reduction_divide_factor: float | None, + full_dtensor: bool = False, ) -> None: super().__init__() self.device_mesh = device_mesh @@ -201,6 +202,7 @@ def __init__( mp_policy = mp_policy or MixedPrecisionPolicy() self.param_dtype: torch.dtype | None = mp_policy.param_dtype self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype + self.full_dtensor = full_dtensor def replicate_compute(self, x: DTensor) -> torch.Tensor: # data parallel runtime replicate parameters and do local compute @@ -210,6 +212,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor: non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported" if non_dp_mesh_dims > 0: + if self.full_dtensor: + raise NotImplementedError( + "full_dtensor not implemented for nD parallelisms" + ) dp_mesh = self.device_mesh # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather sharded_local_tensor = x.to_local() @@ -245,7 +251,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor: placements=self.compute_placements, forward_dtype=self.param_dtype, backward_dtype=self.reduce_dtype, - ).to_local(grad_placements=self.grad_placements) + ) + + if not self.full_dtensor: + output = output.to_local(grad_placements=self.grad_placements) else: raise AssertionError( f"Unsupported replicate compute on placement {x._spec.placements} for DTensor {x}" @@ -274,6 +283,7 @@ def data_parallel( mp_policy: MixedPrecisionPolicy | None = None, shard_dim: int = 0, reduction_divide_factor: float | None = None, + full_dtensor: bool = False, ) -> nn.Module: param_sharding: tuple[Placement, ...] if mode == "replicate": @@ -333,6 +343,7 @@ def data_parallel( mode, mp_policy=mp_policy, reduction_divide_factor=reduction_divide_factor, + full_dtensor=full_dtensor, ), ) return model