Skip to content

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Jun 15, 2023

THIS IS A new PR with the changes from #102038 + #103201 + plus namespacing changes to fix bug.

Summary

This PR builds off of:

It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:

  • Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention
  • Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support
  • Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. Use the absolute index for rng state indexing facebookresearch/xformers#755

cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @anijain2305 @aakhundov @bertmaher

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 15, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103704

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3c3a39b:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg force-pushed the redo_dropout_support_for_mem_eff branch from dd9356b to dbb2d16 Compare June 16, 2023 01:15
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg force-pushed the redo_dropout_support_for_mem_eff branch from dbb2d16 to 053644c Compare June 16, 2023 04:07
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg force-pushed the redo_dropout_support_for_mem_eff branch from 053644c to 37640a2 Compare June 16, 2023 18:41
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg force-pushed the redo_dropout_support_for_mem_eff branch from 37640a2 to c18880c Compare June 21, 2023 16:29
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 21, 2023
@drisspg drisspg force-pushed the redo_dropout_support_for_mem_eff branch from c18880c to 0ffaf1f Compare June 22, 2023 23:56
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D46778637

Summary:
bypass-github-export-checks

THIS IS A new PR with the changes from pytorch#102038 + pytorch#103201 +  plus namespacing changes to fix bug.

This PR builds off of:
- pytorch#101847
- pytorch#100583

It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:
- Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention
- Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support
- Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. facebookresearch/xformers#755

cc albanD voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 anijain2305 aakhundov bertmaher

Pull Request resolved: pytorch#103704

Reviewed By: cpuhrsch

Differential Revision: D46778637

Pulled By: drisspg

fbshipit-source-id: 79cb8d9641b00c27b8ec669be1398b1bf36e7a6e
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D46778637

@drisspg drisspg force-pushed the redo_dropout_support_for_mem_eff branch from 0ffaf1f to 5ab17c4 Compare June 23, 2023 00:06
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@davidberard98
Copy link
Contributor

davidberard98 commented Jun 27, 2023

I bisected a 4% slowdown on huggingface to this PR (specifically using ElectraForQuestionAnswering). @drisspg any ideas what could cause this?

before this: baseline 102.8ms, 36.2ms PT2
after this: baseline 102.8ms, 41.9ms PT2

python benchmarks/dynamo/huggingface.py --output=bisect.csv --training -dcuda --no-skip --dashboard --cold_start_latency --inductor --performance --amp --only ElectraForQuestionAnswering

@drisspg
Copy link
Contributor Author

drisspg commented Jun 27, 2023

@davidberard98 This is likely because the model question is using dropout. Previously Memory effiecient attention would have not been dispatched to and this PR allows this now. A good way to verify this would be to select which kernel: https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#explicit-dispatcher-control

std::array<SDPBackend, num_backends> priority_order(sdp_params params) {

This defines the kernel ordering. I have one more PR to add attn_bias support to mem_eff and then was going see if any updates are needed to the priority order

@eellison
Copy link
Contributor

@drisspg do we expect the other pr to recover perf ? 5% performance regression is significant. FWIW the memory compression stayed stable with this pr. I guess potentially our heuristics on which backend to select here didn't work great here ?

@drisspg
Copy link
Contributor Author

drisspg commented Jun 28, 2023

@eellison I do not think the other PR will recover the PERF. I suspect the heurisitic is choosing the new mem_eff + dropout path now that it exists when the more performant path is FlashAttn.

I am planning though to create a 3rd pr once the mem_eff attn PR to re run the perf sweep and see if there are more modifications to be made to the heuristic

@davidberard98
Copy link
Contributor

@drisspg you mentioned selecting the kernel using the tutorial you linked - do you recommend this or do you think it's better to rely on the heuristics (and wait for your updates to see if the heuristics)

@drisspg
Copy link
Contributor Author

drisspg commented Jun 28, 2023

@davidberard98 Depends, for user who want the maximum peformance I think they should always measure on their specific shapes and hardware the 3 different options and use the dispatcher to pick the fastest one. I assume this only effects PT2 models becuase of the sdpa re-writer. And for that there are two options. Similiar to matmul we could add an autotune pass? Otherwise I think we need to rely on the heurisitc

@eellison
Copy link
Contributor

Where are the current heuristics ?

@drisspg
Copy link
Contributor Author

drisspg commented Jun 28, 2023

Where are the current heuristics ?

Linked up above in my response to David

@eellison
Copy link
Contributor

cc @Chillee who updated the heuristics recently

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants