Skip to content

Conversation

@jinzhen-lin
Copy link
Contributor

fix #28220

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a critical fix for a race condition in the MoE Marlin kernel. The change in csrc/moe/marlin_moe_wna16/marlin_template.h correctly addresses an issue where multiple threads could write to the same shared memory location, which would lead to incorrect results. By introducing a proper boundary check, the fix ensures memory safety and correctness. Additionally, this PR enables overlapped execution for Marlin kernels in vllm/model_executor/layers/fused_moe/shared_fused_moe.py, a performance optimization that was likely blocked by this bug. The changes are well-implemented and address the issue effectively.

Comment on lines 492 to 501
if (idx < block_num_valid_tokens) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
sh_block_topk_weights[idx] = __hmul2(
global_scale, Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]])));
} else {
sh_block_topk_weights[idx] = Dtype::num2num2(
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change correctly fixes a critical bug. The previous logic, which reset out-of-bounds idx values to 0, could lead to a race condition where multiple threads would write to sh_block_topk_weights[0] simultaneously. This would cause incorrect results and undefined behavior. By wrapping the operation in an if (idx < block_num_valid_tokens) check, you ensure that out-of-bounds accesses are safely skipped. This is the correct and robust approach to prevent this issue.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug in the MoE marlin kernel and re-enables a related optimization. The change in csrc/moe/marlin_moe_wna16/marlin_template.h correctly fixes a race condition where multiple threads could write to the same shared memory location when handling tokens near block boundaries. The previous logic incorrectly reset an out-of-bounds index to 0, causing this data race. The new implementation properly guards the memory access with a conditional check, resolving the issue. The second change in vllm/model_executor/layers/fused_moe/shared_fused_moe.py re-enables overlapped execution for marlin kernels. This optimization was likely disabled due to the bug, and its re-introduction is a good performance improvement. The changes are correct and effectively address the underlying problem.

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
@youkaichao
Copy link
Member

cc @mgoin @vadiklyutiy

@vadiklyutiy
Copy link
Collaborator

vadiklyutiy commented Nov 13, 2025

So, the illegal memory access appeared when we read some garbage from x=sh_block_topk_weights[0] and then use x as index in topk_weights_ptr[x]?

@jinzhen-lin
Copy link
Contributor Author

So, the illegal memory access appeared when we read some garbage from x=sh_block_topk_weights[0] and then use x as index in topk_weights_ptr[x]?

Yes.

@vadiklyutiy
Copy link
Collaborator

So, the illegal memory access appeared when we read some garbage from x=sh_block_topk_weights[0] and then use x as index in topk_weights_ptr[x]?

Yes.

Just wondering how multi stream impact on appearance this issue...

@jinzhen-lin
Copy link
Contributor Author

jinzhen-lin commented Nov 13, 2025

Just wondering how multi stream impact on appearance this issue...

I’m not sure. Before we read values from global memory into sh_block_sorted_ids[0], the value of sh_block_sorted_ids[0] might be a leftover from the previous kernel execution, or it could be some other garbage value. In addition, the layout of memory allocation in GPU memory might also affect whether an IMA behavior actually occurs.

@vadiklyutiy
Copy link
Collaborator

Although I still have a slight feeling of something left unsaid because I don’t understand how multi-stream impacted.
Meantime the changes look reasonable - I don't see the reason to write in idx=0 when we out of bound.

@youkaichao youkaichao enabled auto-merge (squash) November 20, 2025 16:19
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 20, 2025
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try again before force merge

@jinzhen-lin
Copy link
Contributor Author

@yewentao256 failed lora test seems not related

@vllm-bot vllm-bot merged commit a67dec7 into vllm-project:main Nov 27, 2025
88 of 90 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Find the root cause of SHARED_EXPERTS_STREAM fail

6 participants