Skip to content

Conversation

adabeyta
Copy link
Contributor

@adabeyta adabeyta commented Sep 23, 2025

Purpose

Fix KV scale calculation incompatibility with torch.compile and enforce_eager mode (#21640)

The conditional check if attn_metadata.enable_kv_scales_calculation: at line 274 in vllm/attention/layer.py caused two failures:

  • _Dynamo compilation error (enforce_eager=False): Data-dependent branching that torch.compile cannot trace through
  • _AttributeError (enforce_eager=True): attn_metadata could be None, causing crashes

This PR moves the scale calculation logic into a custom operator (unified_kv_scale_calc) that avoids dynamic control flow and safely handles None values.

Test Plan

If enforce_eager=False

With code:

import vllm

def main():
    engine = vllm.LLM(
        model="Qwen/Qwen3-8B",
        tensor_parallel_size=2,
        kv_cache_dtype="fp8_e4m3",
        calculate_kv_scales=True,
    )
    output = engine.generate("Hello, world!")
    print(output)


if __name__ == "__main__":
    main()

Before Fix:

RuntimeError: Worker failed with error 'Data-dependent branching
  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
  Hint: Use `torch.cond` to express dynamic control flow.

  Developer debug context: attempted to jump with GetAttrVariable(ConstantVariable(NoneType: None), enable_kv_scales_calculation)With code:

After Fix:

RequestOutput(request_id=0, prompt='Hello, world!', prompt_token_ids=[9707, 11, 1879, 0], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=" I'm a new user. I'm trying to find out if there's a", token_ids=[358, 2776, 264, 501, 1196, 13, 358, 2776, 4460, 311, 1477, 700, 421, 1052, 594, 264], cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})

If enforce_eager=True

With code:

import vllm

def main():
    engine = vllm.LLM(
        model="Qwen/Qwen3-8B",
        tensor_parallel_size=2,
        kv_cache_dtype="fp8_e4m3",
        calculate_kv_scales=True,
        enforce_eager=True,
    )
    output = engine.generate("Hello, world!")
    print(output)


if __name__ == "__main__":
    main()## Test Result

Before Fix:

(VllmWorker rank=1 pid=87573) ERROR 07-26 00:49:21 [multiproc_executor.py:546]   File "/home/ubuntu/miniconda3/envs/vllm-new/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(VllmWorker rank=1 pid=87573) ERROR 07-26 00:49:21 [multiproc_executor.py:546]     return forward_call(*args, **kwargs)
(VllmWorker rank=1 pid=87573) ERROR 07-26 00:49:21 [multiproc_executor.py:546]   File "/home/ubuntu/miniconda3/envs/vllm-new/lib/python3.10/site-packages/vllm/attention/layer.py", line 239, in forward
(VllmWorker rank=1 pid=87573) ERROR 07-26 00:49:21 [multiproc_executor.py:546]     if attn_metadata.enable_kv_scales_calculation:
(VllmWorker rank=1 pid=87573) ERROR 07-26 00:49:21 [multiproc_executor.py:546] AttributeError: 'NoneType' object has no attribute 'enable_kv_scales_calculation'

After Fix:

RequestOutput(request_id=0, prompt='Hello, world!', prompt_token_ids=[9707, 11, 1879, 0], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=" I'm a new user here. How can I get started? Welcome! To", token_ids=[358, 2776, 264, 501, 1196, 1588, 13, 2585, 646, 358, 633, 3855, 30, 20166, 0, 2014], cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses two issues related to FP8 quantization's KV scale calculation: a torch.compile failure due to data-dependent branching, and an AttributeError in eager mode when attention metadata is None. The fix refactors the scale calculation logic into a new custom operator, unified_kv_scale_calc. This effectively hides the dynamic control flow from torch.compile, resolving the compilation error. Additionally, the code now safely handles cases where attention metadata might be None, preventing the AttributeError. The changes are well-contained and directly address the reported problems. The implementation appears correct and I did not find any high or critical severity issues.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this on! I think this is a good start but I believe there are still issues:

  1. attention metadata might change between requests but I think the current approach will bake in the value set during compilation during the control flow
  2. I don't think this is cudagraph-compatible.

I suggested a resolution for 1 inline. For 2, we'll have to modify the model runner and explicitly set the runtime cudagraph mode to none if calc_kv_scales is true on the metadata object.

Comment on lines 273 to 294
forward_ctx = get_forward_context()
attn_metadata = (forward_ctx.attn_metadata
if forward_ctx else None)

scale_calc = bool(
getattr(attn_metadata, 'enable_kv_scales_calculation', False
) if attn_metadata is not None else False)

torch.ops.vllm.unified_kv_scale_calc(query, key, value,
self._q_scale, self._k_scale,
self._v_scale, self.q_range,
self.k_range, self.v_range,
scale_calc)

if scale_calc:
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()

# We only calculate the scales once.
self.calculate_kv_scales = False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code can look like this:

Suggested change
forward_ctx = get_forward_context()
attn_metadata = (forward_ctx.attn_metadata
if forward_ctx else None)
scale_calc = bool(
getattr(attn_metadata, 'enable_kv_scales_calculation', False
) if attn_metadata is not None else False)
torch.ops.vllm.unified_kv_scale_calc(query, key, value,
self._q_scale, self._k_scale,
self._v_scale, self.q_range,
self.k_range, self.v_range,
scale_calc)
if scale_calc:
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once.
self.calculate_kv_scales = False
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, <layer name>)

Inside the op, you can look at the forward context to extract the layer instance and the attention metadata to do the control flow as well as save the scale values (and manipulate them in any other way - perhaps just call calc_kv_scales. You can take a look at the implementation of unified_attention to see how it's done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! I updated the PR to address both points:

  1. Metadata handling: I moved the control flow and scale updates into the new maybe_calc_kv_scales op. It now looks up the layer and metadata from the forward context directly (following the same pattern as unified_attention).

  2. CUDA graph compatibility: In gpu_model_runner.py, I added logic to explicitly set cudagraph_runtime_mode = NONE if any layer’s attn_metadata.enable_kv_scales_calculation is true. That way we skip CUDA graph capture only when scale calculation is required, and fall back to the normal cudagraph path otherwise.

Let me know if you’d like me to tweak the control flow further or if there are any other key parts I should adjust.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is overall the right approach. Please also post lm_eval numbers and performance numbers both when kv_cache calc is used and when it's not used. We especially want to make sure this doesn't slow down the general case. Also please add an e2e test for the issue that you fixed.

Copy link

mergify bot commented Sep 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @adabeyta.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 26, 2025
Signed-off-by: adabeyta <aabeyta@redhat.com>
@adabeyta adabeyta force-pushed the kv_scales_dynamo_fix branch from ee03301 to 6ec29a1 Compare September 26, 2025 19:59
@mergify mergify bot removed the needs-rebase label Sep 26, 2025
Signed-off-by: adabeyta <aabeyta@redhat.com>
@adabeyta adabeyta force-pushed the kv_scales_dynamo_fix branch from 6ec29a1 to 8586d2b Compare September 26, 2025 20:04
@adabeyta
Copy link
Contributor Author

e2e

Here are the results of the lm_eval numbers. The normal inference path (without KV scale calculation) is not affected.

Configuration GSM8K Accuracy Throughput Time
Without KV calc (baseline) 89.5% ± 3.1% 2.80 req/s 35s
With KV calc 88.0% ± 3.3% 2.77 req/s 36s
Delta -1.5% -1.1% +2.9%

Test case reproduction steps

Commands Used

Server with KV scale calculation ON (remove --calculate-kv-scales flag for OFF case):

vllm serve Qwen/Qwen3-8B \
    --tensor-parallel-size 2 \
    --quantization fp8 \
    --kv-cache-dtype fp8_e4m3 \
    --calculate-kv-scales \
    --port 8000

Eval Command:

lm_eval --model local-completions \
    --model_args model=Qwen/Qwen3-8B,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3 \
    --tasks gsm8k \
    --num_fewshot 5 \
    --batch_size auto \
    --limit 100

@ProExpertProg
Copy link
Collaborator

I'm not sure I understand the prf comparison. Could you compare main (no kv calc), pr (no kv calc), and pr (kv calc), and clarify which is which?

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now torch.ops.vllm.maybe_calc_kv_scales doesn't even appear in the graph if disabled. Still would be good to see the numbers just in case.

@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 26, 2025
@adabeyta
Copy link
Contributor Author

adabeyta commented Sep 29, 2025

I see now torch.ops.vllm.maybe_calc_kv_scales doesn't even appear in the graph if disabled. Still would be good to see the numbers just in case.

Here are the main (no kv calc), pr (no kv calc), and pr (kv calc) perf:

PR: No KV Calc

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.89 ± 0.0314
strict-match 5 exact_match 0.89 ± 0.0314

Main No KV Calc

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.90 ± 0.0302
strict-match 5 exact_match 0.89 ± 0.0314

PR: KV Calc

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.89 ± 0.0314
strict-match 5 exact_match 0.90 ± 0.0302

Throughput results

Perf Command:

vllm bench throughput --model Qwen/Qwen3-8B --quantization fp8 --kv-cache-dtype fp8_e4m3 --tensor-parallel-size 2 --dataset-name random --input-len 1024 --output-len 256

Main No KV Calc
Throughput: 1.49 requests/s, 1910.60 total tokens/s, 382.12 output tokens/s
Total num prompt tokens: 1024000
Total num output tokens: 256000



PR: No KV Calc
Throughput: 1.50 requests/s, 1911.01 total tokens/s, 383.02 output tokens/s
Total num prompt tokens: 1021253
Total num output tokens: 256000



PR KV Calc
Throughput: 1.50 requests/s, 1911.61 total tokens/s, 383.14 output tokens/s
Total num prompt tokens: 1021253
Total num output tokens: 256000

@ProExpertProg ProExpertProg merged commit c42ff4f into vllm-project:main Sep 29, 2025
44 checks passed
@ProExpertProg ProExpertProg changed the title [Bugfix] KV scale calculation issues with FP8 quantization (#21640) [BugFix][torch.compile] KV scale calculation issues with FP8 quantization (#21640) Sep 29, 2025
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
…tion (#25513)

Signed-off-by: adabeyta <aabeyta@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: calculate_kv_scales leads to dynamo compilation issue; enforce_eager=True leads to another issue
2 participants