diff --git a/pyproject.toml b/pyproject.toml
index 935587d0..f1624d3f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -167,7 +167,7 @@ ignore_missing_imports = true
target-version = "py310"
line-length = 88
indent-width = 4
-exclude = ["build", "dist", "env", ".venv"]
+exclude = ["build", "dist", "env", ".venv*"]
[tool.ruff.format]
quote-style = "double"
diff --git a/setup.py b/setup.py
index 623bad28..d3b92889 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,6 @@
import os
import re
from pathlib import Path
-from typing import Optional, Union
from packaging.version import Version
from setuptools import setup
@@ -11,7 +10,7 @@
TAG_VERSION_PATTERN = re.compile(r"^v(\d+\.\d+\.\d+)$")
-def get_last_version_diff() -> tuple[Version, Optional[str], Optional[int]]:
+def get_last_version_diff() -> tuple[Version, str | None, int | None]:
"""
Get the last version, last tag, and the number of commits since the last tag.
If no tags are found, return the last release version and None for the tag/commits.
@@ -38,8 +37,8 @@ def get_last_version_diff() -> tuple[Version, Optional[str], Optional[int]]:
def get_next_version(
- build_type: str, build_iteration: Optional[Union[str, int]]
-) -> tuple[Version, Optional[str], int]:
+ build_type: str, build_iteration: str | int | None
+) -> tuple[Version, str | None, int]:
"""
Get the next version based on the build type and iteration.
- build_type == release: take the last version and add a post if build iteration
diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py
index 0a035551..dbc8e1da 100644
--- a/src/guidellm/__main__.py
+++ b/src/guidellm/__main__.py
@@ -28,7 +28,7 @@
import asyncio
import codecs
from pathlib import Path
-from typing import Annotated, Union
+from typing import Annotated
import click
from pydantic import ValidationError
@@ -78,9 +78,8 @@
"run",
]
-STRATEGY_PROFILE_CHOICES: Annotated[
- list[str], "Available strategy and profile choices for benchmark execution types"
-] = list(get_literal_vals(Union[ProfileType, StrategyType]))
+# Available strategy and profile choices for benchmark execution types
+STRATEGY_PROFILE_CHOICES: list[str] = list(get_literal_vals(ProfileType | StrategyType))
def decode_escaped_str(_ctx, _param, value):
diff --git a/src/guidellm/backends/objects.py b/src/guidellm/backends/objects.py
index 05280940..001aeb70 100644
--- a/src/guidellm/backends/objects.py
+++ b/src/guidellm/backends/objects.py
@@ -7,7 +7,7 @@
"""
import uuid
-from typing import Any, Literal, Optional
+from typing import Any, Literal
from pydantic import Field
@@ -73,32 +73,32 @@ class GenerationResponse(StandardBaseModel):
request_args: dict[str, Any] = Field(
description="Arguments passed to the backend for this request."
)
- value: Optional[str] = Field(
+ value: str | None = Field(
default=None,
description="Complete generated text content. None for streaming responses.",
)
- delta: Optional[str] = Field(
+ delta: str | None = Field(
default=None, description="Incremental text content for streaming responses."
)
iterations: int = Field(
default=0, description="Number of generation iterations completed."
)
- request_prompt_tokens: Optional[int] = Field(
+ request_prompt_tokens: int | None = Field(
default=None, description="Token count from the original request prompt."
)
- request_output_tokens: Optional[int] = Field(
+ request_output_tokens: int | None = Field(
default=None,
description="Expected output token count from the original request.",
)
- response_prompt_tokens: Optional[int] = Field(
+ response_prompt_tokens: int | None = Field(
default=None, description="Actual prompt token count reported by the backend."
)
- response_output_tokens: Optional[int] = Field(
+ response_output_tokens: int | None = Field(
default=None, description="Actual output token count reported by the backend."
)
@property
- def prompt_tokens(self) -> Optional[int]:
+ def prompt_tokens(self) -> int | None:
"""
:return: The number of prompt tokens used in the request
(response_prompt_tokens if available, otherwise request_prompt_tokens).
@@ -106,7 +106,7 @@ def prompt_tokens(self) -> Optional[int]:
return self.response_prompt_tokens or self.request_prompt_tokens
@property
- def output_tokens(self) -> Optional[int]:
+ def output_tokens(self) -> int | None:
"""
:return: The number of output tokens generated in the response
(response_output_tokens if available, otherwise request_output_tokens).
@@ -114,7 +114,7 @@ def output_tokens(self) -> Optional[int]:
return self.response_output_tokens or self.request_output_tokens
@property
- def total_tokens(self) -> Optional[int]:
+ def total_tokens(self) -> int | None:
"""
:return: The total number of tokens used in the request and response.
Sum of prompt_tokens and output_tokens.
@@ -125,7 +125,7 @@ def total_tokens(self) -> Optional[int]:
def preferred_prompt_tokens(
self, preferred_source: Literal["request", "response"]
- ) -> Optional[int]:
+ ) -> int | None:
if preferred_source == "request":
return self.request_prompt_tokens or self.response_prompt_tokens
else:
@@ -133,7 +133,7 @@ def preferred_prompt_tokens(
def preferred_output_tokens(
self, preferred_source: Literal["request", "response"]
- ) -> Optional[int]:
+ ) -> int | None:
if preferred_source == "request":
return self.request_output_tokens or self.response_output_tokens
else:
@@ -146,11 +146,11 @@ class GenerationRequestTimings(MeasuredRequestTimings):
"""Timing model for tracking generation request lifecycle events."""
timings_type: Literal["generation_request_timings"] = "generation_request_timings"
- first_iteration: Optional[float] = Field(
+ first_iteration: float | None = Field(
default=None,
description="Unix timestamp when the first generation iteration began.",
)
- last_iteration: Optional[float] = Field(
+ last_iteration: float | None = Field(
default=None,
description="Unix timestamp when the last generation iteration completed.",
)
diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py
index ce83076f..c8eb70f3 100644
--- a/src/guidellm/backends/openai.py
+++ b/src/guidellm/backends/openai.py
@@ -17,7 +17,7 @@
import time
from collections.abc import AsyncIterator
from pathlib import Path
-from typing import Any, ClassVar, Optional, Union
+from typing import Any, ClassVar
import httpx
from PIL import Image
@@ -33,13 +33,15 @@
__all__ = ["OpenAIHTTPBackend", "UsageStats"]
+ContentT = str | list[str | dict[str, str | dict[str, str]] | Path | Image.Image] | Any
+
@dataclasses.dataclass
class UsageStats:
"""Token usage statistics for generation requests."""
- prompt_tokens: Optional[int] = None
- output_tokens: Optional[int] = None
+ prompt_tokens: int | None = None
+ output_tokens: int | None = None
@Backend.register("openai_http")
@@ -78,19 +80,19 @@ class OpenAIHTTPBackend(Backend):
def __init__(
self,
target: str,
- model: Optional[str] = None,
- api_key: Optional[str] = None,
- organization: Optional[str] = None,
- project: Optional[str] = None,
+ model: str | None = None,
+ api_key: str | None = None,
+ organization: str | None = None,
+ project: str | None = None,
timeout: float = 60.0,
http2: bool = True,
follow_redirects: bool = True,
- max_output_tokens: Optional[int] = None,
+ max_output_tokens: int | None = None,
stream_response: bool = True,
- extra_query: Optional[dict] = None,
- extra_body: Optional[dict] = None,
- remove_from_body: Optional[list[str]] = None,
- headers: Optional[dict] = None,
+ extra_query: dict | None = None,
+ extra_body: dict | None = None,
+ remove_from_body: list[str] | None = None,
+ headers: dict | None = None,
verify: bool = False,
):
"""
@@ -137,7 +139,7 @@ def __init__(
# Runtime state
self._in_process = False
- self._async_client: Optional[httpx.AsyncClient] = None
+ self._async_client: httpx.AsyncClient | None = None
@property
def info(self) -> dict[str, Any]:
@@ -264,7 +266,7 @@ async def available_models(self) -> list[str]:
return [item["id"] for item in response.json()["data"]]
- async def default_model(self) -> Optional[str]:
+ async def default_model(self) -> str | None:
"""
Get the default model for this backend.
@@ -280,7 +282,7 @@ async def resolve(
self,
request: GenerationRequest,
request_info: ScheduledRequestInfo,
- history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None,
+ history: list[tuple[GenerationRequest, GenerationResponse]] | None = None,
) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]:
"""
Process a generation request and yield progressive responses.
@@ -363,12 +365,12 @@ async def resolve(
async def text_completions(
self,
- prompt: Union[str, list[str]],
- request_id: Optional[str], # noqa: ARG002
- output_token_count: Optional[int] = None,
+ prompt: str | list[str],
+ request_id: str | None, # noqa: ARG002
+ output_token_count: int | None = None,
stream_response: bool = True,
**kwargs,
- ) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]:
+ ) -> AsyncIterator[tuple[str | None, UsageStats | None]]:
"""
Generate text completions using the /v1/completions endpoint.
@@ -431,17 +433,13 @@ async def text_completions(
async def chat_completions(
self,
- content: Union[
- str,
- list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]],
- Any,
- ],
- request_id: Optional[str] = None, # noqa: ARG002
- output_token_count: Optional[int] = None,
+ content: ContentT,
+ request_id: str | None = None, # noqa: ARG002
+ output_token_count: int | None = None,
raw_content: bool = False,
stream_response: bool = True,
**kwargs,
- ) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]:
+ ) -> AsyncIterator[tuple[str | None, UsageStats | None]]:
"""
Generate chat completions using the /v1/chat/completions endpoint.
@@ -502,10 +500,10 @@ async def chat_completions(
def _build_headers(
self,
- api_key: Optional[str],
- organization: Optional[str],
- project: Optional[str],
- user_headers: Optional[dict],
+ api_key: str | None,
+ organization: str | None,
+ project: str | None,
+ user_headers: dict | None,
) -> dict[str, str]:
headers = {}
@@ -541,11 +539,7 @@ def _get_params(self, endpoint_type: str) -> dict[str, str]:
def _get_chat_messages(
self,
- content: Union[
- str,
- list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]],
- Any,
- ],
+ content: ContentT,
) -> list[dict[str, Any]]:
if isinstance(content, str):
return [{"role": "user", "content": content}]
@@ -559,7 +553,7 @@ def _get_chat_messages(
resolved_content.append(item)
elif isinstance(item, str):
resolved_content.append({"type": "text", "text": item})
- elif isinstance(item, (Image.Image, Path)):
+ elif isinstance(item, Image.Image | Path):
resolved_content.append(self._get_chat_message_media_item(item))
else:
raise ValueError(f"Unsupported content item type: {type(item)}")
@@ -567,7 +561,7 @@ def _get_chat_messages(
return [{"role": "user", "content": resolved_content}]
def _get_chat_message_media_item(
- self, item: Union[Path, Image.Image]
+ self, item: Path | Image.Image
) -> dict[str, Any]:
if isinstance(item, Image.Image):
encoded = base64.b64encode(item.tobytes()).decode("utf-8")
@@ -597,8 +591,8 @@ def _get_chat_message_media_item(
def _get_body(
self,
endpoint_type: str,
- request_kwargs: Optional[dict[str, Any]],
- max_output_tokens: Optional[int] = None,
+ request_kwargs: dict[str, Any] | None,
+ max_output_tokens: int | None = None,
**kwargs,
) -> dict[str, Any]:
# Start with endpoint-specific extra body parameters
@@ -628,7 +622,7 @@ def _get_body(
return {key: val for key, val in body.items() if val is not None}
- def _get_completions_text_content(self, data: dict) -> Optional[str]:
+ def _get_completions_text_content(self, data: dict) -> str | None:
if not data.get("choices"):
return None
@@ -639,7 +633,7 @@ def _get_completions_text_content(self, data: dict) -> Optional[str]:
or choice.get("message", {}).get("content")
)
- def _get_completions_usage_stats(self, data: dict) -> Optional[UsageStats]:
+ def _get_completions_usage_stats(self, data: dict) -> UsageStats | None:
if not data.get("usage"):
return None
diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py
index e965c482..b33a7b14 100644
--- a/src/guidellm/benchmark/aggregator.py
+++ b/src/guidellm/benchmark/aggregator.py
@@ -267,7 +267,7 @@ def resolve(
resolved = {}
for key, val in aggregators.items():
- if isinstance(val, (Aggregator, CompilableAggregator)):
+ if isinstance(val, Aggregator | CompilableAggregator):
resolved[key] = val
else:
aggregator_class = cls.get_registered_object(key)
@@ -975,7 +975,7 @@ def _calculate_requests_per_second(
filtered_statuses = []
filtered_times = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(
safe_getattr(request.scheduler_info.request_timings, "request_start"),
safe_getattr(request.scheduler_info.request_timings, "request_end"),
@@ -1005,7 +1005,7 @@ def _calculate_request_concurrency(
filtered_statuses = []
filtered_times = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(
safe_getattr(request.scheduler_info.request_timings, "request_start"),
safe_getattr(request.scheduler_info.request_timings, "request_end"),
@@ -1035,7 +1035,7 @@ def _calculate_request_latency(
filtered_statuses = []
filtered_values = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(request.request_latency):
continue
@@ -1056,7 +1056,7 @@ def _calculate_prompt_token_count(
filtered_statuses = []
filtered_values = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(request.prompt_tokens):
continue
@@ -1077,7 +1077,7 @@ def _calculate_output_token_count(
filtered_statuses = []
filtered_values = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(request.output_tokens):
continue
@@ -1098,7 +1098,7 @@ def _calculate_total_token_count(
filtered_statuses = []
filtered_values = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(request.total_tokens):
continue
@@ -1119,7 +1119,7 @@ def _calculate_time_to_first_token_ms(
filtered_statuses = []
filtered_values = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(request.time_to_first_token_ms):
continue
@@ -1141,7 +1141,7 @@ def _calculate_time_per_output_token_ms(
filtered_values = []
filtered_weights = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(request.time_to_first_token_ms):
continue
@@ -1174,7 +1174,7 @@ def _calculate_inter_token_latency_ms(
filtered_values = []
filtered_weights = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(request.inter_token_latency_ms):
continue
@@ -1199,7 +1199,7 @@ def _calculate_output_tokens_per_second(
filtered_first_iter_times = []
filtered_iter_counts = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(request.output_tokens_per_second):
continue
@@ -1234,7 +1234,7 @@ def _calculate_tokens_per_second(
filtered_iter_counts = []
filtered_first_iter_counts = []
- for status, request in zip(statuses, requests):
+ for status, request in zip(statuses, requests, strict=False):
if not all_defined(request.tokens_per_second):
continue
diff --git a/src/guidellm/benchmark/benchmarker.py b/src/guidellm/benchmark/benchmarker.py
index 5f05065a..99410e4c 100644
--- a/src/guidellm/benchmark/benchmarker.py
+++ b/src/guidellm/benchmark/benchmarker.py
@@ -228,12 +228,12 @@ def _combine(
existing: dict[str, Any] | StandardBaseDict,
addition: dict[str, Any] | StandardBaseDict,
) -> dict[str, Any] | StandardBaseDict:
- if not isinstance(existing, (dict, StandardBaseDict)):
+ if not isinstance(existing, dict | StandardBaseDict):
raise ValueError(
f"Existing value {existing} (type: {type(existing).__name__}) "
f"is not a valid type for merging."
)
- if not isinstance(addition, (dict, StandardBaseDict)):
+ if not isinstance(addition, dict | StandardBaseDict):
raise ValueError(
f"Addition value {addition} (type: {type(addition).__name__}) "
f"is not a valid type for merging."
diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py
index 56775dac..cacadc94 100644
--- a/src/guidellm/benchmark/output.py
+++ b/src/guidellm/benchmark/output.py
@@ -34,10 +34,11 @@
DistributionSummary,
RegistryMixin,
StatusDistributionSummary,
+ camelize_str,
+ recursive_key_update,
safe_format_timestamp,
split_text_list_by_length,
)
-from guidellm.utils import recursive_key_update, camelize_str
__all__ = [
"GenerativeBenchmarkerCSV",
@@ -90,7 +91,7 @@ def resolve(
if not output_formats:
return {}
- if isinstance(output_formats, (list, tuple)):
+ if isinstance(output_formats, list | tuple):
# support list of output keys: ["csv", "json"]
# support list of files: ["path/to/file.json", "path/to/file.csv"]
formats_list = output_formats
@@ -369,7 +370,7 @@ def _print_line(
f"Value and style length mismatch: {len(value)} vs {len(style)}"
)
- for val, sty in zip(value, style):
+ for val, sty in zip(value, style, strict=False):
text.append(val, style=sty)
self.console.print(Padding.indent(text, indent))
@@ -568,8 +569,8 @@ async def finalize(self, report: GenerativeBenchmarksReport) -> Path:
benchmark_values: list[str | float | list[float]] = []
# Add basic run description info
- desc_headers, desc_values = (
- self._get_benchmark_desc_headers_and_values(benchmark)
+ desc_headers, desc_values = self._get_benchmark_desc_headers_and_values(
+ benchmark
)
benchmark_headers.extend(desc_headers)
benchmark_values.extend(desc_values)
@@ -680,7 +681,8 @@ def _get_benchmark_status_metrics_stats(
return headers, values
def _get_benchmark_extras_headers_and_values(
- self, benchmark: GenerativeBenchmark,
+ self,
+ benchmark: GenerativeBenchmark,
) -> tuple[list[str], list[str]]:
headers = ["Profile", "Backend", "Generator Data"]
values: list[str] = [
@@ -733,9 +735,7 @@ async def finalize(self, report: GenerativeBenchmarksReport) -> Path:
ui_api_data = {}
for k, v in camel_data.items():
placeholder_key = f"window.{k} = {{}};"
- replacement_value = (
- f"window.{k} = {json.dumps(v, indent=2)};\n"
- )
+ replacement_value = f"window.{k} = {json.dumps(v, indent=2)};\n"
ui_api_data[placeholder_key] = replacement_value
create_report(ui_api_data, output_path)
diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py
index 3ff8d0e0..87a9a2be 100644
--- a/src/guidellm/benchmark/profile.py
+++ b/src/guidellm/benchmark/profile.py
@@ -679,7 +679,10 @@ def next_strategy(
prev_benchmark.metrics.requests_per_second.successful.mean
)
if self.synchronous_rate <= 0 and self.throughput_rate <= 0:
- raise RuntimeError("Invalid rates in sweep; aborting. Were there any successful requests?")
+ raise RuntimeError(
+ "Invalid rates in sweep; aborting. "
+ "Were there any successful requests?"
+ )
self.measured_rates = list(
np.linspace(
self.synchronous_rate,
diff --git a/src/guidellm/benchmark/scenario.py b/src/guidellm/benchmark/scenario.py
index b53ef424..73a9a050 100644
--- a/src/guidellm/benchmark/scenario.py
+++ b/src/guidellm/benchmark/scenario.py
@@ -1,10 +1,11 @@
from __future__ import annotations
import json
+from collections.abc import Callable
from functools import cache, wraps
from inspect import Parameter, signature
from pathlib import Path
-from typing import Annotated, Any, Callable, Literal, TypeVar
+from typing import Annotated, Any, Literal, TypeVar
import yaml
from loguru import logger
@@ -38,7 +39,7 @@ def parse_float_list(value: str | float | list[float]) -> list[float]:
or convert single float list of one or pass float
list through.
"""
- if isinstance(value, (int, float)):
+ if isinstance(value, int | float):
return [value]
elif isinstance(value, list):
return value
diff --git a/src/guidellm/dataset/creator.py b/src/guidellm/dataset/creator.py
index a74ec8c0..fe712c23 100644
--- a/src/guidellm/dataset/creator.py
+++ b/src/guidellm/dataset/creator.py
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import Any, Literal, Optional, Union
+from typing import Any, Literal
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from transformers import PreTrainedTokenizerBase # type: ignore[import]
@@ -80,12 +80,12 @@ class DatasetCreator(ABC):
def create(
cls,
data: Any,
- data_args: Optional[dict[str, Any]],
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]],
- processor_args: Optional[dict[str, Any]],
+ data_args: dict[str, Any] | None,
+ processor: str | Path | PreTrainedTokenizerBase | None,
+ processor_args: dict[str, Any] | None,
random_seed: int = 42,
- split_pref_order: Optional[list[str]] = None,
- ) -> tuple[Union[Dataset, IterableDataset], dict[ColumnInputTypes, str]]:
+ split_pref_order: list[str] | None = None,
+ ) -> tuple[Dataset | IterableDataset, dict[ColumnInputTypes, str]]:
if not cls.is_supported(data, data_args):
raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ")
@@ -95,10 +95,10 @@ def create(
data, data_args, processor, processor_args, random_seed
)
- if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
+ if isinstance(dataset, DatasetDict | IterableDatasetDict):
dataset = cls.extract_dataset_split(dataset, split, split_pref_order)
- if not isinstance(dataset, (Dataset, IterableDataset)):
+ if not isinstance(dataset, Dataset | IterableDataset):
raise ValueError(
f"Unsupported data type: {type(dataset)} given for {dataset}."
)
@@ -106,7 +106,7 @@ def create(
return dataset, column_mappings
@classmethod
- def extract_args_split(cls, data_args: Optional[dict[str, Any]]) -> str:
+ def extract_args_split(cls, data_args: dict[str, Any] | None) -> str:
split = "auto"
if data_args and "split" in data_args:
@@ -118,7 +118,7 @@ def extract_args_split(cls, data_args: Optional[dict[str, Any]]) -> str:
@classmethod
def extract_args_column_mappings(
cls,
- data_args: Optional[dict[str, Any]],
+ data_args: dict[str, Any] | None,
) -> dict[ColumnInputTypes, str]:
columns: dict[ColumnInputTypes, str] = {}
@@ -143,12 +143,12 @@ def extract_args_column_mappings(
@classmethod
def extract_dataset_name(
- cls, dataset: Union[Dataset, IterableDataset, DatasetDict, IterableDatasetDict]
- ) -> Optional[str]:
- if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
+ cls, dataset: Dataset | IterableDataset | DatasetDict | IterableDatasetDict
+ ) -> str | None:
+ if isinstance(dataset, DatasetDict | IterableDatasetDict):
dataset = dataset[list(dataset.keys())[0]]
- if isinstance(dataset, (Dataset, IterableDataset)):
+ if isinstance(dataset, Dataset | IterableDataset):
if not hasattr(dataset, "info") or not hasattr(
dataset.info, "dataset_name"
):
@@ -161,11 +161,11 @@ def extract_dataset_name(
@classmethod
def extract_dataset_split(
cls,
- dataset: Union[DatasetDict, IterableDatasetDict],
- specified_split: Union[Literal["auto"], str] = "auto",
- split_pref_order: Optional[Union[Literal["auto"], list[str]]] = "auto",
- ) -> Union[Dataset, IterableDataset]:
- if not isinstance(dataset, (DatasetDict, IterableDatasetDict)):
+ dataset: DatasetDict | IterableDatasetDict,
+ specified_split: Literal["auto"] | str = "auto",
+ split_pref_order: Literal["auto"] | list[str] | None = "auto",
+ ) -> Dataset | IterableDataset:
+ if not isinstance(dataset, DatasetDict | IterableDatasetDict):
raise ValueError(
f"Unsupported data type: {type(dataset)} given for {dataset}."
)
@@ -199,15 +199,15 @@ def extract_dataset_split(
@classmethod
@abstractmethod
- def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: ...
+ def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: ...
@classmethod
@abstractmethod
def handle_create(
cls,
data: Any,
- data_args: Optional[dict[str, Any]],
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]],
- processor_args: Optional[dict[str, Any]],
+ data_args: dict[str, Any] | None,
+ processor: str | Path | PreTrainedTokenizerBase | None,
+ processor_args: dict[str, Any] | None,
random_seed: int,
- ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: ...
+ ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: ...
diff --git a/src/guidellm/dataset/entrypoints.py b/src/guidellm/dataset/entrypoints.py
index cf689956..1da2222a 100644
--- a/src/guidellm/dataset/entrypoints.py
+++ b/src/guidellm/dataset/entrypoints.py
@@ -1,5 +1,5 @@
from pathlib import Path
-from typing import Any, Optional, Union
+from typing import Any
from datasets import Dataset, IterableDataset
from transformers import PreTrainedTokenizerBase # type: ignore[import]
@@ -15,12 +15,12 @@
def load_dataset(
data: Any,
- data_args: Optional[dict[str, Any]],
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]],
- processor_args: Optional[dict[str, Any]],
+ data_args: dict[str, Any] | None,
+ processor: str | Path | PreTrainedTokenizerBase | None,
+ processor_args: dict[str, Any] | None,
random_seed: int = 42,
- split_pref_order: Optional[list[str]] = None,
-) -> tuple[Union[Dataset, IterableDataset], dict[ColumnInputTypes, str]]:
+ split_pref_order: list[str] | None = None,
+) -> tuple[Dataset | IterableDataset, dict[ColumnInputTypes, str]]:
creators = [
InMemoryDatasetCreator,
SyntheticDatasetCreator,
diff --git a/src/guidellm/dataset/file.py b/src/guidellm/dataset/file.py
index 5d6df1d9..718cb46f 100644
--- a/src/guidellm/dataset/file.py
+++ b/src/guidellm/dataset/file.py
@@ -1,5 +1,5 @@
from pathlib import Path
-from typing import Any, Optional, Union
+from typing import Any
import pandas as pd # type: ignore[import]
from datasets import (
@@ -30,8 +30,8 @@ class FileDatasetCreator(DatasetCreator):
}
@classmethod
- def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003
- if isinstance(data, (str, Path)) and (path := Path(data)).exists():
+ def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003
+ if isinstance(data, str | Path) and (path := Path(data)).exists():
# local folder or py file, assume supported
return path.suffix.lower() in cls.SUPPORTED_TYPES
@@ -41,12 +41,12 @@ def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool:
def handle_create(
cls,
data: Any,
- data_args: Optional[dict[str, Any]],
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], # noqa: ARG003
- processor_args: Optional[dict[str, Any]], # noqa: ARG003
+ data_args: dict[str, Any] | None,
+ processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003
+ processor_args: dict[str, Any] | None, # noqa: ARG003
random_seed: int, # noqa: ARG003
- ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]:
- if not isinstance(data, (str, Path)):
+ ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict:
+ if not isinstance(data, str | Path):
raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ")
path = Path(data)
@@ -63,8 +63,8 @@ def handle_create(
@classmethod
def load_dataset(
- cls, path: Path, data_args: Optional[dict[str, Any]]
- ) -> Union[Dataset, IterableDataset]:
+ cls, path: Path, data_args: dict[str, Any] | None
+ ) -> Dataset | IterableDataset:
if path.suffix.lower() in {".txt", ".text"}:
with path.open("r") as file:
items = file.readlines()
diff --git a/src/guidellm/dataset/hf_datasets.py b/src/guidellm/dataset/hf_datasets.py
index 7f91facd..d1be46c1 100644
--- a/src/guidellm/dataset/hf_datasets.py
+++ b/src/guidellm/dataset/hf_datasets.py
@@ -1,5 +1,5 @@
from pathlib import Path
-from typing import Any, Optional, Union
+from typing import Any
from datasets import (
Dataset,
@@ -18,18 +18,18 @@
class HFDatasetsCreator(DatasetCreator):
@classmethod
- def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003
+ def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003
if isinstance(
- data, (Dataset, DatasetDict, IterableDataset, IterableDatasetDict)
+ data, Dataset | DatasetDict | IterableDataset | IterableDatasetDict
):
# base type is supported
return True
- if isinstance(data, (str, Path)) and (path := Path(data)).exists():
+ if isinstance(data, str | Path) and (path := Path(data)).exists():
# local folder or py file, assume supported
return path.is_dir() or path.suffix == ".py"
- if isinstance(data, (str, Path)):
+ if isinstance(data, str | Path):
try:
# try to load dataset
return get_dataset_config_info(data) is not None
@@ -42,12 +42,12 @@ def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool:
def handle_create(
cls,
data: Any,
- data_args: Optional[dict[str, Any]],
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], # noqa: ARG003
- processor_args: Optional[dict[str, Any]], # noqa: ARG003
+ data_args: dict[str, Any] | None,
+ processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003
+ processor_args: dict[str, Any] | None, # noqa: ARG003
random_seed: int, # noqa: ARG003
- ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]:
- if isinstance(data, (str, Path)):
+ ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict:
+ if isinstance(data, str | Path):
data = load_dataset(data, **(data_args or {}))
elif data_args:
raise ValueError(
@@ -55,7 +55,7 @@ def handle_create(
)
if isinstance(
- data, (Dataset, DatasetDict, IterableDataset, IterableDatasetDict)
+ data, Dataset | DatasetDict | IterableDataset | IterableDatasetDict
):
return data
diff --git a/src/guidellm/dataset/in_memory.py b/src/guidellm/dataset/in_memory.py
index af84f658..0461948c 100644
--- a/src/guidellm/dataset/in_memory.py
+++ b/src/guidellm/dataset/in_memory.py
@@ -1,6 +1,6 @@
from collections.abc import Iterable
from pathlib import Path
-from typing import Any, Optional, Union
+from typing import Any
from datasets import (
Dataset,
@@ -17,18 +17,18 @@
class InMemoryDatasetCreator(DatasetCreator):
@classmethod
- def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003
+ def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003
return isinstance(data, Iterable) and not isinstance(data, str)
@classmethod
def handle_create(
cls,
data: Any,
- data_args: Optional[dict[str, Any]],
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]], # noqa: ARG003
- processor_args: Optional[dict[str, Any]], # noqa: ARG003
+ data_args: dict[str, Any] | None,
+ processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003
+ processor_args: dict[str, Any] | None, # noqa: ARG003
random_seed: int, # noqa: ARG003
- ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]:
+ ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict:
if not isinstance(data, Iterable):
raise TypeError(
f"Unsupported data format. Expected Iterable[Any], got {type(data)}"
diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py
index 8c30f0f7..8a1626fe 100644
--- a/src/guidellm/dataset/synthetic.py
+++ b/src/guidellm/dataset/synthetic.py
@@ -3,7 +3,7 @@
from collections.abc import Iterable, Iterator
from itertools import cycle
from pathlib import Path
-from typing import Any, Literal, Optional, Union
+from typing import Any, Literal
import yaml
from datasets import (
@@ -35,17 +35,17 @@ class SyntheticDatasetConfig(BaseModel):
description="The average number of text tokens generated for prompts.",
gt=0,
)
- prompt_tokens_stdev: Optional[int] = Field(
+ prompt_tokens_stdev: int | None = Field(
description="The standard deviation of the tokens generated for prompts.",
gt=0,
default=None,
)
- prompt_tokens_min: Optional[int] = Field(
+ prompt_tokens_min: int | None = Field(
description="The minimum number of text tokens generated for prompts.",
gt=0,
default=None,
)
- prompt_tokens_max: Optional[int] = Field(
+ prompt_tokens_max: int | None = Field(
description="The maximum number of text tokens generated for prompts.",
gt=0,
default=None,
@@ -54,17 +54,17 @@ class SyntheticDatasetConfig(BaseModel):
description="The average number of text tokens generated for outputs.",
gt=0,
)
- output_tokens_stdev: Optional[int] = Field(
+ output_tokens_stdev: int | None = Field(
description="The standard deviation of the tokens generated for outputs.",
gt=0,
default=None,
)
- output_tokens_min: Optional[int] = Field(
+ output_tokens_min: int | None = Field(
description="The minimum number of text tokens generated for outputs.",
gt=0,
default=None,
)
- output_tokens_max: Optional[int] = Field(
+ output_tokens_max: int | None = Field(
description="The maximum number of text tokens generated for outputs.",
gt=0,
default=None,
@@ -80,7 +80,7 @@ class SyntheticDatasetConfig(BaseModel):
)
@staticmethod
- def parse_str(data: Union[str, Path]) -> "SyntheticDatasetConfig":
+ def parse_str(data: str | Path) -> "SyntheticDatasetConfig":
if (
isinstance(data, Path)
or data.strip().endswith(".config")
@@ -117,7 +117,7 @@ def parse_key_value_pairs(data: str) -> "SyntheticDatasetConfig":
return SyntheticDatasetConfig(**config_dict) # type: ignore[arg-type]
@staticmethod
- def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig":
+ def parse_config_file(data: str | Path) -> "SyntheticDatasetConfig":
with Path(data).open("r") as file:
config_dict = yaml.safe_load(file)
@@ -128,7 +128,7 @@ class SyntheticTextItemsGenerator(
Iterable[
dict[
Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
- Union[str, int],
+ str | int,
]
]
):
@@ -150,7 +150,7 @@ def __iter__(
) -> Iterator[
dict[
Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
- Union[str, int],
+ str | int,
]
]:
prompt_tokens_sampler = IntegerRangeSampler(
@@ -177,7 +177,7 @@ def __iter__(
for _, prompt_tokens, output_tokens in zip(
range(self.config.samples),
prompt_tokens_sampler,
- output_tokens_sampler,
+ output_tokens_sampler, strict=False,
):
start_index = rand.randint(0, len(self.text_creator.words))
prompt_text = self.processor.decode(
@@ -194,7 +194,7 @@ def __iter__(
}
def _create_prompt(
- self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None
+ self, prompt_tokens: int, start_index: int, unique_prefix: int | None = None
) -> list[int]:
if prompt_tokens <= 0:
return []
@@ -224,7 +224,7 @@ class SyntheticDatasetCreator(DatasetCreator):
def is_supported(
cls,
data: Any,
- data_args: Optional[dict[str, Any]], # noqa: ARG003
+ data_args: dict[str, Any] | None, # noqa: ARG003
) -> bool:
if (
isinstance(data, Path)
@@ -248,11 +248,11 @@ def is_supported(
def handle_create(
cls,
data: Any,
- data_args: Optional[dict[str, Any]],
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]],
- processor_args: Optional[dict[str, Any]],
+ data_args: dict[str, Any] | None,
+ processor: str | Path | PreTrainedTokenizerBase | None,
+ processor_args: dict[str, Any] | None,
random_seed: int,
- ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]:
+ ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict:
processor = check_load_processor(
processor,
processor_args,
@@ -270,7 +270,7 @@ def handle_create(
@classmethod
def extract_args_column_mappings(
cls,
- data_args: Optional[dict[str, Any]],
+ data_args: dict[str, Any] | None,
) -> dict[ColumnInputTypes, str]:
data_args_columns = super().extract_args_column_mappings(data_args)
diff --git a/src/guidellm/logger.py b/src/guidellm/logger.py
index 70259bad..da3464f9 100644
--- a/src/guidellm/logger.py
+++ b/src/guidellm/logger.py
@@ -72,7 +72,7 @@ def configure_logger(config: LoggingSettings = settings.logging):
sys.stdout,
level=config.console_log_level.upper(),
format="{time:YY-MM-DD HH:mm:ss}|{level: <8} \
- |{name}:{function}:{line} - {message}"
+ |{name}:{function}:{line} - {message}",
)
if config.log_file or config.log_file_level:
diff --git a/src/guidellm/preprocess/dataset.py b/src/guidellm/preprocess/dataset.py
index a94b8a14..b02efec5 100644
--- a/src/guidellm/preprocess/dataset.py
+++ b/src/guidellm/preprocess/dataset.py
@@ -1,9 +1,9 @@
import json
import os
-from collections.abc import Iterator
+from collections.abc import Callable, Iterator
from enum import Enum
from pathlib import Path
-from typing import Any, Callable, Optional, Union
+from typing import Any
import yaml
from datasets import Dataset
@@ -32,7 +32,7 @@ def handle_ignore_strategy(
min_prompt_tokens: int,
tokenizer: PreTrainedTokenizerBase,
**_kwargs,
-) -> Optional[str]:
+) -> str | None:
"""
Ignores prompts that are shorter than the required minimum token length.
@@ -56,7 +56,7 @@ def handle_concatenate_strategy(
tokenizer: PreTrainedTokenizerBase,
concat_delimiter: str,
**_kwargs,
-) -> Optional[str]:
+) -> str | None:
"""
Concatenates prompts until the minimum token requirement is met.
@@ -117,7 +117,7 @@ def handle_error_strategy(
min_prompt_tokens: int,
tokenizer: PreTrainedTokenizerBase,
**_kwargs,
-) -> Optional[str]:
+) -> str | None:
"""
Raises an error if the prompt is too short.
@@ -150,24 +150,24 @@ class TokensConfig(BaseModel):
description="The average number of tokens.",
gt=0,
)
- stdev: Optional[int] = Field(
+ stdev: int | None = Field(
description="The standard deviation of the tokens.",
gt=0,
default=None,
)
- min: Optional[int] = Field(
+ min: int | None = Field(
description="The minimum number of tokens.",
gt=0,
default=None,
)
- max: Optional[int] = Field(
+ max: int | None = Field(
description="The maximum number of tokens.",
gt=0,
default=None,
)
@staticmethod
- def parse_str(data: Union[str, Path]) -> "TokensConfig":
+ def parse_str(data: str | Path) -> "TokensConfig":
"""
Parses a string or path into a TokensConfig object. Supports:
- JSON string
@@ -215,14 +215,14 @@ def parse_key_value_pairs(data: str) -> "TokensConfig":
return TokensConfig(**config_dict) # type: ignore[arg-type]
@staticmethod
- def parse_config_file(data: Union[str, Path]) -> "TokensConfig":
+ def parse_config_file(data: str | Path) -> "TokensConfig":
with Path(data).open("r") as file:
config_dict = yaml.safe_load(file)
return TokensConfig(**config_dict)
-def _validate_output_suffix(output_path: Union[str, Path]) -> None:
+def _validate_output_suffix(output_path: str | Path) -> None:
output_path = Path(output_path)
suffix = output_path.suffix.lower()
if suffix not in SUPPORTED_TYPES:
@@ -233,18 +233,18 @@ def _validate_output_suffix(output_path: Union[str, Path]) -> None:
def process_dataset(
- data: Union[str, Path],
- output_path: Union[str, Path],
- processor: Union[str, Path, PreTrainedTokenizerBase],
- prompt_tokens: Union[str, Path],
- output_tokens: Union[str, Path],
- processor_args: Optional[dict[str, Any]] = None,
- data_args: Optional[dict[str, Any]] = None,
+ data: str | Path,
+ output_path: str | Path,
+ processor: str | Path | PreTrainedTokenizerBase,
+ prompt_tokens: str | Path,
+ output_tokens: str | Path,
+ processor_args: dict[str, Any] | None = None,
+ data_args: dict[str, Any] | None = None,
short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE,
- pad_char: Optional[str] = None,
- concat_delimiter: Optional[str] = None,
+ pad_char: str | None = None,
+ concat_delimiter: str | None = None,
push_to_hub: bool = False,
- hub_dataset_id: Optional[str] = None,
+ hub_dataset_id: str | None = None,
random_seed: int = 42,
) -> None:
"""
@@ -354,7 +354,7 @@ def process_dataset(
def push_dataset_to_hub(
- hub_dataset_id: Optional[str],
+ hub_dataset_id: str | None,
processed_dataset: Dataset,
) -> None:
"""
diff --git a/src/guidellm/presentation/data_models.py b/src/guidellm/presentation/data_models.py
index c1e8f13f..ff2863b4 100644
--- a/src/guidellm/presentation/data_models.py
+++ b/src/guidellm/presentation/data_models.py
@@ -1,7 +1,7 @@
import random
from collections import defaultdict
from math import ceil
-from typing import TYPE_CHECKING, Optional, Union
+from typing import TYPE_CHECKING
from pydantic import BaseModel, computed_field
@@ -12,14 +12,14 @@
class Bucket(BaseModel):
- value: Union[float, int]
+ value: float | int
count: int
@staticmethod
def from_data(
- data: Union[list[float], list[int]],
- bucket_width: Optional[float] = None,
- n_buckets: Optional[int] = None,
+ data: list[float] | list[int],
+ bucket_width: float | None = None,
+ n_buckets: int | None = None,
) -> tuple[list["Bucket"], float]:
if not data:
return [], 1.0
@@ -35,7 +35,7 @@ def from_data(
else:
n_buckets = ceil(range_v / bucket_width)
- bucket_counts: defaultdict[Union[float, int], int] = defaultdict(int)
+ bucket_counts: defaultdict[float | int, int] = defaultdict(int)
for val in data:
idx = int((val - min_v) // bucket_width)
if idx >= n_buckets:
@@ -80,7 +80,7 @@ def from_benchmarks(cls, benchmarks: list["GenerativeBenchmark"]):
class Distribution(BaseModel):
- statistics: Optional[DistributionSummary] = None
+ statistics: DistributionSummary | None = None
buckets: list[Bucket]
bucket_width: float
@@ -190,7 +190,7 @@ class TabularDistributionSummary(DistributionSummary):
"""
@computed_field
- def percentile_rows(self) -> list[dict[str, Union[str, float]]]:
+ def percentile_rows(self) -> list[dict[str, str | float]]:
rows = [
{"percentile": name, "value": value}
for name, value in self.percentiles.model_dump().items()
diff --git a/src/guidellm/presentation/injector.py b/src/guidellm/presentation/injector.py
index bb1fd684..1e78080e 100644
--- a/src/guidellm/presentation/injector.py
+++ b/src/guidellm/presentation/injector.py
@@ -1,6 +1,5 @@
import re
from pathlib import Path
-from typing import Union
from loguru import logger
@@ -8,7 +7,7 @@
from guidellm.utils.text import load_text
-def create_report(js_data: dict, output_path: Union[str, Path]) -> Path:
+def create_report(js_data: dict, output_path: str | Path) -> Path:
"""
Creates a report from the dictionary and saves it to the output path.
diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py
index 607a7455..ac34131e 100644
--- a/src/guidellm/request/loader.py
+++ b/src/guidellm/request/loader.py
@@ -4,8 +4,6 @@
from typing import (
Any,
Literal,
- Optional,
- Union,
)
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
@@ -43,9 +41,9 @@ def description(self) -> RequestLoaderDescription: ...
class GenerativeRequestLoaderDescription(RequestLoaderDescription):
type_: Literal["generative_request_loader"] = "generative_request_loader" # type: ignore[assignment]
data: str
- data_args: Optional[dict[str, Any]]
+ data_args: dict[str, Any] | None
processor: str
- processor_args: Optional[dict[str, Any]]
+ processor_args: dict[str, Any] | None
class GenerativeRequestLoader(RequestLoader):
@@ -69,18 +67,11 @@ class GenerativeRequestLoader(RequestLoader):
def __init__(
self,
- data: Union[
- str,
- Path,
- Iterable[Union[str, dict[str, Any]]],
- Dataset,
- DatasetDict,
- IterableDataset,
- IterableDatasetDict,
- ],
- data_args: Optional[dict[str, Any]],
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]],
- processor_args: Optional[dict[str, Any]],
+ data: str | Path | Iterable[str | dict[str, Any]] | Dataset | DatasetDict | \
+ IterableDataset | IterableDatasetDict,
+ data_args: dict[str, Any] | None,
+ processor: str | Path | PreTrainedTokenizerBase | None,
+ processor_args: dict[str, Any] | None,
shuffle: bool = True,
iter_type: Literal["finite", "infinite"] = "finite",
random_seed: int = 42,
@@ -202,7 +193,7 @@ def _extract_text_column(self) -> str:
"'data_args' dictionary."
)
- def _extract_prompt_tokens_count_column(self) -> Optional[str]:
+ def _extract_prompt_tokens_count_column(self) -> str | None:
column_names = self._dataset_columns()
if column_names and "prompt_tokens_count" in column_names:
@@ -213,7 +204,7 @@ def _extract_prompt_tokens_count_column(self) -> Optional[str]:
return None
- def _extract_output_tokens_count_column(self) -> Optional[str]:
+ def _extract_output_tokens_count_column(self) -> str | None:
column_names = self._dataset_columns()
if column_names and "output_tokens_count" in column_names:
@@ -224,7 +215,7 @@ def _extract_output_tokens_count_column(self) -> Optional[str]:
return None
- def _dataset_columns(self, err_msg: Optional[str] = None) -> Optional[list[str]]:
+ def _dataset_columns(self, err_msg: str | None = None) -> list[str] | None:
try:
column_names = self.dataset.column_names
@@ -240,7 +231,7 @@ def _dataset_columns(self, err_msg: Optional[str] = None) -> Optional[list[str]]
def _get_dataset_iter(
self, scope_create_count: int
- ) -> Optional[Iterator[dict[str, Any]]]:
+ ) -> Iterator[dict[str, Any]] | None:
if scope_create_count > 0 and self.iter_type != "infinite":
return None
diff --git a/src/guidellm/request/request.py b/src/guidellm/request/request.py
index bf4e59fb..83dc40f1 100644
--- a/src/guidellm/request/request.py
+++ b/src/guidellm/request/request.py
@@ -1,5 +1,5 @@
import uuid
-from typing import Any, Literal, Optional
+from typing import Any, Literal
from pydantic import Field
@@ -33,7 +33,7 @@ class GenerationRequest(StandardBaseModel):
of output tokens. Used for controlling the behavior of the backend.
"""
- request_id: Optional[str] = Field(
+ request_id: str | None = Field(
default_factory=lambda: str(uuid.uuid4()),
description="The unique identifier for the request.",
)
diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py
index c724a74a..c974225a 100644
--- a/src/guidellm/scheduler/constraints.py
+++ b/src/guidellm/scheduler/constraints.py
@@ -450,7 +450,7 @@ def __call__(
current_index = max(0, self.current_index)
max_num = (
self.max_num
- if isinstance(self.max_num, (int, float))
+ if isinstance(self.max_num, int | float)
else self.max_num[min(current_index, len(self.max_num) - 1)]
)
@@ -489,7 +489,7 @@ def _validate_max_num(
raise ValueError(
f"max_num must be set and truthful, received {value} ({val} failed)"
)
- if not isinstance(val, (int, float)) or val <= 0:
+ if not isinstance(val, int | float) or val <= 0:
raise ValueError(
f"max_num must be a positive num, received {value} ({val} failed)"
)
@@ -568,7 +568,7 @@ def __call__(
current_index = max(0, self.current_index)
max_duration = (
self.max_duration
- if isinstance(self.max_duration, (int, float))
+ if isinstance(self.max_duration, int | float)
else self.max_duration[min(current_index, len(self.max_duration) - 1)]
)
@@ -607,7 +607,7 @@ def _validate_max_duration(
"max_duration must be set and truthful, "
f"received {value} ({val} failed)"
)
- if not isinstance(val, (int, float)) or val <= 0:
+ if not isinstance(val, int | float) or val <= 0:
raise ValueError(
"max_duration must be a positive num,"
f"received {value} ({val} failed)"
@@ -682,7 +682,7 @@ def __call__(
current_index = max(0, self.current_index)
max_errors = (
self.max_errors
- if isinstance(self.max_errors, (int, float))
+ if isinstance(self.max_errors, int | float)
else self.max_errors[min(current_index, len(self.max_errors) - 1)]
)
errors_exceeded = state.errored_requests >= max_errors
@@ -710,7 +710,7 @@ def _validate_max_errors(
"max_errors must be set and truthful, "
f"received {value} ({val} failed)"
)
- if not isinstance(val, (int, float)) or val <= 0:
+ if not isinstance(val, int | float) or val <= 0:
raise ValueError(
f"max_errors must be a positive num,received {value} ({val} failed)"
)
@@ -799,7 +799,7 @@ def __call__(
current_index = max(0, self.current_index)
max_error_rate = (
self.max_error_rate
- if isinstance(self.max_error_rate, (int, float))
+ if isinstance(self.max_error_rate, int | float)
else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)]
)
@@ -850,7 +850,7 @@ def _validate_max_error_rate(
"max_error_rate must be set and truthful, "
f"received {value} ({val} failed)"
)
- if not isinstance(val, (int, float)) or val <= 0 or val >= 1:
+ if not isinstance(val, int | float) or val <= 0 or val >= 1:
raise ValueError(
"max_error_rate must be a number between 0 and 1,"
f"received {value} ({val} failed)"
@@ -940,7 +940,7 @@ def __call__(
current_index = max(0, self.current_index)
max_error_rate = (
self.max_error_rate
- if isinstance(self.max_error_rate, (int, float))
+ if isinstance(self.max_error_rate, int | float)
else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)]
)
@@ -982,7 +982,7 @@ def _validate_max_error_rate(
"max_error_rate must be set and truthful, "
f"received {value} ({val} failed)"
)
- if not isinstance(val, (int, float)) or val <= 0 or val >= 1:
+ if not isinstance(val, int | float) or val <= 0 or val >= 1:
raise ValueError(
"max_error_rate must be a number between 0 and 1,"
f"received {value} ({val} failed)"
diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py
index 21d30ec8..e2583987 100644
--- a/src/guidellm/scheduler/objects.py
+++ b/src/guidellm/scheduler/objects.py
@@ -19,7 +19,6 @@
Literal,
Protocol,
TypeVar,
- Union,
runtime_checkable,
)
@@ -56,10 +55,7 @@
MultiTurnRequestT = TypeAliasType(
"MultiTurnRequestT",
- Union[
- list[Union[RequestT, tuple[RequestT, float]]],
- tuple[Union[RequestT, tuple[RequestT, float]]],
- ],
+ list[RequestT | tuple[RequestT, float]] | tuple[RequestT | tuple[RequestT, float]],
type_params=(RequestT,),
)
"""Multi-turn request structure supporting conversation history with optional delays."""
diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py
index 5f2fb74b..104ab418 100644
--- a/src/guidellm/scheduler/worker.py
+++ b/src/guidellm/scheduler/worker.py
@@ -310,7 +310,7 @@ async def _process_next_request(self):
# Pull request from the queue
request, request_info = await self.messaging.get()
- if isinstance(request, (list, tuple)):
+ if isinstance(request, list | tuple):
raise NotImplementedError("Multi-turn requests are not yet supported")
# Calculate targeted start and set pending state for request
diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py
index 6823fb77..916d6633 100644
--- a/src/guidellm/utils/encoding.py
+++ b/src/guidellm/utils/encoding.py
@@ -12,7 +12,7 @@
import json
from collections.abc import Mapping
-from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, cast
+from typing import Any, ClassVar, Generic, Literal, TypeVar, cast
try:
import msgpack # type: ignore[import-untyped] # Optional dependency
@@ -45,7 +45,6 @@
HAS_ORJSON = False
from pydantic import BaseModel
-from typing_extensions import TypeAlias
__all__ = [
"Encoder",
@@ -60,14 +59,10 @@
ObjT = TypeVar("ObjT")
MsgT = TypeVar("MsgT")
-SerializationTypesAlias: TypeAlias = Annotated[
- Optional[Literal["dict", "sequence"]],
- "Type alias for available serialization strategies",
-]
-EncodingTypesAlias: TypeAlias = Annotated[
- Optional[Literal["msgpack", "msgspec"]],
- "Type alias for available binary encoding formats",
-]
+# Type alias for available serialization strategies
+SerializationTypesAlias = Literal["dict", "sequence"] | None
+# "Type alias for available binary encoding formats"
+EncodingTypesAlias = Literal["msgpack", "msgspec"]
class MessageEncoding(Generic[ObjT, MsgT]):
@@ -405,7 +400,7 @@ def to_dict(self, obj: Any) -> Any:
if isinstance(obj, BaseModel):
return self.to_dict_pydantic(obj)
- if isinstance(obj, (list, tuple)) and any(
+ if isinstance(obj, list | tuple) and any(
isinstance(item, BaseModel) for item in obj
):
return [
@@ -432,7 +427,7 @@ def from_dict(self, data: Any) -> Any:
:param data: Dictionary representation possibly containing type metadata
:return: Reconstructed object with proper types restored
"""
- if isinstance(data, (list, tuple)):
+ if isinstance(data, list | tuple):
return [
self.from_dict_pydantic(item)
if isinstance(item, dict) and "*PYD*" in item
@@ -493,7 +488,7 @@ def to_sequence(self, obj: Any) -> str | Any:
if isinstance(obj, BaseModel):
payload_type = "pydantic"
payload = self.to_sequence_pydantic(obj)
- elif isinstance(obj, (list, tuple)) and any(
+ elif isinstance(obj, list | tuple) and any(
isinstance(item, BaseModel) for item in obj
):
payload_type = "collection_sequence"
@@ -694,33 +689,36 @@ def pack_next_sequence( # noqa: C901, PLR0912
length=(payload_len.bit_length() + 7) // 8 if payload_len > 0 else 1,
byteorder="big",
)
- if type_ == "pydantic":
- payload_type = b"P"
- elif type_ == "python":
- payload_type = b"p"
- elif type_ == "collection_tuple":
- payload_type = b"T"
- elif type_ == "collection_sequence":
- payload_type = b"S"
- elif type_ == "collection_mapping":
- payload_type = b"M"
- else:
- raise ValueError(f"Unknown type for packing: {type_}")
+ match type_:
+ case "pydantic":
+ payload_type = b"P"
+ case "python":
+ payload_type = b"p"
+ case "collection_tuple":
+ payload_type = b"T"
+ case "collection_sequence":
+ payload_type = b"S"
+ case "collection_mapping":
+ payload_type = b"M"
+ case _:
+ raise ValueError(f"Unknown type for packing: {type_}")
delimiter = b"|"
else:
payload_len_output = str(payload_len)
- if type_ == "pydantic":
- payload_type = "P"
- elif type_ == "python":
- payload_type = "p"
- elif type_ == "collection_tuple":
- payload_type = "T"
- elif type_ == "collection_sequence":
- payload_type = "S"
- elif type_ == "collection_mapping":
- payload_type = "M"
- else:
- raise ValueError(f"Unknown type for packing: {type_}")
+
+ match type_:
+ case "pydantic":
+ payload_type = "P"
+ case "python":
+ payload_type = "p"
+ case "collection_tuple":
+ payload_type = "T"
+ case "collection_sequence":
+ payload_type = "S"
+ case "collection_mapping":
+ payload_type = "M"
+ case _:
+ raise ValueError(f"Unknown type for packing: {type_}")
delimiter = "|"
# Type ignores because types are enforced at runtime
diff --git a/src/guidellm/utils/hf_datasets.py b/src/guidellm/utils/hf_datasets.py
index 73e55ebc..86f04485 100644
--- a/src/guidellm/utils/hf_datasets.py
+++ b/src/guidellm/utils/hf_datasets.py
@@ -1,5 +1,4 @@
from pathlib import Path
-from typing import Union
from datasets import Dataset
@@ -11,7 +10,7 @@
}
-def save_dataset_to_file(dataset: Dataset, output_path: Union[str, Path]) -> None:
+def save_dataset_to_file(dataset: Dataset, output_path: str | Path) -> None:
"""
Saves a HuggingFace Dataset to file in a supported format.
diff --git a/src/guidellm/utils/hf_transformers.py b/src/guidellm/utils/hf_transformers.py
index 1f2aa1b5..636988c3 100644
--- a/src/guidellm/utils/hf_transformers.py
+++ b/src/guidellm/utils/hf_transformers.py
@@ -1,5 +1,5 @@
from pathlib import Path
-from typing import Any, Optional, Union
+from typing import Any
from transformers import AutoTokenizer, PreTrainedTokenizerBase # type: ignore[import]
@@ -9,15 +9,15 @@
def check_load_processor(
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]],
- processor_args: Optional[dict[str, Any]],
+ processor: str | Path | PreTrainedTokenizerBase | None,
+ processor_args: dict[str, Any] | None,
error_msg: str,
) -> PreTrainedTokenizerBase:
if processor is None:
raise ValueError(f"Processor/Tokenizer is required for {error_msg}.")
try:
- if isinstance(processor, (str, Path)):
+ if isinstance(processor, str | Path):
loaded = AutoTokenizer.from_pretrained(
processor,
**(processor_args or {}),
diff --git a/src/guidellm/utils/messaging.py b/src/guidellm/utils/messaging.py
index 9311259d..4dce576d 100644
--- a/src/guidellm/utils/messaging.py
+++ b/src/guidellm/utils/messaging.py
@@ -16,13 +16,13 @@
import threading
import time
from abc import ABC, abstractmethod
-from collections.abc import Iterable
+from collections.abc import Callable, Iterable
from multiprocessing.connection import Connection
from multiprocessing.context import BaseContext
from multiprocessing.managers import SyncManager
from multiprocessing.synchronize import Event as ProcessingEvent
from threading import Event as ThreadingEvent
-from typing import Any, Callable, Generic, Protocol, TypeVar, cast
+from typing import Any, Generic, Protocol, TypeVar, cast
import culsans
from pydantic import BaseModel
@@ -420,7 +420,7 @@ def _create_check_stop_callable(
stop_events = tuple(
item
for item in stop_criteria or []
- if isinstance(item, (ThreadingEvent, ProcessingEvent))
+ if isinstance(item, ThreadingEvent | ProcessingEvent)
)
stop_callbacks = tuple(item for item in stop_criteria or [] if callable(item))
diff --git a/src/guidellm/utils/mixins.py b/src/guidellm/utils/mixins.py
index b001ff2d..7cf28d00 100644
--- a/src/guidellm/utils/mixins.py
+++ b/src/guidellm/utils/mixins.py
@@ -91,7 +91,7 @@ def create_info_dict(cls, obj: Any) -> dict[str, Any]:
"attributes": (
{
key: val
- if isinstance(val, (str, int, float, bool, list, dict))
+ if isinstance(val, str | int | float | bool | list | dict)
else repr(val)
for key, val in obj.__dict__.items()
if not key.startswith("_")
diff --git a/src/guidellm/utils/pydantic_utils.py b/src/guidellm/utils/pydantic_utils.py
index 55816ef1..05f5ad81 100644
--- a/src/guidellm/utils/pydantic_utils.py
+++ b/src/guidellm/utils/pydantic_utils.py
@@ -11,11 +11,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import Any, ClassVar, Generic, TypeVar, cast
+from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
-from typing_extensions import get_args, get_origin
from guidellm.utils.registry import RegistryMixin
diff --git a/src/guidellm/utils/random.py b/src/guidellm/utils/random.py
index ceef20b9..6c8f396d 100644
--- a/src/guidellm/utils/random.py
+++ b/src/guidellm/utils/random.py
@@ -1,6 +1,5 @@
import random
from collections.abc import Iterator
-from typing import Optional
__all__ = ["IntegerRangeSampler"]
@@ -9,9 +8,9 @@ class IntegerRangeSampler:
def __init__(
self,
average: int,
- variance: Optional[int],
- min_value: Optional[int],
- max_value: Optional[int],
+ variance: int | None,
+ min_value: int | None,
+ max_value: int | None,
random_seed: int,
):
self.average = average
diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py
index e6f1b657..e4727cbd 100644
--- a/src/guidellm/utils/registry.py
+++ b/src/guidellm/utils/registry.py
@@ -10,7 +10,8 @@
from __future__ import annotations
-from typing import Any, Callable, ClassVar, Generic, TypeVar, cast
+from collections.abc import Callable
+from typing import Any, ClassVar, Generic, TypeVar, cast
from guidellm.utils.auto_importer import AutoImporterMixin
@@ -103,7 +104,7 @@ def register_decorator(
if name is None:
name = obj.__name__
- elif not isinstance(name, (str, list)):
+ elif not isinstance(name, str | list):
raise ValueError(
"RegistryMixin.register_decorator name must be a string or "
f"an iterable of strings. Got {name}."
diff --git a/src/guidellm/utils/statistics.py b/src/guidellm/utils/statistics.py
index acd9d4f1..f71a2c24 100644
--- a/src/guidellm/utils/statistics.py
+++ b/src/guidellm/utils/statistics.py
@@ -149,7 +149,7 @@ def from_distribution_function(
in the output
:return: DistributionSummary instance with calculated statistical metrics
"""
- values, weights = zip(*distribution) if distribution else ([], [])
+ values, weights = zip(*distribution, strict=True) if distribution else ([], [])
values = np.array(values) # type: ignore[assignment]
weights = np.array(weights) # type: ignore[assignment]
@@ -247,7 +247,7 @@ def from_values(
)
return DistributionSummary.from_distribution_function(
- distribution=list(zip(values, weights)),
+ distribution=list(zip(values, weights, strict=True)),
include_cdf=include_cdf,
)
@@ -389,7 +389,7 @@ def from_iterable_request_times(
events[global_end] = 0
for (_, end), first_iter, first_iter_count, total_count in zip(
- requests, first_iter_times, first_iter_counts, iter_counts
+ requests, first_iter_times, first_iter_counts, iter_counts, strict=True
):
events[first_iter] += first_iter_count
@@ -499,36 +499,36 @@ def from_values(
)
_, successful_values, successful_weights = (
- zip(*successful)
+ zip(*successful, strict=True)
if (
successful := list(
filter(
lambda val: val[0] == "successful",
- zip(value_types, values, weights),
+ zip(value_types, values, weights, strict=True),
)
)
)
else ([], [], [])
)
_, incomplete_values, incomplete_weights = (
- zip(*incomplete)
+ zip(*incomplete, strict=True)
if (
incomplete := list(
filter(
lambda val: val[0] == "incomplete",
- zip(value_types, values, weights),
+ zip(value_types, values, weights, strict=True),
)
)
)
else ([], [], [])
)
_, errored_values, errored_weights = (
- zip(*errored)
+ zip(*errored, strict=True)
if (
errored := list(
filter(
lambda val: val[0] == "error",
- zip(value_types, values, weights),
+ zip(value_types, values, weights, strict=True),
)
)
)
@@ -604,36 +604,36 @@ def from_request_times(
)
_, successful_requests = (
- zip(*successful)
+ zip(*successful, strict=True)
if (
successful := list(
filter(
lambda val: val[0] == "successful",
- zip(request_types, requests),
+ zip(request_types, requests, strict=True),
)
)
)
else ([], [])
)
_, incomplete_requests = (
- zip(*incomplete)
+ zip(*incomplete, strict=True)
if (
incomplete := list(
filter(
lambda val: val[0] == "incomplete",
- zip(request_types, requests),
+ zip(request_types, requests, strict=True),
)
)
)
else ([], [])
)
_, errored_requests = (
- zip(*errored)
+ zip(*errored, strict=True)
if (
errored := list(
filter(
lambda val: val[0] == "error",
- zip(request_types, requests),
+ zip(request_types, requests, strict=True),
)
)
)
@@ -734,7 +734,7 @@ def from_iterable_request_times(
successful_iter_counts,
successful_first_iter_counts,
) = (
- zip(*successful)
+ zip(*successful, strict=True)
if (
successful := list(
filter(
@@ -745,6 +745,7 @@ def from_iterable_request_times(
first_iter_times,
iter_counts,
first_iter_counts,
+ strict=True,
),
)
)
@@ -758,7 +759,7 @@ def from_iterable_request_times(
incomplete_iter_counts,
incomplete_first_iter_counts,
) = (
- zip(*incomplete)
+ zip(*incomplete, strict=True)
if (
incomplete := list(
filter(
@@ -769,6 +770,7 @@ def from_iterable_request_times(
first_iter_times,
iter_counts,
first_iter_counts,
+ strict=True,
),
)
)
@@ -782,7 +784,7 @@ def from_iterable_request_times(
errored_iter_counts,
errored_first_iter_counts,
) = (
- zip(*errored)
+ zip(*errored, strict=True)
if (
errored := list(
filter(
@@ -793,6 +795,7 @@ def from_iterable_request_times(
first_iter_times,
iter_counts,
first_iter_counts,
+ strict=True,
),
)
)
@@ -904,7 +907,7 @@ def __add__(self, value: Any) -> float:
:return: Updated mean after adding the value
:raises ValueError: If value is not numeric (int or float)
"""
- if not isinstance(value, (int, float)):
+ if not isinstance(value, int | float):
raise ValueError(
f"Value must be an int or float, got {type(value)} instead.",
)
@@ -921,7 +924,7 @@ def __iadd__(self, value: Any) -> RunningStats:
:return: Self reference for method chaining
:raises ValueError: If value is not numeric (int or float)
"""
- if not isinstance(value, (int, float)):
+ if not isinstance(value, int | float):
raise ValueError(
f"Value must be an int or float, got {type(value)} instead.",
)
diff --git a/src/guidellm/utils/synchronous.py b/src/guidellm/utils/synchronous.py
index 64c14e94..d37daec2 100644
--- a/src/guidellm/utils/synchronous.py
+++ b/src/guidellm/utils/synchronous.py
@@ -16,9 +16,6 @@
from multiprocessing.synchronize import Event as ProcessingEvent
from threading import Barrier as ThreadingBarrier
from threading import Event as ThreadingEvent
-from typing import Annotated, Union
-
-from typing_extensions import TypeAlias
__all__ = [
"SyncObjectTypesAlias",
@@ -28,10 +25,10 @@
]
-SyncObjectTypesAlias: TypeAlias = Annotated[
- Union[ThreadingEvent, ProcessingEvent, ThreadingBarrier, ProcessingBarrier],
- "Type alias for threading and multiprocessing synchronization object types",
-]
+# Type alias for threading and multiprocessing synchronization object types
+SyncObjectTypesAlias = (
+ ThreadingEvent | ProcessingEvent | ThreadingBarrier | ProcessingBarrier
+)
async def wait_for_sync_event(
@@ -146,7 +143,7 @@ async def wait_for_sync_objects(
tasks = [
asyncio.create_task(
wait_for_sync_barrier(obj, poll_interval)
- if isinstance(obj, (ThreadingBarrier, ProcessingBarrier))
+ if isinstance(obj, ThreadingBarrier | ProcessingBarrier)
else wait_for_sync_event(obj, poll_interval)
)
for obj in objects
diff --git a/src/guidellm/utils/typing.py b/src/guidellm/utils/typing.py
index 8146ea1e..8d3580ef 100644
--- a/src/guidellm/utils/typing.py
+++ b/src/guidellm/utils/typing.py
@@ -1,14 +1,9 @@
from __future__ import annotations
from collections.abc import Iterator
+from types import UnionType
from typing import Annotated, Literal, Union, get_args, get_origin
-# Backwards compatibility for Python <3.10
-try:
- from types import UnionType # type: ignore[attr-defined]
-except ImportError:
- UnionType = Union
-
# Backwards compatibility for Python <3.12
try:
from typing import TypeAliasType # type: ignore[attr-defined]
diff --git a/tests/integration/scheduler/test_scheduler.py b/tests/integration/scheduler/test_scheduler.py
index 51abf59b..65bff95f 100644
--- a/tests/integration/scheduler/test_scheduler.py
+++ b/tests/integration/scheduler/test_scheduler.py
@@ -167,7 +167,7 @@ def _request_indices():
_request_indices(),
received_updates.keys(),
received_updates.values(),
- received_responses,
+ received_responses, strict=False,
):
assert req == f"req_{index}"
assert resp in (f"response_for_{req}", f"mock_error_for_{req}")
diff --git a/tests/unit/benchmark/test_output.py b/tests/unit/benchmark/test_output.py
index 6763d978..6310da88 100644
--- a/tests/unit/benchmark/test_output.py
+++ b/tests/unit/benchmark/test_output.py
@@ -10,7 +10,10 @@
from guidellm.benchmark import (
GenerativeBenchmarksReport,
)
-from guidellm.benchmark.output import GenerativeBenchmarkerConsole, GenerativeBenchmarkerCSV
+from guidellm.benchmark.output import (
+ GenerativeBenchmarkerConsole,
+ GenerativeBenchmarkerCSV,
+)
from tests.unit.mock_benchmark import mock_generative_benchmark
@@ -80,6 +83,7 @@ def test_file_yaml():
mock_path.unlink()
+
@pytest.mark.asyncio
async def test_file_csv():
mock_benchmark = mock_generative_benchmark()
@@ -89,7 +93,7 @@ async def test_file_csv():
csv_benchmarker = GenerativeBenchmarkerCSV(output_path=mock_path)
await csv_benchmarker.finalize(report)
- with mock_path.open("r") as file:
+ with mock_path.open("r") as file: # noqa: ASYNC230 # This is a test.
reader = csv.reader(file)
headers = next(reader)
rows = list(reader)
@@ -105,7 +109,8 @@ def test_console_benchmarks_profile_str():
console = GenerativeBenchmarkerConsole()
mock_benchmark = mock_generative_benchmark()
assert (
- console._get_profile_str(mock_benchmark) == "type=synchronous, strategies=['synchronous']"
+ console._get_profile_str(mock_benchmark)
+ == "type=synchronous, strategies=['synchronous']"
)
diff --git a/tests/unit/dataset/test_synthetic.py b/tests/unit/dataset/test_synthetic.py
index e3110fa3..544634c8 100644
--- a/tests/unit/dataset/test_synthetic.py
+++ b/tests/unit/dataset/test_synthetic.py
@@ -530,7 +530,7 @@ def mock_sampler_side_effect(*args, **kwargs):
# Results should be identical with same seed
assert len(items1) == len(items2)
- for item1, item2 in zip(items1, items2):
+ for item1, item2 in zip(items1, items2, strict=False):
assert item1["prompt"] == item2["prompt"]
assert item1["prompt_tokens_count"] == item2["prompt_tokens_count"]
assert item1["output_tokens_count"] == item2["output_tokens_count"]
diff --git a/tests/unit/mock_backend.py b/tests/unit/mock_backend.py
index 5ac069a8..3b7237e0 100644
--- a/tests/unit/mock_backend.py
+++ b/tests/unit/mock_backend.py
@@ -6,7 +6,7 @@
import random
import time
from collections.abc import AsyncIterator
-from typing import Any, Optional
+from typing import Any
from lorem.text import TextLorem
@@ -32,7 +32,7 @@ def __init__(
self,
target: str = "mock-target",
model: str = "mock-model",
- iter_delay: Optional[float] = None,
+ iter_delay: float | None = None,
):
"""
Initialize mock backend.
@@ -53,7 +53,7 @@ def target(self) -> str:
return self._target
@property
- def model(self) -> Optional[str]:
+ def model(self) -> str | None:
"""Model name for the mock backend."""
return self._model
@@ -87,7 +87,7 @@ async def validate(self) -> None:
if not self._in_process:
raise RuntimeError("Backend not started up for process")
- async def default_model(self) -> Optional[str]:
+ async def default_model(self) -> str | None:
"""
Return the default model for the mock backend.
"""
@@ -97,7 +97,7 @@ async def resolve(
self,
request: GenerationRequest,
request_info: ScheduledRequestInfo,
- history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None,
+ history: list[tuple[GenerationRequest, GenerationResponse]] | None = None,
) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]:
"""
Process a generation request and yield progressive responses.
@@ -170,7 +170,7 @@ def _estimate_prompt_tokens(content: str) -> int:
return len(str(content).split())
@staticmethod
- def _get_tokens(token_count: Optional[int] = None) -> list[str]:
+ def _get_tokens(token_count: int | None = None) -> list[str]:
"""
Generate mock tokens for response.
"""
diff --git a/tests/unit/mock_benchmark.py b/tests/unit/mock_benchmark.py
index cdf4375a..d7bfe7c9 100644
--- a/tests/unit/mock_benchmark.py
+++ b/tests/unit/mock_benchmark.py
@@ -1,4 +1,5 @@
"""Mock benchmark objects for unit testing."""
+
from guidellm.backends import GenerationRequestTimings
from guidellm.benchmark import (
BenchmarkSchedulerStats,
diff --git a/tests/unit/mock_server/test_server.py b/tests/unit/mock_server/test_server.py
index 008103c3..ba712fb6 100644
--- a/tests/unit/mock_server/test_server.py
+++ b/tests/unit/mock_server/test_server.py
@@ -162,7 +162,7 @@ async def test_health_endpoint(self, mock_server_instance):
assert "status" in data
assert data["status"] == "healthy"
assert "timestamp" in data
- assert isinstance(data["timestamp"], (int, float))
+ assert isinstance(data["timestamp"], int | float)
@pytest.mark.smoke
@pytest.mark.asyncio
diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py
index fc5610fd..2fc4c86f 100644
--- a/tests/unit/scheduler/test_objects.py
+++ b/tests/unit/scheduler/test_objects.py
@@ -3,6 +3,7 @@
import inspect
import typing
from collections.abc import AsyncIterator
+from types import UnionType
from typing import Any, Literal, Optional, TypeVar, Union
import pytest
@@ -62,8 +63,7 @@ def test_multi_turn_request_t():
assert MultiTurnRequestT.__name__ == "MultiTurnRequestT"
value = MultiTurnRequestT.__value__
- assert hasattr(value, "__origin__")
- assert value.__origin__ is Union
+ assert isinstance(value, UnionType)
type_params = getattr(MultiTurnRequestT, "__type_params__", ())
assert len(type_params) == 1
@@ -340,7 +340,7 @@ def test_class_signatures(self):
for key in self.CHECK_KEYS:
assert key in fields
field_info = fields[key]
- assert field_info.annotation in (Union[float, None], Optional[float])
+ assert field_info.annotation in (Union[float, None], Optional[float]) # noqa: UP007
assert field_info.default is None
@pytest.mark.smoke
@@ -453,7 +453,7 @@ def test_class_signatures(self):
for key in self.CHECK_KEYS:
assert key in fields
field_info = fields[key]
- assert field_info.annotation in (Union[float, None], Optional[float])
+ assert field_info.annotation in (Union[float, None], Optional[float]) # noqa: UP007
assert field_info.default is None
@pytest.mark.smoke
@@ -704,11 +704,11 @@ def test_marshalling(self, valid_instances):
else:
assert original_value is None or isinstance(
original_value,
- (RequestSchedulerTimings, MeasuredRequestTimings),
+ RequestSchedulerTimings | MeasuredRequestTimings,
)
assert reconstructed_value is None or isinstance(
reconstructed_value,
- (RequestSchedulerTimings, MeasuredRequestTimings),
+ RequestSchedulerTimings | MeasuredRequestTimings,
)
else:
assert original_value == reconstructed_value
diff --git a/tests/unit/scheduler/test_strategies.py b/tests/unit/scheduler/test_strategies.py
index 67a2d77d..143a3130 100644
--- a/tests/unit/scheduler/test_strategies.py
+++ b/tests/unit/scheduler/test_strategies.py
@@ -225,7 +225,7 @@ def test_lifecycle(
for index in range(max(5, startup_requests + 2)):
offset = instance.next_offset()
- assert isinstance(offset, (int, float))
+ assert isinstance(offset, int | float)
if index < startup_requests:
expected_offset = initial_offset + (index + 1) * startup_delay
diff --git a/tests/unit/utils/test_encoding.py b/tests/unit/utils/test_encoding.py
index cc4600cf..5664bcb0 100644
--- a/tests/unit/utils/test_encoding.py
+++ b/tests/unit/utils/test_encoding.py
@@ -476,7 +476,7 @@ def test_to_from_sequence_collections(self, collection):
seq = inst.to_sequence(collection)
out = inst.from_sequence(seq)
assert len(out) == len(collection)
- assert all(a == b for a, b in zip(out, list(collection)))
+ assert all(a == b for a, b in zip(out, list(collection), strict=False))
@pytest.mark.sanity
def test_to_from_sequence_mapping(self):
diff --git a/tests/unit/utils/test_synchronous.py b/tests/unit/utils/test_synchronous.py
index 1a9ea2c9..7acd5b4a 100644
--- a/tests/unit/utils/test_synchronous.py
+++ b/tests/unit/utils/test_synchronous.py
@@ -6,7 +6,7 @@
from functools import wraps
from multiprocessing.synchronize import Barrier as ProcessingBarrier
from multiprocessing.synchronize import Event as ProcessingEvent
-from typing import Union
+from typing import get_args
import pytest
@@ -32,17 +32,25 @@ async def new_func(*args, **kwargs):
def test_sync_object_types_alias():
- """Test that SyncObjectTypesAlias is defined correctly as a type alias."""
- assert hasattr(SyncObjectTypesAlias, "__origin__")
- if hasattr(SyncObjectTypesAlias, "__args__"):
- actual_type = SyncObjectTypesAlias.__args__[0]
- assert hasattr(actual_type, "__origin__")
- assert actual_type.__origin__ is Union
- union_args = actual_type.__args__
- assert threading.Event in union_args
- assert ProcessingEvent in union_args
- assert threading.Barrier in union_args
- assert ProcessingBarrier in union_args
+ """
+ Test that SyncObjectTypesAlias is defined correctly as a type alias.
+
+ ## WRITTEN BY AI ##
+ """
+ # Get the actual types from the union alias
+ actual_types = get_args(SyncObjectTypesAlias)
+
+ # Define the set of expected types
+ expected_types = {
+ threading.Event,
+ ProcessingEvent,
+ threading.Barrier,
+ ProcessingBarrier,
+ }
+
+ # Assert that the set of actual types matches the expected set.
+ # Using a set comparison is robust as it ignores the order.
+ assert set(actual_types) == expected_types
class TestWaitForSyncEvent:
@@ -226,7 +234,7 @@ async def test_invocation(self, objects_types, expected_result):
async def set_target():
await asyncio.sleep(0.01)
obj = objects[expected_result]
- if isinstance(obj, (threading.Event, ProcessingEvent)):
+ if isinstance(obj, threading.Event | ProcessingEvent):
obj.set()
else:
await asyncio.to_thread(obj.wait)
diff --git a/tests/unit/utils/test_typing.py b/tests/unit/utils/test_typing.py
index fafa8765..1e31ef8e 100644
--- a/tests/unit/utils/test_typing.py
+++ b/tests/unit/utils/test_typing.py
@@ -2,10 +2,9 @@
Test suite for the typing utilities module.
"""
-from typing import Annotated, Literal, Union
+from typing import Annotated, Literal, TypeAlias
import pytest
-from typing_extensions import TypeAlias
from guidellm.utils.typing import get_literal_vals
@@ -15,7 +14,7 @@
Literal["synchronous", "concurrent", "throughput", "constant", "poisson"],
"Valid strategy type identifiers for scheduling request patterns",
]
-StrategyProfileType: TypeAlias = Union[LocalStrategyType, LocalProfileType]
+StrategyProfileType: TypeAlias = LocalStrategyType | LocalProfileType
class TestGetLiteralVals:
@@ -54,7 +53,7 @@ def test_inline_union_type(self):
### WRITTEN BY AI ###
"""
- result = get_literal_vals(Union[LocalProfileType, LocalStrategyType])
+ result = get_literal_vals(LocalProfileType | LocalStrategyType)
expected = frozenset(
{
"synchronous",
@@ -118,6 +117,6 @@ def test_literal_union(self):
### WRITTEN BY AI ###
"""
- result = get_literal_vals(Union[Literal["test", "test2"], Literal["test3"]])
+ result = get_literal_vals(Literal["test", "test2"] | Literal["test3"])
expected = frozenset({"test", "test2", "test3"})
assert result == expected