From fbedc35d44d5af9fb73e504cc8ffe507592c0eb2 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 11 Sep 2025 19:52:54 +0000 Subject: [PATCH 1/4] Enable FP8 Signed-off-by: Matthew Bonanni --- vllm/envs.py | 2 ++ vllm/platforms/cuda.py | 4 +++- vllm/v1/attention/backends/mla/flashinfer_mla.py | 13 ++++++------- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 8ca7e0f19428..6d7caf257237 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -475,6 +475,8 @@ def get_vllm_port() -> Optional[int]: # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA # - "FLASH_ATTN_MLA": use FlashAttention for MLA + # - "FLASHINFER_MLA": use FlashInfer for MLA + # - "CUTLASS_MLA": use CUTLASS for MLA "VLLM_ATTENTION_BACKEND": lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e40b6eb2b5a4..525b1cb535f1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -535,7 +535,9 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, attention_backend = "FLASHMLA" # Only FlashMLA and CUTLASS_MLA support fp8 - if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]: + if attention_backend in [ + "FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA" + ]: supported = True else: supported = (not fp8_attention) diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 71eb9e0ce70e..352b7f4df351 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -6,8 +6,7 @@ import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, @@ -69,10 +68,6 @@ def __init__( "are not implemented for " "FlashInferMLAImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashInferMLA V1 with FP8 KV cache not yet supported") - self._workspace_buffer = g_fi_workspace def _forward_decode( @@ -92,6 +87,9 @@ def _forward_decode( # trtllm API requires extra dimension q_len_per_request for MTP q = q.unsqueeze(1) + bmm1_scale = layer._q_scale.item() * layer._k_scale.item() * self.scale + bmm2_scale = layer._v_scale.item() + o = trtllm_batch_decode_with_kv_cache_mla( query=q, kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), @@ -102,7 +100,8 @@ def _forward_decode( block_tables=attn_metadata.decode.block_table, seq_lens=attn_metadata.decode.seq_lens, max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=self.scale, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, ) # TODO: Return LSE pending support from Flashinfer API: From 433b0861b7b1b7f83dd4e99dfc3cb43d07742870 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 11 Sep 2025 17:03:43 -0700 Subject: [PATCH 2/4] Whitespace Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 525b1cb535f1..2f5e6cfdcb63 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -179,6 +179,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config.block_size = 128 logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA backend.") + if use_flashinfer_mla and cache_config.block_size not in [32, 64]: cache_config.block_size = 64 logger.info( From 4af0e82cda2673178fef35f8d7e0a879c270733b Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 11 Sep 2025 17:35:40 -0700 Subject: [PATCH 3/4] Run prefill at q precision Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 036a281f1d26..a990cb2f1a97 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -584,7 +584,6 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.model_config.dtype, - kv_data_type=self.kv_cache_spec.dtype, ) # Prepare context prefills @@ -605,7 +604,6 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): logits_soft_cap=self._global_hyperparameters. logits_soft_cap, q_data_type=self.model_config.dtype, - kv_data_type=self.kv_cache_spec.dtype, ) prefill.prefill_main = self._fi_prefill_main From b34819e9b4a2465e1762292825a0872c5c1e9fa5 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 12 Sep 2025 07:07:48 -0700 Subject: [PATCH 4/4] Only compute bmm scales once Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_attention_backends.py | 1 + tests/v1/tpu/test_pallas.py | 2 ++ vllm/attention/backends/abstract.py | 1 + vllm/model_executor/layers/quantization/kv_cache.py | 2 ++ vllm/v1/attention/backends/mla/flashinfer_mla.py | 13 +++++++++---- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 1ae8b91c347a..0b7e103beca6 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -178,6 +178,7 @@ def __init__(self, device: torch.device): self._k_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(1.0, device=device) # Add float versions for flashinfer + self._q_scale_float = 1.0 self._k_scale_float = 1.0 self._v_scale_float = 1.0 diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index bfba3af57f71..1bc8dff317a7 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -33,10 +33,12 @@ def test_ragged_paged_attention(): ) class FakeAttentionLayer: + _q_scale_float: float _k_scale_float: float _v_scale_float: float layer = FakeAttentionLayer() + layer._q_scale_float = 1.0 layer._k_scale_float = 1.0 layer._v_scale_float = 1.0 diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0217bff6adaf..75bcdc4bbcf0 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -240,6 +240,7 @@ class AttentionLayer(Protocol): _q_scale: torch.Tensor _k_scale: torch.Tensor _v_scale: torch.Tensor + _q_scale_float: float _k_scale_float: float _v_scale_float: float _prob_scale: torch.Tensor diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index e5604670fb4c..4c6fcda893a0 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -88,6 +88,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: "Setting it to k_scale. This only matters for " "the flash-attn backend.") layer._q_scale.copy_(k_scale) + layer._q_scale_float = k_scale # These are used in the final Attention.forward() layer._k_scale.copy_(k_scale) @@ -124,6 +125,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) + layer._q_scale_float = q_scale layer._prob_scale.copy_(prob_scale) if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 352b7f4df351..701248670f72 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -69,6 +69,8 @@ def __init__( "FlashInferMLAImpl") self._workspace_buffer = g_fi_workspace + self.bmm1_scale: Optional[float] = None + self.bmm2_scale: Optional[float] = None def _forward_decode( self, @@ -87,8 +89,11 @@ def _forward_decode( # trtllm API requires extra dimension q_len_per_request for MTP q = q.unsqueeze(1) - bmm1_scale = layer._q_scale.item() * layer._k_scale.item() * self.scale - bmm2_scale = layer._v_scale.item() + if self.bmm1_scale is None: + self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * + self.scale) + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float o = trtllm_batch_decode_with_kv_cache_mla( query=q, @@ -100,8 +105,8 @@ def _forward_decode( block_tables=attn_metadata.decode.block_table, seq_lens=attn_metadata.decode.seq_lens, max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=bmm1_scale, - bmm2_scale=bmm2_scale, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, ) # TODO: Return LSE pending support from Flashinfer API: