diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 1ae8b91c347a..0b7e103beca6 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -178,6 +178,7 @@ def __init__(self, device: torch.device): self._k_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(1.0, device=device) # Add float versions for flashinfer + self._q_scale_float = 1.0 self._k_scale_float = 1.0 self._v_scale_float = 1.0 diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index bfba3af57f71..1bc8dff317a7 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -33,10 +33,12 @@ def test_ragged_paged_attention(): ) class FakeAttentionLayer: + _q_scale_float: float _k_scale_float: float _v_scale_float: float layer = FakeAttentionLayer() + layer._q_scale_float = 1.0 layer._k_scale_float = 1.0 layer._v_scale_float = 1.0 diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0217bff6adaf..75bcdc4bbcf0 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -240,6 +240,7 @@ class AttentionLayer(Protocol): _q_scale: torch.Tensor _k_scale: torch.Tensor _v_scale: torch.Tensor + _q_scale_float: float _k_scale_float: float _v_scale_float: float _prob_scale: torch.Tensor diff --git a/vllm/envs.py b/vllm/envs.py index 8ca7e0f19428..6d7caf257237 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -475,6 +475,8 @@ def get_vllm_port() -> Optional[int]: # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA # - "FLASH_ATTN_MLA": use FlashAttention for MLA + # - "FLASHINFER_MLA": use FlashInfer for MLA + # - "CUTLASS_MLA": use CUTLASS for MLA "VLLM_ATTENTION_BACKEND": lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index e5604670fb4c..4c6fcda893a0 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -88,6 +88,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: "Setting it to k_scale. This only matters for " "the flash-attn backend.") layer._q_scale.copy_(k_scale) + layer._q_scale_float = k_scale # These are used in the final Attention.forward() layer._k_scale.copy_(k_scale) @@ -124,6 +125,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) + layer._q_scale_float = q_scale layer._prob_scale.copy_(prob_scale) if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e40b6eb2b5a4..2f5e6cfdcb63 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -179,6 +179,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config.block_size = 128 logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA backend.") + if use_flashinfer_mla and cache_config.block_size not in [32, 64]: cache_config.block_size = 64 logger.info( @@ -535,7 +536,9 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, attention_backend = "FLASHMLA" # Only FlashMLA and CUTLASS_MLA support fp8 - if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]: + if attention_backend in [ + "FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA" + ]: supported = True else: supported = (not fp8_attention) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 036a281f1d26..a990cb2f1a97 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -584,7 +584,6 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.model_config.dtype, - kv_data_type=self.kv_cache_spec.dtype, ) # Prepare context prefills @@ -605,7 +604,6 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): logits_soft_cap=self._global_hyperparameters. logits_soft_cap, q_data_type=self.model_config.dtype, - kv_data_type=self.kv_cache_spec.dtype, ) prefill.prefill_main = self._fi_prefill_main diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 71eb9e0ce70e..701248670f72 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -6,8 +6,7 @@ import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, @@ -69,11 +68,9 @@ def __init__( "are not implemented for " "FlashInferMLAImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashInferMLA V1 with FP8 KV cache not yet supported") - self._workspace_buffer = g_fi_workspace + self.bmm1_scale: Optional[float] = None + self.bmm2_scale: Optional[float] = None def _forward_decode( self, @@ -92,6 +89,12 @@ def _forward_decode( # trtllm API requires extra dimension q_len_per_request for MTP q = q.unsqueeze(1) + if self.bmm1_scale is None: + self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * + self.scale) + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + o = trtllm_batch_decode_with_kv_cache_mla( query=q, kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), @@ -102,7 +105,8 @@ def _forward_decode( block_tables=attn_metadata.decode.block_table, seq_lens=attn_metadata.decode.seq_lens, max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=self.scale, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, ) # TODO: Return LSE pending support from Flashinfer API: