diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 4c6fcda893a0..275a1c43fdd2 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -125,7 +125,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) - layer._q_scale_float = q_scale + layer._q_scale_float = q_scale.item() if isinstance( + q_scale, torch.Tensor) else q_scale + layer._prob_scale.copy_(prob_scale) if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index c294a5a73cbd..784912a122f6 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -361,7 +361,7 @@ def forward( key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape - assert layer._q_scale == 1.0, \ + assert layer._q_scale_float == 1.0, \ "A non 1.0 q_scale is not currently supported." if current_platform.is_cuda(): # Skip Q quantization on ROCm and XPU, enable this on cuda