Skip to content

Conversation

bringlein
Copy link
Contributor

@bringlein bringlein commented Sep 11, 2025

Purpose

This PR splits the triton_attn backend of V1 into two backends: One triton-only and platform-independent triton_attn and one rocm-specific rocm_attn, including aiter kernels.
This facilitates easier maintenance of both backends. Also, adaptations to the vllm-internal triton backend then don't need to ensure full compatibility with the external library aiter. For example, the standardization of kv-cache layouts in #21624 was blocked by this (and is also solved in this PR).
The selection of the backends is updated to select the new rocm_attn in case any of the aiter-specific variables are set.
Both backends still use the same triton kernels, but the logic to select them is now separated.

CC: @tdoublep @SageMoore @LucasWilkinson @jvlunteren

Test Plan

Correctness tests and manual comparison of the file diff between old and new split backends.

Test Result

Correctness

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

on main, H100:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.782|±  |0.0185|

with this PR on H100:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.782|±  |0.0185|

on main, MI300:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.796|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.776|±  |0.0187|

with this PR on MI300:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.778|±  |0.0186|

to test the "new" rocm backend, on MI300:

VLLM_ATTENTION_BACKEND=ROCM_ATTN_VLLM_V1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500
....
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.778|±  |0.0186|

git diff

Initially, I renamed triton_attn.py to rocm_attn.py with git mv.... However, github now shows rocm_attn.py as completely new file. However, the difference to the triton_attn in main is quite minimal, they are only renamings.

So, to facilitate the PR review, I created the diff manually:

git show origin/main:vllm/v1/attention/backends/triton_attn.py > tmp.py

diff tmp.py vllm/v1/attention/backends/rocm_attn.py 
9a10
> from vllm import _custom_ops as ops                       
27,31d27
< if current_platform.is_cuda_alike():                        
<     from vllm import _custom_ops as ops
< elif current_platform.is_xpu():                           
<     from vllm._ipex_ops import ipex_ops as ops
<                                                                                                                                           
36c32
< class TritonAttentionMetadata:                                                                                                            
---    
> class RocmAttentionMetadata:                                                                                                              
65,66c61,62
< class TritonAttentionMetadataBuilder(                                                                                                     
<         AttentionMetadataBuilder[TritonAttentionMetadata]):
---                                          
> class RocmAttentionMetadataBuilder(
>         AttentionMetadataBuilder[RocmAttentionMetadata]):
84c80                                                                                                                                       
<     ) -> TritonAttentionMetadata:              
---
>     ) -> RocmAttentionMetadata:
95c91
<               fast_build: bool = False) -> TritonAttentionMetadata:
---
>               fast_build: bool = False) -> RocmAttentionMetadata:
123c119
<         attn_metadata = TritonAttentionMetadata(
---
>         attn_metadata = RocmAttentionMetadata(
141c137
< class TritonAttentionBackend(AttentionBackend):
---
> class RocmAttentionBackend(AttentionBackend):
166c162
<         return "TRITON_ATTN_VLLM_V1"
---
>         return "ROCM_ATTN_VLLM_V1"
169,170c165,166
<     def get_impl_cls() -> type["TritonAttentionImpl"]:
<         return TritonAttentionImpl
---
>     def get_impl_cls() -> type["RocmAttentionImpl"]:
>         return RocmAttentionImpl
174c170
<         return TritonAttentionMetadata
---
>         return RocmAttentionMetadata
192,193c188,189
<     def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: 
<         return TritonAttentionMetadataBuilder
---
>     def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
>         return RocmAttentionMetadataBuilder
205c201
< class TritonAttentionImpl(AttentionImpl):
---
> class RocmAttentionImpl(AttentionImpl):
244c240
<         TritonAttentionBackend.validate_head_size(head_size)
---
>         RocmAttentionBackend.validate_head_size(head_size)
250c246
<                                       "TritonAttentionImpl")
---
>                                       "RocmAttentionImpl")
261c257
<                     "Using aiter unified attention for TritonAttentionImpl")
---
>                     "Using aiter unified attention for RocmAttentionImpl")
267c263
<                     "Using vllm unified attention for TritonAttentionImpl")
---
>                     "Using vllm unified attention for RocmAttentionImpl")
308c304
<                 " for TritonAttentionImpl")
---
>                 " for RocmAttentionImpl")


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@mergify mergify bot added rocm Related to AMD ROCm v1 labels Sep 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR splits the triton_attn backend into triton_attn and rocm_attn to facilitate easier maintenance and platform-specific adaptations. The review focuses on identifying potential issues related to correctness and maintainability, particularly in the new rocm_attn.py file and modifications to platforms/rocm.py.

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
"RocmAttentionImpl")

self.fp8_dtype = current_platform.fp8_dtype()
self.force_prefill_decode_attn = \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for simplicity we can remove this variable altogether, to not have 2 identical unified attention codepaths
The logic (for ROCm) could be just:
if backend == ROCM_ATTN_VLLM_V1:
use rocm_attn with split attention (or aiter)
else:
use triton_attn with unified attention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be a good idea, but maybe as a follow-up PR? I think the focus of this PR should be to enable such optimizations easily later.

Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tdoublep tdoublep enabled auto-merge (squash) September 22, 2025 13:22
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 22, 2025
@tdoublep tdoublep merged commit 175811e into vllm-project:main Sep 22, 2025
55 checks passed
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…ckends (vllm-project#24648)

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…ckends (vllm-project#24648)

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: charlifu <charlifu@amd.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
…ckends (#24648)

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants