Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions vllm/benchmarks/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,23 @@ def get_requests(args, tokenizer):
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
requests = filter_requests_for_dp(requests, args.data_parallel_size)
return requests


def filter_requests_for_dp(requests, data_parallel_size):
# Note(zhuohan): The way we get data_parallel_rank is hacky and only
# works for external launcher mode. Should be cleaned up and deprecated
# in the future with a better vLLM distributed process design.
if data_parallel_size == 1:
return requests

global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
data_parallel_rank = global_rank // (world_size // data_parallel_size)
Comment on lines +373 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation of data_parallel_rank is susceptible to a ZeroDivisionError if data_parallel_size is greater than world_size. Additionally, if world_size is not divisible by data_parallel_size, the calculated data_parallel_rank can be incorrect for some ranks, potentially leading to out-of-bounds errors or incorrect behavior. It's crucial to add validation for these conditions to ensure the benchmark runs robustly.

Suggested change
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
data_parallel_rank = global_rank // (world_size // data_parallel_size)
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if data_parallel_size > world_size:
raise ValueError(
f"data_parallel_size ({data_parallel_size}) cannot be larger than "
f"world_size ({world_size}).")
if world_size % data_parallel_size != 0:
raise ValueError(
f"world_size ({world_size}) must be divisible by "
f"data_parallel_size ({data_parallel_size}).")
model_parallel_size = world_size // data_parallel_size
data_parallel_rank = global_rank // model_parallel_size

return [r for i, r in enumerate(requests)
if i % data_parallel_size == data_parallel_rank]


def validate_args(args):
Expand Down Expand Up @@ -453,12 +469,17 @@ def validate_args(args):
if args.backend == "mii" and args.tokenizer != args.model:
raise ValueError(
"Tokenizer must be the same as the model for MII backend.")

# --data-parallel is not supported currently.
# https://github.com/vllm-project/vllm/issues/16222
if args.data_parallel_size > 1:

if args.data_parallel_size > 1 and (
args.distributed_executor_backend != "external_launcher"
or args.async_engine):
# --data-parallel is not supported fully.
# Old issue: https://github.com/vllm-project/vllm/issues/16222
# Currently we only support data parallel with external launcher
# mode (i.e., launch with toruchrun).
raise ValueError(
"Data parallel is not supported in offline benchmark, "
"Data parallel is only supported with external launcher mode "
"with synchronous engine in offline benchmark, "
"please use benchmark serving instead"
)

Expand Down