Skip to content

Conversation

@Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Nov 19, 2025

Purpose

Test Plan

pytest -s -v tests/models/multimodal/generation/test_common.py -k gemma3-test

Test Result

Failing gemma3 test should pass now.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request partially reverts changes related to custom attention mask generation for Gemma3 multimodal models, which is intended to fix a failing test. The changes involve removing the uses_custom_attention_masks logic from ModelConfig, GpuModelRunner, and transformers_utils.config. The custom mask generation methods generate_attention_masks and prepare_attn_masks are also removed from gemma3_mm.py. Additionally, the get_multimodal_embeddings method in gemma3_mm.py has been renamed to embed_multimodal to align with the current interface, which is a good refactoring. The changes are clean, consistent, and effectively address the stated purpose of the PR. I find no issues with this change.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.unflatten(-1, (self.num_heads, self.head_dim))
q = self.q_norm(q)
q = q.flatten(-2, -1)
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
k = self.k_norm(k)
k = k.flatten(-2, -1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if not kwargs.get("has_images", False):
# Fast path for text-only inputs. The performance for the text-only
# inputs are not affected by the naive attention below.
output, _ = self.o_proj(attn_output)
return output
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
# that correspond to the same image while using causal attention
# otherwise. Current attention backends cannot handle this pattern, so
# we temporarily use a naive attention implementation with mask tensors.
# We intentionally keep the attention backend as-is and only override
# `attn_output` with the naive implementation's output. This minimizes
# changes to existing model runners and attention backends. The call to
# `self.attn(q, k, v)` is only used to populate the KV cache - its
# output is discarded and overwritten below. While this duplicates
# computation, it maintains compatibility.
# TODO(woosuk): Optimize by implementing custom attention kernels.
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)

P1 Badge Restore custom masks when Gemma3 sees images

Gemma3’s attention module only enables the bidirectional mask for image tokens when has_images is passed in kwargs, and it expects accompanying seq_lens and mask tensors (global_attn_masks/local_attn_masks). After this change the GPU model runner no longer sets has_images or builds those masks when preparing multimodal batches, so the check in Gemma3Attention.forward is always false and the code falls back to the standard causal attention. That causes all multimodal Gemma3 requests to run with purely causal masks, preventing image patches from attending to each other as the model definition requires. Any inference that includes images will therefore produce incorrect attention patterns.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@DarkLight1337
Copy link
Member

cc @lucianommartins

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 19, 2025
@lucianommartins
Copy link
Contributor

lucianommartins commented Nov 19, 2025

hey @Isotr0py,

This breaks the mm support for Gemma3 GGUF. And test96 failure is related to different tokens between Transformers and VLLM V1 engines for a very specific non-GGUF pan-and-scan scenario - test95 uses pan-and-scan too and works fine.

Are you sure it is the best way out?

As an alternative couldn't we suspend test96 (as it is kind of duplicate with test95) while we investigate why the generation differs? I volunteer for that investigation.

@Isotr0py
Copy link
Member Author

This breaks the mm support for Gemma3 GGUF.

Not really, I have verified that the gguf e2e mm test can still pass when submitting this PR:

tests/models/multimodal/generation/test_multimodal_gguf.py::test_models[10-32-bfloat16-model0]
  /home/mozf/develop-projects/vllm/tests/models/multimodal/generation/test_multimodal_gguf.py:87: UserWarning: Test0:
  Matched tokens:       [108, 2094]
  original:     '\n\nThis image captures a vibrant street scene in a Chinatown district, likely in Australia. The focal point is a grand, ornate Chinese gate, painted in vibrant red'  {2471: Logprob(logprob=-1.2075515985488892, rank=2, decoded_token=' image'), 10807: Logprob(logprob=-1.2075515985488892, rank=1, decoded_token=' photograph'), 4429: Logprob(logprob=-1.9575515985488892, rank=3, decoded_token=' photo'), 563: Logprob(logprob=-1.9575515985488892, rank=4, decoded_token=' is'), 28239: Logprob(logprob=-2.4575514793395996, rank=5, decoded_token=' vibrant'), 5777: Logprob(logprob=-4.7075514793396, rank=6, decoded_token=' wide'), 7804: Logprob(logprob=-5.3325514793396, rank=7, decoded_token=' bright'), 2258: Logprob(logprob=-6.2075514793396, rank=8, decoded_token=' color'), 11690: Logprob(logprob=-6.2075514793396, rank=9, decoded_token=' outdoor'), 5719: Logprob(logprob=-6.3325514793396, rank=10, decoded_token=' shot')}
  gguf: '\n\nThis photo captures a vibrant street scene in Chinatown, likely in Sydney, Australia. The focal point is a traditional Chinese gate, adorned with red and gold decorations'        {4429: Logprob(logprob=-1.0050928592681885, rank=1, decoded_token=' photo'), 10807: Logprob(logprob=-1.2550928592681885, rank=2, decoded_token=' photograph'), 2471: Logprob(logprob=-1.5050928592681885, rank=3, decoded_token=' image'), 563: Logprob(logprob=-2.7550928592681885, rank=4, decoded_token=' is'), 28239: Logprob(logprob=-3.3800928592681885, rank=5, decoded_token=' vibrant'), 11690: Logprob(logprob=-5.005092620849609, rank=6, decoded_token=' outdoor'), 5777: Logprob(logprob=-5.505092620849609, rank=7, decoded_token=' wide'), 6083: Logprob(logprob=-5.755092620849609, rank=8, decoded_token=' picture'), 5719: Logprob(logprob=-6.005092620849609, rank=9, decoded_token=' shot'), 2258: Logprob(logprob=-6.255092620849609, rank=10, decoded_token=' color')}
    check_logprobs_close(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================================================== 1 passed, 13 warnings in 195.93s (0:03:15) ===============================================================================

Therefore, given that the incorrect implemented custom mask has broken existing CI. We should revert it to make it pass again instead of ignoring the failure by suspension.

@Isotr0py Isotr0py merged commit 64192d5 into vllm-project:main Nov 20, 2025
53 checks passed
@Isotr0py Isotr0py deleted the revert-custom-attn branch November 20, 2025 05:23
LuminolT pushed a commit to LuminolT/vllm that referenced this pull request Nov 21, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: LuminolT <lumischen01@gmail.com>
lucianommartins added a commit to lucianommartins/my-vllm that referenced this pull request Nov 21, 2025
Restores custom attention mask generation for Gemma3 GGUF multimodal models
that was partially reverted in vllm-project#28995. Implements robust GGUF-only guards to
ensure the feature only applies to GGUF models and does not affect HF models.

Changes:
- Add uses_custom_attention_masks() utility with GGUF file format check
- Add uses_custom_attention_masks property to ModelConfig
- Initialize uses_custom_attention_masks in GPUModelRunner
- Restore generate_attention_masks() method to Gemma3ForConditionalGeneration
- Implement 3-layer defense-in-depth guard mechanism

The implementation uses check_gguf_file() to guarantee that custom attention
mask logic only triggers for GGUF files, preventing the issue that caused
the original revert where HF models incorrectly triggered the custom logic.

Tested with GGUF models (1B, 4B, 270M) for both text-only and multimodal
inference. HF model compatibility verified via pytest multimodal test suite.

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
lucianommartins added a commit to lucianommartins/my-vllm that referenced this pull request Nov 21, 2025
Restores custom attention mask generation for Gemma3 GGUF multimodal models
that was partially reverted in vllm-project#28995. Implements robust GGUF-only guards to
ensure the feature only applies to GGUF models and does not affect HF models.

Changes:
- Add uses_custom_attention_masks() utility with GGUF file format check
- Add uses_custom_attention_masks property to ModelConfig
- Initialize uses_custom_attention_masks in GPUModelRunner
- Restore generate_attention_masks() method to Gemma3ForConditionalGeneration
- Implement 3-layer defense-in-depth guard mechanism

The implementation uses check_gguf_file() to guarantee that custom attention
mask logic only triggers for GGUF files, preventing the issue that caused
the original revert where HF models incorrectly triggered the custom logic.

Tested with GGUF models (1B, 4B, 270M) for both text-only and multimodal
inference. HF model compatibility verified via pytest multimodal test suite.

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants