diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b7d7a10057b8..ae4125d54190 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -9,8 +9,10 @@ import time import uuid from collections import defaultdict -from unittest.mock import patch +from typing import Any +from unittest.mock import MagicMock, patch +import msgspec import pytest import ray import torch @@ -18,6 +20,7 @@ from vllm import LLM from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( MultiKVConnectorStats, @@ -29,7 +32,9 @@ NixlConnectorMetadata, NixlConnectorScheduler, NixlConnectorWorker, + NixlHandshakePayload, NixlKVConnectorStats, + compute_nixl_compatibility_hash, ) from vllm.distributed.kv_transfer.kv_transfer_state import ( ensure_kv_transfer_shutdown, @@ -317,13 +322,19 @@ def test_kv_transfer_handshake(dist_init): } prefill_connector.register_kv_caches(kv_caches) - # Simulate EngineCore initialization that would - # gather connector metadata from all workers, the scheduler connector - # expects metadata to be in dict[int, KVConnectorHandshakeMetadata], - # where the first key is the dp_rank, the second key is the tp_rank. - metadata = {0: prefill_connector.get_handshake_metadata()} + # Simulate EngineCore initialization that would gather connector + # metadata from all workers + metadata = prefill_connector.get_handshake_metadata() + + # metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes) + + # The scheduler connector expects metadata to be in + # dict[int, KVConnectorHandshakeMetadata], where the first key is + # the dp_rank, the second key is the tp_rank. scheduler_connector = scheduler.get_kv_connector() - scheduler_connector.set_xfer_handshake_metadata(metadata) + scheduler_connector.set_xfer_handshake_metadata({0: metadata}) # Simulate a request that finishes prefill, which returns # corresponding NixlConnectorMetadata for decode instance. @@ -362,9 +373,9 @@ def test_kv_transfer_handshake(dist_init): ) received_metadata = mock_add_remote_agent.call_args.args + assert received_metadata[0] == expected_agent_metadata assert received_metadata[1] == 0 # remote_tp_rank assert received_metadata[2] == 1 # remote_tp_size - assert metadata[0] == received_metadata[0] # Need to shutdown the background thread to release NIXL side channel port scheduler_connector.shutdown() @@ -403,7 +414,6 @@ def _nixl_handshake( device_id=0, num_blocks=1, block_lens=self.block_len_per_layer, - attn_backend_name=self.backend_name, # `self.kv_cache_layout` is only forced to HND when vllm engine # is started. We mock HND here. kv_cache_layout="HND", @@ -651,7 +661,6 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): device_id=0, num_blocks=1, block_lens=worker.block_len_per_layer, - attn_backend_name=worker.backend_name, kv_cache_layout=mismatched_layout, block_size=worker.block_size, ) @@ -706,7 +715,6 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( num_blocks=1, # prefill TP=1, decode TP=2, remote block_lens is double to local block_lens=[i * 2 for i in worker.block_len_per_layer], - attn_backend_name=worker.backend_name, kv_cache_layout="HND", block_size=worker.block_size, ) @@ -1168,6 +1176,9 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): mock_wrapper_instance = mock_nixl_wrapper.return_value connector.connector_worker.nixl_wrapper = mock_wrapper_instance + # Appease NixlHandshakePayload encoding with some bytes + mock_wrapper_instance.get_agent_metadata.return_value = b"fake_agent_metadata" + # Reassure the shutdown() check that the thread is terminated mock_thread.return_value.is_alive.return_value = False @@ -1534,3 +1545,194 @@ def test_transfer_setup_failure_returns_finished(dist_init): # ensure request appears in get_finished _, done_recving = connector.get_finished(finished_req_ids=set()) assert request_id in done_recving + + +@pytest.mark.parametrize( + "mismatch_type,config_overrides,version_override,should_fail,enforce_handshake_compat", + [ + ("vllm_version", {}, {"vllm_version": "0.6.1"}, True, True), + ("nixl_connector_version", {}, {"connector_version": 37}, True, True), + ("model_name", {"model": "facebook/opt-350m"}, {}, True, True), + ("dtype", {"dtype": "bfloat16"}, {}, True, True), + ("cache_dtype", {"cache_dtype": "fp8"}, {}, True, True), + ("num_kv_heads", {"hf_overrides": {"num_key_value_heads": 8}}, {}, True, True), + ( + "num_hidden_layers", + {"hf_overrides": {"num_hidden_layers": 24}}, + {}, + True, + True, + ), + ("hidden_size", {"hf_overrides": {"hidden_size": 1536}}, {}, True, True), + ("block_size", {"block_size": 8}, {}, False, True), + ("matching_config", {}, {}, False, True), + ("escape_hatch", {"model": "facebook/opt-350m"}, {}, False, False), + ], +) +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_compatibility_hash_validation( + dist_init, + mismatch_type, + config_overrides, + version_override, + should_fail, + enforce_handshake_compat, +): + """ + Test NIXL compatibility hash validation during handshake. + + Parameters: + mismatch_type: description of what is being tested + config_overrides: dict of config to override for the remote instance + version_override: version dict e.g. {"vllm_version": "0.6.1"} + should_fail: whether the handshake should fail + enforce_handshake_compat: whether to enforce compatibility checking + """ + local_vllm_config = create_vllm_config( + model="facebook/opt-125m", + block_size=16, + kv_connector_extra_config={ + "enforce_handshake_compat": enforce_handshake_compat + }, + ) + decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) + decode_worker = decode_connector.connector_worker + + remote_config_params: dict[str, Any] = { + "model": "facebook/opt-125m", + "block_size": 16, + **config_overrides, + } + remote_vllm_config = create_vllm_config(**remote_config_params) + + with contextlib.ExitStack() as stack: + if "vllm_version" in version_override: + stack.enter_context( + patch("vllm.__version__", version_override["vllm_version"]) + ) + elif "connector_version" in version_override: + stack.enter_context( + patch.object( + nixl_connector, + "NIXL_CONNECTOR_VERSION", + version_override["connector_version"], + ) + ) + remote_hash = compute_nixl_compatibility_hash( + remote_vllm_config, decode_worker.backend_name + ) + + prefill_block_size = config_overrides.get("block_size", 16) + prefill_metadata = NixlAgentMetadata( + engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + device_id=0, + num_blocks=1, + block_lens=[4096 * prefill_block_size], # slot_size * block_size + kv_cache_layout="HND", + block_size=prefill_block_size, + ) + handshake_payload = NixlHandshakePayload( + compatibility_hash=remote_hash, + agent_metadata_bytes=msgspec.msgpack.encode(prefill_metadata), + ) + + # Mock ZMQ socket to return our handshake payload + mock_socket = MagicMock() + mock_socket.recv.return_value = msgspec.msgpack.encode(handshake_payload) + + # Mock add_remote_agent to avoid actual NIXL operations + # Patch zmq_ctx to return our mock socket + with ( + patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"), + patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx, + ): + mock_zmq_ctx.return_value.__enter__.return_value = mock_socket + + if should_fail: + with pytest.raises(RuntimeError, match="compatibility hash mismatch"): + decode_worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=1, + expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + ) + else: + result = decode_worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=1, + expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + ) + # Verify handshake returned agent mapping + assert isinstance(result, dict) + assert len(result) == 1 + + +@pytest.mark.parametrize( + "error_scenario", + [ + "handshake_decode_error", + "handshake_validation_error", + "metadata_decode_error", + "metadata_validation_error", + ], +) +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_handshake_decode_errors(dist_init, error_scenario): + """ + Test that msgspec decode errors are properly handled during handshake. + + Tests both DecodeError and ValidationError for both decoders: + - NixlHandshakePayload decoder + - NixlAgentMetadata decoder + """ + local_vllm_config = create_vllm_config( + model="facebook/opt-125m", + block_size=16, + ) + decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) + decode_worker = decode_connector.connector_worker + + if error_scenario == "handshake_decode_error": + msg_bytes = b"this is not valid msgpack data" + elif error_scenario == "handshake_validation_error": + msg_bytes = msgspec.msgpack.encode({"wrong_field": "value"}) + elif error_scenario == "metadata_decode_error": + valid_handshake = NixlHandshakePayload( + compatibility_hash=decode_worker.compat_hash, + agent_metadata_bytes=b"invalid msgpack for metadata", + ) + msg_bytes = msgspec.msgpack.encode(valid_handshake) + + elif error_scenario == "metadata_validation_error": + valid_handshake = NixlHandshakePayload( + compatibility_hash=decode_worker.compat_hash, + agent_metadata_bytes=msgspec.msgpack.encode({"missing": "fields"}), + ) + msg_bytes = msgspec.msgpack.encode(valid_handshake) + else: + raise AssertionError(f"{error_scenario} not a valid scenario") + + mock_socket = MagicMock() + mock_socket.recv.return_value = msg_bytes + with ( + patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"), + patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx, + ): + mock_zmq_ctx.return_value.__enter__.return_value = mock_socket + + with pytest.raises(RuntimeError): + decode_worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=1, + expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + ) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index f35f91bb3adf..039746f9c15f 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -90,6 +90,10 @@ def create_vllm_config( max_model_len: int = 10000, enable_chunked_prefill: bool = True, enable_permute_local_kv: bool = False, + kv_connector_extra_config: dict[str, Any] | None = None, + dtype: str = "float16", + cache_dtype: str = "auto", + hf_overrides: dict[str, Any] | None = None, ) -> VllmConfig: """Initialize VllmConfig For Testing.""" scheduler_config = SchedulerConfig( @@ -101,21 +105,23 @@ def create_vllm_config( model_config = ModelConfig( model=model, trust_remote_code=True, - dtype="float16", + dtype=dtype, seed=42, + hf_overrides=hf_overrides or {}, ) # Cache config, optionally force APC cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, swap_space=0, - cache_dtype="auto", + cache_dtype=cache_dtype, enable_prefix_caching=True, ) kv_transfer_config = KVTransferConfig( kv_connector="NixlConnector", kv_role="kv_both", enable_permute_local_kv=enable_permute_local_kv, + kv_connector_extra_config=kv_connector_extra_config or {}, ) return VllmConfig( scheduler_config=scheduler_config, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 24c8d32dafed..785deabd5e57 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -59,6 +59,21 @@ EngineId = str ReqId = str +# +# NIXL Connector Version +# +# Increment this version whenever there is an incompatible change to: +# - NixlAgentMetadata schema +# - kv_transfer_params schema or semantics +# - NIXL transfer protocol or wire format +# - KV cache memory layout or block organization +# - Any other change that breaks P/D interoperability +# +# Version History: +# 1: Initial version with compatibility checking +# +NIXL_CONNECTOR_VERSION: int = 1 + GET_META_MSG = b"get_meta_msg" logger = init_logger(__name__) @@ -97,18 +112,95 @@ @dataclass -class NixlAgentMetadata(KVConnectorHandshakeMetadata): +class NixlAgentMetadata: engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] device_id: int num_blocks: int block_lens: list[int] - attn_backend_name: str kv_cache_layout: str block_size: int +@dataclass +class NixlHandshakePayload(KVConnectorHandshakeMetadata): + """ + Wrapper for NIXL handshake sent over the wire. + + Enables two-phase decoding for graceful compatibility checking: + 1. Decode NixlHandshakePayload to get compatibility_hash + 2. Compute local hash and compare + 3. Only if hashes match, decode agent_metadata_bytes + + This prevents decoder errors when NixlAgentMetadata schema is + incompatible, allowing graceful failure with clear error message. + """ + + compatibility_hash: str + agent_metadata_bytes: bytes # NixlAgentMetadata encoded + + +def compute_nixl_compatibility_hash( + vllm_config: VllmConfig, attn_backend_name: str +) -> str: + """ + Compute compatibility hash for NIXL KV transfer. + + Hash only the factors that affect whether two NIXL instances can + successfully transfer KV cache data. + + Factors included: + - vLLM version and NIXL connector version + - Model architecture (name, dtype, KV heads, layers) + - KV cache format (dtype, sliding window) + - Attention backend + + Note: Factors like tensor_parallel_size, block_size, and kv_cache_layout + are validated at runtime in _validate_remote_agent_handshake and are not + included in this hash to support heterogeneous deployments. + + Note - the set of factors are likely to evolve significantly over + time to be more or less permissive. + + Returns: + SHA-256 hex digest + """ + from vllm import __version__ as vllm_version + from vllm.config.utils import hash_factors + + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + + factors = { + # Version compatibility + "vllm_version": vllm_version, + "nixl_connector_version": NIXL_CONNECTOR_VERSION, + # Model architecture - affects KV cache shape + "model": model_config.model, + "dtype": str(model_config.dtype), + "num_kv_heads": model_config.get_total_num_kv_heads(), + "head_size": model_config.get_head_size(), + "num_hidden_layers": model_config.get_total_num_hidden_layers(), + # Attention backend and KV cache dtype affect memory layout + "attn_backend_name": attn_backend_name, + "cache_dtype": str(cache_config.cache_dtype), + } + + compat_hash = hash_factors(factors) + logger.info( + "NIXL compatibility hash: %s (model=%s, dtype=%s, num_kv_heads=%d, " + "cache_dtype=%s, attn_backend=%s)", + compat_hash, + factors["model"], + factors["dtype"], + factors["num_kv_heads"], + factors["cache_dtype"], + attn_backend_name, + ) + return compat_hash + + @dataclass class ReqMeta: local_block_ids: list[int] @@ -396,14 +488,14 @@ def set_xfer_handshake_metadata( encoded_data: dict[int, bytes] = {} encoder = msgspec.msgpack.Encoder() for tp_rank, rank_metadata in metadata.items(): - if not isinstance(rank_metadata, NixlAgentMetadata): + if not isinstance(rank_metadata, NixlHandshakePayload): raise ValueError( - "NixlConnectorScheduler expects NixlAgentMetadata for " + "NixlConnectorScheduler expects NixlHandshakePayload for " "handshake metadata." ) encoded_data[tp_rank] = encoder.encode(rank_metadata) logger.debug( - "Tp rank %d: encoded NixlAgentMetadata size: %s bytes", + "Tp rank %d: encoded NixlHandshakePayload size: %s bytes", tp_rank, str(len(encoded_data[tp_rank])), ) @@ -916,7 +1008,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._failed_recv_reqs: set[ReqId] = set() # Handshake metadata of this worker for NIXL transfers. - self.xfer_handshake_metadata: NixlAgentMetadata | None = None + self.xfer_handshake_metadata: NixlHandshakePayload | None = None # Background thread for initializing new NIXL handshakes. self._handshake_initiation_executor = ThreadPoolExecutor( # NIXL is not guaranteed to be thread-safe, limit 1 worker. @@ -951,6 +1043,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) + self.compat_hash = compute_nixl_compatibility_hash( + self.vllm_config, self.backend_name + ) + self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( + "enforce_handshake_compat", True + ) + self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} # With heterogeneous TP, P must wait for all assigned D TP workers to @@ -999,14 +1098,58 @@ def _nixl_handshake( # Set receive timeout to 5 seconds to avoid hanging on dead server sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds sock.send(msg) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) + handshake_bytes = sock.recv() + + # Decode handshake payload to get compatibility hash + handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload) + try: + handshake_payload = handshake_decoder.decode(handshake_bytes) + except (msgspec.DecodeError, msgspec.ValidationError) as e: + raise RuntimeError( + f"Failed to decode NixlHandshakePayload. This likely indicates " + f"an incompatibility between connector version. Error: {e}" + ) from e + got_metadata_time = time.perf_counter() logger.debug( "NIXL handshake: get metadata took: %s", got_metadata_time - start_time ) + # Check compatibility hash BEFORE decoding agent metadata + if ( + self.enforce_compat_hash + and handshake_payload.compatibility_hash != self.compat_hash + ): + raise RuntimeError( + f"NIXL compatibility hash mismatch. " + f"Local: {self.compat_hash}, " + f"Remote: {handshake_payload.compatibility_hash}. " + f"Prefill and decode instances have incompatible configurations. " + f"This may be due to: different vLLM versions, models, dtypes, " + f"KV cache layouts, attention backends, etc. " + f"Both instances must use identical configurations." + f"Disable this check using " + f'--kv-transfer-config \'{{"kv_connector_extra_config": ' + f'{{"enforce_handshake_compat": false}}}}\'' + ) + + logger.info( + "NIXL compatibility check passed (hash: %s)", + handshake_payload.compatibility_hash, + ) + + # Decode agent metadata + metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + try: + metadata = metadata_decoder.decode( + handshake_payload.agent_metadata_bytes + ) + except (msgspec.DecodeError, msgspec.ValidationError) as e: + # This should not happen if hash matched + raise RuntimeError( + f"Failed to decode NixlAgentMetadata. Error: {e}" + ) from e + # Ensure engine id matches. if metadata.engine_id != expected_engine_id: raise RuntimeError( @@ -1297,19 +1440,24 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert len(self.block_window_per_layer) == self.num_layers # After KV Caches registered, listen for new connections. - self.xfer_handshake_metadata = NixlAgentMetadata( + agent_metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], device_id=self.device_id, num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, - attn_backend_name=self.backend_name, kv_cache_layout=self.kv_cache_layout if not self.use_host_buffer else self.host_buffer_kv_cache_layout, block_size=self.block_size, ) + # Wrap metadata in payload with hash for defensive decoding + encoder = msgspec.msgpack.Encoder() + self.xfer_handshake_metadata = NixlHandshakePayload( + compatibility_hash=self.compat_hash, + agent_metadata_bytes=encoder.encode(agent_metadata), + ) def register_local_xfer_handler( self, @@ -1524,8 +1672,6 @@ def _validate_remote_agent_handshake( remote_engine_id = nixl_agent_meta.engine_id assert self._tp_size[remote_engine_id] == remote_tp_size - # TODO We may eventually want to skip enforcing the same attn backend. - assert nixl_agent_meta.attn_backend_name == self.backend_name tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(