diff --git a/src/celeste/modalities/images/providers/google/client.py b/src/celeste/modalities/images/providers/google/client.py index cf4ed5b..5f2d3a8 100644 --- a/src/celeste/modalities/images/providers/google/client.py +++ b/src/celeste/modalities/images/providers/google/client.py @@ -7,7 +7,7 @@ from celeste.types import ImageContent from ...client import ImagesClient -from ...io import ImageFinishReason, ImageInput, ImageOutput, ImageUsage +from ...io import ImageFinishReason, ImageInput, ImageOutput from ...parameters import ImageParameters from .gemini import GeminiImagesClient from .imagen import ImagenImagesClient @@ -81,7 +81,9 @@ def _build_request( ) -> dict[str, Any]: return self._strategy._build_request(inputs, **parameters) # type: ignore[union-attr] - def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage: + def _parse_usage( + self, response_data: dict[str, Any] + ) -> dict[str, int | float | None]: return self._strategy._parse_usage(response_data) # type: ignore[union-attr] def _parse_content( diff --git a/src/celeste/modalities/images/providers/google/gemini.py b/src/celeste/modalities/images/providers/google/gemini.py index 098fe7c..51f333d 100644 --- a/src/celeste/modalities/images/providers/google/gemini.py +++ b/src/celeste/modalities/images/providers/google/gemini.py @@ -4,6 +4,7 @@ from typing import Any, Unpack from celeste.artifacts import ImageArtifact +from celeste.core import UsageField from celeste.mime_types import ImageMimeType from celeste.parameters import ParameterMapper from celeste.providers.google.generate_content import config as google_config @@ -11,7 +12,7 @@ from celeste.types import ImageContent from ...client import ImagesClient -from ...io import ImageFinishReason, ImageInput, ImageOutput, ImageUsage +from ...io import ImageFinishReason, ImageInput, ImageOutput from ...parameters import ImageParameters from .parameters import GEMINI_PARAMETER_MAPPERS @@ -89,11 +90,13 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]: }, } - def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage: + def _parse_usage( + self, response_data: dict[str, Any] + ) -> dict[str, int | float | None]: """Parse usage from response.""" usage = super()._parse_usage(response_data) candidates = response_data.get("candidates", []) - return ImageUsage(**usage, num_images=len(candidates)) + return {**usage, UsageField.NUM_IMAGES: len(candidates)} def _parse_content( self, diff --git a/src/celeste/modalities/images/providers/ollama/client.py b/src/celeste/modalities/images/providers/ollama/client.py index 2d6a7ab..d1bbf26 100644 --- a/src/celeste/modalities/images/providers/ollama/client.py +++ b/src/celeste/modalities/images/providers/ollama/client.py @@ -15,7 +15,6 @@ ImageChunk, ImageInput, ImageOutput, - ImageUsage, ) from ...parameters import ImageParameters from ...streaming import ImagesStream @@ -75,12 +74,14 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]: """Build request with prompt.""" return {"prompt": inputs.prompt} - def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage: + def _parse_usage( + self, response_data: dict[str, Any] + ) -> dict[str, int | float | None]: """Parse usage from response. Ollama image generation doesn't return usage metrics. """ - return ImageUsage() + return {} def _parse_content( self, diff --git a/templates/modalities/{modality_slug}/providers/{provider_slug}/client.py.template b/templates/modalities/{modality_slug}/providers/{provider_slug}/client.py.template index 109701c..32bd996 100644 --- a/templates/modalities/{modality_slug}/providers/{provider_slug}/client.py.template +++ b/templates/modalities/{modality_slug}/providers/{provider_slug}/client.py.template @@ -15,7 +15,6 @@ from ...io import ( {Modality}FinishReason, {Modality}Input, {Modality}Output, - {Modality}Usage, ) from ...parameters import {Modality}Parameters from ...streaming import {Modality}Stream @@ -61,11 +60,6 @@ class {Provider}{Modality}Client({Provider}{Api}Mixin, {Modality}Client): **parameters, ) - def _parse_usage(self, response_data: dict[str, Any]) -> {Modality}Usage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return {Modality}Usage(**usage) - def _parse_content( self, response_data: dict[str, Any], diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 1f7bb95..7364b93 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -122,8 +122,10 @@ def parameter_mappers(cls) -> list[ParameterMapper]: def _init_request(self, inputs: _TestInput) -> dict[str, Any]: return {"prompt": inputs.prompt, "model": self.model.id} - def _parse_usage(self, response_data: dict[str, Any]) -> Usage: - return Usage() + def _parse_usage( + self, response_data: dict[str, Any] + ) -> dict[str, int | float | None]: + return {} def _parse_content( # type: ignore[override] self, response_data: dict[str, Any], **parameters: Unpack[Parameters] @@ -326,3 +328,41 @@ def test_stream_raises_not_supported_for_non_streaming_model( error_msg = str(exc_info.value) assert "Streaming not supported" in error_msg assert "non-streaming-model" in error_msg + + +class TestGetUsageContract: + """Test that _get_usage correctly wraps _parse_usage dict into typed Usage.""" + + async def test_parse_usage_returns_dict_not_typed_object( + self, text_model: Model, api_key: str + ) -> None: + """Regression: _parse_usage must return a dict, not a typed Usage object. + + If _parse_usage returns a Usage object, _get_usage will crash with: + 'Usage() argument after ** must be a mapping, not Usage' + """ + client = ConcreteModalityClient( + modality=Modality.TEXT, + model=text_model, + provider=text_model.provider, + auth=APIKey(secret=SecretStr(api_key)), + ) + + raw = client._parse_usage({"some": "data"}) + assert isinstance(raw, dict), ( + f"_parse_usage must return a dict, got {type(raw).__name__}" + ) + + async def test_get_usage_wraps_dict_into_typed_usage( + self, text_model: Model, api_key: str + ) -> None: + """_get_usage must convert the raw dict from _parse_usage into typed Usage.""" + client = ConcreteModalityClient( + modality=Modality.TEXT, + model=text_model, + provider=text_model.provider, + auth=APIKey(secret=SecretStr(api_key)), + ) + + usage = client._get_usage({"some": "data"}) + assert isinstance(usage, Usage)