Skip to content

Conversation

@FENP
Copy link
Contributor

@FENP FENP commented Nov 19, 2025

Purpose

Ref to issue #25749. Enable PCP for MLA models.
This PR mainly includes the following changes:

  • Modified vllm/v1/worker/gpu_model_runner.py for PCP partitioning logic for tokens
  • Modified vllm/v1/attention/backends/mla/common.py to adapt the MLA backend to PCP
  • Add utility functions required by PCP to vllm/v1/attention/backends/utils.py
  • Renamed variables and functions shared by both PCP and DCP

Test Plan

vllm serve deepseek-ai/DeepSeek-V2-Lite-Chat --gpu-memory-utilization 0.9 --tensor-parallel-size 1 --prefill-context-parallel-size 2

Test Result

  • PCP1 (Baseline)
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6277|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.6179|±  |0.0095|
  • PCP2
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6259|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.6164|±  |0.0095|
  • PCP4
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6308|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.6236|±  |0.0094|
  • PCP8
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6240|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.6168|±  |0.0095|
  • Support piecewise graph
  • Support chunk prefill and prefix caching
  • Add ci test
  • Accuracy test

Future work

These items will be tackled in follow-up PRs; community contributions are warmly welcomed.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify
Copy link

mergify bot commented Nov 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @FENP.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 19, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Prefill Context Parallelism (PCP) to the Multi-Level Attention (MLA) backend. The changes are extensive, involving refactoring of context parallelism logic to be more generic, adding new metadata and utility functions for PCP, and implementing the PCP attention logic based on the Dual-Chunk-Swap strategy.

My review has identified a critical issue in the attention correction logic where DCP and PCP corrections are applied in the wrong order, which will lead to incorrect results. I have also pointed out a significant performance issue related to nested communication calls that should be optimized. Overall, the PR is a good step towards enabling PCP, but these critical issues need to be addressed.

Comment on lines 1828 to 1917
cur_allgather_kvcache.copy_(
get_dcp_group().all_gather(local_gathered_kvcache, dim=0)
get_pcp_group().all_gather(
get_dcp_group().all_gather(local_gathered_kvcache, dim=0),
dim=0,
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The nested all_gather calls, first over the DCP group and then over the PCP group, are inefficient as they introduce extra communication overhead and synchronization points. This should be optimized into a single all_gather operation.

To achieve this, a new communication group that combines the ranks from both DCP and PCP should be created during initialization. Then, a single all_gather can be performed over this combined "context parallel" (CP) group. This will be more performant. The TODO comment already acknowledges this, and this comment serves to emphasize its importance for performance.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/mla/common.py#L212-L215
P1 Badge Importing undefined get_pcp_group

Lines 212‑215 import get_pcp_group from vllm.distributed.parallel_state, but that module still only exposes get_dcp_group (the commit merely introduced a _CP variable without any getter). Importing common.py will therefore immediately raise ImportError: cannot import name 'get_pcp_group', so none of the new PCP code paths can even be instantiated.


https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/mla/flashattn_mla.py#L79-L86
P1 Badge FlashAttn builder now passes nonexistent kwarg

The call to super().__init__(…, supports_cp_with_varlen=True) in FlashAttnMLAMetadataBuilder.__init__ (lines 79‑86) will raise TypeError: __init__() got an unexpected keyword argument 'supports_cp_with_varlen' because MLACommonMetadataBuilder.__init__ still only accepts supports_dcp_with_varlen. This prevents the FlashAttn MLA backend from constructing at all.


https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/mla/common.py#L572-L574
P1 Badge Referencing cp_kv_cache_interleave_size attribute that does not exist

Lines 572‑574 now read self.cp_local_block_size = parallel_config.cp_kv_cache_interleave_size, but ParallelConfig (vllm/config/parallel.py) defines only dcp_kv_cache_interleave_size. As soon as MLACommonMetadataBuilder is constructed this access raises AttributeError: 'ParallelConfig' object has no attribute 'cp_kv_cache_interleave_size', so the MLA backend cannot even initialize.


https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/utils.py#L1118-L1124
P1 Badge New utils annotation causes NameError on import

The new helper pcp_kv_allgather_and_restore (lines 1118‑1124) annotates pcp_group: GroupCoordinator, but GroupCoordinator is only imported inside the TYPE_CHECKING block and there is no from __future__ import annotations. When Python evaluates these annotations at import time it looks up GroupCoordinator, fails to find the name, and raises NameError, breaking vllm.v1.attention.backends.utils for every runtime import.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@mergify mergify bot removed the needs-rebase label Nov 19, 2025
@FENP FENP force-pushed the prefill-context-parallel-mla branch 3 times, most recently from 2d9034a to 8ac9843 Compare November 20, 2025 07:10
@FENP FENP requested review from mgoin and tjtanaa as code owners November 20, 2025 07:10
@mergify mergify bot added nvidia rocm Related to AMD ROCm labels Nov 20, 2025
@FENP FENP force-pushed the prefill-context-parallel-mla branch from 8ac9843 to 5e79da7 Compare November 20, 2025 07:38
@mergify
Copy link

mergify bot commented Nov 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @FENP.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some very initial comments; will review more thoroughly tmrw

Comment on lines 1512 to 1520
logits_indices = query_start_loc[1:] - 1
if self.pcp_world_size > 1:
logits_indices = (
torch.from_numpy(cu_num_tokens) * self.pcp_world_size
- self.pcp_manager.num_pcp_pads_cpu_tensor[:num_reqs]
- 1
)
else:
logits_indices = query_start_loc[1:] - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why repeated logits_indices = query_start_loc[1:] - 1?

can we do something like? the ultimate goal is to lower the visual presence and cognitive load so people can easily ignore the PCP stuff when reading gpu_model_runner if they don't care about PCP

logits_indices = query_start_loc[1:] - 1
if self.pcp_world_size > 1:
     logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens, num_reqs)

Copy link
Contributor Author

@FENP FENP Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Modified as suggested. Refactor the logic of other PCPs into functions in this way, including: get_restore_hidden_states, get_discard_request_mask, and get_padded_slot_mapping.

# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros(
(num_reqs_padded, 1),
(num_tokens_padded, 1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rebase error?

Copy link
Contributor Author

@FENP FENP Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry for the mistake, fixed it.

else:
self.discard_request_mask.np[:num_reqs] = (
self.seq_lens.np[:num_reqs] < num_tokens_np
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit similar to: https://github.com/vllm-project/vllm/pull/28988/files#r2587693784

can we do:

if self.pcp_world_size > 1:
      self.discard_request_mask.np[:num_reqs] = self.pcp_manager.get_discard_request_mask(...)
else:
      self.discard_request_mask.np[:num_reqs] = (
         self.seq_lens.np[:num_reqs] < num_tokens_np
      )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified as suggested.

blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
if self.pcp_world_size == 1:
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't this cause issues for decode batches with full-cudagraphs (we should try to get FULL_AND_PIECEWISE turned on for DCP)

Copy link
Contributor Author

@FENP FENP Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the PCP logic is incompatible with pad attn by now, I've added an if condition here to ensure that these two lines of code are executed only when PCP is disabled. I don't think this will have any impact on DCP.

@FENP FENP force-pushed the prefill-context-parallel-mla branch 3 times, most recently from b8d3dea to 5a51e85 Compare December 4, 2025 09:07
@mergify
Copy link

mergify bot commented Dec 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @FENP.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 5, 2025
FENP and others added 7 commits December 8, 2025 10:44
Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
…ckend.

Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
2 * pcp_size,
return_head=True,
)
kv_tail_indices, _ = get_pcp_part_indices(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be _, kv_tail_indices?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, return_head is setted True here. For query, the required KV is starting from the first token, so we always need to return the indices starting from the head.

@FENP FENP force-pushed the prefill-context-parallel-mla branch from 5a51e85 to 92d8766 Compare December 8, 2025 09:30
@mergify mergify bot removed the needs-rebase label Dec 8, 2025
@mergify
Copy link

mergify bot commented Dec 8, 2025

Hi @FENP, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

@FENP FENP force-pushed the prefill-context-parallel-mla branch from 92d8766 to 45c5c7c Compare December 8, 2025 10:45
@mergify
Copy link

mergify bot commented Dec 8, 2025

Hi @FENP, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
@FENP FENP force-pushed the prefill-context-parallel-mla branch from 45c5c7c to 4b40860 Compare December 8, 2025 11:31
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i ran some profiles and I think theres way to many small torch ops in the current implementation (e.g. index_select) leading to excessive CPU overhead; i know this is targeting very long prefills (this was tested with 16k input) but I think we should try to optimize this more before landing

Image

cu_seqlens_q=prefill.query_start_loc // 2,
cu_seqlens_k=prefill.query_start_loc // 2 * (self.pcp_rank + 1),
max_seqlen_q=prefill.max_query_len // 2,
max_seqlen_k=prefill.max_query_len // 2 * (self.pcp_rank + 1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move the // 2 floor divs of the hot path? im seeing them show up in the profiles

Image

assert pcp_metadata is not None
output_head, lse_head = self._flash_attn_varlen_diff_headdims(
q=torch.index_select(q, 0, pcp_metadata.query_head_indices),
k=torch.index_select(k, 0, pcp_metadata.kv_head_indices),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we replace the k/v index selects by instead treating pcp_metadata.kv_head_indices as the block table for a page_size==1 kv-cache (and pass k/v directly)? FA3 supports page size 1

Copy link
Contributor Author

@FENP FENP Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 It's worth a try

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @LucasWilkinson, I've thought about this issue a bit more.
Could we consider using a single Triton kernel to perform the index select for qkv all at once? Compared to treating pcp_metadata.kv_head_indices as a block table, this approach could avoid the additional index select call for q. In addition, I think this approach would also make it easier to migrate to other attention backends.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me; assuming the triton launch overheads are reasonable since we unfortunately don't have prefill CGs to hide this overhead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! I'm currently working on this change — will push the update soon.

@mergify
Copy link

mergify bot commented Dec 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @FENP.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 9, 2025
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
if self.pcp_world_size > 1:
logits_indices = self.pcp_manager.get_logits_indices(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logits_indices here seem a cpu_tensor , maybe it should be a devie_tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense.

@weiguihua2 weiguihua2 mentioned this pull request Dec 18, 2025
5 tasks
]
all_positions = np.concatenate(all_positions_lst)
self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = (
all_positions.argsort()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now, all_positions is a cpu_tensor, how do you think put the tensor to device and then sort?

blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)

if self.pcp_world_size > 1:
slot_mapping = self.pcp_manager.get_padded_slot_mapping(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the address of slot_mapping here seems changed, I think it may infulence cudagraph in future。maybe we could make the address fixed?

cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[
: num_tokens * self.pcp_world_size
]
pcp_padded_slot_mapping.fill_(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we delete this operation and initilized with full of -1

@chaunceyjiang
Copy link
Collaborator

hi @FENP Can you resolve the conflict? I want to test it locally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

5 participants