Skip to content

Conversation

zhoukezi
Copy link
Contributor

@zhoukezi zhoukezi commented Sep 29, 2025

Purpose

This PR addresses several issues discovered during the testing of MiDashengLM quantization.

  1. Add packed_modules_mapping to MiDashengLMModel: The MiDashengLMModel was missing the packed_modules_mapping attribute, which caused an error when loading models quantized with bitsandbytes. This has been added to ensure compatibility.
  2. Fix Parameter Mismatch in DashengAudioTransformer: The DashengAudioTransformer was providing an incorrect prefix to DashengBlock, resulting in a parameter mismatch error when loading quantized models. This PR provides the correct prefix to resolve the loading issue.
  3. Update Audio Encoder Frontend in vLLM: The current vLLM implementation was using an outdated audio encoder frontend. It has been updated to align with the latest implementation from the Hugging Face version, ensuring consistency.
  4. Correct Audio Encoder Attention Mechanism: The attention mechanism in the vLLM audio encoder was incorrectly implemented, leading to significant deviations in the encoder's output. This has been reverted to the manual implementation from the Hugging Face version to produce the correct output.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: zhoukz <me@zhoukz.com>
…lock` loads quantized model parameters correctly

Signed-off-by: zhoukz <me@zhoukz.com>
…mplementation

Signed-off-by: zhoukz <me@zhoukz.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several important bug fixes for the MiDashengLM implementation, particularly for quantized models. The changes include adding packed_modules_mapping for bitsandbytes compatibility, correcting parameter prefixes, updating the audio encoder frontend, and fixing the audio encoder attention mechanism.

My review identified a critical issue in the updated DashengAttention.forward method related to tensor reshaping when using tensor parallelism. The current implementation will likely cause errors or incorrect behavior when tp_size > 1. I've provided a detailed comment and a code suggestion to address this. The other changes appear correct and align with the goals of the PR.

Comment on lines 217 to 234
qkv, _ = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)

attn = (q @ k.transpose(-2, -1)) * self.scale
if self.causal:
mask_value = -torch.finfo(attn.dtype).max
i, j = attn.shape[-2:]
mask = torch.ones(i, j, device=q.device,
dtype=torch.bool).triu(j - i + 1)
attn = attn.masked_fill(mask, mask_value)
if mask is not None:
mask_value = torch.finfo(attn.dtype).min
attn_mask = mask[:, None, None, :].expand(B, 1, N, N)
attn = attn.masked_fill(attn_mask, mask_value)
attn = attn.softmax(dim=-1)
attn = torch.nan_to_num(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)

x, _ = self.proj(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The tensor reshaping logic in DashengAttention.forward appears to be incorrect for tensor parallelism (tp_size > 1), which could lead to runtime errors or incorrect outputs.

  1. Incorrect reshape of qkv: On line 218, C // self.num_heads is used for the head dimension. This is only correct when tp_size is 1. For tp_size > 1, it should be self.head_dim.

  2. Incorrect reshape of attention output: On line 235, the attention output is reshaped to (B, N, C). The number of elements in (attn @ v).transpose(1, 2) is B * N * self.num_heads * self.head_dim, which equals B * N * self.embed_dim / tp_size. Reshaping this to (B, N, C) (where C is self.embed_dim) will fail if tp_size > 1. It should be reshaped to (B, N, -1) to match the partitioned input size expected by the subsequent RowParallelLinear layer.

Here is a suggested correction for this part of the forward method:

        qkv, _ = self.qkv(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        if self.causal:
            mask_value = -torch.finfo(attn.dtype).max
            i, j = attn.shape[-2:]
            mask = torch.ones(i, j, device=q.device,
                              dtype=torch.bool).triu(j - i + 1)
            attn = attn.masked_fill(mask, mask_value)
        if mask is not None:
            mask_value = torch.finfo(attn.dtype).min
            attn_mask = mask[:, None, None, :].expand(B, 1, N, N)
            attn = attn.masked_fill(attn_mask, mask_value)
        attn = attn.softmax(dim=-1)
        attn = torch.nan_to_num(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)

        x, _ = self.proj(x)

@DarkLight1337
Copy link
Member

Correct Audio Encoder Attention Mechanism: The attention mechanism in the vLLM audio encoder was incorrectly implemented, leading to significant deviations in the encoder's output. This has been reverted to the manual implementation from the Hugging Face version to produce the correct output.

cc @Isotr0py regarding this

…anual implementation

Signed-off-by: zhoukz <me@zhoukz.com>
Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Isotr0py Isotr0py enabled auto-merge (squash) September 29, 2025 09:13
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 29, 2025
@Isotr0py Isotr0py merged commit 8616300 into vllm-project:main Sep 29, 2025
50 checks passed
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
…d models (#25854)

Signed-off-by: zhoukz <me@zhoukz.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants