Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/guidellm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,9 @@ def benchmark():
)
@click.option(
"--rate",
type=float,
multiple=True,
type=str,
callback=cli_tools.parse_list_floats,
multiple=False,
default=BenchmarkGenerativeTextArgs.get_default("rate"),
help=(
"Benchmark rate(s) to test. Meaning depends on profile: "
Expand Down Expand Up @@ -383,7 +384,7 @@ def run(**kwargs):
kwargs.get("data_args"), default=[], simplify_single=False
)
kwargs["rate"] = cli_tools.format_list_arg(
kwargs.get("rate"), default=None, simplify_single=True
kwargs.get("rate"), default=None, simplify_single=False
)

disable_console_outputs = kwargs.pop("disable_console_outputs", False)
Expand Down
27 changes: 26 additions & 1 deletion src/guidellm/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,34 @@

import click

__all__ = ["Union", "format_list_arg", "parse_json", "set_if_not_default"]
__all__ = [
"Union",
"format_list_arg",
"parse_json",
"parse_list_floats",
"set_if_not_default",
]


def parse_list_floats(ctx, param, value): # noqa: ARG001
"""
Callback to parse a comma-separated string into a list of floats.
"""
# This callback only runs if the --rate option is provided by the user.
# If it's not, 'value' will be None, and Click will use the 'default'.
if value is None:
return None # Keep the default

try:
# Split by comma, strip any whitespace, and convert to float
return [float(item.strip()) for item in value.split(",")]
except ValueError as e:
# Raise a Click error if any part isn't a valid float
raise click.BadParameter(
f"Value '{value}' is not a valid comma-separated list "
f"of floats/ints. Error: {e}"
) from e

def parse_json(ctx, param, value): # noqa: ARG001
if value is None or value == [None]:
return None
Expand Down
Loading