From f2c599559bcbbb6ef9e237635aacb6e2a552a440 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 14:47:26 +0000 Subject: [PATCH] Clean up and consolidate LRUCache Signed-off-by: DarkLight1337 --- vllm/adapter_commons/models.py | 9 ++- .../tokenizer_group/tokenizer_group.py | 2 +- vllm/utils.py | 59 ++++++++----------- vllm/v1/engine/mm_input_mapper.py | 6 +- vllm/v1/utils.py | 25 -------- 5 files changed, 34 insertions(+), 67 deletions(-) diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py index a5c04ab78fbe..468904c90fff 100644 --- a/vllm/adapter_commons/models.py +++ b/vllm/adapter_commons/models.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Hashable, Optional, TypeVar +from typing import Any, Callable, Dict, Optional, TypeVar from torch import nn @@ -24,14 +24,13 @@ def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): T = TypeVar('T') -class AdapterLRUCache(LRUCache[T]): +class AdapterLRUCache(LRUCache[int, T]): - def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], - None]): + def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): super().__init__(capacity) self.deactivate_fn = deactivate_fn - def _on_remove(self, key: Hashable, value: Optional[T]): + def _on_remove(self, key: int, value: Optional[T]): logger.debug("Removing adapter int id: %d", key) self.deactivate_fn(key) return super()._on_remove(key, value) diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 761b07f34d2f..95a8f7098bba 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -22,7 +22,7 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.max_input_length = max_input_length self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) max_loras = tokenizer_config.get("max_loras", 0) - self.lora_tokenizers = LRUCache[AnyTokenizer]( + self.lora_tokenizers = LRUCache[int, AnyTokenizer]( capacity=max(max_loras, max_num_seqs) if enable_lora else 0) @classmethod diff --git a/vllm/utils.py b/vllm/utils.py index ba567feb1979..3934903385ad 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -21,14 +21,13 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task -from collections import UserDict, defaultdict +from collections import OrderedDict, UserDict, defaultdict from collections.abc import Iterable, Mapping from dataclasses import dataclass, field from functools import lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Dict, Generator, Generic, Hashable, List, Literal, - Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union, - overload) + Optional, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 import numpy as np @@ -154,10 +153,12 @@ } P = ParamSpec('P') -K = TypeVar("K") T = TypeVar("T") U = TypeVar("U") +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") + class _Sentinel: ... @@ -190,50 +191,48 @@ def reset(self) -> None: self.counter = 0 -class LRUCache(Generic[T]): +class LRUCache(Generic[_K, _V]): - def __init__(self, capacity: int): - self.cache: OrderedDict[Hashable, T] = OrderedDict() - self.pinned_items: Set[Hashable] = set() + def __init__(self, capacity: int) -> None: + self.cache = OrderedDict[_K, _V]() + self.pinned_items = set[_K]() self.capacity = capacity - def __contains__(self, key: Hashable) -> bool: + def __contains__(self, key: _K) -> bool: return key in self.cache def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> T: + def __getitem__(self, key: _K) -> _V: value = self.cache[key] # Raise KeyError if not exists self.cache.move_to_end(key) return value - def __setitem__(self, key: Hashable, value: T) -> None: + def __setitem__(self, key: _K, value: _V) -> None: self.put(key, value) - def __delitem__(self, key: Hashable) -> None: + def __delitem__(self, key: _K) -> None: self.pop(key) - def touch(self, key: Hashable) -> None: + def touch(self, key: _K) -> None: self.cache.move_to_end(key) - def get(self, - key: Hashable, - default_value: Optional[T] = None) -> Optional[T]: - value: Optional[T] + def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: + value: Optional[_V] if key in self.cache: value = self.cache[key] self.cache.move_to_end(key) else: - value = default_value + value = default return value - def put(self, key: Hashable, value: T) -> None: + def put(self, key: _K, value: _V) -> None: self.cache[key] = value self.cache.move_to_end(key) self._remove_old_if_needed() - def pin(self, key: Hashable) -> None: + def pin(self, key: _K) -> None: """ Pins a key in the cache preventing it from being evicted in the LRU order. @@ -242,13 +241,13 @@ def pin(self, key: Hashable) -> None: raise ValueError(f"Cannot pin key: {key} not in cache.") self.pinned_items.add(key) - def _unpin(self, key: Hashable) -> None: + def _unpin(self, key: _K) -> None: self.pinned_items.remove(key) - def _on_remove(self, key: Hashable, value: Optional[T]): + def _on_remove(self, key: _K, value: Optional[_V]) -> None: pass - def remove_oldest(self, remove_pinned=False): + def remove_oldest(self, *, remove_pinned: bool = False) -> None: if not self.cache: return @@ -262,17 +261,15 @@ def remove_oldest(self, remove_pinned=False): "cannot remove oldest from the cache.") else: lru_key = next(iter(self.cache)) - self.pop(lru_key) + self.pop(lru_key) # type: ignore def _remove_old_if_needed(self) -> None: while len(self.cache) > self.capacity: self.remove_oldest() - def pop(self, - key: Hashable, - default_value: Optional[T] = None) -> Optional[T]: + def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: run_on_remove = key in self.cache - value: Optional[T] = self.cache.pop(key, default_value) + value = self.cache.pop(key, default) # remove from pinned items if key in self.pinned_items: self._unpin(key) @@ -280,7 +277,7 @@ def pop(self, self._on_remove(key, value) return value - def clear(self): + def clear(self) -> None: while len(self.cache) > 0: self.remove_oldest(remove_pinned=True) self.cache.clear() @@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]: return [item for sublist in lists for item in sublist] -_K = TypeVar("_K", bound=Hashable) -_V = TypeVar("_V") - - def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): """ Unlike :class:`itertools.groupby`, groups are not broken by diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index bba71c29cc10..d03da8209310 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -8,7 +8,7 @@ from vllm.logger import init_logger from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalKwargs, MultiModalRegistry) -from vllm.v1.utils import LRUDictCache +from vllm.utils import LRUCache logger = init_logger(__name__) @@ -44,7 +44,7 @@ def __init__( # Init cache self.use_cache = not model_config.disable_mm_preprocessor_cache - self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) + self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE) # DEBUG: Set to None to disable self.mm_debug_cache_hit_ratio_steps = None @@ -120,7 +120,7 @@ class MMInputMapperServer: def __init__(self, model_config): self.use_cache = not model_config.disable_mm_preprocessor_cache - self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) + self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE) def process_inputs( self, diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 5f327d706683..e802c6439b74 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,4 +1,3 @@ -from collections import OrderedDict from collections.abc import Sequence from contextlib import contextmanager from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union, @@ -102,27 +101,3 @@ def make_zmq_socket( finally: ctx.destroy(linger=0) - - -K = TypeVar('K') -V = TypeVar('V') - - -class LRUDictCache(Generic[K, V]): - - def __init__(self, size: int): - self.cache: OrderedDict[K, V] = OrderedDict() - self.size = size - - def get(self, key: K, default=None) -> V: - if key not in self.cache: - return default - - self.cache.move_to_end(key) - return self.cache[key] - - def put(self, key: K, value: V): - self.cache[key] = value - self.cache.move_to_end(key) - if len(self.cache) > self.size: - self.cache.popitem(last=False)