Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
# gpu_model_runner. Here we just hardcode some dummy values.
self.slot_size_bytes = 4096
self.block_len = self.slot_size_bytes * self.block_size
slot_size_bytes = 4096
self.slot_size_per_layer = [slot_size_bytes]
self.block_len_per_layer = [slot_size_bytes * self.block_size]
self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks

Expand All @@ -268,7 +269,7 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_len=self.block_len,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
Expand Down Expand Up @@ -485,8 +486,8 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
worker = connector.connector_worker

# Minimal local registration params used by add_remote_agent
worker.slot_size_bytes = 4096
worker.block_len = worker.slot_size_bytes * worker.block_size
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks

Expand All @@ -498,7 +499,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_len=worker.block_len,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout,
)
Expand Down
95 changes: 59 additions & 36 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class NixlAgentMetadata(
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
block_len: int
block_lens: list[int]
attn_backend_name: str
kv_cache_layout: str

Expand Down Expand Up @@ -772,6 +772,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
split_k_and_v = not (self.use_mla or self._use_pallas
or self._use_flashinfer)
tensor_size_bytes = None
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [
cache_or_caches
Expand All @@ -789,10 +792,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]

assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
assert cache.shape[0] == self.num_blocks, \
"All kv cache tensors must have the same number of blocks"

self.block_len_per_layer.append(curr_tensor_size_bytes //
self.num_blocks)
self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
self.block_size)

if not self.use_mla:
# Different kv cache shape is not supported by HeteroTP
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
caches_data.append(
(base_addr, tensor_size_bytes, self.tp_rank, ""))
(base_addr, curr_tensor_size_bytes, self.tp_rank, ""))

logger.debug("Different block lengths collected: %s",
set(self.block_len_per_layer))
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0

self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data)
Expand All @@ -805,16 +823,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
logger.debug("Done registering descs")
self._registered_descs.append(descs)

assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0
self.slot_size_bytes /= 2
for i in range(len(self.slot_size_per_layer)):
assert self.slot_size_per_layer[i] % 2 == 0
self.slot_size_per_layer[i] //= 2

# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
Expand All @@ -824,17 +838,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
# of 'virtual' regions here and halve `block_len` below.
self.num_regions *= 2

kv_block_len = self.get_backend_aware_kv_block_len()
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in seen_base_addresses:
for i, base_addr in enumerate(seen_base_addresses):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.tp_rank))
Expand All @@ -844,7 +858,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
Expand Down Expand Up @@ -884,7 +898,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
block_len=self.block_len,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout)
ready_event = threading.Event()
Expand All @@ -909,7 +923,7 @@ def add_remote_agent(self,
The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker.

Here's an example:
Here's an example (non-MLA case):

rank_offset p_remote_tp_rank
(kv split no)
Expand Down Expand Up @@ -965,14 +979,20 @@ def add_remote_agent(self,
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1

remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or is_kv_replicated:
# With MLA the only difference is in the number of blocks.
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes)
assert self.block_len == nixl_agent_meta.block_len
# With replicated KV cache, only the number of blocks can differ.
assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
"KV cache sizes must match between P and D when replicated"
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0])
else:
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes * tp_ratio)
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, \
"All remote layers must have the same block size"
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio)
if self._use_flashinfer:
# With flashinfer, KV are sent in the same message.
remote_block_size //= 2
Expand All @@ -983,14 +1003,14 @@ def add_remote_agent(self,
raise ValueError(
"Heterogeneous TP is not supported on XPU")

assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)

assert self.block_size == remote_block_size, (
"Remote P worker with different block size is not supported "
f"{self.block_size=} {remote_block_size=}")
"Remote P worker with different page/block size is not supported "
f"{self.block_size=}, {remote_block_size=}")

# Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks:
Expand All @@ -1005,13 +1025,16 @@ def add_remote_agent(self,
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
kv_block_len = self.get_backend_aware_kv_block_len()
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
if not (self.use_mla or is_kv_replicated) else 0

assert len(nixl_agent_meta.kv_caches_base_addr) == len(
self.block_len_per_layer)
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
if not (self.use_mla or is_kv_replicated) else 0
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
block_offset = block_id * nixl_agent_meta.block_lens[i]
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
Expand All @@ -1022,9 +1045,9 @@ def add_remote_agent(self,
if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting.
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
block_offset = block_id * nixl_agent_meta.block_lens[i]
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_len // 2
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))

logger.debug(
Expand Down Expand Up @@ -1351,7 +1374,7 @@ def _get_block_descs_ids(self,
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()

def get_backend_aware_kv_block_len(self):
def get_backend_aware_kv_block_len(self, layer_idx: int):
"""
Get the block length for one K/V element (K and V have the same size).

Expand All @@ -1362,9 +1385,9 @@ def get_backend_aware_kv_block_len(self):
"""
if self._use_flashinfer:
# For indexing only half (either just the K or V part).
block_len = self.block_len // 2
block_len = self.block_len_per_layer[layer_idx] // 2
else:
block_len = self.block_len
block_len = self.block_len_per_layer[layer_idx]
return block_len

def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,9 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
kv_cache_spec: The kv cache spec of each attention layer in the model
"""

if is_kv_cache_spec_uniform(kv_cache_spec):
if is_kv_cache_spec_uniform(
kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type(
kv_cache_spec):
return

logger.warning(
Expand Down Expand Up @@ -1141,7 +1143,8 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
attention_chunk_size=spec.attention_chunk_size,
)

if not is_kv_cache_spec_uniform(kv_cache_spec):
if not (is_kv_cache_spec_uniform(kv_cache_spec)
or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.")

Expand Down