-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[torch.compile] Make Query Quantization Fusable #24914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] Make Query Quantization Fusable #24914
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good idea but let's make it more extensible to other attention backends and lose the environment variable.
I think we should add a class property to the attention backend called supports_quant_query_input
(False by default). If that property is True, we initialize a QuantFP8
object in the attention layer init:
self.query_quant = None
if self.kv_cache_dtype.startswith("fp8") and self.attn_backend.supports_quant_query_input:
# We'll need to add support for e5m2 to QuantFP8, shouldn't be hard!
# Just add a `qdtype` param with `current_platform.fp8_dtype()` as the default
self.query_quant = QuantFP8(True, GroupShape.PER_TENSOR, qdtype=...)
Layer forward:
if self.query_quant is not None:
query = self.query_quant(query, self._q_scale)
Thanks @ProExpertProg ! |
@jmkuebler I think we should always do it. The reason I mention the If you choose the gradual migration path, please create an issue for the rest of the migration and link it next to the flag. |
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
65a2337
to
f9936d3
Compare
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor WIP notes
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
@ProExpertProg thanks a ton for the guidance and the comments. I now did it for the FA backend and would defer other backends to a future PR (need some time to onboard on those etc). Created an issue to track here #25584 My changes are ready for another round of review, also note the updated PR description with latest testing results. |
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice and clean addition, thanks!
@jmkuebler could you run lm_eval just to check E2E correctness? Both kvcache={auto,fp8} for main and PR. |
@ProExpertProg added the results to the PR description. Good idea to add this here to track. |
@jmkuebler not to be the worst but I noticed that you used enforce_eager in the lm eval. That will enable the cuda kernel for quant instead of the torch impl (at least by default). Could you run without enforce eager as well? |
@ProExpertProg sure thing. With compilation it looks also good (the tiny variations are expected when using slightly different graphs).
Main+fp8 (w/ compile)
PR+fp8 (w/compile)
|
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
Purpose
When running attention in FP8, quantizing the queries causes overhead, since this is done via a custom operation. And because it is a custom op it cannot be fused by torch compile. Profiling shows that this is a (context-lenght-independent) overhead.
This PR moves the query quantization out of the attention backend into
attention/layer.py
and uses a simple torch implementation. Then torch compile is able to fuse it and reduce the quantization overheads.It is currently only done for the FA backend, but the design is generalizable. Created an Issue to track #25584
Test Plan
Spin up server
Benchmark
unit test covering these changes:
pytest /efs/kuebj/kv_cache/vllm_source/tests/quantization/test_fp8.py::test_kv_cache_model_load_and_run
✅Accuracy
To ensure there is no accidental accuracy degradation we also run
with
kv_cache_dtype
in{auto,fp8}
both on this PR and on mainline. We also run withoutenforce_eager=True
for the FP8 variantsTest Result
Accuracy on GSM8k ✅
The PR preserves the accuracy numbers both for FP8 and auto. That FP8 can be slightly worse than auto is expected.
Main+auto
PR + auto
Main+fp8 (enforce-eager)
PR+fp8 (enforce-eager)
Main+fp8 (w/ compile)
PR+fp8 (w/compile)
[old description until 09/24]
## Purpose This PR is part of a series to make FP8 attention fast for GPT-OSS, see https://github.com//issues/24916 for context. However, this specific PR would make quantization faster for any model using quantized Queries.Quantizing the queries causes overhead, since this is done via a custom operation, it cannot be fused by torch compile. Profiling shows that this is a (context-lenght-independent) overhead. This PR adds an env variable that allows to perform the query quantization via simple torch operations. When running with torch.compile enabled, this can automatically be fused into previous ops and removes the quantization overhead.
Test Plan
Setup server via
Benchmark via
Test Result
We tested with concurrency 1, TP1 on GPT-OSS for 25k and 1k inputs (
bench serve
). As intended the query quantization overheads disappear. Also profiling showed that the quantization get's fused by torch compile. We also tested the accuracy through GPQA and AIME25 via GPT-OSS repo, and find no differences between the usual noise.Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.