-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[BugFix][torch.compile] KV scale calculation issues with FP8 quantization (#21640) #25513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix][torch.compile] KV scale calculation issues with FP8 quantization (#21640) #25513
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses two issues related to FP8 quantization's KV scale calculation: a torch.compile
failure due to data-dependent branching, and an AttributeError
in eager mode when attention metadata is None
. The fix refactors the scale calculation logic into a new custom operator, unified_kv_scale_calc
. This effectively hides the dynamic control flow from torch.compile
, resolving the compilation error. Additionally, the code now safely handles cases where attention metadata might be None
, preventing the AttributeError
. The changes are well-contained and directly address the reported problems. The implementation appears correct and I did not find any high or critical severity issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking this on! I think this is a good start but I believe there are still issues:
- attention metadata might change between requests but I think the current approach will bake in the value set during compilation during the control flow
- I don't think this is cudagraph-compatible.
I suggested a resolution for 1 inline. For 2, we'll have to modify the model runner and explicitly set the runtime cudagraph mode to none if calc_kv_scales is true on the metadata object.
vllm/attention/layer.py
Outdated
forward_ctx = get_forward_context() | ||
attn_metadata = (forward_ctx.attn_metadata | ||
if forward_ctx else None) | ||
|
||
scale_calc = bool( | ||
getattr(attn_metadata, 'enable_kv_scales_calculation', False | ||
) if attn_metadata is not None else False) | ||
|
||
torch.ops.vllm.unified_kv_scale_calc(query, key, value, | ||
self._q_scale, self._k_scale, | ||
self._v_scale, self.q_range, | ||
self.k_range, self.v_range, | ||
scale_calc) | ||
|
||
if scale_calc: | ||
self._q_scale_float = self._q_scale.item() | ||
self._k_scale_float = self._k_scale.item() | ||
self._v_scale_float = self._v_scale.item() | ||
|
||
# We only calculate the scales once. | ||
self.calculate_kv_scales = False | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code can look like this:
forward_ctx = get_forward_context() | |
attn_metadata = (forward_ctx.attn_metadata | |
if forward_ctx else None) | |
scale_calc = bool( | |
getattr(attn_metadata, 'enable_kv_scales_calculation', False | |
) if attn_metadata is not None else False) | |
torch.ops.vllm.unified_kv_scale_calc(query, key, value, | |
self._q_scale, self._k_scale, | |
self._v_scale, self.q_range, | |
self.k_range, self.v_range, | |
scale_calc) | |
if scale_calc: | |
self._q_scale_float = self._q_scale.item() | |
self._k_scale_float = self._k_scale.item() | |
self._v_scale_float = self._v_scale.item() | |
# We only calculate the scales once. | |
self.calculate_kv_scales = False | |
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, <layer name>) |
Inside the op, you can look at the forward context to extract the layer instance and the attention metadata to do the control flow as well as save the scale values (and manipulate them in any other way - perhaps just call calc_kv_scales
. You can take a look at the implementation of unified_attention
to see how it's done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback! I updated the PR to address both points:
-
Metadata handling: I moved the control flow and scale updates into the new
maybe_calc_kv_scales
op. It now looks up the layer and metadata from the forward context directly (following the same pattern as unified_attention). -
CUDA graph compatibility: In gpu_model_runner.py, I added logic to explicitly set
cudagraph_runtime_mode = NONE
if any layer’sattn_metadata.enable_kv_scales_calculation
is true. That way we skip CUDA graph capture only when scale calculation is required, and fall back to the normal cudagraph path otherwise.
Let me know if you’d like me to tweak the control flow further or if there are any other key parts I should adjust.
e3ba09a
to
a719898
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is overall the right approach. Please also post lm_eval numbers and performance numbers both when kv_cache calc is used and when it's not used. We especially want to make sure this doesn't slow down the general case. Also please add an e2e test for the issue that you fixed.
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: adabeyta <aabeyta@redhat.com>
Signed-off-by: adabeyta <aabeyta@redhat.com>
ee03301
to
6ec29a1
Compare
Signed-off-by: adabeyta <aabeyta@redhat.com>
6ec29a1
to
8586d2b
Compare
Here are the results of the lm_eval numbers. The normal inference path (without KV scale calculation) is not affected.
Test case reproduction stepsCommands UsedServer with KV scale calculation ON (remove --calculate-kv-scales flag for OFF case):
Eval Command:
|
I'm not sure I understand the prf comparison. Could you compare main (no kv calc), pr (no kv calc), and pr (kv calc), and clarify which is which? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see now torch.ops.vllm.maybe_calc_kv_scales
doesn't even appear in the graph if disabled. Still would be good to see the numbers just in case.
Here are the main (no kv calc), pr (no kv calc), and pr (kv calc) perf: PR: No KV Calc
Main No KV Calc
PR: KV Calc
Throughput resultsPerf Command:
Main No KV Calc
PR: No KV Calc
PR KV Calc
|
…tion (vllm-project#25513) Signed-off-by: adabeyta <aabeyta@redhat.com>
…tion (#25513) Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
Purpose
Fix KV scale calculation incompatibility with torch.compile and enforce_eager mode (#21640)
The conditional check
if attn_metadata.enable_kv_scales_calculation
: at line 274 invllm/attention/layer.py
caused two failures:This PR moves the scale calculation logic into a custom operator (unified_kv_scale_calc) that avoids dynamic control flow and safely handles None values.
Test Plan
If
enforce_eager=False
With code:
Before Fix:
After Fix:
RequestOutput(request_id=0, prompt='Hello, world!', prompt_token_ids=[9707, 11, 1879, 0], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=" I'm a new user. I'm trying to find out if there's a", token_ids=[358, 2776, 264, 501, 1196, 13, 358, 2776, 4460, 311, 1477, 700, 421, 1052, 594, 264], cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})
If
enforce_eager=True
With code:
Before Fix:
After Fix:
RequestOutput(request_id=0, prompt='Hello, world!', prompt_token_ids=[9707, 11, 1879, 0], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=" I'm a new user here. How can I get started? Welcome! To", token_ids=[358, 2776, 264, 501, 1196, 1588, 13, 2585, 646, 358, 633, 3855, 30, 20166, 0, 2014], cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.