From 164c0f381bc007fe7779aef98df6c25197936ec3 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 12 Aug 2025 01:36:44 +0000 Subject: [PATCH 1/3] fp8 kv cache support for fp4 llama 3.1 405B Signed-off-by: Aleksandr Malyshev --- vllm/attention/layer.py | 5 +++++ vllm/model_executor/layers/quantization/kv_cache.py | 2 ++ vllm/v1/attention/backends/triton_attn.py | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1a9c0e26b53c..75c860401f81 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -128,6 +128,10 @@ def __init__( self._q_scale = torch.tensor(1.0, dtype=torch.float32) self._prob_scale = torch.tensor(1.0, dtype=torch.float32) + # Keeping float32 version of _q_scale tensor for assertions + # during graph capture. Otherwise asserts are triggeting HIP error + self._q_scale_float = 1.0 + # We also keep the float32 versions of k/v_scale for attention # backends that don't support tensors (Flashinfer) self._k_scale_float = 1.0 @@ -291,6 +295,7 @@ def calc_kv_scales(self, query, key, value): self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() # We only calculate the scales once diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index e5604670fb4c..7a869c59117f 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -124,6 +124,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) + layer._q_scale_float = q_scale + layer._prob_scale.copy_(prob_scale) if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index c33afbfebcde..43b06d697bbb 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -356,7 +356,7 @@ def forward( key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape - assert layer._q_scale == 1.0, \ + assert layer._q_scale_float == 1.0, \ "A non 1.0 q_scale is not currently supported." if not current_platform.is_rocm(): # Skip Q quantization on ROCm, since dequantizing back to From c9e1469a5ce111753142435c8db4bef861ee4385 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 12 Aug 2025 21:23:59 +0000 Subject: [PATCH 2/3] minor fix for q_scale to be tensor fix Signed-off-by: Gregory Shtrasberg --- vllm/model_executor/layers/quantization/kv_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 7a869c59117f..eb38ace619a3 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -124,7 +124,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) - layer._q_scale_float = q_scale + layer._q_scale_float = q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale layer._prob_scale.copy_(prob_scale) if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 From c29c2fd5604f3d19a2d470387de000cd59720dbe Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 14 Aug 2025 21:42:36 +0000 Subject: [PATCH 3/3] lint error fix Signed-off-by: Aleksandr Malyshev --- vllm/model_executor/layers/quantization/kv_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index eb38ace619a3..79752c257b18 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -124,7 +124,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) - layer._q_scale_float = q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale + layer._q_scale_float = q_scale.item() if isinstance( + q_scale, torch.Tensor) else q_scale layer._prob_scale.copy_(prob_scale) if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0