diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 2498f0be..e75f5d25 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -156,8 +156,9 @@ def benchmark(): ) @click.option( "--rate", - type=float, - multiple=True, + type=str, + callback=cli_tools.parse_list_floats, + multiple=False, default=BenchmarkGenerativeTextArgs.get_default("rate"), help=( "Benchmark rate(s) to test. Meaning depends on profile: " @@ -383,7 +384,7 @@ def run(**kwargs): kwargs.get("data_args"), default=[], simplify_single=False ) kwargs["rate"] = cli_tools.format_list_arg( - kwargs.get("rate"), default=None, simplify_single=True + kwargs.get("rate"), default=None, simplify_single=False ) disable_console_outputs = kwargs.pop("disable_console_outputs", False) diff --git a/src/guidellm/utils/cli.py b/src/guidellm/utils/cli.py index a75c37a8..c4783f65 100644 --- a/src/guidellm/utils/cli.py +++ b/src/guidellm/utils/cli.py @@ -3,9 +3,34 @@ import click -__all__ = ["Union", "format_list_arg", "parse_json", "set_if_not_default"] +__all__ = [ + "Union", + "format_list_arg", + "parse_json", + "parse_list_floats", + "set_if_not_default", +] +def parse_list_floats(ctx, param, value): # noqa: ARG001 + """ + Callback to parse a comma-separated string into a list of floats. + """ + # This callback only runs if the --rate option is provided by the user. + # If it's not, 'value' will be None, and Click will use the 'default'. + if value is None: + return None # Keep the default + + try: + # Split by comma, strip any whitespace, and convert to float + return [float(item.strip()) for item in value.split(",")] + except ValueError as e: + # Raise a Click error if any part isn't a valid float + raise click.BadParameter( + f"Value '{value}' is not a valid comma-separated list " + f"of floats/ints. Error: {e}" + ) from e + def parse_json(ctx, param, value): # noqa: ARG001 if value is None or value == [None]: return None