Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
ExternalBlockHash,
FreeKVCacheBlockQueue, KVCacheBlock,
SingleTypeKVCacheBlocks,
get_block_hash,
make_block_hash_with_group_id,
maybe_convert_block_hash)
Expand Down Expand Up @@ -319,7 +320,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
medium=MEDIUM_GPU))
return True

def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
def touch(self, blocks: tuple[SingleTypeKVCacheBlocks, ...]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Expand Down
9 changes: 5 additions & 4 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Optional

from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
SingleTypeKVCacheBlocks)
from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(

def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int,
new_computed_blocks: tuple[
list[KVCacheBlock], ...],
SingleTypeKVCacheBlocks, ...],
num_encoder_tokens: int) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Expand All @@ -69,15 +70,15 @@ def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int,
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, [])
request_id, num_encoder_tokens, ())
else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i])
return num_blocks_to_allocate

def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None:
new_computed_blocks: tuple[SingleTypeKVCacheBlocks, ...]) -> None:
"""
Add the new computed blocks to the request.
Expand Down
45 changes: 28 additions & 17 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.core.kv_cache_utils import KVCacheBlock, SingleTypeKVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
Expand All @@ -22,7 +22,7 @@ class KVCacheBlocks:
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks: tuple[list[KVCacheBlock], ...]
blocks: tuple[SingleTypeKVCacheBlocks, ...]
"""
`blocks[i][j]` refers to the i-th kv_cache_group
and the j-th block of tokens.We don't use block of
Expand All @@ -35,8 +35,9 @@ class KVCacheBlocks:
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(
tuple(blk1 + blk2
for blk1, blk2 in zip(self.blocks, other.blocks)))
tuple(
list(blk1) + list(blk2)
for blk1, blk2 in zip(self.blocks, other.blocks)))

@overload
def get_block_ids(
Expand Down Expand Up @@ -78,8 +79,10 @@ def get_unhashed_block_ids(self) -> list[int]:
]

def new_empty(self) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
"""
Creates a new KVCacheBlocks instance with no blocks.
"""
return KVCacheBlocks(tuple(() for _ in range(len(self.blocks))))


class KVCacheManager:
Expand Down Expand Up @@ -130,6 +133,12 @@ def __init__(
self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config

# Pre-constructed KVCacheBlocks with no blocks, callers should use this
# via create_kv_cache_blocks instead of creating new ones to avoid GC
# overhead.
self.empty_kv_cache_blocks = KVCacheBlocks(
tuple(() for _ in range(self.num_kv_cache_groups)))

@property
def usage(self) -> float:
"""Get the KV cache usage.
Expand Down Expand Up @@ -169,7 +178,7 @@ def get_computed_blocks(self,
if (not self.enable_caching
or (request.sampling_params is not None
and request.sampling_params.prompt_logprobs is not None)):
return self.create_empty_block_list(), 0
return self.empty_kv_cache_blocks, 0

# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
Expand All @@ -196,7 +205,8 @@ def get_computed_blocks(self,
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens

return KVCacheBlocks(computed_blocks), num_new_computed_tokens
return (self.create_kv_cache_blocks(computed_blocks),
num_new_computed_tokens)

def allocate_slots(
self,
Expand Down Expand Up @@ -249,8 +259,7 @@ def allocate_slots(
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = tuple(
[] for _ in range(len(self.kv_cache_config.kv_cache_groups)))
new_computed_block_list = self.empty_kv_cache_blocks.blocks

# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
Expand Down Expand Up @@ -299,7 +308,7 @@ def allocate_slots(
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
if not self.enable_caching or delay_cache_blocks:
return KVCacheBlocks(new_blocks)
return self.create_kv_cache_blocks(new_blocks)

# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
Expand All @@ -309,7 +318,7 @@ def allocate_slots(
request.num_tokens)
self.coordinator.cache_blocks(request, num_tokens_to_cache)

return KVCacheBlocks(new_blocks)
return self.create_kv_cache_blocks(new_blocks)

def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
Expand Down Expand Up @@ -390,7 +399,8 @@ def take_events(self) -> list[KVCacheEvent]:

def get_blocks(self, request_id: str) -> KVCacheBlocks:
"""Get the blocks of a request."""
return KVCacheBlocks(self.coordinator.get_blocks(request_id))
return self.create_kv_cache_blocks(
self.coordinator.get_blocks(request_id))

def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
"""Get the block ids of a request."""
Expand All @@ -401,7 +411,8 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
if self.enable_caching:
self.coordinator.cache_blocks(request, num_computed_tokens)

def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks(tuple([]
for _ in range(self.num_kv_cache_groups)))
def create_kv_cache_blocks(
self, blocks: tuple[list[KVCacheBlock], ...]) -> KVCacheBlocks:
# Only create new KVCacheBlocks for non-empty blocks
return KVCacheBlocks(blocks) if any(
blocks) else self.empty_kv_cache_blocks
10 changes: 10 additions & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dataclasses import dataclass
from typing import Any, Callable, NewType, Optional, Union

from typing_extensions import TypeAlias

from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
Expand Down Expand Up @@ -213,6 +215,14 @@ def __repr__(self) -> str:
f"next_free_block={next_block_id})")


# Represents KVCacheBlocks associated with a request.
# It could be represented as:
# - list[KVCacheBlock] for more than one KVCacheBlock
# - an empty tuple for requests without KVCacheBlock
# (a precomputed KVCacheBlocks is in KVCacheManager to avoid GC overhead)
SingleTypeKVCacheBlocks: TypeAlias = Sequence[KVCacheBlock]


class FreeKVCacheBlockQueue:
"""This class organizes a list of KVCacheBlock objects to a doubly linked
list of free blocks. We implement this class instead of using Python
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def schedule(self) -> SchedulerOutput:
# after async KV recvs are completed.
else:
new_computed_blocks = (
self.kv_cache_manager.create_empty_block_list())
self.kv_cache_manager.empty_kv_cache_blocks)
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens

Expand Down
24 changes: 19 additions & 5 deletions vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
SingleTypeKVCacheBlocks)
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
CrossAttentionSpec, FullAttentionSpec,
KVCacheSpec, MambaSpec,
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(

def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[KVCacheBlock]) -> int:
new_computed_blocks: SingleTypeKVCacheBlocks) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Expand Down Expand Up @@ -87,7 +88,7 @@ def get_num_blocks_to_allocate(

def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[KVCacheBlock]) -> None:
new_computed_blocks: SingleTypeKVCacheBlocks) -> None:
"""
Add the new computed blocks to the request.
Expand Down Expand Up @@ -564,9 +565,22 @@ def get_num_common_prefix_blocks(self, request_id: str,

def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[KVCacheBlock]) -> int:
new_computed_blocks: SingleTypeKVCacheBlocks) -> int:
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks
"""
assert isinstance(self.kv_cache_spec, MambaSpec)
if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (self.kv_cache_spec.block_size *
Expand All @@ -590,7 +604,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):

def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[KVCacheBlock]) -> None:
new_computed_blocks: SingleTypeKVCacheBlocks) -> None:
# We do not cache blocks for cross-attention to be shared between
# requests, so `new_computed_blocks` should always be empty.
assert len(new_computed_blocks) == 0
Expand Down
Loading