Skip to content

Conversation

zhuohan123
Copy link
Member

@zhuohan123 zhuohan123 commented Sep 29, 2025

Purpose

Add throughput benchmark for DP with external launcher mode. This is benchmark the throughput for integration with RL frameworks.

Test Plan

torchrun --nproc-per-node 2 -m vllm.entrypoints.cli.main bench throughput   --model NousResearch/Hermes-3-Llama-3.1-8B   --dataset-name random   --num-prompts 100 --input-len 1024 --output-len 1024 -dp 2 --distributed_executor_backend external_launcher

Test Result

Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:14<00:00,  3.37it/s, est. speed input: 3453.30 toks/s, output: 3453.30 toks/s]
Throughput: 3.34 requests/s, 6835.37 total tokens/s, 3417.69 output tokens/s
Total num prompt tokens:  51200
Total num output tokens:  51200
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:14<00:00,  3.37it/s, est. speed input: 3455.84 toks/s, output: 3455.83 toks/s]
Throughput: 3.34 requests/s, 6837.08 total tokens/s, 3418.54 output tokens/s
Total num prompt tokens:  51200
Total num output tokens:  51200

on 2xH100s.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
@mergify mergify bot added the performance Performance-related issues label Sep 29, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for data parallel throughput benchmarking when using an external launcher like torchrun. It adds a new function to filter and distribute benchmark requests across data parallel ranks and updates argument validation to permit this new mode. My review identifies a critical issue in the new request filtering logic where a lack of validation on world_size and data_parallel_size could lead to a ZeroDivisionError or incorrect rank assignments. I've provided a code suggestion to add the necessary checks and make the implementation more robust.

Comment on lines +373 to +375
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
data_parallel_rank = global_rank // (world_size // data_parallel_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation of data_parallel_rank is susceptible to a ZeroDivisionError if data_parallel_size is greater than world_size. Additionally, if world_size is not divisible by data_parallel_size, the calculated data_parallel_rank can be incorrect for some ranks, potentially leading to out-of-bounds errors or incorrect behavior. It's crucial to add validation for these conditions to ensure the benchmark runs robustly.

Suggested change
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
data_parallel_rank = global_rank // (world_size // data_parallel_size)
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if data_parallel_size > world_size:
raise ValueError(
f"data_parallel_size ({data_parallel_size}) cannot be larger than "
f"world_size ({world_size}).")
if world_size % data_parallel_size != 0:
raise ValueError(
f"world_size ({world_size}) must be divisible by "
f"data_parallel_size ({data_parallel_size}).")
model_parallel_size = world_size // data_parallel_size
data_parallel_rank = global_rank // model_parallel_size

@zhuohan123 zhuohan123 enabled auto-merge (squash) September 29, 2025 23:58
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 29, 2025
@zhuohan123 zhuohan123 merged commit d3bd171 into main Sep 30, 2025
49 checks passed
@zhuohan123 zhuohan123 deleted the zhuohan/support-benchmark-throughput-for-external-launcher-dp branch September 30, 2025 01:43
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
…5913)

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
…lm-project#25913)

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants