Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2] Added APIs for explicit fwd/bwd prefetching #128884

Closed
wants to merge 4 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Jun 17, 2024

Stack from ghstack (oldest at bottom):

This PR adds two APIs set_modules_to_forward_prefetch and set_modules_to_backward_prefetch to enable explicit forward/backward all-gather prefetching, respectively.

def set_modules_to_forward_prefetch(self, modules: List[FSDPModule]): -> None
def set_modules_to_backward_prefetch(self, modules: List[FSDPModule]): -> None

Motivation
FSDP2 implements reasonable defaults for forward and backward prefetching. In forward, it uses implicit prefetching and allows two all-gather output tensors to be alive at once (so that the current all-gather copy-out can overlap with the next all-gather). In backward, it uses explicit prefetching based on the reverse post-forward order.

However, there may be cases where with expert knowledge, we can reduce communication bubbles by moving all-gathers manually. One way to expose such behavior is to expose prefetching limits, i.e. integers that configure how many outstanding all-gathers/all-gather output tensors can be alive at once. IMIHO, this leans toward easy, not simple (see PyTorch design principles).

The crux of the problem is that there may be special cases where manual intervention can give better performance. Exposing a prefetching limit and allowing users to pass a value >1 just smooths over the problem since such a limit would generally apply over the entire model even though it possibly should not. Then, expert users will see a specific all-gather that they want to deviate from this limit, and there is little we can do.

Thus, we instead choose to expose the most primitive extension point: namely, every FSDPModule gives an opportunity to prefetch other all-gathers in forward and in backward. How to leverage this extension point is fully up to the user. Implementing the prefetch limit can be done using this extension point (e.g. record the post-forward order yourself using forward hooks, iterate over that order, and call the set_modules_to_forward_prefetch / set_modules_to_backward_prefetch APIs).

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Differential Revision: D58700346

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Jun 17, 2024
Copy link

pytorch-bot bot commented Jun 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128884

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 25fff88 with merge base 24443fe (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

awgu added a commit that referenced this pull request Jun 17, 2024
ghstack-source-id: eaa86396ed1254451dbd453616b0ffccd133dabe
Pull Request resolved: #128884
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jun 17, 2024
ghstack-source-id: 8f387bf5eb645d3d7417eecda01182dfc869e21f
Pull Request resolved: #128884
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jun 17, 2024
ghstack-source-id: 86e664adafe8e22d99a3809ba6712d54c25dc5e5
Pull Request resolved: #128884
@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Jun 17, 2024
This PR adds two APIs `set_modules_to_forward_prefetch` and `set_modules_to_backward_prefetch` to enable explicit forward/backward all-gather prefetching, respectively.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jun 17, 2024
ghstack-source-id: a41b1b6d326513a78dfccca5a7dfbb1a503c4713
Pull Request resolved: #128884
@awgu awgu marked this pull request as ready for review June 17, 2024 20:56
@awgu
Copy link
Contributor Author

awgu commented Jun 17, 2024

Example of prefetching 1 FSDP module explicitly in backward:
Screenshot 2024-06-17 at 4 59 33 PM

Example of prefetching 2 FSDP modules explicitly in backward:
Screenshot 2024-06-17 at 4 58 33 PM

@awgu
Copy link
Contributor Author

awgu commented Jun 17, 2024

Example of prefetching 1 FSDP module explicitly in forward:
Screenshot 2024-06-17 at 5 00 02 PM

Example of prefetching 2 FSDP module explicitly in forward:
Screenshot 2024-06-17 at 5 00 31 PM

Example of prefetching 4 FSDP module explicitly in forward:
Screenshot 2024-06-17 at 5 00 48 PM

Copy link
Contributor

@ckluk2 ckluk2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great!

@awgu
Copy link
Contributor Author

awgu commented Jun 17, 2024

@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 17, 2024
Copy link
Contributor

@sanketpurandare sanketpurandare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to know the use case. This will only help in case of FSDPModules that have unbalanced compute. So that we can use the compute of a larger FSDPModule to overlap the prefetch of smaller ones.

Also, will your extension of wrapping List[nn.Module] eliminate the need for this?

@awgu
Copy link
Contributor Author

awgu commented Jun 18, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jun 18, 2024
This PR adds `set_post_optim_event` that allows power users to provide their own CUDA event that is recorded after the optimizer step for the FSDP root module to wait the all-gather streams on.
```
def set_post_optim_event(self, event: torch.cuda.Event) -> None:
```
By default, the root would have the all-gather streams wait on the current stream (`wait_stream`), which may introduce false dependencies if there is unrelated computation after the optimizer step and before the wait. For example, this pattern can appear in recommendation models.

To avoid those false dependencies while preserving the correctness guarantee, we provide this API so that the user can provide their own CUDA event to wait the all-gather streams on.

We include both correctness test (`test_fully_shard_training.py`) and overlap test (`test_fully_shard_overlap.py`).

---

One possible way to use the API is to register a post-step hook on the optimizer. For example:
https://github.com/pytorch/pytorch/blob/12e8d1399b979b45d16f0934017f742d01ab2b8d/test/distributed/_composable/fsdp/test_fully_shard_training.py#L546-L552

Pull Request resolved: #128975
Approved by: https://github.com/sanketpurandare, https://github.com/weifengpy
ghstack dependencies: #128884
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants