diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 6af6faa2b296..79f4203018cf 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -274,6 +274,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. @@ -300,25 +302,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.proj", disable_tp=use_data_parallel) - - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( - head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype()) - self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True - - if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA - }: - raise RuntimeError( - f"Qwen2.5-VL does not support {self.attn_backend} backend now." - ) + self.attn_backend = attn_backend + self.use_upstream_fa = use_upstream_fa self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA } @@ -443,6 +428,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -455,7 +442,9 @@ def __init__( projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + attn_backend=attn_backend, + use_upstream_fa=use_upstream_fa) self.mlp = Qwen2_5_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, @@ -627,17 +616,35 @@ def __init__( head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + use_upstream_fa = False + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Qwen2.5-VL does not support {self.attn_backend} backend now." + ) + self.blocks = nn.ModuleList([ - Qwen2_5_VisionBlock(dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=get_act_and_mul_fn( - vision_config.hidden_act), - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(depth) + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=get_act_and_mul_fn(vision_config.hidden_act), + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa) for layer_idx in range(depth) ]) self.merger = Qwen2_5_VisionPatchMerger( d_model=vision_config.out_hidden_size, @@ -648,12 +655,6 @@ def __init__( prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, ) - self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f3f11438eeee..f1aeb99a4d37 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -63,7 +63,7 @@ PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend +from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.utils import is_list_of @@ -158,6 +158,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -170,7 +172,9 @@ def __init__( projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + attn_backend=attn_backend, + use_upstream_fa=use_upstream_fa) self.mlp = Qwen3_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, @@ -287,19 +291,6 @@ def __init__( head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Qwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(vision_config.depth) - ]) - self.merger = Qwen3_VisionPatchMerger( d_model=vision_config.out_hidden_size, context_dim=self.hidden_size, @@ -325,10 +316,42 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) + use_upstream_fa = False if self.attn_backend != _Backend.FLASH_ATTN and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Qwen3-VL does not support {self.attn_backend} backend now.") + if current_platform.is_device_capability( + 100) and self.attn_backend != _Backend.TORCH_SDPA: + # TODO(Roger/Wentao): remove this after FA + # or XFORMERS's issue fixed on Blackwell + logger.info_once("Qwen3-VL vision attention does not support " + f"{self.attn_backend} backend on Blackwell now. " + "Vision attention backend is set to TORCH_SDPA.") + self.attn_backend = _Backend.TORCH_SDPA + + self.blocks = nn.ModuleList([ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa) + for layer_idx in range(vision_config.depth) + ]) @property def dtype(self) -> torch.dtype: