From 6c13d63867fc1c672aef023eaf2b5e69fbc7b51b Mon Sep 17 00:00:00 2001 From: zhyajie Date: Tue, 18 Nov 2025 12:39:30 +0000 Subject: [PATCH] [Bugfix] fix(fused_moe): Fix precision corruption when shared_experts_stream=None Signed-off-by: zhyajie --- vllm/model_executor/layers/fused_moe/layer.py | 11 +++++++---- vllm/utils/torch_utils.py | 3 +-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 023132acfed3..a77ce6a217e7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -335,8 +335,8 @@ def __init__( logger.info_once("Disabling MoE shared_experts cuda stream") self.shared_experts_stream = None else: - # TODO(rob): enable shared expert overlap with non-cuda. - # aux_stream() returns None on non-cuda platforms. + # TODO(rob): enable shared expert overlap with non-cuda-alike. + # aux_stream() returns None on non-cuda-alike platforms. self.shared_experts_stream = aux_stream() if self.shared_experts_stream is not None: logger.info_once("Enabled separate cuda stream for MoE shared_experts") @@ -1752,6 +1752,11 @@ def forward_impl( hidden_states_combined, router_logits = get_ep_group().dispatch( hidden_states, router_logits, self.is_sequence_parallel ) + # Run shared experts before matrix multiply. + # because matrix multiply maybe modify the hidden_states. + if has_separate_shared_experts and not use_shared_experts_stream: + assert self.shared_experts is not None + shared_output = self.shared_experts(hidden_states) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -1795,8 +1800,6 @@ def forward_impl( # conflict with the main stream shared_output = self.shared_experts(hidden_states_clone) current_stream().wait_stream(self.shared_experts_stream) - else: - shared_output = self.shared_experts(hidden_states) final_hidden_states = ( shared_output, diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 7c094e14cff7..3661dfd09047 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -426,8 +426,7 @@ def aux_stream() -> torch.cuda.Stream | None: from vllm.platforms import current_platform - # TODO: validate this works properly on ROCm platform. - if _aux_stream is None and current_platform.is_cuda(): + if _aux_stream is None and current_platform.is_cuda_alike(): _aux_stream = torch.cuda.Stream() return _aux_stream