diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index bf9e87198bcf..0a297479bcc0 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -11,6 +11,7 @@ - HuggingFace - VisionArena """ +import argparse import ast import base64 import io @@ -1019,6 +1020,25 @@ def sample( return samples +class _ValidateDatasetArgs(argparse.Action): + """Argparse action to validate dataset name and path compatibility.""" + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + # Get current values of both dataset_name and dataset_path + dataset_name = getattr(namespace, 'dataset_name', 'random') + dataset_path = getattr(namespace, 'dataset_path', None) + + # Validate the combination + if dataset_name == "random" and dataset_path is not None: + parser.error( + "Cannot use 'random' dataset with --dataset-path. " + "Please specify the appropriate --dataset-name (e.g., " + "'sharegpt', 'custom', 'sonnet') for your dataset file: " + f"{dataset_path}" + ) + + def add_dataset_parser(parser: FlexibleArgumentParser): parser.add_argument("--seed", type=int, default=0) parser.add_argument( @@ -1031,6 +1051,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--dataset-name", type=str, default="random", + action=_ValidateDatasetArgs, choices=[ "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", "custom", "prefix_repetition", "spec_bench" @@ -1046,6 +1067,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--dataset-path", type=str, default=None, + action=_ValidateDatasetArgs, help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.", )