From 05ef7f1e8f78951e1b54bc56bb08dd31489c3a68 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 15 Sep 2025 07:44:19 +0000 Subject: [PATCH 01/37] [Hybrid]: Decouple Logical Block Size from Physical Page Size Signed-off-by: lizhiyuan --- tests/v1/worker/test_gpu_model_runner.py | 156 ++++++++++++++++++ vllm/attention/backends/abstract.py | 10 ++ vllm/model_executor/models/config.py | 22 ++- vllm/platforms/cuda.py | 6 +- vllm/v1/attention/backends/mla/cutlass_mla.py | 4 + vllm/v1/attention/backends/mla/flashmla.py | 4 + vllm/v1/worker/block_table.py | 110 +++++++++++- vllm/v1/worker/gpu_input_batch.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 37 ++++- 9 files changed, 334 insertions(+), 18 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 6d99029e404e..53a222f69c13 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -841,3 +841,159 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): conv_blocks_constant) assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant) + + +def test_hybrid_block_table_initialization(): + """Test hybrid block table with different kernel and physical block + sizes.""" + from vllm.v1.worker.block_table import BlockTable + + # Test configuration: physical block size = 32, kernel block size = 16 + physical_block_size = 32 + kernel_block_sizes = [16] + max_num_reqs = 10 + max_num_blocks_per_req = 20 + max_num_batched_tokens = 512 + + block_table = BlockTable(block_size=physical_block_size, + max_num_reqs=max_num_reqs, + max_num_blocks_per_req=max_num_blocks_per_req, + max_num_batched_tokens=max_num_batched_tokens, + pin_memory=False, + device=torch.device(DEVICE), + kernel_sizes=kernel_block_sizes) + + # Verify hybrid block configuration + assert block_table.use_hybrid_blocks is True + assert block_table.physical_block_size == physical_block_size + assert block_table.logical_block_size == kernel_block_sizes[ + 0] # Changed to use first element + assert block_table.blocks_per_phys_block == ( + physical_block_size // kernel_block_sizes[0] + ) # Changed to use first element + + # Test block table conversion logic + # One physical block should map to multiple logical blocks + physical_blocks = [0, 1, 2] + + # Verify that physical blocks can be converted to logical blocks + # and that block table operations work correctly. + req_index = 0 + block_table.append_row(physical_blocks, req_index) + # Get expected logical blocks from the implementation for verification. + expected_logical_blocks = block_table._convert_physical_to_logical_blocks( + np.array(physical_blocks)) + # Verify block table state + assert block_table.num_blocks_per_row[req_index] == len( + expected_logical_blocks) + assert np.array_equal( + block_table.block_table_np[req_index, :len(expected_logical_blocks)], + expected_logical_blocks) + + +def test_input_batch_with_kernel_block_sizes(): + """Test InputBatch initialization with kernel_block_sizes parameter.""" + max_num_reqs = 10 + max_model_len = 512 + max_num_batched_tokens = 512 + device = torch.device(DEVICE) + pin_memory = False + vocab_size = 50272 + + # Test with different kernel block sizes + physical_block_sizes = [32, 64] + kernel_block_sizes = [[16], [32]] + + input_batch = InputBatch(max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + device=device, + pin_memory=pin_memory, + vocab_size=vocab_size, + block_sizes=physical_block_sizes, + kernel_block_sizes=kernel_block_sizes) + + # Verify that block tables were created with kernel block sizes + assert len( + input_batch.block_table.block_tables) == len(physical_block_sizes) + + for i, (phys_size, kernel_size_list) in enumerate( + zip(physical_block_sizes, kernel_block_sizes)): + block_table = input_batch.block_table.block_tables[i] + kernel_size = kernel_size_list[0] # Use first element from list + if phys_size != kernel_size: + assert block_table.use_hybrid_blocks is True + assert block_table.physical_block_size == phys_size + assert block_table.logical_block_size == kernel_size + else: + assert block_table.use_hybrid_blocks is False + assert block_table.physical_block_size == phys_size + assert block_table.logical_block_size == phys_size + + +def test_hybrid_cache_integration(model_runner, dist_init): + """Test hybrid cache architecture integration with GPUModelRunner.""" + # Create a new model runner with hybrid cache configuration + vllm_config = get_vllm_config() + + # Configure hybrid cache with different physical block size + vllm_config.cache_config.block_size = 32 + + model_config = vllm_config.model_config + num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) + head_size = model_config.get_head_size() + vllm_config.compilation_config.static_forward_context[ + "layer.0"] = Attention(num_heads, head_size, 0.1) + + runner = GPUModelRunner(vllm_config, DEVICE) + + # Initialize KV cache with configuration + attn_spec = FullAttentionSpec( + block_size=16, # Use logical block size directly + num_kv_heads=runner.model_config.get_num_kv_heads( + runner.parallel_config), + head_size=runner.model_config.get_head_size(), + dtype=runner.kv_cache_dtype, + use_mla=False, + ) + tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS + kv_cache_config = KVCacheConfig( + num_blocks=NUM_BLOCKS, + kv_cache_tensors=[ + KVCacheTensor(size=tensor_size, shared_by=["layer.0"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec) + ], + ) + runner.kv_cache_config = kv_cache_config + + # Initialize input batch with kernel block sizes + runner.input_batch = InputBatch( + max_num_reqs=runner.max_num_reqs, + max_model_len=runner.max_model_len, + max_num_batched_tokens=runner.max_num_tokens, + device=runner.device, + pin_memory=runner.pin_memory, + vocab_size=runner.model_config.get_vocab_size(), + block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], + kernel_block_sizes=[[16]]) # Use logical block size list + + runner.initialize_attn_backend(kv_cache_config) + + # Verify hybrid block table configuration + block_table = runner.input_batch.block_table.block_tables[0] + assert block_table.physical_block_size == ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size) + assert block_table.logical_block_size == 16 + + # Test request processing with hybrid blocks + req_id = "hybrid_req_0" + scheduler_output = _schedule_new_request(req_id) + + # Update states should work with hybrid blocks + runner._update_states(scheduler_output) + assert _is_req_scheduled(runner, req_id) + assert _is_req_state_block_table_match(runner, req_id) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0217bff6adaf..6df6eb8ef126 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -60,6 +60,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_state_cls() -> Type["AttentionState"]: raise NotImplementedError + @classmethod + def get_supported_block_size(cls) -> list[int]: + return cls.get_impl_cls().get_supported_block_size() + @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @@ -299,6 +303,12 @@ def __init__( ) -> None: raise NotImplementedError + @staticmethod + def get_supported_block_size() -> list[int]: + # [0] is a placeholder: the actual block size will be determined + # by config.block_size at runtime. + return [0] + @abstractmethod def forward( self, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 687af7a189ce..d25c13bd3160 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -374,12 +374,22 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=model_config.max_model_len, ).page_size_bytes - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, - 16 * attn_page_size_1_token) + # Attention backend constraints: + # - FlashAttention (FA) requires block size to be multiple of 16 + # - MLA (Multi-head Latent Attention) requires larger alignment: + # * CUTLASS_MLA backend: 128-byte alignment + # * Other MLA backends: 64-byte alignment + if model_config.use_mla: + use_cutlass_mla = (envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") + block_alignment_bytes = 128 if use_cutlass_mla else 64 + else: + block_alignment_bytes = 16 + + # Calculate minimum attention block size that satisfies both: + # 1. Backend alignment requirements (block_alignment_bytes) + # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) + attn_block_size = block_alignment_bytes * cdiv( + mamba_page_size, block_alignment_bytes * attn_page_size_1_token) # override attention block size if either (a) the # user has not set it or (b) the user has set it diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index dc94cfcc3ce8..084413608053 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -139,7 +139,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing - if model_config is not None and model_config.use_mla: + # Note: block_size is initialized in + # HybridAttentionMambaModelConfig.verify_and_update_config + # and doesn't need to be reinitialized here + if model_config is not None and model_config.use_mla \ + and cache_config.block_size is not None: # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 6017445402ec..2f57184161c6 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -39,6 +39,10 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder + @staticmethod + def get_supported_block_size() -> list[int]: + return [128] + class SM100Workspace: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 549af1a06225..b3ec16e7270b 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -42,6 +42,10 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl + @staticmethod + def get_supported_block_size() -> list[int]: + return [64] + @dataclass class FlashMLADecodeMetadata(MLACommonDecodeMetadata): diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 1901de6d2e5b..158d8474931e 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + import numpy as np import torch @@ -21,21 +23,53 @@ def __init__( max_num_batched_tokens: int, pin_memory: bool, device: torch.device, + kernel_block_size: int, ): - self.block_size = block_size self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device + self.physical_block_size = block_size + # Validate kernel_block_size and set up logical block configuration + if kernel_block_size <= 0: + raise ValueError(f"kernel_block_size must be positive, got {kernel_block_size}") + + if kernel_block_size == block_size: + # No splitting - use physical block size directly + self.block_size = block_size + self.logical_block_size = block_size + self.blocks_per_phys_block = 1 + self.use_hybrid_blocks = False + else: + # Validate that kernel_block_size divides physical_block_size evenly + if self.physical_block_size % kernel_block_size != 0: + raise ValueError( + f"kernel_block_size {kernel_block_size} must divide " + f"physical block size {self.physical_block_size} evenly") + + self.block_size = kernel_block_size + self.logical_block_size = kernel_block_size + self.blocks_per_phys_block = (self.physical_block_size // + self.logical_block_size) + if self.blocks_per_phys_block > 1: + self.use_hybrid_blocks = True + else: + self.use_hybrid_blocks = False + + if self.use_hybrid_blocks: + logical_table_size = (max_num_blocks_per_req * + self.blocks_per_phys_block) + else: + logical_table_size = max_num_blocks_per_req self.block_table = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), + (max_num_reqs, logical_table_size), device=self.device, dtype=torch.int32, ) self.block_table_cpu = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), + (max_num_reqs, logical_table_size), device="cpu", dtype=torch.int32, pin_memory=pin_memory, @@ -51,6 +85,14 @@ def __init__( self.slot_mapping = torch.zeros(self.max_num_batched_tokens, dtype=torch.int64, device=self.device) + + # Pre-compute bias array for physical to logical block conversion + if self.use_hybrid_blocks: + self._bias_array = np.arange(0, + self.blocks_per_phys_block).reshape( + 1, -1) + else: + self._bias_array = None try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group @@ -66,6 +108,11 @@ def append_row( ) -> None: if not block_ids: return + + if self.use_hybrid_blocks: + block_ids = self._convert_physical_to_logical_blocks( + np.array(block_ids)) + num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks @@ -105,8 +152,19 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. virtual_block_size = self.block_size * self.dcp_world_size - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions // virtual_block_size) + + # IMPORTANT: In hybrid mode, positions are in logical block space, + # but we need to map them to the correct logical block table indices + logical_block_idx = positions // virtual_block_size + + # Account for the expanded logical table + # (always needed with unified tensor) + # Each physical block is split into multiple logical blocks + # The logical table has been expanded to accommodate this + block_table_indices = (req_indices * self.max_num_blocks_per_req * + self.blocks_per_phys_block + + logical_block_idx) + block_numbers = self.block_table_np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local # tokens. @@ -120,8 +178,18 @@ def compute_slot_mapping(self, req_indices: np.ndarray, self.slot_mapping_np[:req_indices.shape[0]] = np.where( mask, slot_mapping, -1) else: - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions // self.block_size) + # IMPORTANT: In hybrid mode, positions are in logical block space, + # but we need to map them to the correct logical block table indices + logical_block_idx = positions // self.block_size + + # Account for the expanded logical table + # (always needed with unified tensor) + # Each physical block is split into multiple logical blocks + # The logical table has been expanded to accommodate this + block_table_indices = (req_indices * self.max_num_blocks_per_req * + self.blocks_per_phys_block + + logical_block_idx) + block_numbers = self.block_table_np.ravel()[block_table_indices] block_offsets = positions % self.block_size np.add(block_numbers * self.block_size, @@ -140,6 +208,17 @@ def clear(self) -> None: self.block_table.fill_(0) self.block_table_cpu.fill_(0) + def _convert_physical_to_logical_blocks( + self, physical_blocks: np.ndarray) -> np.ndarray: + """Convert physical block IDs to logical block IDs.""" + if not self.use_hybrid_blocks: + return physical_blocks + + logical_blocks = physical_blocks.reshape( + -1, 1) * self.blocks_per_phys_block + self._bias_array + + return logical_blocks.reshape(-1) + def get_device_tensor(self) -> torch.Tensor: """Returns the device tensor of the block table.""" return self.block_table @@ -163,7 +242,8 @@ def __init__(self, pin_memory: bool, device: torch.device, block_sizes: list[int], - num_speculative_tokens: int = 0) -> None: + num_speculative_tokens: int = 0, + kernel_sizes: Optional[list[list[int]]] = None) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, # so the block_size which used for calc max_num_blocks_per_req @@ -174,12 +254,24 @@ def __init__(self, # DCP might not be initialized in testing dcp_world_size = 1 + if kernel_sizes is None: + kernel_sizes = [[0]] * len(block_sizes) + # Ensure kernel_sizes matches block_sizes length + elif len(kernel_sizes) == 1 and len(block_sizes) > 1: + kernel_sizes = kernel_sizes * len(block_sizes) + elif len(kernel_sizes) != len(block_sizes): + raise ValueError( + f"kernel_sizes length ({len(kernel_sizes)}) must match " + f"block_sizes length ({len(block_sizes)})") + + # Use zip to pair block_sizes with kernel_sizes one-to-one self.block_tables = [ BlockTable( block_size, max_num_reqs, max(cdiv(max_model_len, block_size * dcp_world_size), 1 + num_speculative_tokens), max_num_batched_tokens, - pin_memory, device) for block_size in block_sizes + pin_memory, device, kernel_size_list) + for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) ] def append_row(self, block_ids: tuple[list[int], ...], diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1cf56656d7ad..386b218d5b35 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -84,6 +84,7 @@ def __init__( is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, + kernel_block_sizes: Optional[list[list[int]]] = None, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -129,7 +130,7 @@ def __init__( device=device, block_sizes=block_sizes, num_speculative_tokens=num_speculative_tokens, - ) + kernel_sizes=kernel_block_sizes) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1b785af96a9a..c8f02a17fb4c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -291,6 +291,7 @@ def __init__( self.is_pooling_model, self.vllm_config.model_config.logits_processors), is_pooling_model=self.is_pooling_model, + kernel_block_sizes=None, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -3328,6 +3329,37 @@ def may_reinitialize_input_batch(self, kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups ] + + # Generate kernel_block_sizes that matches each block_size + # For attention backends that support virtual block splitting, + # use the supported block sizes from the backend + # For other backends (like Mamba), use [0] (no splitting) + kernel_block_sizes = [] + for kv_cache_group_id, kv_cache_group in enumerate( + kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): + # This is an attention backend that supports virtual + # block splitting. Get the supported block sizes from + # the backend. + attn_groups = self.attn_groups[kv_cache_group_id] + if attn_groups: + # Use the backend's supported block size list + backend_cls = attn_groups[0].backend + supported_sizes = backend_cls.get_supported_block_size() + # If no specific sizes supported, use cache config + # block_size + kernel_block_size_list = (supported_sizes + if supported_sizes else + [self.cache_config.block_size]) + else: + # Fallback to cache config block_size if no backend found + kernel_block_size_list = [self.cache_config.block_size] + kernel_block_sizes.append(kernel_block_size_list) + else: + # This is likely Mamba or other non-attention cache, + # no splitting. + kernel_block_sizes.append([0]) + if block_sizes != [self.cache_config.block_size]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " @@ -3347,6 +3379,7 @@ def may_reinitialize_input_batch(self, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens if self.vllm_config.speculative_config else 0), + kernel_block_sizes=kernel_block_sizes, ) def _allocate_kv_cache_tensors( @@ -3420,6 +3453,7 @@ def _reshape_kv_cache_tensors( kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): has_attn = True + # FIXME here kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) @@ -3570,10 +3604,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.may_reinitialize_input_batch(kv_cache_config) self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) + # Reinitialize need to after initialize_attn_backend + self.may_reinitialize_input_batch(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if self.speculative_config and self.speculative_config.use_eagle(): From 0d1866835b1339e711cc894f29003de2e0a0ac0c Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 15 Sep 2025 07:49:07 +0000 Subject: [PATCH 02/37] Change kernel_sizes parameter to kernel_block_size in BlockTable Signed-off-by: lizhiyuan --- tests/v1/worker/test_gpu_model_runner.py | 9 +++---- vllm/v1/worker/block_table.py | 31 +++++++++++------------- vllm/v1/worker/gpu_input_batch.py | 4 +-- vllm/v1/worker/gpu_model_runner.py | 25 +++++++++++-------- 4 files changed, 35 insertions(+), 34 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 53a222f69c13..d052b7b6fd24 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -861,7 +861,7 @@ def test_hybrid_block_table_initialization(): max_num_batched_tokens=max_num_batched_tokens, pin_memory=False, device=torch.device(DEVICE), - kernel_sizes=kernel_block_sizes) + kernel_block_size=kernel_block_sizes[0]) # Verify hybrid block configuration assert block_table.use_hybrid_blocks is True @@ -902,7 +902,7 @@ def test_input_batch_with_kernel_block_sizes(): # Test with different kernel block sizes physical_block_sizes = [32, 64] - kernel_block_sizes = [[16], [32]] + kernel_block_sizes = [16, 32] input_batch = InputBatch(max_num_reqs=max_num_reqs, max_model_len=max_model_len, @@ -917,10 +917,9 @@ def test_input_batch_with_kernel_block_sizes(): assert len( input_batch.block_table.block_tables) == len(physical_block_sizes) - for i, (phys_size, kernel_size_list) in enumerate( + for i, (phys_size, kernel_size) in enumerate( zip(physical_block_sizes, kernel_block_sizes)): block_table = input_batch.block_table.block_tables[i] - kernel_size = kernel_size_list[0] # Use first element from list if phys_size != kernel_size: assert block_table.use_hybrid_blocks is True assert block_table.physical_block_size == phys_size @@ -979,7 +978,7 @@ def test_hybrid_cache_integration(model_runner, dist_init): block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], - kernel_block_sizes=[[16]]) # Use logical block size list + kernel_block_sizes=[16]) # Use logical block size runner.initialize_attn_backend(kv_cache_config) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 158d8474931e..de935c17328e 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Optional import numpy as np import torch @@ -31,10 +31,6 @@ def __init__( self.pin_memory = pin_memory self.device = device self.physical_block_size = block_size - # Validate kernel_block_size and set up logical block configuration - if kernel_block_size <= 0: - raise ValueError(f"kernel_block_size must be positive, got {kernel_block_size}") - if kernel_block_size == block_size: # No splitting - use physical block size directly self.block_size = block_size @@ -243,7 +239,7 @@ def __init__(self, device: torch.device, block_sizes: list[int], num_speculative_tokens: int = 0, - kernel_sizes: Optional[list[list[int]]] = None) -> None: + kernel_block_sizes: Optional[list[int]] = None) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, # so the block_size which used for calc max_num_blocks_per_req @@ -254,24 +250,25 @@ def __init__(self, # DCP might not be initialized in testing dcp_world_size = 1 - if kernel_sizes is None: - kernel_sizes = [[0]] * len(block_sizes) - # Ensure kernel_sizes matches block_sizes length - elif len(kernel_sizes) == 1 and len(block_sizes) > 1: - kernel_sizes = kernel_sizes * len(block_sizes) - elif len(kernel_sizes) != len(block_sizes): + if kernel_block_sizes is None: + # Use physical block size by default + kernel_block_sizes = block_sizes + # Ensure kernel_block_sizes matches block_sizes length + elif len(kernel_block_sizes) == 1 and len(block_sizes) > 1: + kernel_block_sizes = kernel_block_sizes * len(block_sizes) + elif len(kernel_block_sizes) != len(block_sizes): raise ValueError( - f"kernel_sizes length ({len(kernel_sizes)}) must match " - f"block_sizes length ({len(block_sizes)})") + f"kernel_block_sizes length ({len(kernel_block_sizes)}) " + f"must match block_sizes length ({len(block_sizes)})") - # Use zip to pair block_sizes with kernel_sizes one-to-one + # Use zip to pair block_sizes with kernel_block_sizes one-to-one self.block_tables = [ BlockTable( block_size, max_num_reqs, max(cdiv(max_model_len, block_size * dcp_world_size), 1 + num_speculative_tokens), max_num_batched_tokens, - pin_memory, device, kernel_size_list) - for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) + pin_memory, device, kernel_block_size) for block_size, + kernel_block_size in zip(block_sizes, kernel_block_sizes) ] def append_row(self, block_ids: tuple[list[int], ...], diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 386b218d5b35..4e03cd31eda2 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -84,7 +84,7 @@ def __init__( is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, - kernel_block_sizes: Optional[list[list[int]]] = None, + kernel_block_sizes: Optional[list[int]] = None, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -130,7 +130,7 @@ def __init__( device=device, block_sizes=block_sizes, num_speculative_tokens=num_speculative_tokens, - kernel_sizes=kernel_block_sizes) + kernel_block_sizes=kernel_block_sizes) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c8f02a17fb4c..cc682c0e0297 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -291,7 +291,7 @@ def __init__( self.is_pooling_model, self.vllm_config.model_config.logits_processors), is_pooling_model=self.is_pooling_model, - kernel_block_sizes=None, + kernel_block_sizes=[self.cache_config.block_size], ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -3342,23 +3342,28 @@ def may_reinitialize_input_batch(self, # block splitting. Get the supported block sizes from # the backend. attn_groups = self.attn_groups[kv_cache_group_id] + physical_block_size = kv_cache_group.kv_cache_spec.block_size if attn_groups: # Use the backend's supported block size list backend_cls = attn_groups[0].backend supported_sizes = backend_cls.get_supported_block_size() - # If no specific sizes supported, use cache config - # block_size - kernel_block_size_list = (supported_sizes - if supported_sizes else - [self.cache_config.block_size]) + # Select the first supported size that divides physical + # block size evenly + selected_kernel_size = physical_block_size + if supported_sizes: + for kernel_size in supported_sizes: + if (kernel_size > 0 and + physical_block_size % kernel_size == 0): + selected_kernel_size = kernel_size + break + kernel_block_sizes.append(selected_kernel_size) else: - # Fallback to cache config block_size if no backend found - kernel_block_size_list = [self.cache_config.block_size] - kernel_block_sizes.append(kernel_block_size_list) + kernel_block_sizes.append(physical_block_size) else: # This is likely Mamba or other non-attention cache, # no splitting. - kernel_block_sizes.append([0]) + kernel_block_sizes.append( + kv_cache_group.kv_cache_spec.block_size) if block_sizes != [self.cache_config.block_size]: assert self.cache_config.cpu_offload_gb == 0, ( From bded2b4e38528ac64778982f85df20dbdbd948e2 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 15 Sep 2025 07:55:58 +0000 Subject: [PATCH 03/37] fix condition for may_reinitialize_input_batch Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cc682c0e0297..034351b324d8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3365,7 +3365,8 @@ def may_reinitialize_input_batch(self, kernel_block_sizes.append( kv_cache_group.kv_cache_spec.block_size) - if block_sizes != [self.cache_config.block_size]: + if block_sizes != [self.cache_config.block_size + ] or block_sizes != kernel_block_sizes: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 From 90c14ab11f3023b8b59c9b29c80be4ceb5f206e9 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 15 Sep 2025 08:05:37 +0000 Subject: [PATCH 04/37] update shapes when getting kv_cache_shape in ModelRunner Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 44 ++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 034351b324d8..39deb383ad62 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3315,6 +3315,28 @@ def calculate_reorder_batch_threshold(self) -> None: else: self.reorder_batch_threshold = reorder_batch_threshold_i + def _select_kernel_block_size(self, physical_block_size: int, + backend_cls: type[AttentionBackend]) -> int: + """ + Select the optimal kernel block size for a given physical block size. + + Args: + physical_block_size: The physical block size of the KV cache + backend_cls: The attention backend class + + Returns: + The selected kernel block size + """ + supported_sizes = backend_cls.get_supported_block_size() + selected_kernel_size = physical_block_size + if supported_sizes: + for kernel_size in supported_sizes: + if (kernel_size > 0 + and physical_block_size % kernel_size == 0): + selected_kernel_size = kernel_size + break + return selected_kernel_size + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -3346,16 +3368,8 @@ def may_reinitialize_input_batch(self, if attn_groups: # Use the backend's supported block size list backend_cls = attn_groups[0].backend - supported_sizes = backend_cls.get_supported_block_size() - # Select the first supported size that divides physical - # block size evenly - selected_kernel_size = physical_block_size - if supported_sizes: - for kernel_size in supported_sizes: - if (kernel_size > 0 and - physical_block_size % kernel_size == 0): - selected_kernel_size = kernel_size - break + selected_kernel_size = self._select_kernel_block_size( + physical_block_size, backend_cls) kernel_block_sizes.append(selected_kernel_size) else: kernel_block_sizes.append(physical_block_size) @@ -3459,9 +3473,15 @@ def _reshape_kv_cache_tensors( kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): has_attn = True - # FIXME here + physical_block_size = kv_cache_spec.block_size + logical_kernel_size = self._select_kernel_block_size( + physical_block_size, attn_backend) + num_blocks_per_phys_block = (physical_block_size // + logical_kernel_size) + logical_num_blocks = num_blocks * num_blocks_per_phys_block + kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, + logical_num_blocks, logical_kernel_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype try: From 613f4c63225d01e5bd7f50f27040a793a1e68a91 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 15 Sep 2025 09:28:19 +0000 Subject: [PATCH 05/37] minor fix Signed-off-by: lizhiyuan --- vllm/attention/backends/abstract.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 6df6eb8ef126..05831c36426a 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -305,9 +305,9 @@ def __init__( @staticmethod def get_supported_block_size() -> list[int]: - # [0] is a placeholder: the actual block size will be determined + # [16] is a placeholder: the actual block size will be determined # by config.block_size at runtime. - return [0] + return [16] @abstractmethod def forward( From edfdf8d94695aa3fc9e0e0652306279238f41bbc Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 15 Sep 2025 10:19:20 +0000 Subject: [PATCH 06/37] Fix embedding test Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 39deb383ad62..a082a15eb36d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3157,12 +3157,25 @@ def _capture_cudagraphs(self, compilation_cases: list[int], remove_lora=False) self.maybe_remove_all_loras(self.lora_config) - def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + def initialize_attn_backend(self, + kv_cache_config: KVCacheConfig, + incremental: bool = False) -> None: """ Initialize the attention backends and attention metadata builders. + + Args: + kv_cache_config: The KV cache configuration + incremental: If True, only initialize backends + for newly added groups. + If False, initialize all groups + (asserts attn_groups is empty). """ - assert len(self.attn_groups) == 0, \ - "Attention backends are already initialized" + if not incremental: + assert len(self.attn_groups) == 0, \ + "Attention backends are already initialized" + + # If incremental, start from the existing group count + start_group_idx = len(self.attn_groups) if incremental else 0 def get_attn_backends_for_layers( layer_names: list[str] @@ -3212,7 +3225,9 @@ def create_attn_groups( attn_groups.append(attn_group) return attn_groups - for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + # Only process groups starting from start_group_idx + for i in range(start_group_idx, len(kv_cache_config.kv_cache_groups)): + kv_cache_group_spec = kv_cache_config.kv_cache_groups[i] kv_cache_spec = kv_cache_group_spec.kv_cache_spec attn_backends = get_attn_backends_for_layers( kv_cache_group_spec.layer_names) @@ -3630,11 +3645,12 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.may_add_encoder_only_layers_to_kv_cache_config() - self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) - # Reinitialize need to after initialize_attn_backend + # Reinitialize input batch (depends on attn_groups) self.may_reinitialize_input_batch(kv_cache_config) + self.may_add_encoder_only_layers_to_kv_cache_config() + self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) + self.initialize_attn_backend(kv_cache_config, incremental=True) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if self.speculative_config and self.speculative_config.use_eagle(): From 28e94eb470d3551a3bdc26e1c3af3809c4ecb942 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 15 Sep 2025 12:31:33 +0000 Subject: [PATCH 07/37] making max_num_blocks_per_req to save max number of logical blocks Signed-off-by: lizhiyuan --- vllm/v1/worker/block_table.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index de935c17328e..8575521f6ae8 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -26,7 +26,6 @@ def __init__( kernel_block_size: int, ): self.max_num_reqs = max_num_reqs - self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device @@ -53,6 +52,8 @@ def __init__( else: self.use_hybrid_blocks = False + self.max_num_blocks_per_req = max_num_blocks_per_req * \ + self.blocks_per_phys_block if self.use_hybrid_blocks: logical_table_size = (max_num_blocks_per_req * self.blocks_per_phys_block) @@ -157,8 +158,7 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # (always needed with unified tensor) # Each physical block is split into multiple logical blocks # The logical table has been expanded to accommodate this - block_table_indices = (req_indices * self.max_num_blocks_per_req * - self.blocks_per_phys_block + + block_table_indices = (req_indices * self.max_num_blocks_per_req + logical_block_idx) block_numbers = self.block_table_np.ravel()[block_table_indices] @@ -182,8 +182,7 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # (always needed with unified tensor) # Each physical block is split into multiple logical blocks # The logical table has been expanded to accommodate this - block_table_indices = (req_indices * self.max_num_blocks_per_req * - self.blocks_per_phys_block + + block_table_indices = (req_indices * self.max_num_blocks_per_req + logical_block_idx) block_numbers = self.block_table_np.ravel()[block_table_indices] From b1d3dcc49585c05c9ccabd7e312454bee34d5e63 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 15 Sep 2025 12:48:38 +0000 Subject: [PATCH 08/37] change design Signed-off-by: lizhiyuan --- vllm/attention/backends/abstract.py | 15 +++++++++++---- vllm/v1/attention/backends/mla/cutlass_mla.py | 5 +++-- vllm/v1/attention/backends/mla/flashmla.py | 5 +++-- vllm/v1/worker/gpu_model_runner.py | 18 +++++++++++------- 4 files changed, 28 insertions(+), 15 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 05831c36426a..4649fc8c74c7 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from dataclasses import dataclass, fields from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, - Protocol, Set, Tuple, Type, TypeVar) + Protocol, Set, Tuple, Type, TypeVar, Union) import torch @@ -33,6 +33,13 @@ class AttentionType: ENCODER_DECODER = "encoder_decoder" +class MultipleOf: + base: int + + def __init__(self, base: int): + self.base = base + + class AttentionBackend(ABC): """Abstract class for attention backends.""" # For some attention backends, we allocate an output tensor before @@ -61,7 +68,7 @@ def get_state_cls() -> Type["AttentionState"]: raise NotImplementedError @classmethod - def get_supported_block_size(cls) -> list[int]: + def get_supported_block_size(cls) -> list[Union[int, MultipleOf]]: return cls.get_impl_cls().get_supported_block_size() @classmethod @@ -304,10 +311,10 @@ def __init__( raise NotImplementedError @staticmethod - def get_supported_block_size() -> list[int]: + def get_supported_block_size() -> list[Union[int, MultipleOf]]: # [16] is a placeholder: the actual block size will be determined # by config.block_size at runtime. - return [16] + return [MultipleOf(16)] @abstractmethod def forward( diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 2f57184161c6..1f4bc8b234ea 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -2,12 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch import vllm._custom_ops as ops from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, + MultipleOf, is_quantized_kv_cache) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -40,7 +41,7 @@ def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder @staticmethod - def get_supported_block_size() -> list[int]: + def get_supported_block_size() -> list[Union[int, MultipleOf]]: return [128] diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index b3ec16e7270b..e83afe7ae0b0 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -6,7 +6,8 @@ import torch -from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, + MultipleOf) from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) @@ -43,7 +44,7 @@ def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl @staticmethod - def get_supported_block_size() -> list[int]: + def get_supported_block_size() -> list[Union[int, MultipleOf]]: return [64] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a082a15eb36d..a74360ac9053 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -18,7 +18,7 @@ import vllm.envs as envs from vllm.attention import Attention, AttentionType -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, MultipleOf from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper @@ -3342,13 +3342,17 @@ def _select_kernel_block_size(self, physical_block_size: int, Returns: The selected kernel block size """ - supported_sizes = backend_cls.get_supported_block_size() + supported_constraints = backend_cls.get_supported_block_size() selected_kernel_size = physical_block_size - if supported_sizes: - for kernel_size in supported_sizes: - if (kernel_size > 0 - and physical_block_size % kernel_size == 0): - selected_kernel_size = kernel_size + if supported_constraints: + for constraint in supported_constraints: + if (isinstance(constraint, int) + and physical_block_size % constraint == 0): + selected_kernel_size = constraint + break + elif (isinstance(constraint, MultipleOf) + and physical_block_size % constraint.base == 0): + selected_kernel_size = constraint.base break return selected_kernel_size From e10d70a5700e1ef4d6fbb954b152edd637b4f014 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 15 Sep 2025 13:26:25 +0000 Subject: [PATCH 09/37] clean codes Signed-off-by: lizhiyuan --- vllm/v1/worker/block_table.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 8575521f6ae8..49f73be8006a 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -37,7 +37,6 @@ def __init__( self.blocks_per_phys_block = 1 self.use_hybrid_blocks = False else: - # Validate that kernel_block_size divides physical_block_size evenly if self.physical_block_size % kernel_block_size != 0: raise ValueError( f"kernel_block_size {kernel_block_size} must divide " @@ -54,19 +53,14 @@ def __init__( self.max_num_blocks_per_req = max_num_blocks_per_req * \ self.blocks_per_phys_block - if self.use_hybrid_blocks: - logical_table_size = (max_num_blocks_per_req * - self.blocks_per_phys_block) - else: - logical_table_size = max_num_blocks_per_req self.block_table = torch.zeros( - (max_num_reqs, logical_table_size), + (max_num_reqs, self.max_num_blocks_per_req), device=self.device, dtype=torch.int32, ) self.block_table_cpu = torch.zeros( - (max_num_reqs, logical_table_size), + (max_num_reqs, self.max_num_blocks_per_req), device="cpu", dtype=torch.int32, pin_memory=pin_memory, From 2ce97c4240364679a19f14f520f72ae70b87dfc5 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Tue, 16 Sep 2025 04:29:05 +0000 Subject: [PATCH 10/37] clean imps Signed-off-by: lizhiyuan --- vllm/attention/backends/abstract.py | 5 +- vllm/v1/worker/block_table.py | 77 +++++++++++++++++------------ vllm/v1/worker/gpu_input_batch.py | 6 +-- vllm/v1/worker/gpu_model_runner.py | 66 +++++++++++++++++-------- vllm/v1/worker/tpu_input_batch.py | 18 ++++--- vllm/v1/worker/tpu_model_runner.py | 4 ++ 6 files changed, 111 insertions(+), 65 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 4649fc8c74c7..a0ed1e4ff835 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -312,9 +312,8 @@ def __init__( @staticmethod def get_supported_block_size() -> list[Union[int, MultipleOf]]: - # [16] is a placeholder: the actual block size will be determined - # by config.block_size at runtime. - return [MultipleOf(16)] + # TODO: implement this function for all backends. + return [MultipleOf(1)] @abstractmethod def forward( diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 49f73be8006a..14815db80552 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - import numpy as np import torch @@ -25,27 +23,42 @@ def __init__( device: torch.device, kernel_block_size: int, ): + """Manages the mapping between logical and physical memory blocks + for KV cache. + + The BlockTable handles the conversion between kv_manager_block_size size + (actual memory allocation) and kernel block size (computation + granularity). When these sizes differ, it implements a hybrid block + system for memory efficiency. + + Args: + block_size: kv_manager_block size + kernel_block_size: Kernel block size - the granularity at which + attention kernels operate during computation. + max_num_reqs: Maximum number of concurrent requests supported. + max_num_blocks_per_req: Maximum number of blocks per request. + max_num_batched_tokens: Maximum number of tokens in a batch. + pin_memory: Whether to pin memory for faster GPU transfers. + device: Target device for the block table. + """ self.max_num_reqs = max_num_reqs self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device - self.physical_block_size = block_size + if kernel_block_size == block_size: - # No splitting - use physical block size directly + # No splitting - use kv_manager_block_size size directly self.block_size = block_size - self.logical_block_size = block_size self.blocks_per_phys_block = 1 self.use_hybrid_blocks = False else: - if self.physical_block_size % kernel_block_size != 0: + if block_size % kernel_block_size != 0: raise ValueError( f"kernel_block_size {kernel_block_size} must divide " - f"physical block size {self.physical_block_size} evenly") + f"kv_manager_block_size size {block_size} evenly") self.block_size = kernel_block_size - self.logical_block_size = kernel_block_size - self.blocks_per_phys_block = (self.physical_block_size // - self.logical_block_size) + self.blocks_per_phys_block = (block_size // kernel_block_size) if self.blocks_per_phys_block > 1: self.use_hybrid_blocks = True else: @@ -144,13 +157,11 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # for block_table_indices calculation. virtual_block_size = self.block_size * self.dcp_world_size - # IMPORTANT: In hybrid mode, positions are in logical block space, - # but we need to map them to the correct logical block table indices logical_block_idx = positions // virtual_block_size # Account for the expanded logical table # (always needed with unified tensor) - # Each physical block is split into multiple logical blocks + # Each kv_manager_block_size is split into multiple logical blocks # The logical table has been expanded to accommodate this block_table_indices = (req_indices * self.max_num_blocks_per_req + logical_block_idx) @@ -168,14 +179,7 @@ def compute_slot_mapping(self, req_indices: np.ndarray, self.slot_mapping_np[:req_indices.shape[0]] = np.where( mask, slot_mapping, -1) else: - # IMPORTANT: In hybrid mode, positions are in logical block space, - # but we need to map them to the correct logical block table indices logical_block_idx = positions // self.block_size - - # Account for the expanded logical table - # (always needed with unified tensor) - # Each physical block is split into multiple logical blocks - # The logical table has been expanded to accommodate this block_table_indices = (req_indices * self.max_num_blocks_per_req + logical_block_idx) @@ -198,12 +202,25 @@ def clear(self) -> None: self.block_table_cpu.fill_(0) def _convert_physical_to_logical_blocks( - self, physical_blocks: np.ndarray) -> np.ndarray: - """Convert physical block IDs to logical block IDs.""" + self, kv_manager_block_size: np.ndarray) -> np.ndarray: + """Convert kv_manager_block_size IDs to logical block IDs. + + Example: + # kv_manager_block_size: 32 tokens, + # Kernel block size: 16 tokens + # blocks_per_phys_block = 2 + >>> kv_manager_block_size = np.array([0, 1, 2]) + >>> Result: [0, 1, 2, 3, 4, 5] + + # Each kv_manager_block_size maps to 2 logical blocks: + # kv_manager_block_size 0 → Logical blocks [0, 1] + # kv_manager_block_size 1 → Logical blocks [2, 3] + # kv_manager_block_size 2 → Logical blocks [4, 5] + """ if not self.use_hybrid_blocks: - return physical_blocks + return kv_manager_block_size - logical_blocks = physical_blocks.reshape( + logical_blocks = kv_manager_block_size.reshape( -1, 1) * self.blocks_per_phys_block + self._bias_array return logical_blocks.reshape(-1) @@ -231,8 +248,8 @@ def __init__(self, pin_memory: bool, device: torch.device, block_sizes: list[int], - num_speculative_tokens: int = 0, - kernel_block_sizes: Optional[list[int]] = None) -> None: + kernel_block_sizes: list[int], + num_speculative_tokens: int = 0) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, # so the block_size which used for calc max_num_blocks_per_req @@ -244,12 +261,10 @@ def __init__(self, dcp_world_size = 1 if kernel_block_sizes is None: - # Use physical block size by default + # Use kv_manager_block_size size by default kernel_block_sizes = block_sizes - # Ensure kernel_block_sizes matches block_sizes length - elif len(kernel_block_sizes) == 1 and len(block_sizes) > 1: - kernel_block_sizes = kernel_block_sizes * len(block_sizes) - elif len(kernel_block_sizes) != len(block_sizes): + + if len(kernel_block_sizes) != len(block_sizes): raise ValueError( f"kernel_block_sizes length ({len(kernel_block_sizes)}) " f"must match block_sizes length ({len(block_sizes)})") diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 4e03cd31eda2..7cc3b7c8c8a9 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -80,11 +80,11 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + kernel_block_sizes: list[int], logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, - kernel_block_sizes: Optional[list[int]] = None, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -129,8 +129,8 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, - num_speculative_tokens=num_speculative_tokens, - kernel_block_sizes=kernel_block_sizes) + kernel_block_sizes=kernel_block_sizes, + num_speculative_tokens=num_speculative_tokens) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a74360ac9053..81422fc74de0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -285,13 +285,13 @@ def __init__( pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], + kernel_block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config, self.device, self.pin_memory, self.is_pooling_model, self.vllm_config.model_config.logits_processors), is_pooling_model=self.is_pooling_model, - kernel_block_sizes=[self.cache_config.block_size], ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -3330,30 +3330,56 @@ def calculate_reorder_batch_threshold(self) -> None: else: self.reorder_batch_threshold = reorder_batch_threshold_i - def _select_kernel_block_size(self, physical_block_size: int, + def _select_kernel_block_size(self, kv_manager_block_size: int, backend_cls: type[AttentionBackend]) -> int: """ Select the optimal kernel block size for a given physical block size. Args: - physical_block_size: The physical block size of the KV cache + kv_manager_block_size: The physical block size of the KV cache backend_cls: The attention backend class Returns: The selected kernel block size + + Raises: + ValueError: If no valid kernel block size can be found that + satisfies the backend's constraints """ supported_constraints = backend_cls.get_supported_block_size() - selected_kernel_size = physical_block_size - if supported_constraints: + selected_kernel_size = kv_manager_block_size + constraint_satisfied = False + + for constraint in supported_constraints: + if (isinstance(constraint, int) + and kv_manager_block_size % constraint == 0): + selected_kernel_size = constraint + constraint_satisfied = True + break + elif (isinstance(constraint, MultipleOf) + and kv_manager_block_size % constraint.base == 0): + selected_kernel_size = constraint.base + constraint_satisfied = True + break + + if not constraint_satisfied and supported_constraints: + # Only raise error if there are actual constraints to satisfy + # and none of them were met + constraint_strs = [] for constraint in supported_constraints: - if (isinstance(constraint, int) - and physical_block_size % constraint == 0): - selected_kernel_size = constraint - break - elif (isinstance(constraint, MultipleOf) - and physical_block_size % constraint.base == 0): - selected_kernel_size = constraint.base - break + if isinstance(constraint, int): + constraint_strs.append(f"{constraint}") + elif isinstance(constraint, MultipleOf): + constraint_strs.append(f"multiple of {constraint.base}") + + raise ValueError( + f"Physical block size {kv_manager_block_size} does not " + f"satisfy any constraints for {backend_cls.__name__} " + f"backend. Supported constraints: " + f"{', '.join(constraint_strs)}. " + f"The physical block size must be compatible with at least " + f"one constraint.") + return selected_kernel_size def may_reinitialize_input_batch(self, @@ -3383,15 +3409,15 @@ def may_reinitialize_input_batch(self, # block splitting. Get the supported block sizes from # the backend. attn_groups = self.attn_groups[kv_cache_group_id] - physical_block_size = kv_cache_group.kv_cache_spec.block_size + kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size if attn_groups: # Use the backend's supported block size list backend_cls = attn_groups[0].backend selected_kernel_size = self._select_kernel_block_size( - physical_block_size, backend_cls) + kv_manager_block_size, backend_cls) kernel_block_sizes.append(selected_kernel_size) else: - kernel_block_sizes.append(physical_block_size) + kernel_block_sizes.append(kv_manager_block_size) else: # This is likely Mamba or other non-attention cache, # no splitting. @@ -3412,13 +3438,13 @@ def may_reinitialize_input_batch(self, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, is_pooling_model=self.is_pooling_model, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens if self.vllm_config.speculative_config else 0), - kernel_block_sizes=kernel_block_sizes, ) def _allocate_kv_cache_tensors( @@ -3492,10 +3518,10 @@ def _reshape_kv_cache_tensors( kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): has_attn = True - physical_block_size = kv_cache_spec.block_size + kv_manager_block_size = kv_cache_spec.block_size logical_kernel_size = self._select_kernel_block_size( - physical_block_size, attn_backend) - num_blocks_per_phys_block = (physical_block_size // + kv_manager_block_size, attn_backend) + num_blocks_per_phys_block = (kv_manager_block_size // logical_kernel_size) logical_num_blocks = num_blocks * num_blocks_per_phys_block diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 81c798685cb3..123ba70a3ac5 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -20,14 +20,15 @@ class InputBatch: def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group + kernel_block_sizes: list[int], ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -70,6 +71,7 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, ) # Sampling-related. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 15af7ffac809..d4c74a39510f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -223,6 +223,7 @@ def __init__( pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.block_size], + kernel_block_sizes=[self.cache_config.block_size], ) # Cached torch/numpy tensor @@ -1617,6 +1618,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], + kernel_block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], ) # Verify dtype compatibility between block_table_cpu and input_batch assert self.block_table_cpu.dtype == self.input_batch.block_table[ From 0909efd8dcdd25fd4d5e5de2813c7aea6ac0c9e1 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Tue, 16 Sep 2025 04:35:21 +0000 Subject: [PATCH 11/37] Revert "Fix embedding test" This reverts commit edfdf8d94695aa3fc9e0e0652306279238f41bbc. Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 30 +++++++----------------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 81422fc74de0..65e8fd3ba9ad 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3157,25 +3157,12 @@ def _capture_cudagraphs(self, compilation_cases: list[int], remove_lora=False) self.maybe_remove_all_loras(self.lora_config) - def initialize_attn_backend(self, - kv_cache_config: KVCacheConfig, - incremental: bool = False) -> None: + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. - - Args: - kv_cache_config: The KV cache configuration - incremental: If True, only initialize backends - for newly added groups. - If False, initialize all groups - (asserts attn_groups is empty). """ - if not incremental: - assert len(self.attn_groups) == 0, \ - "Attention backends are already initialized" - - # If incremental, start from the existing group count - start_group_idx = len(self.attn_groups) if incremental else 0 + assert len(self.attn_groups) == 0, \ + "Attention backends are already initialized" def get_attn_backends_for_layers( layer_names: list[str] @@ -3225,9 +3212,7 @@ def create_attn_groups( attn_groups.append(attn_group) return attn_groups - # Only process groups starting from start_group_idx - for i in range(start_group_idx, len(kv_cache_config.kv_cache_groups)): - kv_cache_group_spec = kv_cache_config.kv_cache_groups[i] + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group_spec.kv_cache_spec attn_backends = get_attn_backends_for_layers( kv_cache_group_spec.layer_names) @@ -3675,12 +3660,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.initialize_attn_backend(kv_cache_config) - # Reinitialize input batch (depends on attn_groups) - self.may_reinitialize_input_batch(kv_cache_config) self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) - self.initialize_attn_backend(kv_cache_config, incremental=True) + self.initialize_attn_backend(kv_cache_config) + # Reinitialize need to after initialize_attn_backend + self.may_reinitialize_input_batch(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if self.speculative_config and self.speculative_config.use_eagle(): From 097c11c268090a6609abe7881eaf03a97125e6af Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Tue, 16 Sep 2025 04:38:30 +0000 Subject: [PATCH 12/37] remove conditions since attn_groups won't be empty Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 65e8fd3ba9ad..38b1e12f0ed9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3395,14 +3395,10 @@ def may_reinitialize_input_batch(self, # the backend. attn_groups = self.attn_groups[kv_cache_group_id] kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size - if attn_groups: - # Use the backend's supported block size list - backend_cls = attn_groups[0].backend - selected_kernel_size = self._select_kernel_block_size( - kv_manager_block_size, backend_cls) - kernel_block_sizes.append(selected_kernel_size) - else: - kernel_block_sizes.append(kv_manager_block_size) + backend_cls = attn_groups[0].backend + selected_kernel_size = self._select_kernel_block_size( + kv_manager_block_size, backend_cls) + kernel_block_sizes.append(selected_kernel_size) else: # This is likely Mamba or other non-attention cache, # no splitting. From 0e6ae07e1eb44df03abe27f4b59bd0cc1ca90770 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Tue, 16 Sep 2025 04:41:31 +0000 Subject: [PATCH 13/37] fix conditions Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 38b1e12f0ed9..884ad92300f3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3405,8 +3405,9 @@ def may_reinitialize_input_batch(self, kernel_block_sizes.append( kv_cache_group.kv_cache_spec.block_size) - if block_sizes != [self.cache_config.block_size - ] or block_sizes != kernel_block_sizes: + if block_sizes != [ + self.cache_config.block_size + ] or kernel_block_sizes != [self.cache_config.block_size]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 From 3fd0727b9d2982bcf6fefb57a3e46b12d6281709 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Tue, 16 Sep 2025 04:48:32 +0000 Subject: [PATCH 14/37] find largest block_size to reduce overhead Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 884ad92300f3..a9a0b5347993 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3325,7 +3325,7 @@ def _select_kernel_block_size(self, kv_manager_block_size: int, backend_cls: The attention backend class Returns: - The selected kernel block size + The selected kernel block size (largest available) Raises: ValueError: If no valid kernel block size can be found that @@ -3334,18 +3334,19 @@ def _select_kernel_block_size(self, kv_manager_block_size: int, supported_constraints = backend_cls.get_supported_block_size() selected_kernel_size = kv_manager_block_size constraint_satisfied = False + valid_constraints = [] for constraint in supported_constraints: if (isinstance(constraint, int) and kv_manager_block_size % constraint == 0): - selected_kernel_size = constraint - constraint_satisfied = True - break + valid_constraints.append(constraint) elif (isinstance(constraint, MultipleOf) and kv_manager_block_size % constraint.base == 0): - selected_kernel_size = constraint.base - constraint_satisfied = True - break + valid_constraints.append(constraint.base) + + if valid_constraints: + selected_kernel_size = max(valid_constraints) + constraint_satisfied = True if not constraint_satisfied and supported_constraints: # Only raise error if there are actual constraints to satisfy From ff983aff661cc9ed59280d7a38118ebb964e9b55 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Tue, 16 Sep 2025 06:24:00 +0000 Subject: [PATCH 15/37] fix default block size Signed-off-by: lizhiyuan --- vllm/v1/attention/backends/flash_attn.py | 7 ++++++- vllm/v1/attention/backends/flashinfer.py | 6 +++++- vllm/v1/attention/backends/rocm_aiter_fa.py | 9 +++++++-- vllm/v1/attention/backends/tree_attn.py | 9 +++++++-- vllm/v1/attention/backends/triton_attn.py | 9 +++++++-- vllm/v1/attention/backends/xformers.py | 9 +++++++-- 6 files changed, 39 insertions(+), 10 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 20f1904b3be6..8e62a18d7770 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import numpy as np import torch @@ -10,6 +10,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, is_quantized_kv_cache) from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states @@ -49,6 +50,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 9e05cc8ab2f1..73bdd6037fa6 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -17,7 +17,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionType) + AttentionType, MultipleOf) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -61,6 +61,10 @@ def get_supported_head_sizes(cls) -> list[int]: # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 return [64, 128, 256] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index a4e2758bd311..75546bf692da 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,12 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + MultipleOf) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -327,6 +328,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 10238f36455d..904da6f9b8ef 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,12 +4,13 @@ import ast from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + MultipleOf) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger @@ -39,6 +40,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index c294a5a73cbd..e565f7d3ed85 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -3,13 +3,14 @@ """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass from functools import cache -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + MultipleOf) from vllm.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode) from vllm.attention.ops.paged_attn import PagedAttention @@ -150,6 +151,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index a6ca33491235..46f69e7da283 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,12 +3,13 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar, Optional, Union import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + MultipleOf) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger @@ -77,6 +78,10 @@ def get_supported_head_sizes(cls) -> list[int]: 256, ] + @staticmethod + def get_supported_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() From e869bf0f4df4e1b1079458c88c39eb95773f5ae4 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Tue, 16 Sep 2025 08:59:04 +0000 Subject: [PATCH 16/37] fix embedding test Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a9a0b5347993..d6edf3907d03 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3381,6 +3381,8 @@ def may_reinitialize_input_batch(self, block_sizes = [ kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups + if not isinstance(kv_cache_group.kv_cache_spec, + EncoderOnlyAttentionSpec) ] # Generate kernel_block_sizes that matches each block_size @@ -3390,7 +3392,10 @@ def may_reinitialize_input_batch(self, kernel_block_sizes = [] for kv_cache_group_id, kv_cache_group in enumerate( kv_cache_config.kv_cache_groups): - if isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): + if isinstance(kv_cache_group.kv_cache_spec, + EncoderOnlyAttentionSpec): + continue + elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): # This is an attention backend that supports virtual # block splitting. Get the supported block sizes from # the backend. From 3bb83b966f0f2ee5db38f06417d8e13a925c90fa Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Tue, 16 Sep 2025 09:10:37 +0000 Subject: [PATCH 17/37] change params Signed-off-by: lizhiyuan --- vllm/v1/worker/block_table.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 14815db80552..6065808c31fa 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -202,28 +202,28 @@ def clear(self) -> None: self.block_table_cpu.fill_(0) def _convert_physical_to_logical_blocks( - self, kv_manager_block_size: np.ndarray) -> np.ndarray: - """Convert kv_manager_block_size IDs to logical block IDs. + self, kv_manager_block_id: np.ndarray) -> np.ndarray: + """Convert kv_manager_block_id IDs to logical block IDs. Example: - # kv_manager_block_size: 32 tokens, + # kv_manager_block_id: 32 tokens, # Kernel block size: 16 tokens # blocks_per_phys_block = 2 - >>> kv_manager_block_size = np.array([0, 1, 2]) + >>> kv_manager_block_id = np.array([0, 1, 2]) >>> Result: [0, 1, 2, 3, 4, 5] - # Each kv_manager_block_size maps to 2 logical blocks: - # kv_manager_block_size 0 → Logical blocks [0, 1] - # kv_manager_block_size 1 → Logical blocks [2, 3] - # kv_manager_block_size 2 → Logical blocks [4, 5] + # Each kv_manager_block_id maps to 2 logical block id: + # kv_manager_block_id 0 → Logical block id [0, 1] + # kv_manager_block_id 1 → Logical block id [2, 3] + # kv_manager_block_id 2 → Logical block id [4, 5] """ if not self.use_hybrid_blocks: - return kv_manager_block_size + return kv_manager_block_id - logical_blocks = kv_manager_block_size.reshape( + logical_block_id = kv_manager_block_id.reshape( -1, 1) * self.blocks_per_phys_block + self._bias_array - return logical_blocks.reshape(-1) + return logical_block_id.reshape(-1) def get_device_tensor(self) -> torch.Tensor: """Returns the device tensor of the block table.""" From 8a7c2b6904fcd9223e41ea113b453b6df1e4be4c Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Tue, 16 Sep 2025 09:11:52 +0000 Subject: [PATCH 18/37] clean unused branch Signed-off-by: lizhiyuan --- vllm/v1/worker/block_table.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 6065808c31fa..4c1848ab165a 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -59,10 +59,7 @@ def __init__( self.block_size = kernel_block_size self.blocks_per_phys_block = (block_size // kernel_block_size) - if self.blocks_per_phys_block > 1: - self.use_hybrid_blocks = True - else: - self.use_hybrid_blocks = False + self.use_hybrid_blocks = False self.max_num_blocks_per_req = max_num_blocks_per_req * \ self.blocks_per_phys_block From 9620fe03b9ca3b04935cd6ccce0fec74f357e25c Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Wed, 17 Sep 2025 06:59:59 +0000 Subject: [PATCH 19/37] fix typos Signed-off-by: lizhiyuan --- vllm/v1/worker/block_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 4c1848ab165a..452ec8b3e814 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -59,7 +59,7 @@ def __init__( self.block_size = kernel_block_size self.blocks_per_phys_block = (block_size // kernel_block_size) - self.use_hybrid_blocks = False + self.use_hybrid_blocks = True self.max_num_blocks_per_req = max_num_blocks_per_req * \ self.blocks_per_phys_block From ddbaebbbf7c06e0765a6200fc0d899c6c31eb44b Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Wed, 17 Sep 2025 08:05:53 +0000 Subject: [PATCH 20/37] default block_size 16 Signed-off-by: lizhiyuan --- vllm/attention/backends/abstract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 9523243a33d9..5dd5cca2e4d9 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -314,7 +314,7 @@ def __init__( @staticmethod def get_supported_block_size() -> list[Union[int, MultipleOf]]: # TODO: implement this function for all backends. - return [MultipleOf(1)] + return [MultipleOf(16)] @abstractmethod def forward( From 698b55e0bc64cc7fa0082253b2675499bdfbcab4 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Wed, 17 Sep 2025 08:18:49 +0000 Subject: [PATCH 21/37] fix lint Signed-off-by: lizhiyuan --- tests/v1/worker/test_gpu_model_runner.py | 10 ++++------ vllm/v1/worker/block_table.py | 5 +++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index b5470d5fda00..2bd7e588b013 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -867,8 +867,7 @@ def test_hybrid_block_table_initialization(): assert block_table.use_hybrid_blocks is True assert block_table.block_size == kernel_block_sizes[0] assert block_table.blocks_per_phys_block == ( - block_size // kernel_block_sizes[0] - ) # Changed to use first element + block_size // kernel_block_sizes[0]) # Changed to use first element # Test block table conversion logic # One physical block should map to multiple logical blocks @@ -912,11 +911,10 @@ def test_input_batch_with_kernel_block_sizes(): kernel_block_sizes=kernel_block_sizes) # Verify that block tables were created with kernel block sizes - assert len( - input_batch.block_table.block_tables) == len(block_sizes) + assert len(input_batch.block_table.block_tables) == len(block_sizes) - for i, (phys_size, kernel_size) in enumerate( - zip(block_sizes, kernel_block_sizes)): + for i, (phys_size, + kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)): block_table = input_batch.block_table.block_tables[i] if phys_size != kernel_size: assert block_table.use_hybrid_blocks is True diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 9fe7ed0069b4..ac216a1a8a5d 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -115,7 +115,8 @@ def add_row(self, block_ids: list[int], row_idx: int) -> None: def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] - self.block_table.np[tgt, :num_blocks] = self.block_table.np[src, :num_blocks] + self.block_table.np[tgt, :num_blocks] = self.block_table.np[ + src, :num_blocks] self.num_blocks_per_row[tgt] = num_blocks def swap_row(self, src: int, tgt: int) -> None: @@ -303,4 +304,4 @@ def clear(self) -> None: def __getitem__(self, idx: int) -> "BlockTable": """Returns the BlockTable for the i-th KV cache group.""" - return self.block_tables[idx] \ No newline at end of file + return self.block_tables[idx] From e0130937a2f74c3318230326e630f60543ba7670 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Sun, 21 Sep 2025 07:46:15 +0000 Subject: [PATCH 22/37] fix part of reviews Signed-off-by: lizhiyuan --- vllm/v1/worker/block_table.py | 54 +++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index ac216a1a8a5d..df9b1096a06d 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -25,36 +25,52 @@ def __init__( device: torch.device, kernel_block_size: int, ): - """Manages the mapping between logical and physical memory blocks - for KV cache. - - The BlockTable handles the conversion between kv_manager_block_size size - (actual memory allocation) and kernel block size (computation - granularity). When these sizes differ, it implements a hybrid block - system for memory efficiency. + """Manages KV cache block allocation and token-to-block mapping for + efficient inference. + + The BlockTable manages the relationship between token positions and + their corresponding memory blocks in the KV cache, supporting flexible + block size configurations to optimize both memory usage and + computational efficiency. It implements a hybrid block system that + bridges potential differences between memory allocation granularity + and kernel computation requirements. + + Key functionality: + - Maps token positions to KV cache memory blocks for efficient lookup + - Handles hybrid block configurations when allocation and computation + sizes differ + - Manages slot mappings for batched processing of multiple requests + - Provides efficient GPU/CPU buffer management for block metadata + - Supports distributed processing with DCP (Distributed Context + Parallelism) Args: - block_size: kv_manager_block size - kernel_block_size: Kernel block size - the granularity at which - attention kernels operate during computation. + block_size: Block size used for KV cache memory allocation max_num_reqs: Maximum number of concurrent requests supported. max_num_blocks_per_req: Maximum number of blocks per request. max_num_batched_tokens: Maximum number of tokens in a batch. pin_memory: Whether to pin memory for faster GPU transfers. device: Target device for the block table. + kernel_block_size: The block_size of underlying attention kernel. + Will be the same as `block_size` if `block_size` is supported + by the attention kernel. """ self.max_num_reqs = max_num_reqs self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device - # Handle hybrid block system if kernel_block_size == block_size: - # No splitting - use kv_manager_block_size size directly + # Standard case: allocation and computation use same block size + # No block splitting needed, direct mapping self.block_size = block_size self.blocks_per_phys_block = 1 self.use_hybrid_blocks = False else: + # Hybrid case: allocation block size differs from kernel block size + # Memory blocks are subdivided to match kernel requirements + # Example: 32-token memory blocks with 16-token kernel blocks + # → Each memory block corresponds to 2 kernel blocks if block_size % kernel_block_size != 0: raise ValueError( f"kernel_block_size {kernel_block_size} must divide " @@ -67,7 +83,6 @@ def __init__( self.max_num_blocks_per_req = max_num_blocks_per_req * \ self.blocks_per_phys_block - # Use CpuGpuBuffer for unified memory management self.block_table = self._make_buffer(self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32) @@ -76,7 +91,6 @@ def __init__( self.slot_mapping = self._make_buffer(self.max_num_batched_tokens, dtype=torch.int64) - # Pre-compute bias array for physical to logical block conversion if self.use_hybrid_blocks: self._bias_array = np.arange(0, self.blocks_per_phys_block).reshape( @@ -101,8 +115,7 @@ def append_row( return if self.use_hybrid_blocks: - block_ids = self._convert_physical_to_logical_blocks( - np.array(block_ids)) + block_ids = self._map_to_kernel_blocks(np.array(block_ids)) num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] @@ -115,9 +128,8 @@ def add_row(self, block_ids: list[int], row_idx: int) -> None: def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] - self.block_table.np[tgt, :num_blocks] = self.block_table.np[ - src, :num_blocks] - self.num_blocks_per_row[tgt] = num_blocks + block_table_np = self.block_table.np + block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks] def swap_row(self, src: int, tgt: int) -> None: src_tgt, tgt_src = [src, tgt], [tgt, src] @@ -183,8 +195,8 @@ def clear(self) -> None: self.block_table.gpu.fill_(0) self.block_table.cpu.fill_(0) - def _convert_physical_to_logical_blocks( - self, kv_manager_block_id: np.ndarray) -> np.ndarray: + def _map_to_kernel_blocks(self, + kv_manager_block_id: np.ndarray) -> np.ndarray: """Convert kv_manager_block_id IDs to logical block IDs. Example: From 1a52e564ce80e2ab56655f44ce7a7e2897dc9420 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Sun, 21 Sep 2025 08:02:12 +0000 Subject: [PATCH 23/37] fix part of reviews Signed-off-by: lizhiyuan --- vllm/v1/worker/block_table.py | 40 ++++++++++++++--------------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index df9b1096a06d..1d6c21d721fa 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -153,14 +153,12 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # for block_table_indices calculation. virtual_block_size = self.block_size * self.dcp_world_size - logical_block_idx = positions // virtual_block_size - - # Account for the expanded logical table + # Account for the expanded kernel table # (always needed with unified tensor) - # Each kv_manager_block_size is split into multiple logical blocks - # The logical table has been expanded to accommodate this + # Each kv_manager_block_size is split into multiple kernel blocks + # The kernel table has been expanded to accommodate this block_table_indices = (req_indices * self.max_num_blocks_per_req + - logical_block_idx) + positions // virtual_block_size) block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local @@ -175,9 +173,8 @@ def compute_slot_mapping(self, req_indices: np.ndarray, self.slot_mapping.np[:req_indices.shape[0]] = np.where( mask, slot_mapping, -1) else: - logical_block_idx = positions // self.block_size block_table_indices = (req_indices * self.max_num_blocks_per_req + - logical_block_idx) + positions // self.block_size) block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size @@ -196,28 +193,28 @@ def clear(self) -> None: self.block_table.cpu.fill_(0) def _map_to_kernel_blocks(self, - kv_manager_block_id: np.ndarray) -> np.ndarray: - """Convert kv_manager_block_id IDs to logical block IDs. + kv_manager_block_ids: np.ndarray) -> np.ndarray: + """Convert kv_manager_block_id IDs to kernel block IDs. Example: - # kv_manager_block_id: 32 tokens, + # kv_manager_block_ids: 32 tokens, # Kernel block size: 16 tokens # blocks_per_phys_block = 2 - >>> kv_manager_block_id = np.array([0, 1, 2]) + >>> kv_manager_block_ids = np.array([0, 1, 2]) >>> Result: [0, 1, 2, 3, 4, 5] - # Each kv_manager_block_id maps to 2 logical block id: - # kv_manager_block_id 0 → Logical block id [0, 1] - # kv_manager_block_id 1 → Logical block id [2, 3] - # kv_manager_block_id 2 → Logical block id [4, 5] + # Each kv_manager_block_id maps to 2 kernel block id: + # kv_manager_block_id 0 → kernel block id [0, 1] + # kv_manager_block_id 1 → kernel block id [2, 3] + # kv_manager_block_id 2 → kernel block id [4, 5] """ if not self.use_hybrid_blocks: - return kv_manager_block_id + return kv_manager_block_ids - logical_block_id = kv_manager_block_id.reshape( + kernel_block_ids = kv_manager_block_ids.reshape( -1, 1) * self.blocks_per_phys_block + self._bias_array - return logical_block_id.reshape(-1) + return kernel_block_ids.reshape(-1) def get_device_tensor(self, num_reqs: int) -> torch.Tensor: """Returns the device tensor of the block table.""" @@ -261,16 +258,11 @@ def __init__(self, # DCP might not be initialized in testing dcp_world_size = 1 - if kernel_block_sizes is None: - # Use kv_manager_block_size size by default - kernel_block_sizes = block_sizes - if len(kernel_block_sizes) != len(block_sizes): raise ValueError( f"kernel_block_sizes length ({len(kernel_block_sizes)}) " f"must match block_sizes length ({len(block_sizes)})") - # Use zip to pair block_sizes with kernel_block_sizes one-to-one self.block_tables = [ BlockTable( block_size, max_num_reqs, From bbe2200127ffa697b5272802017c10e72e55bc3b Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 22 Sep 2025 08:03:13 +0000 Subject: [PATCH 24/37] clean imps Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 61 +++++++++++++----------------- 1 file changed, 27 insertions(+), 34 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e5a703c0ef61..4595f15191ae 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3541,47 +3541,40 @@ def _select_kernel_block_size(self, kv_manager_block_size: int, Returns: The selected kernel block size (largest available) + - return kv_manager_block_size if supported + - otherwise, return max supported block size that is a factor + of kv_manager_block_size Raises: ValueError: If no valid kernel block size can be found that satisfies the backend's constraints """ - supported_constraints = backend_cls.get_supported_block_size() - selected_kernel_size = kv_manager_block_size - constraint_satisfied = False - valid_constraints = [] - - for constraint in supported_constraints: - if (isinstance(constraint, int) - and kv_manager_block_size % constraint == 0): - valid_constraints.append(constraint) - elif (isinstance(constraint, MultipleOf) - and kv_manager_block_size % constraint.base == 0): - valid_constraints.append(constraint.base) - - if valid_constraints: - selected_kernel_size = max(valid_constraints) - constraint_satisfied = True - - if not constraint_satisfied and supported_constraints: - # Only raise error if there are actual constraints to satisfy - # and none of them were met - constraint_strs = [] - for constraint in supported_constraints: - if isinstance(constraint, int): - constraint_strs.append(f"{constraint}") - elif isinstance(constraint, MultipleOf): - constraint_strs.append(f"multiple of {constraint.base}") + supported_block_size = backend_cls.get_supported_block_size() + + # if kv_manager_block_size is supported by the attention backend, + # return it. + for block_size in supported_block_size: + if isinstance(block_size, int): + if kv_manager_block_size == block_size: + return kv_manager_block_size + elif (isinstance(block_size, MultipleOf) + and kv_manager_block_size % block_size.base == 0): + return kv_manager_block_size + + # Otherwise, we can't find a valid block_size from the + # `MultipleOf`-style candidates. So find the largest one from + # the `int`-style candidates. + compatible_sizes = [ + block_size for block_size in supported_block_size + if isinstance(block_size, int) and kv_manager_block_size % + block_size == 0 + ] + if not compatible_sizes: raise ValueError( - f"Physical block size {kv_manager_block_size} does not " - f"satisfy any constraints for {backend_cls.__name__} " - f"backend. Supported constraints: " - f"{', '.join(constraint_strs)}. " - f"The physical block size must be compatible with at least " - f"one constraint.") - - return selected_kernel_size + f"No compatible block size for {kv_manager_block_size}") + + return max(compatible_sizes) def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: From df485c39355c9b851adb5daeab7a7cc1c5d0cac9 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 22 Sep 2025 08:34:08 +0000 Subject: [PATCH 25/37] refactor: extract kernel block size logic into separate function Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 66 +++++++++++++++++++----------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4595f15191ae..0a9272fe26e3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3594,30 +3594,7 @@ def may_reinitialize_input_batch(self, ] # Generate kernel_block_sizes that matches each block_size - # For attention backends that support virtual block splitting, - # use the supported block sizes from the backend - # For other backends (like Mamba), use [0] (no splitting) - kernel_block_sizes = [] - for kv_cache_group_id, kv_cache_group in enumerate( - kv_cache_config.kv_cache_groups): - if isinstance(kv_cache_group.kv_cache_spec, - EncoderOnlyAttentionSpec): - continue - elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): - # This is an attention backend that supports virtual - # block splitting. Get the supported block sizes from - # the backend. - attn_groups = self.attn_groups[kv_cache_group_id] - kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size - backend_cls = attn_groups[0].backend - selected_kernel_size = self._select_kernel_block_size( - kv_manager_block_size, backend_cls) - kernel_block_sizes.append(selected_kernel_size) - else: - # This is likely Mamba or other non-attention cache, - # no splitting. - kernel_block_sizes.append( - kv_cache_group.kv_cache_spec.block_size) + kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) if block_sizes != [ self.cache_config.block_size @@ -3685,6 +3662,47 @@ def _kv_cache_spec_attn_group_iterator( yield self.kv_cache_config.kv_cache_groups[ kv_cache_spec_id].kv_cache_spec, attn_group + def _prepare_kernel_block_sizes( + self, kv_cache_config: KVCacheConfig) -> list[int]: + """ + Generate kernel_block_sizes that matches each block_size. + + For attention backends that support virtual block splitting, + use the supported block sizes from the backend. + For other backends (like Mamba), use the same block size (no splitting). + + Args: + kv_cache_config: The KV cache configuration. + + Returns: + list[int]: List of kernel block sizes for each cache group. + """ + kernel_block_sizes = [] + for kv_cache_group_id, kv_cache_group in enumerate( + kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, + EncoderOnlyAttentionSpec): + continue + elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): + # This is an attention backend that supports virtual + # block splitting. Get the supported block sizes from + # the backend. + attn_groups = self.attn_groups[kv_cache_group_id] + kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size + backend_cls = attn_groups[0].backend + selected_kernel_size = self._select_kernel_block_size( + kv_manager_block_size, backend_cls) + kernel_block_sizes.append(selected_kernel_size) + elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + # This is likely Mamba or other non-attention cache, + # no splitting. + kernel_block_sizes.append( + kv_cache_group.kv_cache_spec.block_size) + else: + raise NotImplementedError( + f"unknown kv cache spec {kv_cache_group.kv_cache_spec}") + return kernel_block_sizes + def _reshape_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, From f70aefa8e5631074f31107f5c95711e535090485 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 22 Sep 2025 08:39:18 +0000 Subject: [PATCH 26/37] fix flashinfer Signed-off-by: lizhiyuan --- vllm/v1/attention/backends/flashinfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index d4cb724a84ae..f5267f2cfa78 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -145,7 +145,7 @@ def get_supported_head_sizes(cls) -> list[int]: @staticmethod def get_supported_block_size() -> list[Union[int, MultipleOf]]: - return [MultipleOf(16)] + return [MultipleOf(1)] @classmethod def validate_head_size(cls, head_size: int) -> None: From beee4d314bef69b195aabbb3e4eb490794aff947 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 22 Sep 2025 08:53:13 +0000 Subject: [PATCH 27/37] fix tests Signed-off-by: lizhiyuan --- tests/v1/worker/test_gpu_model_runner.py | 31 ++++++++++++------------ 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 2bd7e588b013..3dd83a71432f 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -844,11 +844,12 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): def test_hybrid_block_table_initialization(): - """Test hybrid block table with different kernel and physical block + """Test hybrid block table with different kernel and kvcache_manager block sizes.""" from vllm.v1.worker.block_table import BlockTable - # Test configuration: physical block size = 32, kernel block size = 16 + # Test configuration: kvcache_manager block size = 32, + # kernel block size = 16 block_size = 32 kernel_block_sizes = [16] max_num_reqs = 10 @@ -870,22 +871,22 @@ def test_hybrid_block_table_initialization(): block_size // kernel_block_sizes[0]) # Changed to use first element # Test block table conversion logic - # One physical block should map to multiple logical blocks - physical_blocks = [0, 1, 2] + # One kvcache_manager block should map to multiple kernel blocks + kvcache_manager_blocks = [0, 1, 2] - # Verify that physical blocks can be converted to logical blocks + # Verify that kvcache_manager blocks can be converted to kernel blocks # and that block table operations work correctly. req_index = 0 - block_table.append_row(physical_blocks, req_index) - # Get expected logical blocks from the implementation for verification. - expected_logical_blocks = block_table._convert_physical_to_logical_blocks( - np.array(physical_blocks)) + block_table.append_row(kvcache_manager_blocks, req_index) + # Get expected kernel blocks from the implementation for verification. + expected_kernel_blocks = block_table._map_to_kernel_blocks( + np.array(kvcache_manager_blocks)) # Verify block table state assert block_table.num_blocks_per_row[req_index] == len( - expected_logical_blocks) + expected_kernel_blocks) assert np.array_equal( - block_table.block_table.np[req_index, :len(expected_logical_blocks)], - expected_logical_blocks) + block_table.block_table.np[req_index, :len(expected_kernel_blocks)], + expected_kernel_blocks) def test_input_batch_with_kernel_block_sizes(): @@ -929,7 +930,7 @@ def test_hybrid_cache_integration(model_runner, dist_init): # Create a new model runner with hybrid cache configuration vllm_config = get_vllm_config() - # Configure hybrid cache with different physical block size + # Configure hybrid cache with different kvcache_manager block size vllm_config.cache_config.block_size = 32 model_config = vllm_config.model_config @@ -942,7 +943,7 @@ def test_hybrid_cache_integration(model_runner, dist_init): # Initialize KV cache with configuration attn_spec = FullAttentionSpec( - block_size=16, # Use logical block size directly + block_size=16, # Use kernel block size directly num_kv_heads=runner.model_config.get_num_kv_heads( runner.parallel_config), head_size=runner.model_config.get_head_size(), @@ -972,7 +973,7 @@ def test_hybrid_cache_integration(model_runner, dist_init): block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], - kernel_block_sizes=[16]) # Use logical block size + kernel_block_sizes=[16]) # Use kernel block size runner.initialize_attn_backend(kv_cache_config) From adba4a59ade8365eeb6a2446528fd08fe6f81a75 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Wed, 24 Sep 2025 04:01:00 +0000 Subject: [PATCH 28/37] minor fixs Signed-off-by: lizhiyuan --- vllm/v1/worker/block_table.py | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 1d6c21d721fa..b4139930d175 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -25,25 +25,7 @@ def __init__( device: torch.device, kernel_block_size: int, ): - """Manages KV cache block allocation and token-to-block mapping for - efficient inference. - - The BlockTable manages the relationship between token positions and - their corresponding memory blocks in the KV cache, supporting flexible - block size configurations to optimize both memory usage and - computational efficiency. It implements a hybrid block system that - bridges potential differences between memory allocation granularity - and kernel computation requirements. - - Key functionality: - - Maps token positions to KV cache memory blocks for efficient lookup - - Handles hybrid block configurations when allocation and computation - sizes differ - - Manages slot mappings for batched processing of multiple requests - - Provides efficient GPU/CPU buffer management for block metadata - - Supports distributed processing with DCP (Distributed Context - Parallelism) - + """ Args: block_size: Block size used for KV cache memory allocation max_num_reqs: Maximum number of concurrent requests supported. @@ -54,7 +36,7 @@ def __init__( kernel_block_size: The block_size of underlying attention kernel. Will be the same as `block_size` if `block_size` is supported by the attention kernel. - """ + """ self.max_num_reqs = max_num_reqs self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory @@ -153,10 +135,6 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # for block_table_indices calculation. virtual_block_size = self.block_size * self.dcp_world_size - # Account for the expanded kernel table - # (always needed with unified tensor) - # Each kv_manager_block_size is split into multiple kernel blocks - # The kernel table has been expanded to accommodate this block_table_indices = (req_indices * self.max_num_blocks_per_req + positions // virtual_block_size) From 279d1d0576557538f1116d510433bc8380b6e537 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Wed, 24 Sep 2025 04:03:47 +0000 Subject: [PATCH 29/37] fix kernel block size for attn groups Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 108 ++++++++++++++++++++--------- 1 file changed, 76 insertions(+), 32 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a3ad5cf43d33..e37c4d75f9d0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3621,51 +3621,96 @@ def calculate_reorder_batch_threshold(self) -> None: else: self.reorder_batch_threshold = reorder_batch_threshold_i - def _select_kernel_block_size(self, kv_manager_block_size: int, - backend_cls: type[AttentionBackend]) -> int: + def _find_compatible_block_sizes( + self, + kv_manager_block_size: int, + backend_cls: type[AttentionBackend], + return_all: bool = False) -> Union[int, list[int]]: """ - Select the optimal kernel block size for a given physical block size. + Find compatible block sizes for a backend. Args: - kv_manager_block_size: The physical block size of the KV cache - backend_cls: The attention backend class + kv_manager_block_size: Physical block size of KV cache + backend_cls: Attention backend class + return_all: Return all compatible sizes if True, max size if False Returns: - The selected kernel block size (largest available) - - return kv_manager_block_size if supported - - otherwise, return max supported block size that is a factor - of kv_manager_block_size + Compatible block size(s) based on return_all parameter Raises: - ValueError: If no valid kernel block size can be found that - satisfies the backend's constraints + ValueError: If no compatible block size found """ supported_block_size = backend_cls.get_supported_block_size() + compatible_sizes = [] - # if kv_manager_block_size is supported by the attention backend, - # return it. for block_size in supported_block_size: if isinstance(block_size, int): - if kv_manager_block_size == block_size: - return kv_manager_block_size + if kv_manager_block_size % block_size == 0: + compatible_sizes.append(block_size) elif (isinstance(block_size, MultipleOf) and kv_manager_block_size % block_size.base == 0): - return kv_manager_block_size - - # Otherwise, we can't find a valid block_size from the - # `MultipleOf`-style candidates. So find the largest one from - # the `int`-style candidates. - compatible_sizes = [ - block_size for block_size in supported_block_size - if isinstance(block_size, int) and kv_manager_block_size % - block_size == 0 - ] + compatible_sizes.append(kv_manager_block_size) if not compatible_sizes: raise ValueError( f"No compatible block size for {kv_manager_block_size}") - return max(compatible_sizes) + return compatible_sizes if return_all else max(compatible_sizes) + + def _get_all_compatible_block_sizes( + self, kv_manager_block_size: int, + backend_cls: type[AttentionBackend]) -> list[int]: + """ + Get all compatible block sizes for a backend. + + Args: + kv_manager_block_size: Physical block size of KV cache + backend_cls: Attention backend class + + Returns: + List of all compatible block sizes in descending order + + Raises: + ValueError: If no compatible block size found + """ + compatible_sizes = self._find_compatible_block_sizes( + kv_manager_block_size, backend_cls, return_all=True) + + return sorted(list(set(compatible_sizes)), reverse=True) + + def _select_common_block_size(self, kv_manager_block_size: int, + attn_groups: list[AttentionGroup]) -> int: + """ + Select common block size for all backends. + + Args: + kv_manager_block_size: Physical block size of KV cache + attn_groups: List of attention groups + + Returns: + Largest block size supported by all backends + + Raises: + ValueError: If no common block size found + """ + all_backend_supports = [] + + for attn_group in attn_groups: + supported_sizes = self._get_all_compatible_block_sizes( + kv_manager_block_size, attn_group.backend) + all_backend_supports.append(set(supported_sizes)) + + common_supported_sizes = set.intersection(*all_backend_supports) + + if not common_supported_sizes: + error_msg = (f"No common block size for {kv_manager_block_size}. ") + for i, attn_group in enumerate(attn_groups): + supported = all_backend_supports[i] + error_msg += (f"Backend {attn_group.backend} supports: " + f"{sorted(supported)}. ") + raise ValueError(error_msg) + + return max(common_supported_sizes) def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: @@ -3774,12 +3819,11 @@ def _prepare_kernel_block_sizes( elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): # This is an attention backend that supports virtual # block splitting. Get the supported block sizes from - # the backend. + # all backends in the group. attn_groups = self.attn_groups[kv_cache_group_id] kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size - backend_cls = attn_groups[0].backend - selected_kernel_size = self._select_kernel_block_size( - kv_manager_block_size, backend_cls) + selected_kernel_size = self._select_common_block_size( + kv_manager_block_size, attn_groups) kernel_block_sizes.append(selected_kernel_size) elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec): # This is likely Mamba or other non-attention cache, @@ -3822,8 +3866,8 @@ def _reshape_kv_cache_tensors( if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_manager_block_size = kv_cache_spec.block_size - logical_kernel_size = self._select_kernel_block_size( - kv_manager_block_size, attn_backend) + logical_kernel_size = self._find_compatible_block_sizes( + kv_manager_block_size, attn_backend, return_all=False) num_blocks_per_phys_block = (kv_manager_block_size // logical_kernel_size) logical_num_blocks = num_blocks * num_blocks_per_phys_block From 7cb4fc33a14f7e351cc8e7194c86f5f4decb26d3 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Wed, 24 Sep 2025 04:24:31 +0000 Subject: [PATCH 30/37] fix lint Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 840c70082bce..96a309ec228a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3629,11 +3629,10 @@ def calculate_reorder_batch_threshold(self) -> None: else: self.reorder_batch_threshold = reorder_batch_threshold_i - def _find_compatible_block_sizes( - self, - kv_manager_block_size: int, - backend_cls: type[AttentionBackend], - return_all: bool = False) -> Union[int, list[int]]: + def _find_compatible_block_sizes(self, + kv_manager_block_size: int, + backend_cls: type[AttentionBackend], + return_all: bool = False) -> list[int]: """ Find compatible block sizes for a backend. @@ -3663,7 +3662,7 @@ def _find_compatible_block_sizes( raise ValueError( f"No compatible block size for {kv_manager_block_size}") - return compatible_sizes if return_all else max(compatible_sizes) + return compatible_sizes if return_all else [max(compatible_sizes)] def _get_all_compatible_block_sizes( self, kv_manager_block_size: int, @@ -3874,8 +3873,10 @@ def _reshape_kv_cache_tensors( if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_manager_block_size = kv_cache_spec.block_size - logical_kernel_size = self._find_compatible_block_sizes( + logical_kernel_size_list = \ + self._find_compatible_block_sizes( kv_manager_block_size, attn_backend, return_all=False) + logical_kernel_size = logical_kernel_size_list[0] num_blocks_per_phys_block = (kv_manager_block_size // logical_kernel_size) logical_num_blocks = num_blocks * num_blocks_per_phys_block From 585f2bff2162aaa96cf4a21683ab99607a19e5d4 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Wed, 24 Sep 2025 05:42:04 +0000 Subject: [PATCH 31/37] fix test Signed-off-by: lizhiyuan --- tests/v1/worker/test_gpu_input_batch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 98700ff73fd1..99793ab95581 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -236,6 +236,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) reqs: list[CachedRequestState] = [] req_id_reqs = {} @@ -327,6 +328,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) ref_input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, @@ -336,6 +338,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) reqs: list[CachedRequestState] = [] From 413272b8d3560b8dadc3a787447863ed98c960a8 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Wed, 24 Sep 2025 11:40:04 +0000 Subject: [PATCH 32/37] rename Signed-off-by: lizhiyuan --- vllm/model_executor/models/config.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index fa51d0fb8e93..1ab735886b63 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -381,15 +381,16 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # * Other MLA backends: 64-byte alignment if model_config.use_mla: use_cutlass_mla = (envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") - block_alignment_bytes = 128 if use_cutlass_mla else 64 + kernel_block_alignment_size = 128 if use_cutlass_mla else 64 else: - block_alignment_bytes = 16 + kernel_block_alignment_size = 16 # Calculate minimum attention block size that satisfies both: - # 1. Backend alignment requirements (block_alignment_bytes) + # 1. Backend alignment requirements (kernel_block_alignment_size) # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) - attn_block_size = block_alignment_bytes * cdiv( - mamba_page_size, block_alignment_bytes * attn_page_size_1_token) + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, + kernel_block_alignment_size * attn_page_size_1_token) # override attention block size if either (a) the # user has not set it or (b) the user has set it From a51673b89731c7ad0801ce04854128a4ebce4c2d Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Thu, 25 Sep 2025 03:25:10 +0000 Subject: [PATCH 33/37] fix mla issue Signed-off-by: lizhiyuan --- vllm/platforms/cuda.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 28fa383cca59..c21abead80b5 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -163,17 +163,17 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: from vllm.attention.ops.flashmla import is_flashmla_supported if use_flashmla and is_flashmla_supported()[0] \ - and cache_config.block_size != 64: + and cache_config.block_size % 64 != 0: cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") - if use_cutlass_mla and cache_config.block_size != 128: + if use_cutlass_mla and cache_config.block_size % 128 != 0: cache_config.block_size = 128 logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA backend.") - if use_flashinfer_mla and cache_config.block_size not in [32, 64]: + if use_flashinfer_mla and cache_config.block_size % 64 != 0: cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashInferMLA " From d865f00c46a17d1130ca9672c010ed294e872997 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Fri, 26 Sep 2025 09:53:10 +0000 Subject: [PATCH 34/37] fix flashinfer mla Signed-off-by: lizhiyuan --- vllm/platforms/cuda.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index de1a3f9159af..f817d7015fe8 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -173,7 +173,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA backend.") - if use_flashinfer_mla and cache_config.block_size % 64 != 0: + if use_flashinfer_mla and cache_config.block_size != 32 and \ + cache_config.block_size % 64 != 0: cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashInferMLA " From dd7bfc8625f5fb7eb175498c1cb61abfa714f191 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Fri, 26 Sep 2025 09:59:21 +0000 Subject: [PATCH 35/37] try to use default block_size Signed-off-by: lizhiyuan --- vllm/v1/worker/gpu_model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1c22b068fdc0..90eb9169067f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3695,7 +3695,8 @@ def _select_common_block_size(self, kv_manager_block_size: int, attn_groups: List of attention groups Returns: - Largest block size supported by all backends + Block size supported by all backends, + prioritizing cache_config.block_size Raises: ValueError: If no common block size found @@ -3717,6 +3718,9 @@ def _select_common_block_size(self, kv_manager_block_size: int, f"{sorted(supported)}. ") raise ValueError(error_msg) + if self.cache_config.block_size in common_supported_sizes: + return self.cache_config.block_size + return max(common_supported_sizes) def may_reinitialize_input_batch(self, From c9231e8808e1c1137d99d73165b87928e23c78a2 Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Sat, 27 Sep 2025 11:45:53 +0000 Subject: [PATCH 36/37] fix bugs in mla Signed-off-by: lizhiyuan --- vllm/platforms/cuda.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index f3341b0cd713..88cad70740d4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -242,10 +242,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) - and block_size == 128) + and block_size % 128 == 0) use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( - selected_backend is None and cls.is_device_capability(100) - and block_size in [32, 64]) + selected_backend is None and cls.is_device_capability(100) and + (block_size == 32 or block_size % 64 == 0)) use_flashmla = selected_backend == _Backend.FLASHMLA or ( selected_backend is None and is_flashmla_supported()[0]) use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( @@ -265,7 +265,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, return ("vllm.v1.attention.backends.mla." "flashinfer_mla.FlashInferMLABackend") if use_flashmla: - if block_size != 64: + if block_size % 64 != 0: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", From 248dbd596db0b0d6f13697d296298a81efc8228f Mon Sep 17 00:00:00 2001 From: lizhiyuan Date: Mon, 29 Sep 2025 05:08:04 +0000 Subject: [PATCH 37/37] fix corner case Signed-off-by: lizhiyuan --- vllm/model_executor/models/config.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 812478ed1983..894700d942d4 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -361,6 +361,12 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=model_config.max_model_len, ).page_size_bytes + # Model may be marked as is_hybrid + # but mamba is skipped via config, + # return directly + if mamba_page_size == 0: + return + # Attention backend constraints: # - FlashAttention (FA) requires block size to be multiple of 16 # - MLA (Multi-head Latent Attention) requires larger alignment: @@ -382,6 +388,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # override attention block size if either (a) the # user has not set it or (b) the user has set it # too small. + # if (cache_config.block_size is None or cache_config.block_size < attn_block_size): cache_config.block_size = attn_block_size