Skip to content

Conversation

@zhyajie
Copy link
Contributor

@zhyajie zhyajie commented Nov 18, 2025

Fix Precision issues When shared_experts_stream=None

Purpose

Fix precision issues when shared_experts_stream=None by executing shared experts before routed experts kernel, preventing hidden_states from being mutated in-place.

Test Plan

Tested on ROCm (gfx942) with DeepSeek-R1 model using TP=8.

Start vLLM Server

export GPU_ARCHS=gfx942
export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0
MODEL=path/deepseek-ai/DeepSeek-R1

AITER_ENABLE_VSKIP=0 \
VLLM_USE_V1=1 \
VLLM_ROCM_USE_AITER=1 \
vllm serve $MODEL \
  --tensor-parallel-size 8 \
  --disable-log-requests \
  --no-enable-prefix-caching \
  --trust-remote-code \
  --block-size 1 \
  --enforce-eager \
  2>&1 | tee vllm_serve.log

Send Test Request

curl -X POST "http://localhost:8000/v1/completions" \
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "The capital of China",
    "temperature": 0,
    "top_p": 1,
    "top_k": 0,
    "repetition_penalty": 1.0,
    "presence_penalty": 0,
    "frequency_penalty": 0,
    "stream": false,
    "ignore_eos": false,
    "n": 1,
    "seed": 123
  }'

Test Result

Before Fix (Corrupted Output)

{
  "id": "cmpl-00b08a50df1b4a968ab5e41ccbeeced8",
  "object": "text_completion",
  "created": 1763468384,
  "model": "path/deepseek-ai/DeepSeek-R1",
  "choices": [{
    "index": 0,
    "text": " Beijingisnownsgrmetically Beijing BeijingBeijingBeijingBeijingBeijingBeijingBeijingBeijingBeijing",
    "finish_reason": "length"
  }],
  "usage": {
    "prompt_tokens": 5,
    "total_tokens": 21,
    "completion_tokens": 16
  }
}

After Fix (Correct Output)

{
  "id": "cmpl-dc933ac440a3460ca2922efcaff43f0e",
  "object": "text_completion",
  "created": 1763468925,
  "model": "path/deepseek-ai/DeepSeek-R1",
  "choices": [{
    "index": 0,
    "text": " is Beijing, and the capital of Russia is Moscow. Both China and Russia have",
    "finish_reason": "length"
  }],
  "usage": {
    "prompt_tokens": 5,
    "total_tokens": 21,
    "completion_tokens": 16
  }
}

…_stream=None

Signed-off-by: zhyajie <yajizhan@amd.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a precision corruption bug that occurs when shared_experts_stream is None. The root cause, an in-place mutation of the hidden_states tensor by the routed expert computation before it's used by the shared experts, is correctly identified. The fix involves reordering the operations to execute the shared experts before the routed expert computation, which is a logical and direct solution. Additionally, the changes correctly extend auxiliary stream support to cuda-alike platforms like ROCm, which is crucial for the tested environment. The provided test plan is comprehensive and clearly demonstrates the fix's effectiveness. The code changes are clean, targeted, and appear to be correct. I have no further recommendations.

@zejunchen-zejun
Copy link
Contributor

Hi, @gshtras @HaiShaw @tjtanaa

Could you help review this PR? It fixed the common accuracy issue of all MOE-style models for ROCm device.

Thank you.

Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 18, 2025

@zhyajie is the hidden states mutated by rocm aiter fused moe only or the triton fused moe also mutates the hidden states?

@zhyajie
Copy link
Contributor Author

zhyajie commented Nov 18, 2025

@zhyajie is the hidden states mutated by rocm aiter fused moe only or the triton fused moe also mutates the hidden states?

Both are expected to mutate the hidden states, the precision issue I initially discovered was in Kimi v2 thinking model , which used triton fused moe kernel.

@mgoin mgoin added bug Something isn't working rocm Related to AMD ROCm labels Nov 18, 2025
@sighingnow
Copy link
Collaborator

@zhyajie I made similar change in #28939.

The forward_impl_chunked needs similar bugfix as well.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 18, 2025

@robertgshaw2-redhat @SageMoore @alexm-redhat can you please take a look at both PRs? Which one is a better solution?

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me. Thanks for the fix!

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 19, 2025
@tjtanaa tjtanaa enabled auto-merge (squash) November 19, 2025 04:45
@vllmellm
Copy link
Contributor

Hi, I have successfully verified the fix provided in this PR using the RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic model on a ROCm/MI300X environment (tp 4).
Before Fix (Main branch): Output was corrupted/garbled, confirming the precision issue.
After Fix (Applied PR): Output was correct and accurate.
The fix works for Llama-4 MoE models. thanks for the work!

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 19, 2025

Hi, I have successfully verified the fix provided in this PR using the RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic model on a ROCm/MI300X environment (tp 4). Before Fix (Main branch): Output was corrupted/garbled, confirming the precision issue. After Fix (Applied PR): Output was correct and accurate. The fix works for Llama-4 MoE models. thanks for the work!

Thanks for validating with another model as well.

@tjtanaa tjtanaa merged commit 9d2d561 into vllm-project:main Nov 19, 2025
47 checks passed
Victor49152 pushed a commit to Victor49152/vllm that referenced this pull request Nov 20, 2025
…lm-project#28942)

Signed-off-by: zhyajie <yajizhan@amd.com>
Co-authored-by: zhyajie <yajizhan@amd.com>
LuminolT pushed a commit to LuminolT/vllm that referenced this pull request Nov 21, 2025
…lm-project#28942)

Signed-off-by: zhyajie <yajizhan@amd.com>
Co-authored-by: zhyajie <yajizhan@amd.com>
Signed-off-by: LuminolT <lumischen01@gmail.com>
bigPYJ1151 pushed a commit that referenced this pull request Nov 25, 2025
…8942)

Signed-off-by: zhyajie <yajizhan@amd.com>
Co-authored-by: zhyajie <yajizhan@amd.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
…lm-project#28942)

Signed-off-by: zhyajie <yajizhan@amd.com>
Co-authored-by: zhyajie <yajizhan@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants