-
-
Notifications
You must be signed in to change notification settings - Fork 12.1k
[Feature][Attention][PCP] Support PCP (Prefill Context Parallel) with MLA #28988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for Prefill Context Parallelism (PCP) to the Multi-Level Attention (MLA) backend. The changes are extensive, involving refactoring of context parallelism logic to be more generic, adding new metadata and utility functions for PCP, and implementing the PCP attention logic based on the Dual-Chunk-Swap strategy.
My review has identified a critical issue in the attention correction logic where DCP and PCP corrections are applied in the wrong order, which will lead to incorrect results. I have also pointed out a significant performance issue related to nested communication calls that should be optimized. Overall, the PR is a good step towards enabling PCP, but these critical issues need to be addressed.
| cur_allgather_kvcache.copy_( | ||
| get_dcp_group().all_gather(local_gathered_kvcache, dim=0) | ||
| get_pcp_group().all_gather( | ||
| get_dcp_group().all_gather(local_gathered_kvcache, dim=0), | ||
| dim=0, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested all_gather calls, first over the DCP group and then over the PCP group, are inefficient as they introduce extra communication overhead and synchronization points. This should be optimized into a single all_gather operation.
To achieve this, a new communication group that combines the ranks from both DCP and PCP should be created during initialization. Then, a single all_gather can be performed over this combined "context parallel" (CP) group. This will be more performant. The TODO comment already acknowledges this, and this comment serves to emphasize its importance for performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/mla/common.py#L212-L215
Importing undefined get_pcp_group
Lines 212‑215 import get_pcp_group from vllm.distributed.parallel_state, but that module still only exposes get_dcp_group (the commit merely introduced a _CP variable without any getter). Importing common.py will therefore immediately raise ImportError: cannot import name 'get_pcp_group', so none of the new PCP code paths can even be instantiated.
https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/mla/flashattn_mla.py#L79-L86
FlashAttn builder now passes nonexistent kwarg
The call to super().__init__(…, supports_cp_with_varlen=True) in FlashAttnMLAMetadataBuilder.__init__ (lines 79‑86) will raise TypeError: __init__() got an unexpected keyword argument 'supports_cp_with_varlen' because MLACommonMetadataBuilder.__init__ still only accepts supports_dcp_with_varlen. This prevents the FlashAttn MLA backend from constructing at all.
https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/mla/common.py#L572-L574
Referencing cp_kv_cache_interleave_size attribute that does not exist
Lines 572‑574 now read self.cp_local_block_size = parallel_config.cp_kv_cache_interleave_size, but ParallelConfig (vllm/config/parallel.py) defines only dcp_kv_cache_interleave_size. As soon as MLACommonMetadataBuilder is constructed this access raises AttributeError: 'ParallelConfig' object has no attribute 'cp_kv_cache_interleave_size', so the MLA backend cannot even initialize.
https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/utils.py#L1118-L1124
New utils annotation causes NameError on import
The new helper pcp_kv_allgather_and_restore (lines 1118‑1124) annotates pcp_group: GroupCoordinator, but GroupCoordinator is only imported inside the TYPE_CHECKING block and there is no from __future__ import annotations. When Python evaluates these annotations at import time it looks up GroupCoordinator, fails to find the name, and raises NameError, breaking vllm.v1.attention.backends.utils for every runtime import.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
2d9034a to
8ac9843
Compare
8ac9843 to
5e79da7
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some very initial comments; will review more thoroughly tmrw
vllm/v1/worker/gpu_model_runner.py
Outdated
| logits_indices = query_start_loc[1:] - 1 | ||
| if self.pcp_world_size > 1: | ||
| logits_indices = ( | ||
| torch.from_numpy(cu_num_tokens) * self.pcp_world_size | ||
| - self.pcp_manager.num_pcp_pads_cpu_tensor[:num_reqs] | ||
| - 1 | ||
| ) | ||
| else: | ||
| logits_indices = query_start_loc[1:] - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why repeated logits_indices = query_start_loc[1:] - 1?
can we do something like? the ultimate goal is to lower the visual presence and cognitive load so people can easily ignore the PCP stuff when reading gpu_model_runner if they don't care about PCP
logits_indices = query_start_loc[1:] - 1
if self.pcp_world_size > 1:
logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens, num_reqs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 Modified as suggested. Refactor the logic of other PCPs into functions in this way, including: get_restore_hidden_states, get_discard_request_mask, and get_padded_slot_mapping.
vllm/v1/worker/gpu_model_runner.py
Outdated
| # create a dummy block table and slot mapping for them. | ||
| blk_table_tensor = torch.zeros( | ||
| (num_reqs_padded, 1), | ||
| (num_tokens_padded, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rebase error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry for the mistake, fixed it.
| else: | ||
| self.discard_request_mask.np[:num_reqs] = ( | ||
| self.seq_lens.np[:num_reqs] < num_tokens_np | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit similar to: https://github.com/vllm-project/vllm/pull/28988/files#r2587693784
can we do:
if self.pcp_world_size > 1:
self.discard_request_mask.np[:num_reqs] = self.pcp_manager.get_discard_request_mask(...)
else:
self.discard_request_mask.np[:num_reqs] = (
self.seq_lens.np[:num_reqs] < num_tokens_np
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified as suggested.
| blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) | ||
| if self.pcp_world_size == 1: | ||
| slot_mapping[num_tokens:num_tokens_padded].fill_(-1) | ||
| blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't this cause issues for decode batches with full-cudagraphs (we should try to get FULL_AND_PIECEWISE turned on for DCP)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the PCP logic is incompatible with pad attn by now, I've added an if condition here to ensure that these two lines of code are executed only when PCP is disabled. I don't think this will have any impact on DCP.
b8d3dea to
5a51e85
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com> Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com> Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
…ckend. Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
| 2 * pcp_size, | ||
| return_head=True, | ||
| ) | ||
| kv_tail_indices, _ = get_pcp_part_indices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be _, kv_tail_indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, return_head is setted True here. For query, the required KV is starting from the first token, so we always need to return the indices starting from the head.
5a51e85 to
92d8766
Compare
|
Hi @FENP, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
92d8766 to
45c5c7c
Compare
|
Hi @FENP, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
45c5c7c to
4b40860
Compare
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| cu_seqlens_q=prefill.query_start_loc // 2, | ||
| cu_seqlens_k=prefill.query_start_loc // 2 * (self.pcp_rank + 1), | ||
| max_seqlen_q=prefill.max_query_len // 2, | ||
| max_seqlen_k=prefill.max_query_len // 2 * (self.pcp_rank + 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| assert pcp_metadata is not None | ||
| output_head, lse_head = self._flash_attn_varlen_diff_headdims( | ||
| q=torch.index_select(q, 0, pcp_metadata.query_head_indices), | ||
| k=torch.index_select(k, 0, pcp_metadata.kv_head_indices), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we replace the k/v index selects by instead treating pcp_metadata.kv_head_indices as the block table for a page_size==1 kv-cache (and pass k/v directly)? FA3 supports page size 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 It's worth a try
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @LucasWilkinson, I've thought about this issue a bit more.
Could we consider using a single Triton kernel to perform the index select for qkv all at once? Compared to treating pcp_metadata.kv_head_indices as a block table, this approach could avoid the additional index select call for q. In addition, I think this approach would also make it easier to migrate to other attention backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense to me; assuming the triton launch overheads are reasonable since we unfortunately don't have prefill CGs to hide this overhead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback! I'm currently working on this change — will push the update soon.
|
This pull request has merge conflicts that must be resolved before it can be |
| # TODO: Support prompt logprobs. | ||
| logits_indices = query_start_loc[1:] - 1 | ||
| if self.pcp_world_size > 1: | ||
| logits_indices = self.pcp_manager.get_logits_indices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logits_indices here seem a cpu_tensor , maybe it should be a devie_tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense.
| ] | ||
| all_positions = np.concatenate(all_positions_lst) | ||
| self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = ( | ||
| all_positions.argsort() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now, all_positions is a cpu_tensor, how do you think put the tensor to device and then sort?
| blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) | ||
|
|
||
| if self.pcp_world_size > 1: | ||
| slot_mapping = self.pcp_manager.get_padded_slot_mapping( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the address of slot_mapping here seems changed, I think it may infulence cudagraph in future。maybe we could make the address fixed?
| cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[ | ||
| : num_tokens * self.pcp_world_size | ||
| ] | ||
| pcp_padded_slot_mapping.fill_(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we delete this operation and initilized with full of -1
|
hi @FENP Can you resolve the conflict? I want to test it locally. |


Purpose
Ref to issue #25749. Enable PCP for MLA models.
This PR mainly includes the following changes:
vllm/v1/worker/gpu_model_runner.pyfor PCP partitioning logic for tokensvllm/v1/attention/backends/mla/common.pyto adapt the MLA backend to PCPvllm/v1/attention/backends/utils.pyTest Plan
Test Result
Future work
These items will be tackled in follow-up PRs; community contributions are warmly welcomed.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.