From d378af83adc04272460d6f2671bcf3d91e68265b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 23 Sep 2025 19:42:27 +0000 Subject: [PATCH 1/3] init: nixl support Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 95 ++++++++++++------- 1 file changed, 59 insertions(+), 36 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1c7569515dec..55d87ea994b5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -87,7 +87,7 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int - block_len: int + block_lens: list[int] attn_backend_name: str kv_cache_layout: str @@ -772,6 +772,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) tensor_size_bytes = None + # Enable different block lengths for different layers when MLA is used. + self.block_len_per_layer = list[int]() + self.slot_size_per_layer = list[int]() # HD bytes in kv terms for layer_name, cache_or_caches in xfer_buffers.items(): cache_list = cache_or_caches if split_k_and_v else [ cache_or_caches @@ -789,10 +792,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): tensor_size_bytes = curr_tensor_size_bytes self.num_blocks = cache.shape[0] - assert tensor_size_bytes == curr_tensor_size_bytes, \ - "All kv cache tensors must have the same size" + assert cache.shape[0] == self.num_blocks, \ + "All kv cache tensors must have the same number of blocks" + + self.block_len_per_layer.append(curr_tensor_size_bytes // + self.num_blocks) + self.slot_size_per_layer.append(self.block_len_per_layer[-1] // + self.block_size) + + if not self.use_mla: + # Different kv cache shape is not supported by HeteroTP + assert tensor_size_bytes == curr_tensor_size_bytes, \ + "All kv cache tensors must have the same size" caches_data.append( - (base_addr, tensor_size_bytes, self.tp_rank, "")) + (base_addr, curr_tensor_size_bytes, self.tp_rank, "")) + + logger.debug("Different block lengths collected: %s", + set(self.block_len_per_layer)) + assert len(self.block_len_per_layer) == len(seen_base_addresses) + assert self.num_blocks != 0 self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) @@ -805,16 +823,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("Done registering descs") self._registered_descs.append(descs) - assert tensor_size_bytes is not None - assert self.num_blocks != 0 - assert tensor_size_bytes % self.num_blocks == 0 - self.block_len = tensor_size_bytes // self.num_blocks - self.slot_size_bytes = self.block_len // self.block_size self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks if self._use_flashinfer: - assert self.slot_size_bytes % 2 == 0 - self.slot_size_bytes /= 2 + for i in range(len(self.slot_size_per_layer)): + assert self.slot_size_per_layer[i] % 2 == 0 + self.slot_size_per_layer[i] //= 2 # NOTE (NickLucche) When FlashInfer is used, memory is registered # with joint KV for each block. This minimizes the overhead in @@ -824,17 +838,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # of 'virtual' regions here and halve `block_len` below. self.num_regions *= 2 - kv_block_len = self.get_backend_aware_kv_block_len() # Register local/src descr for NIXL xfer. blocks_data = [] - for base_addr in seen_base_addresses: + for i, base_addr in enumerate(seen_base_addresses): + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # NOTE With heter-TP, more blocks are prepared than what are # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # could create fewer, but then _get_block_descs_ids needs to # select agent_meta.num_blocks instead of self.num_blocks for # local descr, and that makes handling regular flow less clean. for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len + block_offset = block_id * self.block_len_per_layer[i] addr = base_addr + block_offset # (addr, len, device id) blocks_data.append((addr, kv_block_len, self.tp_rank)) @@ -844,7 +858,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # descs ordering. This is needed for selecting contiguous heads # when split across TP ranks. for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len + block_offset = block_id * self.block_len_per_layer[i] addr = base_addr + block_offset # Register addresses for V cache (K registered first). v_addr = addr + kv_block_len @@ -884,7 +898,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, - block_len=self.block_len, + block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, kv_cache_layout=self.kv_cache_layout) ready_event = threading.Event() @@ -909,7 +923,7 @@ def add_remote_agent(self, The latter, assuming D.world_size > P.world_size, requires that two or more local TP worker share the xfer from a single TP worker. - Here's an example: + Here's an example (non-MLA case): rank_offset p_remote_tp_rank (kv split no) @@ -965,14 +979,20 @@ def add_remote_agent(self, total_num_kv_heads = self.model_config.get_total_num_kv_heads() is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 + remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or is_kv_replicated: - # With MLA the only difference is in the number of blocks. - remote_block_size = nixl_agent_meta.block_len // ( - self.slot_size_bytes) - assert self.block_len == nixl_agent_meta.block_len + # With replicated KV cache, only the number of blocks can differ. + assert self.block_len_per_layer == nixl_agent_meta.block_lens, \ + "KV cache sizes must match between P and D when replicated" + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0]) else: - remote_block_size = nixl_agent_meta.block_len // ( - self.slot_size_bytes * tp_ratio) + # When MLA is not used, this is a list of the same block length + for block_len in nixl_agent_meta.block_lens: + assert block_len == remote_block_len, \ + "All remote layers must have the same block size" + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio) if self._use_flashinfer: # With flashinfer, KV are sent in the same message. remote_block_size //= 2 @@ -983,14 +1003,14 @@ def add_remote_agent(self, raise ValueError( "Heterogeneous TP is not supported on XPU") - assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( + assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( "Remote P worker KV layer cache must be of shape [2, N, " "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ) assert self.block_size == remote_block_size, ( - "Remote P worker with different block size is not supported " - f"{self.block_size=} {remote_block_size=}") + "Remote P worker with different page/block size is not supported " + f"{self.block_size=}, {remote_block_size=}") # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: @@ -1005,13 +1025,16 @@ def add_remote_agent(self, # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - kv_block_len = self.get_backend_aware_kv_block_len() - rank_offset = self.tp_rank % tp_ratio * kv_block_len \ - if not (self.use_mla or is_kv_replicated) else 0 + + assert len(nixl_agent_meta.kv_caches_base_addr) == len( + self.block_len_per_layer) # Register all remote blocks, but only the corresponding kv heads. - for base_addr in nixl_agent_meta.kv_caches_base_addr: + for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + rank_offset = self.tp_rank % tp_ratio * kv_block_len \ + if not (self.use_mla or is_kv_replicated) else 0 for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_len + block_offset = block_id * nixl_agent_meta.block_lens[i] # For each block, grab the heads chunk belonging to rank_i # of size remote_nheads // tp_ratio, which correspond to # self.block_len == remote_block_len//tp_ratio bytes. @@ -1022,9 +1045,9 @@ def add_remote_agent(self, if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_len + block_offset = block_id * nixl_agent_meta.block_lens[i] addr = base_addr + block_offset + rank_offset - v_addr = addr + nixl_agent_meta.block_len // 2 + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) logger.debug( @@ -1351,7 +1374,7 @@ def _get_block_descs_ids(self, descs_ids = region_ids * num_blocks + block_ids return descs_ids.flatten() - def get_backend_aware_kv_block_len(self): + def get_backend_aware_kv_block_len(self, layer_idx: int): """ Get the block length for one K/V element (K and V have the same size). @@ -1362,9 +1385,9 @@ def get_backend_aware_kv_block_len(self): """ if self._use_flashinfer: # For indexing only half (either just the K or V part). - block_len = self.block_len // 2 + block_len = self.block_len_per_layer[layer_idx] // 2 else: - block_len = self.block_len + block_len = self.block_len_per_layer[layer_idx] return block_len def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: From 5d90305e700d6680270f14a723ef6d0fcd887a66 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 25 Sep 2025 00:28:17 +0000 Subject: [PATCH 2/3] fix unify kv cache spec Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 47a41322c423..55cc7ea5a265 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1103,7 +1103,9 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): kv_cache_spec: The kv cache spec of each attention layer in the model """ - if is_kv_cache_spec_uniform(kv_cache_spec): + if is_kv_cache_spec_uniform( + kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type( + kv_cache_spec): return logger.warning( @@ -1141,7 +1143,8 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): attention_chunk_size=spec.attention_chunk_size, ) - if not is_kv_cache_spec_uniform(kv_cache_spec): + if not (is_kv_cache_spec_uniform(kv_cache_spec) + or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)): raise ValueError("Hybrid KV cache manager is disabled but failed to " "convert the KV cache specs to one unified type.") From 7f9a3533d4e5db7d20e7664ce88f2b4d06ce016c Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 30 Sep 2025 10:37:56 +0000 Subject: [PATCH 3/3] update tests to use *per_layer values Signed-off-by: NickLucche --- tests/v1/kv_connector/unit/test_nixl_connector.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 6b4bd29f18a5..578bf02eb519 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -255,8 +255,9 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int, time.sleep(self._hand_shake_latency) # These should've been done in register_kv_caches(), called by # gpu_model_runner. Here we just hardcode some dummy values. - self.slot_size_bytes = 4096 - self.block_len = self.slot_size_bytes * self.block_size + slot_size_bytes = 4096 + self.slot_size_per_layer = [slot_size_bytes] + self.block_len_per_layer = [slot_size_bytes * self.block_size] self.num_blocks = 1 self.dst_num_blocks[self.engine_id] = self.num_blocks @@ -268,7 +269,7 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], num_blocks=1, - block_len=self.block_len, + block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, # `self.kv_cache_layout` is only forced to HND when vllm engine # is started. We mock HND here. @@ -485,8 +486,8 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): worker = connector.connector_worker # Minimal local registration params used by add_remote_agent - worker.slot_size_bytes = 4096 - worker.block_len = worker.slot_size_bytes * worker.block_size + worker.slot_size_per_layer = [4096] + worker.block_len_per_layer = [4096 * worker.block_size] worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks @@ -498,7 +499,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], num_blocks=1, - block_len=worker.block_len, + block_lens=worker.block_len_per_layer, attn_backend_name=worker.backend_name, kv_cache_layout=mismatched_layout, )