-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
fp8 kv cache support fix for torch.compile #22758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fp8 kv cache support fix for torch.compile #22758
Conversation
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fix for torch.compile
errors related to q_scale
assertions in the attention layer, particularly for FP8 KV cache on HIP devices. The approach of creating a float copy _q_scale_float
for assertions is sound and consistent with existing patterns for k_scale
and v_scale
. The changes in vllm/attention/layer.py
and vllm/v1/attention/backends/triton_attn.py
are correct. However, there is a potential issue in vllm/model_executor/layers/quantization/kv_cache.py
where _q_scale_float
might be assigned a tensor instead of a float, which could lead to the same torch.compile
issues. I've added a comment with a suggested fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix
To clarify: this issue would present itself when using full_cuda_graph:true and using the unified attention backend.
Would happen on CUDA and ROCm>=7.0
ROCm<7.0 allows to access tensor contents on the CPU side (assert is one example of suc access) during graph capture
cc @SageMoore
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
76153d5
to
c9e1469
Compare
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
This pull request has merge conflicts that must be resolved before it can be |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the work!
Please try merge from main to fix the ci issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please merge from main to solve the pre-commit issue
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Signed-off-by: charlifu <charlifu@amd.com>
Torch compile was erroring in attention layer on assert for q_scale to be equals 1.0. The error came originally from HIP saying that operation is not allowed during cuda graph capture. Thus implementing a copy of q_scale - q_scale_float (similar to k_scale_float and v_scale float).
PS q_scale needs to be one because upscalling doesn't happen on AMD from predeceasing GEMMs and scales are only applied to k and v if those are in fp8.