Skip to content

Conversation

toncao
Copy link
Contributor

@toncao toncao commented Sep 16, 2025

Purpose

Due to qwen3-next does not pass prefixes as it loads shared_expert modules, and thus, it can not match shared_expert modules with the ignore list in quantization config, and ultimately leads to e.g.,:

KeyError: 'layers.31.mlp.shared_expert.down_proj.weight'

Please visit this discussion and model for more information.

If the shared_expert modules are not ignored, which deteriorates model outputs, there are no loading errors.


Essential Elements of an Effective PR Description Checklist
  • [x ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: toncao <cpatonn@gmail.com>
@toncao toncao requested a review from sighingnow as a code owner September 16, 2025 11:14
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the qwen Related to Qwen models label Sep 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a KeyError during the loading of quantized models for qwen3-next and qwen2moe by adding necessary prefixes to shared_expert and mlp modules. The changes correctly add a prefix parameter to Qwen2MoeMLP and use it to construct proper names for its submodules. The fix for qwen3-next appears complete, as it passes the correct prefix to its shared_expert. However, the fix for qwen2moe seems incomplete, as the new prefix parameter is not used at the call sites within qwen2_moe.py, which could lead to the same loading errors. I've left a specific comment on this.

hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While adding the prefix parameter here is a good step, it seems the fix for qwen2moe is incomplete. This Qwen2MoeMLP is instantiated in two places in this file, but the prefix is not passed in either case:

  1. In Qwen2MoeSparseMoeBlock for the shared_expert (L131).
  2. In Qwen2MoeDecoderLayer for the non-MoE mlp (L302).

Without passing the prefix, the submodules inside Qwen2MoeMLP won't have the correct names, which will likely lead to the same KeyError for quantized qwen2moe models. This seems critical to fix for the PR to fully achieve its goal.

Copy link
Collaborator

@jeejeelee jeejeelee Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@toncao could you please look at this, bot is right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, absolutely! I have updated the pr reflecting this in the most recent commit ed5e2a29cdde3a34dfaa3843d52a8f4d2e01543e.

Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you

@sighingnow sighingnow changed the title [Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models [WIP][Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models Sep 16, 2025
@sighingnow
Copy link
Collaborator

sighingnow commented Sep 16, 2025

Work in progress, a further patch for varlen verifying batch is on the way.

@jeejeelee Apologize that I commented on a wrong thread. Could you please convert this PR to "ready to review" again? Thanks.

@jeejeelee jeejeelee marked this pull request as draft September 16, 2025 13:10
Signed-off-by: toncao <cpatonn@gmail.com>
@sighingnow sighingnow changed the title [WIP][Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models [Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models Sep 16, 2025
@sighingnow sighingnow marked this pull request as ready for review September 16, 2025 16:06
@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 17, 2025
@DarkLight1337 DarkLight1337 merged commit 027d37d into vllm-project:main Sep 18, 2025
48 checks passed
845473182 pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 18, 2025
…litPR into model_register

* 'model_register' of https://github.com/dsxsteven/vllm_splitPR: (138 commits)
  Retrieve `sliding_window` from text config in Gemma3 MM (vllm-project#25085)
  [Docs] Fix API Reference (vllm-project#25140)
  [Kernel] Better inf handling for grouped topk cu (vllm-project#24886)
  [CLI] Use streaming in CLI chat and completion commands (vllm-project#23769)
  [benchmark] add peak throughput metrics and plot (vllm-project#23867)
  [Spec Decode] Efficient padded speculation (vllm-project#24539)
  [V0 Deprecation] Remove more V0 tests (vllm-project#25117)
  [EPLB] Add EPLB support for hunyuan_v1 (vllm-project#23078)
  [XPU] Whisper model support on XPU Platform (vllm-project#25123)
  Mark prompt logprobs as incompatible with prompt embeds at API level (vllm-project#25077)
  [Model] enable data parallel for InternVL vision encoder (vllm-project#23909)
  [Kernels] Overlap shared experts with combine instead of dispatch (vllm-project#24254)
  [Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models (vllm-project#24960)
  [Core][MM] Cleanup `MultiModalCache` (vllm-project#25006)
  [Docs] Clean up the contributing README (vllm-project#25099)
  [MM Encoder] Apply DP ViT for Qwen3-VL model series (vllm-project#24955)
  [Kernels] Enable DeepGEMM by default (vllm-project#24462)
  [V0 Deprecation] Skip PP test (vllm-project#25128)
  [V0 Deprecation] Remove misc V0 tests (vllm-project#25118)
  [V0 Deprecation] Remove V0 Tracing & Metrics tests (vllm-project#25115)
  ...
@halexan
Copy link

halexan commented Sep 18, 2025

i build vllm from latest

Successfully installed vllm-0.10.2rc3.dev215+gc9ff9e6f0.precompiled

error still exists

KeyError: 'layers.11.mlp.shared_expert_gate.weight_scale'

params

python3 -m vllm.entrypoints.openai.api_server --served-model-name qwen3-next-80b-a3b-instruct \
    --model /data/model-cache/qwen3-80b-fp8-dynamic/ \
    --tensor-parallel-size 2

using model nm-testing/qwen3-80b-fp8-dynamic

my hardware is 2 x A100

@shanjiaz
Copy link

shanjiaz commented Sep 19, 2025

i build vllm from latest

Successfully installed vllm-0.10.2rc3.dev215+gc9ff9e6f0.precompiled

error still exists

KeyError: 'layers.11.mlp.shared_expert_gate.weight_scale'

params

python3 -m vllm.entrypoints.openai.api_server --served-model-name qwen3-next-80b-a3b-instruct \
    --model /data/model-cache/qwen3-80b-fp8-dynamic/ \
    --tensor-parallel-size 2

using model nm-testing/qwen3-80b-fp8-dynamic

my hardware is 2 x A100

@halexan Thanks for trying the new model! Can you try using this model on H100 machines? I don't think A100 would work. Thanks!

@toncao
Copy link
Contributor Author

toncao commented Sep 19, 2025

I might be wrong but I think this is due to shared_expert_gate is initiated as torch.nn.Linear, which does not have quant_config attributes. Therefore, the quantization configs from llm-compressor/compressed-tensors do not get passed to vllm, and therefore, vllm disregards and does not load shared_expert_gate.weight_scale from the quantized model.

Edit:
And usually, gate params such as shared_expert_gate and mlp.gate should have been ignored during the quantization process, where KeyError does not occur.

debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
…mlp in qwen2moe to successfully load ignored params in quantized models (vllm-project#24960)

Signed-off-by: toncao <cpatonn@gmail.com>
Co-authored-by: toncao <cpatonn@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…mlp in qwen2moe to successfully load ignored params in quantized models (vllm-project#24960)

Signed-off-by: toncao <cpatonn@gmail.com>
Co-authored-by: toncao <cpatonn@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…mlp in qwen2moe to successfully load ignored params in quantized models (vllm-project#24960)

Signed-off-by: toncao <cpatonn@gmail.com>
Co-authored-by: toncao <cpatonn@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants