-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2] Factored out MLPStack
to de-dup code
#126070
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126070
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7e02f46 with merge base afda668 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 22426f4c807175d3b95faacef419fe635268eb30 Pull Request resolved: #126070
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: 2407ba423f3202f733711861259ff2f6c30ca408 Pull Request resolved: #126070
device_mesh=tp_mesh, | ||
# Leave the layer norm as implicitly replicated | ||
parallelize_plan={ | ||
# Pass `use_local_output=False` to keep as DTensor to preserve |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should also add SequenceParallel(sequence_dim=0)
to the parallelize_plan, this is to ensure the model params are all 1D DTensors, as @wz337 was hitting something then enabling for_each optimizer by default, we would need params to be 1D DTensor on tp_mesh_dim, so that the for_each ops could receive correct 2D sharding inputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think originally I had it like this to test the implicitly replicated norm weight (as opposed to having it explicitly a DTensor
). If I am following correctly, we can migrate to SequenceParallel(sequence_dim=0)
only if needed later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think implicit replication won't work for 2D cases, as it doesn't know how to implicit replicate a 1D DTensor to 2D (currently only works for torch.Tensor to DTensor). I will turn off foreach for the 2D test and submit a follow up PR to update this.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This simplifies the test a bit. **Context** Option 1: Ref model is data parallel. Each rank's ref model receives local batch. We manually all-reduce gradients and divide them by world size to match DDP/FSDP semantics. Option 2: Ref model is not data parallel. Each rank's ref model receives the same global batch. We manually divide the ref model's gradients by world size to match DDP/FSDP semantics. (Note that all ranks have the same ref model and same global batch.) All of our other unit tests are written following Option 1, which is simpler and a more direct comparison to what our claimed semantics are. This PR switches the gradient accumulation test from being written as following Option 2 to as following Option 1. Pull Request resolved: #126161 Approved by: https://github.com/wanchaol ghstack dependencies: #126067, #126070
**Context** For FSDP, gradient accumulation across microbatches has two flavors: (1) reduce-scatter or (2) no reduce-scatter. (1) incurs the collective per microbatch backward but saves gradient memory (storing the sharded gradients), while (2) avoids the communication but uses more gradient memory (storing the unsharded gradients). - FSDP2 offers (1) without any intervention. The user should simply make sure to run the optimizer step after `K` microbatches for `K > 1`. - FSDP2 offers (2) via `module.set_requires_gradient_sync()` (e.g. `module.set_requires_gradient_sync(is_last_microbatch)`. For HSDP, since we reduce-scatter and then all-reduce, we have additional flexibility and get three flavors: (1) reduce-scatter and all-reduce, (2) reduce-scatter but no all-reduce, and (3) no reduce-scatter and no all-reduce. This PR adds support for (2). - FSDP2 offers (1) without any intervention like mentioned above. - FSDP2 offers (3) via `module.set_requires_gradient_sync()` like mentioned above. - FSDP2 offers (2) via `module.set_requires_all_reduce()` similar to `set_requires_gradient_sync()`. **Overview** For HSDP, to reduce-scatter but not all-reduce during gradient accumulation, the user can do something like: ``` for microbatch_idx, microbatch in enumerate(microbatches): is_last_microbatch = microbatch_idx == len(microbatches) - 1 model.set_requires_all_reduce(is_last_microbatch) # Run forward/backward ``` This PR also makes the minor change of making the `recurse: bool` argument in these setter methods to be kwarg only. **Developer Notes** We choose to implement this by saving the partial reduce output to the `FSDPParamGroup` for simplicity, where we assume that the set of parameters that receive gradients does not change across microbatches. An alternative would be to view into the partial reduce output per parameter and save the view to each parameter. We prefer to avoid this alternative for now because it introduces more complexity to do extra viewing when saving the partial reduce output to each parameter, accumulating into them, and accumulating back to the last microbatch's reduce output. Pull Request resolved: #126166 Approved by: https://github.com/weifengpy, https://github.com/wanchaol ghstack dependencies: #126067, #126070, #126161
…126166) **Context** For FSDP, gradient accumulation across microbatches has two flavors: (1) reduce-scatter or (2) no reduce-scatter. (1) incurs the collective per microbatch backward but saves gradient memory (storing the sharded gradients), while (2) avoids the communication but uses more gradient memory (storing the unsharded gradients). - FSDP2 offers (1) without any intervention. The user should simply make sure to run the optimizer step after `K` microbatches for `K > 1`. - FSDP2 offers (2) via `module.set_requires_gradient_sync()` (e.g. `module.set_requires_gradient_sync(is_last_microbatch)`. For HSDP, since we reduce-scatter and then all-reduce, we have additional flexibility and get three flavors: (1) reduce-scatter and all-reduce, (2) reduce-scatter but no all-reduce, and (3) no reduce-scatter and no all-reduce. This PR adds support for (2). - FSDP2 offers (1) without any intervention like mentioned above. - FSDP2 offers (3) via `module.set_requires_gradient_sync()` like mentioned above. - FSDP2 offers (2) via `module.set_requires_all_reduce()` similar to `set_requires_gradient_sync()`. **Overview** For HSDP, to reduce-scatter but not all-reduce during gradient accumulation, the user can do something like: ``` for microbatch_idx, microbatch in enumerate(microbatches): is_last_microbatch = microbatch_idx == len(microbatches) - 1 model.set_requires_all_reduce(is_last_microbatch) # Run forward/backward ``` This PR also makes the minor change of making the `recurse: bool` argument in these setter methods to be kwarg only. **Developer Notes** We choose to implement this by saving the partial reduce output to the `FSDPParamGroup` for simplicity, where we assume that the set of parameters that receive gradients does not change across microbatches. An alternative would be to view into the partial reduce output per parameter and save the view to each parameter. We prefer to avoid this alternative for now because it introduces more complexity to do extra viewing when saving the partial reduce output to each parameter, accumulating into them, and accumulating back to the last microbatch's reduce output. Pull Request resolved: pytorch#126166 Approved by: https://github.com/weifengpy, https://github.com/wanchaol ghstack dependencies: pytorch#126067, pytorch#126070, pytorch#126161
Stack from ghstack (oldest at bottom):
set_all_reduce_gradients=False
for HSDP #126166MLPStack
to de-dup code #126070CommDebugMode
in grad acc test #126067cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k