From 87ba006c78d237012ebc253dceb8e4c7e20fc284 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 9 Oct 2025 18:11:19 -0400 Subject: [PATCH 1/4] Fixed quality errors Many of the quality errors are due to using the older union style, and have appeared due to the upgrade of the minimum Python version from 3.9 to 3.10 Signed-off-by: Jared O'Connell --- src/guidellm/backends/openai.py | 2 +- src/guidellm/benchmark/aggregator.py | 2 +- src/guidellm/benchmark/benchmarker.py | 4 +- src/guidellm/benchmark/output.py | 2 +- src/guidellm/benchmark/scenario.py | 2 +- src/guidellm/dataset/creator.py | 10 ++-- src/guidellm/dataset/file.py | 4 +- src/guidellm/dataset/hf_datasets.py | 6 +-- src/guidellm/scheduler/constraints.py | 20 ++++---- src/guidellm/scheduler/worker.py | 2 +- src/guidellm/utils/encoding.py | 72 +++++++++++++-------------- src/guidellm/utils/hf_datasets.py | 3 +- src/guidellm/utils/hf_transformers.py | 8 +-- src/guidellm/utils/messaging.py | 6 +-- src/guidellm/utils/mixins.py | 2 +- src/guidellm/utils/pydantic_utils.py | 3 +- src/guidellm/utils/random.py | 7 ++- src/guidellm/utils/registry.py | 5 +- src/guidellm/utils/statistics.py | 44 ++++++++-------- src/guidellm/utils/synchronous.py | 13 ++--- src/guidellm/utils/typing.py | 7 +-- 21 files changed, 108 insertions(+), 116 deletions(-) diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index ce83076f..fd14ee65 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -559,7 +559,7 @@ def _get_chat_messages( resolved_content.append(item) elif isinstance(item, str): resolved_content.append({"type": "text", "text": item}) - elif isinstance(item, (Image.Image, Path)): + elif isinstance(item, Image.Image | Path): resolved_content.append(self._get_chat_message_media_item(item)) else: raise ValueError(f"Unsupported content item type: {type(item)}") diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index e965c482..be70276b 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -267,7 +267,7 @@ def resolve( resolved = {} for key, val in aggregators.items(): - if isinstance(val, (Aggregator, CompilableAggregator)): + if isinstance(val, Aggregator | CompilableAggregator): resolved[key] = val else: aggregator_class = cls.get_registered_object(key) diff --git a/src/guidellm/benchmark/benchmarker.py b/src/guidellm/benchmark/benchmarker.py index 5f05065a..99410e4c 100644 --- a/src/guidellm/benchmark/benchmarker.py +++ b/src/guidellm/benchmark/benchmarker.py @@ -228,12 +228,12 @@ def _combine( existing: dict[str, Any] | StandardBaseDict, addition: dict[str, Any] | StandardBaseDict, ) -> dict[str, Any] | StandardBaseDict: - if not isinstance(existing, (dict, StandardBaseDict)): + if not isinstance(existing, dict | StandardBaseDict): raise ValueError( f"Existing value {existing} (type: {type(existing).__name__}) " f"is not a valid type for merging." ) - if not isinstance(addition, (dict, StandardBaseDict)): + if not isinstance(addition, dict | StandardBaseDict): raise ValueError( f"Addition value {addition} (type: {type(addition).__name__}) " f"is not a valid type for merging." diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index 56775dac..c4e8fb0f 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -90,7 +90,7 @@ def resolve( if not output_formats: return {} - if isinstance(output_formats, (list, tuple)): + if isinstance(output_formats, list | tuple): # support list of output keys: ["csv", "json"] # support list of files: ["path/to/file.json", "path/to/file.csv"] formats_list = output_formats diff --git a/src/guidellm/benchmark/scenario.py b/src/guidellm/benchmark/scenario.py index b53ef424..5299616f 100644 --- a/src/guidellm/benchmark/scenario.py +++ b/src/guidellm/benchmark/scenario.py @@ -38,7 +38,7 @@ def parse_float_list(value: str | float | list[float]) -> list[float]: or convert single float list of one or pass float list through. """ - if isinstance(value, (int, float)): + if isinstance(value, int | float): return [value] elif isinstance(value, list): return value diff --git a/src/guidellm/dataset/creator.py b/src/guidellm/dataset/creator.py index a74ec8c0..b95f4c50 100644 --- a/src/guidellm/dataset/creator.py +++ b/src/guidellm/dataset/creator.py @@ -95,10 +95,10 @@ def create( data, data_args, processor, processor_args, random_seed ) - if isinstance(dataset, (DatasetDict, IterableDatasetDict)): + if isinstance(dataset, DatasetDict | IterableDatasetDict): dataset = cls.extract_dataset_split(dataset, split, split_pref_order) - if not isinstance(dataset, (Dataset, IterableDataset)): + if not isinstance(dataset, Dataset | IterableDataset): raise ValueError( f"Unsupported data type: {type(dataset)} given for {dataset}." ) @@ -145,10 +145,10 @@ def extract_args_column_mappings( def extract_dataset_name( cls, dataset: Union[Dataset, IterableDataset, DatasetDict, IterableDatasetDict] ) -> Optional[str]: - if isinstance(dataset, (DatasetDict, IterableDatasetDict)): + if isinstance(dataset, DatasetDict | IterableDatasetDict): dataset = dataset[list(dataset.keys())[0]] - if isinstance(dataset, (Dataset, IterableDataset)): + if isinstance(dataset, Dataset | IterableDataset): if not hasattr(dataset, "info") or not hasattr( dataset.info, "dataset_name" ): @@ -165,7 +165,7 @@ def extract_dataset_split( specified_split: Union[Literal["auto"], str] = "auto", split_pref_order: Optional[Union[Literal["auto"], list[str]]] = "auto", ) -> Union[Dataset, IterableDataset]: - if not isinstance(dataset, (DatasetDict, IterableDatasetDict)): + if not isinstance(dataset, DatasetDict | IterableDatasetDict): raise ValueError( f"Unsupported data type: {type(dataset)} given for {dataset}." ) diff --git a/src/guidellm/dataset/file.py b/src/guidellm/dataset/file.py index 5d6df1d9..455ef580 100644 --- a/src/guidellm/dataset/file.py +++ b/src/guidellm/dataset/file.py @@ -31,7 +31,7 @@ class FileDatasetCreator(DatasetCreator): @classmethod def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003 - if isinstance(data, (str, Path)) and (path := Path(data)).exists(): + if isinstance(data, str | Path) and (path := Path(data)).exists(): # local folder or py file, assume supported return path.suffix.lower() in cls.SUPPORTED_TYPES @@ -46,7 +46,7 @@ def handle_create( processor_args: Optional[dict[str, Any]], # noqa: ARG003 random_seed: int, # noqa: ARG003 ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: - if not isinstance(data, (str, Path)): + if not isinstance(data, str | Path): raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ") path = Path(data) diff --git a/src/guidellm/dataset/hf_datasets.py b/src/guidellm/dataset/hf_datasets.py index 7f91facd..56c79936 100644 --- a/src/guidellm/dataset/hf_datasets.py +++ b/src/guidellm/dataset/hf_datasets.py @@ -25,11 +25,11 @@ def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # base type is supported return True - if isinstance(data, (str, Path)) and (path := Path(data)).exists(): + if isinstance(data, str | Path) and (path := Path(data)).exists(): # local folder or py file, assume supported return path.is_dir() or path.suffix == ".py" - if isinstance(data, (str, Path)): + if isinstance(data, str | Path): try: # try to load dataset return get_dataset_config_info(data) is not None @@ -47,7 +47,7 @@ def handle_create( processor_args: Optional[dict[str, Any]], # noqa: ARG003 random_seed: int, # noqa: ARG003 ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: - if isinstance(data, (str, Path)): + if isinstance(data, str | Path): data = load_dataset(data, **(data_args or {})) elif data_args: raise ValueError( diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py index c724a74a..c974225a 100644 --- a/src/guidellm/scheduler/constraints.py +++ b/src/guidellm/scheduler/constraints.py @@ -450,7 +450,7 @@ def __call__( current_index = max(0, self.current_index) max_num = ( self.max_num - if isinstance(self.max_num, (int, float)) + if isinstance(self.max_num, int | float) else self.max_num[min(current_index, len(self.max_num) - 1)] ) @@ -489,7 +489,7 @@ def _validate_max_num( raise ValueError( f"max_num must be set and truthful, received {value} ({val} failed)" ) - if not isinstance(val, (int, float)) or val <= 0: + if not isinstance(val, int | float) or val <= 0: raise ValueError( f"max_num must be a positive num, received {value} ({val} failed)" ) @@ -568,7 +568,7 @@ def __call__( current_index = max(0, self.current_index) max_duration = ( self.max_duration - if isinstance(self.max_duration, (int, float)) + if isinstance(self.max_duration, int | float) else self.max_duration[min(current_index, len(self.max_duration) - 1)] ) @@ -607,7 +607,7 @@ def _validate_max_duration( "max_duration must be set and truthful, " f"received {value} ({val} failed)" ) - if not isinstance(val, (int, float)) or val <= 0: + if not isinstance(val, int | float) or val <= 0: raise ValueError( "max_duration must be a positive num," f"received {value} ({val} failed)" @@ -682,7 +682,7 @@ def __call__( current_index = max(0, self.current_index) max_errors = ( self.max_errors - if isinstance(self.max_errors, (int, float)) + if isinstance(self.max_errors, int | float) else self.max_errors[min(current_index, len(self.max_errors) - 1)] ) errors_exceeded = state.errored_requests >= max_errors @@ -710,7 +710,7 @@ def _validate_max_errors( "max_errors must be set and truthful, " f"received {value} ({val} failed)" ) - if not isinstance(val, (int, float)) or val <= 0: + if not isinstance(val, int | float) or val <= 0: raise ValueError( f"max_errors must be a positive num,received {value} ({val} failed)" ) @@ -799,7 +799,7 @@ def __call__( current_index = max(0, self.current_index) max_error_rate = ( self.max_error_rate - if isinstance(self.max_error_rate, (int, float)) + if isinstance(self.max_error_rate, int | float) else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] ) @@ -850,7 +850,7 @@ def _validate_max_error_rate( "max_error_rate must be set and truthful, " f"received {value} ({val} failed)" ) - if not isinstance(val, (int, float)) or val <= 0 or val >= 1: + if not isinstance(val, int | float) or val <= 0 or val >= 1: raise ValueError( "max_error_rate must be a number between 0 and 1," f"received {value} ({val} failed)" @@ -940,7 +940,7 @@ def __call__( current_index = max(0, self.current_index) max_error_rate = ( self.max_error_rate - if isinstance(self.max_error_rate, (int, float)) + if isinstance(self.max_error_rate, int | float) else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] ) @@ -982,7 +982,7 @@ def _validate_max_error_rate( "max_error_rate must be set and truthful, " f"received {value} ({val} failed)" ) - if not isinstance(val, (int, float)) or val <= 0 or val >= 1: + if not isinstance(val, int | float) or val <= 0 or val >= 1: raise ValueError( "max_error_rate must be a number between 0 and 1," f"received {value} ({val} failed)" diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 5f2fb74b..104ab418 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -310,7 +310,7 @@ async def _process_next_request(self): # Pull request from the queue request, request_info = await self.messaging.get() - if isinstance(request, (list, tuple)): + if isinstance(request, list | tuple): raise NotImplementedError("Multi-turn requests are not yet supported") # Calculate targeted start and set pending state for request diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py index 6823fb77..916d6633 100644 --- a/src/guidellm/utils/encoding.py +++ b/src/guidellm/utils/encoding.py @@ -12,7 +12,7 @@ import json from collections.abc import Mapping -from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, cast +from typing import Any, ClassVar, Generic, Literal, TypeVar, cast try: import msgpack # type: ignore[import-untyped] # Optional dependency @@ -45,7 +45,6 @@ HAS_ORJSON = False from pydantic import BaseModel -from typing_extensions import TypeAlias __all__ = [ "Encoder", @@ -60,14 +59,10 @@ ObjT = TypeVar("ObjT") MsgT = TypeVar("MsgT") -SerializationTypesAlias: TypeAlias = Annotated[ - Optional[Literal["dict", "sequence"]], - "Type alias for available serialization strategies", -] -EncodingTypesAlias: TypeAlias = Annotated[ - Optional[Literal["msgpack", "msgspec"]], - "Type alias for available binary encoding formats", -] +# Type alias for available serialization strategies +SerializationTypesAlias = Literal["dict", "sequence"] | None +# "Type alias for available binary encoding formats" +EncodingTypesAlias = Literal["msgpack", "msgspec"] class MessageEncoding(Generic[ObjT, MsgT]): @@ -405,7 +400,7 @@ def to_dict(self, obj: Any) -> Any: if isinstance(obj, BaseModel): return self.to_dict_pydantic(obj) - if isinstance(obj, (list, tuple)) and any( + if isinstance(obj, list | tuple) and any( isinstance(item, BaseModel) for item in obj ): return [ @@ -432,7 +427,7 @@ def from_dict(self, data: Any) -> Any: :param data: Dictionary representation possibly containing type metadata :return: Reconstructed object with proper types restored """ - if isinstance(data, (list, tuple)): + if isinstance(data, list | tuple): return [ self.from_dict_pydantic(item) if isinstance(item, dict) and "*PYD*" in item @@ -493,7 +488,7 @@ def to_sequence(self, obj: Any) -> str | Any: if isinstance(obj, BaseModel): payload_type = "pydantic" payload = self.to_sequence_pydantic(obj) - elif isinstance(obj, (list, tuple)) and any( + elif isinstance(obj, list | tuple) and any( isinstance(item, BaseModel) for item in obj ): payload_type = "collection_sequence" @@ -694,33 +689,36 @@ def pack_next_sequence( # noqa: C901, PLR0912 length=(payload_len.bit_length() + 7) // 8 if payload_len > 0 else 1, byteorder="big", ) - if type_ == "pydantic": - payload_type = b"P" - elif type_ == "python": - payload_type = b"p" - elif type_ == "collection_tuple": - payload_type = b"T" - elif type_ == "collection_sequence": - payload_type = b"S" - elif type_ == "collection_mapping": - payload_type = b"M" - else: - raise ValueError(f"Unknown type for packing: {type_}") + match type_: + case "pydantic": + payload_type = b"P" + case "python": + payload_type = b"p" + case "collection_tuple": + payload_type = b"T" + case "collection_sequence": + payload_type = b"S" + case "collection_mapping": + payload_type = b"M" + case _: + raise ValueError(f"Unknown type for packing: {type_}") delimiter = b"|" else: payload_len_output = str(payload_len) - if type_ == "pydantic": - payload_type = "P" - elif type_ == "python": - payload_type = "p" - elif type_ == "collection_tuple": - payload_type = "T" - elif type_ == "collection_sequence": - payload_type = "S" - elif type_ == "collection_mapping": - payload_type = "M" - else: - raise ValueError(f"Unknown type for packing: {type_}") + + match type_: + case "pydantic": + payload_type = "P" + case "python": + payload_type = "p" + case "collection_tuple": + payload_type = "T" + case "collection_sequence": + payload_type = "S" + case "collection_mapping": + payload_type = "M" + case _: + raise ValueError(f"Unknown type for packing: {type_}") delimiter = "|" # Type ignores because types are enforced at runtime diff --git a/src/guidellm/utils/hf_datasets.py b/src/guidellm/utils/hf_datasets.py index 73e55ebc..86f04485 100644 --- a/src/guidellm/utils/hf_datasets.py +++ b/src/guidellm/utils/hf_datasets.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Union from datasets import Dataset @@ -11,7 +10,7 @@ } -def save_dataset_to_file(dataset: Dataset, output_path: Union[str, Path]) -> None: +def save_dataset_to_file(dataset: Dataset, output_path: str | Path) -> None: """ Saves a HuggingFace Dataset to file in a supported format. diff --git a/src/guidellm/utils/hf_transformers.py b/src/guidellm/utils/hf_transformers.py index 1f2aa1b5..636988c3 100644 --- a/src/guidellm/utils/hf_transformers.py +++ b/src/guidellm/utils/hf_transformers.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from transformers import AutoTokenizer, PreTrainedTokenizerBase # type: ignore[import] @@ -9,15 +9,15 @@ def check_load_processor( - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, error_msg: str, ) -> PreTrainedTokenizerBase: if processor is None: raise ValueError(f"Processor/Tokenizer is required for {error_msg}.") try: - if isinstance(processor, (str, Path)): + if isinstance(processor, str | Path): loaded = AutoTokenizer.from_pretrained( processor, **(processor_args or {}), diff --git a/src/guidellm/utils/messaging.py b/src/guidellm/utils/messaging.py index 9311259d..4dce576d 100644 --- a/src/guidellm/utils/messaging.py +++ b/src/guidellm/utils/messaging.py @@ -16,13 +16,13 @@ import threading import time from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Callable, Iterable from multiprocessing.connection import Connection from multiprocessing.context import BaseContext from multiprocessing.managers import SyncManager from multiprocessing.synchronize import Event as ProcessingEvent from threading import Event as ThreadingEvent -from typing import Any, Callable, Generic, Protocol, TypeVar, cast +from typing import Any, Generic, Protocol, TypeVar, cast import culsans from pydantic import BaseModel @@ -420,7 +420,7 @@ def _create_check_stop_callable( stop_events = tuple( item for item in stop_criteria or [] - if isinstance(item, (ThreadingEvent, ProcessingEvent)) + if isinstance(item, ThreadingEvent | ProcessingEvent) ) stop_callbacks = tuple(item for item in stop_criteria or [] if callable(item)) diff --git a/src/guidellm/utils/mixins.py b/src/guidellm/utils/mixins.py index b001ff2d..7cf28d00 100644 --- a/src/guidellm/utils/mixins.py +++ b/src/guidellm/utils/mixins.py @@ -91,7 +91,7 @@ def create_info_dict(cls, obj: Any) -> dict[str, Any]: "attributes": ( { key: val - if isinstance(val, (str, int, float, bool, list, dict)) + if isinstance(val, str | int | float | bool | list | dict) else repr(val) for key, val in obj.__dict__.items() if not key.startswith("_") diff --git a/src/guidellm/utils/pydantic_utils.py b/src/guidellm/utils/pydantic_utils.py index 55816ef1..05f5ad81 100644 --- a/src/guidellm/utils/pydantic_utils.py +++ b/src/guidellm/utils/pydantic_utils.py @@ -11,11 +11,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, ClassVar, Generic, TypeVar, cast +from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema -from typing_extensions import get_args, get_origin from guidellm.utils.registry import RegistryMixin diff --git a/src/guidellm/utils/random.py b/src/guidellm/utils/random.py index ceef20b9..6c8f396d 100644 --- a/src/guidellm/utils/random.py +++ b/src/guidellm/utils/random.py @@ -1,6 +1,5 @@ import random from collections.abc import Iterator -from typing import Optional __all__ = ["IntegerRangeSampler"] @@ -9,9 +8,9 @@ class IntegerRangeSampler: def __init__( self, average: int, - variance: Optional[int], - min_value: Optional[int], - max_value: Optional[int], + variance: int | None, + min_value: int | None, + max_value: int | None, random_seed: int, ): self.average = average diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py index e6f1b657..e4727cbd 100644 --- a/src/guidellm/utils/registry.py +++ b/src/guidellm/utils/registry.py @@ -10,7 +10,8 @@ from __future__ import annotations -from typing import Any, Callable, ClassVar, Generic, TypeVar, cast +from collections.abc import Callable +from typing import Any, ClassVar, Generic, TypeVar, cast from guidellm.utils.auto_importer import AutoImporterMixin @@ -103,7 +104,7 @@ def register_decorator( if name is None: name = obj.__name__ - elif not isinstance(name, (str, list)): + elif not isinstance(name, str | list): raise ValueError( "RegistryMixin.register_decorator name must be a string or " f"an iterable of strings. Got {name}." diff --git a/src/guidellm/utils/statistics.py b/src/guidellm/utils/statistics.py index acd9d4f1..04484c2c 100644 --- a/src/guidellm/utils/statistics.py +++ b/src/guidellm/utils/statistics.py @@ -149,7 +149,7 @@ def from_distribution_function( in the output :return: DistributionSummary instance with calculated statistical metrics """ - values, weights = zip(*distribution) if distribution else ([], []) + values, weights = zip(*distribution, strict=True) if distribution else ([], []) values = np.array(values) # type: ignore[assignment] weights = np.array(weights) # type: ignore[assignment] @@ -247,7 +247,7 @@ def from_values( ) return DistributionSummary.from_distribution_function( - distribution=list(zip(values, weights)), + distribution=list(zip(values, weights, strict=True)), include_cdf=include_cdf, ) @@ -389,7 +389,8 @@ def from_iterable_request_times( events[global_end] = 0 for (_, end), first_iter, first_iter_count, total_count in zip( - requests, first_iter_times, first_iter_counts, iter_counts + requests, first_iter_times, first_iter_counts, iter_counts, + strict=True ): events[first_iter] += first_iter_count @@ -499,36 +500,36 @@ def from_values( ) _, successful_values, successful_weights = ( - zip(*successful) + zip(*successful, strict=True) if ( successful := list( filter( lambda val: val[0] == "successful", - zip(value_types, values, weights), + zip(value_types, values, weights, strict=True), ) ) ) else ([], [], []) ) _, incomplete_values, incomplete_weights = ( - zip(*incomplete) + zip(*incomplete, strict=True) if ( incomplete := list( filter( lambda val: val[0] == "incomplete", - zip(value_types, values, weights), + zip(value_types, values, weights, strict=True), ) ) ) else ([], [], []) ) _, errored_values, errored_weights = ( - zip(*errored) + zip(*errored, strict=True) if ( errored := list( filter( lambda val: val[0] == "error", - zip(value_types, values, weights), + zip(value_types, values, weights, strict=True), ) ) ) @@ -604,36 +605,36 @@ def from_request_times( ) _, successful_requests = ( - zip(*successful) + zip(*successful, strict=True) if ( successful := list( filter( lambda val: val[0] == "successful", - zip(request_types, requests), + zip(request_types, requests, strict=True), ) ) ) else ([], []) ) _, incomplete_requests = ( - zip(*incomplete) + zip(*incomplete, strict=True) if ( incomplete := list( filter( lambda val: val[0] == "incomplete", - zip(request_types, requests), + zip(request_types, requests, strict=True), ) ) ) else ([], []) ) _, errored_requests = ( - zip(*errored) + zip(*errored, strict=True) if ( errored := list( filter( lambda val: val[0] == "error", - zip(request_types, requests), + zip(request_types, requests, strict=True), ) ) ) @@ -734,7 +735,7 @@ def from_iterable_request_times( successful_iter_counts, successful_first_iter_counts, ) = ( - zip(*successful) + zip(*successful, strict=True) if ( successful := list( filter( @@ -745,6 +746,7 @@ def from_iterable_request_times( first_iter_times, iter_counts, first_iter_counts, + strict=True, ), ) ) @@ -758,7 +760,7 @@ def from_iterable_request_times( incomplete_iter_counts, incomplete_first_iter_counts, ) = ( - zip(*incomplete) + zip(*incomplete, strict=True) if ( incomplete := list( filter( @@ -769,6 +771,7 @@ def from_iterable_request_times( first_iter_times, iter_counts, first_iter_counts, + strict=True, ), ) ) @@ -782,7 +785,7 @@ def from_iterable_request_times( errored_iter_counts, errored_first_iter_counts, ) = ( - zip(*errored) + zip(*errored, strict=True) if ( errored := list( filter( @@ -793,6 +796,7 @@ def from_iterable_request_times( first_iter_times, iter_counts, first_iter_counts, + strict=True, ), ) ) @@ -904,7 +908,7 @@ def __add__(self, value: Any) -> float: :return: Updated mean after adding the value :raises ValueError: If value is not numeric (int or float) """ - if not isinstance(value, (int, float)): + if not isinstance(value, int | float): raise ValueError( f"Value must be an int or float, got {type(value)} instead.", ) @@ -921,7 +925,7 @@ def __iadd__(self, value: Any) -> RunningStats: :return: Self reference for method chaining :raises ValueError: If value is not numeric (int or float) """ - if not isinstance(value, (int, float)): + if not isinstance(value, int | float): raise ValueError( f"Value must be an int or float, got {type(value)} instead.", ) diff --git a/src/guidellm/utils/synchronous.py b/src/guidellm/utils/synchronous.py index 64c14e94..d37daec2 100644 --- a/src/guidellm/utils/synchronous.py +++ b/src/guidellm/utils/synchronous.py @@ -16,9 +16,6 @@ from multiprocessing.synchronize import Event as ProcessingEvent from threading import Barrier as ThreadingBarrier from threading import Event as ThreadingEvent -from typing import Annotated, Union - -from typing_extensions import TypeAlias __all__ = [ "SyncObjectTypesAlias", @@ -28,10 +25,10 @@ ] -SyncObjectTypesAlias: TypeAlias = Annotated[ - Union[ThreadingEvent, ProcessingEvent, ThreadingBarrier, ProcessingBarrier], - "Type alias for threading and multiprocessing synchronization object types", -] +# Type alias for threading and multiprocessing synchronization object types +SyncObjectTypesAlias = ( + ThreadingEvent | ProcessingEvent | ThreadingBarrier | ProcessingBarrier +) async def wait_for_sync_event( @@ -146,7 +143,7 @@ async def wait_for_sync_objects( tasks = [ asyncio.create_task( wait_for_sync_barrier(obj, poll_interval) - if isinstance(obj, (ThreadingBarrier, ProcessingBarrier)) + if isinstance(obj, ThreadingBarrier | ProcessingBarrier) else wait_for_sync_event(obj, poll_interval) ) for obj in objects diff --git a/src/guidellm/utils/typing.py b/src/guidellm/utils/typing.py index 8146ea1e..8d3580ef 100644 --- a/src/guidellm/utils/typing.py +++ b/src/guidellm/utils/typing.py @@ -1,14 +1,9 @@ from __future__ import annotations from collections.abc import Iterator +from types import UnionType from typing import Annotated, Literal, Union, get_args, get_origin -# Backwards compatibility for Python <3.10 -try: - from types import UnionType # type: ignore[attr-defined] -except ImportError: - UnionType = Union - # Backwards compatibility for Python <3.12 try: from typing import TypeAliasType # type: ignore[attr-defined] From 1bd8846a10b58b1b3fdce55925335621e32d0c00 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 9 Oct 2025 18:14:33 -0400 Subject: [PATCH 2/4] Run auto-formatter Signed-off-by: Jared O'Connell --- setup.py | 7 +- src/guidellm/backends/objects.py | 28 +++---- src/guidellm/backends/openai.py | 74 +++++++++---------- src/guidellm/benchmark/aggregator.py | 22 +++--- src/guidellm/benchmark/output.py | 16 ++-- src/guidellm/benchmark/profile.py | 4 +- src/guidellm/benchmark/scenario.py | 3 +- src/guidellm/dataset/creator.py | 38 +++++----- src/guidellm/dataset/entrypoints.py | 12 +-- src/guidellm/dataset/file.py | 16 ++-- src/guidellm/dataset/hf_datasets.py | 12 +-- src/guidellm/dataset/in_memory.py | 12 +-- src/guidellm/dataset/synthetic.py | 38 +++++----- src/guidellm/logger.py | 2 +- src/guidellm/preprocess/dataset.py | 44 +++++------ src/guidellm/presentation/data_models.py | 16 ++-- src/guidellm/presentation/injector.py | 3 +- src/guidellm/request/loader.py | 30 +++----- src/guidellm/request/request.py | 4 +- src/guidellm/scheduler/objects.py | 6 +- src/guidellm/utils/statistics.py | 3 +- tests/integration/scheduler/test_scheduler.py | 2 +- tests/unit/benchmark/test_output.py | 9 ++- tests/unit/dataset/test_synthetic.py | 2 +- tests/unit/mock_backend.py | 12 +-- tests/unit/mock_benchmark.py | 1 + tests/unit/utils/test_encoding.py | 2 +- tests/unit/utils/test_typing.py | 5 +- 28 files changed, 203 insertions(+), 220 deletions(-) diff --git a/setup.py b/setup.py index 623bad28..d3b92889 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ import os import re from pathlib import Path -from typing import Optional, Union from packaging.version import Version from setuptools import setup @@ -11,7 +10,7 @@ TAG_VERSION_PATTERN = re.compile(r"^v(\d+\.\d+\.\d+)$") -def get_last_version_diff() -> tuple[Version, Optional[str], Optional[int]]: +def get_last_version_diff() -> tuple[Version, str | None, int | None]: """ Get the last version, last tag, and the number of commits since the last tag. If no tags are found, return the last release version and None for the tag/commits. @@ -38,8 +37,8 @@ def get_last_version_diff() -> tuple[Version, Optional[str], Optional[int]]: def get_next_version( - build_type: str, build_iteration: Optional[Union[str, int]] -) -> tuple[Version, Optional[str], int]: + build_type: str, build_iteration: str | int | None +) -> tuple[Version, str | None, int]: """ Get the next version based on the build type and iteration. - build_type == release: take the last version and add a post if build iteration diff --git a/src/guidellm/backends/objects.py b/src/guidellm/backends/objects.py index 05280940..001aeb70 100644 --- a/src/guidellm/backends/objects.py +++ b/src/guidellm/backends/objects.py @@ -7,7 +7,7 @@ """ import uuid -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import Field @@ -73,32 +73,32 @@ class GenerationResponse(StandardBaseModel): request_args: dict[str, Any] = Field( description="Arguments passed to the backend for this request." ) - value: Optional[str] = Field( + value: str | None = Field( default=None, description="Complete generated text content. None for streaming responses.", ) - delta: Optional[str] = Field( + delta: str | None = Field( default=None, description="Incremental text content for streaming responses." ) iterations: int = Field( default=0, description="Number of generation iterations completed." ) - request_prompt_tokens: Optional[int] = Field( + request_prompt_tokens: int | None = Field( default=None, description="Token count from the original request prompt." ) - request_output_tokens: Optional[int] = Field( + request_output_tokens: int | None = Field( default=None, description="Expected output token count from the original request.", ) - response_prompt_tokens: Optional[int] = Field( + response_prompt_tokens: int | None = Field( default=None, description="Actual prompt token count reported by the backend." ) - response_output_tokens: Optional[int] = Field( + response_output_tokens: int | None = Field( default=None, description="Actual output token count reported by the backend." ) @property - def prompt_tokens(self) -> Optional[int]: + def prompt_tokens(self) -> int | None: """ :return: The number of prompt tokens used in the request (response_prompt_tokens if available, otherwise request_prompt_tokens). @@ -106,7 +106,7 @@ def prompt_tokens(self) -> Optional[int]: return self.response_prompt_tokens or self.request_prompt_tokens @property - def output_tokens(self) -> Optional[int]: + def output_tokens(self) -> int | None: """ :return: The number of output tokens generated in the response (response_output_tokens if available, otherwise request_output_tokens). @@ -114,7 +114,7 @@ def output_tokens(self) -> Optional[int]: return self.response_output_tokens or self.request_output_tokens @property - def total_tokens(self) -> Optional[int]: + def total_tokens(self) -> int | None: """ :return: The total number of tokens used in the request and response. Sum of prompt_tokens and output_tokens. @@ -125,7 +125,7 @@ def total_tokens(self) -> Optional[int]: def preferred_prompt_tokens( self, preferred_source: Literal["request", "response"] - ) -> Optional[int]: + ) -> int | None: if preferred_source == "request": return self.request_prompt_tokens or self.response_prompt_tokens else: @@ -133,7 +133,7 @@ def preferred_prompt_tokens( def preferred_output_tokens( self, preferred_source: Literal["request", "response"] - ) -> Optional[int]: + ) -> int | None: if preferred_source == "request": return self.request_output_tokens or self.response_output_tokens else: @@ -146,11 +146,11 @@ class GenerationRequestTimings(MeasuredRequestTimings): """Timing model for tracking generation request lifecycle events.""" timings_type: Literal["generation_request_timings"] = "generation_request_timings" - first_iteration: Optional[float] = Field( + first_iteration: float | None = Field( default=None, description="Unix timestamp when the first generation iteration began.", ) - last_iteration: Optional[float] = Field( + last_iteration: float | None = Field( default=None, description="Unix timestamp when the last generation iteration completed.", ) diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index fd14ee65..fd539063 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -17,7 +17,7 @@ import time from collections.abc import AsyncIterator from pathlib import Path -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar import httpx from PIL import Image @@ -38,8 +38,8 @@ class UsageStats: """Token usage statistics for generation requests.""" - prompt_tokens: Optional[int] = None - output_tokens: Optional[int] = None + prompt_tokens: int | None = None + output_tokens: int | None = None @Backend.register("openai_http") @@ -78,19 +78,19 @@ class OpenAIHTTPBackend(Backend): def __init__( self, target: str, - model: Optional[str] = None, - api_key: Optional[str] = None, - organization: Optional[str] = None, - project: Optional[str] = None, + model: str | None = None, + api_key: str | None = None, + organization: str | None = None, + project: str | None = None, timeout: float = 60.0, http2: bool = True, follow_redirects: bool = True, - max_output_tokens: Optional[int] = None, + max_output_tokens: int | None = None, stream_response: bool = True, - extra_query: Optional[dict] = None, - extra_body: Optional[dict] = None, - remove_from_body: Optional[list[str]] = None, - headers: Optional[dict] = None, + extra_query: dict | None = None, + extra_body: dict | None = None, + remove_from_body: list[str] | None = None, + headers: dict | None = None, verify: bool = False, ): """ @@ -137,7 +137,7 @@ def __init__( # Runtime state self._in_process = False - self._async_client: Optional[httpx.AsyncClient] = None + self._async_client: httpx.AsyncClient | None = None @property def info(self) -> dict[str, Any]: @@ -264,7 +264,7 @@ async def available_models(self) -> list[str]: return [item["id"] for item in response.json()["data"]] - async def default_model(self) -> Optional[str]: + async def default_model(self) -> str | None: """ Get the default model for this backend. @@ -280,7 +280,7 @@ async def resolve( self, request: GenerationRequest, request_info: ScheduledRequestInfo, - history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + history: list[tuple[GenerationRequest, GenerationResponse]] | None = None, ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: """ Process a generation request and yield progressive responses. @@ -363,12 +363,12 @@ async def resolve( async def text_completions( self, - prompt: Union[str, list[str]], - request_id: Optional[str], # noqa: ARG002 - output_token_count: Optional[int] = None, + prompt: str | list[str], + request_id: str | None, # noqa: ARG002 + output_token_count: int | None = None, stream_response: bool = True, **kwargs, - ) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]: + ) -> AsyncIterator[tuple[str | None, UsageStats | None]]: """ Generate text completions using the /v1/completions endpoint. @@ -431,17 +431,13 @@ async def text_completions( async def chat_completions( self, - content: Union[ - str, - list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], - Any, - ], - request_id: Optional[str] = None, # noqa: ARG002 - output_token_count: Optional[int] = None, + content: str | list[str | dict[str, str | dict[str, str]] | Path | Image.Image] | Any, + request_id: str | None = None, # noqa: ARG002 + output_token_count: int | None = None, raw_content: bool = False, stream_response: bool = True, **kwargs, - ) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]: + ) -> AsyncIterator[tuple[str | None, UsageStats | None]]: """ Generate chat completions using the /v1/chat/completions endpoint. @@ -502,10 +498,10 @@ async def chat_completions( def _build_headers( self, - api_key: Optional[str], - organization: Optional[str], - project: Optional[str], - user_headers: Optional[dict], + api_key: str | None, + organization: str | None, + project: str | None, + user_headers: dict | None, ) -> dict[str, str]: headers = {} @@ -541,11 +537,7 @@ def _get_params(self, endpoint_type: str) -> dict[str, str]: def _get_chat_messages( self, - content: Union[ - str, - list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], - Any, - ], + content: str | list[str | dict[str, str | dict[str, str]] | Path | Image.Image] | Any, ) -> list[dict[str, Any]]: if isinstance(content, str): return [{"role": "user", "content": content}] @@ -567,7 +559,7 @@ def _get_chat_messages( return [{"role": "user", "content": resolved_content}] def _get_chat_message_media_item( - self, item: Union[Path, Image.Image] + self, item: Path | Image.Image ) -> dict[str, Any]: if isinstance(item, Image.Image): encoded = base64.b64encode(item.tobytes()).decode("utf-8") @@ -597,8 +589,8 @@ def _get_chat_message_media_item( def _get_body( self, endpoint_type: str, - request_kwargs: Optional[dict[str, Any]], - max_output_tokens: Optional[int] = None, + request_kwargs: dict[str, Any] | None, + max_output_tokens: int | None = None, **kwargs, ) -> dict[str, Any]: # Start with endpoint-specific extra body parameters @@ -628,7 +620,7 @@ def _get_body( return {key: val for key, val in body.items() if val is not None} - def _get_completions_text_content(self, data: dict) -> Optional[str]: + def _get_completions_text_content(self, data: dict) -> str | None: if not data.get("choices"): return None @@ -639,7 +631,7 @@ def _get_completions_text_content(self, data: dict) -> Optional[str]: or choice.get("message", {}).get("content") ) - def _get_completions_usage_stats(self, data: dict) -> Optional[UsageStats]: + def _get_completions_usage_stats(self, data: dict) -> UsageStats | None: if not data.get("usage"): return None diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index be70276b..b33a7b14 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -975,7 +975,7 @@ def _calculate_requests_per_second( filtered_statuses = [] filtered_times = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined( safe_getattr(request.scheduler_info.request_timings, "request_start"), safe_getattr(request.scheduler_info.request_timings, "request_end"), @@ -1005,7 +1005,7 @@ def _calculate_request_concurrency( filtered_statuses = [] filtered_times = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined( safe_getattr(request.scheduler_info.request_timings, "request_start"), safe_getattr(request.scheduler_info.request_timings, "request_end"), @@ -1035,7 +1035,7 @@ def _calculate_request_latency( filtered_statuses = [] filtered_values = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.request_latency): continue @@ -1056,7 +1056,7 @@ def _calculate_prompt_token_count( filtered_statuses = [] filtered_values = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.prompt_tokens): continue @@ -1077,7 +1077,7 @@ def _calculate_output_token_count( filtered_statuses = [] filtered_values = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.output_tokens): continue @@ -1098,7 +1098,7 @@ def _calculate_total_token_count( filtered_statuses = [] filtered_values = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.total_tokens): continue @@ -1119,7 +1119,7 @@ def _calculate_time_to_first_token_ms( filtered_statuses = [] filtered_values = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.time_to_first_token_ms): continue @@ -1141,7 +1141,7 @@ def _calculate_time_per_output_token_ms( filtered_values = [] filtered_weights = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.time_to_first_token_ms): continue @@ -1174,7 +1174,7 @@ def _calculate_inter_token_latency_ms( filtered_values = [] filtered_weights = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.inter_token_latency_ms): continue @@ -1199,7 +1199,7 @@ def _calculate_output_tokens_per_second( filtered_first_iter_times = [] filtered_iter_counts = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.output_tokens_per_second): continue @@ -1234,7 +1234,7 @@ def _calculate_tokens_per_second( filtered_iter_counts = [] filtered_first_iter_counts = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.tokens_per_second): continue diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index c4e8fb0f..cacadc94 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -34,10 +34,11 @@ DistributionSummary, RegistryMixin, StatusDistributionSummary, + camelize_str, + recursive_key_update, safe_format_timestamp, split_text_list_by_length, ) -from guidellm.utils import recursive_key_update, camelize_str __all__ = [ "GenerativeBenchmarkerCSV", @@ -369,7 +370,7 @@ def _print_line( f"Value and style length mismatch: {len(value)} vs {len(style)}" ) - for val, sty in zip(value, style): + for val, sty in zip(value, style, strict=False): text.append(val, style=sty) self.console.print(Padding.indent(text, indent)) @@ -568,8 +569,8 @@ async def finalize(self, report: GenerativeBenchmarksReport) -> Path: benchmark_values: list[str | float | list[float]] = [] # Add basic run description info - desc_headers, desc_values = ( - self._get_benchmark_desc_headers_and_values(benchmark) + desc_headers, desc_values = self._get_benchmark_desc_headers_and_values( + benchmark ) benchmark_headers.extend(desc_headers) benchmark_values.extend(desc_values) @@ -680,7 +681,8 @@ def _get_benchmark_status_metrics_stats( return headers, values def _get_benchmark_extras_headers_and_values( - self, benchmark: GenerativeBenchmark, + self, + benchmark: GenerativeBenchmark, ) -> tuple[list[str], list[str]]: headers = ["Profile", "Backend", "Generator Data"] values: list[str] = [ @@ -733,9 +735,7 @@ async def finalize(self, report: GenerativeBenchmarksReport) -> Path: ui_api_data = {} for k, v in camel_data.items(): placeholder_key = f"window.{k} = {{}};" - replacement_value = ( - f"window.{k} = {json.dumps(v, indent=2)};\n" - ) + replacement_value = f"window.{k} = {json.dumps(v, indent=2)};\n" ui_api_data[placeholder_key] = replacement_value create_report(ui_api_data, output_path) diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index 3ff8d0e0..ec4fa839 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -679,7 +679,9 @@ def next_strategy( prev_benchmark.metrics.requests_per_second.successful.mean ) if self.synchronous_rate <= 0 and self.throughput_rate <= 0: - raise RuntimeError("Invalid rates in sweep; aborting. Were there any successful requests?") + raise RuntimeError( + "Invalid rates in sweep; aborting. Were there any successful requests?" + ) self.measured_rates = list( np.linspace( self.synchronous_rate, diff --git a/src/guidellm/benchmark/scenario.py b/src/guidellm/benchmark/scenario.py index 5299616f..73a9a050 100644 --- a/src/guidellm/benchmark/scenario.py +++ b/src/guidellm/benchmark/scenario.py @@ -1,10 +1,11 @@ from __future__ import annotations import json +from collections.abc import Callable from functools import cache, wraps from inspect import Parameter, signature from pathlib import Path -from typing import Annotated, Any, Callable, Literal, TypeVar +from typing import Annotated, Any, Literal, TypeVar import yaml from loguru import logger diff --git a/src/guidellm/dataset/creator.py b/src/guidellm/dataset/creator.py index b95f4c50..fe712c23 100644 --- a/src/guidellm/dataset/creator.py +++ b/src/guidellm/dataset/creator.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict from transformers import PreTrainedTokenizerBase # type: ignore[import] @@ -80,12 +80,12 @@ class DatasetCreator(ABC): def create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, random_seed: int = 42, - split_pref_order: Optional[list[str]] = None, - ) -> tuple[Union[Dataset, IterableDataset], dict[ColumnInputTypes, str]]: + split_pref_order: list[str] | None = None, + ) -> tuple[Dataset | IterableDataset, dict[ColumnInputTypes, str]]: if not cls.is_supported(data, data_args): raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ") @@ -106,7 +106,7 @@ def create( return dataset, column_mappings @classmethod - def extract_args_split(cls, data_args: Optional[dict[str, Any]]) -> str: + def extract_args_split(cls, data_args: dict[str, Any] | None) -> str: split = "auto" if data_args and "split" in data_args: @@ -118,7 +118,7 @@ def extract_args_split(cls, data_args: Optional[dict[str, Any]]) -> str: @classmethod def extract_args_column_mappings( cls, - data_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, ) -> dict[ColumnInputTypes, str]: columns: dict[ColumnInputTypes, str] = {} @@ -143,8 +143,8 @@ def extract_args_column_mappings( @classmethod def extract_dataset_name( - cls, dataset: Union[Dataset, IterableDataset, DatasetDict, IterableDatasetDict] - ) -> Optional[str]: + cls, dataset: Dataset | IterableDataset | DatasetDict | IterableDatasetDict + ) -> str | None: if isinstance(dataset, DatasetDict | IterableDatasetDict): dataset = dataset[list(dataset.keys())[0]] @@ -161,10 +161,10 @@ def extract_dataset_name( @classmethod def extract_dataset_split( cls, - dataset: Union[DatasetDict, IterableDatasetDict], - specified_split: Union[Literal["auto"], str] = "auto", - split_pref_order: Optional[Union[Literal["auto"], list[str]]] = "auto", - ) -> Union[Dataset, IterableDataset]: + dataset: DatasetDict | IterableDatasetDict, + specified_split: Literal["auto"] | str = "auto", + split_pref_order: Literal["auto"] | list[str] | None = "auto", + ) -> Dataset | IterableDataset: if not isinstance(dataset, DatasetDict | IterableDatasetDict): raise ValueError( f"Unsupported data type: {type(dataset)} given for {dataset}." @@ -199,15 +199,15 @@ def extract_dataset_split( @classmethod @abstractmethod - def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: ... + def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: ... @classmethod @abstractmethod def handle_create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, random_seed: int, - ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: ... + ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: ... diff --git a/src/guidellm/dataset/entrypoints.py b/src/guidellm/dataset/entrypoints.py index cf689956..1da2222a 100644 --- a/src/guidellm/dataset/entrypoints.py +++ b/src/guidellm/dataset/entrypoints.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from datasets import Dataset, IterableDataset from transformers import PreTrainedTokenizerBase # type: ignore[import] @@ -15,12 +15,12 @@ def load_dataset( data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, random_seed: int = 42, - split_pref_order: Optional[list[str]] = None, -) -> tuple[Union[Dataset, IterableDataset], dict[ColumnInputTypes, str]]: + split_pref_order: list[str] | None = None, +) -> tuple[Dataset | IterableDataset, dict[ColumnInputTypes, str]]: creators = [ InMemoryDatasetCreator, SyntheticDatasetCreator, diff --git a/src/guidellm/dataset/file.py b/src/guidellm/dataset/file.py index 455ef580..718cb46f 100644 --- a/src/guidellm/dataset/file.py +++ b/src/guidellm/dataset/file.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import pandas as pd # type: ignore[import] from datasets import ( @@ -30,7 +30,7 @@ class FileDatasetCreator(DatasetCreator): } @classmethod - def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003 + def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003 if isinstance(data, str | Path) and (path := Path(data)).exists(): # local folder or py file, assume supported return path.suffix.lower() in cls.SUPPORTED_TYPES @@ -41,11 +41,11 @@ def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: def handle_create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], # noqa: ARG003 - processor_args: Optional[dict[str, Any]], # noqa: ARG003 + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003 + processor_args: dict[str, Any] | None, # noqa: ARG003 random_seed: int, # noqa: ARG003 - ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: + ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: if not isinstance(data, str | Path): raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ") @@ -63,8 +63,8 @@ def handle_create( @classmethod def load_dataset( - cls, path: Path, data_args: Optional[dict[str, Any]] - ) -> Union[Dataset, IterableDataset]: + cls, path: Path, data_args: dict[str, Any] | None + ) -> Dataset | IterableDataset: if path.suffix.lower() in {".txt", ".text"}: with path.open("r") as file: items = file.readlines() diff --git a/src/guidellm/dataset/hf_datasets.py b/src/guidellm/dataset/hf_datasets.py index 56c79936..bd8d8c23 100644 --- a/src/guidellm/dataset/hf_datasets.py +++ b/src/guidellm/dataset/hf_datasets.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from datasets import ( Dataset, @@ -18,7 +18,7 @@ class HFDatasetsCreator(DatasetCreator): @classmethod - def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003 + def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003 if isinstance( data, (Dataset, DatasetDict, IterableDataset, IterableDatasetDict) ): @@ -42,11 +42,11 @@ def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: def handle_create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], # noqa: ARG003 - processor_args: Optional[dict[str, Any]], # noqa: ARG003 + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003 + processor_args: dict[str, Any] | None, # noqa: ARG003 random_seed: int, # noqa: ARG003 - ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: + ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: if isinstance(data, str | Path): data = load_dataset(data, **(data_args or {})) elif data_args: diff --git a/src/guidellm/dataset/in_memory.py b/src/guidellm/dataset/in_memory.py index af84f658..0461948c 100644 --- a/src/guidellm/dataset/in_memory.py +++ b/src/guidellm/dataset/in_memory.py @@ -1,6 +1,6 @@ from collections.abc import Iterable from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from datasets import ( Dataset, @@ -17,18 +17,18 @@ class InMemoryDatasetCreator(DatasetCreator): @classmethod - def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003 + def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003 return isinstance(data, Iterable) and not isinstance(data, str) @classmethod def handle_create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], # noqa: ARG003 - processor_args: Optional[dict[str, Any]], # noqa: ARG003 + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003 + processor_args: dict[str, Any] | None, # noqa: ARG003 random_seed: int, # noqa: ARG003 - ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: + ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: if not isinstance(data, Iterable): raise TypeError( f"Unsupported data format. Expected Iterable[Any], got {type(data)}" diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 8c30f0f7..8a1626fe 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Iterator from itertools import cycle from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Literal import yaml from datasets import ( @@ -35,17 +35,17 @@ class SyntheticDatasetConfig(BaseModel): description="The average number of text tokens generated for prompts.", gt=0, ) - prompt_tokens_stdev: Optional[int] = Field( + prompt_tokens_stdev: int | None = Field( description="The standard deviation of the tokens generated for prompts.", gt=0, default=None, ) - prompt_tokens_min: Optional[int] = Field( + prompt_tokens_min: int | None = Field( description="The minimum number of text tokens generated for prompts.", gt=0, default=None, ) - prompt_tokens_max: Optional[int] = Field( + prompt_tokens_max: int | None = Field( description="The maximum number of text tokens generated for prompts.", gt=0, default=None, @@ -54,17 +54,17 @@ class SyntheticDatasetConfig(BaseModel): description="The average number of text tokens generated for outputs.", gt=0, ) - output_tokens_stdev: Optional[int] = Field( + output_tokens_stdev: int | None = Field( description="The standard deviation of the tokens generated for outputs.", gt=0, default=None, ) - output_tokens_min: Optional[int] = Field( + output_tokens_min: int | None = Field( description="The minimum number of text tokens generated for outputs.", gt=0, default=None, ) - output_tokens_max: Optional[int] = Field( + output_tokens_max: int | None = Field( description="The maximum number of text tokens generated for outputs.", gt=0, default=None, @@ -80,7 +80,7 @@ class SyntheticDatasetConfig(BaseModel): ) @staticmethod - def parse_str(data: Union[str, Path]) -> "SyntheticDatasetConfig": + def parse_str(data: str | Path) -> "SyntheticDatasetConfig": if ( isinstance(data, Path) or data.strip().endswith(".config") @@ -117,7 +117,7 @@ def parse_key_value_pairs(data: str) -> "SyntheticDatasetConfig": return SyntheticDatasetConfig(**config_dict) # type: ignore[arg-type] @staticmethod - def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig": + def parse_config_file(data: str | Path) -> "SyntheticDatasetConfig": with Path(data).open("r") as file: config_dict = yaml.safe_load(file) @@ -128,7 +128,7 @@ class SyntheticTextItemsGenerator( Iterable[ dict[ Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - Union[str, int], + str | int, ] ] ): @@ -150,7 +150,7 @@ def __iter__( ) -> Iterator[ dict[ Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - Union[str, int], + str | int, ] ]: prompt_tokens_sampler = IntegerRangeSampler( @@ -177,7 +177,7 @@ def __iter__( for _, prompt_tokens, output_tokens in zip( range(self.config.samples), prompt_tokens_sampler, - output_tokens_sampler, + output_tokens_sampler, strict=False, ): start_index = rand.randint(0, len(self.text_creator.words)) prompt_text = self.processor.decode( @@ -194,7 +194,7 @@ def __iter__( } def _create_prompt( - self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None + self, prompt_tokens: int, start_index: int, unique_prefix: int | None = None ) -> list[int]: if prompt_tokens <= 0: return [] @@ -224,7 +224,7 @@ class SyntheticDatasetCreator(DatasetCreator): def is_supported( cls, data: Any, - data_args: Optional[dict[str, Any]], # noqa: ARG003 + data_args: dict[str, Any] | None, # noqa: ARG003 ) -> bool: if ( isinstance(data, Path) @@ -248,11 +248,11 @@ def is_supported( def handle_create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, random_seed: int, - ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: + ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: processor = check_load_processor( processor, processor_args, @@ -270,7 +270,7 @@ def handle_create( @classmethod def extract_args_column_mappings( cls, - data_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, ) -> dict[ColumnInputTypes, str]: data_args_columns = super().extract_args_column_mappings(data_args) diff --git a/src/guidellm/logger.py b/src/guidellm/logger.py index 70259bad..da3464f9 100644 --- a/src/guidellm/logger.py +++ b/src/guidellm/logger.py @@ -72,7 +72,7 @@ def configure_logger(config: LoggingSettings = settings.logging): sys.stdout, level=config.console_log_level.upper(), format="{time:YY-MM-DD HH:mm:ss}|{level: <8} \ - |{name}:{function}:{line} - {message}" + |{name}:{function}:{line} - {message}", ) if config.log_file or config.log_file_level: diff --git a/src/guidellm/preprocess/dataset.py b/src/guidellm/preprocess/dataset.py index a94b8a14..b02efec5 100644 --- a/src/guidellm/preprocess/dataset.py +++ b/src/guidellm/preprocess/dataset.py @@ -1,9 +1,9 @@ import json import os -from collections.abc import Iterator +from collections.abc import Callable, Iterator from enum import Enum from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any import yaml from datasets import Dataset @@ -32,7 +32,7 @@ def handle_ignore_strategy( min_prompt_tokens: int, tokenizer: PreTrainedTokenizerBase, **_kwargs, -) -> Optional[str]: +) -> str | None: """ Ignores prompts that are shorter than the required minimum token length. @@ -56,7 +56,7 @@ def handle_concatenate_strategy( tokenizer: PreTrainedTokenizerBase, concat_delimiter: str, **_kwargs, -) -> Optional[str]: +) -> str | None: """ Concatenates prompts until the minimum token requirement is met. @@ -117,7 +117,7 @@ def handle_error_strategy( min_prompt_tokens: int, tokenizer: PreTrainedTokenizerBase, **_kwargs, -) -> Optional[str]: +) -> str | None: """ Raises an error if the prompt is too short. @@ -150,24 +150,24 @@ class TokensConfig(BaseModel): description="The average number of tokens.", gt=0, ) - stdev: Optional[int] = Field( + stdev: int | None = Field( description="The standard deviation of the tokens.", gt=0, default=None, ) - min: Optional[int] = Field( + min: int | None = Field( description="The minimum number of tokens.", gt=0, default=None, ) - max: Optional[int] = Field( + max: int | None = Field( description="The maximum number of tokens.", gt=0, default=None, ) @staticmethod - def parse_str(data: Union[str, Path]) -> "TokensConfig": + def parse_str(data: str | Path) -> "TokensConfig": """ Parses a string or path into a TokensConfig object. Supports: - JSON string @@ -215,14 +215,14 @@ def parse_key_value_pairs(data: str) -> "TokensConfig": return TokensConfig(**config_dict) # type: ignore[arg-type] @staticmethod - def parse_config_file(data: Union[str, Path]) -> "TokensConfig": + def parse_config_file(data: str | Path) -> "TokensConfig": with Path(data).open("r") as file: config_dict = yaml.safe_load(file) return TokensConfig(**config_dict) -def _validate_output_suffix(output_path: Union[str, Path]) -> None: +def _validate_output_suffix(output_path: str | Path) -> None: output_path = Path(output_path) suffix = output_path.suffix.lower() if suffix not in SUPPORTED_TYPES: @@ -233,18 +233,18 @@ def _validate_output_suffix(output_path: Union[str, Path]) -> None: def process_dataset( - data: Union[str, Path], - output_path: Union[str, Path], - processor: Union[str, Path, PreTrainedTokenizerBase], - prompt_tokens: Union[str, Path], - output_tokens: Union[str, Path], - processor_args: Optional[dict[str, Any]] = None, - data_args: Optional[dict[str, Any]] = None, + data: str | Path, + output_path: str | Path, + processor: str | Path | PreTrainedTokenizerBase, + prompt_tokens: str | Path, + output_tokens: str | Path, + processor_args: dict[str, Any] | None = None, + data_args: dict[str, Any] | None = None, short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE, - pad_char: Optional[str] = None, - concat_delimiter: Optional[str] = None, + pad_char: str | None = None, + concat_delimiter: str | None = None, push_to_hub: bool = False, - hub_dataset_id: Optional[str] = None, + hub_dataset_id: str | None = None, random_seed: int = 42, ) -> None: """ @@ -354,7 +354,7 @@ def process_dataset( def push_dataset_to_hub( - hub_dataset_id: Optional[str], + hub_dataset_id: str | None, processed_dataset: Dataset, ) -> None: """ diff --git a/src/guidellm/presentation/data_models.py b/src/guidellm/presentation/data_models.py index c1e8f13f..ff2863b4 100644 --- a/src/guidellm/presentation/data_models.py +++ b/src/guidellm/presentation/data_models.py @@ -1,7 +1,7 @@ import random from collections import defaultdict from math import ceil -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pydantic import BaseModel, computed_field @@ -12,14 +12,14 @@ class Bucket(BaseModel): - value: Union[float, int] + value: float | int count: int @staticmethod def from_data( - data: Union[list[float], list[int]], - bucket_width: Optional[float] = None, - n_buckets: Optional[int] = None, + data: list[float] | list[int], + bucket_width: float | None = None, + n_buckets: int | None = None, ) -> tuple[list["Bucket"], float]: if not data: return [], 1.0 @@ -35,7 +35,7 @@ def from_data( else: n_buckets = ceil(range_v / bucket_width) - bucket_counts: defaultdict[Union[float, int], int] = defaultdict(int) + bucket_counts: defaultdict[float | int, int] = defaultdict(int) for val in data: idx = int((val - min_v) // bucket_width) if idx >= n_buckets: @@ -80,7 +80,7 @@ def from_benchmarks(cls, benchmarks: list["GenerativeBenchmark"]): class Distribution(BaseModel): - statistics: Optional[DistributionSummary] = None + statistics: DistributionSummary | None = None buckets: list[Bucket] bucket_width: float @@ -190,7 +190,7 @@ class TabularDistributionSummary(DistributionSummary): """ @computed_field - def percentile_rows(self) -> list[dict[str, Union[str, float]]]: + def percentile_rows(self) -> list[dict[str, str | float]]: rows = [ {"percentile": name, "value": value} for name, value in self.percentiles.model_dump().items() diff --git a/src/guidellm/presentation/injector.py b/src/guidellm/presentation/injector.py index bb1fd684..1e78080e 100644 --- a/src/guidellm/presentation/injector.py +++ b/src/guidellm/presentation/injector.py @@ -1,6 +1,5 @@ import re from pathlib import Path -from typing import Union from loguru import logger @@ -8,7 +7,7 @@ from guidellm.utils.text import load_text -def create_report(js_data: dict, output_path: Union[str, Path]) -> Path: +def create_report(js_data: dict, output_path: str | Path) -> Path: """ Creates a report from the dictionary and saves it to the output path. diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 607a7455..e4a6934e 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -4,8 +4,6 @@ from typing import ( Any, Literal, - Optional, - Union, ) from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict @@ -43,9 +41,9 @@ def description(self) -> RequestLoaderDescription: ... class GenerativeRequestLoaderDescription(RequestLoaderDescription): type_: Literal["generative_request_loader"] = "generative_request_loader" # type: ignore[assignment] data: str - data_args: Optional[dict[str, Any]] + data_args: dict[str, Any] | None processor: str - processor_args: Optional[dict[str, Any]] + processor_args: dict[str, Any] | None class GenerativeRequestLoader(RequestLoader): @@ -69,18 +67,10 @@ class GenerativeRequestLoader(RequestLoader): def __init__( self, - data: Union[ - str, - Path, - Iterable[Union[str, dict[str, Any]]], - Dataset, - DatasetDict, - IterableDataset, - IterableDatasetDict, - ], - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + data: str | Path | Iterable[str | dict[str, Any]] | Dataset | DatasetDict | IterableDataset | IterableDatasetDict, + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, shuffle: bool = True, iter_type: Literal["finite", "infinite"] = "finite", random_seed: int = 42, @@ -202,7 +192,7 @@ def _extract_text_column(self) -> str: "'data_args' dictionary." ) - def _extract_prompt_tokens_count_column(self) -> Optional[str]: + def _extract_prompt_tokens_count_column(self) -> str | None: column_names = self._dataset_columns() if column_names and "prompt_tokens_count" in column_names: @@ -213,7 +203,7 @@ def _extract_prompt_tokens_count_column(self) -> Optional[str]: return None - def _extract_output_tokens_count_column(self) -> Optional[str]: + def _extract_output_tokens_count_column(self) -> str | None: column_names = self._dataset_columns() if column_names and "output_tokens_count" in column_names: @@ -224,7 +214,7 @@ def _extract_output_tokens_count_column(self) -> Optional[str]: return None - def _dataset_columns(self, err_msg: Optional[str] = None) -> Optional[list[str]]: + def _dataset_columns(self, err_msg: str | None = None) -> list[str] | None: try: column_names = self.dataset.column_names @@ -240,7 +230,7 @@ def _dataset_columns(self, err_msg: Optional[str] = None) -> Optional[list[str]] def _get_dataset_iter( self, scope_create_count: int - ) -> Optional[Iterator[dict[str, Any]]]: + ) -> Iterator[dict[str, Any]] | None: if scope_create_count > 0 and self.iter_type != "infinite": return None diff --git a/src/guidellm/request/request.py b/src/guidellm/request/request.py index bf4e59fb..83dc40f1 100644 --- a/src/guidellm/request/request.py +++ b/src/guidellm/request/request.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import Field @@ -33,7 +33,7 @@ class GenerationRequest(StandardBaseModel): of output tokens. Used for controlling the behavior of the backend. """ - request_id: Optional[str] = Field( + request_id: str | None = Field( default_factory=lambda: str(uuid.uuid4()), description="The unique identifier for the request.", ) diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py index 21d30ec8..e2583987 100644 --- a/src/guidellm/scheduler/objects.py +++ b/src/guidellm/scheduler/objects.py @@ -19,7 +19,6 @@ Literal, Protocol, TypeVar, - Union, runtime_checkable, ) @@ -56,10 +55,7 @@ MultiTurnRequestT = TypeAliasType( "MultiTurnRequestT", - Union[ - list[Union[RequestT, tuple[RequestT, float]]], - tuple[Union[RequestT, tuple[RequestT, float]]], - ], + list[RequestT | tuple[RequestT, float]] | tuple[RequestT | tuple[RequestT, float]], type_params=(RequestT,), ) """Multi-turn request structure supporting conversation history with optional delays.""" diff --git a/src/guidellm/utils/statistics.py b/src/guidellm/utils/statistics.py index 04484c2c..f71a2c24 100644 --- a/src/guidellm/utils/statistics.py +++ b/src/guidellm/utils/statistics.py @@ -389,8 +389,7 @@ def from_iterable_request_times( events[global_end] = 0 for (_, end), first_iter, first_iter_count, total_count in zip( - requests, first_iter_times, first_iter_counts, iter_counts, - strict=True + requests, first_iter_times, first_iter_counts, iter_counts, strict=True ): events[first_iter] += first_iter_count diff --git a/tests/integration/scheduler/test_scheduler.py b/tests/integration/scheduler/test_scheduler.py index 51abf59b..65bff95f 100644 --- a/tests/integration/scheduler/test_scheduler.py +++ b/tests/integration/scheduler/test_scheduler.py @@ -167,7 +167,7 @@ def _request_indices(): _request_indices(), received_updates.keys(), received_updates.values(), - received_responses, + received_responses, strict=False, ): assert req == f"req_{index}" assert resp in (f"response_for_{req}", f"mock_error_for_{req}") diff --git a/tests/unit/benchmark/test_output.py b/tests/unit/benchmark/test_output.py index 6763d978..67e65e2e 100644 --- a/tests/unit/benchmark/test_output.py +++ b/tests/unit/benchmark/test_output.py @@ -10,7 +10,10 @@ from guidellm.benchmark import ( GenerativeBenchmarksReport, ) -from guidellm.benchmark.output import GenerativeBenchmarkerConsole, GenerativeBenchmarkerCSV +from guidellm.benchmark.output import ( + GenerativeBenchmarkerConsole, + GenerativeBenchmarkerCSV, +) from tests.unit.mock_benchmark import mock_generative_benchmark @@ -80,6 +83,7 @@ def test_file_yaml(): mock_path.unlink() + @pytest.mark.asyncio async def test_file_csv(): mock_benchmark = mock_generative_benchmark() @@ -105,7 +109,8 @@ def test_console_benchmarks_profile_str(): console = GenerativeBenchmarkerConsole() mock_benchmark = mock_generative_benchmark() assert ( - console._get_profile_str(mock_benchmark) == "type=synchronous, strategies=['synchronous']" + console._get_profile_str(mock_benchmark) + == "type=synchronous, strategies=['synchronous']" ) diff --git a/tests/unit/dataset/test_synthetic.py b/tests/unit/dataset/test_synthetic.py index e3110fa3..544634c8 100644 --- a/tests/unit/dataset/test_synthetic.py +++ b/tests/unit/dataset/test_synthetic.py @@ -530,7 +530,7 @@ def mock_sampler_side_effect(*args, **kwargs): # Results should be identical with same seed assert len(items1) == len(items2) - for item1, item2 in zip(items1, items2): + for item1, item2 in zip(items1, items2, strict=False): assert item1["prompt"] == item2["prompt"] assert item1["prompt_tokens_count"] == item2["prompt_tokens_count"] assert item1["output_tokens_count"] == item2["output_tokens_count"] diff --git a/tests/unit/mock_backend.py b/tests/unit/mock_backend.py index 5ac069a8..3b7237e0 100644 --- a/tests/unit/mock_backend.py +++ b/tests/unit/mock_backend.py @@ -6,7 +6,7 @@ import random import time from collections.abc import AsyncIterator -from typing import Any, Optional +from typing import Any from lorem.text import TextLorem @@ -32,7 +32,7 @@ def __init__( self, target: str = "mock-target", model: str = "mock-model", - iter_delay: Optional[float] = None, + iter_delay: float | None = None, ): """ Initialize mock backend. @@ -53,7 +53,7 @@ def target(self) -> str: return self._target @property - def model(self) -> Optional[str]: + def model(self) -> str | None: """Model name for the mock backend.""" return self._model @@ -87,7 +87,7 @@ async def validate(self) -> None: if not self._in_process: raise RuntimeError("Backend not started up for process") - async def default_model(self) -> Optional[str]: + async def default_model(self) -> str | None: """ Return the default model for the mock backend. """ @@ -97,7 +97,7 @@ async def resolve( self, request: GenerationRequest, request_info: ScheduledRequestInfo, - history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + history: list[tuple[GenerationRequest, GenerationResponse]] | None = None, ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: """ Process a generation request and yield progressive responses. @@ -170,7 +170,7 @@ def _estimate_prompt_tokens(content: str) -> int: return len(str(content).split()) @staticmethod - def _get_tokens(token_count: Optional[int] = None) -> list[str]: + def _get_tokens(token_count: int | None = None) -> list[str]: """ Generate mock tokens for response. """ diff --git a/tests/unit/mock_benchmark.py b/tests/unit/mock_benchmark.py index cdf4375a..d7bfe7c9 100644 --- a/tests/unit/mock_benchmark.py +++ b/tests/unit/mock_benchmark.py @@ -1,4 +1,5 @@ """Mock benchmark objects for unit testing.""" + from guidellm.backends import GenerationRequestTimings from guidellm.benchmark import ( BenchmarkSchedulerStats, diff --git a/tests/unit/utils/test_encoding.py b/tests/unit/utils/test_encoding.py index cc4600cf..5664bcb0 100644 --- a/tests/unit/utils/test_encoding.py +++ b/tests/unit/utils/test_encoding.py @@ -476,7 +476,7 @@ def test_to_from_sequence_collections(self, collection): seq = inst.to_sequence(collection) out = inst.from_sequence(seq) assert len(out) == len(collection) - assert all(a == b for a, b in zip(out, list(collection))) + assert all(a == b for a, b in zip(out, list(collection), strict=False)) @pytest.mark.sanity def test_to_from_sequence_mapping(self): diff --git a/tests/unit/utils/test_typing.py b/tests/unit/utils/test_typing.py index fafa8765..009473f5 100644 --- a/tests/unit/utils/test_typing.py +++ b/tests/unit/utils/test_typing.py @@ -2,10 +2,9 @@ Test suite for the typing utilities module. """ -from typing import Annotated, Literal, Union +from typing import Annotated, Literal, TypeAlias, Union import pytest -from typing_extensions import TypeAlias from guidellm.utils.typing import get_literal_vals @@ -15,7 +14,7 @@ Literal["synchronous", "concurrent", "throughput", "constant", "poisson"], "Valid strategy type identifiers for scheduling request patterns", ] -StrategyProfileType: TypeAlias = Union[LocalStrategyType, LocalProfileType] +StrategyProfileType: TypeAlias = LocalStrategyType | LocalProfileType class TestGetLiteralVals: From 1e8974c2f60fe340cacdb1beaef0957c46e47398 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 9 Oct 2025 20:10:42 -0400 Subject: [PATCH 3/4] Fix remaining ruff errors Signed-off-by: Jared O'Connell --- pyproject.toml | 2 +- src/guidellm/__main__.py | 7 +++---- src/guidellm/backends/openai.py | 6 ++++-- src/guidellm/benchmark/profile.py | 3 ++- src/guidellm/dataset/hf_datasets.py | 4 ++-- src/guidellm/request/loader.py | 3 ++- tests/unit/benchmark/test_output.py | 2 +- tests/unit/mock_server/test_server.py | 2 +- tests/unit/scheduler/test_objects.py | 8 ++++---- tests/unit/scheduler/test_strategies.py | 2 +- tests/unit/utils/test_synchronous.py | 2 +- tests/unit/utils/test_typing.py | 6 +++--- 12 files changed, 25 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 935587d0..f1624d3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,7 +167,7 @@ ignore_missing_imports = true target-version = "py310" line-length = 88 indent-width = 4 -exclude = ["build", "dist", "env", ".venv"] +exclude = ["build", "dist", "env", ".venv*"] [tool.ruff.format] quote-style = "double" diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 0a035551..dbc8e1da 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -28,7 +28,7 @@ import asyncio import codecs from pathlib import Path -from typing import Annotated, Union +from typing import Annotated import click from pydantic import ValidationError @@ -78,9 +78,8 @@ "run", ] -STRATEGY_PROFILE_CHOICES: Annotated[ - list[str], "Available strategy and profile choices for benchmark execution types" -] = list(get_literal_vals(Union[ProfileType, StrategyType])) +# Available strategy and profile choices for benchmark execution types +STRATEGY_PROFILE_CHOICES: list[str] = list(get_literal_vals(ProfileType | StrategyType)) def decode_escaped_str(_ctx, _param, value): diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index fd539063..c8eb70f3 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -33,6 +33,8 @@ __all__ = ["OpenAIHTTPBackend", "UsageStats"] +ContentT = str | list[str | dict[str, str | dict[str, str]] | Path | Image.Image] | Any + @dataclasses.dataclass class UsageStats: @@ -431,7 +433,7 @@ async def text_completions( async def chat_completions( self, - content: str | list[str | dict[str, str | dict[str, str]] | Path | Image.Image] | Any, + content: ContentT, request_id: str | None = None, # noqa: ARG002 output_token_count: int | None = None, raw_content: bool = False, @@ -537,7 +539,7 @@ def _get_params(self, endpoint_type: str) -> dict[str, str]: def _get_chat_messages( self, - content: str | list[str | dict[str, str | dict[str, str]] | Path | Image.Image] | Any, + content: ContentT, ) -> list[dict[str, Any]]: if isinstance(content, str): return [{"role": "user", "content": content}] diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index ec4fa839..87a9a2be 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -680,7 +680,8 @@ def next_strategy( ) if self.synchronous_rate <= 0 and self.throughput_rate <= 0: raise RuntimeError( - "Invalid rates in sweep; aborting. Were there any successful requests?" + "Invalid rates in sweep; aborting. " + "Were there any successful requests?" ) self.measured_rates = list( np.linspace( diff --git a/src/guidellm/dataset/hf_datasets.py b/src/guidellm/dataset/hf_datasets.py index bd8d8c23..d1be46c1 100644 --- a/src/guidellm/dataset/hf_datasets.py +++ b/src/guidellm/dataset/hf_datasets.py @@ -20,7 +20,7 @@ class HFDatasetsCreator(DatasetCreator): @classmethod def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003 if isinstance( - data, (Dataset, DatasetDict, IterableDataset, IterableDatasetDict) + data, Dataset | DatasetDict | IterableDataset | IterableDatasetDict ): # base type is supported return True @@ -55,7 +55,7 @@ def handle_create( ) if isinstance( - data, (Dataset, DatasetDict, IterableDataset, IterableDatasetDict) + data, Dataset | DatasetDict | IterableDataset | IterableDatasetDict ): return data diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index e4a6934e..ac34131e 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -67,7 +67,8 @@ class GenerativeRequestLoader(RequestLoader): def __init__( self, - data: str | Path | Iterable[str | dict[str, Any]] | Dataset | DatasetDict | IterableDataset | IterableDatasetDict, + data: str | Path | Iterable[str | dict[str, Any]] | Dataset | DatasetDict | \ + IterableDataset | IterableDatasetDict, data_args: dict[str, Any] | None, processor: str | Path | PreTrainedTokenizerBase | None, processor_args: dict[str, Any] | None, diff --git a/tests/unit/benchmark/test_output.py b/tests/unit/benchmark/test_output.py index 67e65e2e..6310da88 100644 --- a/tests/unit/benchmark/test_output.py +++ b/tests/unit/benchmark/test_output.py @@ -93,7 +93,7 @@ async def test_file_csv(): csv_benchmarker = GenerativeBenchmarkerCSV(output_path=mock_path) await csv_benchmarker.finalize(report) - with mock_path.open("r") as file: + with mock_path.open("r") as file: # noqa: ASYNC230 # This is a test. reader = csv.reader(file) headers = next(reader) rows = list(reader) diff --git a/tests/unit/mock_server/test_server.py b/tests/unit/mock_server/test_server.py index 008103c3..ba712fb6 100644 --- a/tests/unit/mock_server/test_server.py +++ b/tests/unit/mock_server/test_server.py @@ -162,7 +162,7 @@ async def test_health_endpoint(self, mock_server_instance): assert "status" in data assert data["status"] == "healthy" assert "timestamp" in data - assert isinstance(data["timestamp"], (int, float)) + assert isinstance(data["timestamp"], int | float) @pytest.mark.smoke @pytest.mark.asyncio diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py index fc5610fd..2e0374e4 100644 --- a/tests/unit/scheduler/test_objects.py +++ b/tests/unit/scheduler/test_objects.py @@ -340,7 +340,7 @@ def test_class_signatures(self): for key in self.CHECK_KEYS: assert key in fields field_info = fields[key] - assert field_info.annotation in (Union[float, None], Optional[float]) + assert field_info.annotation in (Union[float, None], Optional[float]) # noqa: UP007 assert field_info.default is None @pytest.mark.smoke @@ -453,7 +453,7 @@ def test_class_signatures(self): for key in self.CHECK_KEYS: assert key in fields field_info = fields[key] - assert field_info.annotation in (Union[float, None], Optional[float]) + assert field_info.annotation in (Union[float, None], Optional[float]) # noqa: UP007 assert field_info.default is None @pytest.mark.smoke @@ -704,11 +704,11 @@ def test_marshalling(self, valid_instances): else: assert original_value is None or isinstance( original_value, - (RequestSchedulerTimings, MeasuredRequestTimings), + RequestSchedulerTimings | MeasuredRequestTimings, ) assert reconstructed_value is None or isinstance( reconstructed_value, - (RequestSchedulerTimings, MeasuredRequestTimings), + RequestSchedulerTimings | MeasuredRequestTimings, ) else: assert original_value == reconstructed_value diff --git a/tests/unit/scheduler/test_strategies.py b/tests/unit/scheduler/test_strategies.py index 67a2d77d..143a3130 100644 --- a/tests/unit/scheduler/test_strategies.py +++ b/tests/unit/scheduler/test_strategies.py @@ -225,7 +225,7 @@ def test_lifecycle( for index in range(max(5, startup_requests + 2)): offset = instance.next_offset() - assert isinstance(offset, (int, float)) + assert isinstance(offset, int | float) if index < startup_requests: expected_offset = initial_offset + (index + 1) * startup_delay diff --git a/tests/unit/utils/test_synchronous.py b/tests/unit/utils/test_synchronous.py index 1a9ea2c9..620ba3fa 100644 --- a/tests/unit/utils/test_synchronous.py +++ b/tests/unit/utils/test_synchronous.py @@ -226,7 +226,7 @@ async def test_invocation(self, objects_types, expected_result): async def set_target(): await asyncio.sleep(0.01) obj = objects[expected_result] - if isinstance(obj, (threading.Event, ProcessingEvent)): + if isinstance(obj, threading.Event | ProcessingEvent): obj.set() else: await asyncio.to_thread(obj.wait) diff --git a/tests/unit/utils/test_typing.py b/tests/unit/utils/test_typing.py index 009473f5..1e31ef8e 100644 --- a/tests/unit/utils/test_typing.py +++ b/tests/unit/utils/test_typing.py @@ -2,7 +2,7 @@ Test suite for the typing utilities module. """ -from typing import Annotated, Literal, TypeAlias, Union +from typing import Annotated, Literal, TypeAlias import pytest @@ -53,7 +53,7 @@ def test_inline_union_type(self): ### WRITTEN BY AI ### """ - result = get_literal_vals(Union[LocalProfileType, LocalStrategyType]) + result = get_literal_vals(LocalProfileType | LocalStrategyType) expected = frozenset( { "synchronous", @@ -117,6 +117,6 @@ def test_literal_union(self): ### WRITTEN BY AI ### """ - result = get_literal_vals(Union[Literal["test", "test2"], Literal["test3"]]) + result = get_literal_vals(Literal["test", "test2"] | Literal["test3"]) expected = frozenset({"test", "test2", "test3"}) assert result == expected From d0dad5aa7f752e88d250284bb48cb2f8ed78d5ff Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Fri, 10 Oct 2025 12:55:32 -0400 Subject: [PATCH 4/4] Fix unit tests Signed-off-by: Jared O'Connell --- tests/unit/scheduler/test_objects.py | 4 ++-- tests/unit/utils/test_synchronous.py | 32 +++++++++++++++++----------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py index 2e0374e4..2fc4c86f 100644 --- a/tests/unit/scheduler/test_objects.py +++ b/tests/unit/scheduler/test_objects.py @@ -3,6 +3,7 @@ import inspect import typing from collections.abc import AsyncIterator +from types import UnionType from typing import Any, Literal, Optional, TypeVar, Union import pytest @@ -62,8 +63,7 @@ def test_multi_turn_request_t(): assert MultiTurnRequestT.__name__ == "MultiTurnRequestT" value = MultiTurnRequestT.__value__ - assert hasattr(value, "__origin__") - assert value.__origin__ is Union + assert isinstance(value, UnionType) type_params = getattr(MultiTurnRequestT, "__type_params__", ()) assert len(type_params) == 1 diff --git a/tests/unit/utils/test_synchronous.py b/tests/unit/utils/test_synchronous.py index 620ba3fa..7acd5b4a 100644 --- a/tests/unit/utils/test_synchronous.py +++ b/tests/unit/utils/test_synchronous.py @@ -6,7 +6,7 @@ from functools import wraps from multiprocessing.synchronize import Barrier as ProcessingBarrier from multiprocessing.synchronize import Event as ProcessingEvent -from typing import Union +from typing import get_args import pytest @@ -32,17 +32,25 @@ async def new_func(*args, **kwargs): def test_sync_object_types_alias(): - """Test that SyncObjectTypesAlias is defined correctly as a type alias.""" - assert hasattr(SyncObjectTypesAlias, "__origin__") - if hasattr(SyncObjectTypesAlias, "__args__"): - actual_type = SyncObjectTypesAlias.__args__[0] - assert hasattr(actual_type, "__origin__") - assert actual_type.__origin__ is Union - union_args = actual_type.__args__ - assert threading.Event in union_args - assert ProcessingEvent in union_args - assert threading.Barrier in union_args - assert ProcessingBarrier in union_args + """ + Test that SyncObjectTypesAlias is defined correctly as a type alias. + + ## WRITTEN BY AI ## + """ + # Get the actual types from the union alias + actual_types = get_args(SyncObjectTypesAlias) + + # Define the set of expected types + expected_types = { + threading.Event, + ProcessingEvent, + threading.Barrier, + ProcessingBarrier, + } + + # Assert that the set of actual types matches the expected set. + # Using a set comparison is robust as it ignores the order. + assert set(actual_types) == expected_types class TestWaitForSyncEvent: