Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 213 additions & 11 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
import time
import uuid
from collections import defaultdict
from unittest.mock import patch
from typing import Any
from unittest.mock import MagicMock, patch

import msgspec
import pytest
import ray
import torch

from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats,
Expand All @@ -29,7 +32,9 @@
NixlConnectorMetadata,
NixlConnectorScheduler,
NixlConnectorWorker,
NixlHandshakePayload,
NixlKVConnectorStats,
compute_nixl_compatibility_hash,
)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_shutdown,
Expand Down Expand Up @@ -317,13 +322,19 @@ def test_kv_transfer_handshake(dist_init):
}
prefill_connector.register_kv_caches(kv_caches)

# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata = {0: prefill_connector.get_handshake_metadata()}
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata = prefill_connector.get_handshake_metadata()

# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)

# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata(metadata)
scheduler_connector.set_xfer_handshake_metadata({0: metadata})

# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
Expand Down Expand Up @@ -362,9 +373,9 @@ def test_kv_transfer_handshake(dist_init):
)

received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[0] == expected_agent_metadata
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
assert metadata[0] == received_metadata[0]

# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()
Expand Down Expand Up @@ -403,7 +414,6 @@ def _nixl_handshake(
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
Expand Down Expand Up @@ -651,7 +661,6 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
device_id=0,
num_blocks=1,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout,
block_size=worker.block_size,
)
Expand Down Expand Up @@ -706,7 +715,6 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
num_blocks=1,
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens=[i * 2 for i in worker.block_len_per_layer],
attn_backend_name=worker.backend_name,
kv_cache_layout="HND",
block_size=worker.block_size,
)
Expand Down Expand Up @@ -1168,6 +1176,9 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance

# Appease NixlHandshakePayload encoding with some bytes
mock_wrapper_instance.get_agent_metadata.return_value = b"fake_agent_metadata"

# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False

Expand Down Expand Up @@ -1534,3 +1545,194 @@ def test_transfer_setup_failure_returns_finished(dist_init):
# ensure request appears in get_finished
_, done_recving = connector.get_finished(finished_req_ids=set())
assert request_id in done_recving


@pytest.mark.parametrize(
"mismatch_type,config_overrides,version_override,should_fail,enforce_handshake_compat",
[
("vllm_version", {}, {"vllm_version": "0.6.1"}, True, True),
("nixl_connector_version", {}, {"connector_version": 37}, True, True),
("model_name", {"model": "facebook/opt-350m"}, {}, True, True),
("dtype", {"dtype": "bfloat16"}, {}, True, True),
("cache_dtype", {"cache_dtype": "fp8"}, {}, True, True),
("num_kv_heads", {"hf_overrides": {"num_key_value_heads": 8}}, {}, True, True),
(
"num_hidden_layers",
{"hf_overrides": {"num_hidden_layers": 24}},
{},
True,
True,
),
("hidden_size", {"hf_overrides": {"hidden_size": 1536}}, {}, True, True),
("block_size", {"block_size": 8}, {}, False, True),
("matching_config", {}, {}, False, True),
("escape_hatch", {"model": "facebook/opt-350m"}, {}, False, False),
],
)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_compatibility_hash_validation(
dist_init,
mismatch_type,
config_overrides,
version_override,
should_fail,
enforce_handshake_compat,
):
"""
Test NIXL compatibility hash validation during handshake.

Parameters:
mismatch_type: description of what is being tested
config_overrides: dict of config to override for the remote instance
version_override: version dict e.g. {"vllm_version": "0.6.1"}
should_fail: whether the handshake should fail
enforce_handshake_compat: whether to enforce compatibility checking
"""
local_vllm_config = create_vllm_config(
model="facebook/opt-125m",
block_size=16,
kv_connector_extra_config={
"enforce_handshake_compat": enforce_handshake_compat
},
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker

remote_config_params: dict[str, Any] = {
"model": "facebook/opt-125m",
"block_size": 16,
**config_overrides,
}
remote_vllm_config = create_vllm_config(**remote_config_params)

with contextlib.ExitStack() as stack:
if "vllm_version" in version_override:
stack.enter_context(
patch("vllm.__version__", version_override["vllm_version"])
)
elif "connector_version" in version_override:
stack.enter_context(
patch.object(
nixl_connector,
"NIXL_CONNECTOR_VERSION",
version_override["connector_version"],
)
)
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config, decode_worker.backend_name
)

prefill_block_size = config_overrides.get("block_size", 16)
prefill_metadata = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=[4096 * prefill_block_size], # slot_size * block_size
kv_cache_layout="HND",
block_size=prefill_block_size,
)
handshake_payload = NixlHandshakePayload(
compatibility_hash=remote_hash,
agent_metadata_bytes=msgspec.msgpack.encode(prefill_metadata),
)

# Mock ZMQ socket to return our handshake payload
mock_socket = MagicMock()
mock_socket.recv.return_value = msgspec.msgpack.encode(handshake_payload)

# Mock add_remote_agent to avoid actual NIXL operations
# Patch zmq_ctx to return our mock socket
with (
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx,
):
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket

if should_fail:
with pytest.raises(RuntimeError, match="compatibility hash mismatch"):
decode_worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)
else:
result = decode_worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)
# Verify handshake returned agent mapping
assert isinstance(result, dict)
assert len(result) == 1


@pytest.mark.parametrize(
"error_scenario",
[
"handshake_decode_error",
"handshake_validation_error",
"metadata_decode_error",
"metadata_validation_error",
],
)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_decode_errors(dist_init, error_scenario):
"""
Test that msgspec decode errors are properly handled during handshake.

Tests both DecodeError and ValidationError for both decoders:
- NixlHandshakePayload decoder
- NixlAgentMetadata decoder
"""
local_vllm_config = create_vllm_config(
model="facebook/opt-125m",
block_size=16,
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker

if error_scenario == "handshake_decode_error":
msg_bytes = b"this is not valid msgpack data"
elif error_scenario == "handshake_validation_error":
msg_bytes = msgspec.msgpack.encode({"wrong_field": "value"})
elif error_scenario == "metadata_decode_error":
valid_handshake = NixlHandshakePayload(
compatibility_hash=decode_worker.compat_hash,
agent_metadata_bytes=b"invalid msgpack for metadata",
)
msg_bytes = msgspec.msgpack.encode(valid_handshake)

elif error_scenario == "metadata_validation_error":
valid_handshake = NixlHandshakePayload(
compatibility_hash=decode_worker.compat_hash,
agent_metadata_bytes=msgspec.msgpack.encode({"missing": "fields"}),
)
msg_bytes = msgspec.msgpack.encode(valid_handshake)
else:
raise AssertionError(f"{error_scenario} not a valid scenario")

mock_socket = MagicMock()
mock_socket.recv.return_value = msg_bytes
with (
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx,
):
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket

with pytest.raises(RuntimeError):
decode_worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)
10 changes: 8 additions & 2 deletions tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def create_vllm_config(
max_model_len: int = 10000,
enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False,
kv_connector_extra_config: dict[str, Any] | None = None,
dtype: str = "float16",
cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
Expand All @@ -101,21 +105,23 @@ def create_vllm_config(
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
dtype=dtype,
seed=42,
hf_overrides=hf_overrides or {},
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
cache_dtype=cache_dtype,
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {},
)
return VllmConfig(
scheduler_config=scheduler_config,
Expand Down
Loading