Skip to content
8 changes: 8 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False

# Whether this backend supports receiving pre-quantized query input.
# If True, the attention layer will handle query quantization instead
# of the backend, allowing torch.compile to fuse quantization with
# previous operations.
# Needs to be worked through for all backends
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False

@staticmethod
@abstractmethod
def get_name() -> str:
Expand Down
23 changes: 22 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op
Expand Down Expand Up @@ -247,6 +250,13 @@ def __init__(
"This may be caused by insufficient memory to allocate "
"kv cache.") from e

# for attn backends supporting query quantization
self.query_quant = None
if self.kv_cache_dtype.startswith(
"fp8") and self.attn_backend.supports_quant_query_input:
self.query_quant = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)

def forward(
self,
query: torch.Tensor,
Expand All @@ -270,11 +280,22 @@ def forward(
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)

output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
# torch.compile to fuse this into previous ops
# which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
query, _ = self.query_quant(query, self._q_scale)

if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.zeros(output_shape,
dtype=query.dtype,
dtype=output_dtype,
device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
Expand Down
9 changes: 2 additions & 7 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import torch

from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
Expand Down Expand Up @@ -38,6 +37,7 @@
class FlashAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True
supports_quant_query_input: bool = True

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
Expand Down Expand Up @@ -506,16 +506,11 @@ def forward(
)

if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype)
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))

if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
Expand Down