Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flash attention for GQA (Phi4) #23850

Merged
merged 1 commit into from
Feb 28, 2025
Merged

Fix flash attention for GQA (Phi4) #23850

merged 1 commit into from
Feb 28, 2025

Conversation

sushraja-msft
Copy link
Contributor

@sushraja-msft sushraja-msft commented Feb 28, 2025

Description

This change fixes GQA for Flash Attention on Nvidia GPUs. The root cause appears to be
k_start + capped_sg_id < seq_causal_length
check. This is either because,
a. seq_causal_length varies per lane, so the check becomes non uniform control flow, which is having interactions with subgroupShuffle.
or
b. The check itself is incorrect and is wiping out values of v based on the source lane's seq_causal_length. While in actualness values of v need to be causal as per the lane that is going to multiply it with qkt.

qkt is already causal because earlier values of qk for out of bounds k are set to min_value, and exp(<-4) are 0.

This fix works by removing that causal check and relying on the qk being wiped out earlier. The documentation for causality behavior for GQA is missing to determine which of this reason is the true reason.

Prior to this prompts with sequence length > 16 < 32 or 1k would break with Phi 4 but smaller prompts would work.
Tested on Intel Alderlake, Nvidia 4070.

@sushraja-msft sushraja-msft marked this pull request as ready for review February 28, 2025 05:14
@sushraja-msft sushraja-msft changed the title Fix GQA' Fix flash attention for GQA Feb 28, 2025
@sushraja-msft sushraja-msft changed the title Fix flash attention for GQA Fix flash attention for GQA (Phi4) Feb 28, 2025
@guschmue guschmue merged commit 1be64f8 into main Feb 28, 2025
95 of 99 checks passed
@guschmue guschmue deleted the user/sushraja/fix_gqa_min branch February 28, 2025 16:02
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Feb 28, 2025
guschmue pushed a commit that referenced this pull request Mar 6, 2025
### Description
This change fixes GQA for Flash Attention on Nvidia GPUs. The root cause
appears to be
`k_start + capped_sg_id < seq_causal_length`
check. This is either because, 
a. seq_causal_length varies per lane, so the check becomes non uniform
control flow, which is having interactions with subgroupShuffle.
or 
b. The check itself is incorrect and is wiping out values of v based on
the source lane's seq_causal_length. While in actualness values of v
need to be causal as per the lane that is going to multiply it with qkt.

qkt is already causal because earlier values of qk for out of bounds k
are set to min_value, and exp(<-4) are 0.

This fix works by removing that causal check and relying on the qk being
wiped out earlier. The documentation for causality behavior for GQA is
missing to determine which of this reason is the true reason.

Prior to this prompts with sequence length > 16 < 32 or 1k would break
with Phi 4 but smaller prompts would work.
Tested on Intel Alderlake, Nvidia 4070.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants