Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ def __init__(
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda.
# aux_stream() returns None on non-cuda platforms.
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
Expand Down Expand Up @@ -1865,6 +1865,11 @@ def forward_impl(
hidden_states_combined, router_logits = get_ep_group().dispatch(
hidden_states, router_logits, self.is_sequence_parallel
)
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states)

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
Expand Down Expand Up @@ -1908,8 +1913,6 @@ def forward_impl(
# conflict with the main stream
shared_output = self.shared_experts(hidden_states_clone)
current_stream().wait_stream(self.shared_experts_stream)
else:
shared_output = self.shared_experts(hidden_states)

final_hidden_states = (
shared_output,
Expand Down
3 changes: 1 addition & 2 deletions vllm/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,7 @@ def aux_stream() -> torch.cuda.Stream | None:

from vllm.platforms import current_platform

# TODO: validate this works properly on ROCm platform.
if _aux_stream is None and current_platform.is_cuda():
if _aux_stream is None and current_platform.is_cuda_alike():
_aux_stream = torch.cuda.Stream()

return _aux_stream
Expand Down