Skip to content

Commit 93cadab

Browse files
anshul-sietaf
authored andcommitted
[FSDP][Replicate] final version integrating 1D device mesh replicate into fsdp (#166433)
**Summary:** I have created a new composable replicate api that's integrated into FSDP's codebase with minimal changes. The key changes I made are when we use DDPMeshInfo, we use Replicate placements, prevent initial sharding of parameters, set worldsize to 1 to skip allgathers and reducescatter. **Test Cases** 1. pytest test/distributed/_composable/test_replicate_training.py 2. pytest test_pp_composability.py 3. pytest test_replicate_with_fsdp.py Pull Request resolved: #166433 Approved by: https://github.com/weifengpy
1 parent ce91b40 commit 93cadab

File tree

7 files changed

+116
-77
lines changed

7 files changed

+116
-77
lines changed

test/distributed/_composable/test_composability/test_pp_composability.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,11 +392,11 @@ def test_replicate_pp(self, ScheduleClass, MixedPrecisionParam):
392392
replicate_size = self.world_size // (pp_size)
393393
device_mesh = init_device_mesh(
394394
device_type,
395-
mesh_shape=(replicate_size, 1, pp_size),
396-
mesh_dim_names=("replicate", "shard", "pp"),
395+
mesh_shape=(replicate_size, pp_size),
396+
mesh_dim_names=("replicate", "pp"),
397397
)
398398
torch.manual_seed(42)
399-
dp_mesh = device_mesh["replicate", "shard"]
399+
dp_mesh = device_mesh["replicate"]
400400
pp_mesh = device_mesh["pp"]
401401
pp_group = device_mesh["pp"].get_group()
402402

@@ -582,11 +582,11 @@ def test_replicate_pp_grads(self, ScheduleClass):
582582
replicate_size = self.world_size // (pp_size)
583583
device_mesh = init_device_mesh(
584584
device_type,
585-
mesh_shape=(replicate_size, 1, pp_size),
586-
mesh_dim_names=("replicate", "shard", "pp"),
585+
mesh_shape=(replicate_size, pp_size),
586+
mesh_dim_names=("replicate", "pp"),
587587
)
588588
torch.manual_seed(42)
589-
dp_mesh = device_mesh["replicate", "shard"]
589+
dp_mesh = device_mesh["replicate"]
590590
pp_mesh = device_mesh["pp"]
591591
pp_group = device_mesh["pp"].get_group()
592592
dp_group = device_mesh["replicate"].get_group()

test/distributed/_composable/test_replicate_training.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_param_registration_after_forward(self):
108108
"""Tests the parameter registration after forward."""
109109
device = torch.device(device_type.type, 0)
110110
# Single Replicate group
111-
for reshard_after_forward in (True, False, None):
111+
for reshard_after_forward in (False,):
112112
torch.manual_seed(42)
113113
model = MLP(3, device)
114114
# Since seed is per process, not per thread, we broadcast to ensure
@@ -131,7 +131,7 @@ def test_param_registration_after_forward(self):
131131
self._assert_same_params(model.parameters(), ref_model.parameters())
132132

133133
# Multiple Replicate groups
134-
for reshard_after_forward in (True, False, None):
134+
for reshard_after_forward in (False,):
135135
torch.manual_seed(42)
136136
model = nn.Sequential(MLP(3, device), MLP(3, device))
137137
for param in model.parameters():
@@ -405,8 +405,8 @@ def _test_train_parity_multi_group(
405405
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
406406
mesh = init_device_mesh(
407407
test_device_type,
408-
(self.world_size, 1),
409-
mesh_dim_names=("replicate", "shard"),
408+
(self.world_size,),
409+
mesh_dim_names=("replicate",),
410410
)
411411
fully_shard_fn = functools.partial(
412412
replicate,
@@ -740,8 +740,8 @@ def _test_train_parity_with_activation_checkpointing(
740740
# Apply Replicate
741741
device_mesh = init_device_mesh(
742742
test_device_type,
743-
(self.world_size, 1),
744-
mesh_dim_names=("replicate", "shard"),
743+
(self.world_size,),
744+
mesh_dim_names=("replicate",),
745745
)
746746
fsdp_kwargs = {
747747
"reshard_after_forward": reshard_after_forward,
@@ -868,11 +868,11 @@ def test_gradient_accumulation(self):
868868
with/without resharding after backward.
869869
"""
870870

871-
shard_size, replicate_size = 1, self.world_size
871+
replicate_size = self.world_size
872872
meshes = init_device_mesh(
873873
device_type.type,
874-
(replicate_size, shard_size),
875-
mesh_dim_names=("replicate", "shard"),
874+
(replicate_size,),
875+
mesh_dim_names=("replicate",),
876876
)
877877
self.run_subtests(
878878
{
@@ -1145,8 +1145,8 @@ def world_size(self) -> int:
11451145
def init_global_mesh(self) -> DeviceMesh:
11461146
return init_device_mesh(
11471147
device_type.type,
1148-
(2, 1, 2),
1149-
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
1148+
(2, 2),
1149+
mesh_dim_names=("dp_replicate", "tp"),
11501150
)
11511151

11521152
@skip_if_lt_x_gpu(8)
@@ -1170,7 +1170,7 @@ def _test_replicate_tp(
11701170
mlp_dim: int,
11711171
foreach: bool,
11721172
):
1173-
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
1173+
dp_mesh, tp_mesh = global_mesh["dp_replicate"], global_mesh["tp"]
11741174
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
11751175

11761176
torch.manual_seed(42)
@@ -1229,11 +1229,9 @@ def _test_replicate_tp(
12291229

12301230
for _, p in model.named_parameters():
12311231
self.assertIsInstance(p, DTensor)
1232-
self.assertEqual(p.device_mesh.ndim, 3)
1233-
self.assertEqual(len(p.placements), 3)
1234-
self.assertEqual(
1235-
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
1236-
)
1232+
self.assertEqual(p.device_mesh.ndim, 2)
1233+
self.assertEqual(len(p.placements), 2)
1234+
self.assertEqual(p.device_mesh.mesh_dim_names, ("dp_replicate", "tp"))
12371235

12381236

12391237
if __name__ == "__main__":

test/distributed/_composable/test_replicate_with_fsdp.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _test_replicate_transformer(self, sharding_strategy):
120120
if i % 2 == 0:
121121
self.assertTrue("replicate" in _get_registry(layer))
122122
for parameter in layer.parameters():
123-
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
123+
self.assertEqual(parameter.placements, (Replicate(),))
124124
elif i % 2 == 1:
125125
self.assertTrue("fully_shard" in _get_registry(layer))
126126
for parameter in layer.parameters():
@@ -197,14 +197,14 @@ def test_replicate_tp_device_mesh(self):
197197
]
198198

199199
global_mesh = self.init_replicate_tp_mesh()
200-
replicate_mesh = global_mesh["replicate", "shard"]
200+
replicate_mesh = global_mesh["replicate"]
201201

202202
for layer in layers:
203203
replicate(layer, device_mesh=replicate_mesh)
204204

205205
for parameter in layer.parameters():
206-
self.assertEqual(parameter.device_mesh.shape, (2, 1))
207-
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
206+
self.assertEqual(parameter.device_mesh.shape, (2,))
207+
self.assertEqual(parameter.placements, (Replicate(),))
208208

209209
@skip_if_lt_x_gpu(2)
210210
def test_train_replicate_fsdp(self):
@@ -263,7 +263,6 @@ def test_train_parity_2d_mlp(self):
263263
run_subtests(
264264
self,
265265
{
266-
"reshard_after_forward": [False, True],
267266
"use_activation_checkpointing": [False, True],
268267
"mlp_dim": [3, 16, 17],
269268
},
@@ -273,7 +272,6 @@ def test_train_parity_2d_mlp(self):
273272
def _test_train_parity_2d_mlp(
274273
self,
275274
global_mesh: DeviceMesh,
276-
reshard_after_forward: bool,
277275
use_activation_checkpointing: bool,
278276
mlp_dim: int,
279277
):
@@ -287,13 +285,12 @@ def _test_train_parity_2d_mlp(
287285
torch.manual_seed(42)
288286
model = MLPStack(mlp_dim)
289287
ref_model = copy.deepcopy(model).cuda()
290-
replicate(ref_model, device_mesh=replicate_shard_mesh)
288+
replicate(ref_model, device_mesh=replicate_mesh)
291289
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
292290
model.parallelize(
293291
tp_mesh,
294292
replicate_shard_mesh,
295293
use_activation_checkpointing,
296-
reshard_after_forward=reshard_after_forward,
297294
)
298295
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
299296

torch/distributed/_composable/replicate_with_fsdp.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414
OffloadPolicy,
1515
)
1616
from torch.distributed.fsdp._fully_shard._fsdp_common import (
17+
DDPMeshInfo,
1718
detect_compiled_autograd,
18-
HSDPMeshInfo,
1919
)
2020
from torch.distributed.fsdp._fully_shard._fsdp_init import (
2121
_get_device_from_mesh,
2222
_get_managed_states,
23-
_get_post_forward_mesh_info,
2423
_init_default_fully_shard_mesh,
2524
_move_states_to_device,
2625
)
@@ -184,23 +183,19 @@ def replicate_impl(
184183
)
185184

186185
mesh = mesh or _init_default_fully_shard_mesh()
187-
if mesh.ndim != 2:
188-
raise ValueError(f"replicate expects a 2D DeviceMesh but got {mesh}")
186+
if mesh.ndim != 1:
187+
raise ValueError(f"replicate expects a 1D DeviceMesh but got {mesh}")
189188

190189
else:
191190
if mesh.mesh_dim_names is None:
192191
raise AssertionError(
193192
"Please init the 2D mesh for HSDP with mesh_dim_names specified"
194193
)
195-
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
194+
mesh_info = DDPMeshInfo(mesh, replicate_mesh_dim=0)
196195
device = _get_device_from_mesh(mesh)
197196
auto_reshard_after_forward = reshard_after_forward is None
198-
# If the user does not provide ``reshard_after_forward``, we set it to True.
199-
# During lazy_init, we identify which module is the root and override its value to False
200-
post_forward_mesh_info = _get_post_forward_mesh_info(
201-
reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type]
202-
mesh_info,
203-
)
197+
198+
post_forward_mesh_info = None
204199

205200
arg_module = module
206201
modules = (
@@ -217,7 +212,7 @@ def replicate_impl(
217212
state._fsdp_param_group = FSDPParamGroup(
218213
params,
219214
modules,
220-
mesh_info,
215+
mesh_info, # type: ignore[arg-type]
221216
post_forward_mesh_info,
222217
device,
223218
shard_placement_fn,
@@ -341,8 +336,8 @@ def replicate_mesh():
341336
device = torch._C._get_accelerator()
342337
mesh = init_device_mesh(
343338
device.type,
344-
mesh_shape=(default_pg.size(), 1),
345-
mesh_dim_names=("replicate", "shard"),
339+
mesh_shape=(default_pg.size(),),
340+
mesh_dim_names=("replicate",),
346341
)
347342
return mesh
348343

torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,11 @@ def foreach_reduce(
492492
force_sum_reduction_for_comms,
493493
)
494494
)
495-
world_size = reduce_scatter_group.size()
495+
496+
if reduce_scatter_group is None:
497+
world_size = 1
498+
else:
499+
world_size = reduce_scatter_group.size()
496500
device_handle = _get_device_handle(device.type)
497501
current_stream = device_handle.current_stream()
498502

@@ -547,7 +551,7 @@ def foreach_reduce(
547551
reduce_output.copy_(reduce_scatter_input)
548552
reduce_scatter_event = reduce_scatter_stream.record_event()
549553
post_reduce_stream = reduce_scatter_stream
550-
if all_reduce_group is not None: # HSDP
554+
if all_reduce_group is not None: # HSDP or DDP/replicate
551555
# Accumulations must run in the reduce-scatter stream
552556
if not all_reduce_grads:
553557
if partial_reduce_output is not None:
@@ -690,7 +694,7 @@ def _get_all_gather_input_metadatas(
690694

691695

692696
def _get_gradient_divide_factors(
693-
reduce_scatter_group: dist.ProcessGroup,
697+
reduce_scatter_group: Optional[dist.ProcessGroup],
694698
all_reduce_group: Optional[dist.ProcessGroup],
695699
reduce_dtype: torch.dtype,
696700
device_type: str = "",
@@ -709,8 +713,11 @@ def _get_gradient_divide_factors(
709713
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
710714
# use NCCL's built-in division to avoid separate div kernels
711715
overflow_risk = reduce_dtype not in (torch.float32, torch.bfloat16)
716+
if reduce_scatter_group is not None:
717+
data_parallel_size = reduce_scatter_group.size()
718+
else:
719+
data_parallel_size = 1
712720

713-
data_parallel_size = reduce_scatter_group.size()
714721
if all_reduce_group is not None:
715722
data_parallel_size *= all_reduce_group.size()
716723

torch/distributed/fsdp/_fully_shard/_fsdp_param.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch._prims_common import make_contiguous_strides_for
1212
from torch.distributed._functional_collectives import AsyncCollectiveTensor
1313
from torch.distributed.device_mesh import DeviceMesh
14+
from torch.distributed.fsdp._fully_shard._fsdp_common import DDPMeshInfo
1415
from torch.distributed.tensor import DTensor, Replicate, Shard
1516
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
1617
from torch.distributed.tensor.placement_types import _StridedShard, Placement
@@ -306,22 +307,29 @@ def _init_sharded_param(
306307
f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}."
307308
)
308309
self._spmd_placements: tuple[Placement, ...]
309-
dp_shard_tp_placement = (
310-
(
311-
_StridedShard(shard_dim, split_factor=split_factor)
312-
if split_factor > 1
313-
else fsdp_placement
314-
),
315-
*self._tp_spec.placements,
316-
)
317-
if dp_mesh.ndim == 1: # FSDP
318-
self._spmd_placements = dp_shard_tp_placement
319-
else: # HSDP
310+
if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP
311+
dp_shard_tp_placement = (
312+
(
313+
_StridedShard(shard_dim, split_factor=split_factor)
314+
if split_factor > 1
315+
else fsdp_placement
316+
),
317+
*self._tp_spec.placements,
318+
)
319+
else: # DDP
320+
dp_shard_tp_placement = (
321+
(Replicate()),
322+
*self._tp_spec.placements,
323+
)
324+
if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP
320325
if self.mesh_info.replicate_mesh_dim != 0:
321326
raise AssertionError(
322327
f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}"
323328
)
324329
self._spmd_placements = (Replicate(),) + dp_shard_tp_placement
330+
else: # FSDP or DDP
331+
self._spmd_placements = dp_shard_tp_placement
332+
325333
self._sharding_spec = DTensorSpec(
326334
self._spmd_mesh,
327335
self._spmd_placements,
@@ -330,10 +338,12 @@ def _init_sharded_param(
330338
param_data = cast(DTensor, param)._local_tensor
331339
else:
332340
self._spmd_mesh = self.mesh_info.mesh
333-
if isinstance(self.mesh_info, HSDPMeshInfo):
341+
if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP
334342
self._spmd_placements = (Replicate(), fsdp_placement)
335-
else:
343+
elif isinstance(self.mesh_info, FSDPMeshInfo): # FSDP
336344
self._spmd_placements = (fsdp_placement,)
345+
elif isinstance(self.mesh_info, DDPMeshInfo): # DDP
346+
self._spmd_placements = (Replicate(),)
337347
self._sharding_spec = DTensorSpec(
338348
self._spmd_mesh,
339349
self._spmd_placements,
@@ -351,8 +361,13 @@ def _init_sharded_param(
351361
)
352362
self._orig_size = param_data.size()
353363
self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
354-
shard_rank = self.mesh_info.shard_mesh_rank
355-
shard_world_size = self.mesh_info.shard_mesh_size
364+
if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP
365+
shard_rank = self.mesh_info.shard_mesh_rank
366+
shard_world_size = self.mesh_info.shard_mesh_size
367+
else: # DDP
368+
shard_rank = 0
369+
shard_world_size = 1
370+
356371
if shard_dim > 0 and param_data.size(shard_dim) % shard_world_size != 0:
357372
# If sharding on nonzero dim, require even sharding for now because
358373
# the uneven sharding (1) requires extra copies before/after FSDP
@@ -401,12 +416,20 @@ def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None
401416
if mesh_info is None:
402417
raise AssertionError("Expected post_forward_mesh_info to not be None")
403418
param_data = param._local_tensor if isinstance(param, DTensor) else param
404-
chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
405-
self.sharded_post_forward_size = _get_dim_chunked_size(
406-
chunks[mesh_info.shard_mesh_rank],
407-
param_data.size(),
408-
dim=self.fsdp_placement.dim,
409-
)
419+
if isinstance(mesh_info, FSDPMeshInfo):
420+
chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
421+
self.sharded_post_forward_size = _get_dim_chunked_size(
422+
chunks[mesh_info.shard_mesh_rank],
423+
param_data.size(),
424+
dim=self.fsdp_placement.dim,
425+
)
426+
else: # DDP
427+
chunks = _chunk_with_empty(param_data, 1, dim=0)
428+
self.sharded_post_forward_size = _get_dim_chunked_size(
429+
chunks[0],
430+
param_data.size(),
431+
dim=self.fsdp_placement.dim,
432+
)
410433
self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for(
411434
self.sharded_post_forward_size
412435
)

0 commit comments

Comments
 (0)