-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Attention][FlashInfer] Enable FP8 FlashInfer (TRTLLM) MLA decode #24705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
bmm1_scale = layer._q_scale.item() * layer._k_scale.item() * self.scale | ||
bmm2_scale = layer._v_scale.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use the _float
version of these?
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
Also could this be calculated ahead of time rather than each forward pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can definitely switch to the float ones. It's possible that each layer could have a different scale though, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, definitely. Sorry I wasn't looking carefully at the difference between layer and self here, we can keep the local computation. Was just trying to avoid CPU ops if possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem, thanks for the review!
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
…lm-project#24705) Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
…lm-project#24705) Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
…lm-project#24705) Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: bbartels <benjamin@bartels.dev>
…lm-project#24705) Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Purpose
Enable FP8 kv cache for
FLASHINFER_MLA
backend.Test Plan
Correctness
VLLM_ATTENTION_BACKEND=FLASHINFER_MLA lm_eval --model vllm --model_args '{"pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat", "trust_remote_code": true, "kv_cache_dtype": <dtype>}' --tasks gsm8k --batch_size auto
Performance
VLLM_ATTENTION_BACKEND=FLASHINFER_MLA vllm bench throughput --model=deepseek-ai/DeepSeek-V2-Lite-Chat --dataset-name=random --input-len=8192 --output-len=1024 --num-prompts=100 --kv-cache-dtype=<dtype>
Test Result
Correctness
with
<dtype>="auto"
:with
<dtype>="fp8"
:Performance
with
<dtype>="auto"
:Throughput: 3.22 requests/s, 29668.09 total tokens/s, 3296.62 output tokens/s
with
<dtype>="fp8"
:Throughput: 3.55 requests/s, 32757.46 total tokens/s, 3639.90 output tokens/s
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.