Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/celeste/modalities/images/providers/google/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from celeste.types import ImageContent

from ...client import ImagesClient
from ...io import ImageFinishReason, ImageInput, ImageOutput, ImageUsage
from ...io import ImageFinishReason, ImageInput, ImageOutput
from ...parameters import ImageParameters
from .gemini import GeminiImagesClient
from .imagen import ImagenImagesClient
Expand Down Expand Up @@ -81,7 +81,9 @@ def _build_request(
) -> dict[str, Any]:
return self._strategy._build_request(inputs, **parameters) # type: ignore[union-attr]

def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage:
def _parse_usage(
self, response_data: dict[str, Any]
) -> dict[str, int | float | None]:
return self._strategy._parse_usage(response_data) # type: ignore[union-attr]

def _parse_content(
Expand Down
9 changes: 6 additions & 3 deletions src/celeste/modalities/images/providers/google/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Any, Unpack

from celeste.artifacts import ImageArtifact
from celeste.core import UsageField
from celeste.mime_types import ImageMimeType
from celeste.parameters import ParameterMapper
from celeste.providers.google.generate_content import config as google_config
from celeste.providers.google.generate_content.client import GoogleGenerateContentClient
from celeste.types import ImageContent

from ...client import ImagesClient
from ...io import ImageFinishReason, ImageInput, ImageOutput, ImageUsage
from ...io import ImageFinishReason, ImageInput, ImageOutput
from ...parameters import ImageParameters
from .parameters import GEMINI_PARAMETER_MAPPERS

Expand Down Expand Up @@ -89,11 +90,13 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]:
},
}

def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage:
def _parse_usage(
self, response_data: dict[str, Any]
) -> dict[str, int | float | None]:
"""Parse usage from response."""
usage = super()._parse_usage(response_data)
candidates = response_data.get("candidates", [])
return ImageUsage(**usage, num_images=len(candidates))
return {**usage, UsageField.NUM_IMAGES: len(candidates)}

def _parse_content(
self,
Expand Down
7 changes: 4 additions & 3 deletions src/celeste/modalities/images/providers/ollama/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
ImageChunk,
ImageInput,
ImageOutput,
ImageUsage,
)
from ...parameters import ImageParameters
from ...streaming import ImagesStream
Expand Down Expand Up @@ -75,12 +74,14 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]:
"""Build request with prompt."""
return {"prompt": inputs.prompt}

def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage:
def _parse_usage(
self, response_data: dict[str, Any]
) -> dict[str, int | float | None]:
"""Parse usage from response.

Ollama image generation doesn't return usage metrics.
"""
return ImageUsage()
return {}

def _parse_content(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ from ...io import (
{Modality}FinishReason,
{Modality}Input,
{Modality}Output,
{Modality}Usage,
)
from ...parameters import {Modality}Parameters
from ...streaming import {Modality}Stream
Expand Down Expand Up @@ -61,11 +60,6 @@ class {Provider}{Modality}Client({Provider}{Api}Mixin, {Modality}Client):
**parameters,
)

def _parse_usage(self, response_data: dict[str, Any]) -> {Modality}Usage:
"""Parse usage from response."""
usage = super()._parse_usage(response_data)
return {Modality}Usage(**usage)

def _parse_content(
self,
response_data: dict[str, Any],
Expand Down
44 changes: 42 additions & 2 deletions tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ def parameter_mappers(cls) -> list[ParameterMapper]:
def _init_request(self, inputs: _TestInput) -> dict[str, Any]:
return {"prompt": inputs.prompt, "model": self.model.id}

def _parse_usage(self, response_data: dict[str, Any]) -> Usage:
return Usage()
def _parse_usage(
self, response_data: dict[str, Any]
) -> dict[str, int | float | None]:
return {}

def _parse_content( # type: ignore[override]
self, response_data: dict[str, Any], **parameters: Unpack[Parameters]
Expand Down Expand Up @@ -326,3 +328,41 @@ def test_stream_raises_not_supported_for_non_streaming_model(
error_msg = str(exc_info.value)
assert "Streaming not supported" in error_msg
assert "non-streaming-model" in error_msg


class TestGetUsageContract:
"""Test that _get_usage correctly wraps _parse_usage dict into typed Usage."""

async def test_parse_usage_returns_dict_not_typed_object(
self, text_model: Model, api_key: str
) -> None:
"""Regression: _parse_usage must return a dict, not a typed Usage object.

If _parse_usage returns a Usage object, _get_usage will crash with:
'Usage() argument after ** must be a mapping, not Usage'
"""
client = ConcreteModalityClient(
modality=Modality.TEXT,
model=text_model,
provider=text_model.provider,
auth=APIKey(secret=SecretStr(api_key)),
)

raw = client._parse_usage({"some": "data"})
assert isinstance(raw, dict), (
f"_parse_usage must return a dict, got {type(raw).__name__}"
)

async def test_get_usage_wraps_dict_into_typed_usage(
self, text_model: Model, api_key: str
) -> None:
"""_get_usage must convert the raw dict from _parse_usage into typed Usage."""
client = ConcreteModalityClient(
modality=Modality.TEXT,
model=text_model,
provider=text_model.provider,
auth=APIKey(secret=SecretStr(api_key)),
)

usage = client._get_usage({"some": "data"})
assert isinstance(usage, Usage)
Loading