-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Open
Labels
module: fsdptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Work Items
- Meta-device initialization /
_apply()
methods- Support initial meta-device initialization using
swap_tensors
path - Remove manual padding logic after [Feature][DTensor] Manage additional
_padded_local_tensor
attribute #113045 @wz337 - Outcome: Once we have
DTensor
manage the padded storage, then FSDP only needs to maintain a reference to theDTensor
, not its_local_tensor
. We should be able to do meta-device initialization viann.Module.to_empty()
(usingswap_tensors
path) followed by normal initialization ops (e.g.torch.nn.init.<...>
) that runs throughDTensor
's op dispatch without any padding logic in_apply()
.
- Support initial meta-device initialization using
clip_grad_norm_()
/ sharded gradient scaler- Support
DTensor
-basedclip_grad_norm_()
, e.g. implementingtorch.linalg.vector_norm()
- Support
DTensor
-basedclip_grad_norm_()
withforeach=True
- Support
DTensor
-based sharded gradient scaler (low priority since only used for fp16) Gradient scaler for DTensor #132816 - Outcome: We can replace existing FSDP with per-parameter FSDP in torchtrain.
- Support
- 2D sharded state dict
- Implement strided sharding placement to handle
(Shard(0), Shard(0))
placements where FSDP shards dim-0 and TP shards dim-1 [1/N][dtensor] introduce StridedShard placement type and _split_tensor() logic #126697 - Validate 2D sharded state dict integration with distributed checkpointing excluding checkpoint resharding (so just saving/loading same world size)
- Validate 2D sharded state dict integration with distributed checkpointing including checkpoint resharding @wz337
- Outcome: We can do 2D training including checkpointing without communication and can reshard checkpoints to different world sizes.
- Implement strided sharding placement to handle
distribute_tensor()
- Allow
distribute_tensor()
to take in aDTensor
to simplify/robust-ify construction of sharded parameter and sharded post-forward parameter @wz337 - Outcome: We remove all custom
DTensor
construction code from per-parameter FSDP (_init_sharded_param
,_init_sharded_post_forward_param_metadata
,to_sharded_post_forward
).
- Allow
- Optimizer <>
DTensor
- Enable
foreach=True
by default for Adam/AdamW @wz337 - Support
fused=True
for Adam/AdamW @wz337 - Support
foreach=True
for all torch-native optimizers @wz337 - Outcome: We have competitive optimizer performance on common optimizers without user-code changes (like passing
foreach=True
explicitly). This includes new CPU fused Adam/AdamW kernels. - Support custom (user-defined) ops [DTensor] add support for custom op registration #131108
- Outcome: We can support custom optimizers (e.g. Apex, low-precision, etc.).
- Enable
- Reduce-scatter copy-in kernel/fast-path
- Add new aten op for chunk-cat @BoyuanFeng
- Outcome: We can achieve competitive performance with existing FSDP even on low-compute-density workloads (e.g. recommendation models).
- HSDP
- Add hybrid sharding (when passing 2D
mesh
arg) @weifengpy - Validate HSDP sharded state dict integration with distributed checkpointing
- Validate HSDP + TP sharded state dict integration with distributed checkpointing
- Outcome: Existing HSDP has adoption both internally and externally. We should support it under per-parameter and migrate users.
- Add hybrid sharding (when passing 2D
- CPU offloading
- Add CPU offloading to per-parameter FSDP with async H2D parameter copy and async D2H gradient copy (at least when not accumulating gradient)
- FSDP extensions
- Add pre/post-all-gather extensions support @awgu
- Add pre/post-all-gather extensions for float8_experimental @awgu (dynamic scaling eager done)
- Add pre/post all-gather extensions for QLoRA @weifengpy
References
RFC: #114299
Metadata
Metadata
Assignees
Labels
module: fsdptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module