Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2] Factored out MLPStack to de-dup code #126070

Closed
wants to merge 2 commits into from

Conversation

Copy link

pytorch-bot bot commented May 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126070

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7e02f46 with merge base afda668 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ci-td-distributed oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels May 13, 2024
awgu added a commit that referenced this pull request May 13, 2024
ghstack-source-id: 22426f4c807175d3b95faacef419fe635268eb30
Pull Request resolved: #126070
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit that referenced this pull request May 13, 2024
ghstack-source-id: 2407ba423f3202f733711861259ff2f6c30ca408
Pull Request resolved: #126070
@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels May 13, 2024
@awgu awgu marked this pull request as ready for review May 13, 2024 17:48
@awgu awgu requested review from wanchaol and weifengpy May 13, 2024 17:48
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label May 13, 2024
device_mesh=tp_mesh,
# Leave the layer norm as implicitly replicated
parallelize_plan={
# Pass `use_local_output=False` to keep as DTensor to preserve
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should also add SequenceParallel(sequence_dim=0) to the parallelize_plan, this is to ensure the model params are all 1D DTensors, as @wz337 was hitting something then enabling for_each optimizer by default, we would need params to be 1D DTensor on tp_mesh_dim, so that the for_each ops could receive correct 2D sharding inputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think originally I had it like this to test the implicitly replicated norm weight (as opposed to having it explicitly a DTensor). If I am following correctly, we can migrate to SequenceParallel(sequence_dim=0) only if needed later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think implicit replication won't work for 2D cases, as it doesn't know how to implicit replicate a 1D DTensor to 2D (currently only works for torch.Tensor to DTensor). I will turn off foreach for the 2D test and submit a follow up PR to update this.

@awgu
Copy link
Contributor Author

awgu commented May 14, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 14, 2024
This simplifies the test a bit.

**Context**
Option 1: Ref model is data parallel. Each rank's ref model receives local batch. We manually all-reduce gradients and divide them by world size to match DDP/FSDP semantics.
Option 2: Ref model is not data parallel. Each rank's ref model receives the same global batch. We manually divide the ref model's gradients by world size to match DDP/FSDP semantics. (Note that all ranks have the same ref model and same global batch.)

All of our other unit tests are written following Option 1, which is simpler and a more direct comparison to what our claimed semantics are. This PR switches the gradient accumulation test from being written as following Option 2 to as following Option 1.

Pull Request resolved: #126161
Approved by: https://github.com/wanchaol
ghstack dependencies: #126067, #126070
pytorchmergebot pushed a commit that referenced this pull request May 16, 2024
**Context**
For FSDP, gradient accumulation across microbatches has two flavors: (1) reduce-scatter or (2) no reduce-scatter. (1) incurs the collective per microbatch backward but saves gradient memory (storing the sharded gradients), while (2) avoids the communication but uses more gradient memory (storing the unsharded gradients).
- FSDP2 offers (1) without any intervention. The user should simply make sure to run the optimizer step after `K` microbatches for `K > 1`.
- FSDP2 offers (2) via `module.set_requires_gradient_sync()` (e.g. `module.set_requires_gradient_sync(is_last_microbatch)`.

For HSDP, since we reduce-scatter and then all-reduce, we have additional flexibility and get three flavors: (1) reduce-scatter and all-reduce, (2) reduce-scatter but no all-reduce, and (3) no reduce-scatter and no all-reduce. This PR adds support for (2).
- FSDP2 offers (1) without any intervention like mentioned above.
- FSDP2 offers (3) via `module.set_requires_gradient_sync()` like mentioned above.
- FSDP2 offers (2) via `module.set_requires_all_reduce()` similar to `set_requires_gradient_sync()`.

**Overview**
For HSDP, to reduce-scatter but not all-reduce during gradient accumulation, the user can do something like:
```
for microbatch_idx, microbatch in enumerate(microbatches):
    is_last_microbatch = microbatch_idx == len(microbatches) - 1
    model.set_requires_all_reduce(is_last_microbatch)
    # Run forward/backward
```

This PR also makes the minor change of making the `recurse: bool` argument in these setter methods to be kwarg only.

**Developer Notes**
We choose to implement this by saving the partial reduce output to the `FSDPParamGroup` for simplicity, where we assume that the set of parameters that receive gradients does not change across microbatches. An alternative would be to view into the partial reduce output per parameter and save the view to each parameter. We prefer to avoid this alternative for now because it introduces more complexity to do extra viewing when saving the partial reduce output to each parameter, accumulating into them, and accumulating back to the last microbatch's reduce output.

Pull Request resolved: #126166
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: #126067, #126070, #126161
ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
…126166)

**Context**
For FSDP, gradient accumulation across microbatches has two flavors: (1) reduce-scatter or (2) no reduce-scatter. (1) incurs the collective per microbatch backward but saves gradient memory (storing the sharded gradients), while (2) avoids the communication but uses more gradient memory (storing the unsharded gradients).
- FSDP2 offers (1) without any intervention. The user should simply make sure to run the optimizer step after `K` microbatches for `K > 1`.
- FSDP2 offers (2) via `module.set_requires_gradient_sync()` (e.g. `module.set_requires_gradient_sync(is_last_microbatch)`.

For HSDP, since we reduce-scatter and then all-reduce, we have additional flexibility and get three flavors: (1) reduce-scatter and all-reduce, (2) reduce-scatter but no all-reduce, and (3) no reduce-scatter and no all-reduce. This PR adds support for (2).
- FSDP2 offers (1) without any intervention like mentioned above.
- FSDP2 offers (3) via `module.set_requires_gradient_sync()` like mentioned above.
- FSDP2 offers (2) via `module.set_requires_all_reduce()` similar to `set_requires_gradient_sync()`.

**Overview**
For HSDP, to reduce-scatter but not all-reduce during gradient accumulation, the user can do something like:
```
for microbatch_idx, microbatch in enumerate(microbatches):
    is_last_microbatch = microbatch_idx == len(microbatches) - 1
    model.set_requires_all_reduce(is_last_microbatch)
    # Run forward/backward
```

This PR also makes the minor change of making the `recurse: bool` argument in these setter methods to be kwarg only.

**Developer Notes**
We choose to implement this by saving the partial reduce output to the `FSDPParamGroup` for simplicity, where we assume that the set of parameters that receive gradients does not change across microbatches. An alternative would be to view into the partial reduce output per parameter and save the view to each parameter. We prefer to avoid this alternative for now because it introduces more complexity to do extra viewing when saving the partial reduce output to each parameter, accumulating into them, and accumulating back to the last microbatch's reduce output.

Pull Request resolved: pytorch#126166
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: pytorch#126067, pytorch#126070, pytorch#126161
@github-actions github-actions bot deleted the gh/awgu/582/head branch June 14, 2024 01:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-td-distributed ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants