diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3cf9d9369676..37b4f9a08e40 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -14,10 +14,11 @@ MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams from vllm.utils import sha256, sha256_cbor -from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - get_block_hash, get_group_id, +from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, + KVCacheBlock, get_block_hash, + get_group_id, get_request_block_hasher, hash_block_tokens, init_none_hash, make_block_hash_with_group_id) @@ -138,7 +139,7 @@ def test_prefill(hash_fn): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) # Check full block metadata parent_block_hash = None @@ -171,7 +172,7 @@ def test_prefill(hash_fn): blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([5], ) + assert blocks is not None and blocks.get_block_ids() == ([5], ) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -207,7 +208,7 @@ def test_prefill(hash_fn): blocks = manager.allocate_slots(req2, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([6], ) + assert blocks is not None and blocks.get_block_ids() == ([6], ) # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -227,7 +228,9 @@ def test_prefill(hash_fn): len(computed_blocks.blocks[0]) * 16, computed_blocks) # This block ID order also checks the eviction order. - assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], ) + assert blocks is not None and blocks.get_block_ids() == ([ + 7, 8, 9, 10, 4, 5, 6, 3, 2, 1 + ], ) assert free_block_queue.num_free_blocks == 0 assert (free_block_queue.fake_free_list_head.next_free_block @@ -261,8 +264,9 @@ def test_prefill_hybrid_model(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, - 8], [9, 10, 11, 12]) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [ + 5, 6, 7, 8 + ], [9, 10, 11, 12]) # Check full block metadata parent_block_hash = None @@ -298,7 +302,7 @@ def test_prefill_hybrid_model(): blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([13], [14], [15]) + assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15]) for block_per_group in computed_blocks.blocks: for block in block_per_group: if block != manager.block_pool.null_block: @@ -309,14 +313,15 @@ def test_prefill_hybrid_model(): manager.free(req1) cached_block_hash_to_block_bak = copy.copy( - manager.block_pool.cached_block_hash_to_block) + manager.block_pool.cached_block_hash_to_block._cache) - def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes], + def test_partial_request_hit(request_id: str, + hash_to_evict: list[BlockHashWithGroupId], expect_hit_length: int): req = make_request(request_id, common_token_ids + unique_token_ids, block_size, sha256) for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block.pop( + manager.block_pool.cached_block_hash_to_block._cache.pop( hash_with_group_id) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert len(req.block_hashes) == 3 @@ -324,7 +329,7 @@ def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes], for block_per_group in computed_blocks.blocks: assert len(block_per_group) == num_computed_tokens // block_size for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block[ + manager.block_pool.cached_block_hash_to_block._cache[ hash_with_group_id] = cached_block_hash_to_block_bak[ hash_with_group_id] manager.free(req) @@ -362,7 +367,8 @@ def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes], # total cache miss. # The cache hit length of full attention is 1 * block_size. # The cache hit length of sliding window is 2 * block_size. - # Then it is cache miss as the two type of layers have different hit length. + # Then it is cache miss as the two type of layers + # have different hit length. test_partial_request_hit("8", [ make_block_hash_with_group_id(block_hashes[2], 0), make_block_hash_with_group_id(block_hashes[0], 1), @@ -406,7 +412,7 @@ def test_prefill_plp(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) req0_block_hashes = [b.block_hash for b in blocks.blocks[0]] # Check full block metadata @@ -441,7 +447,7 @@ def test_prefill_plp(): blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([5], ) + assert blocks is not None and blocks.get_block_ids() == ([5], ) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -478,6 +484,7 @@ def test_prefill_plp(): blocks = manager.allocate_slots(req2, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) + assert blocks is not None block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes @@ -513,7 +520,7 @@ def test_decode(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) # Append slots without allocating a new block. req0.num_computed_tokens = 55 @@ -558,7 +565,8 @@ def test_evict(): blocks = manager.allocate_slots(req0, 5 * 16 + 7, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial + # 5 full + 1 partial + assert blocks is not None and len(blocks.blocks[0]) == 6 # 3 blocks. req1 = make_request("1", list(range(last_token_id, @@ -570,7 +578,7 @@ def test_evict(): blocks = manager.allocate_slots(req1, 3 * 16, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 3 # 3 full blocks + assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 # 10 - (6 + 3) == 1 @@ -592,7 +600,7 @@ def test_evict(): blocks = manager.allocate_slots(req2, 3, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([10], ) + assert blocks is not None and blocks.get_block_ids() == ([10], ) assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -617,7 +625,7 @@ def test_hash_block_correct_reuse(): blocks = manager.allocate_slots(req, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 1 + assert blocks is not None and len(blocks.blocks[0]) == 1 # Deallocate the block. manager.free(req) @@ -631,7 +639,7 @@ def test_hash_block_correct_reuse(): blocks = manager.allocate_slots(req, num_tokens - 1, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 1 + assert blocks is not None and len(blocks.blocks[0]) == 1 assert manager.block_pool.blocks[blocks.blocks[0] [0].block_id].block_hash is None @@ -658,7 +666,7 @@ def test_computed_blocks_not_evicted(): blocks = manager.allocate_slots(req0, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 1 + assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 1 # Allocate another block. @@ -670,7 +678,7 @@ def test_computed_blocks_not_evicted(): blocks = manager.allocate_slots(req1, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 1 + assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 # Free the blocks. @@ -688,7 +696,7 @@ def test_computed_blocks_not_evicted(): blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 1 + assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 @@ -712,7 +720,7 @@ def test_basic_prefix_caching_disabled(): blocks = manager.allocate_slots(req1, 10, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 3 + assert blocks is not None and len(blocks.blocks[0]) == 3 # Free the blocks. manager.free(req1) @@ -726,7 +734,7 @@ def test_basic_prefix_caching_disabled(): blocks = manager.allocate_slots(req2, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 4 + assert blocks is not None and len(blocks.blocks[0]) == 4 # New requests should not have any blocks. req3 = make_request("3", list(range(4)), block_size, sha256) @@ -773,7 +781,8 @@ def test_cache_blocks(hash_fn): assert len(block_pool.cached_block_hash_to_block) == 2 assert all([block.block_hash is not None for block in blocks]) - # Test that blocks that don't start from the beginning are cached correctly. + # Test that blocks that don't start from the beginning are cached + # correctly. blocks += [KVCacheBlock(block_id=2)] block_pool.cache_full_blocks( request=req, @@ -1101,7 +1110,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, sha256) blocks = manager.allocate_slots(req0, 55) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids @@ -1112,7 +1121,7 @@ def test_reset_prefix_cache(): blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([5], ) + assert blocks is not None and blocks.get_block_ids() == ([5], ) # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() @@ -1168,49 +1177,41 @@ def test_maybe_evict_cached_block(): # Manually add all blocks to cached_blocks for block, block_hash in zip(pool.blocks, block_hashes): block.block_hash = block_hash - pool.cached_block_hash_to_block[block_hash][block.block_id] = block + pool.cached_block_hash_to_block.insert(block_hash, block) block0, block1, block2, block3 = pool.blocks - assert pool.cached_block_hash_to_block == { + assert pool.cached_block_hash_to_block._cache == { block_hash0: { block0.block_id: block0, - block3.block_id: block3 - }, - block_hash1: { - block1.block_id: block1 + block3.block_id: block3, }, - block_hash2: { - block2.block_id: block2 - } + block_hash1: block1, + block_hash2: block2, } # Evict block1 pool._maybe_evict_cached_block(block1) - assert pool.cached_block_hash_to_block == { + assert pool.cached_block_hash_to_block._cache == { block_hash0: { block0.block_id: block0, block3.block_id: block3 }, - block_hash2: { - block2.block_id: block2 - } + block_hash2: block2, } # Evict block0: block_hash0 entry should NOT be removed, as block3 # also use the same hash pool._maybe_evict_cached_block(block0) - assert pool.cached_block_hash_to_block == { + assert pool.cached_block_hash_to_block._cache == { block_hash0: { block3.block_id: block3 }, - block_hash2: { - block2.block_id: block2 - } + block_hash2: block2, } # Evict block2 pool._maybe_evict_cached_block(block2) - assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}} + assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}} # Evict block3 pool._maybe_evict_cached_block(block3) - assert pool.cached_block_hash_to_block == {} + assert pool.cached_block_hash_to_block._cache == {} @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) @@ -1374,7 +1375,7 @@ def test_eagle_with_sliding_window(): # Evict the first block in the request assert manager.block_pool.get_cached_block( block_hash_first_block, kv_cache_group_ids=[0]) is not None - manager.block_pool.cached_block_hash_to_block.pop( + manager.block_pool.cached_block_hash_to_block._cache.pop( make_block_hash_with_group_id(block_hash_first_block, 0)) # New request @@ -1386,3 +1387,78 @@ def test_eagle_with_sliding_window(): # there will be no matched prefix. assert len(computed_blocks.blocks[0]) == 0 assert num_tokens == 0 + + +def test_block_lookup_cache_single_block_per_key(): + cache = BlockHashToBlockMap() + key0 = BlockHashWithGroupId(b"hash0") + key1 = BlockHashWithGroupId(b"hash1") + key2 = BlockHashWithGroupId(b"hash2") + block0 = KVCacheBlock(0) + block1 = KVCacheBlock(1) + + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is None + assert cache.get_one_block(key2) is None + # key0 inserted + cache.insert(key0, block0) + assert cache.get_one_block(key0) is block0 + assert cache.get_one_block(key1) is None + assert cache.get_one_block(key2) is None + # key1 inserted + cache.insert(key1, block1) + assert cache.get_one_block(key0) is block0 + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # No block poped due to block_id mismatch + assert cache.pop(key0, 100) is None + assert cache.get_one_block(key0) is block0 + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # block poped with (key0, block ID 0) + assert cache.pop(key0, 0) is block0 + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # No block poped due to block_id mismatch + assert cache.pop(key0, 1) is None + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # block poped with (key1, block ID 1) + assert cache.pop(key1, 1) is block1 + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is None + assert cache.get_one_block(key2) is None + + +def test_block_lookup_cache_multi_blocks_per_key(): + cache = BlockHashToBlockMap() + key0 = BlockHashWithGroupId(b"hash0") + key1 = BlockHashWithGroupId(b"hash1") + block00 = KVCacheBlock(0) + block01 = KVCacheBlock(1) + block10 = KVCacheBlock(10) + block11 = KVCacheBlock(11) + + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is None + + cache.insert(key0, block00) + cache.insert(key0, block01) + cache.insert(key1, block10) + cache.insert(key1, block11) + + assert cache.get_one_block(key0) is block00 + assert cache.pop(key0, 0) is block00 + assert cache.get_one_block(key0) is block01 + assert cache.pop(key0, 1) is block01 + assert cache.get_one_block(key0) is None + assert cache.pop(key0, 2) is None + + assert cache.get_one_block(key1) is block10 + assert cache.pop(key1, 10) is block10 + assert cache.get_one_block(key1) is block11 + assert cache.pop(key1, 11) is block11 + assert cache.get_one_block(key1) is None + assert cache.pop(key1, 12) is None diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index b70850a9bcff..01b54ae56e90 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -47,16 +47,15 @@ def run_one_case(block_is_cached, tail_token, expect_length): BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] - block_pool.cached_block_hash_to_block.clear() + block_pool.cached_block_hash_to_block._cache.clear() # Mock the block pool with the cached blocks for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[ - make_block_hash_with_group_id(block_hash, 0)] = { - i: block_pool.blocks[i + 10], - } + block_pool.cached_block_hash_to_block.insert( + make_block_hash_with_group_id(block_hash, 0), + block_pool.blocks[i + 10]) computed_blocks = manager.find_longest_cache_hit( block_hashes=block_hash_list, @@ -112,16 +111,15 @@ def run_one_case(block_is_cached, expect_length): BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] - block_pool.cached_block_hash_to_block.clear() + block_pool.cached_block_hash_to_block._cache.clear() # Mock the block pool with the cached blocks for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[ - make_block_hash_with_group_id(block_hash, 0)] = { - i: block_pool.blocks[i + 10], - } + block_pool.cached_block_hash_to_block.insert( + make_block_hash_with_group_id(block_hash, 0), + block_pool.blocks[i + 10]) computed_blocks = manager.find_longest_cache_hit( block_hashes=block_hash_list, diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index d1e1c1c8d038..3cc738304821 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from collections.abc import Iterable -from typing import Optional +from typing import Any, Optional, Union from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, BlockRemoved, BlockStored, @@ -19,6 +18,103 @@ logger = init_logger(__name__) +class BlockHashToBlockMap: + """ + Cache of blocks that are used for prefix caching. It caches blocks + from hash directly to a block or multiple blocks + (i.e. {block_hash: KVCacheBlocks}) + - Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks + would simply be a KVCacheBlock. + - Otherwise, KVCacheBlocks is a dict from {block_id: KVCacheBlock} + + A cached block is a full block with a block hash that can be used + for prefix caching. + The cached block may be used by running requests or in the + free_block_queue that could potentially be evicted. + + NOTE #1: We currently don't de-duplicate the blocks in the cache, + meaning that if a block becomes full and is cached, we don't check + if there is already an identical block in the cache. This is because + we want to make sure the allocated block IDs won't change so that + block tables are append-only. + NOTE #2: The union type is introduced in order to reduce GC costs + from the inner dict. + """ + + def __init__(self): + self._cache: dict[BlockHashWithGroupId, + Union[KVCacheBlock, dict[int, KVCacheBlock]]] = {} + + def get_one_block(self, + key: BlockHashWithGroupId) -> Optional[KVCacheBlock]: + """ + Gets any block with the given block hash key. + """ + blocks = self._cache.get(key) + if blocks is not None: + if isinstance(blocks, KVCacheBlock): + return blocks + if isinstance(blocks, dict): + return next(iter(blocks.values())) + self._unexpected_blocks_type(blocks) + return None + + def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None: + """ + Inserts the KVCacheBlock to the cache + """ + blocks = self._cache.get(key) + if blocks is None: + # When key is not found, attach a single block to the key + self._cache[key] = block + elif isinstance(blocks, KVCacheBlock): + # If there's a block with the same key, merge the original block + # and the new block into a dict + self._cache[key] = {blocks.block_id: blocks, block.block_id: block} + elif isinstance(blocks, dict): + # If it's already a dict, simply insert the block + blocks[block.block_id] = block + else: + self._unexpected_blocks_type(blocks) + + def pop(self, key: BlockHashWithGroupId, + block_id: int) -> Optional[KVCacheBlock]: + """ + Checks if block_hash exists and pop block_id from the cache + """ + blocks = self._cache.pop(key, None) + if blocks is None: + # block_hash not found in the cache + return None + # TODO(Jialin): If key is found, block_id should always present + # in blocks. We currently keep the original behaviour for safety. + # + # Will add block_id == blocks.block_id assertion and + # use del blocks[block_id] instead as followup. + if isinstance(blocks, KVCacheBlock): + if blocks.block_id == block_id: + return blocks + # If the single block ID doesn't match, we should put the + # block back (it should happen rarely) + self._cache[key] = blocks + return None + if isinstance(blocks, dict): + # Try to pop block_id from the block dict, and if dict still + # contain blocks, put back to the cache. + block = blocks.pop(block_id, None) + if len(blocks) > 0: + self._cache[key] = blocks + return block + self._unexpected_blocks_type(blocks) + return None + + def __len__(self) -> int: + return len(self._cache) + + def _unexpected_blocks_type(self, blocks: Any) -> None: + raise AssertionError(f"Invalid KV cache block type {type(blocks)}") + + class BlockPool: """BlockPool that manages KVCacheBlocks. It provides methods to allocate, free and cache the kv cache blocks. The @@ -51,17 +147,9 @@ def __init__( # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # {block_hash: {block ID: block}}. A cached block is - # a full block with a block hash that can be used for prefix caching. - # The cached block may be used by running requests or in the - # free_block_queue that could potentially be evicted. - # NOTE: We currently don't de-duplicate the blocks in the cache, - # meaning that if a block becomes full and is cached, we don't check - # if there is already an identical block in the cache. This is because - # we want to make sure the allocated block IDs won't change so that - # block tables are append-only. - self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[ - int, KVCacheBlock]] = defaultdict(dict) + # Cache for block lookup + self.cached_block_hash_to_block: BlockHashToBlockMap = \ + BlockHashToBlockMap() # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to @@ -90,12 +178,11 @@ def get_cached_block( for group_id in kv_cache_group_ids: block_hash_with_group_id = make_block_hash_with_group_id( block_hash, group_id) - cached_blocks_one_group = self.cached_block_hash_to_block.get( + block = self.cached_block_hash_to_block.get_one_block( block_hash_with_group_id) - if not cached_blocks_one_group: + if not block: return None - first_block = next(iter(cached_blocks_one_group.values())) - cached_blocks.append(first_block) + cached_blocks.append(block) return cached_blocks def cache_full_blocks( @@ -140,8 +227,8 @@ def cache_full_blocks( block_hash_with_group_id = make_block_hash_with_group_id( block_hash, kv_cache_group_id) blk.block_hash = block_hash_with_group_id - self.cached_block_hash_to_block[block_hash_with_group_id][ - blk.block_id] = blk + self.cached_block_hash_to_block.insert(block_hash_with_group_id, + blk) if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) @@ -211,15 +298,14 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: if block_hash is None: # The block doesn't have hash, eviction is not needed return False - blocks_by_id = self.cached_block_hash_to_block.get(block_hash) - if blocks_by_id is None: - # block_hash not found in cached_block_hash_to_block, + + if self.cached_block_hash_to_block.pop(block_hash, + block.block_id) is None: + # block not found in cached_block_hash_to_block, # eviction is not needed return False + block.reset_hash() - blocks_by_id.pop(block.block_id, None) - if len(blocks_by_id) == 0: - del self.cached_block_hash_to_block[block_hash] if self.enable_kv_cache_events: # FIXME (Chen): Not sure whether we should return `hash_value` @@ -283,7 +369,7 @@ def reset_prefix_cache(self) -> bool: return False # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) + self.cached_block_hash_to_block = BlockHashToBlockMap() # Remove all hashes from all blocks. for block in self.blocks: