diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 98700ff73fd1..99793ab95581 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -236,6 +236,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) reqs: list[CachedRequestState] = [] req_id_reqs = {} @@ -327,6 +328,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) ref_input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, @@ -336,6 +338,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) reqs: list[CachedRequestState] = [] diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 8b571f95c5ec..3dd83a71432f 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -62,6 +62,9 @@ def initialize_kv_cache(runner: GPUModelRunner): block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], + kernel_block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], ) runner.initialize_attn_backend(kv_cache_config) @@ -838,3 +841,152 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): conv_blocks_constant) assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant) + + +def test_hybrid_block_table_initialization(): + """Test hybrid block table with different kernel and kvcache_manager block + sizes.""" + from vllm.v1.worker.block_table import BlockTable + + # Test configuration: kvcache_manager block size = 32, + # kernel block size = 16 + block_size = 32 + kernel_block_sizes = [16] + max_num_reqs = 10 + max_num_blocks_per_req = 20 + max_num_batched_tokens = 512 + + block_table = BlockTable(block_size=block_size, + max_num_reqs=max_num_reqs, + max_num_blocks_per_req=max_num_blocks_per_req, + max_num_batched_tokens=max_num_batched_tokens, + pin_memory=False, + device=torch.device(DEVICE), + kernel_block_size=kernel_block_sizes[0]) + + # Verify hybrid block configuration + assert block_table.use_hybrid_blocks is True + assert block_table.block_size == kernel_block_sizes[0] + assert block_table.blocks_per_phys_block == ( + block_size // kernel_block_sizes[0]) # Changed to use first element + + # Test block table conversion logic + # One kvcache_manager block should map to multiple kernel blocks + kvcache_manager_blocks = [0, 1, 2] + + # Verify that kvcache_manager blocks can be converted to kernel blocks + # and that block table operations work correctly. + req_index = 0 + block_table.append_row(kvcache_manager_blocks, req_index) + # Get expected kernel blocks from the implementation for verification. + expected_kernel_blocks = block_table._map_to_kernel_blocks( + np.array(kvcache_manager_blocks)) + # Verify block table state + assert block_table.num_blocks_per_row[req_index] == len( + expected_kernel_blocks) + assert np.array_equal( + block_table.block_table.np[req_index, :len(expected_kernel_blocks)], + expected_kernel_blocks) + + +def test_input_batch_with_kernel_block_sizes(): + """Test InputBatch initialization with kernel_block_sizes parameter.""" + max_num_reqs = 10 + max_model_len = 512 + max_num_batched_tokens = 512 + device = torch.device(DEVICE) + pin_memory = False + vocab_size = 50272 + + # Test with different kernel block sizes + block_sizes = [32, 64] + kernel_block_sizes = [16, 32] + + input_batch = InputBatch(max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + device=device, + pin_memory=pin_memory, + vocab_size=vocab_size, + block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes) + + # Verify that block tables were created with kernel block sizes + assert len(input_batch.block_table.block_tables) == len(block_sizes) + + for i, (phys_size, + kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)): + block_table = input_batch.block_table.block_tables[i] + if phys_size != kernel_size: + assert block_table.use_hybrid_blocks is True + assert block_table.block_size == kernel_size + else: + assert block_table.use_hybrid_blocks is False + assert block_table.block_size == kernel_size + + +def test_hybrid_cache_integration(model_runner, dist_init): + """Test hybrid cache architecture integration with GPUModelRunner.""" + # Create a new model runner with hybrid cache configuration + vllm_config = get_vllm_config() + + # Configure hybrid cache with different kvcache_manager block size + vllm_config.cache_config.block_size = 32 + + model_config = vllm_config.model_config + num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) + head_size = model_config.get_head_size() + vllm_config.compilation_config.static_forward_context[ + "layer.0"] = Attention(num_heads, head_size, 0.1) + + runner = GPUModelRunner(vllm_config, DEVICE) + + # Initialize KV cache with configuration + attn_spec = FullAttentionSpec( + block_size=16, # Use kernel block size directly + num_kv_heads=runner.model_config.get_num_kv_heads( + runner.parallel_config), + head_size=runner.model_config.get_head_size(), + dtype=runner.kv_cache_dtype, + use_mla=False, + ) + tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS + kv_cache_config = KVCacheConfig( + num_blocks=NUM_BLOCKS, + kv_cache_tensors=[ + KVCacheTensor(size=tensor_size, shared_by=["layer.0"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec) + ], + ) + runner.kv_cache_config = kv_cache_config + + # Initialize input batch with kernel block sizes + runner.input_batch = InputBatch( + max_num_reqs=runner.max_num_reqs, + max_model_len=runner.max_model_len, + max_num_batched_tokens=runner.max_num_tokens, + device=runner.device, + pin_memory=runner.pin_memory, + vocab_size=runner.model_config.get_vocab_size(), + block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], + kernel_block_sizes=[16]) # Use kernel block size + + runner.initialize_attn_backend(kv_cache_config) + + # Verify hybrid block table configuration + block_table = runner.input_batch.block_table.block_tables[0] + assert block_table.block_size == ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size) + + # Test request processing with hybrid blocks + req_id = "hybrid_req_0" + scheduler_output = _schedule_new_request(req_id) + + # Update states should work with hybrid blocks + runner._update_states(scheduler_output) + assert _is_req_scheduled(runner, req_id) + assert _is_req_state_block_table_match(runner, req_id) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 629e42a8b902..6414cf2fff88 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar +from typing import (Generic, List, Optional, Protocol, Tuple, Type, TypeVar, + Union) import torch @@ -24,6 +25,13 @@ class AttentionType: """Attention between dec. Q and enc. K/V for encoder-decoder.""" +class MultipleOf: + base: int + + def __init__(self, base: int): + self.base = base + + class AttentionBackend(ABC): """Abstract class for attention backends.""" # For some attention backends, we allocate an output tensor before @@ -54,6 +62,10 @@ def get_impl_cls() -> Type["AttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: raise NotImplementedError + @classmethod + def get_supported_block_size(cls) -> list[Union[int, MultipleOf]]: + return cls.get_impl_cls().get_supported_block_size() + @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @@ -154,6 +166,11 @@ def __init__( ) -> None: raise NotImplementedError + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + # TODO: implement this function for all backends. + return [MultipleOf(1)] + @abstractmethod def forward( self, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index cab85ea347f4..894700d942d4 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -361,16 +361,34 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=model_config.max_model_len, ).page_size_bytes - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, - 16 * attn_page_size_1_token) + # Model may be marked as is_hybrid + # but mamba is skipped via config, + # return directly + if mamba_page_size == 0: + return + + # Attention backend constraints: + # - FlashAttention (FA) requires block size to be multiple of 16 + # - MLA (Multi-head Latent Attention) requires larger alignment: + # * CUTLASS_MLA backend: 128-byte alignment + # * Other MLA backends: 64-byte alignment + if model_config.use_mla: + use_cutlass_mla = (envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") + kernel_block_alignment_size = 128 if use_cutlass_mla else 64 + else: + kernel_block_alignment_size = 16 + + # Calculate minimum attention block size that satisfies both: + # 1. Backend alignment requirements (kernel_block_alignment_size) + # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, + kernel_block_alignment_size * attn_page_size_1_token) # override attention block size if either (a) the # user has not set it or (b) the user has set it # too small. + # if (cache_config.block_size is None or cache_config.block_size < attn_block_size): cache_config.block_size = attn_block_size diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 58ba08101bc9..88cad70740d4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -128,7 +128,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing - if model_config is not None and model_config.use_mla: + # Note: block_size is initialized in + # HybridAttentionMambaModelConfig.verify_and_update_config + # and doesn't need to be reinitialized here + if model_config is not None and model_config.use_mla \ + and cache_config.block_size is not None: # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the @@ -159,17 +163,18 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: from vllm.attention.ops.flashmla import is_flashmla_supported if use_flashmla and is_flashmla_supported()[0] \ - and cache_config.block_size != 64: + and cache_config.block_size % 64 != 0: cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") - if use_cutlass_mla and cache_config.block_size != 128: + if use_cutlass_mla and cache_config.block_size % 128 != 0: cache_config.block_size = 128 logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA backend.") - if use_flashinfer_mla and cache_config.block_size not in [32, 64]: + if use_flashinfer_mla and cache_config.block_size != 32 and \ + cache_config.block_size % 64 != 0: cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashInferMLA " @@ -237,10 +242,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) - and block_size == 128) + and block_size % 128 == 0) use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( - selected_backend is None and cls.is_device_capability(100) - and block_size in [32, 64]) + selected_backend is None and cls.is_device_capability(100) and + (block_size == 32 or block_size % 64 == 0)) use_flashmla = selected_backend == _Backend.FLASHMLA or ( selected_backend is None and is_flashmla_supported()[0]) use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( @@ -260,7 +265,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, return ("vllm.v1.attention.backends.mla." "flashinfer_mla.FlashInferMLABackend") if use_flashmla: - if block_size != 64: + if block_size % 64 != 0: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f284847dd9e9..c9f734646ed3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import numpy as np import torch @@ -10,6 +10,7 @@ from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, is_quantized_kv_cache) from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states @@ -47,6 +48,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index a4bf3635bbca..626e9af6b59d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -17,7 +17,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionType) + AttentionType, MultipleOf) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -154,6 +154,10 @@ def get_supported_head_sizes(cls) -> list[int]: # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 return [64, 128, 256] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(1)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index d44e20f2cb6b..c67b71615bce 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -8,6 +8,7 @@ import vllm._custom_ops as ops from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, + MultipleOf, is_quantized_kv_cache) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -39,6 +40,10 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [128] + class SM100Workspace: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index ac0524ba088b..f2079d3f19f5 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -6,7 +6,8 @@ import torch -from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, + MultipleOf) from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) @@ -41,6 +42,10 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [64] + @dataclass class FlashMLADecodeMetadata(MLACommonDecodeMetadata): diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 96f8e92a2039..82c09472646a 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,12 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + MultipleOf) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -327,6 +328,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 1d4ab4c96728..719f61d32817 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,12 +4,13 @@ import ast from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + MultipleOf) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger @@ -39,6 +40,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index fc5ecf6ed3b6..2fd28f5d8f3b 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -2,12 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """High-Performance Triton-only Attention layer.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + MultipleOf) from vllm.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash) from vllm.attention.ops.triton_unified_attention import unified_attention @@ -143,6 +144,10 @@ class TritonAttentionBackend(AttentionBackend): def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16, torch.float32] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: # Triton Attention supports any head size above 32 diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index f739e6832274..f799a2361f5b 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,12 +3,13 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + MultipleOf) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger @@ -77,6 +78,10 @@ def get_supported_head_sizes(cls) -> list[int]: 256, ] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 82b6d1b514d5..b4139930d175 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -23,21 +23,63 @@ def __init__( max_num_batched_tokens: int, pin_memory: bool, device: torch.device, + kernel_block_size: int, ): - self.block_size = block_size + """ + Args: + block_size: Block size used for KV cache memory allocation + max_num_reqs: Maximum number of concurrent requests supported. + max_num_blocks_per_req: Maximum number of blocks per request. + max_num_batched_tokens: Maximum number of tokens in a batch. + pin_memory: Whether to pin memory for faster GPU transfers. + device: Target device for the block table. + kernel_block_size: The block_size of underlying attention kernel. + Will be the same as `block_size` if `block_size` is supported + by the attention kernel. + """ self.max_num_reqs = max_num_reqs - self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device - self.block_table = self._make_buffer(max_num_reqs, - max_num_blocks_per_req, + if kernel_block_size == block_size: + # Standard case: allocation and computation use same block size + # No block splitting needed, direct mapping + self.block_size = block_size + self.blocks_per_phys_block = 1 + self.use_hybrid_blocks = False + else: + # Hybrid case: allocation block size differs from kernel block size + # Memory blocks are subdivided to match kernel requirements + # Example: 32-token memory blocks with 16-token kernel blocks + # → Each memory block corresponds to 2 kernel blocks + if block_size % kernel_block_size != 0: + raise ValueError( + f"kernel_block_size {kernel_block_size} must divide " + f"kv_manager_block_size size {block_size} evenly") + + self.block_size = kernel_block_size + self.blocks_per_phys_block = (block_size // kernel_block_size) + self.use_hybrid_blocks = True + + self.max_num_blocks_per_req = max_num_blocks_per_req * \ + self.blocks_per_phys_block + + self.block_table = self._make_buffer(self.max_num_reqs, + self.max_num_blocks_per_req, dtype=torch.int32) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) self.slot_mapping = self._make_buffer(self.max_num_batched_tokens, dtype=torch.int64) + + if self.use_hybrid_blocks: + self._bias_array = np.arange(0, + self.blocks_per_phys_block).reshape( + 1, -1) + else: + self._bias_array = None + try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group @@ -53,6 +95,10 @@ def append_row( ) -> None: if not block_ids: return + + if self.use_hybrid_blocks: + block_ids = self._map_to_kernel_blocks(np.array(block_ids)) + num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks @@ -66,7 +112,6 @@ def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] block_table_np = self.block_table.np block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks] - self.num_blocks_per_row[tgt] = num_blocks def swap_row(self, src: int, tgt: int) -> None: src_tgt, tgt_src = [src, tgt], [tgt, src] @@ -89,8 +134,10 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. virtual_block_size = self.block_size * self.dcp_world_size + block_table_indices = (req_indices * self.max_num_blocks_per_req + positions // virtual_block_size) + block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local # tokens. @@ -106,6 +153,7 @@ def compute_slot_mapping(self, req_indices: np.ndarray, else: block_table_indices = (req_indices * self.max_num_blocks_per_req + positions // self.block_size) + block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size np.add(block_numbers * self.block_size, @@ -122,6 +170,30 @@ def clear(self) -> None: self.block_table.gpu.fill_(0) self.block_table.cpu.fill_(0) + def _map_to_kernel_blocks(self, + kv_manager_block_ids: np.ndarray) -> np.ndarray: + """Convert kv_manager_block_id IDs to kernel block IDs. + + Example: + # kv_manager_block_ids: 32 tokens, + # Kernel block size: 16 tokens + # blocks_per_phys_block = 2 + >>> kv_manager_block_ids = np.array([0, 1, 2]) + >>> Result: [0, 1, 2, 3, 4, 5] + + # Each kv_manager_block_id maps to 2 kernel block id: + # kv_manager_block_id 0 → kernel block id [0, 1] + # kv_manager_block_id 1 → kernel block id [2, 3] + # kv_manager_block_id 2 → kernel block id [4, 5] + """ + if not self.use_hybrid_blocks: + return kv_manager_block_ids + + kernel_block_ids = kv_manager_block_ids.reshape( + -1, 1) * self.blocks_per_phys_block + self._bias_array + + return kernel_block_ids.reshape(-1) + def get_device_tensor(self, num_reqs: int) -> torch.Tensor: """Returns the device tensor of the block table.""" return self.block_table.gpu[:num_reqs] @@ -152,6 +224,7 @@ def __init__(self, pin_memory: bool, device: torch.device, block_sizes: list[int], + kernel_block_sizes: list[int], num_speculative_tokens: int = 0) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, @@ -163,12 +236,18 @@ def __init__(self, # DCP might not be initialized in testing dcp_world_size = 1 + if len(kernel_block_sizes) != len(block_sizes): + raise ValueError( + f"kernel_block_sizes length ({len(kernel_block_sizes)}) " + f"must match block_sizes length ({len(block_sizes)})") + self.block_tables = [ BlockTable( block_size, max_num_reqs, max(cdiv(max_model_len, block_size * dcp_world_size), 1 + num_speculative_tokens), max_num_batched_tokens, - pin_memory, device) for block_size in block_sizes + pin_memory, device, kernel_block_size) for block_size, + kernel_block_size in zip(block_sizes, kernel_block_sizes) ] def append_row(self, block_ids: tuple[list[int], ...], diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 67fb9864b19c..7100cfdad657 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -87,6 +87,7 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + kernel_block_sizes: list[int], logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, is_pooling_model: bool = False, @@ -143,8 +144,8 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, - num_speculative_tokens=num_speculative_tokens, - ) + kernel_block_sizes=kernel_block_sizes, + num_speculative_tokens=num_speculative_tokens) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1bae0d4ce4d1..380fc4dfb9f2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -19,7 +19,7 @@ import vllm.envs as envs from vllm.attention import Attention, AttentionType -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, MultipleOf from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper @@ -318,6 +318,7 @@ def __init__( pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], + kernel_block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config, self.device, self.pin_memory, @@ -3740,6 +3741,100 @@ def calculate_reorder_batch_threshold(self) -> None: else: self.reorder_batch_threshold = reorder_batch_threshold_i + def _find_compatible_block_sizes(self, + kv_manager_block_size: int, + backend_cls: type[AttentionBackend], + return_all: bool = False) -> list[int]: + """ + Find compatible block sizes for a backend. + + Args: + kv_manager_block_size: Physical block size of KV cache + backend_cls: Attention backend class + return_all: Return all compatible sizes if True, max size if False + + Returns: + Compatible block size(s) based on return_all parameter + + Raises: + ValueError: If no compatible block size found + """ + supported_block_size = backend_cls.get_supported_block_size() + compatible_sizes = [] + + for block_size in supported_block_size: + if isinstance(block_size, int): + if kv_manager_block_size % block_size == 0: + compatible_sizes.append(block_size) + elif (isinstance(block_size, MultipleOf) + and kv_manager_block_size % block_size.base == 0): + compatible_sizes.append(kv_manager_block_size) + + if not compatible_sizes: + raise ValueError( + f"No compatible block size for {kv_manager_block_size}") + + return compatible_sizes if return_all else [max(compatible_sizes)] + + def _get_all_compatible_block_sizes( + self, kv_manager_block_size: int, + backend_cls: type[AttentionBackend]) -> list[int]: + """ + Get all compatible block sizes for a backend. + + Args: + kv_manager_block_size: Physical block size of KV cache + backend_cls: Attention backend class + + Returns: + List of all compatible block sizes in descending order + + Raises: + ValueError: If no compatible block size found + """ + compatible_sizes = self._find_compatible_block_sizes( + kv_manager_block_size, backend_cls, return_all=True) + + return sorted(list(set(compatible_sizes)), reverse=True) + + def _select_common_block_size(self, kv_manager_block_size: int, + attn_groups: list[AttentionGroup]) -> int: + """ + Select common block size for all backends. + + Args: + kv_manager_block_size: Physical block size of KV cache + attn_groups: List of attention groups + + Returns: + Block size supported by all backends, + prioritizing cache_config.block_size + + Raises: + ValueError: If no common block size found + """ + all_backend_supports = [] + + for attn_group in attn_groups: + supported_sizes = self._get_all_compatible_block_sizes( + kv_manager_block_size, attn_group.backend) + all_backend_supports.append(set(supported_sizes)) + + common_supported_sizes = set.intersection(*all_backend_supports) + + if not common_supported_sizes: + error_msg = (f"No common block size for {kv_manager_block_size}. ") + for i, attn_group in enumerate(attn_groups): + supported = all_backend_supports[i] + error_msg += (f"Backend {attn_group.backend} supports: " + f"{sorted(supported)}. ") + raise ValueError(error_msg) + + if self.cache_config.block_size in common_supported_sizes: + return self.cache_config.block_size + + return max(common_supported_sizes) + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -3753,8 +3848,16 @@ def may_reinitialize_input_batch(self, block_sizes = [ kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups + if not isinstance(kv_cache_group.kv_cache_spec, + EncoderOnlyAttentionSpec) ] - if block_sizes != [self.cache_config.block_size]: + + # Generate kernel_block_sizes that matches each block_size + kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + + if block_sizes != [ + self.cache_config.block_size + ] or kernel_block_sizes != [self.cache_config.block_size]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 @@ -3767,6 +3870,7 @@ def may_reinitialize_input_batch(self, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, is_pooling_model=self.is_pooling_model, @@ -3814,6 +3918,46 @@ def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: for attn_groups in self.attn_groups: yield from attn_groups + def _prepare_kernel_block_sizes( + self, kv_cache_config: KVCacheConfig) -> list[int]: + """ + Generate kernel_block_sizes that matches each block_size. + + For attention backends that support virtual block splitting, + use the supported block sizes from the backend. + For other backends (like Mamba), use the same block size (no splitting). + + Args: + kv_cache_config: The KV cache configuration. + + Returns: + list[int]: List of kernel block sizes for each cache group. + """ + kernel_block_sizes = [] + for kv_cache_group_id, kv_cache_group in enumerate( + kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, + EncoderOnlyAttentionSpec): + continue + elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): + # This is an attention backend that supports virtual + # block splitting. Get the supported block sizes from + # all backends in the group. + attn_groups = self.attn_groups[kv_cache_group_id] + kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size + selected_kernel_size = self._select_common_block_size( + kv_manager_block_size, attn_groups) + kernel_block_sizes.append(selected_kernel_size) + elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + # This is likely Mamba or other non-attention cache, + # no splitting. + kernel_block_sizes.append( + kv_cache_group.kv_cache_spec.block_size) + else: + raise NotImplementedError( + f"unknown kv cache spec {kv_cache_group.kv_cache_spec}") + return kernel_block_sizes + def _reshape_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, @@ -3844,8 +3988,17 @@ def _reshape_kv_cache_tensors( kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): has_attn = True + kv_manager_block_size = kv_cache_spec.block_size + logical_kernel_size_list = \ + self._find_compatible_block_sizes( + kv_manager_block_size, attn_backend, return_all=False) + logical_kernel_size = logical_kernel_size_list[0] + num_blocks_per_phys_block = (kv_manager_block_size // + logical_kernel_size) + logical_num_blocks = num_blocks * num_blocks_per_phys_block + kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, + logical_num_blocks, logical_kernel_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype try: @@ -3997,10 +4150,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.may_reinitialize_input_batch(kv_cache_config) self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) + # Reinitialize need to after initialize_attn_backend + self.may_reinitialize_input_batch(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if self.speculative_config and self.speculative_config.use_eagle(): diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 4cd0ac352de0..42273b0df498 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -20,14 +20,15 @@ class InputBatch: def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group + kernel_block_sizes: list[int], ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -70,6 +71,7 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, ) # Sampling-related. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2405f978ca73..fb38757a9bb7 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -225,6 +225,7 @@ def __init__( pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.block_size], + kernel_block_sizes=[self.cache_config.block_size], ) # Cached torch/numpy tensor @@ -1669,6 +1670,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], + kernel_block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], ) # Verify dtype compatibility between block_table_cpu and input_batch assert self.block_table_cpu.dtype == self.input_batch.block_table[