Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions vllm/adapter_commons/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
from typing import Any, Callable, Dict, Optional, TypeVar

from torch import nn

Expand All @@ -24,14 +24,13 @@ def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
T = TypeVar('T')


class AdapterLRUCache(LRUCache[T]):
class AdapterLRUCache(LRUCache[int, T]):

def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn

def _on_remove(self, key: Hashable, value: Optional[T]):
def _on_remove(self, key: int, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/tokenizer_group/tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[AnyTokenizer](
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)

@classmethod
Expand Down
59 changes: 26 additions & 33 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Hashable, List, Literal,
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
overload)
Optional, Tuple, Type, TypeVar, Union, overload)
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -154,10 +153,12 @@
}

P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")

_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")


class _Sentinel:
...
Expand Down Expand Up @@ -190,50 +191,48 @@ def reset(self) -> None:
self.counter = 0


class LRUCache(Generic[T]):
class LRUCache(Generic[_K, _V]):

def __init__(self, capacity: int):
self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.pinned_items: Set[Hashable] = set()
def __init__(self, capacity: int) -> None:
self.cache = OrderedDict[_K, _V]()
self.pinned_items = set[_K]()
self.capacity = capacity

def __contains__(self, key: Hashable) -> bool:
def __contains__(self, key: _K) -> bool:
return key in self.cache

def __len__(self) -> int:
return len(self.cache)

def __getitem__(self, key: Hashable) -> T:
def __getitem__(self, key: _K) -> _V:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value

def __setitem__(self, key: Hashable, value: T) -> None:
def __setitem__(self, key: _K, value: _V) -> None:
self.put(key, value)

def __delitem__(self, key: Hashable) -> None:
def __delitem__(self, key: _K) -> None:
self.pop(key)

def touch(self, key: Hashable) -> None:
def touch(self, key: _K) -> None:
self.cache.move_to_end(key)

def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
value: Optional[_V]
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
value = default
return value

def put(self, key: Hashable, value: T) -> None:
def put(self, key: _K, value: _V) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()

def pin(self, key: Hashable) -> None:
def pin(self, key: _K) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
Expand All @@ -242,13 +241,13 @@ def pin(self, key: Hashable) -> None:
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)

def _unpin(self, key: Hashable) -> None:
def _unpin(self, key: _K) -> None:
self.pinned_items.remove(key)

def _on_remove(self, key: Hashable, value: Optional[T]):
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass

def remove_oldest(self, remove_pinned=False):
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache:
return

Expand All @@ -262,25 +261,23 @@ def remove_oldest(self, remove_pinned=False):
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since mm hash cache doesn't need the pin logic, could we optimize this function when there's no pinned items? For example

if remove_pinned and self.pinned_items:
    ...
else:
    lru_key = ...

Also I'm wondering whether to use .popitem(last=False) would be better?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this optimization is somewhat unnecessary, as the first key will be returned if pinned_items is empty. So either way, we will check for the emptiness of pinned_items once.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC in the branch with remove_pinned is False, we iterative an entire cache anyways whatever pinned_items are empty or not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next() statement only consumes the first item in the generator comprehension, it doesn't consume the generator fully.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you're right. Miss the point that it's a generator. Should be good then

self.pop(lru_key)
self.pop(lru_key) # type: ignore

def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()

def pop(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value)
value = self.cache.pop(key, default)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
return value

def clear(self):
def clear(self) -> None:
while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True)
self.cache.clear()
Expand Down Expand Up @@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return [item for sublist in lists for item in sublist]


_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")


def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike :class:`itertools.groupby`, groups are not broken by
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache
from vllm.utils import LRUCache

logger = init_logger(__name__)

Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(

# Init cache
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)

# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
Expand Down Expand Up @@ -120,7 +120,7 @@ class MMInputMapperServer:

def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)

def process_inputs(
self,
Expand Down
25 changes: 0 additions & 25 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections import OrderedDict
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
Expand Down Expand Up @@ -102,27 +101,3 @@ def make_zmq_socket(

finally:
ctx.destroy(linger=0)


K = TypeVar('K')
V = TypeVar('V')


class LRUDictCache(Generic[K, V]):

def __init__(self, size: int):
self.cache: OrderedDict[K, V] = OrderedDict()
self.size = size

def get(self, key: K, default=None) -> V:
if key not in self.cache:
return default

self.cache.move_to_end(key)
return self.cache[key]

def put(self, key: K, value: V):
self.cache[key] = value
self.cache.move_to_end(key)
if len(self.cache) > self.size:
self.cache.popitem(last=False)
Loading