-
Notifications
You must be signed in to change notification settings - Fork 25.2k
REDO of dropout support for mem eff #102038 #103704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103704
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3c3a39b: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
dd9356b
to
dbb2d16
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
dbb2d16
to
053644c
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
053644c
to
37640a2
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
37640a2
to
c18880c
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
c18880c
to
0ffaf1f
Compare
This pull request was exported from Phabricator. Differential Revision: D46778637 |
Summary: bypass-github-export-checks THIS IS A new PR with the changes from pytorch#102038 + pytorch#103201 + plus namespacing changes to fix bug. This PR builds off of: - pytorch#101847 - pytorch#100583 It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made: - Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention - Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support - Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. facebookresearch/xformers#755 cc albanD voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 anijain2305 aakhundov bertmaher Pull Request resolved: pytorch#103704 Reviewed By: cpuhrsch Differential Revision: D46778637 Pulled By: drisspg fbshipit-source-id: 79cb8d9641b00c27b8ec669be1398b1bf36e7a6e
This pull request was exported from Phabricator. Differential Revision: D46778637 |
0ffaf1f
to
5ab17c4
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
I bisected a 4% slowdown on huggingface to this PR (specifically using ElectraForQuestionAnswering). @drisspg any ideas what could cause this? before this: baseline 102.8ms, 36.2ms PT2
|
@davidberard98 This is likely because the model question is using dropout. Previously Memory effiecient attention would have not been dispatched to and this PR allows this now. A good way to verify this would be to select which kernel: https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#explicit-dispatcher-control
This defines the kernel ordering. I have one more PR to add attn_bias support to mem_eff and then was going see if any updates are needed to the priority order |
@drisspg do we expect the other pr to recover perf ? 5% performance regression is significant. FWIW the memory compression stayed stable with this pr. I guess potentially our heuristics on which backend to select here didn't work great here ? |
@eellison I do not think the other PR will recover the PERF. I suspect the heurisitic is choosing the new mem_eff + dropout path now that it exists when the more performant path is FlashAttn. I am planning though to create a 3rd pr once the mem_eff attn PR to re run the perf sweep and see if there are more modifications to be made to the heuristic |
@drisspg you mentioned selecting the kernel using the tutorial you linked - do you recommend this or do you think it's better to rely on the heuristics (and wait for your updates to see if the heuristics) |
@davidberard98 Depends, for user who want the maximum peformance I think they should always measure on their specific shapes and hardware the 3 different options and use the dispatcher to pick the fastest one. I assume this only effects PT2 models becuase of the sdpa re-writer. And for that there are two options. Similiar to matmul we could add an autotune pass? Otherwise I think we need to rely on the heurisitc |
Where are the current heuristics ? |
Linked up above in my response to David |
cc @Chillee who updated the heuristics recently |
THIS IS A new PR with the changes from #102038 + #103201 + plus namespacing changes to fix bug.
Summary
This PR builds off of:
It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:
cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @anijain2305 @aakhundov @bertmaher