diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 1b392cd7c88d..02b4492485f5 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -34,6 +34,14 @@ class AttentionBackend(ABC): # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False + # Whether this backend supports receiving pre-quantized query input. + # If True, the attention layer will handle query quantization instead + # of the backend, allowing torch.compile to fuse quantization with + # previous operations. + # Needs to be worked through for all backends + # https://github.com/vllm-project/vllm/issues/25584 + supports_quant_query_input: bool = False + @staticmethod @abstractmethod def get_name() -> str: diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index baa83e29bdd0..17281c89516d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -22,7 +22,10 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import _Backend, current_platform from vllm.utils import GiB_bytes, direct_register_custom_op @@ -247,6 +250,13 @@ def __init__( "This may be caused by insufficient memory to allocate " "kv cache.") from e + # for attn backends supporting query quantization + self.query_quant = None + if self.kv_cache_dtype.startswith( + "fp8") and self.attn_backend.supports_quant_query_input: + self.query_quant = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + def forward( self, query: torch.Tensor, @@ -270,11 +280,22 @@ def forward( attn_metadata = get_forward_context().attn_metadata if attn_metadata.enable_kv_scales_calculation: self.calc_kv_scales(query, key, value) + + output_dtype = query.dtype + if self.query_quant is not None: + # quantizing with a simple torch operation enables + # torch.compile to fuse this into previous ops + # which reduces overheads during decoding. + # Otherwise queries are quantized using custom ops + # which causes decoding overheads + assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} + query, _ = self.query_quant(query, self._q_scale) + if self.use_output: output_shape = (output_shape if output_shape is not None else query.shape) output = torch.zeros(output_shape, - dtype=query.dtype, + dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] # We skip reshaping query, key and value tensors for the MLA diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a2e18f970bec..7a50bb5d3134 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -7,7 +7,6 @@ import numpy as np import torch -from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, @@ -38,6 +37,7 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + supports_quant_query_input: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: @@ -506,16 +506,11 @@ def forward( ) if self.kv_cache_dtype.startswith("fp8"): + # queries are quantized in the attention layer dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( self.kv_cache_dtype) key_cache = key_cache.view(dtype) value_cache = value_cache.view(dtype) - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) if not attn_metadata.use_cascade: cu_seqlens_q = attn_metadata.query_start_loc