From 9552bcea1e022275dd202d6a0031f7117734a535 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 25 Nov 2025 17:19:04 -0500 Subject: [PATCH 1/3] eliminate cpu access in FA metadata building Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/flash_attn.py | 25 ++++++++++++------------ vllm/v1/attention/backends/utils.py | 6 ++++-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a9a4af5ac118..dd5879d2a96e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -328,7 +328,6 @@ def build( max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping causal = common_attn_metadata.causal @@ -401,20 +400,21 @@ def schedule( prefix_scheduler_metadata = None if self.dcp_world_size > 1: - query_kv_lens_cpu = ( - common_attn_metadata.query_start_loc_cpu[1:] - - common_attn_metadata.query_start_loc_cpu[:-1] - ) - dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu + query_kv_lens = query_start_loc[1:] - query_start_loc[:-1] + dcp_context_kv_lens = seq_lens - query_kv_lens - dcp_context_kv_lens_cpu = get_dcp_local_seq_lens( - dcp_context_kv_lens_cpu, + dcp_context_kv_lens = get_dcp_local_seq_lens( + dcp_context_kv_lens, self.dcp_world_size, self.dcp_rank, self.cp_kv_cache_interleave_size, ) - dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) - max_dcp_context_kv_len = dcp_context_kv_lens.max().item() + # Use upper bound to avoid GPU->CPU sync + # Context length = seq_len - query_len, so max context ≤ max_seq_len + # This is conservative but eliminates sync point while maintaining + # correctness. The actual sequence lengths are passed via + # dcp_context_kv_lens (GPU tensor). + max_dcp_context_kv_len = max_seq_len scheduler_metadata = schedule( batch_size=num_reqs, @@ -431,9 +431,8 @@ def schedule( prefix_kv_lens = torch.tensor( [common_prefix_len], dtype=torch.int32, device=self.device ) - suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( - self.device, non_blocking=True - ) + # Use GPU tensor directly - no CPU sync needed + suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index cebfe8a3ff04..ab757d2a5749 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1092,12 +1092,14 @@ def get_dcp_local_seq_lens( num_requests = seq_lens.size(0) if dcp_rank is None: rank_offsets = ( - torch.arange(dcp_size, dtype=torch.int32) + torch.arange(dcp_size, dtype=torch.int32, device=seq_lens.device) .unsqueeze(0) .repeat(num_requests, 1) ) else: - rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32) + rank_offsets = torch.tensor( + [[dcp_rank]], dtype=torch.int32, device=seq_lens.device + ) seq_lens_tiled = ( seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1]) ) From 3faf1e075598dfba019d2cd008d9170bb1c114d0 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 25 Nov 2025 17:34:33 -0500 Subject: [PATCH 2/3] tighter upper bound Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/flash_attn.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index dd5879d2a96e..f66cda64f3f8 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -409,12 +409,11 @@ def schedule( self.dcp_rank, self.cp_kv_cache_interleave_size, ) - # Use upper bound to avoid GPU->CPU sync - # Context length = seq_len - query_len, so max context ≤ max_seq_len - # This is conservative but eliminates sync point while maintaining - # correctness. The actual sequence lengths are passed via - # dcp_context_kv_lens (GPU tensor). - max_dcp_context_kv_len = max_seq_len + # After DCP distribution, each rank gets at most ceil(L / N) elements + # This eliminates GPU->CPU sync while minimizing workspace over-allocation. + max_dcp_context_kv_len = ( + max_seq_len + self.dcp_world_size - 1 + ) // self.dcp_world_size scheduler_metadata = schedule( batch_size=num_reqs, From 78829f4e798918823ad5ad2f146db358c7b7a975 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 25 Nov 2025 17:39:03 -0500 Subject: [PATCH 3/3] fix upper bound logic Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/flash_attn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f66cda64f3f8..353a738a2693 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -409,11 +409,14 @@ def schedule( self.dcp_rank, self.cp_kv_cache_interleave_size, ) - # After DCP distribution, each rank gets at most ceil(L / N) elements + # After DCP distribution, the maximum number of tokens for any rank is + # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, + # and I is cp_kv_cache_interleave_size. # This eliminates GPU->CPU sync while minimizing workspace over-allocation. + num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size max_dcp_context_kv_len = ( - max_seq_len + self.dcp_world_size - 1 - ) // self.dcp_world_size + (max_seq_len + num_partitions - 1) // num_partitions + ) * self.cp_kv_cache_interleave_size scheduler_metadata = schedule( batch_size=num_reqs,