From 91e4a581ba8cef55ce7abcbf9f1f08f99b8077c3 Mon Sep 17 00:00:00 2001 From: apinge Date: Mon, 10 Nov 2025 10:50:16 +0800 Subject: [PATCH 1/4] add support for whisper v1 using aiter unified attention and aiter flash attention Signed-off-by: apinge --- vllm/v1/attention/backends/rocm_aiter_fa.py | 20 +++++++++++-------- .../backends/rocm_aiter_unified_attn.py | 11 ++++++++-- vllm/v1/attention/backends/rocm_attn.py | 7 ++----- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index c7f925817a6a..941676ae8687 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -513,12 +513,9 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if attn_type != AttentionType.DECODER: + if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl" + "Encoder self-attention is not implemented for FlashAttentionImpl" ) def extend_forward( @@ -674,7 +671,14 @@ def forward( # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping @@ -700,8 +704,8 @@ def forward( # decode:extend:prefill query = query[:num_actual_tokens] - key = key[:num_actual_tokens] - value = value[:num_actual_tokens] + key = key[:num_actual_tokens] if key is not None else key_cache[:num_actual_tokens] + value = value[:num_actual_tokens] if value is not None else value_cache[:num_actual_tokens] output_actual_tokens = output[:num_actual_tokens] diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index b2639c0df041..df63880c43a3 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -142,7 +142,14 @@ def forward( key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. ops.reshape_and_cache_flash( @@ -169,7 +176,7 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1] if key is not None else self.num_kv_heads) self.unified_attention( q=query[:num_actual_tokens], diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 6dfdfc19ccba..868143cc192e 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -238,12 +238,9 @@ def __init__( RocmAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: + if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "RocmAttentionImpl" + "Encoder self-attention is not implemented for RocmAttentionImpl" ) self.fp8_dtype = current_platform.fp8_dtype() From 723bf67b3afd40ef7844781ce58c840412f420aa Mon Sep 17 00:00:00 2001 From: apinge Date: Mon, 10 Nov 2025 11:49:34 +0800 Subject: [PATCH 2/4] update key and value for the None condition Signed-off-by: apinge --- vllm/v1/attention/backends/rocm_aiter_fa.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 941676ae8687..8376c36f00e4 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -704,8 +704,11 @@ def forward( # decode:extend:prefill query = query[:num_actual_tokens] - key = key[:num_actual_tokens] if key is not None else key_cache[:num_actual_tokens] - value = value[:num_actual_tokens] if value is not None else value_cache[:num_actual_tokens] + if key is not None: + key = key[:num_actual_tokens] + if value is not None: + value = value[:num_actual_tokens] + output_actual_tokens = output[:num_actual_tokens] From 1c69959b732fc908c08410920ca0913c4a7a6406 Mon Sep 17 00:00:00 2001 From: apinge Date: Mon, 10 Nov 2025 12:00:55 +0800 Subject: [PATCH 3/4] update format for rocm_aiter_unified_attn.py Signed-off-by: apinge --- vllm/v1/attention/backends/rocm_aiter_unified_attn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index df63880c43a3..16fb52ab501c 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -176,7 +176,10 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1] if key is not None else self.num_kv_heads) + descale_shape = ( + cu_seqlens_q.shape[0] - 1, + key.shape[1] if key is not None else self.num_kv_heads, + ) self.unified_attention( q=query[:num_actual_tokens], From 110481d0a91c6058c26ad0e7fd2243fd4f8c267c Mon Sep 17 00:00:00 2001 From: apinge Date: Mon, 10 Nov 2025 12:09:01 +0800 Subject: [PATCH 4/4] update format Signed-off-by: apinge --- vllm/v1/attention/backends/rocm_aiter_fa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 8376c36f00e4..a45b4a4707fd 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -709,7 +709,6 @@ def forward( if value is not None: value = value[:num_actual_tokens] - output_actual_tokens = output[:num_actual_tokens] num_decodes = attn_metadata.num_decodes