Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def build(
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
Expand Down Expand Up @@ -401,20 +400,23 @@ def schedule(
prefix_scheduler_metadata = None

if self.dcp_world_size > 1:
query_kv_lens_cpu = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
dcp_context_kv_lens = seq_lens - query_kv_lens

dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
dcp_context_kv_lens_cpu,
dcp_context_kv_lens = get_dcp_local_seq_lens(
dcp_context_kv_lens,
self.dcp_world_size,
self.dcp_rank,
self.cp_kv_cache_interleave_size,
)
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
# After DCP distribution, the maximum number of tokens for any rank is
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
# and I is cp_kv_cache_interleave_size.
# This eliminates GPU->CPU sync while minimizing workspace over-allocation.
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
max_dcp_context_kv_len = (
(max_seq_len + num_partitions - 1) // num_partitions
) * self.cp_kv_cache_interleave_size

scheduler_metadata = schedule(
batch_size=num_reqs,
Expand All @@ -431,9 +433,8 @@ def schedule(
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
self.device, non_blocking=True
)
# Use GPU tensor directly - no CPU sync needed
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,12 +1092,14 @@ def get_dcp_local_seq_lens(
num_requests = seq_lens.size(0)
if dcp_rank is None:
rank_offsets = (
torch.arange(dcp_size, dtype=torch.int32)
torch.arange(dcp_size, dtype=torch.int32, device=seq_lens.device)
.unsqueeze(0)
.repeat(num_requests, 1)
)
else:
rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
rank_offsets = torch.tensor(
[[dcp_rank]], dtype=torch.int32, device=seq_lens.device
)
seq_lens_tiled = (
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
)
Expand Down