Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config

# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
Expand Down Expand Up @@ -594,6 +596,7 @@ def _run_encoder_attention_test(
encoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor:
'''
Run encoder attention.
Expand Down Expand Up @@ -623,7 +626,7 @@ def _run_encoder_attention_test(
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
with set_forward_context(attn_metadata):
with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
Expand All @@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
decoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor:
'''
Run decoder self-attention test.
Expand Down Expand Up @@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
with set_forward_context(attn_metadata):
with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
Expand All @@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_test_params: Optional[PhaseTestParameters],
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor:
'''
Run encoder/decoder cross-attention test.
Expand Down Expand Up @@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
with set_forward_context(attn_metadata):
with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
Expand Down Expand Up @@ -839,7 +844,9 @@ def test_encoder_only(

# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)

# Construct encoder attention test params (only used
# during prefill)
Expand All @@ -863,7 +870,8 @@ def test_encoder_only(
test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt))
test_pt=test_pt,
vllm_config=vllm_config))

# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
Expand Down Expand Up @@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn(

# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)

# Construct encoder attention test params (only used
# during prefill)
Expand Down Expand Up @@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt)
test_pt=test_pt,
vllm_config=vllm_config)

# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
Expand All @@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs,
prephase_dec_test_params,
prephase_attn_metadata,
test_pt=test_pt)
test_pt=test_pt,
vllm_config=vllm_config)

# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params,
Expand All @@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params,
prephase_cross_test_params,
prephase_attn_metadata,
test_pt=test_pt)
test_pt=test_pt,
vllm_config=vllm_config)

# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params,
Expand All @@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs,
decphase_dec_test_params,
decphase_attn_metadata,
test_pt=test_pt)
test_pt=test_pt,
vllm_config=vllm_config)

# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params,
Expand All @@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
decphase_dec_test_params,
None,
decphase_attn_metadata,
test_pt=test_pt)
test_pt=test_pt,
vllm_config=vllm_config)

# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params,
Expand Down
23 changes: 14 additions & 9 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)

Expand All @@ -15,13 +14,19 @@
ModelRunnerInputBuilderBase)


class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
ENCODER = auto(
) # Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto(
) # Attention between dec. Q and enc. K/V for encoder-decoder
class AttentionType:
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER = "encoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
ENCODER_DECODER = "encoder_decoder"


class AttentionBackend(ABC):
Expand Down Expand Up @@ -241,6 +246,6 @@ def forward(
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
raise NotImplementedError
2 changes: 1 addition & 1 deletion vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def forward(
attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Expand Down
Loading