Skip to content

[FSDP2] Eager-Mode Execution Tracker #120003

@awgu

Description

@awgu

Work Items

  • Meta-device initialization / _apply() methods
    • Support initial meta-device initialization using swap_tensors path
    • Remove manual padding logic after [Feature][DTensor] Manage additional _padded_local_tensor attribute #113045 @wz337
    • Outcome: Once we have DTensor manage the padded storage, then FSDP only needs to maintain a reference to the DTensor, not its _local_tensor. We should be able to do meta-device initialization via nn.Module.to_empty() (using swap_tensors path) followed by normal initialization ops (e.g. torch.nn.init.<...>) that runs through DTensor's op dispatch without any padding logic in _apply().
  • clip_grad_norm_() / sharded gradient scaler
    • Support DTensor-based clip_grad_norm_(), e.g. implementing torch.linalg.vector_norm()
    • Support DTensor-based clip_grad_norm_() with foreach=True
    • Support DTensor-based sharded gradient scaler (low priority since only used for fp16) Gradient scaler for DTensor #132816
    • Outcome: We can replace existing FSDP with per-parameter FSDP in torchtrain.
  • 2D sharded state dict
    • Implement strided sharding placement to handle (Shard(0), Shard(0)) placements where FSDP shards dim-0 and TP shards dim-1 [1/N][dtensor] introduce StridedShard placement type and _split_tensor() logic #126697
    • Validate 2D sharded state dict integration with distributed checkpointing excluding checkpoint resharding (so just saving/loading same world size)
    • Validate 2D sharded state dict integration with distributed checkpointing including checkpoint resharding @wz337
    • Outcome: We can do 2D training including checkpointing without communication and can reshard checkpoints to different world sizes.
  • distribute_tensor()
    • Allow distribute_tensor() to take in a DTensor to simplify/robust-ify construction of sharded parameter and sharded post-forward parameter @wz337
    • Outcome: We remove all custom DTensor construction code from per-parameter FSDP (_init_sharded_param, _init_sharded_post_forward_param_metadata, to_sharded_post_forward).
  • Optimizer <> DTensor
    • Enable foreach=True by default for Adam/AdamW @wz337
    • Support fused=True for Adam/AdamW @wz337
    • Support foreach=True for all torch-native optimizers @wz337
    • Outcome: We have competitive optimizer performance on common optimizers without user-code changes (like passing foreach=True explicitly). This includes new CPU fused Adam/AdamW kernels.
    • Support custom (user-defined) ops [DTensor] add support for custom op registration #131108
    • Outcome: We can support custom optimizers (e.g. Apex, low-precision, etc.).
  • Reduce-scatter copy-in kernel/fast-path
    • Add new aten op for chunk-cat @BoyuanFeng
    • Outcome: We can achieve competitive performance with existing FSDP even on low-compute-density workloads (e.g. recommendation models).
  • HSDP
    • Add hybrid sharding (when passing 2D mesh arg) @weifengpy
    • Validate HSDP sharded state dict integration with distributed checkpointing
    • Validate HSDP + TP sharded state dict integration with distributed checkpointing
    • Outcome: Existing HSDP has adoption both internally and externally. We should support it under per-parameter and migrate users.
  • CPU offloading
    • Add CPU offloading to per-parameter FSDP with async H2D parameter copy and async D2H gradient copy (at least when not accumulating gradient)
  • FSDP extensions
    • Add pre/post-all-gather extensions support @awgu
    • Add pre/post-all-gather extensions for float8_experimental @awgu (dynamic scaling eager done)
    • Add pre/post all-gather extensions for QLoRA @weifengpy

References

RFC: #114299

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: fsdptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions