diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index dd86a9a5ba6e..4dbca30da57a 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -489,14 +489,16 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < 4; i++) { int idx = tid4 * 4 + i; - idx = idx < block_num_valid_tokens ? idx : 0; - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - sh_block_topk_weights[idx] = __hmul2( - global_scale, Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[idx]]))); - } else { - sh_block_topk_weights[idx] = Dtype::num2num2( - Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); + if (idx < block_num_valid_tokens) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + sh_block_topk_weights[idx] = + __hmul2(global_scale, + Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[idx]]))); + } else { + sh_block_topk_weights[idx] = Dtype::num2num2( + Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); + } } } } diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 6ec8b33ed930..9aaeec4f98a6 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -38,7 +38,6 @@ def __init__( # TODO(wentao): find the root cause and remove this condition self.enable_eplb or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) - or self.use_marlin_kernels ) and self._shared_experts is not None )