diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index c8742e983520..ea911af3d19c 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -517,12 +517,9 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if attn_type != AttentionType.DECODER: + if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl" + "Encoder self-attention is not implemented for FlashAttentionImpl" ) def extend_forward( @@ -678,7 +675,14 @@ def forward( # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping @@ -704,8 +708,10 @@ def forward( # decode:extend:prefill query = query[:num_actual_tokens] - key = key[:num_actual_tokens] - value = value[:num_actual_tokens] + if key is not None: + key = key[:num_actual_tokens] + if value is not None: + value = value[:num_actual_tokens] output_actual_tokens = output[:num_actual_tokens] diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index b2639c0df041..16fb52ab501c 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -142,7 +142,14 @@ def forward( key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. ops.reshape_and_cache_flash( @@ -169,7 +176,10 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = ( + cu_seqlens_q.shape[0] - 1, + key.shape[1] if key is not None else self.num_kv_heads, + ) self.unified_attention( q=query[:num_actual_tokens], diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 6dfdfc19ccba..868143cc192e 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -238,12 +238,9 @@ def __init__( RocmAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: + if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "RocmAttentionImpl" + "Encoder self-attention is not implemented for RocmAttentionImpl" ) self.fp8_dtype = current_platform.fp8_dtype()