From bfd2cd6a33a97f5fcaf4cabb6b0033e48c6bc701 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Mon, 15 Sep 2025 08:27:55 +0000 Subject: [PATCH 01/13] add env flag to make query quantization fusable Signed-off-by: Jonas Kuebler --- vllm/attention/layer.py | 4 ++++ vllm/envs.py | 7 +++++++ vllm/v1/attention/backends/flash_attn.py | 14 ++++++++------ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index baa83e29bdd0..85d84357f498 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -280,6 +280,10 @@ def forward( # We skip reshaping query, key and value tensors for the MLA # backend since these tensors have different semantics and are # processed differently. + if envs.VLLM_FUSE_QUERY_QUANT: + assert self._q_scale.numel() == 1 + query = (query / self._q_scale).to(torch.float8_e4m3fn) + if not self.use_mla: # Reshape the query, key, and value tensors. # NOTE(woosuk): We do this outside the custom op to minimize the diff --git a/vllm/envs.py b/vllm/envs.py index 689428ec5910..0b95d4aa7764 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -160,6 +160,7 @@ VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_SLEEP_WHEN_IDLE: bool = False + VLLM_FUSE_QUERY_QUANT: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Optional[Literal["NHD", "HND"]] = None @@ -1256,6 +1257,12 @@ def get_vllm_port() -> Optional[int]: "VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), + # Fuse query quantization into the attention layer instead of + # doing it in the attention backend. Then torch.compile can + # fuse it into previous ops and reduce overhead. + "VLLM_FUSE_QUERY_QUANT": + lambda: bool(int(os.getenv("VLLM_FUSE_QUERY_QUANT", "0"))), + # Control the max chunk bytes (in MB) for the rpc message queue. # Object larger than this threshold will be broadcast to worker # processes via zmq. diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a2e18f970bec..f70ee0e5f56d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -7,6 +7,7 @@ import numpy as np import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -510,12 +511,13 @@ def forward( self.kv_cache_dtype) key_cache = key_cache.view(dtype) value_cache = value_cache.view(dtype) - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) + if not envs.VLLM_FUSE_QUERY_QUANT: + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) if not attn_metadata.use_cascade: cu_seqlens_q = attn_metadata.query_start_loc From f150176eda1687e7077e218217217a202f45a4b8 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Tue, 16 Sep 2025 19:10:22 +0000 Subject: [PATCH 02/13] fix to only apply to fp8 layers Signed-off-by: Jonas Kuebler --- vllm/attention/layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 85d84357f498..008e52465e54 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -280,7 +280,8 @@ def forward( # We skip reshaping query, key and value tensors for the MLA # backend since these tensors have different semantics and are # processed differently. - if envs.VLLM_FUSE_QUERY_QUANT: + if (envs.VLLM_FUSE_QUERY_QUANT + and self.kv_cache_dtype.startswith("fp8")): assert self._q_scale.numel() == 1 query = (query / self._q_scale).to(torch.float8_e4m3fn) From 8a36973c938e69768b4e40d27d580ae17b6a9feb Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Thu, 18 Sep 2025 09:02:14 +0000 Subject: [PATCH 03/13] add e5m2 handling Signed-off-by: Jonas Kuebler --- vllm/attention/layer.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 008e52465e54..9d760949fdcf 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -277,14 +277,26 @@ def forward( dtype=query.dtype, device=query.device) hidden_size = output_shape[-1] + + if envs.VLLM_FUSE_QUERY_QUANT: + # quantizing with a simple torch operation enables + # torch.compile to fuse this into previous ops + # which reduces overheads during decoding. + # Otherwise queries are quantized using custom ops + # which causes decoding overheads + assert self._q_scale.numel() == 1 + if self.kv_cache_dtype in ["fp8", "fp8_e4m3"]: + query = (query / self._q_scale).to(torch.float8_e4m3fn) + elif self.kv_cache_dtype == "fp8_e5m2": + query = (query / self._q_scale).to(torch.float8_e5m2) + else: + raise NotImplementedError( + "VLLM_FUSE_QUERY_QUANT only supported for fp8_e4m3 " + "and fp8_e5m2") + # We skip reshaping query, key and value tensors for the MLA # backend since these tensors have different semantics and are # processed differently. - if (envs.VLLM_FUSE_QUERY_QUANT - and self.kv_cache_dtype.startswith("fp8")): - assert self._q_scale.numel() == 1 - query = (query / self._q_scale).to(torch.float8_e4m3fn) - if not self.use_mla: # Reshape the query, key, and value tensors. # NOTE(woosuk): We do this outside the custom op to minimize the From 7ea71ef7ff2215cb7085d85baf88e0240b2bd234 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Thu, 18 Sep 2025 09:13:45 +0000 Subject: [PATCH 04/13] update variable description Signed-off-by: Jonas Kuebler --- vllm/envs.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 0b95d4aa7764..3d76fd34b18f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1257,9 +1257,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), - # Fuse query quantization into the attention layer instead of - # doing it in the attention backend. Then torch.compile can - # fuse it into previous ops and reduce overhead. + # Quantize the query in the attention layer with a simple + # pytorch operation instead of a custom op in the attention backend. + # Then torch.compile can fuse it and reduce overhead. + # Only relevant for quantized attention w/ FA3 "VLLM_FUSE_QUERY_QUANT": lambda: bool(int(os.getenv("VLLM_FUSE_QUERY_QUANT", "0"))), @@ -1522,6 +1523,7 @@ def compute_hash() -> str: "VLLM_USE_CUDNN_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", + "VLLM_FUSE_QUERY_QUANT", "VLLM_ROCM_USE_AITER", "VLLM_ROCM_USE_AITER_PAGED_ATTN", "VLLM_ROCM_USE_AITER_LINEAR", From 2bf80a1eee2687c7efb03b9f90c23c483d392a53 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Tue, 23 Sep 2025 06:54:33 +0000 Subject: [PATCH 05/13] fix bug. don't try to quantize auto layers Signed-off-by: Jonas Kuebler --- vllm/attention/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9d760949fdcf..bc687997f424 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -278,7 +278,7 @@ def forward( device=query.device) hidden_size = output_shape[-1] - if envs.VLLM_FUSE_QUERY_QUANT: + if envs.VLLM_FUSE_QUERY_QUANT and self.kv_cache_dtype != "auto": # quantizing with a simple torch operation enables # torch.compile to fuse this into previous ops # which reduces overheads during decoding. From 296050a1732663594f3102bbe8c9cad138666d44 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Tue, 23 Sep 2025 07:14:58 +0000 Subject: [PATCH 06/13] fix pre-commit Signed-off-by: Jonas Kuebler --- vllm/v1/attention/backends/flash_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f70ee0e5f56d..2cdae97e2fe8 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -7,7 +7,6 @@ import numpy as np import torch -import vllm.envs as envs from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, From b30425b0ee11ba4e7e53464e0ef516bc2232a850 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Wed, 24 Sep 2025 07:59:53 +0000 Subject: [PATCH 07/13] rework Signed-off-by: Jonas Kuebler --- vllm/attention/backends/abstract.py | 6 ++++ vllm/attention/layer.py | 35 +++++++++++++----------- vllm/v1/attention/backends/flash_attn.py | 10 ++----- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 1b392cd7c88d..dfd56e5673ee 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -34,6 +34,12 @@ class AttentionBackend(ABC): # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False + # Whether this backend supports receiving pre-quantized query input. + # If True, the attention layer will handle query quantization instead + # of the backend, allowing torch.compile to fuse quantization with + # previous operations. + supports_quant_query_input: bool = False + @staticmethod @abstractmethod def get_name() -> str: diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index bc687997f424..648f9c715dd5 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -22,7 +22,10 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import _Backend, current_platform from vllm.utils import GiB_bytes, direct_register_custom_op @@ -247,6 +250,13 @@ def __init__( "This may be caused by insufficient memory to allocate " "kv cache.") from e + # for attn backends supporting query quantization + self.query_quant = None + if self.kv_cache_dtype.startswith( + "fp8") and self.attn_backend.supports_quant_query_input: + self.query_quant = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + def forward( self, query: torch.Tensor, @@ -270,6 +280,15 @@ def forward( attn_metadata = get_forward_context().attn_metadata if attn_metadata.enable_kv_scales_calculation: self.calc_kv_scales(query, key, value) + + if self.query_quant is not None: + # quantizing with a simple torch operation enables + # torch.compile to fuse this into previous ops + # which reduces overheads during decoding. + # Otherwise queries are quantized using custom ops + # which causes decoding overheads + query, _ = self.query_quant.forward_native(query, self._q_scale) + if self.use_output: output_shape = (output_shape if output_shape is not None else query.shape) @@ -278,22 +297,6 @@ def forward( device=query.device) hidden_size = output_shape[-1] - if envs.VLLM_FUSE_QUERY_QUANT and self.kv_cache_dtype != "auto": - # quantizing with a simple torch operation enables - # torch.compile to fuse this into previous ops - # which reduces overheads during decoding. - # Otherwise queries are quantized using custom ops - # which causes decoding overheads - assert self._q_scale.numel() == 1 - if self.kv_cache_dtype in ["fp8", "fp8_e4m3"]: - query = (query / self._q_scale).to(torch.float8_e4m3fn) - elif self.kv_cache_dtype == "fp8_e5m2": - query = (query / self._q_scale).to(torch.float8_e5m2) - else: - raise NotImplementedError( - "VLLM_FUSE_QUERY_QUANT only supported for fp8_e4m3 " - "and fp8_e5m2") - # We skip reshaping query, key and value tensors for the MLA # backend since these tensors have different semantics and are # processed differently. diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 2cdae97e2fe8..7a50bb5d3134 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -7,7 +7,6 @@ import numpy as np import torch -from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, @@ -38,6 +37,7 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + supports_quant_query_input: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: @@ -506,17 +506,11 @@ def forward( ) if self.kv_cache_dtype.startswith("fp8"): + # queries are quantized in the attention layer dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( self.kv_cache_dtype) key_cache = key_cache.view(dtype) value_cache = value_cache.view(dtype) - if not envs.VLLM_FUSE_QUERY_QUANT: - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) if not attn_metadata.use_cascade: cu_seqlens_q = attn_metadata.query_start_loc From f9936d37ff197ad8b0e5dc18875c339e5c1a0788 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Wed, 24 Sep 2025 12:26:25 +0000 Subject: [PATCH 08/13] track original input/output dtype Signed-off-by: Jonas Kuebler --- vllm/attention/layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 648f9c715dd5..bfb897d7463f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -281,6 +281,7 @@ def forward( if attn_metadata.enable_kv_scales_calculation: self.calc_kv_scales(query, key, value) + output_dtype = query.dtype if self.query_quant is not None: # quantizing with a simple torch operation enables # torch.compile to fuse this into previous ops @@ -293,7 +294,7 @@ def forward( output_shape = (output_shape if output_shape is not None else query.shape) output = torch.zeros(output_shape, - dtype=query.dtype, + dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] From 932715b1db411624e0143067629acbb40213da02 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Wed, 24 Sep 2025 12:43:58 +0000 Subject: [PATCH 09/13] revert env changes Signed-off-by: Jonas Kuebler --- vllm/envs.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 3d76fd34b18f..689428ec5910 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -160,7 +160,6 @@ VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_SLEEP_WHEN_IDLE: bool = False - VLLM_FUSE_QUERY_QUANT: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Optional[Literal["NHD", "HND"]] = None @@ -1257,13 +1256,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), - # Quantize the query in the attention layer with a simple - # pytorch operation instead of a custom op in the attention backend. - # Then torch.compile can fuse it and reduce overhead. - # Only relevant for quantized attention w/ FA3 - "VLLM_FUSE_QUERY_QUANT": - lambda: bool(int(os.getenv("VLLM_FUSE_QUERY_QUANT", "0"))), - # Control the max chunk bytes (in MB) for the rpc message queue. # Object larger than this threshold will be broadcast to worker # processes via zmq. @@ -1523,7 +1515,6 @@ def compute_hash() -> str: "VLLM_USE_CUDNN_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", - "VLLM_FUSE_QUERY_QUANT", "VLLM_ROCM_USE_AITER", "VLLM_ROCM_USE_AITER_PAGED_ATTN", "VLLM_ROCM_USE_AITER_LINEAR", From 40de7b4cc5bfec1cb2f962dee2de66a41c0eb8b6 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Wed, 24 Sep 2025 12:46:02 +0000 Subject: [PATCH 10/13] style Signed-off-by: Jonas Kuebler --- vllm/attention/layer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index bfb897d7463f..afd61bdc2fc8 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -297,7 +297,6 @@ def forward( dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] - # We skip reshaping query, key and value tensors for the MLA # backend since these tensors have different semantics and are # processed differently. From 3473b18bb8c1662ec3febc2b72e1f14bd81fb30a Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Wed, 24 Sep 2025 13:11:05 +0000 Subject: [PATCH 11/13] add assert Signed-off-by: Jonas Kuebler --- vllm/attention/layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index afd61bdc2fc8..616dacba1057 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -288,6 +288,7 @@ def forward( # which reduces overheads during decoding. # Otherwise queries are quantized using custom ops # which causes decoding overheads + assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} query, _ = self.query_quant.forward_native(query, self._q_scale) if self.use_output: From 9e027a726e4f2bec8013df5d91bc65200899cd3f Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Wed, 24 Sep 2025 13:21:42 +0000 Subject: [PATCH 12/13] link issue to rework other backends Signed-off-by: Jonas Kuebler --- vllm/attention/backends/abstract.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index dfd56e5673ee..02b4492485f5 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -38,6 +38,8 @@ class AttentionBackend(ABC): # If True, the attention layer will handle query quantization instead # of the backend, allowing torch.compile to fuse quantization with # previous operations. + # Needs to be worked through for all backends + # https://github.com/vllm-project/vllm/issues/25584 supports_quant_query_input: bool = False @staticmethod From 5ec35b4ab88d0ca040fc7a53507bf01b9ac61d08 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Wed, 24 Sep 2025 17:09:27 +0000 Subject: [PATCH 13/13] use direct call Signed-off-by: Jonas Kuebler --- vllm/attention/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 616dacba1057..17281c89516d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -289,7 +289,7 @@ def forward( # Otherwise queries are quantized using custom ops # which causes decoding overheads assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} - query, _ = self.query_quant.forward_native(query, self._q_scale) + query, _ = self.query_quant(query, self._q_scale) if self.use_output: output_shape = (output_shape