-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Fix routing_bias dtype #25711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix routing_bias dtype #25711
Conversation
Signed-off-by: Shu Wang. <shuw@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix the data type of routing_bias
before passing it to a FlashInfer kernel. While the change correctly identifies the need for a type cast, it hardcodes torch.bfloat16
. This can lead to data type mismatches and potential runtime errors if the model uses a different activation data type, such as torch.float16
. My review provides a more robust solution that dynamically casts routing_bias
to match the data type of routing_logits
, ensuring consistency.
routing_bias = e_score_correction_bias | ||
if routing_bias is not None: | ||
routing_bias = routing_bias.to(torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding the routing_bias
dtype to torch.bfloat16
can cause a dtype mismatch with routing_logits
, which could lead to runtime errors or incorrect behavior for models using other dtypes like torch.float16
.
The routing_logits
passed to the kernel are either of router_logits.dtype
(when use_llama4_routing
is true) or torch.float32
. To ensure consistency, routing_bias
should be cast to the same dtype as the routing_logits
being passed to the kernel.
The suggested change determines the target dtype dynamically.
routing_bias = e_score_correction_bias
if routing_bias is not None:
target_dtype = router_logits.dtype if use_llama4_routing else torch.float32
routing_bias = routing_bias.to(target_dtype)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the flashinfer kernel, it looks like bfloat16 is hardcoded https://github.com/flashinfer-ai/flashinfer/blob/59a0c514b1fbdf1868466d6e7a98a7c4a1e26908/csrc/trtllm_fused_moe_kernel_launcher.cu#L79
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wenscarl thanks for the fix!
Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
Fix the routing_bias dtype to bf16 for flashinfer.fused_moe.trtllm_fp4_block_scale_moe
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.