-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[V1][Kernel] Add triton implementation for reshape_and_cache_flash
#24503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][Kernel] Add triton implementation for reshape_and_cache_flash
#24503
Conversation
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces a Triton implementation for the reshape_and_cache
kernel, aiming to reduce dependencies on non-PyTorch/Triton kernels for the triton_attn
kernel. The changes include adding a new Triton kernel file, modifying the test file to include the new implementation, and updating the attention backend to use the Triton kernel. The review focuses on correctness and potential issues arising from the new implementation.
) | ||
else: | ||
ops.reshape_and_cache_flash( | ||
triton_reshape_and_cache_flash( | ||
key, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change from ops.reshape_and_cache_flash
to triton_reshape_and_cache_flash
directly replaces the CUDA implementation with the Triton one. It's crucial to ensure that this replacement doesn't introduce any regressions in performance or correctness across all supported configurations. A more robust approach would involve a mechanism to dynamically switch between implementations based on hardware, input size, or other relevant factors, allowing for a fallback to the CUDA implementation if necessary. Without such a mechanism, any unforeseen issues with the Triton implementation could severely impact the overall system performance. This needs to be addressed immediately.
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
Could you also test on the E2E throughput benchmark? vllm bench throughput ...
And I don't fully understand why we don't want to use the reshape_and_cache_flash
CUDA kernel, could you show more context here?
Sure:
with main:
with this PR:
Using CUDA kernels creates difficulties for some platforms (see e.g. for XPUs #24149 ). Since vLLM embraced torch.compile and with the Triton Attention backend, the However, one problem right now is the numerical accuracy in case the |
Are you sure this are using different kernels? They are nearly the same. I don't think Triton would has the same perf with Cuda kernel |
My understanding is this |
Yes, it is using the triton kernel. Also the micro-benchmarks of just the kernels above show that the triton kernel has the same performance. We have shown repeatedly that Triton can have comparable performance than CUDA kernels, see e.g. https://arxiv.org/abs/2505.03780, or https://www.youtube.com/watch?v=GG1qi82J8Hg&t=500s, or if you benchmark the |
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
In the meantime, I was able to fix the problem with an fp8 kv-cache on MI300. The kernel works now without problems for fp16 and fp8 on H100 and MI300. Also, if testing the end2end accuracy with fp8 on MI300, it results in (nearly) the same score like with fp16:
and unit tests on MI300:
Background:If using fp8, vllm allocates the kv-cache as Lines 176 to 186 in fdb09c7
FP8-FNUZ (Finite and NaN Only) datatype is used, as declared here: Lines 408 to 412 in fdb09c7
So, the vllm/vllm/attention/ops/triton_reshape_and_cache_flash.py Lines 25 to 36 in 55e829e
However, the triton langauge (as of version 3.4 and below) does not have a datatype for So what happened was that the reshape and cache kernel casted bfloat16 explicitly to The good news is that the triton backend does support |
reshape_and_cache_flash
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable @bringlein
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
I have tested locally using vllm-source/benchmarks/kernels/benchmark_reshape_and_cache_flash.py
Here is what I get on B200:
num_tokens layout latency_cuda (µs) latency_triton (µs)
------------ -------- ------------------- ---------------------
2 NHD 6.902 56.777
4 NHD 6.714 55.243
8 NHD 6.519 55.221
16 NHD 6.273 54.965
32 NHD 6.355 55.046
64 NHD 6.399 56.348
128 NHD 6.546 55.745
256 NHD 6.431 54.853
512 NHD 10.428 54.352
1024 NHD 25.994 55.091
2048 NHD 47.105 55.002
4096 NHD 87.363 101.223
8192 NHD 168.325 211.682
16384 NHD 329.958 424.305
32768 NHD 654.033 843.652
65536 NHD 1303.03 1690.64
No HND results since the triton kernel doesn't support it.
From both speed view or from functional view, I don't think it is a good idea to replace the whole Cuda kernel.
Hi @yewentao256, thanks for your feedback and engagement! Regarding your performance numbers: I guess you have measured it without cuda graphs? In this case, you measure the software overhead of the JIT compiler of Triton and that's why it has more or less the very same performance for 2 and 2048 tokens. If you measure it with cuda graphs, i.e. ensuring to not measure the software overhead, I expect that you see very similar performance between cuda and triton kernel (as it is shown in the plots in the initial post). regarding the layout: The HND layout can be implemented, but I'm unsure if it is necessary? In my current understanding, we don't want to replace the cuda kernel. We only want to add a triton kernel for the triton backend to have a full triton path. And in the triton backend, we don't have the |
@bringlein Maybe you could update We do use HND layout for flashinfer/trtllm attention, so if this was to be used there that would be a requirement. We are working towards fusing RoPE + reshape_and_cache though, so that should be prioritized over improvements to just reshape_and_cache |
This pull request has merge conflicts that must be resolved before it can be |
…d_cache_pr and fix merge conflict Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
…n/vllm into ngl_triton_reshape_and_cache_pr
@mgoin @yewentao256 I updated On an H100, I get the following results:
and for triton
...which is what I would expect, and also is shown in the plots above. |
Please fix the precommit first
|
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@mgoin yes, I've seen that but I couldn't really explain where it comes from. All the variables are defined (obviously) and I didn't touch the lines where there are defined. So I muted these false positives now, but I don't know if that is the best approach. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
…llm-project#24503) Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
…24503) Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
Purpose
This PR adds a triton implementation of the
reshape_and_cache
kernel, this helps reducing the dependency on non-pytorch/triton kernels for thetriton_attn
kernel.The kernel itself has the same (or slightly better) performance on H100 and MI300 as the CUDA kernel.


CC: @tdoublep @SageMoore @jvlunteren @jikunshang @cyang49
Test Plan
unit tests and end-to-end correctness tests
Test Result
unit tests:
end2end correctness tests:
on main with H100:
with this PR using H100:
(results are the same if running on MI300).
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.