-
Notifications
You must be signed in to change notification settings - Fork 684
[Executorch][SDPA] Fix bug in sdpa #9105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Executorch][SDPA] Fix bug in sdpa #9105
Conversation
This diff fixes two bugs 1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it. 2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner. It is probably better to just use mask though. Code has detail comments on the issue and fix. Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9105
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 16c1fce with merge base 6daff83 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This diff fixes two bugs 1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it. 2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner. It is probably better to just use mask though. Code has detail comments on the issue and fix. Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/) ghstack-source-id: 270854238 Pull Request resolved: #9105
This pull request was exported from Phabricator. Differential Revision: D70922039 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks for adding the test
/* dst */ qSplitSize * headSize; | ||
|
||
int64_t size_bytes = size_per_thread * num_thread * query.element_size(); | ||
int64_t size_bytes = size_per_thread * num_thread * query.element_size() * 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is *4 for fp32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh really good catch. Thats was left over from local debug. need to remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please finalize so that @spalatinate has the correct version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh totally forgot about this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here is the fix pr #9492
This diff fixes two bugs 1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it. 2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner. It is probably better to just use mask though. Code has detail comments on the issue and fix. Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/) cc larryliu0820 mergennachin cccclai helunwencser jackzhxng [ghstack-poisoned]
Pull Request resolved: #9105 This diff fixes two bugs 1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it. 2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner. It is probably better to just use mask though. Code has detail comments on the issue and fix. Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/) ghstack-source-id: 272375653
This pull request was exported from Phabricator. Differential Revision: D70922039 |
This diff fixes two bugs 1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it. 2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner. It is probably better to just use mask though. Code has detail comments on the issue and fix. Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/) cc larryliu0820 mergennachin cccclai helunwencser jackzhxng [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D70922039 |
76e3fd9
into
gh/kimishpatel/159/base
Pull Request resolved: pytorch/executorch#9105 This diff fixes two bugs 1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it. 2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner. It is probably better to just use mask though. Code has detail comments on the issue and fix. ghstack-source-id: 272776939 Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/)
Stack from ghstack (oldest at bottom):
This diff fixes two bugs
It is probably better to just use mask though.
Code has detail comments on the issue and fix.
Differential Revision: D70922039
cc @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng