Skip to content

Conversation

jmkuebler
Copy link
Contributor

@jmkuebler jmkuebler commented Sep 15, 2025

Purpose

When running attention in FP8, quantizing the queries causes overhead, since this is done via a custom operation. And because it is a custom op it cannot be fused by torch compile. Profiling shows that this is a (context-lenght-independent) overhead.

This PR moves the query quantization out of the attention backend into attention/layer.py and uses a simple torch implementation. Then torch compile is able to fuse it and reduce the quantization overheads.

It is currently only done for the FA backend, but the design is generalizable. Created an Issue to track #25584

Test Plan

Spin up server

vllm serve meta-llama/Llama-3.1-8B-Instruct \
  --kv-cache-dtype fp8 \
  --compilation-config '{"compile_sizes": [1,2,4,8], "cudagraph_capture_sizes": [1,2,4,8], "cudagraph_mode": "FULL_AND_PIECEWISE"}' \
  --no-enable-prefix-caching

Benchmark

vllm bench serve \
    --backend vllm \
    --model $model_name \
    --dataset-name sonnet \
    --dataset-path vllm/benchmarks/sonnet.txt \
    --sonnet-input-len 1000 \
    --sonnet-output-len 200 \
    --port 8000 \
    --num-prompts 20 \
    --max-concurrency 1

unit test covering these changes:
pytest /efs/kuebj/kv_cache/vllm_source/tests/quantization/test_fp8.py::test_kv_cache_model_load_and_run

Accuracy

To ensure there is no accidental accuracy degradation we also run

lm_eval \
  --model vllm \
  --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct,kv_cache_dtype=auto,tensor_parallel_size=8,enforce_eager=True \
  --tasks gsm8k \
  --batch_size 

with kv_cache_dtype in {auto,fp8} both on this PR and on mainline. We also run without enforce_eager=True for the FP8 variants

Test Result

this PR:

============ Serving Benchmark Result ============
Successful requests:                     20        
Maximum request concurrency:             1         
Benchmark duration (s):                  27.16     
Total input tokens:                      18248     
Total generated tokens:                  4000      
Request throughput (req/s):              0.74      
Output token throughput (tok/s):         147.27    
Peak output token throughput (tok/s):    151.00    
Peak concurrent requests:                2.00      
Total Token throughput (tok/s):          819.09    
---------------Time to First Token----------------
Mean TTFT (ms):                          32.36     
Median TTFT (ms):                        31.82     
P99 TTFT (ms):                           37.59     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.66      
Median TPOT (ms):                        6.66      
P99 TPOT (ms):                           6.68      
---------------Inter-token Latency----------------
Mean ITL (ms):                           6.66      
Median ITL (ms):                         6.66      
P99 ITL (ms):                            6.95 



MAINLINE
============ Serving Benchmark Result ============
Successful requests:                     20        
Maximum request concurrency:             1         
Benchmark duration (s):                  27.44     
Total input tokens:                      18248     
Total generated tokens:                  4000      
Request throughput (req/s):              0.73      
Output token throughput (tok/s):         145.77    
Peak output token throughput (tok/s):    149.00    
Peak concurrent requests:                2.00      
Total Token throughput (tok/s):          810.80    
---------------Time to First Token----------------
Mean TTFT (ms):                          30.39     
Median TTFT (ms):                        30.26     
P99 TTFT (ms):                           31.77     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.74      
Median TPOT (ms):                        6.74      
P99 TPOT (ms):                           6.75      
---------------Inter-token Latency----------------
Mean ITL (ms):                           6.74      
Median ITL (ms):                         6.74      
P99 ITL (ms):                            7.01      
==================================================

Accuracy on GSM8k ✅

The PR preserves the accuracy numbers both for FP8 and auto. That FP8 can be slightly worse than auto is expected.

Main+auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7801|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.7597|±  |0.0118|

PR + auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7801|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.7597|±  |0.0118|

Main+fp8 (enforce-eager)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7559|±  |0.0118|
|     |       |strict-match    |     5|exact_match|↑  |0.7384|±  |0.0121|

PR+fp8 (enforce-eager)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7559|±  |0.0118|
|     |       |strict-match    |     5|exact_match|↑  |0.7384|±  |0.0121|

Main+fp8 (w/ compile)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7703|±  |0.0116|
|     |       |strict-match    |     5|exact_match|↑  |0.7460|±  |0.0120|

PR+fp8 (w/compile)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7650|±  |0.0117|
|     |       |strict-match    |     5|exact_match|↑  |0.7445|±  |0.0120|

[old description until 09/24] ## Purpose This PR is part of a series to make FP8 attention fast for GPT-OSS, see https://github.com//issues/24916 for context. However, this specific PR would make quantization faster for any model using quantized Queries.

Quantizing the queries causes overhead, since this is done via a custom operation, it cannot be fused by torch compile. Profiling shows that this is a (context-lenght-independent) overhead. This PR adds an env variable that allows to perform the query quantization via simple torch operations. When running with torch.compile enabled, this can automatically be fused into previous ops and removes the quantization overhead.

Test Plan

Setup server via

export VLLM_FUSE_QUERY_QUANT=1
python -m vllm.entrypoints.openai.api_server \
    --model $MODEL_PATH \
    --tensor-parallel-size 1 \
    --max-num-seqs 1 \
    --port 8088 \
    --kv-cache-dtype "fp8" \
    --served-model-name gpt-oss-KV8-fuse-q-quant \
    --compilation-config '{"compile_sizes": [1,2,4,8], "cudagraph_capture_sizes": [1,2,4,8], "cudagraph_mode": "FULL_AND_PIECEWISE"}' \
    --no-enable-prefix-caching

Benchmark via

model_name=gpt-oss-KV8-fuse-q-quant
vllm bench serve \
    --backend vllm \
    --model $model_name \
    --dataset-name sonnet \
    --dataset-path /home/ec2-user/gpt-oss-attn/vllm/benchmarks/sonnet.txt \
    --sonnet-input-len 25000 \
    --sonnet-output-len 200 \
    --port 8088 \
    --tokenizer openai/gpt-oss-120b \
    --num-prompts 10 \
    --max-concurrency 1

Test Result

We tested with concurrency 1, TP1 on GPT-OSS for 25k and 1k inputs (bench serve). As intended the query quantization overheads disappear. Also profiling showed that the quantization get's fused by torch compile. We also tested the accuracy through GPQA and AIME25 via GPT-OSS repo, and find no differences between the usual noise.

model applied optimization input output median ITL [ms] median ttft [ms]
gpt-oss (full bf 16) mainline 25000 200 5.08 1252.43
gpt-oss (full bf 16) mainline 1000 200 4.83 76.16
kv fp8 mainline 25000 200 5.38 1211.33
kv fp8 mainline 1000 200 5.06 76.36
gpt-oss-KV8-fuse-q-quant this PR 25000 200 5.3 1208.56
gpt-oss-KV8-fuse-q-quant this PR 1000 200 4.98 78.89

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea but let's make it more extensible to other attention backends and lose the environment variable.

I think we should add a class property to the attention backend called supports_quant_query_input (False by default). If that property is True, we initialize a QuantFP8 object in the attention layer init:

self.query_quant = None
if self.kv_cache_dtype.startswith("fp8") and self.attn_backend.supports_quant_query_input:
    # We'll need to add support for e5m2 to QuantFP8, shouldn't be hard! 
    # Just add a `qdtype` param with `current_platform.fp8_dtype()` as the default
    self.query_quant = QuantFP8(True, GroupShape.PER_TENSOR, qdtype=...)

Layer forward:

if self.query_quant is not None:
    query = self.query_quant(query, self._q_scale)

@robertgshaw2-redhat robertgshaw2-redhat changed the title add env flag to make query quantization fusable [torch.compile] Make Query Quantization Fusable Sep 23, 2025
@jmkuebler
Copy link
Contributor Author

Thanks @ProExpertProg !
Just to be sure I make the changes accordingly. shall we always bring the query quantization out of the attention backend to the attention layer? Or shall we have that configurable? I currently don't know a case why we would want to keep it inside the backend, but want to be sure.

@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Sep 23, 2025

@jmkuebler I think we should always do it. The reason I mention the supports_quant_query_input flag is so that we can gradually migrate the backends. But if you're able to migrate all of them in this PR and test them that's even better and we don't need the flag! Either way, for a backend that supports it, quant should always just happen outside.

If you choose the gradual migration path, please create an issue for the rest of the migration and link it next to the flag.

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
@jmkuebler jmkuebler force-pushed the make_query_quantization_fusable branch from 65a2337 to f9936d3 Compare September 24, 2025 12:40
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor WIP notes

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
@jmkuebler
Copy link
Contributor Author

@ProExpertProg thanks a ton for the guidance and the comments.

I now did it for the FA backend and would defer other backends to a future PR (need some time to onboard on those etc). Created an issue to track here #25584

My changes are ready for another round of review, also note the updated PR description with latest testing results.

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 24, 2025
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice and clean addition, thanks!

@ProExpertProg
Copy link
Collaborator

@jmkuebler could you run lm_eval just to check E2E correctness? Both kvcache={auto,fp8} for main and PR.

@jmkuebler
Copy link
Contributor Author

jmkuebler commented Sep 25, 2025

@jmkuebler could you run lm_eval just to check E2E correctness? Both kvcache={auto,fp8} for main and PR.

@ProExpertProg added the results to the PR description. Good idea to add this here to track.
The PR preserves the gsm8k accuracy numbers both for FP8 and auto.

@ProExpertProg
Copy link
Collaborator

@jmkuebler not to be the worst but I noticed that you used enforce_eager in the lm eval. That will enable the cuda kernel for quant instead of the torch impl (at least by default). Could you run without enforce eager as well?

@jmkuebler
Copy link
Contributor Author

jmkuebler commented Sep 25, 2025

@ProExpertProg sure thing. With compilation it looks also good (the tiny variations are expected when using slightly different graphs).

lm_eval \
  --model vllm \
  --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct,kv_cache_dtype=fp8,tensor_parallel_size=8 \
  --tasks gsm8k \
  --batch_size auto

Main+fp8 (w/ compile)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7703|±  |0.0116|
|     |       |strict-match    |     5|exact_match|↑  |0.7460|±  |0.0120|

PR+fp8 (w/compile)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7650|±  |0.0117|
|     |       |strict-match    |     5|exact_match|↑  |0.7445|±  |0.0120|

@ProExpertProg ProExpertProg merged commit 69a8c8e into vllm-project:main Sep 25, 2025
45 checks passed
@github-project-automation github-project-automation bot moved this from In progress to Done in torch.compile integration Sep 25, 2025
Zhuul pushed a commit to Zhuul/vllm that referenced this pull request Sep 26, 2025
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed torch.compile v1
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

2 participants