From e2f0ed8a0a6155931e3f0bf79102024ba1493cd4 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 23 Sep 2025 20:55:21 -0700 Subject: [PATCH 01/12] Migrate the LMCache integration code to be vLLM native Signed-off-by: ApostaC --- .../kv_connector/v1/lmcache_connector.py | 17 +- .../v1/lmcache_integration/__init__.py | 2 + .../v1/lmcache_integration/utils.py | 211 +++ .../v1/lmcache_integration/vllm_v1_adapter.py | 1353 +++++++++++++++++ 4 files changed, 1581 insertions(+), 2 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 2b0abe983fbb..835c7f041df9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING, Any, Optional import torch -from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl +from lmcache.integration.vllm.vllm_v1_adapter import ( + LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl) from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -11,6 +12,9 @@ from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput +from .lmcache_integration.vllm_v1_adapter import ( + LMCacheConnectorV1Impl as LMCacheConnectorUpstreamImpl) # yapf: disable + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext @@ -24,7 +28,16 @@ class LMCacheConnectorV1(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) - self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) + use_native = vllm_config.kv_transfer_config.get_from_extra_config( + "use_native", True) + if use_native: + logger.info("Initializing native LMCache connector") + cls = LMCacheConnectorUpstreamImpl + else: + logger.info("Initializing latest dev LMCache connector") + cls = LMCacheConnectorLatestImpl + + self._lmcache_engine = cls(vllm_config, role, self) # ============================== # Worker-side methods diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py new file mode 100644 index 000000000000..208f01a7cb5e --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py new file mode 100644 index 000000000000..995717c088a6 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Standard +import os +import threading +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.multimodal.inputs import PlaceholderRange + from vllm.v1.request import Request + +# Third Party +import torch +# First Party +from lmcache.config import LMCacheEngineConfig as Config +from lmcache.logging import init_logger +from lmcache.v1.config import LMCacheEngineConfig as V1Config + +logger = init_logger(__name__) +ENGINE_NAME = "vllm-instance" + +# Thread-safe singleton storage +_config_instance: Union[Config, V1Config, None] = None +_config_lock = threading.Lock() + + +def is_false(value: str) -> bool: + """Check if the given string value is equivalent to 'false'.""" + return value.lower() in ("false", "0", "no", "n", "off") + + +def lmcache_get_or_create_config() -> Union[Config, V1Config]: + """Get the LMCache configuration from the environment variable + `LMCACHE_CONFIG_FILE`. If the environment variable is not set, this + function will return the default configuration. + + This function is thread-safe and implements singleton pattern, + ensuring the configuration is loaded only once. + """ + global _config_instance + + # Double-checked locking for thread-safe singleton + if _config_instance is None: + with _config_lock: + if _config_instance is None: # Check again within lock + if is_false(os.getenv("LMCACHE_USE_EXPERIMENTAL", "True")): + logger.warning( + "Detected LMCACHE_USE_EXPERIMENTAL is set to False. " + "Using legacy configuration is deprecated and will " + "be remove soon! Please set LMCACHE_USE_EXPERIMENTAL " + "to True.") + LMCacheEngineConfig = Config # type: ignore[assignment] + else: + LMCacheEngineConfig = V1Config # type: ignore[assignment] + + if "LMCACHE_CONFIG_FILE" not in os.environ: + logger.warning( + "No LMCache configuration file is set. Trying to read" + " configurations from the environment variables.") + logger.warning( + "You can set the configuration file through " + "the environment variable: LMCACHE_CONFIG_FILE") + _config_instance = LMCacheEngineConfig.from_env() + else: + config_file = os.environ["LMCACHE_CONFIG_FILE"] + logger.info("Loading LMCache config file %s", config_file) + _config_instance = LMCacheEngineConfig.from_file( + config_file) + # Update config from environment variables + _config_instance.update_config_from_env() + return _config_instance + + +def hex_hash_to_int16(s: str) -> int: + """ + Convert a hex hash string to a 16-bit integer. + """ + return int(s, 16) & 0xFFFF + + +def apply_mm_hashes_to_token_ids( + token_ids: torch.Tensor, + mm_hashes: list[str], + mm_positions: list["PlaceholderRange"], +) -> torch.Tensor: + """ + Overwrite token_ids in-place for multimodal placeholders using + efficient slice assignments. + """ + n = token_ids.size(0) + for hash_str, placeholder in zip(mm_hashes, mm_positions, strict=False): + start, length = placeholder.offset, placeholder.length + if start >= n: + continue + end = min(start + length, n) + token_ids[start:end] = hex_hash_to_int16(hash_str) + return token_ids + + +def mla_enabled(model_config: "ModelConfig") -> bool: + return (hasattr(model_config, "use_mla") + and isinstance(model_config.use_mla, bool) + and model_config.use_mla) + + +def create_lmcache_metadata(vllm_config=None, + model_config=None, + parallel_config=None, + cache_config=None): + """ + Create LMCacheEngineMetadata from vLLM configuration. + + This function extracts common metadata creation logic that was duplicated + across multiple files. + + Args: + vllm_config: vLLM configuration object containing model, parallel, and + cache configs (alternative to individual config parameters) + model_config: Model configuration (alternative to vllm_config) + parallel_config: Parallel configuration (alternative to vllm_config) + cache_config: Cache configuration (alternative to vllm_config) + + Returns: + tuple: (LMCacheEngineMetadata, LMCacheEngineConfig) + """ + # Third Party + # First Party + from lmcache.config import LMCacheEngineMetadata + + from vllm.utils import get_kv_cache_torch_dtype + + config = lmcache_get_or_create_config() + # Support both vllm_config object and individual config parameters + if vllm_config is not None: + model_cfg = vllm_config.model_config + parallel_cfg = vllm_config.parallel_config + cache_cfg = vllm_config.cache_config + else: + model_cfg = model_config + parallel_cfg = parallel_config + cache_cfg = cache_config + + # Get KV cache dtype + kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype) + + # Check if MLA is enabled + use_mla = mla_enabled(model_cfg) + + # Construct KV shape (for memory pool) + num_layer = model_cfg.get_num_layers(parallel_cfg) + chunk_size = config.chunk_size + num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg) + head_size = model_cfg.get_head_size() + kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, + head_size) + + # Create metadata + metadata = LMCacheEngineMetadata( + model_cfg.model, + parallel_cfg.world_size, + parallel_cfg.rank, + "vllm", + kv_dtype, + kv_shape, + use_mla, + ) + + return metadata, config + + +def extract_mm_features( + request: "Request", + modify: bool = False) -> tuple[list[str], list["PlaceholderRange"]]: + """ + Normalize multimodal information from a Request into parallel lists. + + This helper reads either: + 1) `request.mm_features` (objects each exposing `.identifier` and + `.mm_position`), or + 2) legacy fields `request.mm_hashes` and `request.mm_positions`. + + It returns two equally sized lists: the multimodal hash identifiers and + their corresponding positions. If the request contains no multimodal info, + it returns `([], [])`. + + Args: + request (Request): The source object. + modify (bool): + Controls copy semantics for the legacy-path return values. + - If True and legacy fields are used, shallow-copies are returned so + the caller can mutate the lists without affecting `request`. + - If False, the original legacy sequences are returned as-is + (zero-copy); treat them as read-only. + + Returns: + tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`). + May be `([], [])` when no multimodal data is present. + """ + if getattr(request, "mm_features", None): + mm_hashes, mm_positions = zip(*((f.identifier, f.mm_position) + for f in request.mm_features), + strict=False) + return (list(mm_hashes), list(mm_positions)) + elif getattr(request, "mm_hashes", None): + if modify: + return (request.mm_hashes.copy(), request.mm_positions.copy()) + else: + return (request.mm_hashes, request.mm_positions) + else: + return ([], []) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py new file mode 100644 index 000000000000..57d5eae3ced9 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -0,0 +1,1353 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Standard +import os +import uuid +from collections.abc import Generator +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Optional, Union + +import torch +from lmcache import utils +from lmcache.config import LMCacheEngineMetadata +from lmcache.logging import init_logger +from lmcache.observability import LMCStatsMonitor +from lmcache.utils import _lmcache_nvtx_annotate +from lmcache.v1.cache_engine import LMCacheEngine, LMCacheEngineBuilder +from lmcache.v1.compute.blend import LMCBlenderBuilder +from lmcache.v1.config import (LMCacheEngineConfig, + _validate_and_set_config_value) +from lmcache.v1.gpu_connector import (VLLMBufferLayerwiseGPUConnector, + VLLMPagedMemGPUConnectorV2, + VLLMPagedMemLayerwiseGPUConnector) +from lmcache.v1.internal_api_server.api_server import InternalAPIServer +from lmcache.v1.lookup_client import LookupClientFactory +from lmcache.v1.lookup_client.lmcache_async_lookup_client import ( + LMCacheAsyncLookupServer) +from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer +from lmcache.v1.plugin.plugin_launcher import PluginLauncher + +# Third Party +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, + get_tp_group) +from vllm.sampling_params import SamplingParams +from vllm.utils import cdiv, get_kv_cache_torch_dtype +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.version import __version__ as VLLM_VERSION + +from .utils import (ENGINE_NAME, apply_mm_hashes_to_token_ids, + extract_mm_features, lmcache_get_or_create_config, + mla_enabled) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.multimodal.inputs import PlaceholderRange + from vllm.v1.core.kv_cache_manager import KVCacheManager + from vllm.v1.core.sched.output import NewRequestData + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class LoadSpec: + # Number of tokens cached in vLLM + vllm_cached_tokens: int + # Number of tokens that are cached in LMCache + lmcache_cached_tokens: int + # Whether the scheduler allow us to load the tokens + can_load: bool + + +@dataclass +class SaveSpec: + # Skip already saved tokens + skip_leading_tokens: int + # Whether the scheduler allow us to save the tokens + can_save: bool + + +@dataclass +class DisaggSpec: + req_id: str + receiver_id: str + receiver_host: str + receiver_init_port: int + receiver_alloc_port: int + is_last_prefill: bool = False + num_transferred_tokens: int = 0 + + +tmp_disagg_tracker: dict[str, DisaggSpec] = {} + + +def extract_request_configs(sampling_params: SamplingParams) -> Optional[dict]: + request_configs = None + if sampling_params.extra_args is not None and \ + "kv_transfer_params" in sampling_params.extra_args: + kv_transfer_params = sampling_params.extra_args.get( + "kv_transfer_params") + for k, v in kv_transfer_params.items(): + if k.startswith("lmcache."): + if request_configs is None: + request_configs = {} + request_configs[k] = v + return request_configs + + +@dataclass +class RequestTracker: + # Request id + req_id: str + + # Total prompt token length + prompt_len: int + + # The token ids that has been scheduled so far + token_ids: list[int] + + # The block ids that has been allocated so far + # NOTE: allocated blocks could be more than the number of tokens + # FIXME: need to check whether the block ids will be changed after + # preemption + allocated_block_ids: list[int] + + # The number of tokens that has been saved + num_saved_tokens: int = 0 + + # Disagg spec for the request + disagg_spec: Optional[DisaggSpec] = None + + # Multimodal hashes and positions + mm_hashes: Optional[list[str]] = None + mm_positions: Optional[list["PlaceholderRange"]] = None + + # The configs of the request, includes tags and other configs + request_configs: Optional[dict] = None + + # Whether the request is in decode phase + is_decode_phase = False + + # Whether the request cache should be saved + skip_save: bool = False + + @_lmcache_nvtx_annotate + @staticmethod + def from_new_request( + lmcache_config: LMCacheEngineConfig, + new_request: "NewRequestData", + num_tokens_to_compute: int, + lmcache_cached_tokens: int, + skip_save: bool, + ) -> "RequestTracker": + """Create the request tracker from a new request. + + Args: + lmcache_config (LMCacheEngineConfig): the LMCache engine config. + new_request (NewRequestData): the new request data. + num_tokens_to_compute (int): the number of tokens that will + be 'computed', including the `num_computed_tokens` (vLLM's + local cache hit) and new tokens that will be scheduled. + lmcache_cached_tokens (int): the number of tokens that are + cached in LMCache. + request_priority (int): the priority of the request + skip_save (bool): whether the request cache should be saved + """ + # vLLM 0.9.0 update: request.block_ids changed from list[int] to + # list[list[int]] + # Need to check the type of request.block_ids + + unfolded_block_ids = [] + + if not isinstance(new_request.block_ids[0], list): + unfolded_block_ids = new_request.block_ids.copy() + else: + # According to the vLLM code + # (https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/ + # sched/scheduler.py#L943), + # only one KVCacheGroup is supported in connector for now. + + # TODO: Please support multiple KVCacheGroup in connector. + # NOTE: Also, `update` method in RequestTracker should be + # updated accordingly. + unfolded_block_ids = new_request.block_ids[0].copy() + + # NOTE: Initialized in `update_state_after_alloc` + disagg_spec = tmp_disagg_tracker.pop(new_request.req_id, None) + + request_configs = extract_request_configs(new_request.sampling_params) + + mm_hashes, mm_positions = extract_mm_features(new_request, modify=True) + + return RequestTracker( + req_id=new_request.req_id, + prompt_len=len(new_request.prompt_token_ids), + token_ids=new_request.prompt_token_ids[:num_tokens_to_compute]. + copy(), + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=lmcache_cached_tokens, + disagg_spec=disagg_spec, + mm_hashes=mm_hashes, + mm_positions=mm_positions, + skip_save=skip_save, + request_configs=request_configs, + ) + + def update( + self, + new_token_ids: list[int], + new_block_ids: Union[Optional[tuple[list[int], ...]], list[int]], + ) -> None: + """Update the request tracker when a running request is + scheduled again + """ + + self.token_ids.extend(new_token_ids) + + if new_block_ids is None: + # https://github.com/vllm-project/vllm/commit/ + # b029de9902aa3ac58806c8c17776c7074175b6db + new_block_ids = [] + elif len(new_block_ids) == 0: + new_block_ids = [] + elif isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0] + elif isinstance(new_block_ids, list): + pass + else: + raise ValueError( + f"Unsupported new_block_ids type {type(new_block_ids)}") + self.allocated_block_ids.extend(new_block_ids) + + # When a request is scheduled again, and the number of new tokens + # is 1 (excluding chunked prefill), the request is in decode phase. + # TODO: Need to further exclude the case of chunked prefill with 1 token + if len(new_token_ids) == 1: + self.is_decode_phase = True + + +@dataclass +class ReqMeta: + # Request id + req_id: str + # Request tokens + token_ids: list[int] # torch.Tensor + # Slot mapping + slot_mapping: torch.Tensor + + # Whether is last prefill or not + is_last_prefill: bool = False + + # Skip save or not + save_spec: Optional[SaveSpec] = None + # load_spec + load_spec: Optional[LoadSpec] = None + # disagg spec + disagg_spec: Optional[DisaggSpec] = None + # the configs of the request + request_configs: Optional[dict] = None + + @staticmethod + def from_request_tracker( + tracker: RequestTracker, + block_size: int, + lmcache_chunk_size: int = 256, + load_spec: Optional[LoadSpec] = None, + discard_partial_chunks: bool = True, + save_decode_cache: bool = False, + ) -> Optional["ReqMeta"]: + """Create the request metadata from a request tracker. + + Args: + tracker (RequestTracker): the request tracker. + block_size (int): the block size in vLLM. + lmcache_chunk_size (int): the chunk size for LMCache. + load_spec (Optional[LoadSpec]): the load spec for KV cache loading. + discard_partial_chunks (bool): whether to discard partial chunks. + save_decode_cache (bool): whether to save the cache in decode phase. + + Returns: + the request metadata if we need to perform load/save + operations, None otherwise. + """ + input_token_ids = tracker.token_ids + input_token_len = len(input_token_ids) + + is_last_prefill = False + if input_token_len == tracker.prompt_len: + is_last_prefill = True + + # For save operation: do not save if the following condition is met + # 1. has already been saved before (num_saved_tokens > 0) + # 2. number of unsaved tokens is not reached the chunk boundary + # 3. if save_decode_cache is False and it is in decode phase + + skip_leading_tokens = tracker.num_saved_tokens + chunk_boundary = ( + cdiv(tracker.num_saved_tokens + 1, lmcache_chunk_size) * + lmcache_chunk_size) + + # NOTE(vladnosiv): for disagg, you cannot skip saving, as saving is a + # trqansfer. Check if request_configs has lmcache.skip_save set to True + request_skip = (tracker.request_configs + or {}).get("lmcache.skip_save", False) + + skip_save = tracker.disagg_spec is None and ( + tracker.skip_save or + (tracker.num_saved_tokens > 0 and input_token_len < chunk_boundary) + or (tracker.is_decode_phase and not save_decode_cache) + or request_skip) + + if skip_save and load_spec is None: + return None + + # Calculate number of tokens to save based on discard_partial_chunks + # setting + + # NOTE(vladnosiv): for the input_token_len chunk prefill, + # we are required to discard partial chunks, + # as new tokens will be added in the next iteration. + num_tokens_to_save = ((input_token_len // lmcache_chunk_size * + lmcache_chunk_size) if not is_last_prefill + or discard_partial_chunks else input_token_len) + + # If we need to save, update the number of saved tokens + if not skip_save: + tracker.num_saved_tokens = num_tokens_to_save + save_spec = SaveSpec(skip_leading_tokens, not skip_save) + + # Calculate the token ids and slot mappings for load and save + token_ids = input_token_ids[:num_tokens_to_save] + + # If the request has multimodal hashes, apply them to the token ids + if tracker.mm_hashes: + # TODO: Optimize this + token_ids_tensor = torch.tensor(token_ids) + assert tracker.mm_positions is not None, ( + "tracker got mm_hashes but no mm_positions") + apply_mm_hashes_to_token_ids(token_ids, tracker.mm_hashes, + tracker.mm_positions) + token_ids = token_ids_tensor.tolist() + + num_blocks = len(tracker.allocated_block_ids) + + if len(token_ids) > num_blocks * block_size: + logger.error( + "The number of tokens is more than the number of blocks." + "Something might be wrong in scheduling logic!") + logger.error( + "Num tokens: %d, num blocks: %d, block size: %d", + len(token_ids), + num_blocks, + block_size, + ) + + block_ids = torch.tensor(tracker.allocated_block_ids, dtype=torch.long) + block_offsets = torch.arange(0, block_size, dtype=torch.long) + slot_mapping = (block_offsets.reshape( + (1, block_size)) + block_ids.reshape((num_blocks, 1)) * block_size) + + slot_mapping = slot_mapping.flatten()[:len(token_ids)] + assert slot_mapping.dtype == torch.long # TODO: this could be removed + + # For load operation: check whether the request is scheduled to load + if load_spec is not None and load_spec.can_load: + logger.debug( + "Scheduled to load %d tokens for request %s", + load_spec.lmcache_cached_tokens, + tracker.req_id, + ) + else: + # Do not load if not in `can_load` state + load_spec = None + + return ReqMeta( + req_id=tracker.req_id, + token_ids=token_ids, + slot_mapping=slot_mapping, + is_last_prefill=is_last_prefill, + save_spec=save_spec, + load_spec=load_spec, + disagg_spec=tracker.disagg_spec, + request_configs=tracker.request_configs, + ) + + +def need_gpu_interm_buffer(lmcache_config: LMCacheEngineConfig): + return lmcache_config.enable_pd + + +def _calculate_mtp_layers(vllm_config, model_config): + num_mtp_layers = 0 + if vllm_config is not None and vllm_config.speculative_config is not None: + logger.info("vllm_config.speculative_config: %s", + vllm_config.speculative_config) + # TODO(baoloongmao): Support other MTP methods + if vllm_config.speculative_config.method == "deepseek_mtp": + num_mtp_layers = getattr(model_config.hf_config, + "num_nextn_predict_layers", 0) + return num_mtp_layers + + +def _init_lmcache_engine( + lmcache_config: LMCacheEngineConfig, + vllm_config: "VllmConfig", +) -> LMCacheEngine: + """Initialize the LMCache engine by the given model config and parallel + config. This function will check the environment variable + `LMCACHE_CONFIG_FILE` to load the configuration file. If that environment + variable is not set, this function will return None. + + :param lmcache_config: The LMCache configuration. + :type lmcache_config: LMCacheEngineConfig + :param vllm_config: The vLLM configuration. + :type vllm_config: VllmConfig + + :return: The initialized LMCache engine + :rtype: LMCacheEngine + """ + if curr_engine := LMCacheEngineBuilder.get(ENGINE_NAME): + return curr_engine + + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + cache_config = vllm_config.cache_config + + assert isinstance( + lmcache_config, + LMCacheEngineConfig), ("LMCache v1 configuration is should be passed.") + + kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, + model_config.dtype) + + use_mla = mla_enabled(model_config) + if use_mla and (lmcache_config.remote_serde != "naive" + and lmcache_config.remote_serde is not None): + raise ValueError("MLA only works with naive serde mode..") + + # construct kv shape (for mem pool) + num_layer = model_config.get_num_layers(parallel_config) + num_mtp_layers = _calculate_mtp_layers(vllm_config, model_config) + num_layer += num_mtp_layers + chunk_size = lmcache_config.chunk_size + num_kv_head = model_config.get_num_kv_heads(parallel_config) + head_size = model_config.get_head_size() + kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, + head_size) + logger.info("use mla: %s, kv shape: %s, num_mtp_layers: %s", use_mla, + kv_shape, num_mtp_layers) + + # Change current device. + num_gpus = torch.cuda.device_count() + local_rank = parallel_config.rank % num_gpus + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + metadata = LMCacheEngineMetadata( + model_config.model, + parallel_config.world_size, + parallel_config.rank, + "vllm", + kv_dtype, + kv_shape, + use_mla, + ) + + use_gpu = need_gpu_interm_buffer(lmcache_config) + vllm_gpu_connector: Union[ + VLLMBufferLayerwiseGPUConnector, + VLLMPagedMemGPUConnectorV2, + VLLMPagedMemLayerwiseGPUConnector, + ] + + if use_mla and lmcache_config.use_layerwise: + raise ValueError("layerwise MLA connector is not supported yet") + + # When use_mla is True, num_kv_head is 1 + hidden_dim_size = num_kv_head * head_size + if lmcache_config.use_layerwise: + if lmcache_config.enable_blending: + # Use layerwise connector for blending + vllm_gpu_connector = VLLMBufferLayerwiseGPUConnector( + hidden_dim_size, + num_layer, + use_gpu=use_gpu, + chunk_size=chunk_size, + dtype=kv_dtype, + device=device, + ) + else: + vllm_gpu_connector = VLLMPagedMemLayerwiseGPUConnector( + hidden_dim_size, + num_layer, + use_gpu=use_gpu, + chunk_size=chunk_size, + dtype=kv_dtype, + device=device, + ) + else: + vllm_gpu_connector = VLLMPagedMemGPUConnectorV2( + hidden_dim_size, + num_layer, + use_gpu=use_gpu, + chunk_size=chunk_size, + dtype=kv_dtype, + device=device, + use_mla=use_mla, + ) + tpg = get_tp_group() + engine = LMCacheEngineBuilder.get_or_create( + ENGINE_NAME, + lmcache_config, + metadata, + vllm_gpu_connector, + tpg.broadcast, + tpg.broadcast_object, + ) + + return engine + + +@dataclass +class LMCacheConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] = field(default_factory=list) + lookup_requests_in_step: list[str] = field(default_factory=list) + + @_lmcache_nvtx_annotate + def add_request(self, req_meta: ReqMeta) -> None: + """Add a request to the metadata. + + Args: + req_meta (ReqMeta): the request metadata. + """ + self.requests.append(req_meta) + + +class LMCacheConnectorV1Impl: + + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + parent: KVConnectorBase_V1, + ): + self._parent = parent + self._vllm_config = vllm_config + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.worker_count = vllm_config.parallel_config.tensor_parallel_size + config = lmcache_get_or_create_config() + assert isinstance(config, LMCacheEngineConfig), ( + "LMCache v1 configuration is should be passed for vLLM v1.") + # Put the leading with "lmcache." and matched configs from + # vllm extra_config to the config + kv_connector_extra_config = ( + vllm_config.kv_transfer_config.kv_connector_extra_config) + if kv_connector_extra_config: + for key, value in kv_connector_extra_config.items(): + if key.startswith("lmcache."): + config_key = key[8:] # Remove "lmcache." prefix + if _validate_and_set_config_value(config, config_key, + value): + logger.info( + "Updated config %s from vLLM " + "extra config: %s", config_key, value) + + self.config = config + + self.async_loading = config.enable_async_loading + self.layerwise_retrievers: list[Generator[Optional[torch.Tensor], None, + None]] = [] + self._stats_monitor = LMCStatsMonitor.GetOrCreate() + if role == KVConnectorRole.SCHEDULER: + # Create lookup client using factory + self.lookup_client = LookupClientFactory.create_lookup_client( + vllm_config, config) + self._unfinished_requests: dict[str, Request] = {} + self._lookup_requests_in_step: list[str] = [] + self.lmcache_engine = None + else: + self.lmcache_engine = _init_lmcache_engine( + config, + vllm_config, + ) + + self.use_layerwise = config.use_layerwise + self.enable_blending = config.enable_blending + + if self.enable_blending: + self.blender = LMCBlenderBuilder.get_or_create( + ENGINE_NAME, + self.lmcache_engine, + self.lmcache_engine.gpu_connector, + config, + ) + + # Create lookup server using factory + assert self.lmcache_engine is not None + self.lookup_server = LookupClientFactory.create_lookup_server( + self.lmcache_engine, vllm_config) + + self.offload_server = ZMQOffloadServer( + self.lmcache_engine, + vllm_config, + get_tensor_model_parallel_rank(), + ) + + # In case of MLA, the lookup server is only created on worker 0 + if self.async_loading and self.lookup_server is not None: + assert isinstance(self.lookup_server, LMCacheAsyncLookupServer) + self.lmcache_engine.post_init( + async_lookup_server=self.lookup_server) + + self.kv_caches: dict[str, torch.Tensor] = {} + + self._block_size = vllm_config.cache_config.block_size + + # request_id -> (vllm cached tokens, lmcache cached tokens) + self.load_specs: dict[str, LoadSpec] = {} + + self.kv_cache_manager: Optional[KVCacheManager] = None + + # request_id -> full_token_ids + self._request_trackers: dict[str, RequestTracker] = {} + + # Whether to discard partial chunks + self._discard_partial_chunks = ( + vllm_config.kv_transfer_config.get_from_extra_config( + "discard_partial_chunks", False) + or not config.save_unfull_chunk) + + self._lmcache_chunk_size = config.chunk_size + self._save_decode_cache = config.save_decode_cache + + self.skip_last_n_tokens = \ + vllm_config.kv_transfer_config.get_from_extra_config( + "skip_last_n_tokens", 0) + + self.num_layers = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + self.current_layer = 0 + + self.force_skip_save = bool( + os.environ.get("LMCACHE_FORCE_SKIP_SAVE", False)) + + self._requests_priority: dict[str, int] = {} + + # TODO(baoloongmao): Internal api server & plugin framework support + # dp > 1 + if (vllm_config.parallel_config.data_parallel_size_local == 1 + or vllm_config.parallel_config.data_parallel_rank_local == 0): + # Start internal API server if enabled + # The enabled check is in the InternalAPIServer constructor + self.api_server = InternalAPIServer(self) + self.api_server.start() + # Launch plugins + self.plugin_launcher = PluginLauncher( + self.config, + role, + self.worker_count, + -1 if self.lmcache_engine is None # scheduler side + else self.lmcache_engine.metadata.worker_id, + ) + self.plugin_launcher.launch_plugins() + else: + self.api_server = None # type: ignore[assignment] + self.plugin_launcher = None # type: ignore[assignment] + logger.info( + "LMCache initialized for role %s with version %s, " + "vllm version %s, lmcache cache_engine metadata: %s", role, + utils.get_version(), VLLM_VERSION, + getattr(self.lmcache_engine, 'metadata', None)) + + def get_inference_info(self) -> dict: + """Get inference information including vLLM config and related details. + + Returns: + dict: Dictionary containing inference information + """ + # Get vLLM config information + vllm_config = self._vllm_config + + # Use vLLM config's string representation and add specific configs + inference_info = { + "vllm_version": VLLM_VERSION, + "lmcache_version": utils.get_version(), + "vllm_config": str(vllm_config), + "model_config": { + "model": + getattr(vllm_config.model_config, "model", None), + "dtype": + str(getattr(vllm_config.model_config, "dtype", None)), + "max_model_len": + getattr(vllm_config.model_config, "max_model_len", None), + "vocab_size": + getattr(vllm_config.model_config, "vocab_size", None), + "num_layers": + getattr(vllm_config.model_config, "get_num_layers", + lambda _: None)(vllm_config.parallel_config), + "num_attention_heads": + getattr(vllm_config.model_config, "get_num_attention_heads", + lambda _: None)(vllm_config.parallel_config), + "num_kv_heads": + getattr(vllm_config.model_config, "get_num_kv_heads", + lambda _: None)(vllm_config.parallel_config), + "head_size": + getattr(vllm_config.model_config, "get_head_size", + lambda: None)(), + }, + "cache_config": { + "block_size": + getattr(vllm_config.cache_config, "block_size", None), + "cache_dtype": + str(getattr(vllm_config.cache_config, "cache_dtype", None)), + "gpu_memory_utilization": + getattr(vllm_config.cache_config, "gpu_memory_utilization", + None), + "swap_space": + getattr(vllm_config.cache_config, "swap_space", None), + "enable_prefix_caching": + getattr(vllm_config.cache_config, "enable_prefix_caching", + None), + }, + } + + return inference_info + + def get_inference_version(self) -> str: + """Get vLLM version information. + + Returns: + str: vLLM version string + """ + return VLLM_VERSION + + @_lmcache_nvtx_annotate + def _init_kv_caches_from_forward_context( + self, forward_context: "ForwardContext"): + for layer_name in forward_context.no_compile_layers: + attn_layer = forward_context.no_compile_layers[layer_name] + if not hasattr(attn_layer, "kv_cache"): + logger.debug("The layer %s does not have kv_cache, skip it", + layer_name) + continue + + if layer_name not in self.kv_caches: + self.kv_caches[layer_name] = attn_layer.kv_cache[ + forward_context.virtual_engine] + + #################### + # Worker side APIs + #################### + + @_lmcache_nvtx_annotate + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + self.current_layer = 0 + + if len(self.kv_caches) == 0: + self._init_kv_caches_from_forward_context(forward_context) + + metadata = self._parent._get_connector_metadata() + assert isinstance(metadata, LMCacheConnectorMetadata) + + assert len(self.kv_caches) > 0 + kvcaches = list(self.kv_caches.values()) + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + logger.debug( + "In connector.start_load_kv, but the attn_metadata is None") + return + + assert self.lmcache_engine is not None + + self.lmcache_engine.post_init(kvcaches=kvcaches) + + self.layerwise_retrievers = [] + + for idx, request in enumerate(metadata.requests): + if request.load_spec is None: + continue + last_idx = idx + + for idx, request in enumerate(metadata.requests): + if request.load_spec is None: + continue + + tokens = request.token_ids + # TODO: have a pre-allocated buffer to hold the slot_mappings + slot_mapping = request.slot_mapping.cuda() + assert len(tokens) == len(slot_mapping) + + self._stats_monitor.update_interval_vllm_hit_tokens( + request.load_spec.vllm_cached_tokens) + token_mask = torch.ones(len(tokens), dtype=torch.bool) + masked_token_count = (request.load_spec.vllm_cached_tokens // + self._lmcache_chunk_size * + self._lmcache_chunk_size) + token_mask[:masked_token_count] = False + + lmcache_cached_tokens = request.load_spec.lmcache_cached_tokens + if self.use_layerwise: + sync = idx == last_idx + # NOTE(Jiayi): Perform blending before layerwise prefix caching + if self.enable_blending: + # TODO(Jiayi): Need to make prefix caching and blending + # compatible + self.blender.blend( + tokens[:lmcache_cached_tokens], + token_mask[:lmcache_cached_tokens], + kvcaches=kvcaches, + slot_mapping=slot_mapping[:lmcache_cached_tokens], + ) + else: + layerwise_retriever = self.lmcache_engine.retrieve_layer( + tokens[:lmcache_cached_tokens], + token_mask[:lmcache_cached_tokens], + kvcaches=kvcaches, + slot_mapping=slot_mapping[:lmcache_cached_tokens], + sync=sync, + ) + # NOTE: retrieve for two layers at the first layer + next(layerwise_retriever) + next(layerwise_retriever) + self.layerwise_retrievers.append(layerwise_retriever) + else: + ret_token_mask = self.lmcache_engine.retrieve( + tokens[:lmcache_cached_tokens], + token_mask[:lmcache_cached_tokens], + kvcaches=kvcaches, + slot_mapping=slot_mapping[:lmcache_cached_tokens], + request_configs=request.request_configs, + req_id=request.req_id, + ) + + # Check the result + num_retrieved_tokens = ret_token_mask.sum().item() + num_expected_tokens = (lmcache_cached_tokens - + request.load_spec.vllm_cached_tokens) + if num_retrieved_tokens < num_expected_tokens: + logger.error( + "The number of retrieved tokens is less than the " + "expected number of tokens! This should not happen!") + logger.error( + "Num retrieved tokens: %d, num expected tokens: %d", + num_retrieved_tokens, + num_expected_tokens, + ) + + @_lmcache_nvtx_annotate + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + if self.layerwise_retrievers: + logger.debug("Waiting for layer %s to be loaded", + self.current_layer) + + # Wait for the layer to be loaded + for layerwise_retriever in self.layerwise_retrievers: + ret_token_mask = next(layerwise_retriever) + + if self.current_layer == self.num_layers - 1: + assert ret_token_mask is not None + num_retrieved_tokens = ret_token_mask.sum().item() + logger.info("Retrieved %s tokens", num_retrieved_tokens) + + return + + @_lmcache_nvtx_annotate + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + """Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + assert self.lmcache_engine is not None + + if not self.use_layerwise: + return + + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + if self._parent._connector_metadata is None: + logger.warning( + "In connector.save_kv_layer, but the connector metadata is None" + ) + return + connector_metadata = self._parent._get_connector_metadata() + assert isinstance(connector_metadata, LMCacheConnectorMetadata) + + assert len(self.kv_caches) > 0 + + kvcaches = list(self.kv_caches.values()) + if self.current_layer == 0: + self.layerwise_storers = [] + + is_first = True + + for idx, request in enumerate(connector_metadata.requests): + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + + token_ids = request.token_ids + assert isinstance(token_ids, list) + + slot_mapping = request.slot_mapping + assert isinstance(slot_mapping, torch.Tensor) + assert len(slot_mapping) == len(token_ids) + + # TODO: have a pre-allocated buffer to hold the slot_mappings + slot_mapping = slot_mapping.cuda() + + if self.kv_role == "kv_producer": + skip_leading_tokens = 0 + else: + skip_leading_tokens = save_spec.skip_leading_tokens + + if skip_leading_tokens == len(token_ids): + continue # skip this request + # Align to lmcache chunk size + skip_leading_tokens = (skip_leading_tokens // + self._lmcache_chunk_size * + self._lmcache_chunk_size) + + store_mask = torch.ones(len(token_ids), dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + # TODO (Jiayi): need to make layerwise storing + # compatible with disagg spec + layerwise_storer = self.lmcache_engine.store_layer( + token_ids, + mask=store_mask, + kvcaches=kvcaches, + slot_mapping=slot_mapping, + offset=skip_leading_tokens, + sync=is_first, + ) + self.layerwise_storers.append(layerwise_storer) + if is_first: + is_first = False + + for layerwise_storer in self.layerwise_storers: + next(layerwise_storer) + + self.current_layer += 1 + + @_lmcache_nvtx_annotate + def wait_for_save(self): + """Blocking until the KV cache is saved to the connector buffer.""" + + connector_metadata = self._parent._get_connector_metadata() + assert isinstance(connector_metadata, LMCacheConnectorMetadata) + + self.lmcache_engine.lookup_unpin( # type: ignore + connector_metadata.lookup_requests_in_step) + + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + + if self.use_layerwise: + for layerwise_storer in self.layerwise_storers: + next(layerwise_storer) + return + + assert len(self.kv_caches) > 0 + kvcaches = list(self.kv_caches.values()) + + assert self.lmcache_engine is not None + + for request in connector_metadata.requests: + save_spec = request.save_spec + if (save_spec is None or + not save_spec.can_save) and self.kv_role != "kv_producer": + continue + + token_ids = request.token_ids + + slot_mapping = request.slot_mapping + assert isinstance(slot_mapping, torch.Tensor) + assert len(slot_mapping) == len(token_ids) + assert save_spec is not None + + # TODO: have a pre-allocated buffer to hold the slot_mappings + slot_mapping = slot_mapping.cuda() + + skip_leading_tokens = save_spec.skip_leading_tokens + if self.kv_role == "kv_producer": + assert request.disagg_spec is not None + skip_leading_tokens = min( + skip_leading_tokens, + request.disagg_spec.num_transferred_tokens) + + if skip_leading_tokens == len(token_ids): + continue # skip this request + # Align to lmcache chunk size + skip_leading_tokens = (skip_leading_tokens // + self._lmcache_chunk_size * + self._lmcache_chunk_size) + + store_mask = torch.ones(len(token_ids), dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + is_last_prefill = request.is_last_prefill + if is_last_prefill: + if request.disagg_spec: + request.disagg_spec.is_last_prefill = True + else: + token_len = len(token_ids) + aligned_token_len = (token_len // self._lmcache_chunk_size * + self._lmcache_chunk_size) + token_ids = token_ids[:aligned_token_len] + store_mask = store_mask[:aligned_token_len] + slot_mapping = slot_mapping[:aligned_token_len] + + self.lmcache_engine.store( + token_ids, + mask=store_mask, + kvcaches=kvcaches, + slot_mapping=slot_mapping, + offset=skip_leading_tokens, + transfer_spec=request.disagg_spec, + request_configs=request.request_configs, + ) + + # NOTE(Jiayi): We assume all tokens are saved + save_spec.skip_leading_tokens = len(token_ids) + if request.disagg_spec: + request.disagg_spec.num_transferred_tokens = len(token_ids) + + @_lmcache_nvtx_annotate + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + return None, None + + ################### + # Scheduler side APIs + #################### + + @_lmcache_nvtx_annotate + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> Optional[int]: + """ + Check for external KV cache hit. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + if self.kv_role == "kv_producer" and not hasattr( + self.lookup_client, "supports_producer_reuse"): + return 0 + + self._requests_priority[request.request_id] = request.priority + + token_ids = request.prompt_token_ids + + # If the request has multimodal hashes, apply them to the token ids + mm_hashes, mm_positions = extract_mm_features(request) + if mm_hashes and mm_positions: + # TODO(Jiayi): Optimize this + token_ids = torch.tensor(request.prompt_token_ids) + apply_mm_hashes_to_token_ids(token_ids, mm_hashes, mm_positions) + token_ids = token_ids.tolist() + + request_configs = extract_request_configs(request.sampling_params) + if self.skip_last_n_tokens > 0: + token_ids = token_ids[:-self.skip_last_n_tokens] + if self.async_loading: + lookup_id = request.request_id + else: + lookup_id = str(uuid.uuid4()) + + self._lookup_requests_in_step.append(lookup_id) + + num_external_hit_tokens = self.lookup_client.lookup( + token_ids, + lookup_id=lookup_id, + request_configs=request_configs, + ) + + if num_external_hit_tokens is None: + logger.info( + "Reqid: %s, Total tokens %d, LMCache hit tokens: None.", + request.request_id, + request.num_tokens, + ) + return None + + # When prompt length is divisible by the block size and all + # blocks are cached, we need to recompute the last token. + # This will be removed in the future if vLLM's scheduler provides + # a better support for this case. + need_to_allocate = num_external_hit_tokens - num_computed_tokens + + # In, full-prompt-hit case, we need to recompute the last token + if num_external_hit_tokens == request.num_tokens: + need_to_allocate -= 1 + + logger.info( + "Reqid: %s, Total tokens %d, LMCache hit tokens: %d, " + "need to load: %d", + request.request_id, + request.num_tokens, + num_external_hit_tokens, + need_to_allocate, + ) + + self.load_specs[request.request_id] = LoadSpec( + vllm_cached_tokens=num_computed_tokens, + lmcache_cached_tokens=num_external_hit_tokens, + can_load=False, + ) + + if need_to_allocate <= 0: + return 0 + + # TODO: Align to vLLM block size. Should test whether it can be removed + # need_to_allocate = need_to_allocate // self._block_size * \ + # self._block_size + + return need_to_allocate + + @_lmcache_nvtx_annotate + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after temporary buffer alloc. + + For SharedStorageConnector, update _request_needs_load + if the CacheManager this allocated blocks for us. + """ + + kv_transfer_params = (request.kv_transfer_params if hasattr( + request, "kv_transfer_params") else None) + + if kv_transfer_params is not None and \ + "disagg_spec" in kv_transfer_params: + req_disagg_spec = kv_transfer_params["disagg_spec"] + + receiver_id = req_disagg_spec["receiver_host"] + str( + req_disagg_spec["receiver_init_port"]) + + disagg_spec = DisaggSpec( + req_id=req_disagg_spec["req_id"], + receiver_id=receiver_id, + receiver_host=req_disagg_spec["receiver_host"], + receiver_init_port=req_disagg_spec["receiver_init_port"], + receiver_alloc_port=req_disagg_spec["receiver_alloc_port"], + ) + + tmp_disagg_tracker[request.request_id] = disagg_spec + self._unfinished_requests[request.request_id] = request + + if request.request_id not in self.load_specs: + # No KV tokens from external KV cache, return + return + + if num_external_tokens == 0: + # No need to load anything + self.load_specs[request.request_id].can_load = False + return + + # Only check for non-prompt-hit case + if (self.load_specs[request.request_id].lmcache_cached_tokens + != request.num_tokens): + assert ( + num_external_tokens > 0 and num_external_tokens + == self.load_specs[request.request_id].lmcache_cached_tokens - + self.load_specs[request.request_id].vllm_cached_tokens + ), (f"Mismatch in number of tokens: {num_external_tokens} vs " + f"{self.load_specs[request.request_id].lmcache_cached_tokens} -" + f" {self.load_specs[request.request_id].vllm_cached_tokens}" + f" for request {request.request_id}") + + self.load_specs[request.request_id].can_load = True + + @_lmcache_nvtx_annotate + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """Attach the connector metadata to the request object. + + This function should NOT modify other fields in the scheduler_output + except the `kv_connector_metadata` field. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + + force_skip_save = self.kv_role == "kv_consumer" or self.force_skip_save + + meta = LMCacheConnectorMetadata() + + # set and update lookup requests for unpin + meta.lookup_requests_in_step = self._lookup_requests_in_step + self._lookup_requests_in_step = [] + + for finished_req_id in scheduler_output.finished_req_ids: + self._request_trackers.pop(finished_req_id, None) + self._unfinished_requests.pop(finished_req_id, None) + + for request in scheduler_output.scheduled_new_reqs: + # Right now, we only load KV for new requests + load_spec = self.load_specs.pop(request.req_id, None) + num_tokens_to_compute = ( + request.num_computed_tokens + + scheduler_output.num_scheduled_tokens[request.req_id]) + lmcache_cached_tokens = 0 + if load_spec is not None: + lmcache_cached_tokens = load_spec.lmcache_cached_tokens + request_priority = self._requests_priority.pop(request.req_id, 0) + + skip_save = force_skip_save or ( + self.config.priority_limit is not None + and request_priority > self.config.priority_limit) + + request_tracker = RequestTracker.from_new_request( + self.config, + request, + num_tokens_to_compute, + lmcache_cached_tokens, + skip_save, + ) + self._request_trackers[request.req_id] = request_tracker + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + self._lmcache_chunk_size, + load_spec=load_spec, + discard_partial_chunks=self._discard_partial_chunks, + save_decode_cache=self._save_decode_cache, + ) + if req_meta is not None: + meta.add_request(req_meta) + + cached_reqs = scheduler_output.scheduled_cached_reqs + + # NOTE: For backward compatibility with vllm version < 0.9.2, + # In the latest vllm version, the type of scheduled_cached_reqs has + # changed from list to object `CachedRequestData` + if isinstance(cached_reqs, list): + for i, req in enumerate(cached_reqs): + request_tracker = self._request_trackers[req.req_id] + request_tracker.update(req.new_token_ids, req.new_block_ids) + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + self._lmcache_chunk_size, + load_spec=None, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + return meta + + for i, req_id in enumerate(cached_reqs.req_ids): + request_tracker = self._request_trackers[req_id] + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + if request := self._unfinished_requests.get(req_id): + num_current_tokens = len(request_tracker.token_ids) + new_token_ids = request.all_token_ids[ + num_current_tokens:num_current_tokens + num_new_tokens] + else: + raise ValueError( + f"Request {req_id} is not in _unfinished_requests, " + f"but it is scheduled to be cached") + new_block_ids = cached_reqs.new_block_ids[i] + + request_tracker.update(new_token_ids, new_block_ids) + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + self._lmcache_chunk_size, + load_spec=None, + discard_partial_chunks=self._discard_partial_chunks, + save_decode_cache=self._save_decode_cache, + ) + if req_meta is not None: + meta.add_request(req_meta) + + return meta + + @_lmcache_nvtx_annotate + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + params = (request.kv_transfer_params if hasattr( + request, "kv_transfer_params") else None) + return_params = None + + # NOTE: Used to stream back the first token + # for disagg prefill + if params is not None and "ret_first_tok" in params: + return_params = { + "first_tok": request._output_token_ids[0], + } + + return False, return_params From 63f70677216aa71074be0c917711c3c9f47f7e97 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 23 Sep 2025 21:24:39 -0700 Subject: [PATCH 02/12] [fix] zip problem Signed-off-by: ApostaC --- .../kv_transfer/kv_connector/v1/lmcache_integration/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py index 995717c088a6..f6adbfc0456a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -89,7 +89,7 @@ def apply_mm_hashes_to_token_ids( efficient slice assignments. """ n = token_ids.size(0) - for hash_str, placeholder in zip(mm_hashes, mm_positions, strict=False): + for hash_str, placeholder in zip(mm_hashes, mm_positions): start, length = placeholder.offset, placeholder.length if start >= n: continue From 49c1baadeb3717e815a7c9b9334bbf3a38d19b96 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 23 Sep 2025 21:45:22 -0700 Subject: [PATCH 03/12] fix precommit issues Signed-off-by: ApostaC --- .../kv_transfer/kv_connector/v1/lmcache_integration/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py index f6adbfc0456a..5518daa0f7f4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -199,8 +199,7 @@ def extract_mm_features( """ if getattr(request, "mm_features", None): mm_hashes, mm_positions = zip(*((f.identifier, f.mm_position) - for f in request.mm_features), - strict=False) + for f in request.mm_features)) return (list(mm_hashes), list(mm_positions)) elif getattr(request, "mm_hashes", None): if modify: From f37fff6ee59fab74043c74b60304d0842aa55d7c Mon Sep 17 00:00:00 2001 From: ApostaC Date: Thu, 25 Sep 2025 17:33:11 -0700 Subject: [PATCH 04/12] default non native Signed-off-by: ApostaC --- .../kv_transfer/kv_connector/v1/lmcache_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 835c7f041df9..895aa7215039 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -29,7 +29,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) use_native = vllm_config.kv_transfer_config.get_from_extra_config( - "use_native", True) + "use_native", False) if use_native: logger.info("Initializing native LMCache connector") cls = LMCacheConnectorUpstreamImpl From ecb0f5b7d3962151880b0284cdbdc8116d84249f Mon Sep 17 00:00:00 2001 From: ApostaC Date: Thu, 9 Oct 2025 13:12:50 -0700 Subject: [PATCH 05/12] [add] pre-commit fixes Signed-off-by: ApostaC --- .../kv_connector/v1/lmcache_connector.py | 9 +- .../v1/lmcache_integration/utils.py | 44 +- .../v1/lmcache_integration/vllm_v1_adapter.py | 393 ++++++++++-------- 3 files changed, 251 insertions(+), 195 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 51f12a9a657d..74bec84b27df 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -4,7 +4,8 @@ import torch from lmcache.integration.vllm.vllm_v1_adapter import ( - LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl) + LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl, +) from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -16,7 +17,8 @@ from vllm.v1.core.sched.output import SchedulerOutput from .lmcache_integration.vllm_v1_adapter import ( - LMCacheConnectorV1Impl as LMCacheConnectorUpstreamImpl) # yapf: disable + LMCacheConnectorV1Impl as LMCacheConnectorUpstreamImpl, # yapf: disable +) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -31,7 +33,8 @@ class LMCacheConnectorV1(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) use_native = vllm_config.kv_transfer_config.get_from_extra_config( - "use_native", False) + "use_native", False + ) if use_native: logger.info("Initializing native LMCache connector") cls = LMCacheConnectorUpstreamImpl diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py index 5518daa0f7f4..35f8148a2c5d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -12,6 +12,7 @@ # Third Party import torch + # First Party from lmcache.config import LMCacheEngineConfig as Config from lmcache.logging import init_logger @@ -49,7 +50,8 @@ def lmcache_get_or_create_config() -> Union[Config, V1Config]: "Detected LMCACHE_USE_EXPERIMENTAL is set to False. " "Using legacy configuration is deprecated and will " "be remove soon! Please set LMCACHE_USE_EXPERIMENTAL " - "to True.") + "to True." + ) LMCacheEngineConfig = Config # type: ignore[assignment] else: LMCacheEngineConfig = V1Config # type: ignore[assignment] @@ -57,16 +59,17 @@ def lmcache_get_or_create_config() -> Union[Config, V1Config]: if "LMCACHE_CONFIG_FILE" not in os.environ: logger.warning( "No LMCache configuration file is set. Trying to read" - " configurations from the environment variables.") + " configurations from the environment variables." + ) logger.warning( "You can set the configuration file through " - "the environment variable: LMCACHE_CONFIG_FILE") + "the environment variable: LMCACHE_CONFIG_FILE" + ) _config_instance = LMCacheEngineConfig.from_env() else: config_file = os.environ["LMCACHE_CONFIG_FILE"] logger.info("Loading LMCache config file %s", config_file) - _config_instance = LMCacheEngineConfig.from_file( - config_file) + _config_instance = LMCacheEngineConfig.from_file(config_file) # Update config from environment variables _config_instance.update_config_from_env() return _config_instance @@ -99,15 +102,16 @@ def apply_mm_hashes_to_token_ids( def mla_enabled(model_config: "ModelConfig") -> bool: - return (hasattr(model_config, "use_mla") - and isinstance(model_config.use_mla, bool) - and model_config.use_mla) + return ( + hasattr(model_config, "use_mla") + and isinstance(model_config.use_mla, bool) + and model_config.use_mla + ) -def create_lmcache_metadata(vllm_config=None, - model_config=None, - parallel_config=None, - cache_config=None): +def create_lmcache_metadata( + vllm_config=None, model_config=None, parallel_config=None, cache_config=None +): """ Create LMCacheEngineMetadata from vLLM configuration. @@ -152,8 +156,7 @@ def create_lmcache_metadata(vllm_config=None, chunk_size = config.chunk_size num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg) head_size = model_cfg.get_head_size() - kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, - head_size) + kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size) # Create metadata metadata = LMCacheEngineMetadata( @@ -170,8 +173,8 @@ def create_lmcache_metadata(vllm_config=None, def extract_mm_features( - request: "Request", - modify: bool = False) -> tuple[list[str], list["PlaceholderRange"]]: + request: "Request", modify: bool = False +) -> tuple[list[str], list["PlaceholderRange"]]: """ Normalize multimodal information from a Request into parallel lists. @@ -180,8 +183,8 @@ def extract_mm_features( `.mm_position`), or 2) legacy fields `request.mm_hashes` and `request.mm_positions`. - It returns two equally sized lists: the multimodal hash identifiers and - their corresponding positions. If the request contains no multimodal info, + It returns two equally sized lists: the multimodal hash identifiers and + their corresponding positions. If the request contains no multimodal info, it returns `([], [])`. Args: @@ -198,8 +201,9 @@ def extract_mm_features( May be `([], [])` when no multimodal data is present. """ if getattr(request, "mm_features", None): - mm_hashes, mm_positions = zip(*((f.identifier, f.mm_position) - for f in request.mm_features)) + mm_hashes, mm_positions = zip( + *((f.identifier, f.mm_position) for f in request.mm_features) + ) return (list(mm_hashes), list(mm_positions)) elif getattr(request, "mm_hashes", None): if modify: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 57d5eae3ced9..e9444c82cfaf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -15,32 +15,40 @@ from lmcache.utils import _lmcache_nvtx_annotate from lmcache.v1.cache_engine import LMCacheEngine, LMCacheEngineBuilder from lmcache.v1.compute.blend import LMCBlenderBuilder -from lmcache.v1.config import (LMCacheEngineConfig, - _validate_and_set_config_value) -from lmcache.v1.gpu_connector import (VLLMBufferLayerwiseGPUConnector, - VLLMPagedMemGPUConnectorV2, - VLLMPagedMemLayerwiseGPUConnector) +from lmcache.v1.config import LMCacheEngineConfig, _validate_and_set_config_value +from lmcache.v1.gpu_connector import ( + VLLMBufferLayerwiseGPUConnector, + VLLMPagedMemGPUConnectorV2, + VLLMPagedMemLayerwiseGPUConnector, +) from lmcache.v1.internal_api_server.api_server import InternalAPIServer from lmcache.v1.lookup_client import LookupClientFactory from lmcache.v1.lookup_client.lmcache_async_lookup_client import ( - LMCacheAsyncLookupServer) + LMCacheAsyncLookupServer, +) from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer from lmcache.v1.plugin.plugin_launcher import PluginLauncher # Third Party from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, - get_tp_group) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group from vllm.sampling_params import SamplingParams from vllm.utils import cdiv, get_kv_cache_torch_dtype from vllm.v1.core.sched.output import SchedulerOutput from vllm.version import __version__ as VLLM_VERSION -from .utils import (ENGINE_NAME, apply_mm_hashes_to_token_ids, - extract_mm_features, lmcache_get_or_create_config, - mla_enabled) +from .utils import ( + ENGINE_NAME, + apply_mm_hashes_to_token_ids, + extract_mm_features, + lmcache_get_or_create_config, + mla_enabled, +) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -87,10 +95,11 @@ class DisaggSpec: def extract_request_configs(sampling_params: SamplingParams) -> Optional[dict]: request_configs = None - if sampling_params.extra_args is not None and \ - "kv_transfer_params" in sampling_params.extra_args: - kv_transfer_params = sampling_params.extra_args.get( - "kv_transfer_params") + if ( + sampling_params.extra_args is not None + and "kv_transfer_params" in sampling_params.extra_args + ): + kv_transfer_params = sampling_params.extra_args.get("kv_transfer_params") for k, v in kv_transfer_params.items(): if k.startswith("lmcache."): if request_configs is None: @@ -186,8 +195,7 @@ def from_new_request( return RequestTracker( req_id=new_request.req_id, prompt_len=len(new_request.prompt_token_ids), - token_ids=new_request.prompt_token_ids[:num_tokens_to_compute]. - copy(), + token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].copy(), allocated_block_ids=unfolded_block_ids, num_saved_tokens=lmcache_cached_tokens, disagg_spec=disagg_spec, @@ -219,8 +227,7 @@ def update( elif isinstance(new_block_ids, list): pass else: - raise ValueError( - f"Unsupported new_block_ids type {type(new_block_ids)}") + raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}") self.allocated_block_ids.extend(new_block_ids) # When a request is scheduled again, and the number of new tokens @@ -288,19 +295,19 @@ def from_request_tracker( skip_leading_tokens = tracker.num_saved_tokens chunk_boundary = ( - cdiv(tracker.num_saved_tokens + 1, lmcache_chunk_size) * - lmcache_chunk_size) + cdiv(tracker.num_saved_tokens + 1, lmcache_chunk_size) * lmcache_chunk_size + ) # NOTE(vladnosiv): for disagg, you cannot skip saving, as saving is a # trqansfer. Check if request_configs has lmcache.skip_save set to True - request_skip = (tracker.request_configs - or {}).get("lmcache.skip_save", False) + request_skip = (tracker.request_configs or {}).get("lmcache.skip_save", False) skip_save = tracker.disagg_spec is None and ( - tracker.skip_save or - (tracker.num_saved_tokens > 0 and input_token_len < chunk_boundary) + tracker.skip_save + or (tracker.num_saved_tokens > 0 and input_token_len < chunk_boundary) or (tracker.is_decode_phase and not save_decode_cache) - or request_skip) + or request_skip + ) if skip_save and load_spec is None: return None @@ -311,9 +318,11 @@ def from_request_tracker( # NOTE(vladnosiv): for the input_token_len chunk prefill, # we are required to discard partial chunks, # as new tokens will be added in the next iteration. - num_tokens_to_save = ((input_token_len // lmcache_chunk_size * - lmcache_chunk_size) if not is_last_prefill - or discard_partial_chunks else input_token_len) + num_tokens_to_save = ( + (input_token_len // lmcache_chunk_size * lmcache_chunk_size) + if not is_last_prefill or discard_partial_chunks + else input_token_len + ) # If we need to save, update the number of saved tokens if not skip_save: @@ -328,9 +337,11 @@ def from_request_tracker( # TODO: Optimize this token_ids_tensor = torch.tensor(token_ids) assert tracker.mm_positions is not None, ( - "tracker got mm_hashes but no mm_positions") - apply_mm_hashes_to_token_ids(token_ids, tracker.mm_hashes, - tracker.mm_positions) + "tracker got mm_hashes but no mm_positions" + ) + apply_mm_hashes_to_token_ids( + token_ids, tracker.mm_hashes, tracker.mm_positions + ) token_ids = token_ids_tensor.tolist() num_blocks = len(tracker.allocated_block_ids) @@ -338,7 +349,8 @@ def from_request_tracker( if len(token_ids) > num_blocks * block_size: logger.error( "The number of tokens is more than the number of blocks." - "Something might be wrong in scheduling logic!") + "Something might be wrong in scheduling logic!" + ) logger.error( "Num tokens: %d, num blocks: %d, block size: %d", len(token_ids), @@ -348,10 +360,12 @@ def from_request_tracker( block_ids = torch.tensor(tracker.allocated_block_ids, dtype=torch.long) block_offsets = torch.arange(0, block_size, dtype=torch.long) - slot_mapping = (block_offsets.reshape( - (1, block_size)) + block_ids.reshape((num_blocks, 1)) * block_size) + slot_mapping = ( + block_offsets.reshape((1, block_size)) + + block_ids.reshape((num_blocks, 1)) * block_size + ) - slot_mapping = slot_mapping.flatten()[:len(token_ids)] + slot_mapping = slot_mapping.flatten()[: len(token_ids)] assert slot_mapping.dtype == torch.long # TODO: this could be removed # For load operation: check whether the request is scheduled to load @@ -384,12 +398,14 @@ def need_gpu_interm_buffer(lmcache_config: LMCacheEngineConfig): def _calculate_mtp_layers(vllm_config, model_config): num_mtp_layers = 0 if vllm_config is not None and vllm_config.speculative_config is not None: - logger.info("vllm_config.speculative_config: %s", - vllm_config.speculative_config) + logger.info( + "vllm_config.speculative_config: %s", vllm_config.speculative_config + ) # TODO(baoloongmao): Support other MTP methods if vllm_config.speculative_config.method == "deepseek_mtp": - num_mtp_layers = getattr(model_config.hf_config, - "num_nextn_predict_layers", 0) + num_mtp_layers = getattr( + model_config.hf_config, "num_nextn_predict_layers", 0 + ) return num_mtp_layers @@ -417,16 +433,17 @@ def _init_lmcache_engine( parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config - assert isinstance( - lmcache_config, - LMCacheEngineConfig), ("LMCache v1 configuration is should be passed.") + assert isinstance(lmcache_config, LMCacheEngineConfig), ( + "LMCache v1 configuration is should be passed." + ) - kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, - model_config.dtype) + kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype) use_mla = mla_enabled(model_config) - if use_mla and (lmcache_config.remote_serde != "naive" - and lmcache_config.remote_serde is not None): + if use_mla and ( + lmcache_config.remote_serde != "naive" + and lmcache_config.remote_serde is not None + ): raise ValueError("MLA only works with naive serde mode..") # construct kv shape (for mem pool) @@ -436,10 +453,13 @@ def _init_lmcache_engine( chunk_size = lmcache_config.chunk_size num_kv_head = model_config.get_num_kv_heads(parallel_config) head_size = model_config.get_head_size() - kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, - head_size) - logger.info("use mla: %s, kv shape: %s, num_mtp_layers: %s", use_mla, - kv_shape, num_mtp_layers) + kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size) + logger.info( + "use mla: %s, kv shape: %s, num_mtp_layers: %s", + use_mla, + kv_shape, + num_mtp_layers, + ) # Change current device. num_gpus = torch.cuda.device_count() @@ -527,7 +547,6 @@ def add_request(self, req_meta: ReqMeta) -> None: class LMCacheConnectorV1Impl: - def __init__( self, vllm_config: "VllmConfig", @@ -540,31 +559,36 @@ def __init__( self.worker_count = vllm_config.parallel_config.tensor_parallel_size config = lmcache_get_or_create_config() assert isinstance(config, LMCacheEngineConfig), ( - "LMCache v1 configuration is should be passed for vLLM v1.") + "LMCache v1 configuration is should be passed for vLLM v1." + ) # Put the leading with "lmcache." and matched configs from # vllm extra_config to the config kv_connector_extra_config = ( - vllm_config.kv_transfer_config.kv_connector_extra_config) + vllm_config.kv_transfer_config.kv_connector_extra_config + ) if kv_connector_extra_config: for key, value in kv_connector_extra_config.items(): if key.startswith("lmcache."): config_key = key[8:] # Remove "lmcache." prefix - if _validate_and_set_config_value(config, config_key, - value): + if _validate_and_set_config_value(config, config_key, value): logger.info( - "Updated config %s from vLLM " - "extra config: %s", config_key, value) + "Updated config %s from vLLM extra config: %s", + config_key, + value, + ) self.config = config self.async_loading = config.enable_async_loading - self.layerwise_retrievers: list[Generator[Optional[torch.Tensor], None, - None]] = [] + self.layerwise_retrievers: list[ + Generator[Optional[torch.Tensor], None, None] + ] = [] self._stats_monitor = LMCStatsMonitor.GetOrCreate() if role == KVConnectorRole.SCHEDULER: # Create lookup client using factory self.lookup_client = LookupClientFactory.create_lookup_client( - vllm_config, config) + vllm_config, config + ) self._unfinished_requests: dict[str, Request] = {} self._lookup_requests_in_step: list[str] = [] self.lmcache_engine = None @@ -588,7 +612,8 @@ def __init__( # Create lookup server using factory assert self.lmcache_engine is not None self.lookup_server = LookupClientFactory.create_lookup_server( - self.lmcache_engine, vllm_config) + self.lmcache_engine, vllm_config + ) self.offload_server = ZMQOffloadServer( self.lmcache_engine, @@ -599,8 +624,7 @@ def __init__( # In case of MLA, the lookup server is only created on worker 0 if self.async_loading and self.lookup_server is not None: assert isinstance(self.lookup_server, LMCacheAsyncLookupServer) - self.lmcache_engine.post_init( - async_lookup_server=self.lookup_server) + self.lmcache_engine.post_init(async_lookup_server=self.lookup_server) self.kv_caches: dict[str, torch.Tensor] = {} @@ -617,29 +641,33 @@ def __init__( # Whether to discard partial chunks self._discard_partial_chunks = ( vllm_config.kv_transfer_config.get_from_extra_config( - "discard_partial_chunks", False) - or not config.save_unfull_chunk) + "discard_partial_chunks", False + ) + or not config.save_unfull_chunk + ) self._lmcache_chunk_size = config.chunk_size self._save_decode_cache = config.save_decode_cache - self.skip_last_n_tokens = \ - vllm_config.kv_transfer_config.get_from_extra_config( - "skip_last_n_tokens", 0) + self.skip_last_n_tokens = vllm_config.kv_transfer_config.get_from_extra_config( + "skip_last_n_tokens", 0 + ) self.num_layers = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) + vllm_config.parallel_config + ) self.current_layer = 0 - self.force_skip_save = bool( - os.environ.get("LMCACHE_FORCE_SKIP_SAVE", False)) + self.force_skip_save = bool(os.environ.get("LMCACHE_FORCE_SKIP_SAVE", False)) self._requests_priority: dict[str, int] = {} # TODO(baoloongmao): Internal api server & plugin framework support # dp > 1 - if (vllm_config.parallel_config.data_parallel_size_local == 1 - or vllm_config.parallel_config.data_parallel_rank_local == 0): + if ( + vllm_config.parallel_config.data_parallel_size_local == 1 + or vllm_config.parallel_config.data_parallel_rank_local == 0 + ): # Start internal API server if enabled # The enabled check is in the InternalAPIServer constructor self.api_server = InternalAPIServer(self) @@ -649,7 +677,8 @@ def __init__( self.config, role, self.worker_count, - -1 if self.lmcache_engine is None # scheduler side + -1 + if self.lmcache_engine is None # scheduler side else self.lmcache_engine.metadata.worker_id, ) self.plugin_launcher.launch_plugins() @@ -658,9 +687,12 @@ def __init__( self.plugin_launcher = None # type: ignore[assignment] logger.info( "LMCache initialized for role %s with version %s, " - "vllm version %s, lmcache cache_engine metadata: %s", role, - utils.get_version(), VLLM_VERSION, - getattr(self.lmcache_engine, 'metadata', None)) + "vllm version %s, lmcache cache_engine metadata: %s", + role, + utils.get_version(), + VLLM_VERSION, + getattr(self.lmcache_engine, "metadata", None), + ) def get_inference_info(self) -> dict: """Get inference information including vLLM config and related details. @@ -677,40 +709,37 @@ def get_inference_info(self) -> dict: "lmcache_version": utils.get_version(), "vllm_config": str(vllm_config), "model_config": { - "model": - getattr(vllm_config.model_config, "model", None), - "dtype": - str(getattr(vllm_config.model_config, "dtype", None)), - "max_model_len": - getattr(vllm_config.model_config, "max_model_len", None), - "vocab_size": - getattr(vllm_config.model_config, "vocab_size", None), - "num_layers": - getattr(vllm_config.model_config, "get_num_layers", - lambda _: None)(vllm_config.parallel_config), - "num_attention_heads": - getattr(vllm_config.model_config, "get_num_attention_heads", - lambda _: None)(vllm_config.parallel_config), - "num_kv_heads": - getattr(vllm_config.model_config, "get_num_kv_heads", - lambda _: None)(vllm_config.parallel_config), - "head_size": - getattr(vllm_config.model_config, "get_head_size", - lambda: None)(), + "model": getattr(vllm_config.model_config, "model", None), + "dtype": str(getattr(vllm_config.model_config, "dtype", None)), + "max_model_len": getattr( + vllm_config.model_config, "max_model_len", None + ), + "vocab_size": getattr(vllm_config.model_config, "vocab_size", None), + "num_layers": getattr( + vllm_config.model_config, "get_num_layers", lambda _: None + )(vllm_config.parallel_config), + "num_attention_heads": getattr( + vllm_config.model_config, "get_num_attention_heads", lambda _: None + )(vllm_config.parallel_config), + "num_kv_heads": getattr( + vllm_config.model_config, "get_num_kv_heads", lambda _: None + )(vllm_config.parallel_config), + "head_size": getattr( + vllm_config.model_config, "get_head_size", lambda: None + )(), }, "cache_config": { - "block_size": - getattr(vllm_config.cache_config, "block_size", None), - "cache_dtype": - str(getattr(vllm_config.cache_config, "cache_dtype", None)), - "gpu_memory_utilization": - getattr(vllm_config.cache_config, "gpu_memory_utilization", - None), - "swap_space": - getattr(vllm_config.cache_config, "swap_space", None), - "enable_prefix_caching": - getattr(vllm_config.cache_config, "enable_prefix_caching", - None), + "block_size": getattr(vllm_config.cache_config, "block_size", None), + "cache_dtype": str( + getattr(vllm_config.cache_config, "cache_dtype", None) + ), + "gpu_memory_utilization": getattr( + vllm_config.cache_config, "gpu_memory_utilization", None + ), + "swap_space": getattr(vllm_config.cache_config, "swap_space", None), + "enable_prefix_caching": getattr( + vllm_config.cache_config, "enable_prefix_caching", None + ), }, } @@ -725,26 +754,24 @@ def get_inference_version(self) -> str: return VLLM_VERSION @_lmcache_nvtx_annotate - def _init_kv_caches_from_forward_context( - self, forward_context: "ForwardContext"): + def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): for layer_name in forward_context.no_compile_layers: attn_layer = forward_context.no_compile_layers[layer_name] if not hasattr(attn_layer, "kv_cache"): - logger.debug("The layer %s does not have kv_cache, skip it", - layer_name) + logger.debug("The layer %s does not have kv_cache, skip it", layer_name) continue if layer_name not in self.kv_caches: self.kv_caches[layer_name] = attn_layer.kv_cache[ - forward_context.virtual_engine] + forward_context.virtual_engine + ] #################### # Worker side APIs #################### @_lmcache_nvtx_annotate - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. @@ -769,8 +796,7 @@ def start_load_kv(self, forward_context: "ForwardContext", attn_metadata = forward_context.attn_metadata if attn_metadata is None: - logger.debug( - "In connector.start_load_kv, but the attn_metadata is None") + logger.debug("In connector.start_load_kv, but the attn_metadata is None") return assert self.lmcache_engine is not None @@ -794,11 +820,14 @@ def start_load_kv(self, forward_context: "ForwardContext", assert len(tokens) == len(slot_mapping) self._stats_monitor.update_interval_vllm_hit_tokens( - request.load_spec.vllm_cached_tokens) + request.load_spec.vllm_cached_tokens + ) token_mask = torch.ones(len(tokens), dtype=torch.bool) - masked_token_count = (request.load_spec.vllm_cached_tokens // - self._lmcache_chunk_size * - self._lmcache_chunk_size) + masked_token_count = ( + request.load_spec.vllm_cached_tokens + // self._lmcache_chunk_size + * self._lmcache_chunk_size + ) token_mask[:masked_token_count] = False lmcache_cached_tokens = request.load_spec.lmcache_cached_tokens @@ -838,12 +867,14 @@ def start_load_kv(self, forward_context: "ForwardContext", # Check the result num_retrieved_tokens = ret_token_mask.sum().item() - num_expected_tokens = (lmcache_cached_tokens - - request.load_spec.vllm_cached_tokens) + num_expected_tokens = ( + lmcache_cached_tokens - request.load_spec.vllm_cached_tokens + ) if num_retrieved_tokens < num_expected_tokens: logger.error( "The number of retrieved tokens is less than the " - "expected number of tokens! This should not happen!") + "expected number of tokens! This should not happen!" + ) logger.error( "Num retrieved tokens: %d, num expected tokens: %d", num_retrieved_tokens, @@ -861,8 +892,7 @@ def wait_for_layer_load(self, layer_name: str) -> None: layer_name: the name of that layer """ if self.layerwise_retrievers: - logger.debug("Waiting for layer %s to be loaded", - self.current_layer) + logger.debug("Waiting for layer %s to be loaded", self.current_layer) # Wait for the layer to be loaded for layerwise_retriever in self.layerwise_retrievers: @@ -940,9 +970,11 @@ def save_kv_layer( if skip_leading_tokens == len(token_ids): continue # skip this request # Align to lmcache chunk size - skip_leading_tokens = (skip_leading_tokens // - self._lmcache_chunk_size * - self._lmcache_chunk_size) + skip_leading_tokens = ( + skip_leading_tokens + // self._lmcache_chunk_size + * self._lmcache_chunk_size + ) store_mask = torch.ones(len(token_ids), dtype=torch.bool) store_mask[:skip_leading_tokens] = False @@ -983,7 +1015,8 @@ def wait_for_save(self): assert isinstance(connector_metadata, LMCacheConnectorMetadata) self.lmcache_engine.lookup_unpin( # type: ignore - connector_metadata.lookup_requests_in_step) + connector_metadata.lookup_requests_in_step + ) if self.kv_role == "kv_consumer": # Don't do save if the role is kv_consumer @@ -1001,8 +1034,9 @@ def wait_for_save(self): for request in connector_metadata.requests: save_spec = request.save_spec - if (save_spec is None or - not save_spec.can_save) and self.kv_role != "kv_producer": + if ( + save_spec is None or not save_spec.can_save + ) and self.kv_role != "kv_producer": continue token_ids = request.token_ids @@ -1019,15 +1053,17 @@ def wait_for_save(self): if self.kv_role == "kv_producer": assert request.disagg_spec is not None skip_leading_tokens = min( - skip_leading_tokens, - request.disagg_spec.num_transferred_tokens) + skip_leading_tokens, request.disagg_spec.num_transferred_tokens + ) if skip_leading_tokens == len(token_ids): continue # skip this request # Align to lmcache chunk size - skip_leading_tokens = (skip_leading_tokens // - self._lmcache_chunk_size * - self._lmcache_chunk_size) + skip_leading_tokens = ( + skip_leading_tokens + // self._lmcache_chunk_size + * self._lmcache_chunk_size + ) store_mask = torch.ones(len(token_ids), dtype=torch.bool) store_mask[:skip_leading_tokens] = False @@ -1047,8 +1083,9 @@ def wait_for_save(self): request.disagg_spec.is_last_prefill = True else: token_len = len(token_ids) - aligned_token_len = (token_len // self._lmcache_chunk_size * - self._lmcache_chunk_size) + aligned_token_len = ( + token_len // self._lmcache_chunk_size * self._lmcache_chunk_size + ) token_ids = token_ids[:aligned_token_len] store_mask = store_mask[:aligned_token_len] slot_mapping = slot_mapping[:aligned_token_len] @@ -1097,7 +1134,8 @@ def get_num_new_matched_tokens( external KV cache beyond what is already computed. """ if self.kv_role == "kv_producer" and not hasattr( - self.lookup_client, "supports_producer_reuse"): + self.lookup_client, "supports_producer_reuse" + ): return 0 self._requests_priority[request.request_id] = request.priority @@ -1114,11 +1152,8 @@ def get_num_new_matched_tokens( request_configs = extract_request_configs(request.sampling_params) if self.skip_last_n_tokens > 0: - token_ids = token_ids[:-self.skip_last_n_tokens] - if self.async_loading: - lookup_id = request.request_id - else: - lookup_id = str(uuid.uuid4()) + token_ids = token_ids[: -self.skip_last_n_tokens] + lookup_id = request.request_id if self.async_loading else str(uuid.uuid4()) self._lookup_requests_in_step.append(lookup_id) @@ -1147,8 +1182,7 @@ def get_num_new_matched_tokens( need_to_allocate -= 1 logger.info( - "Reqid: %s, Total tokens %d, LMCache hit tokens: %d, " - "need to load: %d", + "Reqid: %s, Total tokens %d, LMCache hit tokens: %d, need to load: %d", request.request_id, request.num_tokens, num_external_hit_tokens, @@ -1171,8 +1205,7 @@ def get_num_new_matched_tokens( return need_to_allocate @_lmcache_nvtx_annotate - def update_state_after_alloc(self, request: "Request", - num_external_tokens: int): + def update_state_after_alloc(self, request: "Request", num_external_tokens: int): """ Update KVConnector state after temporary buffer alloc. @@ -1180,15 +1213,18 @@ def update_state_after_alloc(self, request: "Request", if the CacheManager this allocated blocks for us. """ - kv_transfer_params = (request.kv_transfer_params if hasattr( - request, "kv_transfer_params") else None) + kv_transfer_params = ( + request.kv_transfer_params + if hasattr(request, "kv_transfer_params") + else None + ) - if kv_transfer_params is not None and \ - "disagg_spec" in kv_transfer_params: + if kv_transfer_params is not None and "disagg_spec" in kv_transfer_params: req_disagg_spec = kv_transfer_params["disagg_spec"] receiver_id = req_disagg_spec["receiver_host"] + str( - req_disagg_spec["receiver_init_port"]) + req_disagg_spec["receiver_init_port"] + ) disagg_spec = DisaggSpec( req_id=req_disagg_spec["req_id"], @@ -1211,22 +1247,28 @@ def update_state_after_alloc(self, request: "Request", return # Only check for non-prompt-hit case - if (self.load_specs[request.request_id].lmcache_cached_tokens - != request.num_tokens): + if ( + self.load_specs[request.request_id].lmcache_cached_tokens + != request.num_tokens + ): assert ( - num_external_tokens > 0 and num_external_tokens - == self.load_specs[request.request_id].lmcache_cached_tokens - - self.load_specs[request.request_id].vllm_cached_tokens - ), (f"Mismatch in number of tokens: {num_external_tokens} vs " + num_external_tokens > 0 + and num_external_tokens + == self.load_specs[request.request_id].lmcache_cached_tokens + - self.load_specs[request.request_id].vllm_cached_tokens + ), ( + f"Mismatch in number of tokens: {num_external_tokens} vs " f"{self.load_specs[request.request_id].lmcache_cached_tokens} -" f" {self.load_specs[request.request_id].vllm_cached_tokens}" - f" for request {request.request_id}") + f" for request {request.request_id}" + ) self.load_specs[request.request_id].can_load = True @_lmcache_nvtx_annotate def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: """Attach the connector metadata to the request object. This function should NOT modify other fields in the scheduler_output @@ -1253,8 +1295,9 @@ def build_connector_meta( # Right now, we only load KV for new requests load_spec = self.load_specs.pop(request.req_id, None) num_tokens_to_compute = ( - request.num_computed_tokens + - scheduler_output.num_scheduled_tokens[request.req_id]) + request.num_computed_tokens + + scheduler_output.num_scheduled_tokens[request.req_id] + ) lmcache_cached_tokens = 0 if load_spec is not None: lmcache_cached_tokens = load_spec.lmcache_cached_tokens @@ -1262,7 +1305,8 @@ def build_connector_meta( skip_save = force_skip_save or ( self.config.priority_limit is not None - and request_priority > self.config.priority_limit) + and request_priority > self.config.priority_limit + ) request_tracker = RequestTracker.from_new_request( self.config, @@ -1311,11 +1355,13 @@ def build_connector_meta( if request := self._unfinished_requests.get(req_id): num_current_tokens = len(request_tracker.token_ids) new_token_ids = request.all_token_ids[ - num_current_tokens:num_current_tokens + num_new_tokens] + num_current_tokens : num_current_tokens + num_new_tokens + ] else: raise ValueError( f"Request {req_id} is not in _unfinished_requests, " - f"but it is scheduled to be cached") + f"but it is scheduled to be cached" + ) new_block_ids = cached_reqs.new_block_ids[i] request_tracker.update(new_token_ids, new_block_ids) @@ -1339,8 +1385,11 @@ def request_finished( request: "Request", block_ids: list[int], ) -> tuple[bool, Optional[dict[str, Any]]]: - params = (request.kv_transfer_params if hasattr( - request, "kv_transfer_params") else None) + params = ( + request.kv_transfer_params + if hasattr(request, "kv_transfer_params") + else None + ) return_params = None # NOTE: Used to stream back the first token From 630bba9b0d7f269d4cb312d2dca3d5125999a947 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Thu, 9 Oct 2025 13:22:29 -0700 Subject: [PATCH 06/12] fix suggested changes Signed-off-by: ApostaC --- .../kv_connector/v1/lmcache_connector.py | 9 ++++----- .../kv_connector/v1/lmcache_integration/utils.py | 5 +++++ .../v1/lmcache_integration/vllm_v1_adapter.py | 16 +++++++--------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 74bec84b27df..e7f7625b6255 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -13,13 +13,12 @@ KVConnectorMetadata, KVConnectorRole, ) +from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import ( + vllm_v1_adapter as _adapter, +) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput -from .lmcache_integration.vllm_v1_adapter import ( - LMCacheConnectorV1Impl as LMCacheConnectorUpstreamImpl, # yapf: disable -) - if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext @@ -37,7 +36,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ) if use_native: logger.info("Initializing native LMCache connector") - cls = LMCacheConnectorUpstreamImpl + cls = _adapter.LMCacheConnectorV1Impl else: logger.info("Initializing latest dev LMCache connector") cls = LMCacheConnectorLatestImpl diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py index 35f8148a2c5d..6c6104125e88 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -141,6 +141,11 @@ def create_lmcache_metadata( parallel_cfg = vllm_config.parallel_config cache_cfg = vllm_config.cache_config else: + if model_config is None or parallel_config is None or cache_config is None: + raise ValueError( + "Either vllm_config must be provided, or all of " + "model_config, parallel_config, and cache_config must be provided." + ) model_cfg = model_config parallel_cfg = parallel_config cache_cfg = cache_config diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index e9444c82cfaf..2f5350209f7d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -29,26 +29,24 @@ from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer from lmcache.v1.plugin.plugin_launcher import PluginLauncher -# Third Party from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, ) -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group -from vllm.sampling_params import SamplingParams -from vllm.utils import cdiv, get_kv_cache_torch_dtype -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.version import __version__ as VLLM_VERSION - -from .utils import ( +from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( ENGINE_NAME, apply_mm_hashes_to_token_ids, extract_mm_features, lmcache_get_or_create_config, mla_enabled, ) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group +from vllm.sampling_params import SamplingParams +from vllm.utils import cdiv, get_kv_cache_torch_dtype +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.version import __version__ as VLLM_VERSION if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -340,7 +338,7 @@ def from_request_tracker( "tracker got mm_hashes but no mm_positions" ) apply_mm_hashes_to_token_ids( - token_ids, tracker.mm_hashes, tracker.mm_positions + token_ids_tensor, tracker.mm_hashes, tracker.mm_positions ) token_ids = token_ids_tensor.tolist() From 07282f5c8d1b5605b56f8def3f7f89300b512b6c Mon Sep 17 00:00:00 2001 From: ApostaC Date: Thu, 9 Oct 2025 13:27:36 -0700 Subject: [PATCH 07/12] clean up the useless codes and fix import orders Signed-off-by: ApostaC --- .../kv_connector/v1/lmcache_integration/utils.py | 13 +++++-------- .../v1/lmcache_integration/vllm_v1_adapter.py | 14 +------------- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py index 6c6104125e88..093f783e2b08 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -5,19 +5,16 @@ import threading from typing import TYPE_CHECKING, Union -if TYPE_CHECKING: - from vllm.config import ModelConfig - from vllm.multimodal.inputs import PlaceholderRange - from vllm.v1.request import Request - -# Third Party import torch - -# First Party from lmcache.config import LMCacheEngineConfig as Config from lmcache.logging import init_logger from lmcache.v1.config import LMCacheEngineConfig as V1Config +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.multimodal.inputs import PlaceholderRange + from vllm.v1.request import Request + logger = init_logger(__name__) ENGINE_NAME = "vllm-instance" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 2f5350209f7d..1ac4bf793d44 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -119,8 +119,6 @@ class RequestTracker: # The block ids that has been allocated so far # NOTE: allocated blocks could be more than the number of tokens - # FIXME: need to check whether the block ids will be changed after - # preemption allocated_block_ids: list[int] # The number of tokens that has been saved @@ -177,10 +175,6 @@ def from_new_request( # (https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/ # sched/scheduler.py#L943), # only one KVCacheGroup is supported in connector for now. - - # TODO: Please support multiple KVCacheGroup in connector. - # NOTE: Also, `update` method in RequestTracker should be - # updated accordingly. unfolded_block_ids = new_request.block_ids[0].copy() # NOTE: Initialized in `update_state_after_alloc` @@ -230,7 +224,6 @@ def update( # When a request is scheduled again, and the number of new tokens # is 1 (excluding chunked prefill), the request is in decode phase. - # TODO: Need to further exclude the case of chunked prefill with 1 token if len(new_token_ids) == 1: self.is_decode_phase = True @@ -332,7 +325,6 @@ def from_request_tracker( # If the request has multimodal hashes, apply them to the token ids if tracker.mm_hashes: - # TODO: Optimize this token_ids_tensor = torch.tensor(token_ids) assert tracker.mm_positions is not None, ( "tracker got mm_hashes but no mm_positions" @@ -364,7 +356,7 @@ def from_request_tracker( ) slot_mapping = slot_mapping.flatten()[: len(token_ids)] - assert slot_mapping.dtype == torch.long # TODO: this could be removed + assert slot_mapping.dtype == torch.long # For load operation: check whether the request is scheduled to load if load_spec is not None and load_spec.can_load: @@ -1196,10 +1188,6 @@ def get_num_new_matched_tokens( if need_to_allocate <= 0: return 0 - # TODO: Align to vLLM block size. Should test whether it can be removed - # need_to_allocate = need_to_allocate // self._block_size * \ - # self._block_size - return need_to_allocate @_lmcache_nvtx_annotate From 5133bff22e0f8339b9ebc285046c18606a5099ca Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 24 Oct 2025 12:36:44 -0700 Subject: [PATCH 08/12] Update docstrings Signed-off-by: ApostaC --- .../v1/lmcache_integration/utils.py | 22 +++++---- .../v1/lmcache_integration/vllm_v1_adapter.py | 49 +++++++++---------- 2 files changed, 35 insertions(+), 36 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py index 093f783e2b08..f8e6a6963785 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -3,7 +3,7 @@ # Standard import os import threading -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import torch from lmcache.config import LMCacheEngineConfig as Config @@ -19,7 +19,7 @@ ENGINE_NAME = "vllm-instance" # Thread-safe singleton storage -_config_instance: Union[Config, V1Config, None] = None +_config_instance: Config | V1Config | None = None _config_lock = threading.Lock() @@ -28,7 +28,7 @@ def is_false(value: str) -> bool: return value.lower() in ("false", "0", "no", "n", "off") -def lmcache_get_or_create_config() -> Union[Config, V1Config]: +def lmcache_get_or_create_config() -> Config | V1Config: """Get the LMCache configuration from the environment variable `LMCACHE_CONFIG_FILE`. If the environment variable is not set, this function will return the default configuration. @@ -116,14 +116,18 @@ def create_lmcache_metadata( across multiple files. Args: - vllm_config: vLLM configuration object containing model, parallel, and - cache configs (alternative to individual config parameters) - model_config: Model configuration (alternative to vllm_config) - parallel_config: Parallel configuration (alternative to vllm_config) - cache_config: Cache configuration (alternative to vllm_config) + vllm_config (VllmConfig): vLLM configuration object containing model, + parallel, and cache configs (alternative to + individual config parameters) + model_config (ModelConfig): Model configuration (alternative to + vllm_config) + parallel_config (ParallelConfig): Parallel configuration (alternative + to vllm_config) + cache_config (CacheConfig): Cache configuration (alternative to + vllm_config) Returns: - tuple: (LMCacheEngineMetadata, LMCacheEngineConfig) + tuple[LMCacheEngineMetadata, LMCacheEngineConfig] """ # Third Party # First Party diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 1ac4bf793d44..fadb99768711 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Generator from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional import torch from lmcache import utils @@ -91,7 +91,7 @@ class DisaggSpec: tmp_disagg_tracker: dict[str, DisaggSpec] = {} -def extract_request_configs(sampling_params: SamplingParams) -> Optional[dict]: +def extract_request_configs(sampling_params: SamplingParams) -> dict | None: request_configs = None if ( sampling_params.extra_args is not None @@ -125,14 +125,14 @@ class RequestTracker: num_saved_tokens: int = 0 # Disagg spec for the request - disagg_spec: Optional[DisaggSpec] = None + disagg_spec: DisaggSpec | None = None # Multimodal hashes and positions - mm_hashes: Optional[list[str]] = None - mm_positions: Optional[list["PlaceholderRange"]] = None + mm_hashes: list[str] | None = None + mm_positions: list["PlaceholderRange"] | None = None # The configs of the request, includes tags and other configs - request_configs: Optional[dict] = None + request_configs: dict | None = None # Whether the request is in decode phase is_decode_phase = False @@ -159,7 +159,6 @@ def from_new_request( local cache hit) and new tokens that will be scheduled. lmcache_cached_tokens (int): the number of tokens that are cached in LMCache. - request_priority (int): the priority of the request skip_save (bool): whether the request cache should be saved """ # vLLM 0.9.0 update: request.block_ids changed from list[int] to @@ -200,7 +199,7 @@ def from_new_request( def update( self, new_token_ids: list[int], - new_block_ids: Union[Optional[tuple[list[int], ...]], list[int]], + new_block_ids: tuple[list[int], ...] | None | list[int], ) -> None: """Update the request tracker when a running request is scheduled again @@ -241,20 +240,20 @@ class ReqMeta: is_last_prefill: bool = False # Skip save or not - save_spec: Optional[SaveSpec] = None + save_spec: SaveSpec | None = None # load_spec - load_spec: Optional[LoadSpec] = None + load_spec: LoadSpec | None = None # disagg spec - disagg_spec: Optional[DisaggSpec] = None + disagg_spec: DisaggSpec | None = None # the configs of the request - request_configs: Optional[dict] = None + request_configs: dict | None = None @staticmethod def from_request_tracker( tracker: RequestTracker, block_size: int, lmcache_chunk_size: int = 256, - load_spec: Optional[LoadSpec] = None, + load_spec: LoadSpec | None = None, discard_partial_chunks: bool = True, save_decode_cache: bool = False, ) -> Optional["ReqMeta"]: @@ -467,11 +466,11 @@ def _init_lmcache_engine( ) use_gpu = need_gpu_interm_buffer(lmcache_config) - vllm_gpu_connector: Union[ - VLLMBufferLayerwiseGPUConnector, - VLLMPagedMemGPUConnectorV2, - VLLMPagedMemLayerwiseGPUConnector, - ] + vllm_gpu_connector: ( + VLLMBufferLayerwiseGPUConnector + | VLLMPagedMemGPUConnectorV2 + | VLLMPagedMemLayerwiseGPUConnector + ) if use_mla and lmcache_config.use_layerwise: raise ValueError("layerwise MLA connector is not supported yet") @@ -570,9 +569,7 @@ def __init__( self.config = config self.async_loading = config.enable_async_loading - self.layerwise_retrievers: list[ - Generator[Optional[torch.Tensor], None, None] - ] = [] + self.layerwise_retrievers: list[Generator[torch.Tensor | None, None, None]] = [] self._stats_monitor = LMCStatsMonitor.GetOrCreate() if role == KVConnectorRole.SCHEDULER: # Create lookup client using factory @@ -623,7 +620,7 @@ def __init__( # request_id -> (vllm cached tokens, lmcache cached tokens) self.load_specs: dict[str, LoadSpec] = {} - self.kv_cache_manager: Optional[KVCacheManager] = None + self.kv_cache_manager: KVCacheManager | None = None # request_id -> full_token_ids self._request_trackers: dict[str, RequestTracker] = {} @@ -767,7 +764,6 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: Args: forward_context (ForwardContext): the forward context. - **kwargs: additional arguments for the load operation Note: The number of elements in kv_caches and layer_names should be @@ -911,7 +907,6 @@ def save_kv_layer( kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. - **kwargs: additional arguments for the save operation. """ assert self.lmcache_engine is not None @@ -1098,7 +1093,7 @@ def wait_for_save(self): @_lmcache_nvtx_annotate def get_finished( self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + ) -> tuple[set[str] | None, set[str] | None]: return None, None ################### @@ -1110,7 +1105,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> Optional[int]: + ) -> int | None: """ Check for external KV cache hit. @@ -1370,7 +1365,7 @@ def request_finished( self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: params = ( request.kv_transfer_params if hasattr(request, "kv_transfer_params") From b1b26a2e030cc5cc421e1d8b2a79d9075a6114db Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 24 Oct 2025 13:03:15 -0700 Subject: [PATCH 09/12] fixing mypy Signed-off-by: ApostaC --- .../kv_connector/v1/lmcache_connector.py | 1 + .../v1/lmcache_integration/utils.py | 12 ++++++--- .../v1/lmcache_integration/vllm_v1_adapter.py | 27 ++++++++++++++----- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index ce3f6234ee08..a5240adab438 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -31,6 +31,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) + assert vllm_config.kv_transfer_config is not None use_native = vllm_config.kv_transfer_config.get_from_extra_config( "use_native", False ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py index f8e6a6963785..68e01086b908 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -3,7 +3,7 @@ # Standard import os import threading -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union import torch from lmcache.config import LMCacheEngineConfig as Config @@ -13,6 +13,7 @@ if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.multimodal.inputs import PlaceholderRange + from vllm.v1.core.sched.output import NewRequestData from vllm.v1.request import Request logger = init_logger(__name__) @@ -179,7 +180,7 @@ def create_lmcache_metadata( def extract_mm_features( - request: "Request", modify: bool = False + request: Union["Request", "NewRequestData"], modify: bool = False ) -> tuple[list[str], list["PlaceholderRange"]]: """ Normalize multimodal information from a Request into parallel lists. @@ -213,8 +214,11 @@ def extract_mm_features( return (list(mm_hashes), list(mm_positions)) elif getattr(request, "mm_hashes", None): if modify: - return (request.mm_hashes.copy(), request.mm_positions.copy()) + return ( + request.mm_hashes.copy(), # type: ignore + request.mm_positions.copy(), + ) # type: ignore else: - return (request.mm_hashes, request.mm_positions) + return (request.mm_hashes, request.mm_positions) # type: ignore else: return ([], []) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index fadb99768711..281c5335458f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -98,6 +98,9 @@ def extract_request_configs(sampling_params: SamplingParams) -> dict | None: and "kv_transfer_params" in sampling_params.extra_args ): kv_transfer_params = sampling_params.extra_args.get("kv_transfer_params") + if kv_transfer_params is None: + return None + assert isinstance(kv_transfer_params, dict) for k, v in kv_transfer_params.items(): if k.startswith("lmcache."): if request_configs is None: @@ -179,10 +182,14 @@ def from_new_request( # NOTE: Initialized in `update_state_after_alloc` disagg_spec = tmp_disagg_tracker.pop(new_request.req_id, None) - request_configs = extract_request_configs(new_request.sampling_params) + if new_request.sampling_params: + request_configs = extract_request_configs(new_request.sampling_params) + else: + request_configs = None mm_hashes, mm_positions = extract_mm_features(new_request, modify=True) + assert new_request.prompt_token_ids is not None return RequestTracker( req_id=new_request.req_id, prompt_len=len(new_request.prompt_token_ids), @@ -542,6 +549,7 @@ def __init__( role: KVConnectorRole, parent: KVConnectorBase_V1, ): + assert vllm_config.kv_transfer_config is not None self._parent = parent self._vllm_config = vllm_config self.kv_role = vllm_config.kv_transfer_config.kv_role @@ -1126,16 +1134,21 @@ def get_num_new_matched_tokens( self._requests_priority[request.request_id] = request.priority token_ids = request.prompt_token_ids + assert token_ids is not None # If the request has multimodal hashes, apply them to the token ids mm_hashes, mm_positions = extract_mm_features(request) if mm_hashes and mm_positions: # TODO(Jiayi): Optimize this - token_ids = torch.tensor(request.prompt_token_ids) - apply_mm_hashes_to_token_ids(token_ids, mm_hashes, mm_positions) - token_ids = token_ids.tolist() + token_ids_tensor = torch.tensor(request.prompt_token_ids) + apply_mm_hashes_to_token_ids(token_ids_tensor, mm_hashes, mm_positions) + token_ids = token_ids_tensor.tolist() + + if request.sampling_params: + request_configs = extract_request_configs(request.sampling_params) + else: + request_configs = None - request_configs = extract_request_configs(request.sampling_params) if self.skip_last_n_tokens > 0: token_ids = token_ids[: -self.skip_last_n_tokens] lookup_id = request.request_id if self.async_loading else str(uuid.uuid4()) @@ -1333,9 +1346,9 @@ def build_connector_meta( for i, req_id in enumerate(cached_reqs.req_ids): request_tracker = self._request_trackers[req_id] num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] - if request := self._unfinished_requests.get(req_id): + if cached_request := self._unfinished_requests.get(req_id): num_current_tokens = len(request_tracker.token_ids) - new_token_ids = request.all_token_ids[ + new_token_ids = cached_request.all_token_ids[ num_current_tokens : num_current_tokens + num_new_tokens ] else: From b94d9c70ba3b49f6c36b4478b3a962f3a5db363f Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 24 Oct 2025 13:10:42 -0700 Subject: [PATCH 10/12] fixing mypy Signed-off-by: ApostaC --- .../kv_transfer/kv_connector/v1/lmcache_integration/utils.py | 4 ++-- .../kv_connector/v1/lmcache_integration/vllm_v1_adapter.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py index 68e01086b908..5ac132c8c895 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -216,8 +216,8 @@ def extract_mm_features( if modify: return ( request.mm_hashes.copy(), # type: ignore - request.mm_positions.copy(), - ) # type: ignore + request.mm_positions.copy(), # type: ignore + ) else: return (request.mm_hashes, request.mm_positions) # type: ignore else: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 281c5335458f..1f42b598bc9c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -1134,7 +1134,6 @@ def get_num_new_matched_tokens( self._requests_priority[request.request_id] = request.priority token_ids = request.prompt_token_ids - assert token_ids is not None # If the request has multimodal hashes, apply them to the token ids mm_hashes, mm_positions = extract_mm_features(request) @@ -1150,6 +1149,7 @@ def get_num_new_matched_tokens( request_configs = None if self.skip_last_n_tokens > 0: + assert token_ids is not None token_ids = token_ids[: -self.skip_last_n_tokens] lookup_id = request.request_id if self.async_loading else str(uuid.uuid4()) From 27c719594bd5e05f86658c2d493a39607b050e75 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 24 Oct 2025 14:20:45 -0700 Subject: [PATCH 11/12] fix docstring Signed-off-by: ApostaC --- .../kv_transfer/kv_connector/v1/lmcache_integration/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py index 5ac132c8c895..9705d45e26ba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -128,7 +128,8 @@ def create_lmcache_metadata( vllm_config) Returns: - tuple[LMCacheEngineMetadata, LMCacheEngineConfig] + tuple[LMCacheEngineMetadata, LMCacheEngineConfig]: a tuple of + LMCacheEngineConfig and LMCacheEngineMetadata """ # Third Party # First Party From 153f4e02d9f51b1d2601630a6f7904c65cbc9ba3 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 24 Oct 2025 14:49:41 -0700 Subject: [PATCH 12/12] fix docstring Signed-off-by: ApostaC --- .../kv_transfer/kv_connector/v1/lmcache_integration/utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py index 9705d45e26ba..e0282c155248 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -126,10 +126,6 @@ def create_lmcache_metadata( to vllm_config) cache_config (CacheConfig): Cache configuration (alternative to vllm_config) - - Returns: - tuple[LMCacheEngineMetadata, LMCacheEngineConfig]: a tuple of - LMCacheEngineConfig and LMCacheEngineMetadata """ # Third Party # First Party