Skip to content
114 changes: 113 additions & 1 deletion vllm/distributed/kv_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from collections import Counter, deque
from collections.abc import Callable
from dataclasses import asdict
from itertools import count
Expand Down Expand Up @@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
lora_id: int | None
medium: str | None

def __hash__(self) -> int:
return hash(
(
tuple(self.block_hashes),
self.parent_block_hash,
tuple(self.token_ids),
self.block_size,
self.lora_id,
self.medium,
)
)


class BlockRemoved(KVCacheEvent):
block_hashes: list[ExternalBlockHash]
medium: str | None

def __hash__(self) -> int:
return hash((tuple(self.block_hashes), self.medium))


class AllBlocksCleared(KVCacheEvent):
pass
Expand All @@ -68,6 +83,103 @@ class KVEventBatch(EventBatch):
events: list[BlockStored | BlockRemoved | AllBlocksCleared]


class KVEventAggregator:
"""
Aggregates KV events across multiple workers.
Tracks how many times each event appears and returns only those
that were emitted by all workers.
"""

__slots__ = ("_event_counter", "_num_workers")

def __init__(self, num_workers: int) -> None:
if num_workers <= 0:
raise ValueError("num_workers must be greater than zero.")
self._event_counter: Counter[KVCacheEvent] = Counter()
self._num_workers: int = num_workers

def add_events(self, events: list[KVCacheEvent]) -> None:
"""
Add events from a worker batch.

:param events: List of KVCacheEvent objects.
"""
if not isinstance(events, list):
raise TypeError("events must be a list of KVCacheEvent.")
self._event_counter.update(events)

def get_common_events(self) -> list[KVCacheEvent]:
"""
Return events that appeared in all workers.

:return: List of events present in all workers.
"""
return [
event
for event, count in self._event_counter.items()
if count == self._num_workers
]

def get_all_events(self) -> list[KVCacheEvent]:
"""
Return all events for all workers.

:return: List of events for all workers.
"""
return list(self._event_counter.elements())

def clear_events(self) -> None:
"""
Clear all tracked events.
"""
self._event_counter.clear()

def increment_workers(self, count: int = 1) -> None:
"""
Increment the number of workers contributing events.

:param count: Number of workers to add.
"""
if count <= 0:
raise ValueError("count must be positive.")
self._num_workers += count

def reset_workers(self) -> None:
"""
Reset the number of workers to 1.
"""
self._num_workers = 1

def __repr__(self) -> str:
return (
f"<KVEventAggregator workers={self._num_workers}, "
f"events={len(self._event_counter)}>"
)


class KVConnectorKVEvents(ABC):
"""
Abstract base class for KV events.
Acts as a container for KV events from the connector.
"""

@abstractmethod
def add_events(self, events: list[KVCacheEvent]) -> None:
raise NotImplementedError

@abstractmethod
def aggregate(self) -> "KVConnectorKVEvents":
raise NotImplementedError

@abstractmethod
def increment_workers(self, count: int = 1) -> None:
raise NotImplementedError

@abstractmethod
def get_all_events(self) -> list[KVCacheEvent]:
raise NotImplementedError


class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches with data parallelism
support.
Expand Down
21 changes: 21 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def update_finished_set(
finished_sending = set[str]()
finished_recving = set[str]()
aggregated_kv_connector_stats = None
combined_kv_cache_events = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
assert model_runner_output is not None
Expand Down Expand Up @@ -201,16 +202,36 @@ def update_finished_set(
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)

# Combine kv_cache_events from all workers.
if combined_kv_cache_events is None:
# Use the first worker's kv_cache events as start event list.
combined_kv_cache_events = kv_output.kv_cache_events
elif kv_cache_events := kv_output.kv_cache_events:
assert isinstance(
combined_kv_cache_events,
type(kv_cache_events),
)
worker_kv_cache_events = kv_cache_events.get_all_events()
combined_kv_cache_events.add_events(worker_kv_cache_events)
combined_kv_cache_events.increment_workers()

invalid_block_ids |= kv_output.invalid_block_ids

# select output of the worker specified by output_rank
output = outputs[output_rank]

# Aggregate the events across workers.
# This operation needs to be done post worker processing so that we have all
# events for all workers.
if combined_kv_cache_events is not None:
combined_kv_cache_events = combined_kv_cache_events.aggregate()

assert output is not None
output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
kv_cache_events=combined_kv_cache_events or None,
invalid_block_ids=invalid_block_ids,
expected_finished_count=self._expected_finished_count,
)
Expand Down
10 changes: 9 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
Expand Down Expand Up @@ -379,6 +379,14 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
return None

def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
"""
Get the KV connector kv cache events collected during the last interval.
This function should be called by the model runner every time after the
model execution and before cleanup.
"""
return None

def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any

import torch
Expand All @@ -8,13 +9,20 @@
)

from vllm.config import VllmConfig
from vllm.distributed.kv_events import (
BlockStored,
KVCacheEvent,
KVConnectorKVEvents,
KVEventAggregator,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
Expand All @@ -26,6 +34,37 @@
logger = init_logger(__name__)


class LMCacheKVEvents(KVConnectorKVEvents):
"""
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
"""

def __init__(self, num_workers: int) -> None:
self._aggregator = KVEventAggregator(num_workers)

def add_events(self, events: list[KVCacheEvent]) -> None:
self._aggregator.add_events(events)

def aggregate(self) -> "LMCacheKVEvents":
"""
Aggregate KV events and retain only common events.
"""
common_events = self._aggregator.get_common_events()
self._aggregator.clear_events()
self._aggregator.add_events(common_events)
self._aggregator.reset_workers()
return self

def increment_workers(self, count: int = 1) -> None:
self._aggregator.increment_workers(count)

def get_all_events(self) -> list[KVCacheEvent]:
return self._aggregator.get_all_events()

def __repr__(self) -> str:
return f"<LMCacheKVEvents events={self.get_all_events()}>"


class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(
self,
Expand Down Expand Up @@ -54,6 +93,8 @@ def __init__(

self._lmcache_engine = cls(vllm_config, role, self)

self._kv_cache_events: list[KVCacheEvent] = []

# ==============================
# Worker-side methods
# ==============================
Expand Down Expand Up @@ -151,6 +192,31 @@ def get_block_ids_with_load_errors(self) -> set[int]:
# Fallback for older versions that don't support this method
return set()

def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
"""
Get the KV connector kv cache events collected during the last interval.
"""

events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
if not events:
return None

blocks: list[BlockStored] = [
BlockStored(
block_hashes=e.block_hashes,
parent_block_hash=e.parent_block_hash,
token_ids=e.token_ids,
lora_id=e.lora_id,
block_size=e.block_size,
medium=e.medium,
)
for e in events
]

lmcache_kv_events = LMCacheKVEvents(num_workers=1)
lmcache_kv_events.add_events(blocks)
return lmcache_kv_events

# ==============================
# Scheduler-side methods
# ==============================
Expand Down Expand Up @@ -198,6 +264,21 @@ def build_connector_meta(
"""
return self._lmcache_engine.build_connector_meta(scheduler_output)

def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.

Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
# Get the KV events
kv_cache_events = connector_output.kv_cache_events
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
return
self._kv_cache_events.extend(kv_cache_events.get_all_events())
return

def request_finished(
self,
request: "Request",
Expand All @@ -214,3 +295,14 @@ def request_finished(
returned by the engine.
"""
return self._lmcache_engine.request_finished(request, block_ids)

def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.

Yields:
New KV cache events since the last call.
"""
if self._kv_cache_events is not None:
yield from self._kv_cache_events
self._kv_cache_events.clear()
4 changes: 4 additions & 0 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from vllm.v1.core.sched.output import SchedulerOutput

if TYPE_CHECKING:
from vllm.distributed.kv_events import KVConnectorKVEvents
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
else:
KVConnectorStats = object
KVConnectorKVEvents = object


class LogprobsLists(NamedTuple):
Expand Down Expand Up @@ -119,6 +121,7 @@ class KVConnectorOutput:
finished_sending: set[str] | None = None
finished_recving: set[str] | None = None
kv_connector_stats: KVConnectorStats | None = None
kv_cache_events: KVConnectorKVEvents | None = None
# IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them
invalid_block_ids: set[int] = field(default_factory=set)
Expand All @@ -134,6 +137,7 @@ def is_empty(self):
not self.finished_sending
and not self.finished_recving
and not self.kv_connector_stats
and not self.kv_cache_events
and not self.invalid_block_ids
)

Expand Down
Loading