From 0ee023ecd82f015fabd16854156c695ed5e4576f Mon Sep 17 00:00:00 2001 From: adabeyta Date: Mon, 22 Sep 2025 23:19:29 +0000 Subject: [PATCH 1/3] Resolve kv_scale/dynamo issue Signed-off-by: adabeyta --- vllm/attention/layer.py | 56 +++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 326fe6dd048a..58b5690ea29b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -341,16 +341,6 @@ def forward( return torch.ops.vllm.unified_attention( query, key, value, self.layer_name) - def calc_kv_scales(self, query, key, value): - self._q_scale.copy_(torch.abs(query).max() / self.q_range) - self._k_scale.copy_(torch.abs(key).max() / self.k_range) - self._v_scale.copy_(torch.abs(value).max() / self.v_range) - self._q_scale_float = self._q_scale.item() - self._k_scale_float = self._k_scale.item() - self._v_scale_float = self._v_scale.item() - # We only calculate the scales once - self.calculate_kv_scales = False - def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore s += f", num_heads={self.impl.num_heads}" # type: ignore @@ -554,6 +544,52 @@ def maybe_save_kv_layer_to_connector( attn_metadata[layer_name]) +def unified_kv_scale_calc( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + q_range: torch.Tensor, + k_range: torch.Tensor, + v_range: torch.Tensor, + scale_calc: bool, +) -> None: + + if not scale_calc: + return + + q_scale.copy_(torch.abs(query).max() / q_range) + k_scale.copy_(torch.abs(key).max() / k_range) + v_scale.copy_(torch.abs(value).max() / v_range) + + +def unified_kv_scale_calc_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + q_range: torch.Tensor, + k_range: torch.Tensor, + v_range: torch.Tensor, + scale_calc: bool, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_kv_scale_calc", + op_func=unified_kv_scale_calc, + mutates_args=["q_scale", "k_scale", "v_scale"], + fake_impl=unified_kv_scale_calc_fake, + dispatch_key=current_platform.dispatch_key, + tags=tag_cudagraph_unsafe, +) + + def unified_attention( query: torch.Tensor, key: torch.Tensor, From 180e26460b72485d5c4eea6d35105a6e2182a6ff Mon Sep 17 00:00:00 2001 From: adabeyta Date: Wed, 24 Sep 2025 22:38:57 +0000 Subject: [PATCH 2/3] Move KV scales logic to custom operator for torch.compile compatibiliy Signed-off-by: adabeyta --- vllm/attention/layer.py | 57 ++++++++++++++++-------------- vllm/v1/worker/gpu_model_runner.py | 9 +++++ 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 58b5690ea29b..fc8ca0593b99 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -277,9 +277,8 @@ def forward( `vllm.forward_context.get_forward_context().attn_metadata`. """ if self.calculate_kv_scales: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(query, key, value) + torch.ops.vllm.maybe_calc_kv_scales(query, key, value, + self.layer_name) output_dtype = query.dtype if self.query_quant is not None: @@ -341,6 +340,16 @@ def forward( return torch.ops.vllm.unified_attention( query, key, value, self.layer_name) + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.q_range) + self._k_scale.copy_(torch.abs(key).max() / self.k_range) + self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._q_scale_float = self._q_scale.item() + self._k_scale_float = self._k_scale.item() + self._v_scale_float = self._v_scale.item() + # We only calculate the scales once + self.calculate_kv_scales = False + def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore s += f", num_heads={self.impl.num_heads}" # type: ignore @@ -544,47 +553,41 @@ def maybe_save_kv_layer_to_connector( attn_metadata[layer_name]) -def unified_kv_scale_calc( +def maybe_calc_kv_scales( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - q_scale: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - q_range: torch.Tensor, - k_range: torch.Tensor, - v_range: torch.Tensor, - scale_calc: bool, + layer_name: str, ) -> None: - if not scale_calc: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + + if attn_metadata is None or not getattr( + attn_metadata, 'enable_kv_scales_calculation', False): return - q_scale.copy_(torch.abs(query).max() / q_range) - k_scale.copy_(torch.abs(key).max() / k_range) - v_scale.copy_(torch.abs(value).max() / v_range) + self = forward_context.no_compile_layers[layer_name] + self.calc_kv_scales(query, key, value) -def unified_kv_scale_calc_fake( +def maybe_calc_kv_scales_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - q_scale: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - q_range: torch.Tensor, - k_range: torch.Tensor, - v_range: torch.Tensor, - scale_calc: bool, + layer_name: str, ) -> None: return direct_register_custom_op( - op_name="unified_kv_scale_calc", - op_func=unified_kv_scale_calc, - mutates_args=["q_scale", "k_scale", "v_scale"], - fake_impl=unified_kv_scale_calc_fake, + op_name="maybe_calc_kv_scales", + op_func=maybe_calc_kv_scales, + mutates_args=[], + fake_impl=maybe_calc_kv_scales_fake, dispatch_key=current_platform.dispatch_key, tags=tag_cudagraph_unsafe, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f199dbd991f4..91c2df3a291c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2275,6 +2275,15 @@ def execute_model( cudagraph_runtime_mode, batch_descriptor = \ self.cudagraph_dispatcher.dispatch(batch_descriptor) + # Set cudagraph mode to none if calc_kv_scales is true. + if attn_metadata is not None: + metadata_list = (attn_metadata.values() if isinstance( + attn_metadata, dict) else [attn_metadata]) + if any( + getattr(m, 'enable_kv_scales_calculation', False) + for m in metadata_list): + cudagraph_runtime_mode = CUDAGraphMode.NONE + # This is currently to get around the assert in the DPMetadata # where it wants `num_tokens_across_dp` to align with `num_tokens` if ubatch_slices is not None: From 8586d2bf05d3ebd149a31372b5b172ea451ccf54 Mon Sep 17 00:00:00 2001 From: adabeyta Date: Fri, 26 Sep 2025 19:30:06 +0000 Subject: [PATCH 3/3] Update custom op reg and add e2e testing Signed-off-by: adabeyta --- tests/compile/test_full_graph.py | 15 +++++++++++++++ vllm/attention/layer.py | 4 +--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 870aa553ca62..f9f146810924 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -139,6 +139,21 @@ def test_custom_compile_config( run_model(compilation_config, model, model_kwargs) +@pytest.mark.parametrize( + "optimization_level", + [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE], +) +def test_fp8_kv_scale_compile(optimization_level: int): + model = "Qwen/Qwen2-0.5B" + model_kwargs = { + "quantization": "fp8", + "kv_cache_dtype": "fp8_e4m3", + "calculate_kv_scales": True, + "max_model_len": 512, + } + run_model(optimization_level, model, model_kwargs) + + def test_inductor_graph_partition_attn_fusion(caplog_vllm): if not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("inductor graph partition is only available " diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index fc8ca0593b99..d97c87d96e99 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -586,10 +586,8 @@ def maybe_calc_kv_scales_fake( direct_register_custom_op( op_name="maybe_calc_kv_scales", op_func=maybe_calc_kv_scales, - mutates_args=[], + mutates_args=["query", "key", "value"], fake_impl=maybe_calc_kv_scales_fake, - dispatch_key=current_platform.dispatch_key, - tags=tag_cudagraph_unsafe, )