Skip to content

Conversation

wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Sep 25, 2025

Fix the routing_bias dtype to bf16 for flashinfer.fused_moe.trtllm_fp4_block_scale_moe

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix the data type of routing_bias before passing it to a FlashInfer kernel. While the change correctly identifies the need for a type cast, it hardcodes torch.bfloat16. This can lead to data type mismatches and potential runtime errors if the model uses a different activation data type, such as torch.float16. My review provides a more robust solution that dynamically casts routing_bias to match the data type of routing_logits, ensuring consistency.

Comment on lines +1457 to +1459
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Hardcoding the routing_bias dtype to torch.bfloat16 can cause a dtype mismatch with routing_logits, which could lead to runtime errors or incorrect behavior for models using other dtypes like torch.float16.

The routing_logits passed to the kernel are either of router_logits.dtype (when use_llama4_routing is true) or torch.float32. To ensure consistency, routing_bias should be cast to the same dtype as the routing_logits being passed to the kernel.

The suggested change determines the target dtype dynamically.

            routing_bias = e_score_correction_bias
            if routing_bias is not None:
                target_dtype = router_logits.dtype if use_llama4_routing else torch.float32
                routing_bias = routing_bias.to(target_dtype)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with this comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wenscarl thanks for the fix!

@alexm-redhat alexm-redhat enabled auto-merge (squash) September 25, 2025 21:26
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 25, 2025
@alexm-redhat alexm-redhat merged commit 081b559 into vllm-project:main Sep 25, 2025
58 of 60 checks passed
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants