diff --git a/src/celeste/client.py b/src/celeste/client.py index 6df75c6..3cb1443 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator from json import JSONDecodeError -from typing import Any, Unpack +from typing import Any, ClassVar, Unpack import httpx from pydantic import BaseModel, ConfigDict, Field @@ -16,7 +16,7 @@ from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters from celeste.streaming import Stream -from celeste.types import TextContent +from celeste.types import RawUsage, TextContent class APIMixin(ABC): @@ -124,6 +124,9 @@ async def generate(self, prompt: str, **parameters) -> ImageGenerationOutput: model_config = ConfigDict(arbitrary_types_allowed=True) + _usage_class: ClassVar[type[Usage]] = Usage + _finish_reason_class: ClassVar[type[FinishReason]] = FinishReason + modality: Modality model: Model provider: Provider @@ -171,8 +174,8 @@ async def _predict( ) return self._output_class()( content=self._parse_content(response_data, **parameters), - usage=self._parse_usage(response_data), - finish_reason=self._parse_finish_reason(response_data), + usage=self._get_usage(response_data), + finish_reason=self._get_finish_reason(response_data), metadata=self._build_metadata(response_data), ) @@ -227,7 +230,7 @@ def parameter_mappers(cls) -> list[ParameterMapper]: ... @abstractmethod - def _parse_usage(self, response_data: dict[str, Any]) -> Usage: + def _parse_usage(self, response_data: dict[str, Any]) -> RawUsage: """Parse usage information from provider response.""" ... @@ -246,6 +249,20 @@ def _parse_finish_reason( """Parse finish reason from provider response.""" return None + def _get_usage(self, response_data: dict[str, Any]) -> Usage: + """Get modality-typed usage from response.""" + raw = self._parse_usage(response_data) + return self._usage_class(**raw) + + def _get_finish_reason(self, response_data: dict[str, Any]) -> FinishReason | None: + """Get modality-typed finish reason from response.""" + raw = self._parse_finish_reason(response_data) + if raw is None: + return None + if isinstance(raw, self._finish_reason_class): + return raw + return self._finish_reason_class(reason=raw.reason) + @classmethod @abstractmethod def _output_class(cls) -> type[Out]: diff --git a/src/celeste/modalities/audio/client.py b/src/celeste/modalities/audio/client.py index eada7e7..2b3bf50 100644 --- a/src/celeste/modalities/audio/client.py +++ b/src/celeste/modalities/audio/client.py @@ -8,7 +8,7 @@ from celeste.core import Modality from celeste.types import AudioContent -from .io import AudioInput, AudioOutput +from .io import AudioFinishReason, AudioInput, AudioOutput, AudioUsage from .parameters import AudioParameters from .streaming import AudioStream @@ -19,6 +19,8 @@ class AudioClient( """Base audio client. Providers implement speak() method.""" modality: Modality = Modality.AUDIO + _usage_class = AudioUsage + _finish_reason_class = AudioFinishReason @classmethod def _output_class(cls) -> type[AudioOutput]: diff --git a/src/celeste/modalities/audio/providers/elevenlabs/client.py b/src/celeste/modalities/audio/providers/elevenlabs/client.py index 89e9656..395d505 100644 --- a/src/celeste/modalities/audio/providers/elevenlabs/client.py +++ b/src/celeste/modalities/audio/providers/elevenlabs/client.py @@ -15,10 +15,8 @@ from ...client import AudioClient from ...io import ( AudioChunk, - AudioFinishReason, AudioInput, AudioOutput, - AudioUsage, ) from ...parameters import AudioParameters from ...streaming import AudioStream @@ -28,28 +26,12 @@ class ElevenLabsAudioStream(_ElevenLabsTextToSpeechStream, AudioStream): """ElevenLabs streaming for audio modality.""" - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> AudioUsage | None: - """Parse and wrap usage from event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return AudioUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> AudioFinishReason | None: - """Parse and wrap finish reason from event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return AudioFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> AudioChunk | None: """Parse binary audio chunk from stream event.""" chunk_data = self._parse_chunk_content(event_data) if not chunk_data: - usage = self._parse_chunk_usage(event_data) - finish_reason = self._parse_chunk_finish_reason(event_data) + usage = self._get_chunk_usage(event_data) + finish_reason = self._get_chunk_finish_reason(event_data) if usage is None and finish_reason is None: return None # Chunk with usage/finish_reason only (no audio) @@ -62,8 +44,8 @@ def _parse_chunk(self, event_data: dict[str, Any]) -> AudioChunk | None: return AudioChunk( content=chunk_data, - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata={"event_data": event_data}, ) @@ -110,11 +92,6 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]: """Initialize request with text input.""" return {"text": inputs.text} - def _parse_usage(self, response_data: dict[str, Any]) -> AudioUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return AudioUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -131,11 +108,6 @@ def _parse_content( return AudioArtifact(data=audio_bytes, mime_type=mime_type) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> AudioFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return AudioFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[AudioStream]: """Return the Stream class for this provider.""" return ElevenLabsAudioStream diff --git a/src/celeste/modalities/audio/providers/google/client.py b/src/celeste/modalities/audio/providers/google/client.py index 71edb77..800346e 100644 --- a/src/celeste/modalities/audio/providers/google/client.py +++ b/src/celeste/modalities/audio/providers/google/client.py @@ -12,10 +12,8 @@ from ...client import AudioClient from ...io import ( - AudioFinishReason, AudioInput, AudioOutput, - AudioUsage, ) from ...parameters import AudioParameter, AudioParameters from .parameters import GOOGLE_PARAMETER_MAPPERS @@ -49,11 +47,6 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]: "audioConfig": {}, } - def _parse_usage(self, response_data: dict[str, Any]) -> AudioUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return AudioUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -67,10 +60,5 @@ def _parse_content( return AudioArtifact(data=audio_b64, mime_type=mime_type) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> AudioFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return AudioFinishReason(reason=finish_reason.reason) - __all__ = ["GoogleAudioClient"] diff --git a/src/celeste/modalities/audio/providers/gradium/client.py b/src/celeste/modalities/audio/providers/gradium/client.py index 48586a6..d963c16 100644 --- a/src/celeste/modalities/audio/providers/gradium/client.py +++ b/src/celeste/modalities/audio/providers/gradium/client.py @@ -15,10 +15,8 @@ from ...client import AudioClient from ...io import ( AudioChunk, - AudioFinishReason, AudioInput, AudioOutput, - AudioUsage, ) from ...parameters import AudioParameters from ...streaming import AudioStream @@ -28,28 +26,12 @@ class GradiumAudioStream(_GradiumTextToSpeechStream, AudioStream): """Gradium streaming for audio modality.""" - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> AudioUsage | None: - """Parse and wrap usage from event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return AudioUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> AudioFinishReason | None: - """Parse and wrap finish reason from event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return AudioFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> AudioChunk | None: """Parse binary audio chunk from stream event.""" chunk_data = self._parse_chunk_content(event_data) if not chunk_data: - usage = self._parse_chunk_usage(event_data) - finish_reason = self._parse_chunk_finish_reason(event_data) + usage = self._get_chunk_usage(event_data) + finish_reason = self._get_chunk_finish_reason(event_data) if usage is None and finish_reason is None: return None # Chunk with usage/finish_reason only (no audio) @@ -62,8 +44,8 @@ def _parse_chunk(self, event_data: dict[str, Any]) -> AudioChunk | None: return AudioChunk( content=chunk_data, - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata={"event_data": event_data}, ) @@ -110,11 +92,6 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]: """Initialize request with text input.""" return {"text": inputs.text} - def _parse_usage(self, response_data: dict[str, Any]) -> AudioUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return AudioUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -131,11 +108,6 @@ def _parse_content( return AudioArtifact(data=audio_bytes, mime_type=mime_type) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> AudioFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return AudioFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[AudioStream]: """Return the Stream class for this provider.""" return GradiumAudioStream diff --git a/src/celeste/modalities/audio/providers/openai/client.py b/src/celeste/modalities/audio/providers/openai/client.py index c0ba702..e712f1e 100644 --- a/src/celeste/modalities/audio/providers/openai/client.py +++ b/src/celeste/modalities/audio/providers/openai/client.py @@ -8,7 +8,7 @@ from celeste.providers.openai.audio.client import OpenAIAudioClient as OpenAIAudioMixin from ...client import AudioClient -from ...io import AudioFinishReason, AudioInput, AudioOutput, AudioUsage +from ...io import AudioFinishReason, AudioInput, AudioOutput from ...parameters import AudioParameters from .parameters import OPENAI_PARAMETER_MAPPERS @@ -37,11 +37,6 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]: """Initialize request with text input.""" return {"input": inputs.text} - def _parse_usage(self, response_data: dict[str, Any]) -> AudioUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return AudioUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], diff --git a/src/celeste/modalities/audio/streaming.py b/src/celeste/modalities/audio/streaming.py index 7609c3f..32fe00a 100644 --- a/src/celeste/modalities/audio/streaming.py +++ b/src/celeste/modalities/audio/streaming.py @@ -20,6 +20,9 @@ class AudioStream(Stream[AudioOutput, AudioParameters, AudioChunk]): """Streaming for audio modality.""" + _usage_class = AudioUsage + _finish_reason_class = AudioFinishReason + def __init__( self, sse_iterator: AsyncIterator[dict[str, Any]], diff --git a/src/celeste/modalities/embeddings/client.py b/src/celeste/modalities/embeddings/client.py index 1d46110..2b6e7a3 100644 --- a/src/celeste/modalities/embeddings/client.py +++ b/src/celeste/modalities/embeddings/client.py @@ -8,7 +8,12 @@ from celeste.core import Modality from celeste.types import EmbeddingsContent -from .io import EmbeddingsInput, EmbeddingsOutput +from .io import ( + EmbeddingsFinishReason, + EmbeddingsInput, + EmbeddingsOutput, + EmbeddingsUsage, +) from .parameters import EmbeddingsParameters @@ -20,6 +25,8 @@ class EmbeddingsClient( """Base embeddings client. Providers implement operation methods.""" modality: Modality = Modality.EMBEDDINGS + _usage_class = EmbeddingsUsage + _finish_reason_class = EmbeddingsFinishReason @classmethod def _output_class(cls) -> type[EmbeddingsOutput]: diff --git a/src/celeste/modalities/embeddings/providers/google/client.py b/src/celeste/modalities/embeddings/providers/google/client.py index d09cc43..39bfff9 100644 --- a/src/celeste/modalities/embeddings/providers/google/client.py +++ b/src/celeste/modalities/embeddings/providers/google/client.py @@ -9,11 +9,7 @@ from celeste.types import EmbeddingsContent from ...client import EmbeddingsClient -from ...io import ( - EmbeddingsFinishReason, - EmbeddingsInput, - EmbeddingsUsage, -) +from ...io import EmbeddingsInput from ...parameters import EmbeddingsParameters from .parameters import GOOGLE_PARAMETER_MAPPERS @@ -43,11 +39,6 @@ def _init_request(self, inputs: EmbeddingsInput) -> dict[str, Any]: ] } - def _parse_usage(self, response_data: dict[str, Any]) -> EmbeddingsUsage: - """Parse usage from response (embeddings API doesn't provide usage).""" - usage = super()._parse_usage(response_data) - return EmbeddingsUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -56,12 +47,5 @@ def _parse_content( """Parse embedding vectors from response.""" return super()._parse_content(response_data) - def _parse_finish_reason( - self, response_data: dict[str, Any] - ) -> EmbeddingsFinishReason: - """Parse finish reason (embeddings API doesn't provide finish reasons).""" - finish_reason = super()._parse_finish_reason(response_data) - return EmbeddingsFinishReason(reason=finish_reason.reason) - __all__ = ["GoogleEmbeddingsClient"] diff --git a/src/celeste/modalities/images/client.py b/src/celeste/modalities/images/client.py index b0d7aff..ac01237 100644 --- a/src/celeste/modalities/images/client.py +++ b/src/celeste/modalities/images/client.py @@ -9,7 +9,7 @@ from celeste.core import Modality from celeste.types import ImageContent -from .io import ImageInput, ImageOutput +from .io import ImageFinishReason, ImageInput, ImageOutput, ImageUsage from .parameters import ImageParameters from .streaming import ImagesStream @@ -20,6 +20,8 @@ class ImagesClient( """Base images client. Providers implement generate/edit methods.""" modality: Modality = Modality.IMAGES + _usage_class = ImageUsage + _finish_reason_class = ImageFinishReason @classmethod def _output_class(cls) -> type[ImageOutput]: diff --git a/src/celeste/modalities/images/providers/bfl/client.py b/src/celeste/modalities/images/providers/bfl/client.py index 4c7c171..6842afd 100644 --- a/src/celeste/modalities/images/providers/bfl/client.py +++ b/src/celeste/modalities/images/providers/bfl/client.py @@ -9,7 +9,7 @@ from celeste.providers.bfl.images.utils import encode_image from ...client import ImagesClient -from ...io import ImageFinishReason, ImageInput, ImageOutput, ImageUsage +from ...io import ImageFinishReason, ImageInput, ImageOutput from ...parameters import ImageParameters from .parameters import BFL_PARAMETER_MAPPERS @@ -55,11 +55,6 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]: request["input_image"] = encode_image(inputs.image) return request - def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return ImageUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], diff --git a/src/celeste/modalities/images/providers/byteplus/client.py b/src/celeste/modalities/images/providers/byteplus/client.py index 435cd7c..df8ac54 100644 --- a/src/celeste/modalities/images/providers/byteplus/client.py +++ b/src/celeste/modalities/images/providers/byteplus/client.py @@ -17,7 +17,6 @@ from ...client import ImagesClient from ...io import ( ImageChunk, - ImageFinishReason, ImageInput, ImageOutput, ImageUsage, @@ -35,22 +34,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._completed_usage: ImageUsage | None = None self._completed_event_data: dict[str, Any] | None = None - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> ImageUsage | None: - """Parse and wrap usage from SSE event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return ImageUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> ImageFinishReason | None: - """Parse and wrap finish reason from SSE event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return ImageFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> ImageChunk | None: """Parse one SSE event into a typed chunk.""" # Handle error events (partial_failed) @@ -64,7 +47,7 @@ def _parse_chunk(self, event_data: dict[str, Any]) -> ImageChunk | None: ) # Handle completed event (usage only) - usage = self._parse_chunk_usage(event_data) + usage = self._get_chunk_usage(event_data) if usage is not None: self._completed_usage = usage self._completed_event_data = event_data @@ -83,7 +66,7 @@ def _parse_chunk(self, event_data: dict[str, Any]) -> ImageChunk | None: return ImageChunk( content=artifact, - finish_reason=self._parse_chunk_finish_reason(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), usage=None, metadata={"event_data": event_data}, ) @@ -137,11 +120,6 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]: "response_format": "url", } - def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return ImageUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -168,11 +146,6 @@ def _parse_content( msg = "No image URL or base64 data in BytePlus response" raise ValidationError(msg) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> ImageFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return ImageFinishReason(reason=finish_reason.reason) - async def _make_request( self, request_body: dict[str, Any], diff --git a/src/celeste/modalities/images/providers/google/imagen.py b/src/celeste/modalities/images/providers/google/imagen.py index fe95a31..c415fc3 100644 --- a/src/celeste/modalities/images/providers/google/imagen.py +++ b/src/celeste/modalities/images/providers/google/imagen.py @@ -10,7 +10,7 @@ from celeste.types import ImageContent from ...client import ImagesClient -from ...io import ImageFinishReason, ImageInput, ImageOutput, ImageUsage +from ...io import ImageInput, ImageOutput from ...parameters import ImageParameters from .parameters import IMAGEN_PARAMETER_MAPPERS @@ -41,11 +41,6 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]: "parameters": {}, } - def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return ImageUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -71,10 +66,5 @@ def _parse_content( return images[0] return images if images else ImageArtifact() - def _parse_finish_reason(self, response_data: dict[str, Any]) -> ImageFinishReason: - """Imagen API doesn't provide finish reasons.""" - finish_reason = super()._parse_finish_reason(response_data) - return ImageFinishReason(reason=finish_reason.reason) - __all__ = ["ImagenImagesClient"] diff --git a/src/celeste/modalities/images/providers/ollama/client.py b/src/celeste/modalities/images/providers/ollama/client.py index 5dd9175..c707dff 100644 --- a/src/celeste/modalities/images/providers/ollama/client.py +++ b/src/celeste/modalities/images/providers/ollama/client.py @@ -13,7 +13,6 @@ from ...client import ImagesClient from ...io import ( ImageChunk, - ImageFinishReason, ImageInput, ImageOutput, ImageUsage, @@ -26,22 +25,6 @@ class OllamaImagesStream(_OllamaGenerateStream, ImagesStream): """Ollama NDJSON streaming for images.""" - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> ImageUsage | None: - """Parse and wrap usage from NDJSON event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return ImageUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> ImageFinishReason | None: - """Parse and wrap finish reason from NDJSON event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return ImageFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> ImageChunk | None: """Parse NDJSON event into ImageChunk.""" b64_image = self._parse_chunk_content(event_data) @@ -54,8 +37,8 @@ def _parse_chunk(self, event_data: dict[str, Any]) -> ImageChunk | None: return ImageChunk( content=ImageArtifact(data=b64_image), - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata=self._parse_chunk_metadata(event_data), ) @@ -115,11 +98,6 @@ def _parse_content( image_b64 = super()._parse_content(response_data) return ImageArtifact(data=image_b64) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> ImageFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return ImageFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[ImagesStream]: """Return the Stream class for Ollama images.""" return OllamaImagesStream diff --git a/src/celeste/modalities/images/providers/openai/client.py b/src/celeste/modalities/images/providers/openai/client.py index bf76de9..ab1d805 100644 --- a/src/celeste/modalities/images/providers/openai/client.py +++ b/src/celeste/modalities/images/providers/openai/client.py @@ -15,10 +15,8 @@ from ...client import ImagesClient from ...io import ( ImageChunk, - ImageFinishReason, ImageInput, ImageOutput, - ImageUsage, ) from ...parameters import ImageParameters from ...streaming import ImagesStream @@ -28,28 +26,12 @@ class OpenAIImagesStream(_OpenAIImagesStream, ImagesStream): """OpenAI streaming for images modality.""" - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> ImageUsage | None: - """Parse and wrap usage from SSE event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return ImageUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> ImageFinishReason | None: - """Parse and wrap finish reason from SSE event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return ImageFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> ImageChunk | None: """Parse one SSE event into a typed chunk.""" b64_json = self._parse_chunk_content(event_data) if not b64_json: - usage = self._parse_chunk_usage(event_data) - finish_reason = self._parse_chunk_finish_reason(event_data) + usage = self._get_chunk_usage(event_data) + finish_reason = self._get_chunk_finish_reason(event_data) if usage is None and finish_reason is None: return None # Chunk with usage/finish_reason only (no image) @@ -64,8 +46,8 @@ def _parse_chunk(self, event_data: dict[str, Any]) -> ImageChunk | None: return ImageChunk( content=artifact, - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata={"event_data": event_data}, ) @@ -125,11 +107,6 @@ async def edit( **parameters, ) - def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return ImageUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -150,11 +127,6 @@ def _parse_content( msg = "No image URL or base64 data in response" raise ValueError(msg) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> ImageFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return ImageFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[ImagesStream]: """Return the Stream class for this provider.""" return OpenAIImagesStream diff --git a/src/celeste/modalities/images/providers/xai/client.py b/src/celeste/modalities/images/providers/xai/client.py index ba60d68..6c27e0e 100644 --- a/src/celeste/modalities/images/providers/xai/client.py +++ b/src/celeste/modalities/images/providers/xai/client.py @@ -9,10 +9,8 @@ from ...client import ImagesClient from ...io import ( - ImageFinishReason, ImageInput, ImageOutput, - ImageUsage, ) from ...parameters import ImageParameters from .parameters import XAI_PARAMETER_MAPPERS @@ -65,11 +63,6 @@ async def edit( **parameters, ) - def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return ImageUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -91,10 +84,5 @@ def _parse_content( msg = "No image URL or base64 data in response" raise ValueError(msg) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> ImageFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return ImageFinishReason(reason=finish_reason.reason) - __all__ = ["XAIImagesClient"] diff --git a/src/celeste/modalities/images/streaming.py b/src/celeste/modalities/images/streaming.py index d809451..f2acebc 100644 --- a/src/celeste/modalities/images/streaming.py +++ b/src/celeste/modalities/images/streaming.py @@ -16,6 +16,9 @@ class ImagesStream(Stream[ImageOutput, ImageParameters, ImageChunk]): """Streaming for images modality.""" + _usage_class = ImageUsage + _finish_reason_class = ImageFinishReason + def __init__( self, sse_iterator: AsyncIterator[dict[str, Any]], diff --git a/src/celeste/modalities/text/client.py b/src/celeste/modalities/text/client.py index b3dc169..dd59013 100644 --- a/src/celeste/modalities/text/client.py +++ b/src/celeste/modalities/text/client.py @@ -8,7 +8,7 @@ from celeste.core import InputType, Modality from celeste.types import AudioContent, ImageContent, Message, TextContent, VideoContent -from .io import TextInput, TextOutput +from .io import TextFinishReason, TextInput, TextOutput, TextUsage from .parameters import TextParameters from .streaming import TextStream @@ -20,6 +20,8 @@ class TextClient(ModalityClient[TextInput, TextOutput, TextParameters, TextConte """ modality: Modality = Modality.TEXT + _usage_class = TextUsage + _finish_reason_class = TextFinishReason @classmethod def _output_class(cls) -> type[TextOutput]: diff --git a/src/celeste/modalities/text/providers/anthropic/client.py b/src/celeste/modalities/text/providers/anthropic/client.py index e99047d..13f2464 100644 --- a/src/celeste/modalities/text/providers/anthropic/client.py +++ b/src/celeste/modalities/text/providers/anthropic/client.py @@ -16,10 +16,8 @@ from ...client import TextClient from ...io import ( TextChunk, - TextFinishReason, TextInput, TextOutput, - TextUsage, ) from ...parameters import TextParameters from ...streaming import TextStream @@ -33,22 +31,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._message_start: dict[str, Any] | None = None - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> TextUsage | None: - """Parse and wrap usage from SSE event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return TextUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> TextFinishReason | None: - """Parse and wrap finish reason from SSE event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return TextFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None: """Parse one SSE event into a typed chunk.""" event_type = event_data.get("type") @@ -60,16 +42,16 @@ def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None: content = self._parse_chunk_content(event_data) if content is None: - usage = self._parse_chunk_usage(event_data) - finish_reason = self._parse_chunk_finish_reason(event_data) + usage = self._get_chunk_usage(event_data) + finish_reason = self._get_chunk_finish_reason(event_data) if usage is None and finish_reason is None: return None content = "" return TextChunk( content=content, - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata={"event_data": event_data}, ) @@ -189,11 +171,6 @@ def _build_image_source(self, img: ImageArtifact) -> dict[str, Any]: "data": base64_data, } - def _parse_usage(self, response_data: dict[str, Any]) -> TextUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return TextUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -210,11 +187,6 @@ def _parse_content( return self._transform_output(text_content, **parameters) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> TextFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return TextFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return AnthropicTextStream diff --git a/src/celeste/modalities/text/providers/cohere/client.py b/src/celeste/modalities/text/providers/cohere/client.py index 5160af6..6f9aba4 100644 --- a/src/celeste/modalities/text/providers/cohere/client.py +++ b/src/celeste/modalities/text/providers/cohere/client.py @@ -13,10 +13,8 @@ from ...client import TextClient from ...io import ( TextChunk, - TextFinishReason, TextInput, TextOutput, - TextUsage, ) from ...parameters import TextParameters from ...streaming import TextStream @@ -26,36 +24,20 @@ class CohereTextStream(_CohereChatStream, TextStream): """Cohere streaming for text modality.""" - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> TextUsage | None: - """Parse and wrap usage from SSE event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return TextUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> TextFinishReason | None: - """Parse and wrap finish reason from SSE event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return TextFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None: """Parse one SSE event into a typed chunk.""" content = self._parse_chunk_content(event_data) if content is None: - usage = self._parse_chunk_usage(event_data) - finish_reason = self._parse_chunk_finish_reason(event_data) + usage = self._get_chunk_usage(event_data) + finish_reason = self._get_chunk_finish_reason(event_data) if usage is None and finish_reason is None: return None content = "" return TextChunk( content=content, - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata={"event_data": event_data}, ) @@ -126,11 +108,6 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]: return {"messages": [{"role": "user", "content": content}]} - def _parse_usage(self, response_data: dict[str, Any]) -> TextUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return TextUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -142,11 +119,6 @@ def _parse_content( text = first_content.get("text") or "" return self._transform_output(text, **parameters) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> TextFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return TextFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return CohereTextStream diff --git a/src/celeste/modalities/text/providers/deepseek/client.py b/src/celeste/modalities/text/providers/deepseek/client.py index e59752c..b308e15 100644 --- a/src/celeste/modalities/text/providers/deepseek/client.py +++ b/src/celeste/modalities/text/providers/deepseek/client.py @@ -12,10 +12,8 @@ from ...client import TextClient from ...io import ( TextChunk, - TextFinishReason, TextInput, TextOutput, - TextUsage, ) from ...parameters import TextParameters from ...streaming import TextStream @@ -25,36 +23,20 @@ class DeepSeekTextStream(_DeepSeekChatStream, TextStream): """DeepSeek streaming for text modality.""" - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> TextUsage | None: - """Parse and wrap usage from SSE event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return TextUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> TextFinishReason | None: - """Parse and wrap finish reason from SSE event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return TextFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None: """Parse one SSE event into a typed chunk.""" content = self._parse_chunk_content(event_data) if content is None: - usage = self._parse_chunk_usage(event_data) - finish_reason = self._parse_chunk_finish_reason(event_data) + usage = self._get_chunk_usage(event_data) + finish_reason = self._get_chunk_finish_reason(event_data) if usage is None and finish_reason is None: return None content = "" return TextChunk( content=content, - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata={"event_data": event_data}, ) @@ -106,11 +88,6 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]: return {"messages": messages} - def _parse_usage(self, response_data: dict[str, Any]) -> TextUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return TextUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -122,11 +99,6 @@ def _parse_content( content = message.get("content") or "" return self._transform_output(content, **parameters) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> TextFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return TextFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return DeepSeekTextStream diff --git a/src/celeste/modalities/text/providers/google/client.py b/src/celeste/modalities/text/providers/google/client.py index 46a33ab..3a5b917 100644 --- a/src/celeste/modalities/text/providers/google/client.py +++ b/src/celeste/modalities/text/providers/google/client.py @@ -16,10 +16,8 @@ from ...client import TextClient from ...io import ( TextChunk, - TextFinishReason, TextInput, TextOutput, - TextUsage, ) from ...parameters import TextParameters from ...streaming import TextStream @@ -29,36 +27,20 @@ class GoogleTextStream(_GoogleGenerateContentStream, TextStream): """Google streaming for text modality.""" - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> TextUsage | None: - """Parse and wrap usage from SSE event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return TextUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> TextFinishReason | None: - """Parse and wrap finish reason from SSE event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return TextFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None: """Parse one SSE event into a typed chunk.""" content = self._parse_chunk_content(event_data) if content is None: - usage = self._parse_chunk_usage(event_data) - finish_reason = self._parse_chunk_finish_reason(event_data) + usage = self._get_chunk_usage(event_data) + finish_reason = self._get_chunk_finish_reason(event_data) if usage is None and finish_reason is None: return None content = "" return TextChunk( content=content, - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata={"event_data": event_data}, ) @@ -218,11 +200,6 @@ def _build_audio_part(self, audio: AudioArtifact) -> dict[str, Any]: return {"inline_data": {"mime_type": mime_str, "data": b64}} - def _parse_usage(self, response_data: dict[str, Any]) -> TextUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return TextUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -234,11 +211,6 @@ def _parse_content( text = parts[0].get("text") if parts else "" return self._transform_output(text or "", **parameters) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> TextFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return TextFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return GoogleTextStream diff --git a/src/celeste/modalities/text/providers/groq/client.py b/src/celeste/modalities/text/providers/groq/client.py index ea6ccb5..1d2b536 100644 --- a/src/celeste/modalities/text/providers/groq/client.py +++ b/src/celeste/modalities/text/providers/groq/client.py @@ -11,10 +11,8 @@ from ...client import TextClient from ...io import ( TextChunk, - TextFinishReason, TextInput, TextOutput, - TextUsage, ) from ...parameters import TextParameters from ...streaming import TextStream @@ -24,36 +22,20 @@ class GroqTextStream(_GroqChatStream, TextStream): """Groq streaming for text modality.""" - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> TextUsage | None: - """Parse and wrap usage from SSE event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return TextUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> TextFinishReason | None: - """Parse and wrap finish reason from SSE event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return TextFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None: """Parse one SSE event into a typed chunk.""" content = self._parse_chunk_content(event_data) if content is None: - usage = self._parse_chunk_usage(event_data) - finish_reason = self._parse_chunk_finish_reason(event_data) + usage = self._get_chunk_usage(event_data) + finish_reason = self._get_chunk_finish_reason(event_data) if usage is None and finish_reason is None: return None content = "" return TextChunk( content=content, - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata={"event_data": event_data}, ) @@ -124,11 +106,6 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]: return {"messages": [{"role": "user", "content": content}]} - def _parse_usage(self, response_data: dict[str, Any]) -> TextUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return TextUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -140,11 +117,6 @@ def _parse_content( content = message.get("content") or "" return self._transform_output(content, **parameters) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> TextFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return TextFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return GroqTextStream diff --git a/src/celeste/modalities/text/providers/mistral/client.py b/src/celeste/modalities/text/providers/mistral/client.py index 9fb755f..a502ca3 100644 --- a/src/celeste/modalities/text/providers/mistral/client.py +++ b/src/celeste/modalities/text/providers/mistral/client.py @@ -13,10 +13,8 @@ from ...client import TextClient from ...io import ( TextChunk, - TextFinishReason, TextInput, TextOutput, - TextUsage, ) from ...parameters import TextParameters from ...streaming import TextStream @@ -26,36 +24,20 @@ class MistralTextStream(_MistralChatStream, TextStream): """Mistral streaming for text modality.""" - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> TextUsage | None: - """Parse and wrap usage from SSE event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return TextUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> TextFinishReason | None: - """Parse and wrap finish reason from SSE event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return TextFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None: """Parse one SSE event into a typed chunk.""" content = self._parse_chunk_content(event_data) if content is None: - usage = self._parse_chunk_usage(event_data) - finish_reason = self._parse_chunk_finish_reason(event_data) + usage = self._get_chunk_usage(event_data) + finish_reason = self._get_chunk_finish_reason(event_data) if usage is None and finish_reason is None: return None content = "" return TextChunk( content=content, - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata={"event_data": event_data}, ) @@ -123,11 +105,6 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]: return {"messages": [{"role": "user", "content": content}]} - def _parse_usage(self, response_data: dict[str, Any]) -> TextUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return TextUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -149,11 +126,6 @@ def _parse_content( return self._transform_output(content, **parameters) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> TextFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return TextFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return MistralTextStream diff --git a/src/celeste/modalities/text/providers/moonshot/client.py b/src/celeste/modalities/text/providers/moonshot/client.py index 9c06e6a..9bcde03 100644 --- a/src/celeste/modalities/text/providers/moonshot/client.py +++ b/src/celeste/modalities/text/providers/moonshot/client.py @@ -13,10 +13,8 @@ from ...client import TextClient from ...io import ( TextChunk, - TextFinishReason, TextInput, TextOutput, - TextUsage, ) from ...parameters import TextParameters from ...streaming import TextStream @@ -26,36 +24,20 @@ class MoonshotTextStream(_MoonshotChatStream, TextStream): """Moonshot streaming for text modality.""" - def _parse_chunk_usage(self, event_data: dict[str, Any]) -> TextUsage | None: - """Parse and wrap usage from SSE event.""" - usage = super()._parse_chunk_usage(event_data) - if usage: - return TextUsage(**usage) - return None - - def _parse_chunk_finish_reason( - self, event_data: dict[str, Any] - ) -> TextFinishReason | None: - """Parse and wrap finish reason from SSE event.""" - finish_reason = super()._parse_chunk_finish_reason(event_data) - if finish_reason: - return TextFinishReason(reason=finish_reason.reason) - return None - def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None: """Parse one SSE event into a typed chunk.""" content = self._parse_chunk_content(event_data) if content is None: - usage = self._parse_chunk_usage(event_data) - finish_reason = self._parse_chunk_finish_reason(event_data) + usage = self._get_chunk_usage(event_data) + finish_reason = self._get_chunk_finish_reason(event_data) if usage is None and finish_reason is None: return None content = "" return TextChunk( content=content, - finish_reason=self._parse_chunk_finish_reason(event_data), - usage=self._parse_chunk_usage(event_data), + finish_reason=self._get_chunk_finish_reason(event_data), + usage=self._get_chunk_usage(event_data), metadata={"event_data": event_data}, ) @@ -123,11 +105,6 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]: return {"messages": [{"role": "user", "content": content}]} - def _parse_usage(self, response_data: dict[str, Any]) -> TextUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return TextUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -139,11 +116,6 @@ def _parse_content( content = message.get("content") or "" return self._transform_output(content, **parameters) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> TextFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return TextFinishReason(reason=finish_reason.reason) - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return MoonshotTextStream diff --git a/src/celeste/modalities/text/streaming.py b/src/celeste/modalities/text/streaming.py index 072c740..c0d41ae 100644 --- a/src/celeste/modalities/text/streaming.py +++ b/src/celeste/modalities/text/streaming.py @@ -15,6 +15,9 @@ class TextStream(Stream[TextOutput, TextParameters, TextChunk]): """Streaming for text modality.""" + _usage_class = TextUsage + _finish_reason_class = TextFinishReason + def __init__( self, sse_iterator: AsyncIterator[dict[str, Any]], diff --git a/src/celeste/modalities/videos/client.py b/src/celeste/modalities/videos/client.py index 80a89be..58955fe 100644 --- a/src/celeste/modalities/videos/client.py +++ b/src/celeste/modalities/videos/client.py @@ -8,7 +8,7 @@ from celeste.core import Modality from celeste.types import VideoContent -from .io import VideoInput, VideoOutput +from .io import VideoFinishReason, VideoInput, VideoOutput, VideoUsage from .parameters import VideoParameters @@ -18,6 +18,8 @@ class VideosClient( """Base videos client. Providers implement generate method.""" modality: Modality = Modality.VIDEOS + _usage_class = VideoUsage + _finish_reason_class = VideoFinishReason @classmethod def _output_class(cls) -> type[VideoOutput]: diff --git a/src/celeste/modalities/videos/providers/byteplus/client.py b/src/celeste/modalities/videos/providers/byteplus/client.py index afe9036..05d06d2 100644 --- a/src/celeste/modalities/videos/providers/byteplus/client.py +++ b/src/celeste/modalities/videos/providers/byteplus/client.py @@ -10,7 +10,7 @@ ) from ...client import VideosClient -from ...io import VideoFinishReason, VideoInput, VideoOutput, VideoUsage +from ...io import VideoInput, VideoOutput from ...parameters import VideoParameters from .parameters import BYTEPLUS_PARAMETER_MAPPERS @@ -41,11 +41,6 @@ async def generate( **parameters, ) - def _parse_usage(self, response_data: dict[str, Any]) -> VideoUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return VideoUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -59,10 +54,5 @@ def _parse_content( raise ValueError(msg) return VideoArtifact(url=video_url) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> VideoFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return VideoFinishReason(reason=finish_reason.reason) - __all__ = ["BytePlusVideosClient"] diff --git a/src/celeste/modalities/videos/providers/google/client.py b/src/celeste/modalities/videos/providers/google/client.py index c8a24e7..171aff3 100644 --- a/src/celeste/modalities/videos/providers/google/client.py +++ b/src/celeste/modalities/videos/providers/google/client.py @@ -9,7 +9,7 @@ from celeste.providers.google.veo.client import GoogleVeoClient as GoogleVeoMixin from ...client import VideosClient -from ...io import VideoFinishReason, VideoInput, VideoOutput, VideoUsage +from ...io import VideoInput, VideoOutput from ...parameters import VideoParameters from .parameters import GOOGLE_PARAMETER_MAPPERS @@ -40,11 +40,6 @@ async def generate( **parameters, ) - def _parse_usage(self, response_data: dict[str, Any]) -> VideoUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return VideoUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -60,11 +55,6 @@ def _parse_content( ) return VideoArtifact(url=video_data.get("uri")) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> VideoFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return VideoFinishReason(reason=finish_reason.reason) - async def download_content(self, artifact: VideoArtifact) -> VideoArtifact: """Download video content from GCS URL. diff --git a/src/celeste/modalities/videos/providers/openai/client.py b/src/celeste/modalities/videos/providers/openai/client.py index fe8aefb..e228449 100644 --- a/src/celeste/modalities/videos/providers/openai/client.py +++ b/src/celeste/modalities/videos/providers/openai/client.py @@ -11,7 +11,7 @@ ) from ...client import VideosClient -from ...io import VideoFinishReason, VideoInput, VideoOutput, VideoUsage +from ...io import VideoInput, VideoOutput from ...parameters import VideoParameters from .parameters import OPENAI_PARAMETER_MAPPERS @@ -42,11 +42,6 @@ async def generate( **parameters, ) - def _parse_usage(self, response_data: dict[str, Any]) -> VideoUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return VideoUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -59,10 +54,5 @@ def _parse_content( mime_type=VideoMimeType.MP4, ) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> VideoFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return VideoFinishReason(reason=finish_reason.reason) - __all__ = ["OpenAIVideosClient"] diff --git a/src/celeste/modalities/videos/providers/xai/client.py b/src/celeste/modalities/videos/providers/xai/client.py index 3225a85..8acdf48 100644 --- a/src/celeste/modalities/videos/providers/xai/client.py +++ b/src/celeste/modalities/videos/providers/xai/client.py @@ -8,7 +8,7 @@ from celeste.providers.xai.videos.client import XAIVideosClient as XAIVideosMixin from ...client import VideosClient -from ...io import VideoFinishReason, VideoInput, VideoOutput, VideoUsage +from ...io import VideoInput, VideoOutput from ...parameters import VideoParameters from .parameters import XAI_PARAMETER_MAPPERS @@ -54,11 +54,6 @@ async def edit( **parameters, ) - def _parse_usage(self, response_data: dict[str, Any]) -> VideoUsage: - """Parse usage from response.""" - usage = super()._parse_usage(response_data) - return VideoUsage(**usage) - def _parse_content( self, response_data: dict[str, Any], @@ -69,10 +64,5 @@ def _parse_content( url = super()._parse_content(response_data) return VideoArtifact(url=url) - def _parse_finish_reason(self, response_data: dict[str, Any]) -> VideoFinishReason: - """Parse finish reason from response.""" - finish_reason = super()._parse_finish_reason(response_data) - return VideoFinishReason(reason=finish_reason.reason) - __all__ = ["XAIVideosClient"] diff --git a/src/celeste/streaming.py b/src/celeste/streaming.py index 6e6d9bc..f438e80 100644 --- a/src/celeste/streaming.py +++ b/src/celeste/streaming.py @@ -4,14 +4,15 @@ from collections.abc import AsyncIterator from contextlib import AbstractContextManager, suppress from types import TracebackType -from typing import Any, Self, Unpack +from typing import Any, ClassVar, Self, Unpack from anyio.from_thread import BlockingPortal, start_blocking_portal from celeste.exceptions import StreamNotExhaustedError from celeste.io import Chunk as ChunkBase -from celeste.io import Output +from celeste.io import FinishReason, Output, Usage from celeste.parameters import Parameters +from celeste.types import RawUsage class Stream[Out: Output, Params: Parameters, Chunk: ChunkBase](ABC): @@ -25,6 +26,9 @@ class Stream[Out: Output, Params: Parameters, Chunk: ChunkBase](ABC): creates a background thread per stream. """ + _usage_class: ClassVar[type[Usage]] = Usage + _finish_reason_class: ClassVar[type[FinishReason]] = FinishReason + def __init__( self, sse_iterator: AsyncIterator[dict[str, Any]], @@ -50,6 +54,34 @@ def _parse_output(self, chunks: list[Chunk], **parameters: Unpack[Params]) -> Ou """Parse final Output from accumulated chunks.""" ... + def _parse_chunk_usage(self, event_data: dict[str, Any]) -> RawUsage | None: + """Parse usage from chunk event. Override in provider mixin.""" + return None + + def _parse_chunk_finish_reason( + self, event_data: dict[str, Any] + ) -> FinishReason | None: + """Parse finish reason from chunk event. Override in provider mixin.""" + return None + + def _get_chunk_usage(self, event_data: dict[str, Any]) -> Usage | None: + """Get modality-typed usage from chunk event.""" + raw = self._parse_chunk_usage(event_data) + if raw is None: + return None + return self._usage_class(**raw) + + def _get_chunk_finish_reason( + self, event_data: dict[str, Any] + ) -> FinishReason | None: + """Get modality-typed finish reason from chunk event.""" + raw = self._parse_chunk_finish_reason(event_data) + if raw is None: + return None + if isinstance(raw, self._finish_reason_class): + return raw + return self._finish_reason_class(reason=raw.reason) + def _build_stream_metadata( self, raw_events: list[dict[str, Any]] ) -> dict[str, Any]: diff --git a/src/celeste/types.py b/src/celeste/types.py index cbee754..7dadd71 100644 --- a/src/celeste/types.py +++ b/src/celeste/types.py @@ -19,6 +19,8 @@ type Content = str | JsonValue | dict[str, Any] | list[JsonValue | dict[str, Any]] +type RawUsage = dict[str, int | float | None] + class Role(StrEnum): """Message role in a conversation.""" @@ -45,6 +47,7 @@ class Message(BaseModel): "ImageContent", "JsonValue", "Message", + "RawUsage", "Role", "TextContent", "VideoContent",