Skip to content

Conversation

bringlein
Copy link
Contributor

@bringlein bringlein commented Sep 9, 2025

Purpose

This PR adds a triton implementation of the reshape_and_cache kernel, this helps reducing the dependency on non-pytorch/triton kernels for the triton_attn kernel.

The kernel itself has the same (or slightly better) performance on H100 and MI300 as the CUDA kernel.
image
image

CC: @tdoublep @SageMoore @jvlunteren @jikunshang @cyang49

Test Plan

unit tests and end-to-end correctness tests

Test Result

unit tests:

python3 -m pytest tests/kernels/attention/test_cache.py::test_reshape_and_cache_flash
....
216 passed, 72 skipped in 79.67s (0:01:19) 

end2end correctness tests:

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

on main with H100:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.782|±  |0.0185|

with this PR using H100:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.800|±  |0.0179|
|     |       |strict-match    |     5|exact_match|↑  |0.784|±  |0.0184|

(results are the same if running on MI300).


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

bringlein and others added 10 commits September 5, 2025 10:47
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a Triton implementation for the reshape_and_cache kernel, aiming to reduce dependencies on non-PyTorch/Triton kernels for the triton_attn kernel. The changes include adding a new Triton kernel file, modifying the test file to include the new implementation, and updating the attention backend to use the Triton kernel. The review focuses on correctness and potential issues arising from the new implementation.

Comment on lines 344 to 347
)
else:
ops.reshape_and_cache_flash(
triton_reshape_and_cache_flash(
key,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The change from ops.reshape_and_cache_flash to triton_reshape_and_cache_flash directly replaces the CUDA implementation with the Triton one. It's crucial to ensure that this replacement doesn't introduce any regressions in performance or correctness across all supported configurations. A more robust approach would involve a mechanism to dynamically switch between implementations based on hardware, input size, or other relevant factors, allowing for a fallback to the CUDA implementation if necessary. Without such a mechanism, any unforeseen issues with the Triton implementation could severely impact the overall system performance. This needs to be addressed immediately.

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!
Could you also test on the E2E throughput benchmark? vllm bench throughput ...
And I don't fully understand why we don't want to use the reshape_and_cache_flash CUDA kernel, could you show more context here?

@bringlein
Copy link
Contributor Author

Could you also test on the E2E throughput benchmark? vllm bench throughput ...

Sure:

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 vllm bench throughput --model meta-llama/Llama-3.1-8B-Instruct/ --dataset-name random --input-len 4096 --output-len 4096 --prefix-len 0 --num-prompts 100

with main:

Throughput: 0.44 requests/s, 3576.79 total tokens/s, 1788.60 output tokens/s
Total num prompt tokens:  409505
Total num output tokens:  409600

with this PR:

Throughput: 0.44 requests/s, 3573.81 total tokens/s, 1787.11 output tokens/s
Total num prompt tokens:  409505
Total num output tokens:  409600

And I don't fully understand why we don't want to use the reshape_and_cache_flash CUDA kernel, could you show more context here?

Using CUDA kernels creates difficulties for some platforms (see e.g. for XPUs #24149 ). Since vLLM embraced torch.compile and with the Triton Attention backend, the reshape_and_cache kernel is (one of) the last kernel that is neither pytorch or triton naitive (if using TritonAttention backend). Hence, changing this would improve cross-platform portability of vLLM.

However, one problem right now is the numerical accuracy in case the reshape_and_cache kernel needs to cast the FP16 input into an fp8 cache on AMD GPUs. I'm still debugging this, but right now we can't use the triton kernel for this scenario.

@yewentao256
Copy link
Member

with this PR:

Throughput: 0.44 requests/s, 3573.81 total tokens/s, 1787.11 output tokens/s

Are you sure this are using different kernels? They are nearly the same. I don't think Triton would has the same perf with Cuda kernel

@jikunshang
Copy link
Collaborator

with this PR:

Throughput: 0.44 requests/s, 3573.81 total tokens/s, 1787.11 output tokens/s

Are you sure this are using different kernels? They are nearly the same. I don't think Triton would has the same perf with Cuda kernel

My understanding is this reshape_and_cache_flash kernel is a copy kernel, not a compute kernel, triton can implement such kernel efficiently.
I think it makes sense to have a full triton path:)

@bringlein
Copy link
Contributor Author

Are you sure this are using different kernels? They are nearly the same. I don't think Triton would has the same perf with Cuda kernel

Yes, it is using the triton kernel. Also the micro-benchmarks of just the kernels above show that the triton kernel has the same performance. We have shown repeatedly that Triton can have comparable performance than CUDA kernels, see e.g. https://arxiv.org/abs/2505.03780, or https://www.youtube.com/watch?v=GG1qi82J8Hg&t=500s, or if you benchmark the triton_attn vs. the flash_attn backend (we will publish smth on that soon).

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@bringlein
Copy link
Contributor Author

In the meantime, I was able to fix the problem with an fp8 kv-cache on MI300. The kernel works now without problems for fp16 and fp8 on H100 and MI300. Also, if testing the end2end accuracy with fp8 on MI300, it results in (nearly) the same score like with fp16:

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct/,kv_cache_dtype=fp8 --tasks gsm 8k --num_fewshot 5 --batch_size auto --limit 500

...

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.788|±  |0.0183|
|     |       |strict-match    |     5|exact_match|↑  |0.770|±  |0.0188|

and unit tests on MI300:

python3 -m pytest tests/kernels/attention/test_cache.py::test_reshape_and_cache_flash -x
....
216 passed, 72 skipped in 66.31s

Background:

If using fp8, vllm allocates the kv-cache as uint8 (due to the mapping here:

vllm/vllm/utils/__init__.py

Lines 176 to 186 in fdb09c7

STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
"int8": torch.int8,
"fp8_inc": torch.float8_e4m3fn,
}
). On the rocm platform, the FP8-FNUZ (Finite and NaN Only) datatype is used, as declared here:

vllm/vllm/platforms/rocm.py

Lines 408 to 412 in fdb09c7

def fp8_dtype(cls) -> torch.dtype:
if cls.is_fp8_fnuz():
return torch.float8_e4m3fnuz
else:
return torch.float8_e4m3fn

So, the reshape_and_cache_kernel then would need to explicit scale and cast the fp16 value to fp8, and then bitcast (i.e. reinterpreted cast without any data changes) the fp8 value to uint8. As, e.g. done in the initial version of this PR:

@triton.jit
def quant_symmetric_per_tensor_fp8e4nv(x, scale=None):
if scale is None:
# Compute scale
max_val = tl.max(tl.abs(x))
scale = max_val / 448.0
scale = tl.where(scale == 0.0, 1.0, scale) # Avoid div-by-zero
# Quantize to float8e4nv
x_scaled = x / scale
x_clipped = tl.clamp(x_scaled, -448.0, 448.0)
return x_clipped.to(tl.float8e4nv)

However, the triton langauge (as of version 3.4 and below) does not have a datatype for torch.float8_e4m3fnuz, only a language representation for torch.float8_e4m3fn: https://github.com/triton-lang/triton/blob/c817b9b63d40ead1ed023b7663f5ea14f676f4bc/python/triton/language/core.py#L400

So what happened was that the reshape and cache kernel casted bfloat16 explicitly to fp8_e4m3nv and then stored it in the uint8 tensor (via a bitcast). The subsequent attention kernel then reads the cache as torch.float8_e4m3fnuz, resulting in mostly wrong values.

The good news is that the triton backend does support torch.float8_e4m3fnuz on AMD and hence, the implicit cast happening with every tl.store is correct (as it would be if we would access e.g. key_cache_ptr.dtype.element_ty). But, this can happen only implicitly, not with an explicit tensor.to(tl.float...).
The solution is quite simple: key_cache = key_cache.view(current_platform.fp8_dtype())

@yewentao256 yewentao256 changed the title [V1][Kernel] Add triton implementation for reshape_and_cache [V1][Kernel] Add triton implementation for reshape_and_cache_flash Sep 15, 2025
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable @bringlein

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

I have tested locally using vllm-source/benchmarks/kernels/benchmark_reshape_and_cache_flash.py

Here is what I get on B200:

  num_tokens  layout      latency_cuda (µs)    latency_triton (µs)
------------  --------  -------------------  ---------------------
           2  NHD                     6.902                 56.777
           4  NHD                     6.714                 55.243
           8  NHD                     6.519                 55.221
          16  NHD                     6.273                 54.965
          32  NHD                     6.355                 55.046
          64  NHD                     6.399                 56.348
         128  NHD                     6.546                 55.745
         256  NHD                     6.431                 54.853
         512  NHD                    10.428                 54.352
        1024  NHD                    25.994                 55.091
        2048  NHD                    47.105                 55.002
        4096  NHD                    87.363                101.223
        8192  NHD                   168.325                211.682
       16384  NHD                   329.958                424.305
       32768  NHD                   654.033                843.652
       65536  NHD                  1303.03                1690.64

No HND results since the triton kernel doesn't support it.

From both speed view or from functional view, I don't think it is a good idea to replace the whole Cuda kernel.

@bringlein
Copy link
Contributor Author

Hi @yewentao256,

thanks for your feedback and engagement!

Regarding your performance numbers: I guess you have measured it without cuda graphs? In this case, you measure the software overhead of the JIT compiler of Triton and that's why it has more or less the very same performance for 2 and 2048 tokens. If you measure it with cuda graphs, i.e. ensuring to not measure the software overhead, I expect that you see very similar performance between cuda and triton kernel (as it is shown in the plots in the initial post).
(what wonders me is why you see different performance for 65k tokens, there the overhead should not be visible, even without cuda graphs...but the triton performance on blackwell is known to not be very good yet).

regarding the layout: The HND layout can be implemented, but I'm unsure if it is necessary?

In my current understanding, we don't want to replace the cuda kernel. We only want to add a triton kernel for the triton backend to have a full triton path. And in the triton backend, we don't have the HND layout, or am I missing smth?

@mgoin
Copy link
Member

mgoin commented Sep 17, 2025

@bringlein Maybe you could update benchmark_reshape_and_cache_flash.py to compare and use triton.testing.do_bench_cudagraph so we can have a benchmark in vllm?

We do use HND layout for flashinfer/trtllm attention, so if this was to be used there that would be a requirement. We are working towards fusing RoPE + reshape_and_cache though, so that should be prioritized over improvements to just reshape_and_cache

Copy link

mergify bot commented Sep 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bringlein.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 22, 2025
…d_cache_pr and fix merge conflict

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@mergify mergify bot added performance Performance-related issues and removed needs-rebase labels Sep 22, 2025
@bringlein
Copy link
Contributor Author

@mgoin @yewentao256 I updated benchmarks/kernels/benchmark_reshape_and_cache_flash.py and added two new flags implementation (i.e. cuda or triton) and mode (i.e. cuda graphs or not).

On an H100, I get the following results:

python3 ./benchmarks/kernels/benchmark_reshape_and_cache_flash.py --implementation=cuda
Benchmark results for implementation cuda (measuring with cudagraph):
  num_tokens  layout      latency (µs)
------------  --------  --------------
           2  NHD               12.393
           4  NHD               12.789
           8  NHD               14.348
          16  NHD               26.702
          32  NHD               12.422
          64  NHD               12.187
         128  NHD               12.93
         256  NHD               16.437
         512  NHD               37.987
        1024  NHD               59.177
        2048  NHD              104.466
        4096  NHD              195.447
        8192  NHD              378.028
       16384  NHD              741.654
       32768  NHD             1469.05
       65536  NHD             2925.6

and for triton

python3 ./benchmarks/kernels/benchmark_reshape_and_cache_flash.py --implementation=triton
Benchmark results for implementation triton (measuring with cudagraph):
  num_tokens  layout      latency (µs)
------------  --------  --------------
           2  NHD                8.763
           4  NHD                9.735
           8  NHD               16.91 
          16  NHD               21.476
          32  NHD                9.13
          64  NHD               10.125
         128  NHD               11.34 
         256  NHD               16.787
         512  NHD               35.567
        1024  NHD               56.811
        2048  NHD              103.081
        4096  NHD              199.439
        8192  NHD              395.063
       16384  NHD              766.764
       32768  NHD             1526.15 
       65536  NHD             3050.05 

...which is what I would expect, and also is shown in the plots above.

@mgoin
Copy link
Member

mgoin commented Sep 22, 2025

Please fix the precommit first

Error: benchmarks/kernels/benchmark_reshape_and_cache_flash.py:89:13: F821 Undefined name `key`
Error: benchmarks/kernels/benchmark_reshape_and_cache_flash.py:90:13: F821 Undefined name `value`
Error: benchmarks/kernels/benchmark_reshape_and_cache_flash.py:91:13: F821 Undefined name `key_cache`
Error: benchmarks/kernels/benchmark_reshape_and_cache_flash.py:92:13: F821 Undefined name `value_cache`
Error: benchmarks/kernels/benchmark_reshape_and_cache_flash.py:93:13: F821 Undefined name `slot_mapping`
Error: benchmarks/kernels/benchmark_reshape_and_cache_flash.py:100:13: F821 Undefined name `key`
Error: benchmarks/kernels/benchmark_reshape_and_cache_flash.py:101:13: F821 Undefined name `value`
Error: benchmarks/kernels/benchmark_reshape_and_cache_flash.py:102:13: F821 Undefined name `key_cache`
Error: benchmarks/kernels/benchmark_reshape_and_cache_flash.py:103:13: F821 Undefined name `value_cache`
Error: benchmarks/kernels/benchmark_reshape_and_cache_flash.py:104:13: F821 Undefined name `slot_mapping`

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@bringlein
Copy link
Contributor Author

@mgoin yes, I've seen that but I couldn't really explain where it comes from. All the variables are defined (obviously) and I didn't touch the lines where there are defined. So I muted these false positives now, but I don't know if that is the best approach.

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 23, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

@mgoin mgoin merged commit 100b630 into vllm-project:main Sep 23, 2025
53 checks passed
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…llm-project#24503)

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
yewentao256 added a commit that referenced this pull request Oct 3, 2025
…24503)

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants