Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
05ef7f1
[Hybrid]: Decouple Logical Block Size from Physical Page Size
zhiyuan1i Sep 15, 2025
0d18668
Change kernel_sizes parameter to kernel_block_size in BlockTable
zhiyuan1i Sep 15, 2025
bded2b4
fix condition for may_reinitialize_input_batch
zhiyuan1i Sep 15, 2025
90c14ab
update shapes when getting kv_cache_shape in ModelRunner
zhiyuan1i Sep 15, 2025
613f4c6
minor fix
zhiyuan1i Sep 15, 2025
edfdf8d
Fix embedding test
zhiyuan1i Sep 15, 2025
28e94eb
making max_num_blocks_per_req to save max number of logical blocks
zhiyuan1i Sep 15, 2025
b1d3dcc
change design
zhiyuan1i Sep 15, 2025
e10d70a
clean codes
zhiyuan1i Sep 15, 2025
2ce97c4
clean imps
zhiyuan1i Sep 16, 2025
0909efd
Revert "Fix embedding test"
zhiyuan1i Sep 16, 2025
097c11c
remove conditions since attn_groups won't be empty
zhiyuan1i Sep 16, 2025
0e6ae07
fix conditions
zhiyuan1i Sep 16, 2025
3fd0727
find largest block_size to reduce overhead
zhiyuan1i Sep 16, 2025
ff983af
fix default block size
zhiyuan1i Sep 16, 2025
e869bf0
fix embedding test
zhiyuan1i Sep 16, 2025
3bb83b9
change params
zhiyuan1i Sep 16, 2025
8a7c2b6
clean unused branch
zhiyuan1i Sep 16, 2025
9620fe0
fix typos
zhiyuan1i Sep 17, 2025
5fe1e95
Merge remote-tracking branch 'upstream/main' into hybrid-cache-groups
zhiyuan1i Sep 17, 2025
ddbaebb
default block_size 16
zhiyuan1i Sep 17, 2025
698b55e
fix lint
zhiyuan1i Sep 17, 2025
e013093
fix part of reviews
zhiyuan1i Sep 21, 2025
1a52e56
fix part of reviews
zhiyuan1i Sep 21, 2025
bbe2200
clean imps
zhiyuan1i Sep 22, 2025
df485c3
refactor: extract kernel block size logic into separate function
zhiyuan1i Sep 22, 2025
29f9d30
Merge remote-tracking branch 'upstream/main' into hybrid-cache-groups
zhiyuan1i Sep 22, 2025
f70aefa
fix flashinfer
zhiyuan1i Sep 22, 2025
beee4d3
fix tests
zhiyuan1i Sep 22, 2025
40d7b95
Merge remote-tracking branch 'upstream/main' into hybrid-cache-groups
zhiyuan1i Sep 23, 2025
1710a7a
Merge branch 'main' into hybrid-cache-groups
zhiyuan1i Sep 23, 2025
5820c10
Merge branch 'main' into hybrid-cache-groups
zhiyuan1i Sep 23, 2025
adba4a5
minor fixs
zhiyuan1i Sep 24, 2025
279d1d0
fix kernel block size for attn groups
zhiyuan1i Sep 24, 2025
5d328d2
Merge remote-tracking branch 'upstream/main' into hybrid-cache-groups
zhiyuan1i Sep 24, 2025
7cb4fc3
fix lint
zhiyuan1i Sep 24, 2025
585f2bf
fix test
zhiyuan1i Sep 24, 2025
413272b
rename
zhiyuan1i Sep 24, 2025
a51673b
fix mla issue
zhiyuan1i Sep 25, 2025
5691f12
Merge remote-tracking branch 'upstream/main' into hybrid-cache-groups
zhiyuan1i Sep 25, 2025
d865f00
fix flashinfer mla
zhiyuan1i Sep 26, 2025
dd7bfc8
try to use default block_size
zhiyuan1i Sep 26, 2025
6a97abb
Merge branch 'main' into hybrid-cache-groups
zhiyuan1i Sep 26, 2025
74e0ff1
Merge branch 'main' into hybrid-cache-groups
zhiyuan1i Sep 27, 2025
c9231e8
fix bugs in mla
zhiyuan1i Sep 27, 2025
248dbd5
fix corner case
zhiyuan1i Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
kernel_block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
Expand Down Expand Up @@ -327,6 +328,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
kernel_block_sizes=[1],
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
Expand All @@ -336,6 +338,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
kernel_block_sizes=[1],
)

reqs: list[CachedRequestState] = []
Expand Down
152 changes: 152 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
kernel_block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
)
runner.initialize_attn_backend(kv_cache_config)

Expand Down Expand Up @@ -838,3 +841,152 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
conv_blocks_constant)
assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :],
ssm_blocks_constant)


def test_hybrid_block_table_initialization():
"""Test hybrid block table with different kernel and kvcache_manager block
sizes."""
from vllm.v1.worker.block_table import BlockTable

# Test configuration: kvcache_manager block size = 32,
# kernel block size = 16
block_size = 32
kernel_block_sizes = [16]
max_num_reqs = 10
max_num_blocks_per_req = 20
max_num_batched_tokens = 512

block_table = BlockTable(block_size=block_size,
max_num_reqs=max_num_reqs,
max_num_blocks_per_req=max_num_blocks_per_req,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=False,
device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0])

# Verify hybrid block configuration
assert block_table.use_hybrid_blocks is True
assert block_table.block_size == kernel_block_sizes[0]
assert block_table.blocks_per_phys_block == (
block_size // kernel_block_sizes[0]) # Changed to use first element

# Test block table conversion logic
# One kvcache_manager block should map to multiple kernel blocks
kvcache_manager_blocks = [0, 1, 2]

# Verify that kvcache_manager blocks can be converted to kernel blocks
# and that block table operations work correctly.
req_index = 0
block_table.append_row(kvcache_manager_blocks, req_index)
# Get expected kernel blocks from the implementation for verification.
expected_kernel_blocks = block_table._map_to_kernel_blocks(
np.array(kvcache_manager_blocks))
# Verify block table state
assert block_table.num_blocks_per_row[req_index] == len(
expected_kernel_blocks)
assert np.array_equal(
block_table.block_table.np[req_index, :len(expected_kernel_blocks)],
expected_kernel_blocks)


def test_input_batch_with_kernel_block_sizes():
"""Test InputBatch initialization with kernel_block_sizes parameter."""
max_num_reqs = 10
max_model_len = 512
max_num_batched_tokens = 512
device = torch.device(DEVICE)
pin_memory = False
vocab_size = 50272

# Test with different kernel block sizes
block_sizes = [32, 64]
kernel_block_sizes = [16, 32]

input_batch = InputBatch(max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
device=device,
pin_memory=pin_memory,
vocab_size=vocab_size,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes)

# Verify that block tables were created with kernel block sizes
assert len(input_batch.block_table.block_tables) == len(block_sizes)

for i, (phys_size,
kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)):
block_table = input_batch.block_table.block_tables[i]
if phys_size != kernel_size:
assert block_table.use_hybrid_blocks is True
assert block_table.block_size == kernel_size
else:
assert block_table.use_hybrid_blocks is False
assert block_table.block_size == kernel_size


def test_hybrid_cache_integration(model_runner, dist_init):
"""Test hybrid cache architecture integration with GPUModelRunner."""
# Create a new model runner with hybrid cache configuration
vllm_config = get_vllm_config()

# Configure hybrid cache with different kvcache_manager block size
vllm_config.cache_config.block_size = 32

model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context[
"layer.0"] = Attention(num_heads, head_size, 0.1)

runner = GPUModelRunner(vllm_config, DEVICE)

# Initialize KV cache with configuration
attn_spec = FullAttentionSpec(
block_size=16, # Use kernel block size directly
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
use_mla=False,
)
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
kv_cache_config = KVCacheConfig(
num_blocks=NUM_BLOCKS,
kv_cache_tensors=[
KVCacheTensor(size=tensor_size, shared_by=["layer.0"]),
],
kv_cache_groups=[
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
],
)
runner.kv_cache_config = kv_cache_config

# Initialize input batch with kernel block sizes
runner.input_batch = InputBatch(
max_num_reqs=runner.max_num_reqs,
max_model_len=runner.max_model_len,
max_num_batched_tokens=runner.max_num_tokens,
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
kernel_block_sizes=[16]) # Use kernel block size

runner.initialize_attn_backend(kv_cache_config)

# Verify hybrid block table configuration
block_table = runner.input_batch.block_table.block_tables[0]
assert block_table.block_size == (
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size)

# Test request processing with hybrid blocks
req_id = "hybrid_req_0"
scheduler_output = _schedule_new_request(req_id)

# Update states should work with hybrid blocks
runner._update_states(scheduler_output)
assert _is_req_scheduled(runner, req_id)
assert _is_req_state_block_table_match(runner, req_id)
19 changes: 18 additions & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
from typing import (Generic, List, Optional, Protocol, Tuple, Type, TypeVar,
Union)

import torch

Expand All @@ -24,6 +25,13 @@ class AttentionType:
"""Attention between dec. Q and enc. K/V for encoder-decoder."""


class MultipleOf:
base: int

def __init__(self, base: int):
self.base = base


class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
Expand Down Expand Up @@ -54,6 +62,10 @@ def get_impl_cls() -> Type["AttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError

@classmethod
def get_supported_block_size(cls) -> list[Union[int, MultipleOf]]:
return cls.get_impl_cls().get_supported_block_size()

@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
Expand Down Expand Up @@ -154,6 +166,11 @@ def __init__(
) -> None:
raise NotImplementedError

@staticmethod
def get_supported_block_size() -> list[Union[int, MultipleOf]]:
# TODO: implement this function for all backends.
return [MultipleOf(1)]

@abstractmethod
def forward(
self,
Expand Down
30 changes: 24 additions & 6 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,34 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
block_size=model_config.max_model_len,
).page_size_bytes

# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = 16 * cdiv(mamba_page_size,
16 * attn_page_size_1_token)
# Model may be marked as is_hybrid
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which model?

# but mamba is skipped via config,
# return directly
if mamba_page_size == 0:
return

# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: 128-byte alignment
# * Other MLA backends: 64-byte alignment
if model_config.use_mla:
use_cutlass_mla = (envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
else:
kernel_block_alignment_size = 16
Comment on lines +375 to +379
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we can't get this info from the attention backend directly?

Comment on lines +378 to +379
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to take into account FlashInfer here. It FlashInfer uses TRTLLM (default on blackwell) then the block size cannot be more than 128.

Is there any way we could get this info from the attention backend directly? Or we can't because are at init time here?


# Calculate minimum attention block size that satisfies both:
# 1. Backend alignment requirements (kernel_block_alignment_size)
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
attn_block_size = kernel_block_alignment_size * cdiv(
mamba_page_size,
kernel_block_alignment_size * attn_page_size_1_token)

# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
#
if (cache_config.block_size is None
or cache_config.block_size < attn_block_size):
cache_config.block_size = attn_block_size
Expand Down
21 changes: 13 additions & 8 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:

# TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing
if model_config is not None and model_config.use_mla:
# Note: block_size is initialized in
# HybridAttentionMambaModelConfig.verify_and_update_config
# and doesn't need to be reinitialized here
Comment on lines +131 to +133
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is true for hybrid models only right?

if model_config is not None and model_config.use_mla \
and cache_config.block_size is not None:
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
Expand Down Expand Up @@ -159,17 +163,18 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:

from vllm.attention.ops.flashmla import is_flashmla_supported
if use_flashmla and is_flashmla_supported()[0] \
and cache_config.block_size != 64:
and cache_config.block_size % 64 != 0:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.")

if use_cutlass_mla and cache_config.block_size != 128:
if use_cutlass_mla and cache_config.block_size % 128 != 0:
cache_config.block_size = 128
logger.info("Forcing kv cache block size to 128 for "
"CUTLASS_MLA backend.")

if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
if use_flashinfer_mla and cache_config.block_size != 32 and \
cache_config.block_size % 64 != 0:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA "
Expand Down Expand Up @@ -237,10 +242,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,

use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None and cls.is_device_capability(100)
and block_size == 128)
and block_size % 128 == 0)
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None and cls.is_device_capability(100)
and block_size in [32, 64])
selected_backend is None and cls.is_device_capability(100) and
(block_size == 32 or block_size % 64 == 0))
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_supported()[0])
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
Expand All @@ -260,7 +265,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
return ("vllm.v1.attention.backends.mla."
"flashinfer_mla.FlashInferMLABackend")
if use_flashmla:
if block_size != 64:
if block_size % 64 != 0:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

import numpy as np
import torch

from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
MultipleOf,
is_quantized_kv_cache)
from vllm.attention.layer import Attention
from vllm.attention.ops.merge_attn_states import merge_attn_states
Expand Down Expand Up @@ -47,6 +48,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]:
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@staticmethod
def get_supported_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]

@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
AttentionType, MultipleOf)
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -154,6 +154,10 @@ def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256]

@staticmethod
def get_supported_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(1)]
Comment on lines +158 to +159
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to take into account here whether FlashInfer is using TRTLLM (it has two different paths). When TRTLLM is enabled there are additional constraints on the block size (it can't be larger than 128).


@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
MultipleOf,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
Expand Down Expand Up @@ -39,6 +40,10 @@ def get_impl_cls() -> type["CutlassMLAImpl"]:
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder

@staticmethod
def get_supported_block_size() -> list[Union[int, MultipleOf]]:
return [128]
Comment on lines +44 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that this means the block size has to be exactly 128? Or it can be a multiple? If the latter, should we return MultipleOf(128)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ask because the changes in platfom/cuda.py imply that it can be a multiple of 128



class SM100Workspace:

Expand Down
Loading