From 6583e3be2c6eb6d49d6aac4d999e517258c38a92 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Thu, 13 Nov 2025 14:27:40 +0800 Subject: [PATCH 1/2] fix ima Signed-off-by: Jinzhen Lin --- csrc/moe/marlin_moe_wna16/marlin_template.h | 17 +++++++++-------- .../layers/fused_moe/shared_fused_moe.py | 1 - 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index dd86a9a5ba6e..e1287f07691c 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -489,14 +489,15 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < 4; i++) { int idx = tid4 * 4 + i; - idx = idx < block_num_valid_tokens ? idx : 0; - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - sh_block_topk_weights[idx] = __hmul2( - global_scale, Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[idx]]))); - } else { - sh_block_topk_weights[idx] = Dtype::num2num2( - Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); + if (idx < block_num_valid_tokens) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + sh_block_topk_weights[idx] = __hmul2( + global_scale, Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[idx]]))); + } else { + sh_block_topk_weights[idx] = Dtype::num2num2( + Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); + } } } } diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 6ec8b33ed930..9aaeec4f98a6 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -38,7 +38,6 @@ def __init__( # TODO(wentao): find the root cause and remove this condition self.enable_eplb or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) - or self.use_marlin_kernels ) and self._shared_experts is not None ) From 1d3d82f7fbb0ba484aeeb4777c996313b9e5f4cc Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Thu, 13 Nov 2025 15:05:24 +0800 Subject: [PATCH 2/2] fix pre-commit Signed-off-by: Jinzhen Lin --- csrc/moe/marlin_moe_wna16/marlin_template.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index e1287f07691c..4dbca30da57a 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -491,9 +491,10 @@ __global__ void Marlin( int idx = tid4 * 4 + i; if (idx < block_num_valid_tokens) { if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - sh_block_topk_weights[idx] = __hmul2( - global_scale, Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[idx]]))); + sh_block_topk_weights[idx] = + __hmul2(global_scale, + Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[idx]]))); } else { sh_block_topk_weights[idx] = Dtype::num2num2( Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));