Skip to content

Conversation

MatthewBonanni
Copy link
Contributor

@MatthewBonanni MatthewBonanni commented Sep 12, 2025

Purpose

Enable FP8 kv cache for FLASHINFER_MLA backend.

Test Plan

Correctness

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA lm_eval --model vllm --model_args '{"pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat", "trust_remote_code": true, "kv_cache_dtype": <dtype>}' --tasks gsm8k --batch_size auto

Performance

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA vllm bench throughput --model=deepseek-ai/DeepSeek-V2-Lite-Chat --dataset-name=random --input-len=8192 --output-len=1024 --num-prompts=100 --kv-cache-dtype=<dtype>

Test Result

Correctness

with <dtype>="auto":

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6520|±  |0.0131|
|     |       |strict-match    |     5|exact_match|↑  |0.6444|±  |0.0132|

with <dtype>="fp8":

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6384|±  |0.0132|
|     |       |strict-match    |     5|exact_match|↑  |0.6262|±  |0.0133|

Performance

with <dtype>="auto": Throughput: 3.22 requests/s, 29668.09 total tokens/s, 3296.62 output tokens/s
with <dtype>="fp8": Throughput: 3.55 requests/s, 32757.46 total tokens/s, 3639.90 output tokens/s


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
@MatthewBonanni MatthewBonanni changed the title Enable FP8 FlashInfer MLA decode Enable FP8 FlashInfer (TRTLLM) MLA decode Sep 12, 2025
@mergify mergify bot added the v1 label Sep 12, 2025
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
@MatthewBonanni MatthewBonanni marked this pull request as ready for review September 12, 2025 00:36
@MatthewBonanni MatthewBonanni changed the title Enable FP8 FlashInfer (TRTLLM) MLA decode [Attention][FlashInfer] Enable FP8 FlashInfer (TRTLLM) MLA decode Sep 12, 2025
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed deepseek Related to DeepSeek models labels Sep 12, 2025
Comment on lines 90 to 91
bmm1_scale = layer._q_scale.item() * layer._k_scale.item() * self.scale
bmm2_scale = layer._v_scale.item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use the _float version of these?

        # We also keep q/k/v_scale on host (cpu) memory for attention
        # backends that require the scales to be on host instead of on device.
        # e.g. Flashinfer
        self._q_scale_float = 1.0
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

Also could this be calculated ahead of time rather than each forward pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can definitely switch to the float ones. It's possible that each layer could have a different scale though, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely. Sorry I wasn't looking carefully at the difference between layer and self here, we can keep the local computation. Was just trying to avoid CPU ops if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, thanks for the review!

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
@mergify mergify bot added the tpu Related to Google TPUs label Sep 12, 2025
@mgoin mgoin merged commit 7ba32aa into vllm-project:main Sep 12, 2025
48 checks passed
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
dsxsteven pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 15, 2025
bbartels pushed a commit to bbartels/vllm that referenced this pull request Sep 15, 2025
…lm-project#24705)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: bbartels <benjamin@bartels.dev>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants