diff --git a/examples/offline_inference/qwen_1m.py b/examples/offline_inference/qwen_1m.py index d8d61667f688..c8d0d91ce7b5 100644 --- a/examples/offline_inference/qwen_1m.py +++ b/examples/offline_inference/qwen_1m.py @@ -5,7 +5,6 @@ from vllm import LLM, SamplingParams -os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN" os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index b6bebbba915b..c3f1c7481d1b 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -334,8 +334,9 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): [7, 256, 533] if current_platform.is_cuda() else [8]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("model_name, model_class", MODELS) -@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if - current_platform.is_cuda() else [_Backend.ROCM_FLASH]) +@pytest.mark.parametrize("backend", + [_Backend.FLASHINFER] if current_platform.is_cuda() + else [_Backend.TRITON_ATTN_VLLM_V1]) @pytest.mark.parametrize( "split_attention", [False, True] if current_platform.is_rocm() else [False]) diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 7083661575ef..c7abf652f111 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -18,7 +18,7 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask - from vllm.attention.backends.xformers import _make_alibi_bias + from tests.kernels.utils import make_alibi_bias FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. @@ -429,8 +429,8 @@ def test_multi_query_kv_attention( alibi_bias = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, - seq_lens) + attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, + seq_lens) output = torch.empty_like(query) start = 0 # Dynamic sequence length not supported with custom attn_bias. diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index f8454ad0a4c4..38ab40f88ae0 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -67,6 +67,7 @@ def generate_params(): return params +@pytest.mark.skip(reason="Skipped for now. Should be revisited.") @pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) def test_env( diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 8544eab3accc..0695f84aea1a 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -11,7 +11,7 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask -from vllm.attention.backends.xformers import _make_alibi_bias +from tests.kernels.utils import make_alibi_bias from vllm.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode) from vllm.attention.ops.prefix_prefill import context_attention_fwd @@ -470,7 +470,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: key = key.unsqueeze(0) value = value.unsqueeze(0) - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) output_ref = torch.empty_like(output) seq_start = 0 query_start = 0 @@ -479,7 +479,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # FIXME(DefTruth): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - # modified from: vllm/attention/backends/xformers.py#L343 + # modified from: vllm/v1/attention/backends/xformers.py#L343 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index d56d3f4638f1..af301d9de435 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -16,6 +16,7 @@ def clear_cache(): _cached_get_attn_backend.cache_clear() +@pytest.mark.skip(reason="Skipped for now. Should be revisited.") def test_selector(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index c9bf85f6e2a5..8d6ce381976b 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -513,10 +513,6 @@ def make_backend(backend_name: str) -> AttentionBackend: Construct the backend instance determined by the backend_name string argument. - "XFORMERS" -> construct xformers backend - - TODO: other backends - Note: at time of writing the Attention wrapper automatically selects its own backend for Attention.forward(); so the backend instance which you generate with this function is not meant to be used for *running* @@ -528,18 +524,68 @@ def make_backend(backend_name: str) -> AttentionBackend: * Backend instance ''' - if backend_name == STR_XFORMERS_ATTN_VAL: - # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. - from vllm.attention.backends.xformers import XFormersBackend - return XFormersBackend() - elif backend_name == STR_FLASH_ATTN_VAL: - from vllm.attention.backends.flash_attn import FlashAttentionBackend + if backend_name in (STR_XFORMERS_ATTN_VAL, "XFORMERS_VLLM_V1"): + from vllm.v1.attention.backends.xformers import ( + XFormersAttentionBackend) + return XFormersAttentionBackend() + if backend_name in (STR_FLASH_ATTN_VAL, "FLASH_ATTN_VLLM_V1"): + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend return FlashAttentionBackend() + if backend_name == "TRITON_ATTN_VLLM_V1": + from vllm.v1.attention.backends.triton_attn import ( + TritonAttentionBackend) + return TritonAttentionBackend() + if backend_name == "FLEX_ATTENTION": + from vllm.v1.attention.backends.flex_attention import ( + FlexAttentionBackend) + return FlexAttentionBackend() + if backend_name in ("TORCH_SDPA", "TORCH_SDPA_VLLM_V1"): + from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend + return TorchSDPABackend() + if backend_name == "FLASHINFER": + from vllm.v1.attention.backends.flashinfer import FlashInferBackend + return FlashInferBackend() raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") +def make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_lens: list[int], +) -> list[Any]: + """Create ALiBi biases compatible with xFormers attention tests.""" + from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias + + if alibi_slopes is None: + return [None for _ in seq_lens] + + attn_biases: list[Any] = [] + num_heads = alibi_slopes.shape[0] + assert num_heads >= num_kv_heads, ( + "ALiBi slopes expect at least as many heads as KV heads") + + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + bias_tensor = torch.empty( + 1, + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias_tensor.mul_(alibi_slopes[:, None, None]) + attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor)) + + return attn_biases + + def _make_metadata_tensors( seq_lens: Optional[list[int]], context_lens: Optional[list[int]], diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index b9601114a318..bfde6e20a3b1 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -78,9 +78,8 @@ def _initialize_kv_caches_v1(self, vllm_config): return if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"): - # Phi4FlashForCausalLM and MotifForCausalLM - # only supports DIFFERENTIAL_FLASH_ATTN backend - m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") + pytest.skip( + "Differential Flash Attention backend has been removed.") if model_arch == "GptOssForCausalLM": # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py deleted file mode 100644 index 87a4558e377d..000000000000 --- a/vllm/attention/backends/differential_flash_attn.py +++ /dev/null @@ -1,931 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""" An implementation of https://arxiv.org/pdf/2410.05258 """ -from collections import defaultdict -from dataclasses import dataclass -from itertools import accumulate -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch -from einops import rearrange - -from vllm import _custom_ops as ops -# yapf conflicts with isort for this block -# yapf: disable -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.flash_attn import FlashAttentionBackend -# yapf: enable -from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, - compute_slot_mapping, - compute_slot_mapping_start_idx, - is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, - is_block_tables_empty) -from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) -from vllm.logger import init_logger -from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache) - -logger = init_logger(__name__) - - -class DifferentialFlashAttentionBackend(AttentionBackend): - accept_output_buffer = False - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - assert num_kv_heads % 2 == 0, "num_kv_heads must be divisible by 2" - return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) - - @staticmethod - def get_name() -> str: - return "DIFFERENTIAL_FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["DifferentialFlashAttentionImpl"]: - return DifferentialFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]: - return DifferentialFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]: - return DifferentialFlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - - ops.copy_blocks(key_caches, value_caches, src_to_dists) - - -@dataclass -class DifferentialFlashAttentionMetadata(AttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional[ - "DifferentialFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional[ - "DifferentialFlashAttentionMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - encoder_seq_start_loc: Optional[torch.Tensor] = None - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - # Cross-layer shared attention block tables - cross_layer_shared_block_tables: Optional[torch.Tensor] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return is_all_encoder_attn_metadata_set(self) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return is_all_cross_attn_metadata_set(self) - - @property - def prefill_metadata( - self) -> Optional["DifferentialFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert ((self.seq_lens is not None) - or (self.encoder_seq_lens is not None)) - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - cross_layer_shared_block_tables = ( - None if self.cross_layer_shared_block_tables is None else - self.cross_layer_shared_block_tables[:self.num_prefills]) - - self._cached_prefill_metadata = DifferentialFlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata( - self) -> Optional["DifferentialFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - cross_layer_shared_block_tables = ( - None if self.cross_layer_shared_block_tables is None else - self.cross_layer_shared_block_tables[self.num_prefills:]) - self._cached_decode_metadata = DifferentialFlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_decode_metadata - - -class DifferentialFlashAttentionMetadataBuilder( - AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]): - - def __init__(self, input_builder): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.cross_layer_shared_block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - # TODO: add support for chunked prefill and prefix caching. - assert not chunked_prefill_enabled, \ - "chunked prefill is not supported for now" - assert not prefix_cache_hit, "prefix caching is not supported for now" - - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - cross_layer_shared_block_table = [] - if prefix_cache_hit: - cross_layer_shared_block_table = block_tables[seq_id] - elif block_tables is not None: - if curr_sliding_window_block == 0: - cross_layer_shared_block_table = block_tables[seq_id] - else: - cross_layer_shared_block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.cross_layer_shared_block_tables.append( - cross_layer_shared_block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables(self, num_seqs: int, - block_tables: List[List[int]], - graph_block_tables) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - # max_batch_size, max_blocks = self.runner.graph_block_tables.shape - max_batch_size, max_blocks = graph_block_tables.shape - assert max_batch_size >= num_seqs - - # graph_block_tables = self.runner.graph_block_tables[:num_seqs] - graph_block_tables = graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - - self.cross_layer_shared_block_tables.extend([] * - cuda_graph_pad_size) - - num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables, self.runner.graph_block_tables) - cross_layer_shared_block_tables = \ - self._get_graph_runner_block_tables( - num_seqs, self.cross_layer_shared_block_tables, - self.runner.cross_layer_shared_graph_block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - cross_layer_shared_block_tables = make_tensor_with_pad( - self.cross_layer_shared_block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - return DifferentialFlashAttentionMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class DifferentialFlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - differential_flash_attention_config: Optional[Dict[str, Any]] = None, - ) -> None: - if differential_flash_attention_config is None: - differential_flash_attention_config = {} - self.differential_flash_attention_config = \ - differential_flash_attention_config - self.used_shared_kv_cache = kv_sharing_target_layer_name is not None - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - if use_irope: - logger.warning( - "Using irope in V0 is not supported yet, it will fall back " - "to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window - 1, - 0) if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - self.vllm_flash_attn_version = get_flash_attn_version( - requires_alibi=self.alibi_slopes is not None) - if is_quantized_kv_cache(self.kv_cache_dtype) and ( - not self.kv_cache_dtype.startswith("fp8") - or not flash_attn_supports_fp8()): - raise NotImplementedError( - f"FlashAttention does not support {self.kv_cache_dtype} " - "kv-cache on this device " - f"(FA supports fp8 = {flash_attn_supports_fp8()}).") - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}.") - self.attn_type = attn_type - - self.lambda_full = None - self.subln = self.differential_flash_attention_config["subln"] - - def split_heads(self, x): - # split by num_heads, the stripe pattern is friendly to tensor parallel. - x = rearrange(x, "... (H two) D -> ... H two D", two=2) - x1 = x[..., 0, :] - x2 = x[..., 1, :] - return x1.contiguous(), x2.contiguous() - - def split_kv_cache(self, x): - # split by num_heads, the stripe pattern is friendly to tensor parallel. - if x.numel() == 0: - return torch.empty(0), torch.empty(0) - - x1, x2 = x[0], x[1] - return x1, x2 - - def populate_kv_cache(self, layer: AttentionLayer, key: torch.Tensor, - value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata): - if kv_cache.numel() > 0 and key is not None and value is not None: - updated_slot_mapping = attn_metadata.slot_mapping - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - def forward_generate_kv_cache( - self, query: torch.Tensor, key: Optional[torch.Tensor], - value: Optional[torch.Tensor], k_cache: torch.Tensor, - v_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor: - - head_size = self.head_size - num_heads = self.num_heads // 2 - num_kv_heads = self.num_kv_heads // 2 - - query = query.view(-1, num_heads, head_size) - if key is not None: - assert value is not None - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - else: - assert value is None - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[ - 0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch" - assert value.shape[ - 0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch" - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - if key is not None and value is not None: - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens, "query shape mismatch" - assert decode_query.shape[ - 0] == num_decode_tokens, "decode query shape mismatch" - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if k_cache.numel() == 0 \ - or prefill_meta.block_tables is None \ - or prefill_meta.block_tables.numel() == 0: - # normal attention - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - fa_version=self.vllm_flash_attn_version, - ) - assert prefill_output.shape == output[: - num_prefill_tokens].shape - output[:num_prefill_tokens] = prefill_output - else: - raise Exception("prefix caching not supported") - - if decode_meta := attn_metadata.decode_metadata: - block_tables_arg = decode_meta.block_tables - try: - output[num_prefill_tokens:] = flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=k_cache, - v_cache=v_cache, - block_table=block_tables_arg, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - fa_version=self.vllm_flash_attn_version, - ).squeeze(1) - except Exception as e: - logger.error("Error in PagedAttention.forward_decode: %s", - str(e)) - raise e - - # Reshape the output tensor. - return output.view(-1, num_heads, head_size) - - def forward_with_kv_cache_only( - self, - query: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata, - ): - if not attn_metadata.decode_metadata: - block_tables_arg = attn_metadata.cross_layer_shared_block_tables - else: - block_tables_arg = attn_metadata.block_tables - - output = flash_attn_with_kvcache( - q=query.unsqueeze(1), - k_cache=k_cache, - v_cache=v_cache, - block_table=block_tables_arg, - cache_seqlens=attn_metadata.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - fa_version=self.vllm_flash_attn_version, - ).squeeze(1) - return output - - def forward( - self, - layer: AttentionLayer, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - layer: Attention layer instance. - q: Query tensor with shape = [num_tokens, num_heads, head_size] - k: Key tensor with shape = [num_tokens, num_kv_heads, head_size] - v: Value tensor with shape = [num_tokens, num_kv_heads, head_size] - kv_cache: KV cache tensor with shape - [2, num_blocks, block_size, num_kv_heads, head_size]. - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - output: Output tensor with shape [num_tokens, num_heads, head_size] - output_scale: Optional output scale tensor. - output_block_scale: Optional output block scale tensor. - NOTE: It in-place updates the output tensor. - NOTE: FP8 quantization, flash-attn expect the size of - {q,k,v}_descale to be (num_sequences, num_kv_heads). - We use torch's .expand() to avoid duplicating values - """ - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for DifferentialFlashAttentionImpl") - - if self.lambda_full is None: - self.lambda_init = self.differential_flash_attention_config[ - "lambda_init"] - lambda_q1 = self.differential_flash_attention_config["lambda_q1"] - lambda_k1 = self.differential_flash_attention_config["lambda_k1"] - lambda_q2 = self.differential_flash_attention_config["lambda_q2"] - lambda_k2 = self.differential_flash_attention_config["lambda_k2"] - lambda_1 = torch.exp( - torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q) - lambda_2 = torch.exp( - torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q) - self.lambda_full = lambda_1 - lambda_2 + self.lambda_init - - if not self.used_shared_kv_cache: # need to generate kv-cache - q = q.view(-1, self.num_heads, self.head_size) - k = k.view(-1, self.num_kv_heads, self.head_size) - v = v.view(-1, self.num_kv_heads, self.head_size) - - q1, q2 = self.split_heads(q) - k1, k2 = self.split_heads(k) - v1, v2 = self.split_heads(v) - - # kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501 - # Split by half along the first dimension. - kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) - assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" - assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" - - if kv_cache1.numel() != 0: - self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata) - self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata) - - key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) - key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) - else: - key_cache1, value_cache1 = torch.empty(0), torch.empty(0) - key_cache2, value_cache2 = torch.empty(0), torch.empty(0) - attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1, - value_cache1, - attn_metadata) - attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1, - value_cache2, - attn_metadata) - attn11 = attn11.view(q1.shape) - attn12 = attn12.view(q1.shape) - attn1 = torch.cat([attn11, attn12], dim=-1) - - attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2, - value_cache1, - attn_metadata) - attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2, - value_cache2, - attn_metadata) - attn21 = attn21.view(q2.shape) - attn22 = attn22.view(q2.shape) - attn2 = torch.cat([attn21, attn22], dim=-1) - - attn = attn1 - self.lambda_full * attn2 - # attn shape (-1, self.num_heads // 2, 2 * self.head_dim) - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - # reshape back to 2 * num_head - attn_output = rearrange(attn, - "... H (two D) -> ... (H two) D", - two=2) - - else: # reuse the kv cache, full attention - q = q.view(-1, self.num_heads, self.head_size) - q1, q2 = self.split_heads(q) - # kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501 - kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) - key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] - key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] - - attn11 = self.forward_with_kv_cache_only(q1, key_cache1, - value_cache1, - attn_metadata) - attn12 = self.forward_with_kv_cache_only(q1, key_cache1, - value_cache2, - attn_metadata) - attn11 = attn11.view(q1.shape) - attn12 = attn12.view(q1.shape) - attn1 = torch.cat([attn11, attn12], dim=-1) - - attn21 = self.forward_with_kv_cache_only(q2, key_cache2, - value_cache1, - attn_metadata) - attn22 = self.forward_with_kv_cache_only(q2, key_cache2, - value_cache2, - attn_metadata) - attn21 = attn21.view(q2.shape) - attn22 = attn22.view(q2.shape) - attn2 = torch.cat([attn21, attn22], dim=-1) - - attn = attn1 - self.lambda_full * attn2 - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - # reshape back to 2 * num_head - attn_output = rearrange(attn, - "... H (two D) -> ... (H two) D", - two=2) - attn_output = attn_output.view(-1, self.num_heads * self.head_size) - return attn_output diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py deleted file mode 100644 index de47bb8ebd8f..000000000000 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ /dev/null @@ -1,1495 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with Dual chunk flash attention and sparse attention. -""" -import math -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch -import torch.distributed -import torch.nn.functional as F - -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import AttentionLayer, AttentionType -from vllm.attention.backends.flash_attn import (FlashAttentionBackend, - FlashAttentionImpl, - FlashAttentionMetadata, - FlashAttentionMetadataBuilder) -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.logger import init_logger -from vllm.utils import async_tensor_h2d -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache, sparse_attn_func) - -logger = init_logger(__name__) - - -class DualChunkFlashAttentionBackend(FlashAttentionBackend): - - accept_output_buffer: bool = False - - @staticmethod - def get_name() -> str: - return "DUAL_CHUNK_FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["DualChunkFlashAttentionImpl"]: - return DualChunkFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["DualChunkFlashAttentionMetadata"]: - return DualChunkFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["DualChunkFlashAttentionMetadataBuilder"]: - return DualChunkFlashAttentionMetadataBuilder - - -@dataclass -class DualChunkFlashAttentionMetadata(FlashAttentionMetadata): - # Block size of the paged kv cache. - block_size: int = 16 - - # Original max position embeddings. - original_max_position_embeddings: int = 0 - - # Chunk size - chunk_size: int = 8192 - - # Local size - local_size: int = 1024 - - # (batch_size,). The orig sequence length per sequence. - orig_seq_lens: Optional[List[int]] = None - - # orig_seq_lens stored as a tensor. - orig_seq_lens_tensor: Optional[torch.Tensor] = None - - # Length scaling factor - scaling_factor: Optional[torch.Tensor] = None - - # (batch_size,). Sequence lengths for intra attention. - seq_lens_intra: Optional[torch.Tensor] = None - - # Max sequence length for intra attention. - max_seq_len_intra: Optional[int] = None - - # (batch_size, num_blocks). Block table for intra attention. - block_tables_intra: Optional[torch.Tensor] = None - - # (batch_size,). Sequence lengths for succ attention. - seq_lens_succ: Optional[torch.Tensor] = None - - # Max sequence length for succ attention. - max_seq_len_succ: Optional[int] = None - - # (batch_size, num_blocks). Block table for succ attention. - block_tables_succ: Optional[torch.Tensor] = None - - # (batch_size,). Sequence lengths for inter attention. - seq_lens_inter: Optional[torch.Tensor] = None - - # Max sequence length for inter attention. - max_seq_len_inter: Optional[int] = None - - _cached_prefill_metadata: Optional[ - "DualChunkFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["DualChunkFlashAttentionMetadata"] = None - - @property - def prefill_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - prefill_metadata = super().prefill_metadata - if prefill_metadata is None: - return None - - prefill_metadata = DualChunkFlashAttentionMetadata( - **prefill_metadata.asdict_zerocopy()) - - prefill_metadata.orig_seq_lens = ( - None if self.orig_seq_lens is None else - self.orig_seq_lens[:self.num_prefills]) - prefill_metadata.orig_seq_lens_tensor = ( - None if self.orig_seq_lens_tensor is None else - self.orig_seq_lens_tensor[:self.num_prefills]) - - if self.original_max_position_embeddings > 0: - assert prefill_metadata.orig_seq_lens_tensor is not None - prefill_metadata.scaling_factor = ( - 0.1 * torch.log(prefill_metadata.orig_seq_lens_tensor / - self.original_max_position_embeddings) + - 1.0).clip(min=1) - - self._cached_prefill_metadata = prefill_metadata - return prefill_metadata - - @property - def decode_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - - decode_metadata = super().decode_metadata - if decode_metadata is None: - return None - - decode_metadata = DualChunkFlashAttentionMetadata( - **decode_metadata.asdict_zerocopy()) - - decode_metadata.orig_seq_lens_tensor = ( - None if self.orig_seq_lens_tensor is None else - self.orig_seq_lens_tensor[self.num_prefills:]) - - assert decode_metadata.orig_seq_lens_tensor is not None - assert decode_metadata.block_tables is not None - - cache_seq_lens = decode_metadata.orig_seq_lens_tensor - chunk_len = self.chunk_size - self.local_size - chunk_num_curr = (cache_seq_lens - 1) // chunk_len - batch_size = decode_metadata.num_decode_tokens - - if self.original_max_position_embeddings > 0: - decode_metadata.scaling_factor = (0.1 * torch.log( - cache_seq_lens / self.original_max_position_embeddings) + - 1.0).clip(min=1) - - seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len - max_seq_len_intra = seq_lens_intra.max().item() - decode_metadata.seq_lens_intra = seq_lens_intra - decode_metadata.max_seq_len_intra = max_seq_len_intra - - block_tables_intra = torch.zeros( - batch_size, - (max_seq_len_intra - 1) // self.block_size + 1, - dtype=decode_metadata.block_tables.dtype, - device=decode_metadata.block_tables.device, - ) - for i in range(batch_size): - st = chunk_num_curr[i] * chunk_len // self.block_size - ed = min( - st + (max_seq_len_intra - 1) // self.block_size + 1, - (cache_seq_lens[i] - 1) // self.block_size + 1, - ) - block_tables_intra[i, :ed - - st] = decode_metadata.block_tables[i, st:ed] - decode_metadata.block_tables_intra = block_tables_intra - - seq_lens_succ = (chunk_num_curr - - (chunk_num_curr - 1).clip(min=0)) * chunk_len - max_seq_len_succ = seq_lens_succ.max().item() - decode_metadata.seq_lens_succ = seq_lens_succ - decode_metadata.max_seq_len_succ = max_seq_len_succ - if max_seq_len_succ: - block_tables_succ = torch.zeros( - batch_size, - (max_seq_len_succ - 1) // self.block_size + 1, - dtype=decode_metadata.block_tables.dtype, - device=decode_metadata.block_tables.device, - ) - for i in range(batch_size): - start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // - self.block_size) - end = min( - start + (max_seq_len_succ - 1) // self.block_size + 1, - (cache_seq_lens[i] - 1) // self.block_size + 1, - ) - block_tables_succ[ - i, :end - start] = decode_metadata.block_tables[i, - start:end] - decode_metadata.block_tables_succ = block_tables_succ - - seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len - max_seq_len_inter = seq_lens_inter.max().item() - decode_metadata.seq_lens_inter = seq_lens_inter - decode_metadata.max_seq_len_inter = max_seq_len_inter - - self._cached_decode_metadata = decode_metadata - return decode_metadata - - -class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder): - - def prepare(self): - super().prepare() - self.orig_seq_lens: List[int] = [] - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - super()._add_seq_group(inter_data, chunked_prefill_enabled, - prefix_cache_hit) - for prompt_len, seq_len in zip(inter_data.prompt_lens, - inter_data.seq_lens): - self.orig_seq_lens.append(max(prompt_len, seq_len)) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - attn_metadata = super().build(seq_lens, query_lens, - cuda_graph_pad_size, batch_size) - attn_metadata = DualChunkFlashAttentionMetadata( - **attn_metadata.asdict_zerocopy()) - - device = self.runner.device - attn_metadata.orig_seq_lens = self.orig_seq_lens - attn_metadata.orig_seq_lens_tensor = async_tensor_h2d( - self.orig_seq_lens, torch.int, device, self.runner.pin_memory) - - attn_metadata.block_size = self.runner.block_size - dual_chunk_attn_config = getattr(self.runner.model_config.hf_config, - "dual_chunk_attention_config", {}) - attn_metadata.original_max_position_embeddings = \ - dual_chunk_attn_config.get("original_max_position_embeddings", 0) - attn_metadata.chunk_size = dual_chunk_attn_config.get( - "chunk_size", 8192) - attn_metadata.local_size = dual_chunk_attn_config.get( - "local_size", 1024) - - return attn_metadata - - -class DualChunkFlashAttentionImpl(FlashAttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - The prompts might have different lengths, while the generation tokens - always have length 1. - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - layer_idx: int = -1, - dual_chunk_attention_config: Optional[Dict[str, Any]] = None, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "DUAL_CHUNK_FLASH_ATTN backend.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if sliding_window is not None: - # NOTE(woosuk): flash-attn's sliding window does not work with - # paged KV cache. - raise ValueError( - "Sliding window is not supported in FlashAttention.") - - support_head_sizes = ( - DualChunkFlashAttentionBackend.get_supported_head_sizes()) - - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}.") - - assert dual_chunk_attention_config is not None - self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192) - self.local_size = dual_chunk_attention_config.get("local_size", 1024) - self.original_max_position_embeddings = dual_chunk_attention_config.get( - "original_max_position_embeddings", 0) - self.sparse_attention_config = dual_chunk_attention_config.get( - "sparse_attention_config", None) - if not self.sparse_attention_config: - logger.warning_once("Sparse attention will not be enabled as " - "sparse attention config is not provided.") - self.sparse_attention_enabled = dual_chunk_attention_config.get( - "sparse_attention_enabled", self.sparse_attention_config - is not None) - self.sparse_attention_threshold = dual_chunk_attention_config.get( - "sparse_attention_threshold", 32768) - self.sparse_attention_last_q = dual_chunk_attention_config.get( - "sparse_attention_last_q", 64) - self.layer_idx = layer_idx - self.dual_chunk_attention_config = dual_chunk_attention_config - - if self.sparse_attention_config: - self.sparse_attention_config = { - int(i): j - for i, j in self.sparse_attention_config[ - self.layer_idx].items() - } - start_head = self.num_heads * get_tensor_model_parallel_rank() - end_head = start_head + self.num_heads - self.sparse_attention_config = [ - self.sparse_attention_config[i] - for i in range(start_head, end_head) - ] - - if self.sparse_attention_enabled: - self.arange = torch.arange(self.sparse_attention_last_q, - device="cuda") - self.last_q_mask = (self.arange[None, None, :, None] - >= self.arange[None, None, None, :]) - - def forward( # type: ignore - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: DualChunkFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with DualChunkFlashAttention. - Args: - query: shape = [num_tokens, num_heads * head_size] - query_succ: shape = [num_tokens, num_heads * head_size] - query_inter: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - assert output is None, "Output tensor not supported for DualChunk" - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") - - ( - query, - query_succ, - query_inter, - query_succ_critical, - query_inter_critical, - ) = torch.split(query, query.shape[-1] // 5, dim=-1) - - assert ( - query_succ is not None and query_inter is not None - ), "query_succ and query_inter are required in Dual Chunk Attention." - - num_tokens, hidden_size = query.shape - - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - query_succ = query_succ.view(-1, self.num_heads, self.head_size) - query_inter = query_inter.view(-1, self.num_heads, self.head_size) - query_succ_critical = query_succ_critical.view(-1, self.num_heads, - self.head_size) - query_inter_critical = query_inter_critical.view( - -1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if self.original_max_position_embeddings > 0: - if prefill_meta := attn_metadata.prefill_metadata: - assert prefill_meta.scaling_factor is not None - assert prefill_meta.query_start_loc is not None - assert prefill_meta.orig_seq_lens is not None - current_start = 0 - query_start_loc_cpu = prefill_meta.query_start_loc.cpu() - for i in range(len(prefill_meta.orig_seq_lens)): - current_end = (current_start + - (query_start_loc_cpu[i + 1] - - query_start_loc_cpu[i]).item()) - key[current_start:current_end].mul_( - prefill_meta.scaling_factor[i]) - current_start = current_end - assert current_end <= attn_metadata.num_prefill_tokens - if decode_meta := attn_metadata.decode_metadata: - assert decode_meta.scaling_factor is not None - scaling_factor = decode_meta.scaling_factor - key[attn_metadata.num_prefill_tokens:].mul_( - scaling_factor.unsqueeze(-1).unsqueeze(-1)) - - if kv_cache is not None and kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - output = torch.empty_like(query) - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - decode_query_succ = query_succ[num_prefill_tokens:] - decode_query_inter = query_inter[num_prefill_tokens:] - - # QKV for prefill. - query = query[:num_prefill_tokens] - query_succ = query_succ[:num_prefill_tokens] - query_inter = query_inter[:num_prefill_tokens] - query_succ_critical = query_succ_critical[:num_prefill_tokens] - query_inter_critical = query_inter_critical[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention, called during the profiling run. - out = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out - else: - # prefix-enabled attention - assert prefill_meta.seq_lens is not None - assert prefill_meta.orig_seq_lens is not None - output[:num_prefill_tokens] = ( - self._dual_chunk_flash_attn_prefill( - q=query, - q_succ=query_succ, - q_inter=query_inter, - q_succ_critical=query_succ_critical, - q_inter_critical=query_inter_critical, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - orig_seq_lens=prefill_meta.orig_seq_lens, - scaling_factor=prefill_meta.scaling_factor, - softmax_scale=self.scale, - causal=True, - window_size=(-1, -1), - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - chunk_size=self.chunk_size, - local_size=self.local_size, - )) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output[num_prefill_tokens:] = ( - self._dual_chunk_flash_attn_decoding( - decode_query.unsqueeze(1), - decode_query_succ.unsqueeze(1), - decode_query_inter.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - chunk_size=self.chunk_size, - local_size=self.local_size, - original_max_position_embeddings=self. - original_max_position_embeddings, - decode_meta=decode_meta, - ).squeeze(1)) - # Reshape the output tensor. - return output.view(num_tokens, hidden_size) - - def _dual_chunk_flash_attn_prefill( - self, - q, - q_succ, - q_inter, - q_succ_critical, - q_inter_critical, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - orig_seq_lens: List[int], - scaling_factor: torch.Tensor, - softmax_scale: float, - causal: Optional[bool] = True, - window_size: Tuple[int, int] = (-1, -1), - alibi_slopes: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - chunk_size: int = 8192, - local_size: int = 1024, - ): - if alibi_slopes is not None: - raise ValueError( - "Dual Chunk Attention does not support alibi_slopes") - if not causal: - raise ValueError( - "Dual Chunk Attention does not support causal=False") - if window_size != (-1, -1): - raise ValueError( - "Dual Chunk Attention does not support window_size") - - cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist() - cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist() - all_outputs = [] - - for i in range(0, len(cu_seqlens_q_cpu) - 1): - qs = cu_seqlens_q_cpu[i] - qe = cu_seqlens_q_cpu[i:i + 2][-1] - ks = cu_seqlens_k_cpu[i] - ke = cu_seqlens_k_cpu[i:i + 2][-1] - - current_q = q[qs:qe] - current_q_succ = q_succ[qs:qe] - current_q_inter = q_inter[qs:qe] - current_q_succ_critical = q_succ_critical[qs:qe] - current_q_inter_critical = q_inter_critical[qs:qe] - - if block_table is None: - current_k = k[ks:ke] - current_v = v[ks:ke] - current_block_table = None - current_orig_seq_len = orig_seq_lens[i] - else: - current_block_table = block_table[i] - current_orig_seq_len = orig_seq_lens[i] - current_k = k - current_v = v - sparse_attn_enabled = (self.sparse_attention_enabled - and current_orig_seq_len - > self.sparse_attention_threshold) - - if current_q.shape[0] == 0: - continue - - if current_k.shape[0] == 0: - all_outputs.append( - torch.zeros( - (current_q.shape[0], current_q.shape[1], v.shape[2]), - device=q.device, - dtype=q.dtype, - )) - continue - - current_output = torch.empty_like(current_q) - group_size = int(current_q.size(-2) / current_k.size(-2)) - - if sparse_attn_enabled: - num_device_q_heads = current_q.size(-2) - heads_vertical_size = torch.empty(size=(num_device_q_heads, ), - dtype=torch.int32) - heads_slash_size = torch.empty(size=(num_device_q_heads, ), - dtype=torch.int32) - for head_id in range(current_q.size(-2)): - ( - ty, - vertical_size, - slash_size, - _, - ) = self.sparse_attention_config[head_id] - assert ty == "vertical_and_slash", "only support slash mode" - - if vertical_size == 30: - vertical_size += 100 - heads_vertical_size[head_id] = vertical_size - heads_slash_size[head_id] = slash_size - - current_output = self._dual_chunk_flash_attn_prefill_func( - current_q, # allheads - current_q_succ, - current_q_inter, - current_q_succ_critical, - current_q_inter_critical, - current_k, - current_v, - current_block_table, - softmax_scale, - chunk_size, - local_size, - scaling_factor[i].item(), - ke - ks, - sparse_attn_enabled=sparse_attn_enabled, - heads_vertical_size=heads_vertical_size, - heads_slash_size=heads_slash_size, - group_size=group_size) - else: - for head_id in range(current_q.size(-2)): - # (seq_len, num_heads, head_size) - current_q_head = current_q[:, head_id, :].unsqueeze(1) - current_q_succ_head = \ - current_q_succ[:, head_id, :].unsqueeze(1) - current_q_inter_head = \ - current_q_inter[:, head_id, :].unsqueeze(1) - current_q_succ_head_critical = \ - current_q_succ_critical[:, head_id, :].unsqueeze(1) - current_q_inter_head_critical = \ - current_q_inter_critical[:, head_id, :].unsqueeze(1) - if block_table is not None: - current_k_head = current_k[..., head_id // - group_size, :].unsqueeze(2) - current_v_head = current_v[..., head_id // - group_size, :].unsqueeze(2) - - else: - current_k_head = current_k[:, head_id, :].unsqueeze(1) - current_v_head = current_v[:, head_id, :].unsqueeze(1) - - current_out = self._dual_chunk_flash_attn_prefill_func( - current_q_head, - current_q_succ_head, - current_q_inter_head, - current_q_succ_head_critical, - current_q_inter_head_critical, - current_k_head, - current_v_head, - current_block_table, - softmax_scale, - chunk_size, - local_size, - scaling_factor[i].item(), - ke - ks, - sparse_attn_enabled=sparse_attn_enabled, - ) - current_output[:, head_id:head_id + 1, :] = current_out - all_outputs.append(current_output) - return torch.cat(all_outputs, dim=0) - - def _dual_chunk_flash_attn_prefill_func( - self, - q, - q_succ, - q_inter, - q_succ_critical, - q_inter_critical, - k, - v, - block_table, - softmax_scale: float, - chunk_size: int, - local_size: int, - scaling_factor: float, - k_length: int, - sparse_attn_enabled: Optional[bool] = True, - heads_vertical_size=None, - heads_slash_size=None, - group_size=None, - ): - flash_results = [] - chunk_len = chunk_size - local_size - - if block_table is not None: - block_size = v.shape[1] - if chunk_len % block_size != 0: - raise ValueError("chunk_len must be divisible by block_size.") - else: - block_size = 1 - - if self.original_max_position_embeddings > 0: - softmax_scale = softmax_scale * scaling_factor - - begin = k_length - q.shape[0] - while begin < k_length: - flash_per_chunk = [] - - prev_chunk_end_pos = (begin // chunk_len) * chunk_len - next_chunk_end_pos = prev_chunk_end_pos + chunk_len - end = min(next_chunk_end_pos, k_length) - qbegin = begin - (k_length - q.shape[0]) - qend = end - (k_length - q.shape[0]) - - qk_chunks = [] - q_states_intra = q[qbegin:qend] - # choose critical token - if block_table is not None: - block_tables_intra = _get_block(block_table, block_size, - prev_chunk_end_pos, end) - k_states_intra = k[block_tables_intra].view( - -1, *k.shape[-2:])[:(end - prev_chunk_end_pos)] - v_states_intra = v[block_tables_intra].view( - -1, *v.shape[-2:])[:(end - prev_chunk_end_pos)] - else: - block_tables_intra = None - k_states_intra = k[prev_chunk_end_pos:end] - v_states_intra = v[prev_chunk_end_pos:end] - - if sparse_attn_enabled: - last_q_size = min(qend - qbegin, self.sparse_attention_last_q) - _, num_device_k_heads, head_dim = k_states_intra.shape - k_states_intra = (k_states_intra.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, head_dim)) - v_states_intra = (v_states_intra.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, head_dim)) - qk_chunks.append( - (q_states_intra.transpose(0, 1)[:, -last_q_size:] * - softmax_scale) @ k_states_intra.permute(1, 2, 0)) - - if prev_chunk_end_pos - chunk_len >= 0: - q_states_succ = q_succ[qbegin:qend] - q_states_succ_critical = q_succ_critical[qbegin:qend] - if block_table is not None: - block_tables_succ = _get_block( - block_table, block_size, - prev_chunk_end_pos - chunk_len, prev_chunk_end_pos) - k_states_succ = k[block_tables_succ].view( - -1, *k.shape[-2:])[:chunk_len] - v_states_succ = v[block_tables_succ].view( - -1, *v.shape[-2:])[:chunk_len] - else: - k_states_succ = k[prev_chunk_end_pos - - chunk_len:prev_chunk_end_pos] - v_states_succ = v[prev_chunk_end_pos - - chunk_len:prev_chunk_end_pos] - - if sparse_attn_enabled: - k_states_succ = (k_states_succ.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - v_states_succ = (v_states_succ.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - qk_chunks.append((q_states_succ_critical.transpose( - 0, 1)[:, -last_q_size:] * softmax_scale) - @ k_states_succ.permute(1, 2, 0)) - - if prev_chunk_end_pos - chunk_len * 2 >= 0: - q_states_inter = q_inter[qbegin:qend] - q_states_inter_critical = q_inter_critical[qbegin:qend] - if block_table is not None: - block_tables_inter = _get_block( - block_table, block_size, 0, - prev_chunk_end_pos - chunk_len) - k_states_inter = k[block_tables_inter].view( - -1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] - v_states_inter = v[block_tables_inter].view( - -1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] - else: - k_states_inter = k[:prev_chunk_end_pos - chunk_len] - v_states_inter = v[:prev_chunk_end_pos - chunk_len] - - if sparse_attn_enabled: - k_states_inter = (k_states_inter.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - v_states_inter = (v_states_inter.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - qk_chunks.append((q_states_inter_critical.transpose( - 0, 1)[:, -last_q_size:] * softmax_scale) - @ k_states_inter.permute(1, 2, 0)) - - if sparse_attn_enabled: - reversed_qk = qk_chunks[::-1] - qk = torch.cat(reversed_qk, dim=-1) - - qk[:, :, -last_q_size:] = torch.where( - self.last_q_mask[..., -last_q_size:, - -last_q_size:].to(qk.device), - qk[:, :, -last_q_size:], -torch.inf) - qk = F.softmax(qk, dim=-1, dtype=torch.float32) - - vertical = qk.sum(-2, keepdim=True) - vertical[..., :30] = torch.inf - - # Avoid sorting by using the min/max ints to fill the indexer - # buffers. - int32_max = torch.iinfo(torch.int32).max - int32_min = torch.iinfo(torch.int32).min - n_heads = qk.size()[0] - max_slash_topk = torch.max(heads_slash_size).item() - max_vertical_topk = torch.max(heads_vertical_size).item() - # store each head's slash topk, vertical topk - vertical = vertical.reshape((n_heads, -1)) - # prevent out of range when prompt size < max_vertical_topk - max_vertical_topk = min(vertical.shape[-1], max_vertical_topk) - vertical_topk_buffer = torch.topk(vertical, max_vertical_topk, - -1).indices - slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk), - dtype=torch.int64, - device=qk.device) - for head_i in range(n_heads): - # (nqheads=1, lastq, k_len) - head_score = qk[head_i:head_i + 1, :, :] - slash_scores = _sum_all_diagonal_matrix(head_score) - if head_score.size(1) != 1: - # drop right up corner - slash_scores = slash_scores[..., :-last_q_size + 1] - slash_scores[..., -100:] = torch.inf - - head_slash_size = heads_slash_size[head_i] - head_slash_size = min(head_slash_size, vertical.size(-1)) - slash_topk = torch.topk(slash_scores, head_slash_size, - -1).indices - #(nheads, max_topk) - slash_topk_buffer[head_i, :head_slash_size] = slash_topk - - # reset heads topk - heads_slash_size[head_i] = head_slash_size - heads_vertical_size[head_i] = min( - heads_vertical_size[head_i], max_vertical_topk) - - # store - vertical_buffer = torch.full((n_heads, max_vertical_topk), - int32_max, - dtype=torch.int64, - device=q.device) - slash_buffer = torch.full((n_heads, max_slash_topk), - int32_min, - dtype=torch.int64, - device=q.device) - succ_vertical_buffer = torch.full((n_heads, max_vertical_topk), - int32_max, - dtype=torch.int64, - device=q.device) - succ_slash_buffer = torch.full((n_heads, max_slash_topk), - int32_min, - dtype=torch.int64, - device=q.device) - inter_vertical_buffer = torch.full( - (n_heads, max_vertical_topk), - int32_max, - dtype=torch.int64, - device=q.device) - inter_slash_buffer = torch.full((n_heads, max_slash_topk), - int32_min, - dtype=torch.int64, - device=q.device) - - vertical_size_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - slash_sizes_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - succ_vertical_size_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - succ_slash_sizes_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - inter_vertical_size_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - inter_slash_sizes_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - - for head_i in range(n_heads): - vertical_topk = vertical_topk_buffer[ - head_i, :heads_vertical_size[head_i]] - # intra - intra_vertical_indices = vertical_topk[ - vertical_topk >= - prev_chunk_end_pos] - prev_chunk_end_pos - if intra_vertical_indices.nelement() == 0: - intra_vertical_indices = torch.cat([ - intra_vertical_indices, - torch.arange(0, - k_states_intra.size(0), - max(1, - k_states_intra.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - slash_topk = slash_topk_buffer[ - head_i, :heads_slash_size[head_i]] - intra_slash_indices = ( - (qk.size(-1) - 1) - - slash_topk[slash_topk >= prev_chunk_end_pos]) - # fill buffer - v_count = intra_vertical_indices.nelement() - s_count = intra_slash_indices.nelement() - vertical_size_buffer[head_i] = v_count - slash_sizes_buffer[head_i] = s_count - vertical_buffer[head_i, :v_count].copy_( - intra_vertical_indices) - slash_buffer[head_i, :s_count].copy_(intra_slash_indices) - # succ - if prev_chunk_end_pos - chunk_len >= 0: - succ_vertical_indices = vertical_topk[ - (vertical_topk < prev_chunk_end_pos) - & (vertical_topk >= prev_chunk_end_pos - - chunk_len)] - (prev_chunk_end_pos - chunk_len) - # TODO: support no vertical - if succ_vertical_indices.nelement() == 0: - succ_vertical_indices = torch.cat([ - succ_vertical_indices, - torch.arange( - 0, - k_states_succ.size(0), - max(1, - k_states_succ.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - succ_slash_indices = ( - (prev_chunk_end_pos + (qend - qbegin) - 1) - - slash_topk[((slash_topk >= - (prev_chunk_end_pos - chunk_len)) & - (slash_topk < (prev_chunk_end_pos + - (qend - qbegin))))]) - if succ_slash_indices.nelement() == 0: - succ_slash_indices = torch.cat([ - succ_slash_indices, - torch.arange( - 0, - k_states_succ.size(0), - max(1, - k_states_succ.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - # fill buffer - v_count = succ_vertical_indices.nelement() - s_count = succ_slash_indices.nelement() - succ_vertical_size_buffer[head_i] = v_count - succ_slash_sizes_buffer[head_i] = s_count - succ_vertical_buffer[head_i, :v_count].copy_( - succ_vertical_indices) - succ_slash_buffer[head_i, :s_count].copy_( - succ_slash_indices) - - if prev_chunk_end_pos - 2 * chunk_len >= 0: - inter_vertical_indices = vertical_topk[ - vertical_topk < prev_chunk_end_pos - chunk_len] - - if inter_vertical_indices.nelement() == 0: - inter_vertical_indices = torch.cat([ - inter_vertical_indices, - torch.arange( - 0, - k_states_inter.size(0), - max(1, - k_states_inter.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - inter_slash_indices = ( - (prev_chunk_end_pos - chunk_len + - (qend - qbegin) - 1) - - slash_topk[slash_topk < (prev_chunk_end_pos - - chunk_len + - (qend - qbegin))]) - if inter_slash_indices.nelement() == 0: - inter_slash_indices = torch.cat([ - inter_slash_indices, - torch.arange( - 0, - k_states_inter.size(0), - max(1, - k_states_inter.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - # fill buffer - v_count = inter_vertical_indices.nelement() - s_count = inter_slash_indices.nelement() - inter_vertical_size_buffer[head_i] = v_count - inter_slash_sizes_buffer[head_i] = s_count - inter_vertical_buffer[head_i, :v_count].copy_( - inter_vertical_indices) - inter_slash_buffer[head_i, :s_count].copy_( - inter_slash_indices) - else: - intra_vertical_indices, intra_slash_indices = None, None - succ_vertical_indices, succ_slash_indices = None, None - inter_vertical_indices, inter_slash_indices = None, None - - if sparse_attn_enabled: - flash_result = self._do_flash_attn( - q_states_intra, - k_states_intra, - v_states_intra, - softmax_scale=softmax_scale, - causal=True, - stage="intra", - vertical_indices=vertical_buffer, - slash_indices=slash_buffer, - vertical_indices_count=vertical_size_buffer, - slash_indices_count=slash_sizes_buffer, - mergehead_softmax_scale=softmax_scale, - sparse_attn_enabled=sparse_attn_enabled) - else: - flash_result = self._do_flash_attn( - q_states_intra, - k_states_intra, - v_states_intra, - softmax_scale=softmax_scale, - causal=True, - stage="intra", - vertical_indices=intra_vertical_indices, - slash_indices=intra_slash_indices, - sparse_attn_enabled=sparse_attn_enabled) - flash_per_chunk.append(flash_result) - - if prev_chunk_end_pos - chunk_len >= 0: - if sparse_attn_enabled: - flash_result = self._do_flash_attn( - q_states_succ, - k_states_succ, - v_states_succ, - softmax_scale=softmax_scale, - causal=False, - stage="succ", - vertical_indices=succ_vertical_buffer, - slash_indices=succ_slash_buffer, - vertical_indices_count=succ_vertical_size_buffer, - slash_indices_count=succ_slash_sizes_buffer, - mergehead_softmax_scale=softmax_scale, - sparse_attn_enabled=sparse_attn_enabled) - else: - flash_result = self._do_flash_attn( - q_states_succ, - k_states_succ, - v_states_succ, - softmax_scale=softmax_scale, - causal=False, - stage="succ", - vertical_indices=succ_vertical_indices, - slash_indices=succ_slash_indices, - sparse_attn_enabled=sparse_attn_enabled) - flash_per_chunk.append(flash_result) - - if prev_chunk_end_pos - chunk_len * 2 >= 0: - if sparse_attn_enabled: - flash_result = self._do_flash_attn( - q_states_inter, - k_states_inter, - v_states_inter, - softmax_scale=softmax_scale, - causal=False, - stage="inter", - vertical_indices=inter_vertical_buffer, - slash_indices=inter_slash_buffer, - vertical_indices_count=inter_vertical_size_buffer, - slash_indices_count=inter_slash_sizes_buffer, - mergehead_softmax_scale=softmax_scale, - sparse_attn_enabled=sparse_attn_enabled) - else: - flash_result = self._do_flash_attn( - q_states_inter, - k_states_inter, - v_states_inter, - softmax_scale=softmax_scale, - causal=False, - stage="inter", - vertical_indices=inter_vertical_indices, - slash_indices=inter_slash_indices, - sparse_attn_enabled=sparse_attn_enabled) - flash_per_chunk.append(flash_result) - - flash_results.append(flash_per_chunk) - begin = end - - attn_output = self._merge_attn_outputs(flash_results) - del flash_results - return attn_output - - def _do_flash_attn( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - softmax_scale: float, - causal: bool = True, - max_seqlen_k: Optional[int] = None, - stage: str = "intra", - vertical_indices: Optional[torch.Tensor] = None, - slash_indices: Optional[torch.Tensor] = None, - vertical_indices_count: Optional[torch.Tensor] = None, - slash_indices_count: Optional[torch.Tensor] = None, - mergehead_softmax_scale: Optional[float] = None, - sparse_attn_enabled: Optional[bool] = False, - ): - if max_seqlen_k is None: - max_seqlen_k = key_states.shape[0] - - q_len = query_states.shape[0] - q_heads = query_states.shape[1] - h_dim = query_states.shape[-1] - - if sparse_attn_enabled: - assert slash_indices is not None - if stage == "intra": - assert causal - else: - assert not causal - - query_states = query_states.unsqueeze(0).transpose(1, 2) - key_states = key_states.unsqueeze(0).transpose(1, 2) - value_states = value_states.unsqueeze(0).transpose(1, 2) - - q = query_states - k = key_states - v = value_states - - if (vertical_indices_count is not None and \ - slash_indices_count is not None): - assert mergehead_softmax_scale is not None - - res, s_lse = _vertical_slash_sparse_attention( - q, - k, - v, - vertical_indices, - slash_indices, - mergehead_softmax_scale, - causal=causal, - stage=stage, - vertical_indices_count=vertical_indices_count, - slash_indices_count=slash_indices_count) - res = res.view(q_heads, q_len, - h_dim).transpose(0, 1) # (qlen,nhead,h_dim) - s_lse = s_lse.view( - q_heads, q_len, - 1).squeeze(-1).unsqueeze(0).float() # (1, nhead,qlen) - else: - res, s_lse = _vertical_slash_sparse_attention(q, - k, - v, - vertical_indices, - slash_indices, - softmax_scale, - causal=causal, - stage=stage) - res = res.view(q_len, q_heads, h_dim) - s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() - return res, s_lse - - output, softmax_lse = flash_attn_varlen_func( - q=query_states, - k=key_states, - v=value_states, - softmax_scale=softmax_scale, - cu_seqlens_q=torch.tensor([0, query_states.shape[0]], - dtype=torch.int32, - device=query_states.device), - max_seqlen_q=query_states.shape[0], - cu_seqlens_k=torch.tensor([0, max_seqlen_k], - dtype=torch.int32, - device=query_states.device), - max_seqlen_k=max_seqlen_k, - causal=causal, - return_softmax_lse=True, - ) - softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, - 2).float() - return output, softmax_lse - - def _merge_attn_outputs( - self, - flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]], - return_lse: Optional[bool] = False, - ) -> torch.Tensor: - attn_outputs_all = [] - logits_all = [] - - for flash_per_chunk in flash_results: - if len(flash_per_chunk) == 1: - attn_outputs_all.append(flash_per_chunk[0][0]) - if return_lse: - logits_all.append(flash_per_chunk[0][1]) - continue - - attn_outputs = torch.stack([ - flash_attn_output[0] for flash_attn_output in flash_per_chunk - ]) - logits = torch.stack([ - flash_attn_output[1] for flash_attn_output in flash_per_chunk - ]) - logits = logits.to(torch.float32) - - if return_lse: - max_val = torch.max(logits, dim=0).values - diff = torch.abs(logits[0] - logits[1]) - log_sum_exp = max_val + torch.log1p(torch.exp(-diff)) - logits_all.append(log_sum_exp) - - max_logits = torch.max(logits, dim=0).values - stable_logits = logits - max_logits.unsqueeze(0) - lse_s = torch.exp(stable_logits).detach() - lse_sum = torch.sum(lse_s, dim=0) - lse_s /= lse_sum - attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1) - attn_outputs_all.append(attn_outputs.sum(dim=0)) - - if return_lse: - return (torch.cat(attn_outputs_all, - dim=0), torch.cat(logits_all, dim=-1)) - else: - return torch.cat(attn_outputs_all, dim=0) - - def _dual_chunk_flash_attn_decoding( - self, - query: torch.Tensor, - query_succ: torch.Tensor, - query_inter: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - softmax_scale: float, - causal: bool, - alibi_slopes: Optional[torch.Tensor], - chunk_size: int, - local_size: int, - original_max_position_embeddings: int, - decode_meta: DualChunkFlashAttentionMetadata, - ): - if not causal: - raise ValueError( - "Dual Chunk Attention does not support causal=False") - - block_size = value_cache.shape[1] - chunk_len = chunk_size - local_size - if chunk_len % block_size != 0: - raise ValueError("chunk_len must be divisible by block_size.") - if original_max_position_embeddings > 0: - assert decode_meta.scaling_factor is not None - scaling_factor = decode_meta.scaling_factor - query = (query * scaling_factor.view(-1, 1, 1, 1)).to( - query.dtype - ) # possible for numerical issue, need to fused in the kernel - query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to( - query.dtype) - query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to( - query.dtype) - outputs_list = [] - softmax_lses_list = [] - - # intra-attention - intra_output, intra_softmax_lse = ( - self._dual_chunk_flash_attn_decoding_with_exp_sums( - query, - key_cache, - value_cache, - decode_meta.block_tables_intra, - decode_meta.seq_lens_intra, - softmax_scale, - alibi_slopes, - causal=False, - )) - outputs_list.append(intra_output) - softmax_lses_list.append(intra_softmax_lse) - - # succ-attention - if decode_meta.max_seq_len_succ: - succ_output, succ_softmax_lse = ( - self._dual_chunk_flash_attn_decoding_with_exp_sums( - query_succ, - key_cache, - value_cache, - decode_meta.block_tables_succ, - decode_meta.seq_lens_succ, - softmax_scale, - alibi_slopes, - causal=False, - )) - outputs_list.append(succ_output) - softmax_lses_list.append(succ_softmax_lse) - - # inter-attention - if decode_meta.max_seq_len_inter: - inter_output, inter_softmax_lse = ( - self._dual_chunk_flash_attn_decoding_with_exp_sums( - query_inter, - key_cache, - value_cache, - block_table[:, :decode_meta.max_seq_len_inter], - decode_meta.seq_lens_inter, - softmax_scale, - alibi_slopes, - causal=False, - )) - outputs_list.append(inter_output) - softmax_lses_list.append(inter_softmax_lse) - outputs = torch.stack(outputs_list, dim=0) - del outputs_list - softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32) - del softmax_lses_list - max_logits = torch.max(softmax_lses, dim=0).values - stable_logits = softmax_lses - max_logits.unsqueeze(0) - lse_s = torch.exp(stable_logits).detach() - lse_sum = torch.sum(lse_s, dim=0) - lse_s /= lse_sum - outputs *= lse_s.unsqueeze(-1).transpose(2, 3) - return outputs.sum(0) - - def _dual_chunk_flash_attn_decoding_with_exp_sums( - self, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - softmax_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - ): - out, softmax_lse = flash_attn_with_kvcache( - q=query, - k_cache=key_cache, - v_cache=value_cache, - block_table=block_table, - cache_seqlens=cache_seqlens, - softmax_scale=softmax_scale, - alibi_slopes=alibi_slopes, - causal=causal, - return_softmax_lse=True, - ) - mask = (cache_seqlens == 0) - out[mask] = 0 - softmax_lse[mask] = -float("inf") - return out, softmax_lse - - -def _vertical_slash_sparse_attention( - query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] - key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] - value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] - v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] - s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] - softmax_scale: float, - causal: bool = True, - stage: str = "intra", - block_size_M: int = 64, - block_size_N: int = 64, - vertical_indices_count: torch.Tensor = None, # [N_HEADS,] - slash_indices_count: torch.Tensor = None, -): - if stage == "intra": - assert causal - else: - assert not causal - - batch_size, num_heads, context_size, head_dim = query.shape - _, _, kv_seq_len, _ = key.shape - - if head_dim not in [16, 32, 64, 128, 256, 512]: - target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim - query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) - key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) - value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) - - v_idx = v_idx.to(torch.int32).reshape( - (batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] - s_idx = s_idx.to(torch.int32).reshape( - (batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] - q_seqlens = torch.tensor([context_size], - dtype=torch.int32, - device=query.device) - kv_seqlens = torch.tensor([kv_seq_len], - dtype=torch.int32, - device=query.device) - - if vertical_indices_count is not None and slash_indices_count is not None: - ( - block_count, - block_offset, - column_count, - column_index, - ) = ops.convert_vertical_slash_indexes_mergehead( - q_seqlens, kv_seqlens, v_idx, s_idx, vertical_indices_count, - slash_indices_count, context_size, block_size_M, block_size_N, - causal) - else: - ( - block_count, - block_offset, - column_count, - column_index, - ) = ops.convert_vertical_slash_indexes(q_seqlens, kv_seqlens, v_idx, - s_idx, context_size, - block_size_M, block_size_N, - causal) - - q = query.transpose(1, 2).contiguous() - k = key.transpose(1, 2).contiguous() - v = value.transpose(1, 2).contiguous() - out, lse = sparse_attn_func( - q, - k, - v, - block_count, - block_offset, - column_count, - column_index, - causal=causal, - softmax_scale=softmax_scale, - return_softmax_lse=True, - ) - out = out.transpose(1, 2).contiguous() - softmax_lse = lse.reshape(*lse.shape, 1) - return (out[..., :context_size, :head_dim], - softmax_lse[..., :context_size, :]) - - -def _sum_all_diagonal_matrix(mat: torch.tensor): - h, n, m = mat.shape - # Zero matrix used for padding - zero_mat = torch.zeros((h, n, n), device=mat.device) - # pads the matrix on left and right - mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) - # Change the strides - mat_strided = mat_padded.as_strided((1, n, n + m), - (n * (2 * n + m), 2 * n + m + 1, 1)) - # Sums the resulting matrix's columns - sum_diags = torch.sum(mat_strided, 1) - return sum_diags[:, 1:] # drop left bottom corner - - -def _get_block(block_table: torch.Tensor, block_size: int, begin: int, - end: int): - begin_block = begin // block_size - end_block = (end - 1) // block_size + 1 - return block_table[begin_block:end_block] diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py deleted file mode 100755 index edb3afb4aa07..000000000000 --- a/vllm/attention/backends/flash_attn.py +++ /dev/null @@ -1,929 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" -from collections import defaultdict -from dataclasses import dataclass -from itertools import accumulate -from typing import Dict, List, Optional, Tuple, Type - -import torch - -from vllm import _custom_ops as ops -# yapf conflicts with isort for this block -# yapf: disable -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType, - is_quantized_kv_cache) -# yapf: enable -from vllm.attention.backends.utils import ( - PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, - compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, - get_seq_len_block_table_args, is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, is_block_tables_empty) -from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) -from vllm.logger import init_logger -from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache) - -logger = init_logger(__name__) - - -class FlashAttentionBackend(AttentionBackend): - - accept_output_buffer: bool = True - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @staticmethod - def get_name() -> str: - return "FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["FlashAttentionImpl"]: - return FlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return FlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: - return FlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - - ops.copy_blocks(key_caches, value_caches, src_to_dists) - - -@dataclass -class FlashAttentionMetadata(AttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - encoder_seq_start_loc: Optional[torch.Tensor] = None - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return is_all_encoder_attn_metadata_set(self) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return is_all_cross_attn_metadata_set(self) - - @property - def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert ((self.seq_lens is not None) - or (self.encoder_seq_lens is not None)) - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - - self._cached_prefill_metadata = FlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - - self._cached_decode_metadata = FlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_decode_metadata - - -class FlashAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata]): - - def __init__(self, input_builder): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables( - self, num_seqs: int, - block_tables: List[List[int]]) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs - - graph_block_tables = self.runner.graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - return FlashAttentionMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class FlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "FLASH_ATTN backend.") - if use_irope: - logger.warning( - "Using irope in V0 is not supported yet, it will fall back " - "to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window - 1, - 0) if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - self.vllm_flash_attn_version = get_flash_attn_version( - requires_alibi=self.alibi_slopes is not None) - if is_quantized_kv_cache(self.kv_cache_dtype) and ( - not self.kv_cache_dtype.startswith("fp8") - or not flash_attn_supports_fp8()): - raise NotImplementedError( - f"FlashAttention does not support {self.kv_cache_dtype} " - "kv-cache on this device " - f"(FA supports fp8 = {flash_attn_supports_fp8()}).") - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}.") - self.attn_type = attn_type - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - output: shape = [num_tokens, num_heads, head_size] - kv_cache: KV cache tensor with shape - [2, num_blocks, block_size, num_kv_heads, head_size]. - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - NOTE: It in-place updates the output tensor. - NOTE: FP8 quantization, flash-attn expect the size of - {q,k,v}_descale to be (num_sequences, num_kv_heads). - We use torch's .expand() to avoid duplicating values - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") - - # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. - if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: - assert ( - layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( - "key/v_scale is only supported in FlashAttention 3 with " - "base dtype bfloat16") - - attn_type = self.attn_type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") - - kv_cache_dtype: str = self.kv_cache_dtype - softmax_scale: float = self.scale - window_size = self.sliding_window - alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes - logits_soft_cap: Optional[float] = self.logits_soft_cap - fp8_attention = kv_cache_dtype.startswith("fp8") - - if fp8_attention and not flash_attn_supports_fp8(): - raise NotImplementedError( - "FlashAttention does not support FP8 kv-cache on this device.") - - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - # We skip updating the KV cache under two conditions: - # a. When the Attention Type is ENCODER. In this phase, we compute - # only the encoder attention without updating the cache. - # b. When both Key and Value are None. This occurs during - # cross-attention computation in the decoding phase, where the - # KV cache is already populated with the cross-attention - # tensor. Thus, we skip cache updates during this time. - if (attn_type != AttentionType.ENCODER) and (key is not None) and ( - value is not None): - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory - # profiling run. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), # type: ignore[union-attr] - kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if fp8_attention: - kv_cache = kv_cache.view(torch.float8_e4m3fn) - key_cache = key_cache.view(torch.float8_e4m3fn) - value_cache = value_cache.view(torch.float8_e4m3fn) - - if fp8_attention: - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) - - (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) = \ - get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) - decode_query = query[num_prefill_query_tokens:] - decode_output = output[num_prefill_query_tokens:] - # QKV for prefill. - query = query[:num_prefill_query_tokens] - prefill_output = output[:num_prefill_query_tokens] - assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ - _get_query_key_seq_metadata(prefill_meta, True, attn_type) - - key = key[:num_prefill_kv_tokens] - value = value[:num_prefill_kv_tokens] - - if fp8_attention: - num_kv_tokens, num_kv_heads, head_size = key.shape - - key, _ = ops.scaled_fp8_quant( - key.reshape((num_kv_tokens, - num_kv_heads * head_size)).contiguous(), - layer._k_scale) - key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) - - value, _ = ops.scaled_fp8_quant( - value.reshape((num_kv_tokens, - num_kv_heads * head_size)).contiguous(), - layer._v_scale) - value = value.reshape( - (num_kv_tokens, num_kv_heads, head_size)) - - descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1]) - flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=q_seq_start_loc, - cu_seqlens_k=k_seq_start_loc, - max_seqlen_q=q_seq_len, - max_seqlen_k=k_seq_len, - softmax_scale=softmax_scale, - causal=_get_causal_option(attn_type), - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - out=prefill_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - else: - # prefix-enabled attention - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support prefix caching") - assert prefill_meta.seq_lens is not None - assert prefill_meta.query_start_loc is not None - max_seq_len = max(prefill_meta.seq_lens) - descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, - key.shape[1]) - flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - seqused_k=prefill_meta.seq_lens_tensor, - max_seqlen_k=max_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=logits_soft_cap, - out=prefill_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - # Use flash_attn_varlen_func kernel for speculative decoding - # because different queries might have different lengths. - - assert decode_meta.max_decode_query_len is not None - # use only for actual varlen decoding - if decode_meta.max_decode_query_len > 1: - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support max_decode_query_len > 1" - ) - assert decode_meta.query_start_loc is not None - descale_shape = (decode_meta.query_start_loc.shape[0] - 1, - key.shape[1]) - flash_attn_varlen_func( - q=decode_query, - k=key_cache, - v=value_cache, - cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_q=decode_meta.max_decode_query_len, - seqused_k=decode_meta.seq_lens_tensor, - max_seqlen_k=decode_meta.max_decode_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - block_table=decode_meta.block_tables, - out=decode_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - else: - # Use flash_attn_with_kvcache for normal decoding. - ( - seq_lens_arg, - _, - block_tables_arg, - ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2]) - flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - block_table=block_tables_arg, - cache_seqlens=seq_lens_arg, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - out=decode_output.unsqueeze(1), - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - return output - - -def _get_query_key_seq_metadata( - attn_metadata: FlashAttentionMetadata, - is_prompt: bool, - attn_type: str, -) -> tuple: - """ - Returns sequence metadata for key and query based on the specified - attention type and whether input is a prompt. - - This function computes the starting locations and maximum sequence lengths - for key and query sequences for different attention types. - - Args: - attn_metadata: The attention metadata object - is_prompt (bool): A flag indicating if the input is a prompt - attn_type (AttentionType): The type of attention being used. - - Returns: - tuple: A tuple containing four integers: - - Starting location for the query sequence. - - Maximum sequence length for the query sequence. - - Starting location for the key sequence. - - Maximum sequence length for the key sequence. - - Raises: - AttributeError: If an invalid attention type is provided. - """ - if attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_start_loc, max_seq_len, - attn_metadata.seq_start_loc, max_seq_len) - - elif attn_type == AttentionType.ENCODER_DECODER: - # This is cross attention between the where the key - # is the precomputed encoder attention and query - # is the input sequence. - # Choose query max length based on whether it is prompt - # or not. - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_start_loc, max_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len) - elif attn_type == AttentionType.ENCODER: - # For encoder attention both the query and the key are same i.e. the - # encoder sequence. - return (attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len) - elif attn_type == AttentionType.ENCODER_ONLY: - assert is_prompt, "Should not have decode for encoder only model." - return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, - attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -def _get_causal_option(attn_type: str) -> bool: - """ - Determine whether the given attention type is suitable for causal - attention mechanisms. - - Args: - attn_type (AttentionType): The type of attention being evaluated - - Returns: - bool: Returns `True` if the attention type is suitable for causal - attention (i.e., not encoder, encoder-only, or encoder-decoder), - otherwise returns `False`. - """ - return not (attn_type == AttentionType.ENCODER - or attn_type == AttentionType.ENCODER_ONLY - or attn_type == AttentionType.ENCODER_DECODER) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py deleted file mode 100644 index aeaa0ab631cf..000000000000 --- a/vllm/attention/backends/flashmla.py +++ /dev/null @@ -1,227 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from contextlib import contextmanager -from dataclasses import dataclass -from typing import List, Optional, Tuple, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder, - MLACommonState) -from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) - - -class FlashMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "FLASHMLA" - - @staticmethod - def get_impl_cls() -> Type["FlashMLAImpl"]: - return FlashMLAImpl - - @staticmethod - def get_metadata_cls() -> Type["FlashMLAMetadata"]: - return FlashMLAMetadata - - @staticmethod - def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: - return FlashMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["FlashMLAState"]: - return FlashMLAState - - -@dataclass -class FlashMLAMetadata(MLACommonMetadata): - decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, - torch.Tensor]] = None - decode_num_splits: Optional[torch.Tensor] = None - - @property - def decode_metadata(self): - decode_metadata = super().decode_metadata - # TODO: cache assignment? - if decode_metadata is not None: - decode_metadata.decode_tile_scheduler_metadata=\ - self.decode_tile_scheduler_metadata - decode_metadata.decode_num_splits=\ - self.decode_num_splits - return decode_metadata - - -class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - m = super().build(seq_lens, query_lens, cuda_graph_pad_size, - batch_size) - - if m.num_decode_tokens > 0: - m.decode_tile_scheduler_metadata, m.decode_num_splits = \ - get_mla_metadata( - m.seq_lens_tensor[m.num_prefills:], - self.num_q_heads, - 1, # MQA for the decode path - ) - - return m - - -class FlashMLAState(MLACommonState[FlashMLAMetadata]): - - def __init__(self, *args, **kwds): - super().__init__(*args, **kwds) - - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) - - @contextmanager - def graph_capture(self, max_batch_size: int): - # Run a dummy `get_mla_metadata` so we can get the right shapes - self._graph_decoder_tile_scheduler_metadata, \ - self._graph_decode_num_splits = get_mla_metadata( - torch.ones( - max_batch_size, dtype=torch.int32, device=self.runner.device), - self.num_q_heads, - 1, # MQA for the decode path - ) - - with super().graph_capture(max_batch_size): - yield - - del self._graph_decoder_tile_scheduler_metadata - del self._graph_decode_num_splits - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - metadata = super().graph_capture_get_metadata_for_batch( - batch_size, is_encoder_decoder_model) - assert metadata.num_decode_tokens > 0 - - decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( - self._graph_seq_lens[:batch_size], - self.num_q_heads, - 1, # MQA for the decode path - ) - - self._graph_decoder_tile_scheduler_metadata.copy_( - decoder_tile_scheduler_metadata) - self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) - - metadata.decode_tile_scheduler_metadata=\ - self._graph_decoder_tile_scheduler_metadata - metadata.decode_num_splits=\ - self._graph_decode_num_splits[:batch_size + 1] - - return metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = super().get_graph_input_buffers( - attn_metadata, is_encoder_decoder_model) - input_buffers["decode_tile_scheduler_metadata"] = \ - attn_metadata.decode_metadata.decode_tile_scheduler_metadata - input_buffers["decode_num_splits"] = \ - attn_metadata.decode_metadata.decode_num_splits - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - super().prepare_graph_input_buffers(input_buffers, attn_metadata, - is_encoder_decoder_model) - - input_buffers["decode_tile_scheduler_metadata"].copy_( - attn_metadata.decode_metadata.decode_tile_scheduler_metadata) - input_buffers["decode_num_splits"].copy_( - attn_metadata.decode_metadata.decode_num_splits) - - -class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str] = None, - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - is_supported, reason = is_flashmla_supported() - assert is_supported, reason - - unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] - if any(unsupported_features): - raise NotImplementedError( - "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashMLAImpl") - - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashMLA with FP8 KV cache not yet supported") - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: FlashMLAMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - - q = torch.cat([q_nope, q_pe], dim=-1)\ - .unsqueeze(1) # Add seqlen dim of 1 (decode) - - o, _ = flash_mla_with_kvcache( - q=q, - k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, - num_splits=decode_meta.decode_num_splits, - softmax_scale=self.scale, - causal=True, - ) - - return self._v_up_proj(o) diff --git a/vllm/attention/backends/mla/__init__.py b/vllm/attention/backends/mla/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py deleted file mode 100644 index 826b63e1ccda..000000000000 --- a/vllm/attention/backends/mla/common.py +++ /dev/null @@ -1,1305 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -# MLA Common Components - -This file implements common components for MLA implementations. - -First we define: - -Sq as Q sequence length -Skv as KV sequence length - -MLA has two possible ways of computing, a data-movement friendly approach and a -compute friendly approach, we generally want to use the compute friendly -approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) -and the data-movement friendly approach for "decode" (i.e. the ratio -Sq / Skv is "large"). - -NOTE what we deem small and large is currently determined by if its labelled -prefill or decode by the scheduler, but this is something we should probably -tune. - -Main reference: DeepseekV2 paper, and FlashInfer Implementation -(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - -Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. -* For decode (i.e. the memory friendly approach) the attention "simulates" a -multi-head attention, while the compute is similar to multi-query attention. - -Below is example of both paths assuming batchsize = 1 - -## More Extent Definitions: - -C Context length, `Skv - Sq` -H hidden size -N number of attention heads -Lq latent dimension for Q 1536 in DSV3 -Lkv latent dimension for K/V 512 in DSV3 -P nope dimension, no rope. 128 in DSV3 -R rope dimension, goes through rope. 64 in DSV3 -V V head dim. 128 in DSV3 - -## Vector/Matrix Definitions - -h_t hidden states (input to attention) shape [Sq, H] -q_c latent/compressed Q shape [Sq, Lq] -q_nope uncompressed Q (no-rope) shape [Sq, N, P] -q_pe uncompressed Q (rope) shape [Sq, N, R] -kv_c latent/compressed KV shape [Skv, Lkv] -k_pe decoupled k position embeddings shape [Skv, R] -new_kv_c new kv_c from current iter shape [Sq, Lkv] -new_k_pe new k_pe from current iter shape [Sq, R] -cache_kv_c cached k_c from previous iters shape [C, Lkv] -cache_k_pe cached k_pe from previous iters shape [C, R] -W_DQ project h_t to q_c shape [H, Lq] -W_UQ project q_c to q_nope shape [Lq, N * P] -W_QR project q_c to q_pe shape [Lq, N * R] -W_DKV project h_t to kv_c shape [H, Lkv] -W_UK project kv_c to k_nope shape [Lkv, N, P] -W_KR project h_t to k_pe shape [H, R] -W_UV project kv_c to v shape [Lkv, N, V] -W_O project v to h_t shape [N * V, H] - - -## Compute Friendly Approach (i.e. "_forward_prefill"): - -q_c = h_t @ W_DQ -q_nope = (q_c @ W_UQ).view(Sq, N, P) -q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) -new_kv_c = h_t @ W_DKV -new_k_pe = RoPE(h_t @ W_KR) -kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) -k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) -k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) -v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) - -// MHA with QK headdim = P + R -// V headdim = V -// spda_o shape [Sq, N, V] -spda_o = scaled_dot_product_attention( - torch.cat([q_nope, q_pe], dim=-1), - torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), - v -) -return spda_o @ W_O - -NOTE: in the actual code, - `kv_b_proj` is [W_UK; W_UV] concatenated per head - `q_b_proj` is [W_UQ; W_QR] concatenated per head - `out_proj` is W_O - - -## Data-Movement Friendly Approach (i.e. "_forward_decode"): - -Runtime -q_c = h_t @ W_DQ -q_nope = (q_c @ W_UQ).view(-1, N, P) -ql_nope = einsum("snh,lnh->snl", q, W_UK) -q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) -new_kv_c = h_t @ W_DKV -new_k_pe = RoPE(h_t @ W_KR) -kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) -k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) - -// MQA with QK headdim = Lkv + R -// V headdim = Lkv -// spda_o shape [Sq, N, Lkv] -// NOTE: this is less compute-friendly since Lkv > P -// but is more data-movement friendly since its MQA vs MHA -spda_o = scaled_dot_product_attention( - torch.cat([ql_nope, q_pe], dim=-1), - torch.cat([kv_c, k_pe], dim=-1), - kv_c -) - -o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) -return o.view(-1, N * V) @ self.num_heads @ W_O - - -## Chunked Prefill - -For chunked prefill we want to use the compute friendly algorithm. We are -assuming sufficiently large Sq / Skv ratio, in the future may want to switch to -the data-movement friendly approach if the chunk (i.e. `Sq`) is small. - -However, the compute-friendly approach can potentially run out of memory if Skv -is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` - -To mitigate this, we chunk the computation of attention with respect to the -current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a -fixed workspace size. - -The chunked prefill approach is as follows: - -MCC Max chunk of context to process per iter, computed dynamically, - used to bound the memory usage - -q_c = h_t @ W_DQ -q_nope = (q_c @ W_UQ).view(Sq, N, P) -q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) -new_kv_c = h_t @ W_DKV -new_k_pe = RoPE(h_t @ W_KR) -new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) -new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) - -// MHA between queries and new KV -// with QK headdim = P + R -// V headdim = V -// curr_o shape [Sq, N, V] -// curr_lse shape [N, Sq], this is just order FA returns -curr_o, curr_lse = scaled_dot_product_attention( - torch.cat([q_nope, q_pe], dim=-1), - torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), - new_v, - casual=True, - return_softmax_lse=True -) - -// Compute attention with the already existing context -for chunk_idx in range(cdiv(C, MCC)): - chunk_start = chunk_idx * MCC - chunk_end = min(chunk_start + MCC, C) - Sc = chunk_end - chunk_start - cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] - cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] - cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) - cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) - - chunk_o, chunk_lse = scaled_dot_product_attention( - torch.cat([q_nope, q_pe], dim=-1), - torch.cat([cache_k_nope_chunk, - cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], - dim=-1), - cache_v_chunk, - casual=False, - return_softmax_lse=True - ) - - curr_o, curr_lse = merge_attn_states( - suffix_output=curr_o, - suffix_lse=curr_lse, - prefix_output=chunk_o, - prefix_lse=chunk_lse, - ) - -return curr_o @ W_O -""" - -import functools -from abc import abstractmethod -from collections import defaultdict -from contextlib import contextmanager -from dataclasses import dataclass -from itertools import accumulate -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar - -import torch - -from vllm import _custom_ops as ops -from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, MLAAttentionImpl) -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) -from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.utils.fa_utils import get_flash_attn_version -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - UnquantizedLinearMethod) -from vllm.multimodal import MultiModalPlaceholderMap -from vllm.platforms import current_platform -from vllm.triton_utils import HAS_TRITON -from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down - -if HAS_TRITON: - from vllm.attention.ops.triton_flash_attention import triton_attention -else: - triton_attention = None - -try: - from vllm.vllm_flash_attn import flash_attn_varlen_func - is_vllm_fa = True -except ImportError: - is_vllm_fa = False - try: - # For rocm use upstream flash attention - from flash_attn import flash_attn_varlen_func - except ImportError: - flash_attn_varlen_func = None - -is_hip = current_platform.is_rocm() - - -class MLACommonBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "TRITON_MLA" - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return MLACommonMetadata - - @staticmethod - def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: - return MLACommonMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["MLACommonState"]: - return MLACommonState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - ) -> Tuple[int, ...]: - return (num_blocks, block_size, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - ops.copy_blocks_mla(kv_caches, src_to_dists) - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [576] - - -T = TypeVar("T", bound="MLACommonMetadata") - - -class MLACommonState(AttentionState, Generic[T]): - - def __init__(self, runner): - self.runner = runner - self._is_graph_capturing = False - - scheduler_config = runner.scheduler_config - self.model_config = runner.model_config - cache_config = runner.cache_config - - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled - self.enable_prefix_caching = cache_config.enable_prefix_caching - - if self.chunked_prefill_enabled or self.enable_prefix_caching: - self.context_chunk_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max( - 8 * self.model_config.max_model_len, 4 * - scheduler_config.max_num_seqs * cache_config.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.context_chunk_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size - - @contextmanager - def graph_capture(self, max_batch_size: int): - self._is_graph_capturing = True - - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) - self._graph_seq_lens = torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device) - self._graph_block_tables = torch.from_numpy( - self.runner.graph_block_tables).to(device=self.runner.device) - - self._positions = torch.zeros((max_batch_size, ), - dtype=torch.long, - device=self.runner.device) - - yield - - self._is_graph_capturing = False - del self._graph_slot_mapping - del self._graph_seq_lens - del self._graph_block_tables - del self._positions - - def graph_clone(self, batch_size: int): - assert self._is_graph_capturing - return self.__class__(self.runner) - - def graph_capture_get_metadata_for_batch( - self, - batch_size: int, - is_encoder_decoder_model: bool = False) -> T: - assert self._is_graph_capturing - - attn_metadata = self.runner.attn_backend.make_metadata( - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - use_cuda_graph=True, - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=self._graph_slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=1, - max_decode_query_len=1, - max_prefill_seq_len=0, - max_decode_seq_len=self.runner.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self._graph_block_tables[:batch_size], - head_dim=self.runner.model_config.get_head_size()) - - if is_encoder_decoder_model: - raise NotImplementedError( - "MLACommonState does not support encoder/decoder yet") - - return attn_metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = { - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - } - if is_encoder_decoder_model: - raise NotImplementedError( - "MLACommonState does not support encoder/decoder yet") - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - def begin_forward(self, model_input): - if self.chunked_prefill_enabled or self.enable_prefix_caching: - if not hasattr(self, "context_chunk_workspace"): - # not self.runner.device does not return the correct device - # for this process, (init_device sets the correct device but - # only on the Worker). The only way Ive figured out to get the - # correct device is to allocate the workspace on the first call - # to begin_forward and use the device of the input tokens - assert model_input.input_tokens is not None - self.context_chunk_workspace = torch.empty( - (self.context_chunk_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=model_input.input_tokens.device, - ) - - model_input.attn_metadata.context_chunk_workspace = \ - self.context_chunk_workspace - - -@dataclass -class MLACommonMetadata(AttentionMetadata): - """Metadata for MLACommon. - - NOTE: Please read the comment at the top of the file before trying to - understand this class - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional[Any] = None - _cached_decode_metadata: Optional[Any] = None - - num_prefill_tokens: int - - # The dimension of the attention heads - head_dim: Optional[int] = None - - # Used when chunked prefill is enabled to simulate worst case workspace - # allocations, hopefully to avoid going OOM - is_profile_run: bool = False - - # New for MLA (compared to FlashAttention) - # For chunked prefill - context_chunk_cu_seq_lens: Optional[torch.Tensor] = None - context_chunk_starts: Optional[torch.Tensor] = None - context_chunk_seq_tot: Optional[List[int]] = None - context_chunk_max_seq_lens: Optional[List[int]] = None - # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted - context_chunk_workspace: Optional[torch.Tensor] = None - - def __post_init__(self): - supported_head_sizes = MLACommonBackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f" received {self.head_dim}.") - - @property - def prefill_metadata(self): - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - - self._cached_prefill_metadata = self.__class__( - # Required by ModelRunner - use_cuda_graph=False, # Not Attention Related - # Required by Attention Metadata - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - # Required by Attention Metadata (not used) - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - # MLACommonMetadata - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - head_dim=self.head_dim, - is_profile_run=self.is_profile_run, - # MLACommonMetadata Chunk prefill specific - context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens, - context_chunk_starts=self.context_chunk_starts, - context_chunk_seq_tot=self.context_chunk_seq_tot, - context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self): - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - - self._cached_decode_metadata = self.__class__( - # Required by ModelRunner - use_cuda_graph=self.use_cuda_graph, # Not Attention Related - # Required by Attention Metadata - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - # Required by Attention Metadata (not used) - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - # MLACommonMetadata - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - head_dim=self.head_dim, - is_profile_run=self.is_profile_run) - return self._cached_decode_metadata - - -class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - BLOCK_TABLE_EXTENDER: list[list[int]] = [] - - def __init__(self, input_builder): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - self.chunked_prefill_enabled = \ - self.runner.scheduler_config.chunked_prefill_enabled - self.enable_prefix_caching = \ - self.runner.cache_config.enable_prefix_caching - - if self.chunked_prefill_enabled or self.enable_prefix_caching: - attn_state = self.input_builder.runner.attn_state - self.context_chunk_workspace_size = \ - attn_state.context_chunk_workspace_size - self.page_size = self.runner.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - if is_prompt: - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables( - self, num_seqs: int, - block_tables: List[List[int]]) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs - - graph_block_tables = self.runner.graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * - cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - - context_chunk_cu_seq_lens = None - context_chunk_starts = None - context_chunk_seq_tot = None - context_chunk_max_seq_lens = None - - if (self.chunked_prefill_enabled or self.enable_prefix_caching) \ - and self.num_prefills > 0 \ - and context_lens_tensor is not None \ - and context_lens_tensor[:self.num_prefills].max() > 0: - - # NOTE: it is recommended you read the `Chunked Prefill` section in - # the comment at the top of the file before trying to understand - # the following code - - num_prefills_with_context = \ - (context_lens_tensor[:self.num_prefills] > 0).sum().item() - - # currently we allocate an equal amount of workspace for each - # prefill in the batch, we could probably use a more advanced - # algorithm here and allocate more workspace to prefills with - # longer context lengths - max_context_chunk = \ - self.context_chunk_workspace_size // num_prefills_with_context - - # align max_context_chunk to page_size by rounding down, - # currently the `gather_and_maybe_dequant_cache` kernel cannot - # handle `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, self.page_size) - assert max_context_chunk > 0 - num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) - - # if `max_context_chunk = 256`, `num_chunks = 3`, and - # `num_prefills_with_context = 4`, create a tensor that looks like - # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] - context_chunk_starts = \ - torch.arange(num_chunks, device=device, dtype=torch.int32)\ - .unsqueeze(1).expand(-1, self.num_prefills)\ - * max_context_chunk - chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ - .unsqueeze(0), context_chunk_starts + max_context_chunk) - chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) - _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( - torch.int32) - zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ - .unsqueeze(-1) - context_chunk_cu_seq_lens = \ - torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) - context_chunk_max_seq_lens = \ - chunk_seq_lens.max(dim=1).values.tolist() - context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() - assert max(context_chunk_seq_tot) <= \ - self.context_chunk_workspace_size - - return self.runner.attn_backend.make_metadata( - # Required by ModelRunner - use_cuda_graph=use_captured_graph, # Not Attention Related - # Required by Attention Metadata - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - # Required by Attention Metadata (not used) - multi_modal_placeholder_index_maps=None, # Not Attention Related - enable_kv_scales_calculation=False, - # MLACommonMetadata - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - head_dim=self.runner.model_config.get_head_size(), - is_profile_run=self.runner.in_profile_run, - # MLACommonMetadata Chunk prefill specific - context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, - context_chunk_starts=context_chunk_starts, - context_chunk_seq_tot=context_chunk_seq_tot, - context_chunk_max_seq_lens=context_chunk_max_seq_lens, - ) - - -class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - q_lora_rank: Optional[int], - kv_lora_rank: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - qk_head_dim: int, - v_head_dim: int, - kv_b_proj: ColumnParallelLinear, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing not supported in V0.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.kv_cache_dtype = kv_cache_dtype - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_head_dim - self.v_head_dim = v_head_dim - self.kv_b_proj = kv_b_proj - - self.triton_fa_func = triton_attention - # Handle the differences between the flash_attn_varlen from flash_attn - # and the one from vllm_flash_attn. The former is used on RoCM and the - # latter has an additional parameter to control FA2 vs FA3 - self.flash_attn_varlen_func = flash_attn_varlen_func - self.vllm_flash_attn_version = get_flash_attn_version() - if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim for attention backends that do - # not support different headdims - # We don't need to pad V if we are on a hopper system with FA3 - self._pad_v = self.vllm_flash_attn_version is None or not ( - self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9) - - def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, - return_softmax_lse, **kwargs): - maybe_padded_v = v - if self._pad_v: - maybe_padded_v = torch.nn.functional.pad( - v, [0, q.shape[-1] - v.shape[-1]], value=0) - - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \ - and not return_softmax_lse: - attn_out = self.triton_fa_func( - q, - k, - maybe_padded_v, - None, # output - kwargs["cu_seqlens_q"], - kwargs["cu_seqlens_k"], - kwargs["max_seqlen_q"], - kwargs["max_seqlen_k"], - kwargs["causal"], - softmax_scale, - None, # bias - ) - elif is_vllm_fa: - attn_out = self.flash_attn_varlen_func( - q=q, - k=k, - v=maybe_padded_v, - return_softmax_lse=return_softmax_lse, - softmax_scale=softmax_scale, - **kwargs, - ) - else: - # Use return_attn_probs instead of return_softmax_lse for RoCM - attn_out = self.flash_attn_varlen_func( - q=q, - k=k, - v=maybe_padded_v, - return_attn_probs=return_softmax_lse, - softmax_scale=softmax_scale, - **kwargs, - ) - - # Unpack the output if there is multiple results, - # triton always returns (output, softmax_lse), - # vllm_flash_attn returns (output, softmax_lse) when - # `return_softmax_lse = True` - # flash_attn (RoCM) returns (output, softmax_lse, ...) when - # `return_attn_probs = True` - rest = None - if isinstance(attn_out, tuple): - attn_out, *rest = attn_out - - # Remain consistent with old `flash_attn_varlen_func` where there - # is only one output tensor if `return_softmax_lse` is False. - if return_softmax_lse: - assert rest is not None - return attn_out, rest[0] - return attn_out - - def _v_up_proj(self, x): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - - def process_weights_after_loading(self, act_dtype: torch.dtype): - - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - if hasattr(layer, attr): - return getattr(layer, attr) - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) - - def _compute_prefill_context( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - k_scale: torch.Tensor, - ): - prefill_metadata = attn_metadata.prefill_metadata - assert prefill_metadata is not None - assert prefill_metadata.context_chunk_seq_tot is not None - assert prefill_metadata.context_chunk_cu_seq_lens is not None - assert prefill_metadata.context_chunk_starts is not None - assert prefill_metadata.context_chunk_max_seq_lens is not None - assert prefill_metadata.context_lens_tensor is not None - - output = None - iters = len(prefill_metadata.context_chunk_seq_tot) - - # Fetch from attn_metadata directly, since it late bound by - # MLAAttentionState, grabbing it directly `attn_metadata` can avoid - # any weirdness around prefill_metadata caching - assert attn_metadata.context_chunk_workspace is not None - workspace = attn_metadata.context_chunk_workspace - - for i in range(iters): - toks = prefill_metadata.context_chunk_seq_tot[i] - - ops.gather_and_maybe_dequant_cache( - src_cache=kv_c_and_k_pe_cache, - dst=workspace, - block_table=prefill_metadata.block_tables, - cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], - batch_size=prefill_metadata.num_prefills, - kv_cache_dtype=self.kv_cache_dtype, - scale=k_scale, - seq_starts=prefill_metadata.context_chunk_starts[i], - ) - - kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank] - k_pe = workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) - - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) - - attn_output, attn_softmax_lse = \ - self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) - - if output is None: - output = attn_output - output_lse = attn_softmax_lse - else: - output_tmp = torch.empty_like(output) - output_lse_tmp = torch.empty_like(output_lse) - merge_attn_states( - output=output_tmp, - output_lse=output_lse_tmp, - prefix_output=output, - prefix_lse=output_lse, - suffix_output=attn_output, - suffix_lse=attn_softmax_lse, - ) - output = output_tmp - output_lse = output_lse_tmp - - return output, output_lse - - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - k_scale: torch.Tensor, - ) -> torch.Tensor: - - prefill_metadata = attn_metadata.prefill_metadata - assert prefill_metadata is not None - - has_context = prefill_metadata.context_lens_tensor is not None \ - and prefill_metadata.context_lens_tensor.max() > 0 - - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - output = self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - ) - - if has_context: - # ROCm flash_attn_varlen_func will return 3 objects instead of 2 - suffix_output, suffix_lse = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata, k_scale) - - output = torch.empty_like(suffix_output) - merge_attn_states( - output=output, - prefix_output=context_output, - prefix_lse=context_lse, - suffix_output=suffix_output, - suffix_lse=suffix_lse, - ) - - # unpad if necessary - if self._pad_v: - output = output[..., :v.shape[-1]] - - return output.flatten(start_dim=-2) - - @abstractmethod - def _forward_decode( - self, - ql_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: T, - ) -> torch.Tensor: - raise NotImplementedError - - def forward( - self, - layer: AttentionLayer, - q: torch.Tensor, # query in unified attn - k_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, - attn_metadata: T, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if output is not None: - raise NotImplementedError( - "output is not yet supported for MLAImplBase") - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for MLAImplBase") - - if attn_metadata.is_profile_run and \ - attn_metadata.context_chunk_workspace is not None: - # During the profile run try to simulate to worse case output size - # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` - # since this can be large - _ = torch.empty( - (attn_metadata.context_chunk_workspace.shape[0], - self.num_heads, self.qk_nope_head_dim + self.v_head_dim), - device=k_c_normed.device, - dtype=k_c_normed.dtype, - ) - - has_decode = attn_metadata.decode_metadata is not None - has_prefill = attn_metadata.prefill_metadata is not None - - num_prefill_tokens: int = attn_metadata.num_prefill_tokens - q = q.view(-1, self.num_heads, self.qk_head_dim) - - decode_q = q[num_prefill_tokens:] - - prefill_q = q[:num_prefill_tokens] - prefill_k_pe = k_pe[:num_prefill_tokens] - prefill_k_c_normed = k_c_normed[:num_prefill_tokens] - - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - ops.concat_and_cache_mla( - k_c_normed, - k_pe.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - - output = torch.empty(attn_metadata.num_prefill_tokens + - attn_metadata.num_decode_tokens, - self.v_head_dim * self.num_heads, - device=q.device, - dtype=q.dtype) - if has_prefill: - output[:num_prefill_tokens] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale) - - if has_decode: - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - - output[num_prefill_tokens:] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) - - return output diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py deleted file mode 100644 index 587d08858b92..000000000000 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ /dev/null @@ -1,407 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Optional, Type, Union - -import torch - -import vllm.envs as envs -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder, - MLACommonState) -from vllm.attention.backends.utils import (compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) -from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, - get_aiter_mla_metadata) - - -def is_aiter_mla_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_ROCM_USE_AITER_MLA - - -class AiterMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "ROCM_AITER_MLA" - - @staticmethod - def get_impl_cls() -> Type["AiterMLAImpl"]: - return AiterMLAImpl - - @staticmethod - def get_metadata_cls() -> Type["AiterMLAMetadata"]: - return AiterMLAMetadata - - @staticmethod - def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: - return AiterMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["AiterMLAState"]: - return AiterMLAState - - -@dataclass -class AiterMLAMetadata(MLACommonMetadata): - # The following 5 tensors are for current version of AITER MLA - block_table_bound: Optional[torch.Tensor] = None - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: Optional[torch.Tensor] = None - # The page indices of the paged kv cache - paged_kv_indices: Optional[torch.Tensor] = None - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_lens: Optional[torch.Tensor] = None - - # This is just to make new AITER MLA API work - # -- MTP support is not added yet. - qo_indptr: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self): - prefill_metadata = super().prefill_metadata - self._cached_prefill_metadata = prefill_metadata - - if prefill_metadata is not None: - prefill_metadata.paged_kv_indptr = self.paged_kv_indptr - prefill_metadata.paged_kv_indices = self.paged_kv_indices - prefill_metadata\ - .paged_kv_last_page_lens = self.paged_kv_last_page_lens - prefill_metadata.block_table_bound = self.block_table_bound - prefill_metadata.qo_indptr = self.qo_indptr - - # update the cache - self._cached_prefill_metadata = self.__class__( - **prefill_metadata.__dict__) - - return self._cached_prefill_metadata - - @property - def decode_metadata(self): - decode_metadata = super().decode_metadata - - self._cached_decode_metadata = decode_metadata - - if decode_metadata is not None: - decode_metadata.paged_kv_indptr = self.paged_kv_indptr - decode_metadata.paged_kv_indices = self.paged_kv_indices - decode_metadata\ - .paged_kv_last_page_lens = self.paged_kv_last_page_lens - decode_metadata.block_table_bound = self.block_table_bound - decode_metadata.qo_indptr = self.qo_indptr - - # update the cache - self._cached_decode_metadata = self.__class__( - **decode_metadata.__dict__) - - return self._cached_decode_metadata - - -class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] - - def __init__(self, input_builder): - super().__init__(input_builder) - assert self.block_size == 1, "AITER MLA requires only block size 1." - - def prepare(self): - super().prepare() - self.paged_kv_indices: list[int] = [] - self.paged_kv_indptr: list[int] = [0] - self.paged_kv_last_page_lens: list[int] = [] - self.total_blocks = 0 - self.qo_indptr: list[int] = [0] - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - if is_prompt: - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - if is_profile_run: - return - - # Update paged_kv_* tensors only for non-profile run - block_table = block_tables[seq_id] - self._update_paged_kv_tensors(block_table, seq_len) - - def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. - self.total_blocks += len(block_table) - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size - self.paged_kv_indices.extend(block_table[:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + - block_table_bound) - self.qo_indptr.append(self.qo_indptr[-1] + 1) - - last_page_len = seq_len % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - self.paged_kv_last_page_lens.append(last_page_len) - - def build(self, seq_lens: list[int], query_lens: list[int], - cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata: - metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size, - batch_size) - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - if use_captured_graph: - last_paged_kv_indptr = self.paged_kv_indptr[-1] - self.paged_kv_indptr.extend([last_paged_kv_indptr] * - cuda_graph_pad_size) - self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) - last_qo_indptr = self.qo_indptr[-1] - self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size) - - # For current version of AITER MLA - if len(self.paged_kv_indptr) > 0: - # extend to the maximum number of blocks as returned by the - # scheduler - self.paged_kv_indices.extend( - [0] * (self.total_blocks - len(self.paged_kv_indices))) - paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, - device=device, - dtype=torch.int) - paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, - device=device, - dtype=torch.int) - paged_kv_last_page_lens_tensor = torch.tensor( - self.paged_kv_last_page_lens, device=device, dtype=torch.int) - block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - - 1, - device=device, - dtype=torch.int) - - qo_indptr = torch.tensor(self.qo_indptr, - device=device, - dtype=torch.int) - else: - paged_kv_indices_tensor = None - paged_kv_indptr_tensor = None - paged_kv_last_page_lens_tensor = None - block_table_bound_tensor = None - qo_indptr = None - - metadata.paged_kv_indptr = paged_kv_indptr_tensor - metadata.paged_kv_indices = paged_kv_indices_tensor - metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor - metadata.block_table_bound = block_table_bound_tensor - metadata.qo_indptr = qo_indptr - - return metadata - - -class AiterMLAState(MLACommonState[AiterMLAMetadata]): - - @contextmanager - def graph_capture(self, max_batch_size: int): - kv_indices, kv_indptr, last_page_lens, qo_indptr = \ - get_aiter_mla_metadata( - max_batch_size=max_batch_size, - block_size=self.runner.block_size, - max_block_per_batch=\ - self.runner.get_max_block_per_batch(), - device=self.runner.device) - self._paged_kv_indices_tensor = kv_indices - self._paged_kv_indptr_tensor = kv_indptr - self._paged_kv_last_page_lens_tensor = last_page_lens - self._qo_indptr_tensor = qo_indptr - - with super().graph_capture(max_batch_size): - yield - - del self._paged_kv_indices_tensor - del self._paged_kv_indptr_tensor - del self._paged_kv_last_page_lens_tensor - del self._qo_indptr_tensor - - def graph_capture_get_metadata_for_batch( - self, - batch_size: int, - is_encoder_decoder_model: bool = False) -> AiterMLAMetadata: - - metadata = super().graph_capture_get_metadata_for_batch( - batch_size, is_encoder_decoder_model) - - paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1] - paged_kv_indices = self._paged_kv_indices_tensor - paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: - batch_size] - qo_indptr = self._qo_indptr_tensor[:batch_size + 1] - - metadata.paged_kv_indptr = paged_kv_indptr - metadata.paged_kv_indices = paged_kv_indices - metadata.paged_kv_last_page_lens = paged_kv_last_page_lens - metadata.qo_indptr = qo_indptr - - return metadata - - def get_graph_input_buffers(self, - attn_metadata: AiterMLAMetadata, - is_encoder_decoder_model: bool = False): - input_buffers = super().get_graph_input_buffers( - attn_metadata, is_encoder_decoder_model) - input_buffers[ - 'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr - input_buffers[ - "paged_kv_indices"] = attn_metadata.\ - decode_metadata.paged_kv_indices - input_buffers[ - "paged_kv_last_page_lens"] = attn_metadata.\ - decode_metadata.paged_kv_last_page_lens - input_buffers['qo_indptr'] = attn_metadata.qo_indptr - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata: AiterMLAMetadata, - is_encoder_decoder_model: bool = False): - super().prepare_graph_input_buffers(input_buffers, attn_metadata, - is_encoder_decoder_model) - - num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[ - 0] - input_buffers["paged_kv_indptr"].copy_( - attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True) - input_buffers["paged_kv_indices"][:num_total_blocks].copy_( - attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True) - input_buffers["paged_kv_last_page_lens"].copy_( - attn_metadata.decode_metadata.paged_kv_last_page_lens, - non_blocking=True) - input_buffers["qo_indptr"].copy_( - attn_metadata.decode_metadata.qo_indptr, non_blocking=True) - - -class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] - if any(unsupported_features): - raise NotImplementedError( - "Aiter MLA does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") - - from aiter import flash_attn_varlen_func - self.flash_attn_varlen_func = flash_attn_varlen_func - - def _flash_attn_varlen_diff_headdims( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - softmax_scale: float, return_softmax_lse: bool, - **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: - output = self.flash_attn_varlen_func( - q, - k, - v, - **kwargs, - ) - - return output - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: AiterMLAMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - B = q_nope.shape[0] - - q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.empty(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) - - kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, - attn_metadata.qo_indptr, - attn_metadata.max_query_len, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_lens) - - return self._v_up_proj(o) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py deleted file mode 100644 index 9262144e37b5..000000000000 --- a/vllm/attention/backends/rocm_flash_attn.py +++ /dev/null @@ -1,953 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer ROCm GPUs.""" -import itertools -from dataclasses import dataclass -from functools import cache -from typing import List, Optional, Tuple, Type - -import torch - -import vllm.envs as envs -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import (CommonAttentionState, - CommonMetadataBuilder) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) -from vllm.config import get_current_vllm_config -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym) -from vllm.platforms import current_platform - -logger = init_logger(__name__) -_PARTITION_SIZE_ROCM = 256 - - -@cache -def is_rocm_aiter_paged_attn_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ - and envs.VLLM_ROCM_USE_AITER \ - - -@cache -def _get_paged_attn_module() -> PagedAttention: - """ - Initializes the appropriate PagedAttention module from `attention/ops`, - which is used as helper function - by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. - - The choice of attention module depends on whether - AITER paged attention is enabled: - - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. - - Otherwise, it defaults to using the original `PagedAttention`. - """ - if is_rocm_aiter_paged_attn_enabled(): - # Import AITERPagedAttention only when the flag is enabled - from vllm.attention.ops.rocm_aiter_paged_attn import ( - AITERPagedAttention) - return AITERPagedAttention() - return PagedAttention() - - -class ROCmFlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True - - @staticmethod - def get_name() -> str: - return "ROCM_FLASH" - - @staticmethod - def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: - return ROCmFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return ROCmFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: - return ROCmFlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - paged_attn = _get_paged_attn_module() - return paged_attn.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - paged_attn = _get_paged_attn_module() - paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - paged_attn = _get_paged_attn_module() - paged_attn.copy_blocks(kv_caches, src_to_dists) - - -@dataclass -class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] = None - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.block_tables is not None - - self._cached_prefill_metadata = ROCmFlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = ROCmFlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - # Batch may be composed of prefill|decodes, adjust query start indices - # to refer to the start of decodes when the two are split apart. - # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - if self._cached_decode_metadata.query_start_loc is not None: - qs = self._cached_decode_metadata.query_start_loc - self._cached_decode_metadata.query_start_loc = qs - qs[0] - return self._cached_decode_metadata - - -class ROCmFlashAttentionMetadataBuilder( - CommonMetadataBuilder[ROCmFlashAttentionMetadata]): - - _metadata_cls = ROCmFlashAttentionMetadata - - -def _make_alibi_bias(alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: Optional[List[int]], - make_attn_mask: bool = True) -> List[torch.Tensor]: - attn_biases = [] - if seq_lens: - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat( - (num_heads, 1, 1)).to(alibi_slopes.device) - bias.mul_(alibi_slopes[:, None, None]) - if make_attn_mask: - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( - alibi_slopes.device) - attn_biases.append((bias + inf_mask).to(dtype)) - else: - attn_biases.append(bias.to(dtype)) - - return attn_biases - - -def _get_seq_len_block_table_args( - attn_metadata: ROCmFlashAttentionMetadata, - attn_type: str, -) -> tuple: - ''' - The particular choice of sequence-length - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths - Encoder attn -> select encoder sequence lengths fields - Encoder-only attn -> select prefill sequence lengths with - bidirectional attention - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention, encoder-only - - Returns: - - * Appropriate sequence-lengths tensors for query and key - * Appropriate max sequence-length scalar - * Causal masking flag - ''' - - if attn_type == AttentionType.ENCODER: - assert attn_metadata.encoder_seq_lens is not None - assert attn_metadata.encoder_seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), - device=attn_metadata.encoder_seq_lens_tensor.device, - dtype=attn_metadata.encoder_seq_lens_tensor.dtype) - causal_mask = False - - # No block tables associated with encoder attention - return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, - query_seq_start_loc, attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_lens, causal_mask) - - elif attn_type == AttentionType.ENCODER_ONLY: - # For encoder-only models, we use the prefill sequence lengths - assert attn_metadata.seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.seq_lens)), - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype) - max_seq_len = attn_metadata.max_prefill_seq_len - # Encoder-only models typically use bidirectional attention - causal_mask = False - - return (query_seq_start_loc, max_seq_len, query_seq_start_loc, - max_seq_len, attn_metadata.seq_lens, causal_mask) - - elif attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - assert attn_metadata.seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.seq_lens)), - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype) - max_seq_len = attn_metadata.max_prefill_seq_len - causal_mask = True - - return (query_seq_start_loc, max_seq_len, query_seq_start_loc, - max_seq_len, attn_metadata.seq_lens, causal_mask) - elif attn_type == AttentionType.ENCODER_DECODER: - assert attn_metadata.seq_lens is not None - assert attn_metadata.encoder_seq_lens_tensor is not None - query_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.seq_lens)), - device=attn_metadata.encoder_seq_lens_tensor.device, - dtype=attn_metadata.encoder_seq_lens_tensor.dtype) - - assert attn_metadata.encoder_seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - key_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype) - causal_mask = False - - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (query_start_loc, attn_metadata.max_prefill_seq_len, - key_seq_start_loc, attn_metadata.max_encoder_seq_len, - attn_metadata.seq_lens, causal_mask) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -class ROCmFlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| - |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "ROCM_FLASH backend.") - if use_irope: - logger.warning_once( - "Using irope in ROCm Flash Attention is not supported yet, it " - "will fail back to global attention for long context.") - if use_irope: - logger.warning( - "Using irope in V0 is not supported yet, it will fall back " - "to global attention for long context.") - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - self.logits_soft_cap = 0.0 - else: - self.logits_soft_cap = logits_soft_cap - self.attn_type = attn_type - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - self.paged_attn_module = _get_paged_attn_module() - supported_head_sizes = self.paged_attn_module.get_supported_head_sizes( - ) - - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - self.use_naive_attn = False - # NOTE: Allow for switching between Triton and CK. Defaulting to triton. - self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN - if self.use_triton_flash_attn: - if logits_soft_cap is not None: - raise ValueError( - "ROCm Triton FlashAttention does not support attention" - " logits soft capping." - " please try using the ROCm CK " - "FA backend instead by setting the env var " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") - - from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 - triton_attention) - self.triton_attn_func = triton_attention - logger.debug("Using Triton FA in ROCmBackend") - if self.sliding_window != (-1, -1): - logger.warning("ROCm Triton FA does not currently support " - "sliding window attention. If using half " - "precision, please try using the ROCm CK " - "FA backend instead by setting the env var " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") - else: - # if not using triton, navi3x/navi21/navi10 do not use flash-attn - # either - if not current_platform.has_device_capability(90): - self.use_naive_attn = True - else: - try: - from flash_attn import flash_attn_varlen_func # noqa: F401 - self.fa_attn_func = flash_attn_varlen_func - logger.debug("Using CK FA in ROCmBackend") - except ModuleNotFoundError: - self.use_naive_attn = True - - if self.use_naive_attn: - if logits_soft_cap is not None: - raise ValueError( - "ROCm Naive FlashAttention does not support " - "attention logits soft capping.") - - self.sdpa_attn_func = _sdpa_attention - logger.debug("Using naive (SDPA) attention in ROCmBackend") - - self.aiter_kv_scales_initialized = False - self.force_fp8_attention = ( - get_current_vllm_config() is not None - and get_current_vllm_config().model_config.override_attention_dtype - == "fp8") - - def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - tokens, n_kv_heads, head_dim = x.shape - return (x[:, :, - None, :].expand(tokens, n_kv_heads, n_rep, - head_dim).reshape(tokens, n_kv_heads * n_rep, - head_dim)) - - def fused_output_quant_supported(self, quant_key: QuantKey): - if self.use_triton_flash_attn: - return quant_key == kFp8StaticTensorSym - - # Only supported in the Triton backend - return False - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. - - For decoder-only models: query, key and value must be non-None. - - For encoder/decoder models: - * ROCmFlashAttentionImpl.forward() may be invoked for both self- and - cross-attention layers. - * For self-attention: query, key and value must be non-None. - * For cross-attention: - * Query must be non-None - * During prefill, key and value must be non-None; key and value - get cached for use during decode. - * During decode, key and value may be None, since: - (1) key and value tensors were cached during prefill, and - (2) cross-attention key and value tensors do not grow during - decode - - A note on how the attn_type (attention type enum) argument impacts - attention forward() behavior: - - * DECODER: normal decoder-only behavior; - use decoder self-attention block table - * ENCODER: no KV caching; pass encoder sequence - attributes (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) to kernel, in lieu of decoder - sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) - * ENCODER_DECODER: cross-attention behavior; - use cross-attention block table for caching KVs derived - from encoder hidden states; since KV sequence lengths - will match encoder sequence lengths, pass encoder sequence - attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) - * ENCODER_ONLY: bidirectional attention with no KV caching; - use prefill sequence attributes - - Args: - layer: Attention layer instance. - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache: KV cache tensor with shape - [2, num_blocks, block_size * num_kv_heads * head_size]. - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - output: Optional output tensor. - output_scale: Optional output scale tensor. - output_block_scale: Optional output block scale tensor. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None and not self.use_triton_flash_attn: - raise NotImplementedError( - "fused output quantization only supported for Triton" - " implementation in ROCMFlashAttentionImpl for now") - - if output_block_scale is not None: - raise NotImplementedError( - "fused nvfp4 output quantization is not supported" - " for ROCMFlashAttentionImpl") - - query = query.view(-1, self.num_heads, self.head_size) - if key is not None: - assert value is not None - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - else: - assert value is None - - paged_attn = self.paged_attn_module - - # Reshaping kv tensors is required for AITER paged attention kernel - # because it works on a different tensor shape, - # when the size of one element is one byte (int8/fp8 dtypes). - # This reshaping is only required on the first forward call - # and the kv cache must not be empty. - if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1 - and not self.aiter_kv_scales_initialized - and kv_cache.shape != torch.Size([0])): - num_blocks = kv_cache.shape[1] - block_size = kv_cache.shape[2] // (self.num_kv_heads * - self.head_size) - k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), - dtype=torch.float32, - device=kv_cache.device) - v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), - dtype=torch.float32, - device=kv_cache.device) - self.aiter_kv_scales_initialized = True - k_scale.fill_(layer._k_scale.item()) - v_scale.fill_(layer._v_scale.item()) - layer._k_scale = k_scale - layer._v_scale = v_scale - - # Only update KV cache for decoder self-attention - # and encoder-decoder cross-attention - if self.attn_type not in [ - AttentionType.ENCODER, AttentionType.ENCODER_ONLY - ] and kv_cache.numel() > 0: - key_cache, value_cache = paged_attn.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - if key is not None and value is not None: - # Reshape the input keys and values and store them in the - # cache. If kv_cache is not provided, the new key and value - # tensors are not cached. This happens during the initial - # memory profiling run. - paged_attn.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping - if self.attn_type != AttentionType.ENCODER_DECODER else - attn_metadata.cross_slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if self.attn_type != AttentionType.ENCODER: - num_prefill_tokens = attn_metadata.num_prefill_tokens - elif self.attn_type == AttentionType.ENCODER_ONLY: - # For encoder-only models, all tokens are processed in one go - num_prefill_tokens = query.shape[0] - else: - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - - # For encoder-only and encoder models, - # we process all tokens at once - # For decoder and encoder-decoder, - # we may need to limit key/value to prefill tokens - if key is not None and value is not None \ - and self.attn_type not in [AttentionType.ENCODER_DECODER, - AttentionType.ENCODER_ONLY]: - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - # normal attention and DECODER - if self.attn_type == AttentionType.DECODER and ( - kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, - key_max_seq_len, seq_lens, - causal_mask) = (prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - attn_metadata.seq_lens, True) - # prefix-enabled attention and ENCODER/ENCODER_DECODER - else: - (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, - key_max_seq_len, seq_lens, - causal_mask) = _get_seq_len_block_table_args( - prefill_meta, self.attn_type) - # Prompt run. - if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: - # triton attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - attn_masks = None - if self.use_triton_flash_attn: - if self.alibi_slopes is not None: - attn_masks = _make_alibi_bias( - self.alibi_slopes, - query.dtype, - seq_lens, - make_attn_mask=causal_mask) # type: ignore - - use_fp8_scales = (layer._q_scale and layer._k_scale - and layer._v_scale and layer._prob_scale - and (self.kv_cache_dtype == "fp8" - or self.force_fp8_attention)) - - full_scales = ( - layer._q_scale.item(), layer._k_scale.item(), - layer._v_scale.item(), - layer._prob_scale.item()) if use_fp8_scales else None - self.triton_attn_func( - query, - key, - value, - output[:num_prefill_tokens], - query_seq_start_loc, - key_seq_start_loc, - query_max_seq_len, - key_max_seq_len, - causal_mask, - self.scale, - attn_masks[0][None] - if attn_masks is not None else None, - full_scales, - output_scale, - ) - elif self.use_naive_attn: - if self.num_kv_heads != self.num_heads: - # Interleave for MQA workaround. - key = self.repeat_kv(key, self.num_queries_per_kv) - value = self.repeat_kv(value, self.num_queries_per_kv) - if self.alibi_slopes is not None: - attn_masks = _make_alibi_bias( - self.alibi_slopes, - query.dtype, - attn_metadata.seq_lens, - make_attn_mask=causal_mask) # type: ignore - query = query.movedim(0, query.dim() - 2) - key = key.movedim(0, key.dim() - 2) - value = value.movedim(0, value.dim() - 2) - # sdpa math backend attention - self.sdpa_attn_func( - query, - key, - value, - output[:num_prefill_tokens], - query_seq_start_loc, - num_prefill_tokens, - self.num_heads, - self.head_size, - self.scale, - attn_masks, - ) - else: - # upstream FA does not support an output arg, copy - output[:num_prefill_tokens] = self.fa_attn_func( - q=query, - k=key, - v=value, - cu_seqlens_q=query_seq_start_loc, - cu_seqlens_k=key_seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=key_max_seq_len, - softmax_scale=self.scale, - causal=causal_mask, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ) - - else: - # prefix-enabled attention - - # not applicable for encoder-only models - if self.attn_type != AttentionType.ENCODER_ONLY: - output[:num_prefill_tokens] = paged_attn.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], - layer._k_scale, - layer._v_scale, - ) - # Skip decode phase for encoder-only models - if (decode_meta := attn_metadata.decode_metadata) and ( - self.attn_type != AttentionType.ENCODER_ONLY): - # Decoding run. - # Whether to use rocm custom paged attention or not - num_seqs, num_heads, head_size = decode_query.shape - block_size = value_cache.shape[3] - gqa_ratio = num_heads // self.num_kv_heads - from vllm.platforms.rocm import use_rocm_custom_paged_attention - use_custom = use_rocm_custom_paged_attention( - decode_query.dtype, head_size, block_size, gqa_ratio, - decode_meta.max_decode_seq_len, self.sliding_window, - self.kv_cache_dtype, self.alibi_slopes) - - if use_custom: - max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type - != AttentionType.ENCODER_DECODER else - decode_meta.max_encoder_seq_len) - assert max_seq_len is not None - max_num_partitions = ( - (max_seq_len + _PARTITION_SIZE_ROCM - 1) // - _PARTITION_SIZE_ROCM) - assert _PARTITION_SIZE_ROCM % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=query.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - - query_start_loc = None - ops.paged_attention_rocm( - output[num_prefill_tokens:], - exp_sums, - max_logits, - tmp_output, - decode_query, - key_cache, - value_cache, - self.num_kv_heads, - self.scale, - decode_meta.block_tables - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.cross_block_tables, - decode_meta.seq_lens_tensor - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.encoder_seq_lens_tensor, - query_start_loc, - block_size, - max_seq_len, - self.alibi_slopes, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - output_scale, - ) - else: - # PagedAttention does not support fused quant, manually quantize - if output_scale is None: - out_pa = output[num_prefill_tokens:] - else: - out_pa = torch.empty_like(output[num_prefill_tokens:], - dtype=query.dtype) - - out_pa[:] = paged_attn.forward_decode( - decode_query, - key_cache, - value_cache, - decode_meta.block_tables - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.cross_block_tables, - decode_meta.seq_lens_tensor - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.encoder_seq_lens_tensor, - decode_meta.max_decode_seq_len - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.max_encoder_seq_len, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - ) - - # Manually perform quantization - if output_scale is not None: - out_uq = out_pa.view(-1, self.num_heads * self.head_size) - out_q = output.view(-1, self.num_heads * self.head_size) - ops.scaled_fp8_quant(out_uq, - output_scale, - output=out_q[num_prefill_tokens:]) - - # Reshape the output tensor. - return output.view(-1, self.num_heads * self.head_size) - - -def _sdpa_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - seq_lens: torch.Tensor, - num_tokens: int, - num_heads: int, - head_size: int, - scale: float, - attn_masks: Optional[List[torch.Tensor]] = None, -) -> torch.Tensor: - start = 0 - assert output.shape == (num_tokens, num_heads, head_size) - assert output.dtype == query.dtype - assert output.device == query.device - - for i, seq_len in enumerate(seq_lens): - end = start + seq_len - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.MATH): - sub_out = torch.nn.functional.scaled_dot_product_attention( - query[:, start:end, :], - key[:, start:end, :], - value[:, start:end, :], - dropout_p=0.0, - is_causal=attn_masks is None, - attn_mask=attn_masks[i] if attn_masks else None, - scale=scale).movedim(query.dim() - 2, 0) - output[start:end, :, :] = sub_out - start = end - - return output diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py deleted file mode 100644 index fba5b5f6bca8..000000000000 --- a/vllm/attention/backends/triton_mla.py +++ /dev/null @@ -1,111 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata) -from vllm.attention.ops.triton_decode_attention import decode_attention_fwd - - -class TritonMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "TRITON_MLA" - - @staticmethod - def get_impl_cls() -> Type["TritonMLAImpl"]: - return TritonMLAImpl - - -class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] - if any(unsupported_features): - raise NotImplementedError( - "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonMLAImpl") - - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "TritonMLA with FP8 KV cache not yet supported") - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - B = q_nope.shape[0] - - q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.zeros(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) - - num_kv_splits = 4 # TODO: heuristic - - # TODO(lucas) Allocate ahead of time - attn_logits = torch.empty( - ( - B, - self.num_heads, - num_kv_splits, - # NOTE(lucas) idk why the +1 is here but sglang has it so we - # just mirror that - self.kv_lora_rank + 1, - ), - dtype=torch.float32, - device=q.device, - ) - - # Add a head dim of 1 - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] - PAGE_SIZE = kv_c_and_k_pe_cache.size(1) - - # Run MQA - decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, attn_logits, - num_kv_splits, self.scale, PAGE_SIZE) - - return self._v_up_proj(o) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index b28e6a4237cb..3f15580872c7 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -338,10 +338,9 @@ def graph_capture_get_metadata_for_batch( # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. assert self.runner.attn_backend.get_name() in \ - ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ - f"Expected attn_backend name to be either 'XFORMERS'," \ - f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ - f"got '{self.runner.attn_backend.get_name()}'" + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or " \ + f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) @@ -360,10 +359,9 @@ def get_graph_input_buffers( # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. assert self.runner.attn_backend.get_name() in \ - ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ - f"Expected attn_backend name to be either 'XFORMERS'," \ - f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ - f"got '{self.runner.attn_backend.get_name()}'" + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or " \ + f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'" self._add_additional_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py deleted file mode 100644 index 302d3d7ea903..000000000000 --- a/vllm/attention/backends/xformers.py +++ /dev/null @@ -1,805 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with xFormers and PagedAttention.""" -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type - -import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (AttentionBias, - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMaskWithTensorBias) - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import ( - CommonAttentionState, CommonMetadataBuilder, - get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, - is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class XFormersBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "XFORMERS" - - @staticmethod - def get_impl_cls() -> Type["XFormersImpl"]: - return XFormersImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return XFormersMetadata - - @staticmethod - def get_builder_cls() -> Type["XFormersMetadataBuilder"]: - return XFormersMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], - ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) - - -@dataclass -class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): - """Metadata for XFormersbackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # FIXME: It is for flash attn. - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] = None - - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] = None - - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - - # Self-attention prefill/decode metadata cache - _cached_prefill_metadata: Optional["XFormersMetadata"] = None - _cached_decode_metadata: Optional["XFormersMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - encoder_seq_start_loc: Optional[torch.Tensor] = None - - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[List[AttentionBias]] = None - self.encoder_attn_bias: Optional[List[AttentionBias]] = None - self.cross_attn_bias: Optional[List[AttentionBias]] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return is_all_encoder_attn_metadata_set(self) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return is_all_cross_attn_metadata_set(self) - - @property - def prefill_metadata(self) -> Optional["XFormersMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - # Recover cached prefill-phase attention - # metadata structure - return self._cached_prefill_metadata - - assert ((self.seq_lens is not None) - or (self.encoder_seq_lens is not None)) - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - - # Construct & cache prefill-phase attention metadata structure - self._cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["XFormersMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - # Recover cached decode-phase attention - # metadata structure - return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - - # Construct & cache decode-phase attention metadata structure - self._cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens_tensor=seq_lens_tensor, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - - # Batch may be composed of prefill|decodes, adjust query start indices - # to refer to the start of decodes when the two are split apart. - # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - if self._cached_decode_metadata.query_start_loc is not None: - qs = self._cached_decode_metadata.query_start_loc - self._cached_decode_metadata.query_start_loc = qs - qs[0] - return self._cached_decode_metadata - - -def _get_attn_bias( - attn_metadata: XFormersMetadata, - attn_type: str, -) -> Optional[AttentionBias]: - ''' - Extract appropriate attention bias from attention metadata - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - * Appropriate attention bias value given the attention type - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - return attn_metadata.attn_bias - elif attn_type == AttentionType.ENCODER: - return attn_metadata.encoder_attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - return attn_metadata.cross_attn_bias - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -def _set_attn_bias( - attn_metadata: XFormersMetadata, - attn_bias: List[Optional[AttentionBias]], - attn_type: str, -) -> None: - ''' - Update appropriate attention bias field of attention metadata, - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_bias: The desired attention bias value - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - attn_metadata.attn_bias = attn_bias - elif attn_type == AttentionType.ENCODER: - attn_metadata.encoder_attn_bias = attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - attn_metadata.cross_attn_bias = attn_bias - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): - - _metadata_cls = XFormersMetadata - - -class XFormersImpl(AttentionImpl[XFormersMetadata]): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "XFORMERS backend.") - if logits_soft_cap is not None: - logger.warning_once("XFormers does not support logits soft cap. " - "Outputs may be slightly off.") - if use_irope: - logger.warning_once( - "Using irope in XFormers is not supported yet, it will fall" - " back to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = sliding_window - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - supported_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - self.attn_type = attn_type - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: Optional[torch.Tensor], - value: Optional[torch.Tensor], - kv_cache: torch.Tensor, - attn_metadata: "XFormersMetadata", - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with xFormers and PagedAttention. - - For decoder-only models: query, key and value must be non-None. - - For encoder/decoder models: - * XFormersImpl.forward() may be invoked for both self- and cross- - attention layers. - * For self-attention: query, key and value must be non-None. - * For cross-attention: - * Query must be non-None - * During prefill, key and value must be non-None; key and value - get cached for use during decode. - * During decode, key and value may be None, since: - (1) key and value tensors were cached during prefill, and - (2) cross-attention key and value tensors do not grow during - decode - - A note on how the attn_type (attention type enum) argument impacts - attention forward() behavior: - - * DECODER: normal decoder-only behavior; - use decoder self-attention block table - * ENCODER: no KV caching; pass encoder sequence - attributes (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) to kernel, in lieu of decoder - sequence attributes (seq_lens/seq_lens_tensor/max_seq_len). - Used for encoder branch of encoder-decoder models. - * ENCODER_ONLY: no kv_caching, uses the normal attention - attributes (seq_lens/seq_lens_tensor/max_seq_len). - * ENCODER_DECODER: cross-attention behavior; - use cross-attention block table for caching KVs derived - from encoder hidden states; since KV sequence lengths - will match encoder sequence lengths, pass encoder sequence - attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) - - Args: - layer: Attention layer instance. - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache: KV cache tensor with shape - [2, num_blocks, block_size * num_kv_heads * head_size]. - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - output: Optional output tensor. - output_scale: Optional output scale tensor. - output_block_scale: Optional output block scale tensor. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for XFormersImpl") - - attn_type = self.attn_type - # Check that appropriate attention metadata attributes are - # selected for the desired attention type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") - - query = query.view(-1, self.num_heads, self.head_size) - if key is not None: - assert value is not None - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - else: - assert value is None - - # Self-attention vs. cross-attention will impact - # which KV cache memory-mapping & which - # seqlen datastructures we utilize - - if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): - # KV-cache during decoder-self- or - # encoder-decoder-cross-attention, but not - # during encoder attention. - # - # Even if there are no new key/value pairs to cache, - # we still need to break out key_cache and value_cache - # i.e. for later use by paged attention - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - if (key is not None) and (value is not None): - - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - # During cross-attention decode, key & value will be None, - # preventing this IF-statement branch from running - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory - # profiling run. - PagedAttention.write_to_paged_cache( - key, value, key_cache, value_cache, updated_slot_mapping, - self.kv_cache_dtype, layer._k_scale, layer._v_scale) - (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) = \ - get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_query_tokens:] - # QKV for prefill. - query = query[:num_prefill_query_tokens] - if key is not None and value is not None: - key = key[:num_prefill_kv_tokens] - value = value[:num_prefill_kv_tokens] - - assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: - # normal attention. - # block tables are empty if the prompt does not have a cached - # prefix. - out = self._run_memory_efficient_xformers_forward( - query, key, value, prefill_meta, attn_type=attn_type) - assert out.shape == output[:num_prefill_query_tokens].shape - output[:num_prefill_query_tokens] = out - else: - assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have prefix attention.") - - assert prefill_meta.query_start_loc is not None - assert prefill_meta.max_query_len is not None - - # prefix-enabled attention - # TODO(Hai) this triton kernel has regression issue (broke) to - # deal with different data types between KV and FP8 KV cache, - # to be addressed separately. - out = PagedAttention.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window, - layer._k_scale, - layer._v_scale, - ) - assert output[:num_prefill_query_tokens].shape == out.shape - output[:num_prefill_query_tokens] = out - - if decode_meta := attn_metadata.decode_metadata: - assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have decode metadata.") - - ( - seq_lens_arg, - max_seq_len_arg, - block_tables_arg, - ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - - output[num_prefill_query_tokens:] = PagedAttention.forward_decode( - decode_query, - key_cache, - value_cache, - block_tables_arg, - seq_lens_arg, - max_seq_len_arg, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - ) - - # Reshape the output tensor. - return output.view(-1, self.num_heads * self.head_size) - - def _run_memory_efficient_xformers_forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: XFormersMetadata, - attn_type: str = AttentionType.DECODER, - ) -> torch.Tensor: - """Attention for 1D query of multiple prompts. Multiple prompt - tokens are flattened in to `query` input. - - See https://facebookresearch.github.io/xformers/components/ops.html - for API spec. - - Args: - query: shape = [num_prefill_tokens, num_heads, head_size] - key: shape = [num_prefill_tokens, num_kv_heads, head_size] - value: shape = [num_prefill_tokens, num_kv_heads, head_size] - attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally - """ - - original_query = query - if self.num_kv_heads != self.num_heads: - # GQA/MQA requires the shape [B, M, G, H, K]. - # Note that the output also has the same shape (which is different - # from a spec from the doc). - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - # Set attention bias if not provided. This typically happens at - # the very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - attn_bias = _get_attn_bias(attn_metadata, attn_type) - if attn_bias is None: - if self.alibi_slopes is None: - - # Cross attention block of decoder branch of encoder-decoder - # model uses seq_lens for dec / encoder_seq_lens for enc - if (attn_type == AttentionType.ENCODER_DECODER): - assert attn_metadata.seq_lens is not None - assert attn_metadata.encoder_seq_lens is not None - - # Cross-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens, - attn_metadata.encoder_seq_lens, - device=query.device) - - # Encoder branch of encoder-decoder model uses - # attn_metadata.encoder_seq_lens - elif attn_type == AttentionType.ENCODER: - - assert attn_metadata.encoder_seq_lens is not None - - # Encoder self-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.encoder_seq_lens, device=query.device) - - # Self-attention block of encoder-only model just - # uses the seq_lens directly. - elif attn_type == AttentionType.ENCODER_ONLY: - assert attn_metadata.seq_lens is not None - - # Encoder self-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens, device=query.device) - - # Self-attention block of decoder branch just - # uses the seq_lens directly - elif attn_type == AttentionType.DECODER: - assert attn_metadata.seq_lens is not None - - # Decoder self-attention mask is causal - attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens, device=query.device) - else: - raise ValueError("Unknown AttentionType: %s", attn_type) - - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - attn_bias = [attn_bias] - else: - assert attn_type == AttentionType.DECODER - assert attn_metadata.seq_lens is not None - attn_bias = _make_alibi_bias(self.alibi_slopes, - self.num_kv_heads, query.dtype, - attn_metadata.seq_lens) - - _set_attn_bias(attn_metadata, attn_bias, attn_type) - - # No alibi slopes. - # TODO(woosuk): Too many view operations. Let's try to reduce - # them in the future for code readability. - if self.alibi_slopes is None: - # Add the batch dimension. - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=attn_bias[0], - p=0.0, - scale=self.scale) - return out.view_as(original_query) - - # Attention with alibi slopes. - # FIXME(woosuk): Because xformers does not support dynamic sequence - # lengths with custom attention bias, we process each prompt one by - # one. This is inefficient, especially when we have many short prompts. - assert attn_metadata.seq_lens is not None - output = torch.empty_like(original_query) - start = 0 - for i, seq_len in enumerate(attn_metadata.seq_lens): - end = start + seq_len - out = xops.memory_efficient_attention_forward( - query[None, start:end], - key[None, start:end], - value[None, start:end], - attn_bias=attn_bias[i], - p=0.0, - scale=self.scale) - # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.view_as(original_query[start:end])) - start += seq_len - return output - - -def _make_alibi_bias( - alibi_slopes: torch.Tensor, - num_kv_heads: int, - dtype: torch.dtype, - seq_lens: List[int], -) -> List[AttentionBias]: - attn_biases: List[AttentionBias] = [] - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - # Calculate a matrix where each element represents ith element- jth - # element. - bias = bias[None, :] - bias[:, None] - - padded_len = (seq_len + 7) // 8 * 8 - num_heads = alibi_slopes.shape[0] - bias = torch.empty( - 1, # batch size - num_heads, - seq_len, - padded_len, - device=alibi_slopes.device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) - - return attn_biases diff --git a/vllm/config/model.py b/vllm/config/model.py index 95fe52883db0..33e5d3ea04a4 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -32,8 +32,7 @@ from vllm.transformers_utils.runai_utils import (ObjectStorageModel, is_runai_obj_uri) from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType, - LazyLoader, common_broadcastable_dtype) +from vllm.utils import LayerBlockType, LazyLoader, common_broadcastable_dtype if TYPE_CHECKING: from transformers import PretrainedConfig @@ -1103,10 +1102,6 @@ def verify_dual_chunk_attention_config( self.hf_config.dual_chunk_attention_config[ "sparse_attention_enabled"] = True - if envs.VLLM_ATTENTION_BACKEND != STR_DUAL_CHUNK_FLASH_ATTN_VAL: - raise ValueError("please set VLLM_ATTENTION_BACKEND to " - f"{STR_DUAL_CHUNK_FLASH_ATTN_VAL}") - def verify_with_parallel_config( self, parallel_config: ParallelConfig, diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 911d77ba36fa..efa4c9abf47f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -44,7 +44,7 @@ def get_model_args(self, model_executable: torch.nn.Module): # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading # to a kv_cache shape of [2, num_blks, blk_size, # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. - # For more details, see vllm/attention/backends/mla/common.py. + # For more details, see vllm/v1/attention/backends/mla/common.py. if self.is_deepseek_mla and self.use_mla_opt: head_size = model_config.kv_lora_rank + \ model_config.qk_rope_head_dim diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7e00260caa39..b09d43f70558 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,8 +44,8 @@ from vllm.transformers_utils.config import (get_model_path, is_interleaved, maybe_override_with_speculators) from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, - GiB_bytes, get_ip, is_in_ray_actor) +from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip, + is_in_ray_actor) from vllm.v1.sample.logits_processor import LogitsProcessor # yapf: enable @@ -1163,17 +1163,6 @@ def create_engine_config( self._set_default_args_v0(model_config) assert self.enable_chunked_prefill is not None - if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]: - assert self.enforce_eager, ( - "Cuda graph is not supported with DualChunkFlashAttention. " - "To run the model in eager mode, set 'enforce_eager=True' " - "or use '--enforce-eager' in the CLI.") - assert current_platform.is_cuda(), ( - "DualChunkFlashAttention is only supported on CUDA platform.") - assert not use_v1, ( - "DualChunkFlashAttention is not supported on V1 engine. " - "To run the model in V0 engine, try set 'VLLM_USE_V1=0'") - sliding_window: Optional[int] = None if not is_interleaved(model_config.hf_text_config): # Only set CacheConfig.sliding_window if the model is all sliding diff --git a/vllm/envs.py b/vllm/envs.py index 3991a789d80f..cbd1d5474e60 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -529,7 +529,6 @@ def get_vllm_port() -> Optional[int]: # - "TORCH_SDPA": use torch.nn.MultiheadAttention # - "FLASH_ATTN": use FlashAttention # - "XFORMERS": use XFormers - # - "ROCM_FLASH": use ROCmFlashAttention # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA # - "FLASH_ATTN_MLA": use FlashAttention for MLA diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index c926e17a2c19..7f376b70a7ae 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -53,13 +53,18 @@ class Mamba2Metadata: def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: """Returns the appropriate metadata classes for the current platform.""" if current_platform.is_rocm(): - from vllm.attention.backends.rocm_flash_attn import ( - ROCmFlashAttentionMetadata) - return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata) - elif current_platform.is_cuda(): - from vllm.attention.backends.flash_attn import FlashAttentionMetadata - from vllm.attention.backends.xformers import XFormersMetadata - return (FlashAttentionMetadata, XFormersMetadata, + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) + from vllm.v1.attention.backends.triton_attn import ( + TritonAttentionMetadata) + return (AiterFlashAttentionMetadata, TritonAttentionMetadata, + PlaceholderAttentionMetadata) + if current_platform.is_cuda(): + from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionMetadata) + from vllm.v1.attention.backends.xformers import ( + XFormersAttentionMetadata) + return (FlashAttentionMetadata, XFormersAttentionMetadata, PlaceholderAttentionMetadata) raise ValueError( f"Unsupported platform for Mamba2: {current_platform.device_type}") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a99a6679a569..415d36c681d8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -478,7 +478,8 @@ class DeepseekV2MLAAttention(nn.Module): Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py + For more info see MLACommonImpl in: + vllm/v1/attention/backends/mla/utils.py """ def __init__( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c263e2afe83b..05f129f513a0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -226,8 +226,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink) -> str: if use_mla: - # TODO(lucas): refactor to be more concise - # we should probably consider factoring out V1 here + if not use_v1: + raise RuntimeError( + "MLA attention backends require the V1 engine. " + "Set VLLM_USE_V1=1 to enable them.") from vllm.attention.ops.flashmla import is_flashmla_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla @@ -246,35 +248,17 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, use_triton = selected_backend == _Backend.TRITON_MLA or ( selected_backend is None) - def _get_version(name, import_suffix) -> str: - if use_v1: - logger.info_once(f"Using {name} backend on V1 engine.") - return f"vllm.v1.attention.backends.mla.{import_suffix}" - else: - logger.info_once(f"Using {name} backend.") - return f"vllm.attention.backends.{import_suffix}" - if use_cutlassmla: - if use_v1: - logger.info_once("Using Cutlass MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "cutlass_mla.CutlassMLABackend") - else: - logger.warning( - "Cutlass MLA backend is only supported on V1 engine") + logger.info_once("Using Cutlass MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "cutlass_mla.CutlassMLABackend") if use_flashinfermla: - if use_v1: - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) - set_kv_cache_layout("HND") - logger.info_once( - "Using FlashInfer MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashinfer_mla.FlashInferMLABackend") - else: - logger.warning( - "FlashInfer MLA backend is only supported on V1 engine" - ) + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + set_kv_cache_layout("HND") + logger.info_once("Using FlashInfer MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashinfer_mla.FlashInferMLABackend") if use_flashmla: if block_size != 64: logger.warning( @@ -282,20 +266,18 @@ def _get_version(name, import_suffix) -> str: " (currently only supports block size 64).", block_size) else: - return _get_version("FlashMLA", "flashmla.FlashMLABackend") - if use_flashattn: - if use_v1: - logger.info_once( - "Using FlashAttention MLA backend on V1 engine.") + logger.info_once("Using FlashMLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." - "flashattn_mla.FlashAttnMLABackend") - else: - logger.warning( - "FlashAttention MLA backend is only supported on V1 " - "engine.") + "flashmla.FlashMLABackend") + if use_flashattn: + logger.info_once( + "Using FlashAttention MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashattn_mla.FlashAttnMLABackend") if use_triton: - return _get_version("Triton MLA", - "triton_mla.TritonMLABackend") + logger.info_once("Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") if use_v1: FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 @@ -382,78 +364,9 @@ def _get_version(name, import_suffix) -> str: ) return FLEX_ATTENTION_V1 - # Backends for V0 engine - if selected_backend == _Backend.XFORMERS: - logger.info("Using XFormers backend.") - return "vllm.attention.backends.xformers.XFormersBackend" - elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN: - logger.info("Using DualChunkFlashAttention backend.") - return ("vllm.attention.backends.dual_chunk_flash_attn." - "DualChunkFlashAttentionBackend") - elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN: - logger.info("Using DifferentialFlashAttention backend.") - return ("vllm.attention.backends.differential_flash_attn." - "DifferentialFlashAttentionBackend") - elif selected_backend == _Backend.FLASH_ATTN: - pass - elif selected_backend: - raise ValueError( - f"Invalid attention backend for {cls.device_name}, " - f"with use_v1: {use_v1} use_mla: {use_mla}") - - target_backend = _Backend.FLASH_ATTN - if not cls.has_device_capability(80): - # Volta and Turing NVIDIA GPUs. - logger.info( - "Cannot use FlashAttention-2 backend for Volta and Turing " - "GPUs.") - target_backend = _Backend.XFORMERS - elif dtype not in (torch.float16, torch.bfloat16): - logger.info( - "Cannot use FlashAttention-2 backend for dtype other than " - "torch.float16 or torch.bfloat16.") - target_backend = _Backend.XFORMERS - elif block_size % 16 != 0: - logger.info( - "Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - target_backend = _Backend.XFORMERS - - # FlashAttn is valid for the model, checking if the package is - # installed. - if target_backend == _Backend.FLASH_ATTN: - try: - import vllm.vllm_flash_attn # noqa: F401 - from vllm.attention.backends.flash_attn import ( # noqa: F401 - FlashAttentionBackend, flash_attn_supports_fp8) - - supported_sizes = \ - FlashAttentionBackend.get_supported_head_sizes() - if head_size not in supported_sizes: - logger.info( - "Cannot use FlashAttention-2 backend for head size %d.", - head_size) - target_backend = _Backend.XFORMERS - fp8_kv_cache = (kv_cache_dtype is not None - and kv_cache_dtype.startswith("fp8")) - if (fp8_kv_cache and not flash_attn_supports_fp8()): - logger.info( - "Cannot use FlashAttention backend for FP8 KV cache.") - target_backend = _Backend.XFORMERS - except ImportError: - logger.info( - "Cannot use FlashAttention-2 backend because the " - "vllm.vllm_flash_attn package is not found. " - "Make sure that vllm_flash_attn was built and installed " - "(on by default).") - target_backend = _Backend.XFORMERS - - if target_backend == _Backend.XFORMERS: - logger.info("Using XFormers backend.") - return "vllm.attention.backends.xformers.XFormersBackend" - - logger.info("Using Flash Attention backend.") - return "vllm.attention.backends.flash_attn.FlashAttentionBackend" + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend.") @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index dce2924ac7a9..9470434aa428 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -191,6 +191,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink) -> str: if use_mla: + if not use_v1: + raise RuntimeError( + "MLA attention backends require the V1 engine. " + "Set VLLM_USE_V1=1 to enable them.") + from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( is_aiter_mla_enabled) @@ -201,39 +206,24 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if selected_backend == _Backend.TRITON_MLA: if block_size != 1: - if use_v1: - logger.info_once( - "Using Triton MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") - else: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 - else: - raise ValueError( - f" The selected backend, {selected_backend.name}," - f"does not support block size {block_size}.") - elif selected_backend == _Backend.ROCM_AITER_MLA \ - or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1: + logger.info_once("Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"does not support block size {block_size}.") + if selected_backend in (_Backend.ROCM_AITER_MLA, + _Backend.ROCM_AITER_MLA_VLLM_V1): if block_size == 1: - if use_v1: - logger.info("Using AITER MLA backend on V1 engine.") - return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 - else: - logger.info("Using AITER MLA backend") - return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 - else: - raise ValueError( - f" The selected backend, {selected_backend.name}," - f"does not support block size {block_size}." - "(currently only supports block size 1)") - else: + logger.info("Using AITER MLA backend on V1 engine.") + return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 raise ValueError( f" The selected backend, {selected_backend.name}," - f"is not MLA type while requested for MLA backend.") - - if selected_backend is None or selected_backend == _Backend.FLASH_ATTN: - selected_backend = _Backend.ROCM_FLASH + f"does not support block size {block_size}." + "(currently only supports block size 1)") + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"is not MLA type while requested for MLA backend.") if envs.VLLM_USE_V1: if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \ @@ -245,14 +235,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Triton Attention backend on V1 engine.") return ("vllm.v1.attention.backends." "triton_attn.TritonAttentionBackend") - if selected_backend == _Backend.ROCM_FLASH: - if not cls.has_device_capability(90): - # not Instinct series GPUs. - logger.info("flash_attn is not supported on NAVI GPUs.") - else: - logger.info("%s is not supported in AMD GPUs.", selected_backend) - logger.info("Using ROCmFlashAttention backend.") - return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend.") @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 968bba664f0a..834ec9b1d30b 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -157,10 +157,8 @@ # register, corresponding to possible backends STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" -STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" -STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" MB_bytes = 1_000_000