Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(self, device: torch.device):
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
# Add float versions for flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0

Expand Down
2 changes: 2 additions & 0 deletions tests/v1/tpu/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ def test_ragged_paged_attention():
)

class FakeAttentionLayer:
_q_scale_float: float
_k_scale_float: float
_v_scale_float: float

layer = FakeAttentionLayer()
layer._q_scale_float = 1.0
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0

Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class AttentionLayer(Protocol):
_q_scale: torch.Tensor
_k_scale: torch.Tensor
_v_scale: torch.Tensor
_q_scale_float: float
_k_scale_float: float
_v_scale_float: float
_prob_scale: torch.Tensor
Expand Down
2 changes: 2 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ def get_vllm_port() -> Optional[int]:
# - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
# - "FLASHINFER_MLA": use FlashInfer for MLA
# - "CUTLASS_MLA": use CUTLASS for MLA
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"Setting it to k_scale. This only matters for "
"the flash-attn backend.")
layer._q_scale.copy_(k_scale)
layer._q_scale_float = k_scale

# These are used in the final Attention.forward()
layer._k_scale.copy_(k_scale)
Expand Down Expand Up @@ -124,6 +125,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._q_scale_float = q_scale
layer._prob_scale.copy_(prob_scale)
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
or prob_scale == 1.0):
Expand Down
5 changes: 4 additions & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
cache_config.block_size = 128
logger.info("Forcing kv cache block size to 128 for "
"CUTLASS_MLA backend.")

if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
cache_config.block_size = 64
logger.info(
Expand Down Expand Up @@ -535,7 +536,9 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
attention_backend = "FLASHMLA"

# Only FlashMLA and CUTLASS_MLA support fp8
if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]:
if attention_backend in [
"FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"
]:
supported = True
else:
supported = (not fp8_attention)
Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,6 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)

# Prepare context prefills
Expand All @@ -605,7 +604,6 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
logits_soft_cap=self._global_hyperparameters.
logits_soft_cap,
q_data_type=self.model_config.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)

prefill.prefill_main = self._fi_prefill_main
Expand Down
18 changes: 11 additions & 7 deletions vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla

from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
Expand Down Expand Up @@ -69,11 +68,9 @@ def __init__(
"are not implemented for "
"FlashInferMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashInferMLA V1 with FP8 KV cache not yet supported")

self._workspace_buffer = g_fi_workspace
self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None

def _forward_decode(
self,
Expand All @@ -92,6 +89,12 @@ def _forward_decode(
# trtllm API requires extra dimension q_len_per_request for MTP
q = q.unsqueeze(1)

if self.bmm1_scale is None:
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
self.scale)
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float

o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
Expand All @@ -102,7 +105,8 @@ def _forward_decode(
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.scale,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)

# TODO: Return LSE pending support from Flashinfer API:
Expand Down