diff --git a/vllm/config/model.py b/vllm/config/model.py index b563a40eb8fc..4efecb1b4a72 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -31,7 +31,6 @@ try_get_generation_config, try_get_safetensors_metadata, try_get_tokenizer_config, - uses_custom_attention_masks, uses_mrope, ) from vllm.transformers_utils.gguf_utils import ( @@ -1624,10 +1623,6 @@ def uses_alibi(self) -> bool: def uses_mrope(self) -> bool: return uses_mrope(self.hf_config) - @property - def uses_custom_attention_masks(self) -> bool: - return uses_custom_attention_masks(self.hf_config) - @property def is_multimodal_model(self) -> bool: return self.multimodal_config is not None diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index fe83c8b63b01..43c69e5e1399 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -596,7 +596,7 @@ def _process_image_input( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -644,142 +644,6 @@ def forward( return hidden_states - def generate_attention_masks( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - mask_dtype: torch.dtype, - ) -> dict[str, Any]: - """Generate custom attention masks for Gemma3 multimodal inputs. - - This is called by V1 engine's gpu_model_runner during preprocessing - to generate attention masks that allow bidirectional attention between - image tokens while maintaining causal attention for text. - """ - # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. - # This is a HACK. Fix this. - start_indices = (positions == 0).cpu().nonzero() - num_seqs = len(start_indices) - seq_lens = [] - for i in range(num_seqs): - start_idx = start_indices[i] - end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids) - seq_lens.append(end_idx - start_idx) - - global_attn_masks = [] - local_attn_masks = [] - start_idx = 0 - for seq_idx, seq_len in enumerate(seq_lens): - end_idx = start_idx + seq_len - input_token_ids = input_ids[start_idx:end_idx] - - # Find image token positions - img_pos = input_token_ids == self.config.image_token_index - - start_idx = end_idx - - # Create a global causal mask - global_attn_mask = torch.empty( - 1, - 1, - seq_len, - seq_len, - dtype=mask_dtype, - device=input_ids.device, - ) - global_attn_mask.fill_(float("-inf")) - # Fill the lower triangle with 0 (causal attention) - global_attn_mask = global_attn_mask.triu(diagonal=1) - - # Enable bidirectional attention between image tokens - img_mask = torch.zeros_like(global_attn_mask) - img_mask[:, :, :, img_pos] += 1 - img_mask[:, :, img_pos, :] += 1 - global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) - global_attn_masks.append(global_attn_mask) - - # GGUF compatibility: config might be Gemma3TextConfig directly - text_config = getattr(self.config, "text_config", self.config) - sliding_window = text_config.sliding_window - if sliding_window is not None: - # Create a local causal mask with sliding window (1024) - local_attn_mask = torch.ones_like(global_attn_mask) - local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window) - local_attn_mask = torch.where( - local_attn_mask == 0, global_attn_mask, float("-inf") - ) - local_attn_masks.append(local_attn_mask) - - return { - "has_images": True, - "seq_lens": seq_lens, - "global_attn_masks": global_attn_masks, - "local_attn_masks": local_attn_masks, - } - - def prepare_attn_masks( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - mask_dtype: torch.dtype, - **kwargs, - ): - kwargs["has_images"] = True - # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. - # This is a HACK. Fix this. - start_indices = (positions == 0).cpu().nonzero() - num_seqs = len(start_indices) - seq_lens = [] - for i in range(num_seqs): - start_idx = start_indices[i].item() - if i < num_seqs - 1: - end_idx = start_indices[i + 1].item() - else: - end_idx = len(input_ids) - seq_lens.append(end_idx - start_idx) - kwargs["seq_lens"] = seq_lens - - global_attn_masks = [] - local_attn_masks = [] - start_idx = 0 - for seq_len in seq_lens: - end_idx = start_idx + seq_len - input_token_ids = input_ids[start_idx:end_idx] - start_idx = end_idx - # Create a global causal mask. - global_attn_mask = torch.empty( - 1, - 1, - seq_len, - seq_len, - dtype=mask_dtype, - device=input_ids.device, - ) - global_attn_mask.fill_(float("-inf")) - # Fill the lower triangle with 0. - global_attn_mask = global_attn_mask.triu(diagonal=1) - - # Consider the bidirectional attention between image tokens. - img_mask = torch.zeros_like(global_attn_mask) - img_pos = input_token_ids == self.config.image_token_index - img_mask[:, :, :, img_pos] += 1 - img_mask[:, :, img_pos, :] += 1 - global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) - global_attn_masks.append(global_attn_mask) - - sliding_window = self.config.text_config.sliding_window - if sliding_window is not None: - # Create a local causal mask with sliding window (1024). - local_attn_mask = torch.ones_like(global_attn_mask) - local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window) - local_attn_mask = torch.where( - local_attn_mask == 0, global_attn_mask, float("-inf") - ) - local_attn_masks.append(local_attn_mask) - kwargs["global_attn_masks"] = global_attn_masks - kwargs["local_attn_masks"] = local_attn_masks - return kwargs - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index ac4a71648cec..49250e071eab 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -477,17 +477,6 @@ def is_interleaved(config: PretrainedConfig) -> bool: return False -def uses_custom_attention_masks(config: PretrainedConfig) -> bool: - """Detect if model uses custom attention mask generation for multimodal. - - Some multimodal models require custom attention masks that enable - bidirectional attention between image tokens while maintaining causal - attention for text tokens. Currently applies to Gemma3 multimodal models. - """ - architectures = getattr(config, "architectures", []) - return "Gemma3ForConditionalGeneration" in architectures - - def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str): """ Update kwargs for AutoConfig initialization based on model_type diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3b00085b6bb9..b54af3cc1e1a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -324,7 +324,6 @@ def __init__( # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope - self.uses_custom_attention_masks = model_config.uses_custom_attention_masks self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( model_config ) @@ -2351,24 +2350,6 @@ def _preprocess( **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } - - # Generate custom attention masks for models that require them. - # V1 pre-generates embeddings, so forward() skips prepare_attn_masks(). - # Check mm_features (mm_embeds is empty during decode). - has_mm_features = any( - req_state.mm_features for req_state in self.requests.values() - ) - if ( - self.uses_custom_attention_masks - and has_mm_features - and hasattr(self.model, "generate_attention_masks") - ): - mask_kwargs = self.model.generate_attention_masks( - self.input_ids.gpu[:num_scheduled_tokens], - self.positions.gpu[:num_scheduled_tokens], - mask_dtype=self.model.dtype, - ) - model_kwargs.update(mask_kwargs) elif self.enable_prompt_embeds and is_first_rank: # Get the input embeddings for the tokens that are not input embeds, # then put them into the appropriate positions.