-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Disable FlashAttenion for is_causal=True when seqlen q not equal kv #111007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Disable FlashAttenion for is_causal=True when seqlen q not equal kv #111007
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111007
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit be9bb0e with merge base 652f4c6 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
5dfadec
to
044a820
Compare
e9f3845
to
be9bb0e
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot merge -f "unrelated failures" |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Summary We were restricted from updating to the newest version of FlashAttention based off of the changes to is_casual described here: #108108 Prior to this PR we landed: #111007 which enabled us to updated beyond: 9e5e8bc91e on FlashAttentionV2. With this PR we have updated to this commit: Dao-AILab/flash-attention@02ac572. Or Tag 2.3.2 ## Plans Following this PR I plan to work more on #110681 in order to expose a CausalVariant attn_mask, w/ the potential for also exposing a kvcache attn_mask. Pull Request resolved: #111886 Approved by: https://github.com/cpuhrsch
…ytorch#111007) # Summary: This pull request **removes** support for non-square sequence lengths in causal attention when using FlashAttention V2. ### Why are doing this // FlashAttention 2 updated the default mask meaning for causal in this PR: // 9e5e8bc91e it is now aligned to lower_right which would be a BC break // for non-square masks. We will not support non-square masks for causal w/ FAV2 For more context see: pytorch#108108 ### Followup A large number of people will likely want to use FAV2 with lower_right causal attention for non equal sequence lengths. See this RFC : pytorch#110681 Pull Request resolved: pytorch#111007 Approved by: https://github.com/cpuhrsch
# Summary We were restricted from updating to the newest version of FlashAttention based off of the changes to is_casual described here: pytorch#108108 Prior to this PR we landed: pytorch#111007 which enabled us to updated beyond: 9e5e8bc91e on FlashAttentionV2. With this PR we have updated to this commit: Dao-AILab/flash-attention@02ac572. Or Tag 2.3.2 ## Plans Following this PR I plan to work more on pytorch#110681 in order to expose a CausalVariant attn_mask, w/ the potential for also exposing a kvcache attn_mask. Pull Request resolved: pytorch#111886 Approved by: https://github.com/cpuhrsch
…ytorch#111007) # Summary: This pull request **removes** support for non-square sequence lengths in causal attention when using FlashAttention V2. ### Why are doing this // FlashAttention 2 updated the default mask meaning for causal in this PR: // 9e5e8bc91e it is now aligned to lower_right which would be a BC break // for non-square masks. We will not support non-square masks for causal w/ FAV2 For more context see: pytorch#108108 ### Followup A large number of people will likely want to use FAV2 with lower_right causal attention for non equal sequence lengths. See this RFC : pytorch#110681 Pull Request resolved: pytorch#111007 Approved by: https://github.com/cpuhrsch
# Summary We were restricted from updating to the newest version of FlashAttention based off of the changes to is_casual described here: pytorch#108108 Prior to this PR we landed: pytorch#111007 which enabled us to updated beyond: 9e5e8bc91e on FlashAttentionV2. With this PR we have updated to this commit: Dao-AILab/flash-attention@02ac572. Or Tag 2.3.2 ## Plans Following this PR I plan to work more on pytorch#110681 in order to expose a CausalVariant attn_mask, w/ the potential for also exposing a kvcache attn_mask. Pull Request resolved: pytorch#111886 Approved by: https://github.com/cpuhrsch
Summary:
This pull request removes support for non-square sequence lengths in causal attention when using FlashAttention V2.
Why are doing this
// FlashAttention 2 updated the default mask meaning for causal in this PR:
// 9e5e8bc91e it is now aligned to lower_right which would be a BC break
// for non-square masks. We will not support non-square masks for causal w/ FAV2
For more context see:
#108108
Followup
A large number of people will likely want to use FAV2 with lower_right causal attention for non equal sequence lengths. See this RFC : #110681