Skip to content

[P/D][Bugfix]: Fix the issue where the remote KVCache cannot be loaded when PP > 1 #19558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

chaunceyjiang
Copy link
Contributor

@chaunceyjiang chaunceyjiang commented Jun 12, 2025

When the PD disagg , the remote KVCache cannot be loaded if PP > 1.

we need to execute self.model_executor.execute_model to start loading the KV cache.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @chaunceyjiang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a specific bug related to loading the remote KVCache in a distributed setup (PD disaggregation) when using pipeline parallelism (PP > 1). The core change involves adjusting the scheduling logic to ensure that necessary model execution steps for KV cache operations are triggered under these conditions, resolving the loading failure.

Highlights

  • Bugfix: Remote KVCache Loading: Fixed an issue where the remote KVCache could not be loaded when PD disaggregation is enabled and PP (presumably Pipeline Parallelism) is greater than 1. The previous condition for triggering model execution (total_num_scheduled_tokens > 0) was too restrictive and didn't account for necessary KV cache operations when no new tokens were being scheduled.
  • Scheduling Logic Adjustment: Modified the condition in the step_with_batch_queue function to include checking for scheduler_output.kv_connector_metadata. This ensures that the model executor is called if there is relevant KV connector metadata, even if no new tokens are scheduled, allowing the remote cache loading process to proceed.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the v1 label Jun 12, 2025
@chaunceyjiang chaunceyjiang changed the title [P/D][Bugfix]: Fix the issue where the remote KCCache cannot be loaded when PP > 1 [P/D][Bugfix]: Fix the issue where the remote KVCache cannot be loaded when PP > 1 Jun 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request fixes an issue where remote KV cache could not be loaded when pipeline parallelism (PP > 1) is active and there are no tokens to schedule. The change correctly modifies the condition for model execution to include checks for kv_connector_metadata.

However, a potential side effect is that the logic for determining if a batch was scheduled (scheduled_batch variable) was not updated. This might lead to the system immediately processing a batch intended for KV cache operations, rather than pipelining it as described in the function's documentation. This could affect performance and deviate from the intended execution flow. I've also suggested adding an inline comment for better code clarity regarding the new condition.

@chaunceyjiang
Copy link
Contributor Author

Hi @njhill, @aarnphm @wseaton could you help review this PR?

@aarnphm
Copy link
Collaborator

aarnphm commented Jun 17, 2025

Do you have the command to test this out? I can try to verify with my dev machine.

@chaunceyjiang
Copy link
Contributor Author

Do you have the command to test this out? I can try to verify with my dev machine.

Since NIXL currently does not support working with PP, I applied this patch: #19591, and then used the following command.

# node1 10.254.20.31
VLLM_NIXL_SIDE_CHANNEL_HOST=10.254.20.31 VLLM_NIXL_SIDE_CHANNEL_PORT=27777 VLLM_ALL2ALL_BACKEND="deepep_low_latency"   vllm serve /data/deepseek-ai/DeepSeek-R1 -pp 2   -tp 8 --port 8001  --kv-transfer-config  '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}' --gpu-memory-utilization 0.98 --num-gpu-blocks-override 13000

# node2 10.254.20.28
VLLM_NIXL_SIDE_CHANNEL_HOST=10.254.20.28 VLLM_NIXL_SIDE_CHANNEL_PORT=27777 VLLM_ALL2ALL_BACKEND="deepep_low_latency"   vllm serve /data/deepseek-ai/DeepSeek-R1  -pp 2 -tp 8 --port 8002  --kv-transfer-config  '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}' --gpu-memory-utilization 0.98 --num-gpu-blocks-override 13000

Even without applying patch #19591, we can analyze the code and see that when scheduler_output.total_num_scheduled_tokens is 0, it attempts to load the remote KV cache. However, when pp > 1, self.model_executor.execute_model(scheduler_output) is never called at all.

vllm/vllm/v1/engine/core.py

Lines 260 to 265 in 3597b06

if not self.batch_queue.full():
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output)
self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore

@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)

If pp == 1, the remote KV cache can be loaded correctly. Even when scheduler_output.total_num_scheduled_tokens is 0, self.execute_model(scheduler_output) is still executed.

vllm/vllm/v1/engine/core.py

Lines 230 to 232 in 3597b06

scheduler_output = self.scheduler.schedule()
model_output = self.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(

…n PP > 1

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @chaunceyjiang, LGTM. Would be good if @ruisearch42 or @comaniac could also confirm this looks ok.

@njhill njhill requested a review from ruisearch42 June 17, 2025 15:56
@ruisearch42
Copy link
Collaborator

Hi @chaunceyjiang , thanks for the PR. LGTM, but could you measure the performance impact? We might want to have different implementation with and without P/D.

You can take a look at the perf evals in #14585 as a reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants