From 25920217a7d4df18b21000b1f5dffb823524f663 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 10 Nov 2025 18:49:05 -0800 Subject: [PATCH 1/2] Update (base update) [ghstack-poisoned] From 7b6ff5efac21247f6dcb07e3cc67cf0b038548bf Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 10 Nov 2025 18:49:06 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torchtitan/experiments/simple_fsdp/simple_fsdp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 391ae74df..6597c45f9 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -185,6 +185,7 @@ def __init__( mode: str, mp_policy: MixedPrecisionPolicy | None, reduction_divide_factor: float | None, + full_dtensor: bool = False, ) -> None: super().__init__() self.device_mesh = device_mesh @@ -201,6 +202,7 @@ def __init__( mp_policy = mp_policy or MixedPrecisionPolicy() self.param_dtype: torch.dtype | None = mp_policy.param_dtype self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype + self.full_dtensor = full_dtensor def replicate_compute(self, x: DTensor) -> torch.Tensor: # data parallel runtime replicate parameters and do local compute @@ -210,6 +212,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor: non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported" if non_dp_mesh_dims > 0: + if self.full_dtensor: + raise NotImplementedError( + "full_dtensor not implemented for nD parallelisms" + ) dp_mesh = self.device_mesh # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather sharded_local_tensor = x.to_local() @@ -245,7 +251,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor: placements=self.compute_placements, forward_dtype=self.param_dtype, backward_dtype=self.reduce_dtype, - ).to_local(grad_placements=self.grad_placements) + ) + + if not self.full_dtensor: + output = output.to_local(grad_placements=self.grad_placements) else: raise AssertionError( f"Unsupported replicate compute on placement {x._spec.placements} for DTensor {x}" @@ -274,6 +283,7 @@ def data_parallel( mp_policy: MixedPrecisionPolicy | None = None, shard_dim: int = 0, reduction_divide_factor: float | None = None, + full_dtensor: bool = False, ) -> nn.Module: param_sharding: tuple[Placement, ...] if mode == "replicate": @@ -333,6 +343,7 @@ def data_parallel( mode, mp_policy=mp_policy, reduction_divide_factor=reduction_divide_factor, + full_dtensor=full_dtensor, ), ) return model