Fix flash attention for GQA (Phi4) #23850
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This change fixes GQA for Flash Attention on Nvidia GPUs. The root cause appears to be
k_start + capped_sg_id < seq_causal_length
check. This is either because,
a. seq_causal_length varies per lane, so the check becomes non uniform control flow, which is having interactions with subgroupShuffle.
or
b. The check itself is incorrect and is wiping out values of v based on the source lane's seq_causal_length. While in actualness values of v need to be causal as per the lane that is going to multiply it with qkt.
qkt is already causal because earlier values of qk for out of bounds k are set to min_value, and exp(<-4) are 0.
This fix works by removing that causal check and relying on the qk being wiped out earlier. The documentation for causality behavior for GQA is missing to determine which of this reason is the true reason.
Prior to this prompts with sequence length > 16 < 32 or 1k would break with Phi 4 but smaller prompts would work.
Tested on Intel Alderlake, Nvidia 4070.