-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[BugFix] Fix FI accuracy issue when used for MLA prefill #26063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix FI accuracy issue when used for MLA prefill #26063
Conversation
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an accuracy issue with the FlashInfer MLA prefill implementation. The root cause appears to be an inconsistent shape of the log-sum-exp (LSE) tensor returned by FlashInfer, which is (q_len, num_heads)
instead of the expected (num_heads, q_len)
. The changes correctly transpose the LSE tensor in both _run_prefill_new_tokens_fi
and _run_prefill_context_chunk_fi
to align with other backends, which should resolve the accuracy problem. The fix is logical and well-targeted. I have one suggestion to improve the code's robustness by using isinstance()
for type checking, in line with Python best practices.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
I ran this to validate on 8xB200:
where it completely failed before, it passes GSM8k now. I also ran it with MTP on my development branch (#25984), which relies more heavily on this prefill functionality, and it passed as well:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I too, am curious why this didn't / doesn't cause a shape mismatch. However, it clearly works well to solve the problem.
I would approve if I had any understanding as to why this solves the issue.
FIX #26042 |
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: simon-mo <simon.mo@hey.com>
…t#26063) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
PR
Main