-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2] Added APIs for explicit fwd/bwd prefetching #128884
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128884
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 25fff88 with merge base 24443fe (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: eaa86396ed1254451dbd453616b0ffccd133dabe Pull Request resolved: #128884
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: 8f387bf5eb645d3d7417eecda01182dfc869e21f Pull Request resolved: #128884
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: 86e664adafe8e22d99a3809ba6712d54c25dc5e5 Pull Request resolved: #128884
This PR adds two APIs `set_modules_to_forward_prefetch` and `set_modules_to_backward_prefetch` to enable explicit forward/backward all-gather prefetching, respectively. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: a41b1b6d326513a78dfccca5a7dfbb1a503c4713 Pull Request resolved: #128884
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great!
@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious to know the use case. This will only help in case of FSDPModules
that have unbalanced compute. So that we can use the compute of a larger FSDPModule
to overlap the prefetch of smaller ones.
Also, will your extension of wrapping List[nn.Module]
eliminate the need for this?
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR adds `set_post_optim_event` that allows power users to provide their own CUDA event that is recorded after the optimizer step for the FSDP root module to wait the all-gather streams on. ``` def set_post_optim_event(self, event: torch.cuda.Event) -> None: ``` By default, the root would have the all-gather streams wait on the current stream (`wait_stream`), which may introduce false dependencies if there is unrelated computation after the optimizer step and before the wait. For example, this pattern can appear in recommendation models. To avoid those false dependencies while preserving the correctness guarantee, we provide this API so that the user can provide their own CUDA event to wait the all-gather streams on. We include both correctness test (`test_fully_shard_training.py`) and overlap test (`test_fully_shard_overlap.py`). --- One possible way to use the API is to register a post-step hook on the optimizer. For example: https://github.com/pytorch/pytorch/blob/12e8d1399b979b45d16f0934017f742d01ab2b8d/test/distributed/_composable/fsdp/test_fully_shard_training.py#L546-L552 Pull Request resolved: #128975 Approved by: https://github.com/sanketpurandare, https://github.com/weifengpy ghstack dependencies: #128884
Stack from ghstack (oldest at bottom):
This PR adds two APIs
set_modules_to_forward_prefetch
andset_modules_to_backward_prefetch
to enable explicit forward/backward all-gather prefetching, respectively.Motivation
FSDP2 implements reasonable defaults for forward and backward prefetching. In forward, it uses implicit prefetching and allows two all-gather output tensors to be alive at once (so that the current all-gather copy-out can overlap with the next all-gather). In backward, it uses explicit prefetching based on the reverse post-forward order.
However, there may be cases where with expert knowledge, we can reduce communication bubbles by moving all-gathers manually. One way to expose such behavior is to expose prefetching limits, i.e. integers that configure how many outstanding all-gathers/all-gather output tensors can be alive at once. IMIHO, this leans toward easy, not simple (see PyTorch design principles).
The crux of the problem is that there may be special cases where manual intervention can give better performance. Exposing a prefetching limit and allowing users to pass a value >1 just smooths over the problem since such a limit would generally apply over the entire model even though it possibly should not. Then, expert users will see a specific all-gather that they want to deviate from this limit, and there is little we can do.
Thus, we instead choose to expose the most primitive extension point: namely, every
FSDPModule
gives an opportunity to prefetch other all-gathers in forward and in backward. How to leverage this extension point is fully up to the user. Implementing the prefetch limit can be done using this extension point (e.g. record the post-forward order yourself using forward hooks, iterate over that order, and call theset_modules_to_forward_prefetch
/set_modules_to_backward_prefetch
APIs).cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k
Differential Revision: D58700346