Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ ignore_missing_imports = true
target-version = "py310"
line-length = 88
indent-width = 4
exclude = ["build", "dist", "env", ".venv"]
exclude = ["build", "dist", "env", ".venv*"]

[tool.ruff.format]
quote-style = "double"
Expand Down
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import re
from pathlib import Path
from typing import Optional, Union

from packaging.version import Version
from setuptools import setup
Expand All @@ -11,7 +10,7 @@
TAG_VERSION_PATTERN = re.compile(r"^v(\d+\.\d+\.\d+)$")


def get_last_version_diff() -> tuple[Version, Optional[str], Optional[int]]:
def get_last_version_diff() -> tuple[Version, str | None, int | None]:
"""
Get the last version, last tag, and the number of commits since the last tag.
If no tags are found, return the last release version and None for the tag/commits.
Expand All @@ -38,8 +37,8 @@ def get_last_version_diff() -> tuple[Version, Optional[str], Optional[int]]:


def get_next_version(
build_type: str, build_iteration: Optional[Union[str, int]]
) -> tuple[Version, Optional[str], int]:
build_type: str, build_iteration: str | int | None
) -> tuple[Version, str | None, int]:
"""
Get the next version based on the build type and iteration.
- build_type == release: take the last version and add a post if build iteration
Expand Down
7 changes: 3 additions & 4 deletions src/guidellm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import asyncio
import codecs
from pathlib import Path
from typing import Annotated, Union
from typing import Annotated

import click
from pydantic import ValidationError
Expand Down Expand Up @@ -78,9 +78,8 @@
"run",
]

STRATEGY_PROFILE_CHOICES: Annotated[
list[str], "Available strategy and profile choices for benchmark execution types"
] = list(get_literal_vals(Union[ProfileType, StrategyType]))
# Available strategy and profile choices for benchmark execution types
STRATEGY_PROFILE_CHOICES: list[str] = list(get_literal_vals(ProfileType | StrategyType))


def decode_escaped_str(_ctx, _param, value):
Expand Down
28 changes: 14 additions & 14 deletions src/guidellm/backends/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

import uuid
from typing import Any, Literal, Optional
from typing import Any, Literal

from pydantic import Field

Expand Down Expand Up @@ -73,48 +73,48 @@ class GenerationResponse(StandardBaseModel):
request_args: dict[str, Any] = Field(
description="Arguments passed to the backend for this request."
)
value: Optional[str] = Field(
value: str | None = Field(
default=None,
description="Complete generated text content. None for streaming responses.",
)
delta: Optional[str] = Field(
delta: str | None = Field(
default=None, description="Incremental text content for streaming responses."
)
iterations: int = Field(
default=0, description="Number of generation iterations completed."
)
request_prompt_tokens: Optional[int] = Field(
request_prompt_tokens: int | None = Field(
default=None, description="Token count from the original request prompt."
)
request_output_tokens: Optional[int] = Field(
request_output_tokens: int | None = Field(
default=None,
description="Expected output token count from the original request.",
)
response_prompt_tokens: Optional[int] = Field(
response_prompt_tokens: int | None = Field(
default=None, description="Actual prompt token count reported by the backend."
)
response_output_tokens: Optional[int] = Field(
response_output_tokens: int | None = Field(
default=None, description="Actual output token count reported by the backend."
)

@property
def prompt_tokens(self) -> Optional[int]:
def prompt_tokens(self) -> int | None:
"""
:return: The number of prompt tokens used in the request
(response_prompt_tokens if available, otherwise request_prompt_tokens).
"""
return self.response_prompt_tokens or self.request_prompt_tokens

@property
def output_tokens(self) -> Optional[int]:
def output_tokens(self) -> int | None:
"""
:return: The number of output tokens generated in the response
(response_output_tokens if available, otherwise request_output_tokens).
"""
return self.response_output_tokens or self.request_output_tokens

@property
def total_tokens(self) -> Optional[int]:
def total_tokens(self) -> int | None:
"""
:return: The total number of tokens used in the request and response.
Sum of prompt_tokens and output_tokens.
Expand All @@ -125,15 +125,15 @@ def total_tokens(self) -> Optional[int]:

def preferred_prompt_tokens(
self, preferred_source: Literal["request", "response"]
) -> Optional[int]:
) -> int | None:
if preferred_source == "request":
return self.request_prompt_tokens or self.response_prompt_tokens
else:
return self.response_prompt_tokens or self.request_prompt_tokens

def preferred_output_tokens(
self, preferred_source: Literal["request", "response"]
) -> Optional[int]:
) -> int | None:
if preferred_source == "request":
return self.request_output_tokens or self.response_output_tokens
else:
Expand All @@ -146,11 +146,11 @@ class GenerationRequestTimings(MeasuredRequestTimings):
"""Timing model for tracking generation request lifecycle events."""

timings_type: Literal["generation_request_timings"] = "generation_request_timings"
first_iteration: Optional[float] = Field(
first_iteration: float | None = Field(
default=None,
description="Unix timestamp when the first generation iteration began.",
)
last_iteration: Optional[float] = Field(
last_iteration: float | None = Field(
default=None,
description="Unix timestamp when the last generation iteration completed.",
)
78 changes: 36 additions & 42 deletions src/guidellm/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import time
from collections.abc import AsyncIterator
from pathlib import Path
from typing import Any, ClassVar, Optional, Union
from typing import Any, ClassVar

import httpx
from PIL import Image
Expand All @@ -33,13 +33,15 @@

__all__ = ["OpenAIHTTPBackend", "UsageStats"]

ContentT = str | list[str | dict[str, str | dict[str, str]] | Path | Image.Image] | Any


@dataclasses.dataclass
class UsageStats:
"""Token usage statistics for generation requests."""

prompt_tokens: Optional[int] = None
output_tokens: Optional[int] = None
prompt_tokens: int | None = None
output_tokens: int | None = None


@Backend.register("openai_http")
Expand Down Expand Up @@ -78,19 +80,19 @@ class OpenAIHTTPBackend(Backend):
def __init__(
self,
target: str,
model: Optional[str] = None,
api_key: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
model: str | None = None,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
timeout: float = 60.0,
http2: bool = True,
follow_redirects: bool = True,
max_output_tokens: Optional[int] = None,
max_output_tokens: int | None = None,
stream_response: bool = True,
extra_query: Optional[dict] = None,
extra_body: Optional[dict] = None,
remove_from_body: Optional[list[str]] = None,
headers: Optional[dict] = None,
extra_query: dict | None = None,
extra_body: dict | None = None,
remove_from_body: list[str] | None = None,
headers: dict | None = None,
verify: bool = False,
):
"""
Expand Down Expand Up @@ -137,7 +139,7 @@ def __init__(

# Runtime state
self._in_process = False
self._async_client: Optional[httpx.AsyncClient] = None
self._async_client: httpx.AsyncClient | None = None

@property
def info(self) -> dict[str, Any]:
Expand Down Expand Up @@ -264,7 +266,7 @@ async def available_models(self) -> list[str]:

return [item["id"] for item in response.json()["data"]]

async def default_model(self) -> Optional[str]:
async def default_model(self) -> str | None:
"""
Get the default model for this backend.

Expand All @@ -280,7 +282,7 @@ async def resolve(
self,
request: GenerationRequest,
request_info: ScheduledRequestInfo,
history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None,
history: list[tuple[GenerationRequest, GenerationResponse]] | None = None,
) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]:
"""
Process a generation request and yield progressive responses.
Expand Down Expand Up @@ -363,12 +365,12 @@ async def resolve(

async def text_completions(
self,
prompt: Union[str, list[str]],
request_id: Optional[str], # noqa: ARG002
output_token_count: Optional[int] = None,
prompt: str | list[str],
request_id: str | None, # noqa: ARG002
output_token_count: int | None = None,
stream_response: bool = True,
**kwargs,
) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]:
) -> AsyncIterator[tuple[str | None, UsageStats | None]]:
"""
Generate text completions using the /v1/completions endpoint.

Expand Down Expand Up @@ -431,17 +433,13 @@ async def text_completions(

async def chat_completions(
self,
content: Union[
str,
list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]],
Any,
],
request_id: Optional[str] = None, # noqa: ARG002
output_token_count: Optional[int] = None,
content: ContentT,
request_id: str | None = None, # noqa: ARG002
output_token_count: int | None = None,
raw_content: bool = False,
stream_response: bool = True,
**kwargs,
) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]:
) -> AsyncIterator[tuple[str | None, UsageStats | None]]:
"""
Generate chat completions using the /v1/chat/completions endpoint.

Expand Down Expand Up @@ -502,10 +500,10 @@ async def chat_completions(

def _build_headers(
self,
api_key: Optional[str],
organization: Optional[str],
project: Optional[str],
user_headers: Optional[dict],
api_key: str | None,
organization: str | None,
project: str | None,
user_headers: dict | None,
) -> dict[str, str]:
headers = {}

Expand Down Expand Up @@ -541,11 +539,7 @@ def _get_params(self, endpoint_type: str) -> dict[str, str]:

def _get_chat_messages(
self,
content: Union[
str,
list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]],
Any,
],
content: ContentT,
) -> list[dict[str, Any]]:
if isinstance(content, str):
return [{"role": "user", "content": content}]
Expand All @@ -559,15 +553,15 @@ def _get_chat_messages(
resolved_content.append(item)
elif isinstance(item, str):
resolved_content.append({"type": "text", "text": item})
elif isinstance(item, (Image.Image, Path)):
elif isinstance(item, Image.Image | Path):
resolved_content.append(self._get_chat_message_media_item(item))
else:
raise ValueError(f"Unsupported content item type: {type(item)}")

return [{"role": "user", "content": resolved_content}]

def _get_chat_message_media_item(
self, item: Union[Path, Image.Image]
self, item: Path | Image.Image
) -> dict[str, Any]:
if isinstance(item, Image.Image):
encoded = base64.b64encode(item.tobytes()).decode("utf-8")
Expand Down Expand Up @@ -597,8 +591,8 @@ def _get_chat_message_media_item(
def _get_body(
self,
endpoint_type: str,
request_kwargs: Optional[dict[str, Any]],
max_output_tokens: Optional[int] = None,
request_kwargs: dict[str, Any] | None,
max_output_tokens: int | None = None,
**kwargs,
) -> dict[str, Any]:
# Start with endpoint-specific extra body parameters
Expand Down Expand Up @@ -628,7 +622,7 @@ def _get_body(

return {key: val for key, val in body.items() if val is not None}

def _get_completions_text_content(self, data: dict) -> Optional[str]:
def _get_completions_text_content(self, data: dict) -> str | None:
if not data.get("choices"):
return None

Expand All @@ -639,7 +633,7 @@ def _get_completions_text_content(self, data: dict) -> Optional[str]:
or choice.get("message", {}).get("content")
)

def _get_completions_usage_stats(self, data: dict) -> Optional[UsageStats]:
def _get_completions_usage_stats(self, data: dict) -> UsageStats | None:
if not data.get("usage"):
return None

Expand Down
Loading
Loading