1111from torch ._prims_common import make_contiguous_strides_for
1212from torch .distributed ._functional_collectives import AsyncCollectiveTensor
1313from torch .distributed .device_mesh import DeviceMesh
14+ from torch .distributed .fsdp ._fully_shard ._fsdp_common import DDPMeshInfo
1415from torch .distributed .tensor import DTensor , Replicate , Shard
1516from torch .distributed .tensor ._dtensor_spec import DTensorSpec , TensorMeta
1617from torch .distributed .tensor .placement_types import _StridedShard , Placement
@@ -306,22 +307,29 @@ def _init_sharded_param(
306307 f"or 4 (HSDP+EP+TP) but got { self ._spmd_mesh .ndim } ."
307308 )
308309 self ._spmd_placements : tuple [Placement , ...]
309- dp_shard_tp_placement = (
310- (
311- _StridedShard (shard_dim , split_factor = split_factor )
312- if split_factor > 1
313- else fsdp_placement
314- ),
315- * self ._tp_spec .placements ,
316- )
317- if dp_mesh .ndim == 1 : # FSDP
318- self ._spmd_placements = dp_shard_tp_placement
319- else : # HSDP
310+ if isinstance (self .mesh_info , FSDPMeshInfo ): # FSDP or HSDP
311+ dp_shard_tp_placement = (
312+ (
313+ _StridedShard (shard_dim , split_factor = split_factor )
314+ if split_factor > 1
315+ else fsdp_placement
316+ ),
317+ * self ._tp_spec .placements ,
318+ )
319+ else : # DDP
320+ dp_shard_tp_placement = (
321+ (Replicate ()),
322+ * self ._tp_spec .placements ,
323+ )
324+ if isinstance (self .mesh_info , HSDPMeshInfo ): # HSDP
320325 if self .mesh_info .replicate_mesh_dim != 0 :
321326 raise AssertionError (
322327 f"Expected replicate_mesh_dim to be 0, got { self .mesh_info .replicate_mesh_dim } "
323328 )
324329 self ._spmd_placements = (Replicate (),) + dp_shard_tp_placement
330+ else : # FSDP or DDP
331+ self ._spmd_placements = dp_shard_tp_placement
332+
325333 self ._sharding_spec = DTensorSpec (
326334 self ._spmd_mesh ,
327335 self ._spmd_placements ,
@@ -330,10 +338,12 @@ def _init_sharded_param(
330338 param_data = cast (DTensor , param )._local_tensor
331339 else :
332340 self ._spmd_mesh = self .mesh_info .mesh
333- if isinstance (self .mesh_info , HSDPMeshInfo ):
341+ if isinstance (self .mesh_info , HSDPMeshInfo ): # HSDP
334342 self ._spmd_placements = (Replicate (), fsdp_placement )
335- else :
343+ elif isinstance ( self . mesh_info , FSDPMeshInfo ): # FSDP
336344 self ._spmd_placements = (fsdp_placement ,)
345+ elif isinstance (self .mesh_info , DDPMeshInfo ): # DDP
346+ self ._spmd_placements = (Replicate (),)
337347 self ._sharding_spec = DTensorSpec (
338348 self ._spmd_mesh ,
339349 self ._spmd_placements ,
@@ -351,8 +361,13 @@ def _init_sharded_param(
351361 )
352362 self ._orig_size = param_data .size ()
353363 self ._contiguous_orig_stride = make_contiguous_strides_for (self ._orig_size )
354- shard_rank = self .mesh_info .shard_mesh_rank
355- shard_world_size = self .mesh_info .shard_mesh_size
364+ if isinstance (self .mesh_info , FSDPMeshInfo ): # FSDP or HSDP
365+ shard_rank = self .mesh_info .shard_mesh_rank
366+ shard_world_size = self .mesh_info .shard_mesh_size
367+ else : # DDP
368+ shard_rank = 0
369+ shard_world_size = 1
370+
356371 if shard_dim > 0 and param_data .size (shard_dim ) % shard_world_size != 0 :
357372 # If sharding on nonzero dim, require even sharding for now because
358373 # the uneven sharding (1) requires extra copies before/after FSDP
@@ -401,12 +416,20 @@ def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None
401416 if mesh_info is None :
402417 raise AssertionError ("Expected post_forward_mesh_info to not be None" )
403418 param_data = param ._local_tensor if isinstance (param , DTensor ) else param
404- chunks = _chunk_with_empty (param_data , mesh_info .shard_mesh_size , dim = 0 )
405- self .sharded_post_forward_size = _get_dim_chunked_size (
406- chunks [mesh_info .shard_mesh_rank ],
407- param_data .size (),
408- dim = self .fsdp_placement .dim ,
409- )
419+ if isinstance (mesh_info , FSDPMeshInfo ):
420+ chunks = _chunk_with_empty (param_data , mesh_info .shard_mesh_size , dim = 0 )
421+ self .sharded_post_forward_size = _get_dim_chunked_size (
422+ chunks [mesh_info .shard_mesh_rank ],
423+ param_data .size (),
424+ dim = self .fsdp_placement .dim ,
425+ )
426+ else : # DDP
427+ chunks = _chunk_with_empty (param_data , 1 , dim = 0 )
428+ self .sharded_post_forward_size = _get_dim_chunked_size (
429+ chunks [0 ],
430+ param_data .size (),
431+ dim = self .fsdp_placement .dim ,
432+ )
410433 self .contiguous_sharded_post_forward_stride = make_contiguous_strides_for (
411434 self .sharded_post_forward_size
412435 )
0 commit comments