Skip to content

Conversation

kimishpatel
Copy link
Contributor

@kimishpatel kimishpatel commented Mar 10, 2025

Stack from ghstack (oldest at bottom):

This diff fixes two bugs

  1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it.
  2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner.

It is probably better to just use mask though.

Code has detail comments on the issue and fix.

Differential Revision: D70922039

cc @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng

This diff fixes two bugs
1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it.
2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner.

It is probably better to just use mask though.

Code has detail comments on the issue and fix.

Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/)

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Mar 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9105

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 16c1fce with merge base 6daff83 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kimishpatel added a commit that referenced this pull request Mar 10, 2025
This diff fixes two bugs
1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it.
2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner.

It is probably better to just use mask though.

Code has detail comments on the issue and fix.

Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/)

ghstack-source-id: 270854238
Pull Request resolved: #9105
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 10, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70922039

Copy link
Contributor

@larryliu0820 larryliu0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for adding the test

/* dst */ qSplitSize * headSize;

int64_t size_bytes = size_per_thread * num_thread * query.element_size();
int64_t size_bytes = size_per_thread * num_thread * query.element_size() * 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is *4 for fp32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh really good catch. Thats was left over from local debug. need to remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please finalize so that @spalatinate has the correct version

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh totally forgot about this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is the fix pr #9492

@kimishpatel kimishpatel added the module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code label Mar 18, 2025
This diff fixes two bugs
1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it.
2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner.

It is probably better to just use mask though.

Code has detail comments on the issue and fix.

Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/)

cc larryliu0820 mergennachin cccclai helunwencser jackzhxng

[ghstack-poisoned]
@kimishpatel kimishpatel requested a review from swolchok as a code owner March 18, 2025 04:36
kimishpatel added a commit that referenced this pull request Mar 18, 2025
Pull Request resolved: #9105

This diff fixes two bugs
1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it.
2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner.

It is probably better to just use mask though.

Code has detail comments on the issue and fix.

Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/)
ghstack-source-id: 272375653
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70922039

This diff fixes two bugs
1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it.
2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner.

It is probably better to just use mask though.

Code has detail comments on the issue and fix.

Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/)

cc larryliu0820 mergennachin cccclai helunwencser jackzhxng

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70922039

@kimishpatel kimishpatel added the release notes: ops & kernels Changes to the opset and any new / changed kernel implementations label Mar 20, 2025
@facebook-github-bot facebook-github-bot merged commit 76e3fd9 into gh/kimishpatel/159/base Mar 20, 2025
82 of 83 checks passed
@facebook-github-bot facebook-github-bot deleted the gh/kimishpatel/159/head branch March 20, 2025 23:43
kedarnath03 pushed a commit to kedarnath03/executorch that referenced this pull request Jun 25, 2025
Pull Request resolved: pytorch/executorch#9105

This diff fixes two bugs
1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it.
2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner.

It is probably better to just use mask though.

Code has detail comments on the issue and fix.
ghstack-source-id: 272776939

Differential Revision: [D70922039](https://our.internmc.facebook.com/intern/diff/D70922039/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code release notes: ops & kernels Changes to the opset and any new / changed kernel implementations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants