Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
try_get_generation_config,
try_get_safetensors_metadata,
try_get_tokenizer_config,
uses_custom_attention_masks,
uses_mrope,
)
from vllm.transformers_utils.gguf_utils import (
Expand Down Expand Up @@ -1624,10 +1623,6 @@ def uses_alibi(self) -> bool:
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)

@property
def uses_custom_attention_masks(self) -> bool:
return uses_custom_attention_masks(self.hf_config)

@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
Expand Down
138 changes: 1 addition & 137 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def _process_image_input(
def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
Expand Down Expand Up @@ -644,142 +644,6 @@ def forward(

return hidden_states

def generate_attention_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
) -> dict[str, Any]:
"""Generate custom attention masks for Gemma3 multimodal inputs.

This is called by V1 engine's gpu_model_runner during preprocessing
to generate attention masks that allow bidirectional attention between
image tokens while maintaining causal attention for text.
"""
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i]
end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
seq_lens.append(end_idx - start_idx)

global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_idx, seq_len in enumerate(seq_lens):
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]

# Find image token positions
img_pos = input_token_ids == self.config.image_token_index

start_idx = end_idx

# Create a global causal mask
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0 (causal attention)
global_attn_mask = global_attn_mask.triu(diagonal=1)

# Enable bidirectional attention between image tokens
img_mask = torch.zeros_like(global_attn_mask)
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)

# GGUF compatibility: config might be Gemma3TextConfig directly
text_config = getattr(self.config, "text_config", self.config)
sliding_window = text_config.sliding_window
if sliding_window is not None:
# Create a local causal mask with sliding window (1024)
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)

return {
"has_images": True,
"seq_lens": seq_lens,
"global_attn_masks": global_attn_masks,
"local_attn_masks": local_attn_masks,
}

def prepare_attn_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
**kwargs,
):
kwargs["has_images"] = True
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i].item()
if i < num_seqs - 1:
end_idx = start_indices[i + 1].item()
else:
end_idx = len(input_ids)
seq_lens.append(end_idx - start_idx)
kwargs["seq_lens"] = seq_lens

global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_len in seq_lens:
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
start_idx = end_idx
# Create a global causal mask.
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0.
global_attn_mask = global_attn_mask.triu(diagonal=1)

# Consider the bidirectional attention between image tokens.
img_mask = torch.zeros_like(global_attn_mask)
img_pos = input_token_ids == self.config.image_token_index
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)

sliding_window = self.config.text_config.sliding_window
if sliding_window is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
return kwargs

def compute_logits(
self,
hidden_states: torch.Tensor,
Expand Down
11 changes: 0 additions & 11 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,17 +477,6 @@ def is_interleaved(config: PretrainedConfig) -> bool:
return False


def uses_custom_attention_masks(config: PretrainedConfig) -> bool:
"""Detect if model uses custom attention mask generation for multimodal.

Some multimodal models require custom attention masks that enable
bidirectional attention between image tokens while maintaining causal
attention for text tokens. Currently applies to Gemma3 multimodal models.
"""
architectures = getattr(config, "architectures", [])
return "Gemma3ForConditionalGeneration" in architectures


def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
"""
Update kwargs for AutoConfig initialization based on model_type
Expand Down
19 changes: 0 additions & 19 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def __init__(
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
self.uses_custom_attention_masks = model_config.uses_custom_attention_masks
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
Expand Down Expand Up @@ -2351,24 +2350,6 @@ def _preprocess(
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}

# Generate custom attention masks for models that require them.
# V1 pre-generates embeddings, so forward() skips prepare_attn_masks().
# Check mm_features (mm_embeds is empty during decode).
has_mm_features = any(
req_state.mm_features for req_state in self.requests.values()
)
if (
self.uses_custom_attention_masks
and has_mm_features
and hasattr(self.model, "generate_attention_masks")
):
mask_kwargs = self.model.generate_attention_masks(
self.input_ids.gpu[:num_scheduled_tokens],
self.positions.gpu[:num_scheduled_tokens],
mask_dtype=self.model.dtype,
)
model_kwargs.update(mask_kwargs)
elif self.enable_prompt_embeds and is_first_rank:
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.
Expand Down