diff --git a/pyproject.toml b/pyproject.toml index 935587d0..f1624d3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,7 +167,7 @@ ignore_missing_imports = true target-version = "py310" line-length = 88 indent-width = 4 -exclude = ["build", "dist", "env", ".venv"] +exclude = ["build", "dist", "env", ".venv*"] [tool.ruff.format] quote-style = "double" diff --git a/setup.py b/setup.py index 623bad28..d3b92889 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ import os import re from pathlib import Path -from typing import Optional, Union from packaging.version import Version from setuptools import setup @@ -11,7 +10,7 @@ TAG_VERSION_PATTERN = re.compile(r"^v(\d+\.\d+\.\d+)$") -def get_last_version_diff() -> tuple[Version, Optional[str], Optional[int]]: +def get_last_version_diff() -> tuple[Version, str | None, int | None]: """ Get the last version, last tag, and the number of commits since the last tag. If no tags are found, return the last release version and None for the tag/commits. @@ -38,8 +37,8 @@ def get_last_version_diff() -> tuple[Version, Optional[str], Optional[int]]: def get_next_version( - build_type: str, build_iteration: Optional[Union[str, int]] -) -> tuple[Version, Optional[str], int]: + build_type: str, build_iteration: str | int | None +) -> tuple[Version, str | None, int]: """ Get the next version based on the build type and iteration. - build_type == release: take the last version and add a post if build iteration diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 0a035551..dbc8e1da 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -28,7 +28,7 @@ import asyncio import codecs from pathlib import Path -from typing import Annotated, Union +from typing import Annotated import click from pydantic import ValidationError @@ -78,9 +78,8 @@ "run", ] -STRATEGY_PROFILE_CHOICES: Annotated[ - list[str], "Available strategy and profile choices for benchmark execution types" -] = list(get_literal_vals(Union[ProfileType, StrategyType])) +# Available strategy and profile choices for benchmark execution types +STRATEGY_PROFILE_CHOICES: list[str] = list(get_literal_vals(ProfileType | StrategyType)) def decode_escaped_str(_ctx, _param, value): diff --git a/src/guidellm/backends/objects.py b/src/guidellm/backends/objects.py index 05280940..001aeb70 100644 --- a/src/guidellm/backends/objects.py +++ b/src/guidellm/backends/objects.py @@ -7,7 +7,7 @@ """ import uuid -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import Field @@ -73,32 +73,32 @@ class GenerationResponse(StandardBaseModel): request_args: dict[str, Any] = Field( description="Arguments passed to the backend for this request." ) - value: Optional[str] = Field( + value: str | None = Field( default=None, description="Complete generated text content. None for streaming responses.", ) - delta: Optional[str] = Field( + delta: str | None = Field( default=None, description="Incremental text content for streaming responses." ) iterations: int = Field( default=0, description="Number of generation iterations completed." ) - request_prompt_tokens: Optional[int] = Field( + request_prompt_tokens: int | None = Field( default=None, description="Token count from the original request prompt." ) - request_output_tokens: Optional[int] = Field( + request_output_tokens: int | None = Field( default=None, description="Expected output token count from the original request.", ) - response_prompt_tokens: Optional[int] = Field( + response_prompt_tokens: int | None = Field( default=None, description="Actual prompt token count reported by the backend." ) - response_output_tokens: Optional[int] = Field( + response_output_tokens: int | None = Field( default=None, description="Actual output token count reported by the backend." ) @property - def prompt_tokens(self) -> Optional[int]: + def prompt_tokens(self) -> int | None: """ :return: The number of prompt tokens used in the request (response_prompt_tokens if available, otherwise request_prompt_tokens). @@ -106,7 +106,7 @@ def prompt_tokens(self) -> Optional[int]: return self.response_prompt_tokens or self.request_prompt_tokens @property - def output_tokens(self) -> Optional[int]: + def output_tokens(self) -> int | None: """ :return: The number of output tokens generated in the response (response_output_tokens if available, otherwise request_output_tokens). @@ -114,7 +114,7 @@ def output_tokens(self) -> Optional[int]: return self.response_output_tokens or self.request_output_tokens @property - def total_tokens(self) -> Optional[int]: + def total_tokens(self) -> int | None: """ :return: The total number of tokens used in the request and response. Sum of prompt_tokens and output_tokens. @@ -125,7 +125,7 @@ def total_tokens(self) -> Optional[int]: def preferred_prompt_tokens( self, preferred_source: Literal["request", "response"] - ) -> Optional[int]: + ) -> int | None: if preferred_source == "request": return self.request_prompt_tokens or self.response_prompt_tokens else: @@ -133,7 +133,7 @@ def preferred_prompt_tokens( def preferred_output_tokens( self, preferred_source: Literal["request", "response"] - ) -> Optional[int]: + ) -> int | None: if preferred_source == "request": return self.request_output_tokens or self.response_output_tokens else: @@ -146,11 +146,11 @@ class GenerationRequestTimings(MeasuredRequestTimings): """Timing model for tracking generation request lifecycle events.""" timings_type: Literal["generation_request_timings"] = "generation_request_timings" - first_iteration: Optional[float] = Field( + first_iteration: float | None = Field( default=None, description="Unix timestamp when the first generation iteration began.", ) - last_iteration: Optional[float] = Field( + last_iteration: float | None = Field( default=None, description="Unix timestamp when the last generation iteration completed.", ) diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index ce83076f..c8eb70f3 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -17,7 +17,7 @@ import time from collections.abc import AsyncIterator from pathlib import Path -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar import httpx from PIL import Image @@ -33,13 +33,15 @@ __all__ = ["OpenAIHTTPBackend", "UsageStats"] +ContentT = str | list[str | dict[str, str | dict[str, str]] | Path | Image.Image] | Any + @dataclasses.dataclass class UsageStats: """Token usage statistics for generation requests.""" - prompt_tokens: Optional[int] = None - output_tokens: Optional[int] = None + prompt_tokens: int | None = None + output_tokens: int | None = None @Backend.register("openai_http") @@ -78,19 +80,19 @@ class OpenAIHTTPBackend(Backend): def __init__( self, target: str, - model: Optional[str] = None, - api_key: Optional[str] = None, - organization: Optional[str] = None, - project: Optional[str] = None, + model: str | None = None, + api_key: str | None = None, + organization: str | None = None, + project: str | None = None, timeout: float = 60.0, http2: bool = True, follow_redirects: bool = True, - max_output_tokens: Optional[int] = None, + max_output_tokens: int | None = None, stream_response: bool = True, - extra_query: Optional[dict] = None, - extra_body: Optional[dict] = None, - remove_from_body: Optional[list[str]] = None, - headers: Optional[dict] = None, + extra_query: dict | None = None, + extra_body: dict | None = None, + remove_from_body: list[str] | None = None, + headers: dict | None = None, verify: bool = False, ): """ @@ -137,7 +139,7 @@ def __init__( # Runtime state self._in_process = False - self._async_client: Optional[httpx.AsyncClient] = None + self._async_client: httpx.AsyncClient | None = None @property def info(self) -> dict[str, Any]: @@ -264,7 +266,7 @@ async def available_models(self) -> list[str]: return [item["id"] for item in response.json()["data"]] - async def default_model(self) -> Optional[str]: + async def default_model(self) -> str | None: """ Get the default model for this backend. @@ -280,7 +282,7 @@ async def resolve( self, request: GenerationRequest, request_info: ScheduledRequestInfo, - history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + history: list[tuple[GenerationRequest, GenerationResponse]] | None = None, ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: """ Process a generation request and yield progressive responses. @@ -363,12 +365,12 @@ async def resolve( async def text_completions( self, - prompt: Union[str, list[str]], - request_id: Optional[str], # noqa: ARG002 - output_token_count: Optional[int] = None, + prompt: str | list[str], + request_id: str | None, # noqa: ARG002 + output_token_count: int | None = None, stream_response: bool = True, **kwargs, - ) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]: + ) -> AsyncIterator[tuple[str | None, UsageStats | None]]: """ Generate text completions using the /v1/completions endpoint. @@ -431,17 +433,13 @@ async def text_completions( async def chat_completions( self, - content: Union[ - str, - list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], - Any, - ], - request_id: Optional[str] = None, # noqa: ARG002 - output_token_count: Optional[int] = None, + content: ContentT, + request_id: str | None = None, # noqa: ARG002 + output_token_count: int | None = None, raw_content: bool = False, stream_response: bool = True, **kwargs, - ) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]: + ) -> AsyncIterator[tuple[str | None, UsageStats | None]]: """ Generate chat completions using the /v1/chat/completions endpoint. @@ -502,10 +500,10 @@ async def chat_completions( def _build_headers( self, - api_key: Optional[str], - organization: Optional[str], - project: Optional[str], - user_headers: Optional[dict], + api_key: str | None, + organization: str | None, + project: str | None, + user_headers: dict | None, ) -> dict[str, str]: headers = {} @@ -541,11 +539,7 @@ def _get_params(self, endpoint_type: str) -> dict[str, str]: def _get_chat_messages( self, - content: Union[ - str, - list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], - Any, - ], + content: ContentT, ) -> list[dict[str, Any]]: if isinstance(content, str): return [{"role": "user", "content": content}] @@ -559,7 +553,7 @@ def _get_chat_messages( resolved_content.append(item) elif isinstance(item, str): resolved_content.append({"type": "text", "text": item}) - elif isinstance(item, (Image.Image, Path)): + elif isinstance(item, Image.Image | Path): resolved_content.append(self._get_chat_message_media_item(item)) else: raise ValueError(f"Unsupported content item type: {type(item)}") @@ -567,7 +561,7 @@ def _get_chat_messages( return [{"role": "user", "content": resolved_content}] def _get_chat_message_media_item( - self, item: Union[Path, Image.Image] + self, item: Path | Image.Image ) -> dict[str, Any]: if isinstance(item, Image.Image): encoded = base64.b64encode(item.tobytes()).decode("utf-8") @@ -597,8 +591,8 @@ def _get_chat_message_media_item( def _get_body( self, endpoint_type: str, - request_kwargs: Optional[dict[str, Any]], - max_output_tokens: Optional[int] = None, + request_kwargs: dict[str, Any] | None, + max_output_tokens: int | None = None, **kwargs, ) -> dict[str, Any]: # Start with endpoint-specific extra body parameters @@ -628,7 +622,7 @@ def _get_body( return {key: val for key, val in body.items() if val is not None} - def _get_completions_text_content(self, data: dict) -> Optional[str]: + def _get_completions_text_content(self, data: dict) -> str | None: if not data.get("choices"): return None @@ -639,7 +633,7 @@ def _get_completions_text_content(self, data: dict) -> Optional[str]: or choice.get("message", {}).get("content") ) - def _get_completions_usage_stats(self, data: dict) -> Optional[UsageStats]: + def _get_completions_usage_stats(self, data: dict) -> UsageStats | None: if not data.get("usage"): return None diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index e965c482..b33a7b14 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -267,7 +267,7 @@ def resolve( resolved = {} for key, val in aggregators.items(): - if isinstance(val, (Aggregator, CompilableAggregator)): + if isinstance(val, Aggregator | CompilableAggregator): resolved[key] = val else: aggregator_class = cls.get_registered_object(key) @@ -975,7 +975,7 @@ def _calculate_requests_per_second( filtered_statuses = [] filtered_times = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined( safe_getattr(request.scheduler_info.request_timings, "request_start"), safe_getattr(request.scheduler_info.request_timings, "request_end"), @@ -1005,7 +1005,7 @@ def _calculate_request_concurrency( filtered_statuses = [] filtered_times = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined( safe_getattr(request.scheduler_info.request_timings, "request_start"), safe_getattr(request.scheduler_info.request_timings, "request_end"), @@ -1035,7 +1035,7 @@ def _calculate_request_latency( filtered_statuses = [] filtered_values = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.request_latency): continue @@ -1056,7 +1056,7 @@ def _calculate_prompt_token_count( filtered_statuses = [] filtered_values = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.prompt_tokens): continue @@ -1077,7 +1077,7 @@ def _calculate_output_token_count( filtered_statuses = [] filtered_values = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.output_tokens): continue @@ -1098,7 +1098,7 @@ def _calculate_total_token_count( filtered_statuses = [] filtered_values = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.total_tokens): continue @@ -1119,7 +1119,7 @@ def _calculate_time_to_first_token_ms( filtered_statuses = [] filtered_values = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.time_to_first_token_ms): continue @@ -1141,7 +1141,7 @@ def _calculate_time_per_output_token_ms( filtered_values = [] filtered_weights = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.time_to_first_token_ms): continue @@ -1174,7 +1174,7 @@ def _calculate_inter_token_latency_ms( filtered_values = [] filtered_weights = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.inter_token_latency_ms): continue @@ -1199,7 +1199,7 @@ def _calculate_output_tokens_per_second( filtered_first_iter_times = [] filtered_iter_counts = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.output_tokens_per_second): continue @@ -1234,7 +1234,7 @@ def _calculate_tokens_per_second( filtered_iter_counts = [] filtered_first_iter_counts = [] - for status, request in zip(statuses, requests): + for status, request in zip(statuses, requests, strict=False): if not all_defined(request.tokens_per_second): continue diff --git a/src/guidellm/benchmark/benchmarker.py b/src/guidellm/benchmark/benchmarker.py index 5f05065a..99410e4c 100644 --- a/src/guidellm/benchmark/benchmarker.py +++ b/src/guidellm/benchmark/benchmarker.py @@ -228,12 +228,12 @@ def _combine( existing: dict[str, Any] | StandardBaseDict, addition: dict[str, Any] | StandardBaseDict, ) -> dict[str, Any] | StandardBaseDict: - if not isinstance(existing, (dict, StandardBaseDict)): + if not isinstance(existing, dict | StandardBaseDict): raise ValueError( f"Existing value {existing} (type: {type(existing).__name__}) " f"is not a valid type for merging." ) - if not isinstance(addition, (dict, StandardBaseDict)): + if not isinstance(addition, dict | StandardBaseDict): raise ValueError( f"Addition value {addition} (type: {type(addition).__name__}) " f"is not a valid type for merging." diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index 56775dac..cacadc94 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -34,10 +34,11 @@ DistributionSummary, RegistryMixin, StatusDistributionSummary, + camelize_str, + recursive_key_update, safe_format_timestamp, split_text_list_by_length, ) -from guidellm.utils import recursive_key_update, camelize_str __all__ = [ "GenerativeBenchmarkerCSV", @@ -90,7 +91,7 @@ def resolve( if not output_formats: return {} - if isinstance(output_formats, (list, tuple)): + if isinstance(output_formats, list | tuple): # support list of output keys: ["csv", "json"] # support list of files: ["path/to/file.json", "path/to/file.csv"] formats_list = output_formats @@ -369,7 +370,7 @@ def _print_line( f"Value and style length mismatch: {len(value)} vs {len(style)}" ) - for val, sty in zip(value, style): + for val, sty in zip(value, style, strict=False): text.append(val, style=sty) self.console.print(Padding.indent(text, indent)) @@ -568,8 +569,8 @@ async def finalize(self, report: GenerativeBenchmarksReport) -> Path: benchmark_values: list[str | float | list[float]] = [] # Add basic run description info - desc_headers, desc_values = ( - self._get_benchmark_desc_headers_and_values(benchmark) + desc_headers, desc_values = self._get_benchmark_desc_headers_and_values( + benchmark ) benchmark_headers.extend(desc_headers) benchmark_values.extend(desc_values) @@ -680,7 +681,8 @@ def _get_benchmark_status_metrics_stats( return headers, values def _get_benchmark_extras_headers_and_values( - self, benchmark: GenerativeBenchmark, + self, + benchmark: GenerativeBenchmark, ) -> tuple[list[str], list[str]]: headers = ["Profile", "Backend", "Generator Data"] values: list[str] = [ @@ -733,9 +735,7 @@ async def finalize(self, report: GenerativeBenchmarksReport) -> Path: ui_api_data = {} for k, v in camel_data.items(): placeholder_key = f"window.{k} = {{}};" - replacement_value = ( - f"window.{k} = {json.dumps(v, indent=2)};\n" - ) + replacement_value = f"window.{k} = {json.dumps(v, indent=2)};\n" ui_api_data[placeholder_key] = replacement_value create_report(ui_api_data, output_path) diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index 3ff8d0e0..87a9a2be 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -679,7 +679,10 @@ def next_strategy( prev_benchmark.metrics.requests_per_second.successful.mean ) if self.synchronous_rate <= 0 and self.throughput_rate <= 0: - raise RuntimeError("Invalid rates in sweep; aborting. Were there any successful requests?") + raise RuntimeError( + "Invalid rates in sweep; aborting. " + "Were there any successful requests?" + ) self.measured_rates = list( np.linspace( self.synchronous_rate, diff --git a/src/guidellm/benchmark/scenario.py b/src/guidellm/benchmark/scenario.py index b53ef424..73a9a050 100644 --- a/src/guidellm/benchmark/scenario.py +++ b/src/guidellm/benchmark/scenario.py @@ -1,10 +1,11 @@ from __future__ import annotations import json +from collections.abc import Callable from functools import cache, wraps from inspect import Parameter, signature from pathlib import Path -from typing import Annotated, Any, Callable, Literal, TypeVar +from typing import Annotated, Any, Literal, TypeVar import yaml from loguru import logger @@ -38,7 +39,7 @@ def parse_float_list(value: str | float | list[float]) -> list[float]: or convert single float list of one or pass float list through. """ - if isinstance(value, (int, float)): + if isinstance(value, int | float): return [value] elif isinstance(value, list): return value diff --git a/src/guidellm/dataset/creator.py b/src/guidellm/dataset/creator.py index a74ec8c0..fe712c23 100644 --- a/src/guidellm/dataset/creator.py +++ b/src/guidellm/dataset/creator.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict from transformers import PreTrainedTokenizerBase # type: ignore[import] @@ -80,12 +80,12 @@ class DatasetCreator(ABC): def create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, random_seed: int = 42, - split_pref_order: Optional[list[str]] = None, - ) -> tuple[Union[Dataset, IterableDataset], dict[ColumnInputTypes, str]]: + split_pref_order: list[str] | None = None, + ) -> tuple[Dataset | IterableDataset, dict[ColumnInputTypes, str]]: if not cls.is_supported(data, data_args): raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ") @@ -95,10 +95,10 @@ def create( data, data_args, processor, processor_args, random_seed ) - if isinstance(dataset, (DatasetDict, IterableDatasetDict)): + if isinstance(dataset, DatasetDict | IterableDatasetDict): dataset = cls.extract_dataset_split(dataset, split, split_pref_order) - if not isinstance(dataset, (Dataset, IterableDataset)): + if not isinstance(dataset, Dataset | IterableDataset): raise ValueError( f"Unsupported data type: {type(dataset)} given for {dataset}." ) @@ -106,7 +106,7 @@ def create( return dataset, column_mappings @classmethod - def extract_args_split(cls, data_args: Optional[dict[str, Any]]) -> str: + def extract_args_split(cls, data_args: dict[str, Any] | None) -> str: split = "auto" if data_args and "split" in data_args: @@ -118,7 +118,7 @@ def extract_args_split(cls, data_args: Optional[dict[str, Any]]) -> str: @classmethod def extract_args_column_mappings( cls, - data_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, ) -> dict[ColumnInputTypes, str]: columns: dict[ColumnInputTypes, str] = {} @@ -143,12 +143,12 @@ def extract_args_column_mappings( @classmethod def extract_dataset_name( - cls, dataset: Union[Dataset, IterableDataset, DatasetDict, IterableDatasetDict] - ) -> Optional[str]: - if isinstance(dataset, (DatasetDict, IterableDatasetDict)): + cls, dataset: Dataset | IterableDataset | DatasetDict | IterableDatasetDict + ) -> str | None: + if isinstance(dataset, DatasetDict | IterableDatasetDict): dataset = dataset[list(dataset.keys())[0]] - if isinstance(dataset, (Dataset, IterableDataset)): + if isinstance(dataset, Dataset | IterableDataset): if not hasattr(dataset, "info") or not hasattr( dataset.info, "dataset_name" ): @@ -161,11 +161,11 @@ def extract_dataset_name( @classmethod def extract_dataset_split( cls, - dataset: Union[DatasetDict, IterableDatasetDict], - specified_split: Union[Literal["auto"], str] = "auto", - split_pref_order: Optional[Union[Literal["auto"], list[str]]] = "auto", - ) -> Union[Dataset, IterableDataset]: - if not isinstance(dataset, (DatasetDict, IterableDatasetDict)): + dataset: DatasetDict | IterableDatasetDict, + specified_split: Literal["auto"] | str = "auto", + split_pref_order: Literal["auto"] | list[str] | None = "auto", + ) -> Dataset | IterableDataset: + if not isinstance(dataset, DatasetDict | IterableDatasetDict): raise ValueError( f"Unsupported data type: {type(dataset)} given for {dataset}." ) @@ -199,15 +199,15 @@ def extract_dataset_split( @classmethod @abstractmethod - def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: ... + def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: ... @classmethod @abstractmethod def handle_create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, random_seed: int, - ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: ... + ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: ... diff --git a/src/guidellm/dataset/entrypoints.py b/src/guidellm/dataset/entrypoints.py index cf689956..1da2222a 100644 --- a/src/guidellm/dataset/entrypoints.py +++ b/src/guidellm/dataset/entrypoints.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from datasets import Dataset, IterableDataset from transformers import PreTrainedTokenizerBase # type: ignore[import] @@ -15,12 +15,12 @@ def load_dataset( data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, random_seed: int = 42, - split_pref_order: Optional[list[str]] = None, -) -> tuple[Union[Dataset, IterableDataset], dict[ColumnInputTypes, str]]: + split_pref_order: list[str] | None = None, +) -> tuple[Dataset | IterableDataset, dict[ColumnInputTypes, str]]: creators = [ InMemoryDatasetCreator, SyntheticDatasetCreator, diff --git a/src/guidellm/dataset/file.py b/src/guidellm/dataset/file.py index 5d6df1d9..718cb46f 100644 --- a/src/guidellm/dataset/file.py +++ b/src/guidellm/dataset/file.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import pandas as pd # type: ignore[import] from datasets import ( @@ -30,8 +30,8 @@ class FileDatasetCreator(DatasetCreator): } @classmethod - def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003 - if isinstance(data, (str, Path)) and (path := Path(data)).exists(): + def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003 + if isinstance(data, str | Path) and (path := Path(data)).exists(): # local folder or py file, assume supported return path.suffix.lower() in cls.SUPPORTED_TYPES @@ -41,12 +41,12 @@ def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: def handle_create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], # noqa: ARG003 - processor_args: Optional[dict[str, Any]], # noqa: ARG003 + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003 + processor_args: dict[str, Any] | None, # noqa: ARG003 random_seed: int, # noqa: ARG003 - ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: - if not isinstance(data, (str, Path)): + ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: + if not isinstance(data, str | Path): raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ") path = Path(data) @@ -63,8 +63,8 @@ def handle_create( @classmethod def load_dataset( - cls, path: Path, data_args: Optional[dict[str, Any]] - ) -> Union[Dataset, IterableDataset]: + cls, path: Path, data_args: dict[str, Any] | None + ) -> Dataset | IterableDataset: if path.suffix.lower() in {".txt", ".text"}: with path.open("r") as file: items = file.readlines() diff --git a/src/guidellm/dataset/hf_datasets.py b/src/guidellm/dataset/hf_datasets.py index 7f91facd..d1be46c1 100644 --- a/src/guidellm/dataset/hf_datasets.py +++ b/src/guidellm/dataset/hf_datasets.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from datasets import ( Dataset, @@ -18,18 +18,18 @@ class HFDatasetsCreator(DatasetCreator): @classmethod - def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003 + def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003 if isinstance( - data, (Dataset, DatasetDict, IterableDataset, IterableDatasetDict) + data, Dataset | DatasetDict | IterableDataset | IterableDatasetDict ): # base type is supported return True - if isinstance(data, (str, Path)) and (path := Path(data)).exists(): + if isinstance(data, str | Path) and (path := Path(data)).exists(): # local folder or py file, assume supported return path.is_dir() or path.suffix == ".py" - if isinstance(data, (str, Path)): + if isinstance(data, str | Path): try: # try to load dataset return get_dataset_config_info(data) is not None @@ -42,12 +42,12 @@ def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: def handle_create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], # noqa: ARG003 - processor_args: Optional[dict[str, Any]], # noqa: ARG003 + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003 + processor_args: dict[str, Any] | None, # noqa: ARG003 random_seed: int, # noqa: ARG003 - ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: - if isinstance(data, (str, Path)): + ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: + if isinstance(data, str | Path): data = load_dataset(data, **(data_args or {})) elif data_args: raise ValueError( @@ -55,7 +55,7 @@ def handle_create( ) if isinstance( - data, (Dataset, DatasetDict, IterableDataset, IterableDatasetDict) + data, Dataset | DatasetDict | IterableDataset | IterableDatasetDict ): return data diff --git a/src/guidellm/dataset/in_memory.py b/src/guidellm/dataset/in_memory.py index af84f658..0461948c 100644 --- a/src/guidellm/dataset/in_memory.py +++ b/src/guidellm/dataset/in_memory.py @@ -1,6 +1,6 @@ from collections.abc import Iterable from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from datasets import ( Dataset, @@ -17,18 +17,18 @@ class InMemoryDatasetCreator(DatasetCreator): @classmethod - def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003 + def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003 return isinstance(data, Iterable) and not isinstance(data, str) @classmethod def handle_create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], # noqa: ARG003 - processor_args: Optional[dict[str, Any]], # noqa: ARG003 + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003 + processor_args: dict[str, Any] | None, # noqa: ARG003 random_seed: int, # noqa: ARG003 - ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: + ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: if not isinstance(data, Iterable): raise TypeError( f"Unsupported data format. Expected Iterable[Any], got {type(data)}" diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 8c30f0f7..8a1626fe 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Iterator from itertools import cycle from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Literal import yaml from datasets import ( @@ -35,17 +35,17 @@ class SyntheticDatasetConfig(BaseModel): description="The average number of text tokens generated for prompts.", gt=0, ) - prompt_tokens_stdev: Optional[int] = Field( + prompt_tokens_stdev: int | None = Field( description="The standard deviation of the tokens generated for prompts.", gt=0, default=None, ) - prompt_tokens_min: Optional[int] = Field( + prompt_tokens_min: int | None = Field( description="The minimum number of text tokens generated for prompts.", gt=0, default=None, ) - prompt_tokens_max: Optional[int] = Field( + prompt_tokens_max: int | None = Field( description="The maximum number of text tokens generated for prompts.", gt=0, default=None, @@ -54,17 +54,17 @@ class SyntheticDatasetConfig(BaseModel): description="The average number of text tokens generated for outputs.", gt=0, ) - output_tokens_stdev: Optional[int] = Field( + output_tokens_stdev: int | None = Field( description="The standard deviation of the tokens generated for outputs.", gt=0, default=None, ) - output_tokens_min: Optional[int] = Field( + output_tokens_min: int | None = Field( description="The minimum number of text tokens generated for outputs.", gt=0, default=None, ) - output_tokens_max: Optional[int] = Field( + output_tokens_max: int | None = Field( description="The maximum number of text tokens generated for outputs.", gt=0, default=None, @@ -80,7 +80,7 @@ class SyntheticDatasetConfig(BaseModel): ) @staticmethod - def parse_str(data: Union[str, Path]) -> "SyntheticDatasetConfig": + def parse_str(data: str | Path) -> "SyntheticDatasetConfig": if ( isinstance(data, Path) or data.strip().endswith(".config") @@ -117,7 +117,7 @@ def parse_key_value_pairs(data: str) -> "SyntheticDatasetConfig": return SyntheticDatasetConfig(**config_dict) # type: ignore[arg-type] @staticmethod - def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig": + def parse_config_file(data: str | Path) -> "SyntheticDatasetConfig": with Path(data).open("r") as file: config_dict = yaml.safe_load(file) @@ -128,7 +128,7 @@ class SyntheticTextItemsGenerator( Iterable[ dict[ Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - Union[str, int], + str | int, ] ] ): @@ -150,7 +150,7 @@ def __iter__( ) -> Iterator[ dict[ Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - Union[str, int], + str | int, ] ]: prompt_tokens_sampler = IntegerRangeSampler( @@ -177,7 +177,7 @@ def __iter__( for _, prompt_tokens, output_tokens in zip( range(self.config.samples), prompt_tokens_sampler, - output_tokens_sampler, + output_tokens_sampler, strict=False, ): start_index = rand.randint(0, len(self.text_creator.words)) prompt_text = self.processor.decode( @@ -194,7 +194,7 @@ def __iter__( } def _create_prompt( - self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None + self, prompt_tokens: int, start_index: int, unique_prefix: int | None = None ) -> list[int]: if prompt_tokens <= 0: return [] @@ -224,7 +224,7 @@ class SyntheticDatasetCreator(DatasetCreator): def is_supported( cls, data: Any, - data_args: Optional[dict[str, Any]], # noqa: ARG003 + data_args: dict[str, Any] | None, # noqa: ARG003 ) -> bool: if ( isinstance(data, Path) @@ -248,11 +248,11 @@ def is_supported( def handle_create( cls, data: Any, - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, random_seed: int, - ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: + ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: processor = check_load_processor( processor, processor_args, @@ -270,7 +270,7 @@ def handle_create( @classmethod def extract_args_column_mappings( cls, - data_args: Optional[dict[str, Any]], + data_args: dict[str, Any] | None, ) -> dict[ColumnInputTypes, str]: data_args_columns = super().extract_args_column_mappings(data_args) diff --git a/src/guidellm/logger.py b/src/guidellm/logger.py index 70259bad..da3464f9 100644 --- a/src/guidellm/logger.py +++ b/src/guidellm/logger.py @@ -72,7 +72,7 @@ def configure_logger(config: LoggingSettings = settings.logging): sys.stdout, level=config.console_log_level.upper(), format="{time:YY-MM-DD HH:mm:ss}|{level: <8} \ - |{name}:{function}:{line} - {message}" + |{name}:{function}:{line} - {message}", ) if config.log_file or config.log_file_level: diff --git a/src/guidellm/preprocess/dataset.py b/src/guidellm/preprocess/dataset.py index a94b8a14..b02efec5 100644 --- a/src/guidellm/preprocess/dataset.py +++ b/src/guidellm/preprocess/dataset.py @@ -1,9 +1,9 @@ import json import os -from collections.abc import Iterator +from collections.abc import Callable, Iterator from enum import Enum from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any import yaml from datasets import Dataset @@ -32,7 +32,7 @@ def handle_ignore_strategy( min_prompt_tokens: int, tokenizer: PreTrainedTokenizerBase, **_kwargs, -) -> Optional[str]: +) -> str | None: """ Ignores prompts that are shorter than the required minimum token length. @@ -56,7 +56,7 @@ def handle_concatenate_strategy( tokenizer: PreTrainedTokenizerBase, concat_delimiter: str, **_kwargs, -) -> Optional[str]: +) -> str | None: """ Concatenates prompts until the minimum token requirement is met. @@ -117,7 +117,7 @@ def handle_error_strategy( min_prompt_tokens: int, tokenizer: PreTrainedTokenizerBase, **_kwargs, -) -> Optional[str]: +) -> str | None: """ Raises an error if the prompt is too short. @@ -150,24 +150,24 @@ class TokensConfig(BaseModel): description="The average number of tokens.", gt=0, ) - stdev: Optional[int] = Field( + stdev: int | None = Field( description="The standard deviation of the tokens.", gt=0, default=None, ) - min: Optional[int] = Field( + min: int | None = Field( description="The minimum number of tokens.", gt=0, default=None, ) - max: Optional[int] = Field( + max: int | None = Field( description="The maximum number of tokens.", gt=0, default=None, ) @staticmethod - def parse_str(data: Union[str, Path]) -> "TokensConfig": + def parse_str(data: str | Path) -> "TokensConfig": """ Parses a string or path into a TokensConfig object. Supports: - JSON string @@ -215,14 +215,14 @@ def parse_key_value_pairs(data: str) -> "TokensConfig": return TokensConfig(**config_dict) # type: ignore[arg-type] @staticmethod - def parse_config_file(data: Union[str, Path]) -> "TokensConfig": + def parse_config_file(data: str | Path) -> "TokensConfig": with Path(data).open("r") as file: config_dict = yaml.safe_load(file) return TokensConfig(**config_dict) -def _validate_output_suffix(output_path: Union[str, Path]) -> None: +def _validate_output_suffix(output_path: str | Path) -> None: output_path = Path(output_path) suffix = output_path.suffix.lower() if suffix not in SUPPORTED_TYPES: @@ -233,18 +233,18 @@ def _validate_output_suffix(output_path: Union[str, Path]) -> None: def process_dataset( - data: Union[str, Path], - output_path: Union[str, Path], - processor: Union[str, Path, PreTrainedTokenizerBase], - prompt_tokens: Union[str, Path], - output_tokens: Union[str, Path], - processor_args: Optional[dict[str, Any]] = None, - data_args: Optional[dict[str, Any]] = None, + data: str | Path, + output_path: str | Path, + processor: str | Path | PreTrainedTokenizerBase, + prompt_tokens: str | Path, + output_tokens: str | Path, + processor_args: dict[str, Any] | None = None, + data_args: dict[str, Any] | None = None, short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE, - pad_char: Optional[str] = None, - concat_delimiter: Optional[str] = None, + pad_char: str | None = None, + concat_delimiter: str | None = None, push_to_hub: bool = False, - hub_dataset_id: Optional[str] = None, + hub_dataset_id: str | None = None, random_seed: int = 42, ) -> None: """ @@ -354,7 +354,7 @@ def process_dataset( def push_dataset_to_hub( - hub_dataset_id: Optional[str], + hub_dataset_id: str | None, processed_dataset: Dataset, ) -> None: """ diff --git a/src/guidellm/presentation/data_models.py b/src/guidellm/presentation/data_models.py index c1e8f13f..ff2863b4 100644 --- a/src/guidellm/presentation/data_models.py +++ b/src/guidellm/presentation/data_models.py @@ -1,7 +1,7 @@ import random from collections import defaultdict from math import ceil -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pydantic import BaseModel, computed_field @@ -12,14 +12,14 @@ class Bucket(BaseModel): - value: Union[float, int] + value: float | int count: int @staticmethod def from_data( - data: Union[list[float], list[int]], - bucket_width: Optional[float] = None, - n_buckets: Optional[int] = None, + data: list[float] | list[int], + bucket_width: float | None = None, + n_buckets: int | None = None, ) -> tuple[list["Bucket"], float]: if not data: return [], 1.0 @@ -35,7 +35,7 @@ def from_data( else: n_buckets = ceil(range_v / bucket_width) - bucket_counts: defaultdict[Union[float, int], int] = defaultdict(int) + bucket_counts: defaultdict[float | int, int] = defaultdict(int) for val in data: idx = int((val - min_v) // bucket_width) if idx >= n_buckets: @@ -80,7 +80,7 @@ def from_benchmarks(cls, benchmarks: list["GenerativeBenchmark"]): class Distribution(BaseModel): - statistics: Optional[DistributionSummary] = None + statistics: DistributionSummary | None = None buckets: list[Bucket] bucket_width: float @@ -190,7 +190,7 @@ class TabularDistributionSummary(DistributionSummary): """ @computed_field - def percentile_rows(self) -> list[dict[str, Union[str, float]]]: + def percentile_rows(self) -> list[dict[str, str | float]]: rows = [ {"percentile": name, "value": value} for name, value in self.percentiles.model_dump().items() diff --git a/src/guidellm/presentation/injector.py b/src/guidellm/presentation/injector.py index bb1fd684..1e78080e 100644 --- a/src/guidellm/presentation/injector.py +++ b/src/guidellm/presentation/injector.py @@ -1,6 +1,5 @@ import re from pathlib import Path -from typing import Union from loguru import logger @@ -8,7 +7,7 @@ from guidellm.utils.text import load_text -def create_report(js_data: dict, output_path: Union[str, Path]) -> Path: +def create_report(js_data: dict, output_path: str | Path) -> Path: """ Creates a report from the dictionary and saves it to the output path. diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 607a7455..ac34131e 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -4,8 +4,6 @@ from typing import ( Any, Literal, - Optional, - Union, ) from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict @@ -43,9 +41,9 @@ def description(self) -> RequestLoaderDescription: ... class GenerativeRequestLoaderDescription(RequestLoaderDescription): type_: Literal["generative_request_loader"] = "generative_request_loader" # type: ignore[assignment] data: str - data_args: Optional[dict[str, Any]] + data_args: dict[str, Any] | None processor: str - processor_args: Optional[dict[str, Any]] + processor_args: dict[str, Any] | None class GenerativeRequestLoader(RequestLoader): @@ -69,18 +67,11 @@ class GenerativeRequestLoader(RequestLoader): def __init__( self, - data: Union[ - str, - Path, - Iterable[Union[str, dict[str, Any]]], - Dataset, - DatasetDict, - IterableDataset, - IterableDatasetDict, - ], - data_args: Optional[dict[str, Any]], - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + data: str | Path | Iterable[str | dict[str, Any]] | Dataset | DatasetDict | \ + IterableDataset | IterableDatasetDict, + data_args: dict[str, Any] | None, + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, shuffle: bool = True, iter_type: Literal["finite", "infinite"] = "finite", random_seed: int = 42, @@ -202,7 +193,7 @@ def _extract_text_column(self) -> str: "'data_args' dictionary." ) - def _extract_prompt_tokens_count_column(self) -> Optional[str]: + def _extract_prompt_tokens_count_column(self) -> str | None: column_names = self._dataset_columns() if column_names and "prompt_tokens_count" in column_names: @@ -213,7 +204,7 @@ def _extract_prompt_tokens_count_column(self) -> Optional[str]: return None - def _extract_output_tokens_count_column(self) -> Optional[str]: + def _extract_output_tokens_count_column(self) -> str | None: column_names = self._dataset_columns() if column_names and "output_tokens_count" in column_names: @@ -224,7 +215,7 @@ def _extract_output_tokens_count_column(self) -> Optional[str]: return None - def _dataset_columns(self, err_msg: Optional[str] = None) -> Optional[list[str]]: + def _dataset_columns(self, err_msg: str | None = None) -> list[str] | None: try: column_names = self.dataset.column_names @@ -240,7 +231,7 @@ def _dataset_columns(self, err_msg: Optional[str] = None) -> Optional[list[str]] def _get_dataset_iter( self, scope_create_count: int - ) -> Optional[Iterator[dict[str, Any]]]: + ) -> Iterator[dict[str, Any]] | None: if scope_create_count > 0 and self.iter_type != "infinite": return None diff --git a/src/guidellm/request/request.py b/src/guidellm/request/request.py index bf4e59fb..83dc40f1 100644 --- a/src/guidellm/request/request.py +++ b/src/guidellm/request/request.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import Field @@ -33,7 +33,7 @@ class GenerationRequest(StandardBaseModel): of output tokens. Used for controlling the behavior of the backend. """ - request_id: Optional[str] = Field( + request_id: str | None = Field( default_factory=lambda: str(uuid.uuid4()), description="The unique identifier for the request.", ) diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py index c724a74a..c974225a 100644 --- a/src/guidellm/scheduler/constraints.py +++ b/src/guidellm/scheduler/constraints.py @@ -450,7 +450,7 @@ def __call__( current_index = max(0, self.current_index) max_num = ( self.max_num - if isinstance(self.max_num, (int, float)) + if isinstance(self.max_num, int | float) else self.max_num[min(current_index, len(self.max_num) - 1)] ) @@ -489,7 +489,7 @@ def _validate_max_num( raise ValueError( f"max_num must be set and truthful, received {value} ({val} failed)" ) - if not isinstance(val, (int, float)) or val <= 0: + if not isinstance(val, int | float) or val <= 0: raise ValueError( f"max_num must be a positive num, received {value} ({val} failed)" ) @@ -568,7 +568,7 @@ def __call__( current_index = max(0, self.current_index) max_duration = ( self.max_duration - if isinstance(self.max_duration, (int, float)) + if isinstance(self.max_duration, int | float) else self.max_duration[min(current_index, len(self.max_duration) - 1)] ) @@ -607,7 +607,7 @@ def _validate_max_duration( "max_duration must be set and truthful, " f"received {value} ({val} failed)" ) - if not isinstance(val, (int, float)) or val <= 0: + if not isinstance(val, int | float) or val <= 0: raise ValueError( "max_duration must be a positive num," f"received {value} ({val} failed)" @@ -682,7 +682,7 @@ def __call__( current_index = max(0, self.current_index) max_errors = ( self.max_errors - if isinstance(self.max_errors, (int, float)) + if isinstance(self.max_errors, int | float) else self.max_errors[min(current_index, len(self.max_errors) - 1)] ) errors_exceeded = state.errored_requests >= max_errors @@ -710,7 +710,7 @@ def _validate_max_errors( "max_errors must be set and truthful, " f"received {value} ({val} failed)" ) - if not isinstance(val, (int, float)) or val <= 0: + if not isinstance(val, int | float) or val <= 0: raise ValueError( f"max_errors must be a positive num,received {value} ({val} failed)" ) @@ -799,7 +799,7 @@ def __call__( current_index = max(0, self.current_index) max_error_rate = ( self.max_error_rate - if isinstance(self.max_error_rate, (int, float)) + if isinstance(self.max_error_rate, int | float) else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] ) @@ -850,7 +850,7 @@ def _validate_max_error_rate( "max_error_rate must be set and truthful, " f"received {value} ({val} failed)" ) - if not isinstance(val, (int, float)) or val <= 0 or val >= 1: + if not isinstance(val, int | float) or val <= 0 or val >= 1: raise ValueError( "max_error_rate must be a number between 0 and 1," f"received {value} ({val} failed)" @@ -940,7 +940,7 @@ def __call__( current_index = max(0, self.current_index) max_error_rate = ( self.max_error_rate - if isinstance(self.max_error_rate, (int, float)) + if isinstance(self.max_error_rate, int | float) else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] ) @@ -982,7 +982,7 @@ def _validate_max_error_rate( "max_error_rate must be set and truthful, " f"received {value} ({val} failed)" ) - if not isinstance(val, (int, float)) or val <= 0 or val >= 1: + if not isinstance(val, int | float) or val <= 0 or val >= 1: raise ValueError( "max_error_rate must be a number between 0 and 1," f"received {value} ({val} failed)" diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py index 21d30ec8..e2583987 100644 --- a/src/guidellm/scheduler/objects.py +++ b/src/guidellm/scheduler/objects.py @@ -19,7 +19,6 @@ Literal, Protocol, TypeVar, - Union, runtime_checkable, ) @@ -56,10 +55,7 @@ MultiTurnRequestT = TypeAliasType( "MultiTurnRequestT", - Union[ - list[Union[RequestT, tuple[RequestT, float]]], - tuple[Union[RequestT, tuple[RequestT, float]]], - ], + list[RequestT | tuple[RequestT, float]] | tuple[RequestT | tuple[RequestT, float]], type_params=(RequestT,), ) """Multi-turn request structure supporting conversation history with optional delays.""" diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 5f2fb74b..104ab418 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -310,7 +310,7 @@ async def _process_next_request(self): # Pull request from the queue request, request_info = await self.messaging.get() - if isinstance(request, (list, tuple)): + if isinstance(request, list | tuple): raise NotImplementedError("Multi-turn requests are not yet supported") # Calculate targeted start and set pending state for request diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py index 6823fb77..916d6633 100644 --- a/src/guidellm/utils/encoding.py +++ b/src/guidellm/utils/encoding.py @@ -12,7 +12,7 @@ import json from collections.abc import Mapping -from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, cast +from typing import Any, ClassVar, Generic, Literal, TypeVar, cast try: import msgpack # type: ignore[import-untyped] # Optional dependency @@ -45,7 +45,6 @@ HAS_ORJSON = False from pydantic import BaseModel -from typing_extensions import TypeAlias __all__ = [ "Encoder", @@ -60,14 +59,10 @@ ObjT = TypeVar("ObjT") MsgT = TypeVar("MsgT") -SerializationTypesAlias: TypeAlias = Annotated[ - Optional[Literal["dict", "sequence"]], - "Type alias for available serialization strategies", -] -EncodingTypesAlias: TypeAlias = Annotated[ - Optional[Literal["msgpack", "msgspec"]], - "Type alias for available binary encoding formats", -] +# Type alias for available serialization strategies +SerializationTypesAlias = Literal["dict", "sequence"] | None +# "Type alias for available binary encoding formats" +EncodingTypesAlias = Literal["msgpack", "msgspec"] class MessageEncoding(Generic[ObjT, MsgT]): @@ -405,7 +400,7 @@ def to_dict(self, obj: Any) -> Any: if isinstance(obj, BaseModel): return self.to_dict_pydantic(obj) - if isinstance(obj, (list, tuple)) and any( + if isinstance(obj, list | tuple) and any( isinstance(item, BaseModel) for item in obj ): return [ @@ -432,7 +427,7 @@ def from_dict(self, data: Any) -> Any: :param data: Dictionary representation possibly containing type metadata :return: Reconstructed object with proper types restored """ - if isinstance(data, (list, tuple)): + if isinstance(data, list | tuple): return [ self.from_dict_pydantic(item) if isinstance(item, dict) and "*PYD*" in item @@ -493,7 +488,7 @@ def to_sequence(self, obj: Any) -> str | Any: if isinstance(obj, BaseModel): payload_type = "pydantic" payload = self.to_sequence_pydantic(obj) - elif isinstance(obj, (list, tuple)) and any( + elif isinstance(obj, list | tuple) and any( isinstance(item, BaseModel) for item in obj ): payload_type = "collection_sequence" @@ -694,33 +689,36 @@ def pack_next_sequence( # noqa: C901, PLR0912 length=(payload_len.bit_length() + 7) // 8 if payload_len > 0 else 1, byteorder="big", ) - if type_ == "pydantic": - payload_type = b"P" - elif type_ == "python": - payload_type = b"p" - elif type_ == "collection_tuple": - payload_type = b"T" - elif type_ == "collection_sequence": - payload_type = b"S" - elif type_ == "collection_mapping": - payload_type = b"M" - else: - raise ValueError(f"Unknown type for packing: {type_}") + match type_: + case "pydantic": + payload_type = b"P" + case "python": + payload_type = b"p" + case "collection_tuple": + payload_type = b"T" + case "collection_sequence": + payload_type = b"S" + case "collection_mapping": + payload_type = b"M" + case _: + raise ValueError(f"Unknown type for packing: {type_}") delimiter = b"|" else: payload_len_output = str(payload_len) - if type_ == "pydantic": - payload_type = "P" - elif type_ == "python": - payload_type = "p" - elif type_ == "collection_tuple": - payload_type = "T" - elif type_ == "collection_sequence": - payload_type = "S" - elif type_ == "collection_mapping": - payload_type = "M" - else: - raise ValueError(f"Unknown type for packing: {type_}") + + match type_: + case "pydantic": + payload_type = "P" + case "python": + payload_type = "p" + case "collection_tuple": + payload_type = "T" + case "collection_sequence": + payload_type = "S" + case "collection_mapping": + payload_type = "M" + case _: + raise ValueError(f"Unknown type for packing: {type_}") delimiter = "|" # Type ignores because types are enforced at runtime diff --git a/src/guidellm/utils/hf_datasets.py b/src/guidellm/utils/hf_datasets.py index 73e55ebc..86f04485 100644 --- a/src/guidellm/utils/hf_datasets.py +++ b/src/guidellm/utils/hf_datasets.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Union from datasets import Dataset @@ -11,7 +10,7 @@ } -def save_dataset_to_file(dataset: Dataset, output_path: Union[str, Path]) -> None: +def save_dataset_to_file(dataset: Dataset, output_path: str | Path) -> None: """ Saves a HuggingFace Dataset to file in a supported format. diff --git a/src/guidellm/utils/hf_transformers.py b/src/guidellm/utils/hf_transformers.py index 1f2aa1b5..636988c3 100644 --- a/src/guidellm/utils/hf_transformers.py +++ b/src/guidellm/utils/hf_transformers.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from transformers import AutoTokenizer, PreTrainedTokenizerBase # type: ignore[import] @@ -9,15 +9,15 @@ def check_load_processor( - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], - processor_args: Optional[dict[str, Any]], + processor: str | Path | PreTrainedTokenizerBase | None, + processor_args: dict[str, Any] | None, error_msg: str, ) -> PreTrainedTokenizerBase: if processor is None: raise ValueError(f"Processor/Tokenizer is required for {error_msg}.") try: - if isinstance(processor, (str, Path)): + if isinstance(processor, str | Path): loaded = AutoTokenizer.from_pretrained( processor, **(processor_args or {}), diff --git a/src/guidellm/utils/messaging.py b/src/guidellm/utils/messaging.py index 9311259d..4dce576d 100644 --- a/src/guidellm/utils/messaging.py +++ b/src/guidellm/utils/messaging.py @@ -16,13 +16,13 @@ import threading import time from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Callable, Iterable from multiprocessing.connection import Connection from multiprocessing.context import BaseContext from multiprocessing.managers import SyncManager from multiprocessing.synchronize import Event as ProcessingEvent from threading import Event as ThreadingEvent -from typing import Any, Callable, Generic, Protocol, TypeVar, cast +from typing import Any, Generic, Protocol, TypeVar, cast import culsans from pydantic import BaseModel @@ -420,7 +420,7 @@ def _create_check_stop_callable( stop_events = tuple( item for item in stop_criteria or [] - if isinstance(item, (ThreadingEvent, ProcessingEvent)) + if isinstance(item, ThreadingEvent | ProcessingEvent) ) stop_callbacks = tuple(item for item in stop_criteria or [] if callable(item)) diff --git a/src/guidellm/utils/mixins.py b/src/guidellm/utils/mixins.py index b001ff2d..7cf28d00 100644 --- a/src/guidellm/utils/mixins.py +++ b/src/guidellm/utils/mixins.py @@ -91,7 +91,7 @@ def create_info_dict(cls, obj: Any) -> dict[str, Any]: "attributes": ( { key: val - if isinstance(val, (str, int, float, bool, list, dict)) + if isinstance(val, str | int | float | bool | list | dict) else repr(val) for key, val in obj.__dict__.items() if not key.startswith("_") diff --git a/src/guidellm/utils/pydantic_utils.py b/src/guidellm/utils/pydantic_utils.py index 55816ef1..05f5ad81 100644 --- a/src/guidellm/utils/pydantic_utils.py +++ b/src/guidellm/utils/pydantic_utils.py @@ -11,11 +11,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, ClassVar, Generic, TypeVar, cast +from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema -from typing_extensions import get_args, get_origin from guidellm.utils.registry import RegistryMixin diff --git a/src/guidellm/utils/random.py b/src/guidellm/utils/random.py index ceef20b9..6c8f396d 100644 --- a/src/guidellm/utils/random.py +++ b/src/guidellm/utils/random.py @@ -1,6 +1,5 @@ import random from collections.abc import Iterator -from typing import Optional __all__ = ["IntegerRangeSampler"] @@ -9,9 +8,9 @@ class IntegerRangeSampler: def __init__( self, average: int, - variance: Optional[int], - min_value: Optional[int], - max_value: Optional[int], + variance: int | None, + min_value: int | None, + max_value: int | None, random_seed: int, ): self.average = average diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py index e6f1b657..e4727cbd 100644 --- a/src/guidellm/utils/registry.py +++ b/src/guidellm/utils/registry.py @@ -10,7 +10,8 @@ from __future__ import annotations -from typing import Any, Callable, ClassVar, Generic, TypeVar, cast +from collections.abc import Callable +from typing import Any, ClassVar, Generic, TypeVar, cast from guidellm.utils.auto_importer import AutoImporterMixin @@ -103,7 +104,7 @@ def register_decorator( if name is None: name = obj.__name__ - elif not isinstance(name, (str, list)): + elif not isinstance(name, str | list): raise ValueError( "RegistryMixin.register_decorator name must be a string or " f"an iterable of strings. Got {name}." diff --git a/src/guidellm/utils/statistics.py b/src/guidellm/utils/statistics.py index acd9d4f1..f71a2c24 100644 --- a/src/guidellm/utils/statistics.py +++ b/src/guidellm/utils/statistics.py @@ -149,7 +149,7 @@ def from_distribution_function( in the output :return: DistributionSummary instance with calculated statistical metrics """ - values, weights = zip(*distribution) if distribution else ([], []) + values, weights = zip(*distribution, strict=True) if distribution else ([], []) values = np.array(values) # type: ignore[assignment] weights = np.array(weights) # type: ignore[assignment] @@ -247,7 +247,7 @@ def from_values( ) return DistributionSummary.from_distribution_function( - distribution=list(zip(values, weights)), + distribution=list(zip(values, weights, strict=True)), include_cdf=include_cdf, ) @@ -389,7 +389,7 @@ def from_iterable_request_times( events[global_end] = 0 for (_, end), first_iter, first_iter_count, total_count in zip( - requests, first_iter_times, first_iter_counts, iter_counts + requests, first_iter_times, first_iter_counts, iter_counts, strict=True ): events[first_iter] += first_iter_count @@ -499,36 +499,36 @@ def from_values( ) _, successful_values, successful_weights = ( - zip(*successful) + zip(*successful, strict=True) if ( successful := list( filter( lambda val: val[0] == "successful", - zip(value_types, values, weights), + zip(value_types, values, weights, strict=True), ) ) ) else ([], [], []) ) _, incomplete_values, incomplete_weights = ( - zip(*incomplete) + zip(*incomplete, strict=True) if ( incomplete := list( filter( lambda val: val[0] == "incomplete", - zip(value_types, values, weights), + zip(value_types, values, weights, strict=True), ) ) ) else ([], [], []) ) _, errored_values, errored_weights = ( - zip(*errored) + zip(*errored, strict=True) if ( errored := list( filter( lambda val: val[0] == "error", - zip(value_types, values, weights), + zip(value_types, values, weights, strict=True), ) ) ) @@ -604,36 +604,36 @@ def from_request_times( ) _, successful_requests = ( - zip(*successful) + zip(*successful, strict=True) if ( successful := list( filter( lambda val: val[0] == "successful", - zip(request_types, requests), + zip(request_types, requests, strict=True), ) ) ) else ([], []) ) _, incomplete_requests = ( - zip(*incomplete) + zip(*incomplete, strict=True) if ( incomplete := list( filter( lambda val: val[0] == "incomplete", - zip(request_types, requests), + zip(request_types, requests, strict=True), ) ) ) else ([], []) ) _, errored_requests = ( - zip(*errored) + zip(*errored, strict=True) if ( errored := list( filter( lambda val: val[0] == "error", - zip(request_types, requests), + zip(request_types, requests, strict=True), ) ) ) @@ -734,7 +734,7 @@ def from_iterable_request_times( successful_iter_counts, successful_first_iter_counts, ) = ( - zip(*successful) + zip(*successful, strict=True) if ( successful := list( filter( @@ -745,6 +745,7 @@ def from_iterable_request_times( first_iter_times, iter_counts, first_iter_counts, + strict=True, ), ) ) @@ -758,7 +759,7 @@ def from_iterable_request_times( incomplete_iter_counts, incomplete_first_iter_counts, ) = ( - zip(*incomplete) + zip(*incomplete, strict=True) if ( incomplete := list( filter( @@ -769,6 +770,7 @@ def from_iterable_request_times( first_iter_times, iter_counts, first_iter_counts, + strict=True, ), ) ) @@ -782,7 +784,7 @@ def from_iterable_request_times( errored_iter_counts, errored_first_iter_counts, ) = ( - zip(*errored) + zip(*errored, strict=True) if ( errored := list( filter( @@ -793,6 +795,7 @@ def from_iterable_request_times( first_iter_times, iter_counts, first_iter_counts, + strict=True, ), ) ) @@ -904,7 +907,7 @@ def __add__(self, value: Any) -> float: :return: Updated mean after adding the value :raises ValueError: If value is not numeric (int or float) """ - if not isinstance(value, (int, float)): + if not isinstance(value, int | float): raise ValueError( f"Value must be an int or float, got {type(value)} instead.", ) @@ -921,7 +924,7 @@ def __iadd__(self, value: Any) -> RunningStats: :return: Self reference for method chaining :raises ValueError: If value is not numeric (int or float) """ - if not isinstance(value, (int, float)): + if not isinstance(value, int | float): raise ValueError( f"Value must be an int or float, got {type(value)} instead.", ) diff --git a/src/guidellm/utils/synchronous.py b/src/guidellm/utils/synchronous.py index 64c14e94..d37daec2 100644 --- a/src/guidellm/utils/synchronous.py +++ b/src/guidellm/utils/synchronous.py @@ -16,9 +16,6 @@ from multiprocessing.synchronize import Event as ProcessingEvent from threading import Barrier as ThreadingBarrier from threading import Event as ThreadingEvent -from typing import Annotated, Union - -from typing_extensions import TypeAlias __all__ = [ "SyncObjectTypesAlias", @@ -28,10 +25,10 @@ ] -SyncObjectTypesAlias: TypeAlias = Annotated[ - Union[ThreadingEvent, ProcessingEvent, ThreadingBarrier, ProcessingBarrier], - "Type alias for threading and multiprocessing synchronization object types", -] +# Type alias for threading and multiprocessing synchronization object types +SyncObjectTypesAlias = ( + ThreadingEvent | ProcessingEvent | ThreadingBarrier | ProcessingBarrier +) async def wait_for_sync_event( @@ -146,7 +143,7 @@ async def wait_for_sync_objects( tasks = [ asyncio.create_task( wait_for_sync_barrier(obj, poll_interval) - if isinstance(obj, (ThreadingBarrier, ProcessingBarrier)) + if isinstance(obj, ThreadingBarrier | ProcessingBarrier) else wait_for_sync_event(obj, poll_interval) ) for obj in objects diff --git a/src/guidellm/utils/typing.py b/src/guidellm/utils/typing.py index 8146ea1e..8d3580ef 100644 --- a/src/guidellm/utils/typing.py +++ b/src/guidellm/utils/typing.py @@ -1,14 +1,9 @@ from __future__ import annotations from collections.abc import Iterator +from types import UnionType from typing import Annotated, Literal, Union, get_args, get_origin -# Backwards compatibility for Python <3.10 -try: - from types import UnionType # type: ignore[attr-defined] -except ImportError: - UnionType = Union - # Backwards compatibility for Python <3.12 try: from typing import TypeAliasType # type: ignore[attr-defined] diff --git a/tests/integration/scheduler/test_scheduler.py b/tests/integration/scheduler/test_scheduler.py index 51abf59b..65bff95f 100644 --- a/tests/integration/scheduler/test_scheduler.py +++ b/tests/integration/scheduler/test_scheduler.py @@ -167,7 +167,7 @@ def _request_indices(): _request_indices(), received_updates.keys(), received_updates.values(), - received_responses, + received_responses, strict=False, ): assert req == f"req_{index}" assert resp in (f"response_for_{req}", f"mock_error_for_{req}") diff --git a/tests/unit/benchmark/test_output.py b/tests/unit/benchmark/test_output.py index 6763d978..6310da88 100644 --- a/tests/unit/benchmark/test_output.py +++ b/tests/unit/benchmark/test_output.py @@ -10,7 +10,10 @@ from guidellm.benchmark import ( GenerativeBenchmarksReport, ) -from guidellm.benchmark.output import GenerativeBenchmarkerConsole, GenerativeBenchmarkerCSV +from guidellm.benchmark.output import ( + GenerativeBenchmarkerConsole, + GenerativeBenchmarkerCSV, +) from tests.unit.mock_benchmark import mock_generative_benchmark @@ -80,6 +83,7 @@ def test_file_yaml(): mock_path.unlink() + @pytest.mark.asyncio async def test_file_csv(): mock_benchmark = mock_generative_benchmark() @@ -89,7 +93,7 @@ async def test_file_csv(): csv_benchmarker = GenerativeBenchmarkerCSV(output_path=mock_path) await csv_benchmarker.finalize(report) - with mock_path.open("r") as file: + with mock_path.open("r") as file: # noqa: ASYNC230 # This is a test. reader = csv.reader(file) headers = next(reader) rows = list(reader) @@ -105,7 +109,8 @@ def test_console_benchmarks_profile_str(): console = GenerativeBenchmarkerConsole() mock_benchmark = mock_generative_benchmark() assert ( - console._get_profile_str(mock_benchmark) == "type=synchronous, strategies=['synchronous']" + console._get_profile_str(mock_benchmark) + == "type=synchronous, strategies=['synchronous']" ) diff --git a/tests/unit/dataset/test_synthetic.py b/tests/unit/dataset/test_synthetic.py index e3110fa3..544634c8 100644 --- a/tests/unit/dataset/test_synthetic.py +++ b/tests/unit/dataset/test_synthetic.py @@ -530,7 +530,7 @@ def mock_sampler_side_effect(*args, **kwargs): # Results should be identical with same seed assert len(items1) == len(items2) - for item1, item2 in zip(items1, items2): + for item1, item2 in zip(items1, items2, strict=False): assert item1["prompt"] == item2["prompt"] assert item1["prompt_tokens_count"] == item2["prompt_tokens_count"] assert item1["output_tokens_count"] == item2["output_tokens_count"] diff --git a/tests/unit/mock_backend.py b/tests/unit/mock_backend.py index 5ac069a8..3b7237e0 100644 --- a/tests/unit/mock_backend.py +++ b/tests/unit/mock_backend.py @@ -6,7 +6,7 @@ import random import time from collections.abc import AsyncIterator -from typing import Any, Optional +from typing import Any from lorem.text import TextLorem @@ -32,7 +32,7 @@ def __init__( self, target: str = "mock-target", model: str = "mock-model", - iter_delay: Optional[float] = None, + iter_delay: float | None = None, ): """ Initialize mock backend. @@ -53,7 +53,7 @@ def target(self) -> str: return self._target @property - def model(self) -> Optional[str]: + def model(self) -> str | None: """Model name for the mock backend.""" return self._model @@ -87,7 +87,7 @@ async def validate(self) -> None: if not self._in_process: raise RuntimeError("Backend not started up for process") - async def default_model(self) -> Optional[str]: + async def default_model(self) -> str | None: """ Return the default model for the mock backend. """ @@ -97,7 +97,7 @@ async def resolve( self, request: GenerationRequest, request_info: ScheduledRequestInfo, - history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + history: list[tuple[GenerationRequest, GenerationResponse]] | None = None, ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: """ Process a generation request and yield progressive responses. @@ -170,7 +170,7 @@ def _estimate_prompt_tokens(content: str) -> int: return len(str(content).split()) @staticmethod - def _get_tokens(token_count: Optional[int] = None) -> list[str]: + def _get_tokens(token_count: int | None = None) -> list[str]: """ Generate mock tokens for response. """ diff --git a/tests/unit/mock_benchmark.py b/tests/unit/mock_benchmark.py index cdf4375a..d7bfe7c9 100644 --- a/tests/unit/mock_benchmark.py +++ b/tests/unit/mock_benchmark.py @@ -1,4 +1,5 @@ """Mock benchmark objects for unit testing.""" + from guidellm.backends import GenerationRequestTimings from guidellm.benchmark import ( BenchmarkSchedulerStats, diff --git a/tests/unit/mock_server/test_server.py b/tests/unit/mock_server/test_server.py index 008103c3..ba712fb6 100644 --- a/tests/unit/mock_server/test_server.py +++ b/tests/unit/mock_server/test_server.py @@ -162,7 +162,7 @@ async def test_health_endpoint(self, mock_server_instance): assert "status" in data assert data["status"] == "healthy" assert "timestamp" in data - assert isinstance(data["timestamp"], (int, float)) + assert isinstance(data["timestamp"], int | float) @pytest.mark.smoke @pytest.mark.asyncio diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py index fc5610fd..2fc4c86f 100644 --- a/tests/unit/scheduler/test_objects.py +++ b/tests/unit/scheduler/test_objects.py @@ -3,6 +3,7 @@ import inspect import typing from collections.abc import AsyncIterator +from types import UnionType from typing import Any, Literal, Optional, TypeVar, Union import pytest @@ -62,8 +63,7 @@ def test_multi_turn_request_t(): assert MultiTurnRequestT.__name__ == "MultiTurnRequestT" value = MultiTurnRequestT.__value__ - assert hasattr(value, "__origin__") - assert value.__origin__ is Union + assert isinstance(value, UnionType) type_params = getattr(MultiTurnRequestT, "__type_params__", ()) assert len(type_params) == 1 @@ -340,7 +340,7 @@ def test_class_signatures(self): for key in self.CHECK_KEYS: assert key in fields field_info = fields[key] - assert field_info.annotation in (Union[float, None], Optional[float]) + assert field_info.annotation in (Union[float, None], Optional[float]) # noqa: UP007 assert field_info.default is None @pytest.mark.smoke @@ -453,7 +453,7 @@ def test_class_signatures(self): for key in self.CHECK_KEYS: assert key in fields field_info = fields[key] - assert field_info.annotation in (Union[float, None], Optional[float]) + assert field_info.annotation in (Union[float, None], Optional[float]) # noqa: UP007 assert field_info.default is None @pytest.mark.smoke @@ -704,11 +704,11 @@ def test_marshalling(self, valid_instances): else: assert original_value is None or isinstance( original_value, - (RequestSchedulerTimings, MeasuredRequestTimings), + RequestSchedulerTimings | MeasuredRequestTimings, ) assert reconstructed_value is None or isinstance( reconstructed_value, - (RequestSchedulerTimings, MeasuredRequestTimings), + RequestSchedulerTimings | MeasuredRequestTimings, ) else: assert original_value == reconstructed_value diff --git a/tests/unit/scheduler/test_strategies.py b/tests/unit/scheduler/test_strategies.py index 67a2d77d..143a3130 100644 --- a/tests/unit/scheduler/test_strategies.py +++ b/tests/unit/scheduler/test_strategies.py @@ -225,7 +225,7 @@ def test_lifecycle( for index in range(max(5, startup_requests + 2)): offset = instance.next_offset() - assert isinstance(offset, (int, float)) + assert isinstance(offset, int | float) if index < startup_requests: expected_offset = initial_offset + (index + 1) * startup_delay diff --git a/tests/unit/utils/test_encoding.py b/tests/unit/utils/test_encoding.py index cc4600cf..5664bcb0 100644 --- a/tests/unit/utils/test_encoding.py +++ b/tests/unit/utils/test_encoding.py @@ -476,7 +476,7 @@ def test_to_from_sequence_collections(self, collection): seq = inst.to_sequence(collection) out = inst.from_sequence(seq) assert len(out) == len(collection) - assert all(a == b for a, b in zip(out, list(collection))) + assert all(a == b for a, b in zip(out, list(collection), strict=False)) @pytest.mark.sanity def test_to_from_sequence_mapping(self): diff --git a/tests/unit/utils/test_synchronous.py b/tests/unit/utils/test_synchronous.py index 1a9ea2c9..7acd5b4a 100644 --- a/tests/unit/utils/test_synchronous.py +++ b/tests/unit/utils/test_synchronous.py @@ -6,7 +6,7 @@ from functools import wraps from multiprocessing.synchronize import Barrier as ProcessingBarrier from multiprocessing.synchronize import Event as ProcessingEvent -from typing import Union +from typing import get_args import pytest @@ -32,17 +32,25 @@ async def new_func(*args, **kwargs): def test_sync_object_types_alias(): - """Test that SyncObjectTypesAlias is defined correctly as a type alias.""" - assert hasattr(SyncObjectTypesAlias, "__origin__") - if hasattr(SyncObjectTypesAlias, "__args__"): - actual_type = SyncObjectTypesAlias.__args__[0] - assert hasattr(actual_type, "__origin__") - assert actual_type.__origin__ is Union - union_args = actual_type.__args__ - assert threading.Event in union_args - assert ProcessingEvent in union_args - assert threading.Barrier in union_args - assert ProcessingBarrier in union_args + """ + Test that SyncObjectTypesAlias is defined correctly as a type alias. + + ## WRITTEN BY AI ## + """ + # Get the actual types from the union alias + actual_types = get_args(SyncObjectTypesAlias) + + # Define the set of expected types + expected_types = { + threading.Event, + ProcessingEvent, + threading.Barrier, + ProcessingBarrier, + } + + # Assert that the set of actual types matches the expected set. + # Using a set comparison is robust as it ignores the order. + assert set(actual_types) == expected_types class TestWaitForSyncEvent: @@ -226,7 +234,7 @@ async def test_invocation(self, objects_types, expected_result): async def set_target(): await asyncio.sleep(0.01) obj = objects[expected_result] - if isinstance(obj, (threading.Event, ProcessingEvent)): + if isinstance(obj, threading.Event | ProcessingEvent): obj.set() else: await asyncio.to_thread(obj.wait) diff --git a/tests/unit/utils/test_typing.py b/tests/unit/utils/test_typing.py index fafa8765..1e31ef8e 100644 --- a/tests/unit/utils/test_typing.py +++ b/tests/unit/utils/test_typing.py @@ -2,10 +2,9 @@ Test suite for the typing utilities module. """ -from typing import Annotated, Literal, Union +from typing import Annotated, Literal, TypeAlias import pytest -from typing_extensions import TypeAlias from guidellm.utils.typing import get_literal_vals @@ -15,7 +14,7 @@ Literal["synchronous", "concurrent", "throughput", "constant", "poisson"], "Valid strategy type identifiers for scheduling request patterns", ] -StrategyProfileType: TypeAlias = Union[LocalStrategyType, LocalProfileType] +StrategyProfileType: TypeAlias = LocalStrategyType | LocalProfileType class TestGetLiteralVals: @@ -54,7 +53,7 @@ def test_inline_union_type(self): ### WRITTEN BY AI ### """ - result = get_literal_vals(Union[LocalProfileType, LocalStrategyType]) + result = get_literal_vals(LocalProfileType | LocalStrategyType) expected = frozenset( { "synchronous", @@ -118,6 +117,6 @@ def test_literal_union(self): ### WRITTEN BY AI ### """ - result = get_literal_vals(Union[Literal["test", "test2"], Literal["test3"]]) + result = get_literal_vals(Literal["test", "test2"] | Literal["test3"]) expected = frozenset({"test", "test2", "test3"}) assert result == expected