Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/celeste/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from json import JSONDecodeError
from typing import Any, Unpack
from typing import Any, ClassVar, Unpack

import httpx
from pydantic import BaseModel, ConfigDict, Field
Expand All @@ -16,7 +16,7 @@
from celeste.models import Model
from celeste.parameters import ParameterMapper, Parameters
from celeste.streaming import Stream
from celeste.types import TextContent
from celeste.types import RawUsage, TextContent


class APIMixin(ABC):
Expand Down Expand Up @@ -124,6 +124,9 @@ async def generate(self, prompt: str, **parameters) -> ImageGenerationOutput:

model_config = ConfigDict(arbitrary_types_allowed=True)

_usage_class: ClassVar[type[Usage]] = Usage
_finish_reason_class: ClassVar[type[FinishReason]] = FinishReason

modality: Modality
model: Model
provider: Provider
Expand Down Expand Up @@ -171,8 +174,8 @@ async def _predict(
)
return self._output_class()(
content=self._parse_content(response_data, **parameters),
usage=self._parse_usage(response_data),
finish_reason=self._parse_finish_reason(response_data),
usage=self._get_usage(response_data),
finish_reason=self._get_finish_reason(response_data),
metadata=self._build_metadata(response_data),
)

Expand Down Expand Up @@ -227,7 +230,7 @@ def parameter_mappers(cls) -> list[ParameterMapper]:
...

@abstractmethod
def _parse_usage(self, response_data: dict[str, Any]) -> Usage:
def _parse_usage(self, response_data: dict[str, Any]) -> RawUsage:
"""Parse usage information from provider response."""
...

Expand All @@ -246,6 +249,20 @@ def _parse_finish_reason(
"""Parse finish reason from provider response."""
return None

def _get_usage(self, response_data: dict[str, Any]) -> Usage:
"""Get modality-typed usage from response."""
raw = self._parse_usage(response_data)
return self._usage_class(**raw)

def _get_finish_reason(self, response_data: dict[str, Any]) -> FinishReason | None:
"""Get modality-typed finish reason from response."""
raw = self._parse_finish_reason(response_data)
if raw is None:
return None
if isinstance(raw, self._finish_reason_class):
return raw
return self._finish_reason_class(reason=raw.reason)

@classmethod
@abstractmethod
def _output_class(cls) -> type[Out]:
Expand Down
4 changes: 3 additions & 1 deletion src/celeste/modalities/audio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from celeste.core import Modality
from celeste.types import AudioContent

from .io import AudioInput, AudioOutput
from .io import AudioFinishReason, AudioInput, AudioOutput, AudioUsage
from .parameters import AudioParameters
from .streaming import AudioStream

Expand All @@ -19,6 +19,8 @@ class AudioClient(
"""Base audio client. Providers implement speak() method."""

modality: Modality = Modality.AUDIO
_usage_class = AudioUsage
_finish_reason_class = AudioFinishReason

@classmethod
def _output_class(cls) -> type[AudioOutput]:
Expand Down
36 changes: 4 additions & 32 deletions src/celeste/modalities/audio/providers/elevenlabs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
from ...client import AudioClient
from ...io import (
AudioChunk,
AudioFinishReason,
AudioInput,
AudioOutput,
AudioUsage,
)
from ...parameters import AudioParameters
from ...streaming import AudioStream
Expand All @@ -28,28 +26,12 @@
class ElevenLabsAudioStream(_ElevenLabsTextToSpeechStream, AudioStream):
"""ElevenLabs streaming for audio modality."""

def _parse_chunk_usage(self, event_data: dict[str, Any]) -> AudioUsage | None:
"""Parse and wrap usage from event."""
usage = super()._parse_chunk_usage(event_data)
if usage:
return AudioUsage(**usage)
return None

def _parse_chunk_finish_reason(
self, event_data: dict[str, Any]
) -> AudioFinishReason | None:
"""Parse and wrap finish reason from event."""
finish_reason = super()._parse_chunk_finish_reason(event_data)
if finish_reason:
return AudioFinishReason(reason=finish_reason.reason)
return None

def _parse_chunk(self, event_data: dict[str, Any]) -> AudioChunk | None:
"""Parse binary audio chunk from stream event."""
chunk_data = self._parse_chunk_content(event_data)
if not chunk_data:
usage = self._parse_chunk_usage(event_data)
finish_reason = self._parse_chunk_finish_reason(event_data)
usage = self._get_chunk_usage(event_data)
finish_reason = self._get_chunk_finish_reason(event_data)
if usage is None and finish_reason is None:
return None
# Chunk with usage/finish_reason only (no audio)
Expand All @@ -62,8 +44,8 @@ def _parse_chunk(self, event_data: dict[str, Any]) -> AudioChunk | None:

return AudioChunk(
content=chunk_data,
finish_reason=self._parse_chunk_finish_reason(event_data),
usage=self._parse_chunk_usage(event_data),
finish_reason=self._get_chunk_finish_reason(event_data),
usage=self._get_chunk_usage(event_data),
metadata={"event_data": event_data},
)

Expand Down Expand Up @@ -110,11 +92,6 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
"""Initialize request with text input."""
return {"text": inputs.text}

def _parse_usage(self, response_data: dict[str, Any]) -> AudioUsage:
"""Parse usage from response."""
usage = super()._parse_usage(response_data)
return AudioUsage(**usage)

def _parse_content(
self,
response_data: dict[str, Any],
Expand All @@ -131,11 +108,6 @@ def _parse_content(

return AudioArtifact(data=audio_bytes, mime_type=mime_type)

def _parse_finish_reason(self, response_data: dict[str, Any]) -> AudioFinishReason:
"""Parse finish reason from response."""
finish_reason = super()._parse_finish_reason(response_data)
return AudioFinishReason(reason=finish_reason.reason)

def _stream_class(self) -> type[AudioStream]:
"""Return the Stream class for this provider."""
return ElevenLabsAudioStream
Expand Down
12 changes: 0 additions & 12 deletions src/celeste/modalities/audio/providers/google/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@

from ...client import AudioClient
from ...io import (
AudioFinishReason,
AudioInput,
AudioOutput,
AudioUsage,
)
from ...parameters import AudioParameter, AudioParameters
from .parameters import GOOGLE_PARAMETER_MAPPERS
Expand Down Expand Up @@ -49,11 +47,6 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
"audioConfig": {},
}

def _parse_usage(self, response_data: dict[str, Any]) -> AudioUsage:
"""Parse usage from response."""
usage = super()._parse_usage(response_data)
return AudioUsage(**usage)

def _parse_content(
self,
response_data: dict[str, Any],
Expand All @@ -67,10 +60,5 @@ def _parse_content(

return AudioArtifact(data=audio_b64, mime_type=mime_type)

def _parse_finish_reason(self, response_data: dict[str, Any]) -> AudioFinishReason:
"""Parse finish reason from response."""
finish_reason = super()._parse_finish_reason(response_data)
return AudioFinishReason(reason=finish_reason.reason)


__all__ = ["GoogleAudioClient"]
36 changes: 4 additions & 32 deletions src/celeste/modalities/audio/providers/gradium/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
from ...client import AudioClient
from ...io import (
AudioChunk,
AudioFinishReason,
AudioInput,
AudioOutput,
AudioUsage,
)
from ...parameters import AudioParameters
from ...streaming import AudioStream
Expand All @@ -28,28 +26,12 @@
class GradiumAudioStream(_GradiumTextToSpeechStream, AudioStream):
"""Gradium streaming for audio modality."""

def _parse_chunk_usage(self, event_data: dict[str, Any]) -> AudioUsage | None:
"""Parse and wrap usage from event."""
usage = super()._parse_chunk_usage(event_data)
if usage:
return AudioUsage(**usage)
return None

def _parse_chunk_finish_reason(
self, event_data: dict[str, Any]
) -> AudioFinishReason | None:
"""Parse and wrap finish reason from event."""
finish_reason = super()._parse_chunk_finish_reason(event_data)
if finish_reason:
return AudioFinishReason(reason=finish_reason.reason)
return None

def _parse_chunk(self, event_data: dict[str, Any]) -> AudioChunk | None:
"""Parse binary audio chunk from stream event."""
chunk_data = self._parse_chunk_content(event_data)
if not chunk_data:
usage = self._parse_chunk_usage(event_data)
finish_reason = self._parse_chunk_finish_reason(event_data)
usage = self._get_chunk_usage(event_data)
finish_reason = self._get_chunk_finish_reason(event_data)
if usage is None and finish_reason is None:
return None
# Chunk with usage/finish_reason only (no audio)
Expand All @@ -62,8 +44,8 @@ def _parse_chunk(self, event_data: dict[str, Any]) -> AudioChunk | None:

return AudioChunk(
content=chunk_data,
finish_reason=self._parse_chunk_finish_reason(event_data),
usage=self._parse_chunk_usage(event_data),
finish_reason=self._get_chunk_finish_reason(event_data),
usage=self._get_chunk_usage(event_data),
metadata={"event_data": event_data},
)

Expand Down Expand Up @@ -110,11 +92,6 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
"""Initialize request with text input."""
return {"text": inputs.text}

def _parse_usage(self, response_data: dict[str, Any]) -> AudioUsage:
"""Parse usage from response."""
usage = super()._parse_usage(response_data)
return AudioUsage(**usage)

def _parse_content(
self,
response_data: dict[str, Any],
Expand All @@ -131,11 +108,6 @@ def _parse_content(

return AudioArtifact(data=audio_bytes, mime_type=mime_type)

def _parse_finish_reason(self, response_data: dict[str, Any]) -> AudioFinishReason:
"""Parse finish reason from response."""
finish_reason = super()._parse_finish_reason(response_data)
return AudioFinishReason(reason=finish_reason.reason)

def _stream_class(self) -> type[AudioStream]:
"""Return the Stream class for this provider."""
return GradiumAudioStream
Expand Down
7 changes: 1 addition & 6 deletions src/celeste/modalities/audio/providers/openai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from celeste.providers.openai.audio.client import OpenAIAudioClient as OpenAIAudioMixin

from ...client import AudioClient
from ...io import AudioFinishReason, AudioInput, AudioOutput, AudioUsage
from ...io import AudioFinishReason, AudioInput, AudioOutput
from ...parameters import AudioParameters
from .parameters import OPENAI_PARAMETER_MAPPERS

Expand Down Expand Up @@ -37,11 +37,6 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
"""Initialize request with text input."""
return {"input": inputs.text}

def _parse_usage(self, response_data: dict[str, Any]) -> AudioUsage:
"""Parse usage from response."""
usage = super()._parse_usage(response_data)
return AudioUsage(**usage)

def _parse_content(
self,
response_data: dict[str, Any],
Expand Down
3 changes: 3 additions & 0 deletions src/celeste/modalities/audio/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
class AudioStream(Stream[AudioOutput, AudioParameters, AudioChunk]):
"""Streaming for audio modality."""

_usage_class = AudioUsage
_finish_reason_class = AudioFinishReason

def __init__(
self,
sse_iterator: AsyncIterator[dict[str, Any]],
Expand Down
9 changes: 8 additions & 1 deletion src/celeste/modalities/embeddings/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from celeste.core import Modality
from celeste.types import EmbeddingsContent

from .io import EmbeddingsInput, EmbeddingsOutput
from .io import (
EmbeddingsFinishReason,
EmbeddingsInput,
EmbeddingsOutput,
EmbeddingsUsage,
)
from .parameters import EmbeddingsParameters


Expand All @@ -20,6 +25,8 @@ class EmbeddingsClient(
"""Base embeddings client. Providers implement operation methods."""

modality: Modality = Modality.EMBEDDINGS
_usage_class = EmbeddingsUsage
_finish_reason_class = EmbeddingsFinishReason

@classmethod
def _output_class(cls) -> type[EmbeddingsOutput]:
Expand Down
18 changes: 1 addition & 17 deletions src/celeste/modalities/embeddings/providers/google/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
from celeste.types import EmbeddingsContent

from ...client import EmbeddingsClient
from ...io import (
EmbeddingsFinishReason,
EmbeddingsInput,
EmbeddingsUsage,
)
from ...io import EmbeddingsInput
from ...parameters import EmbeddingsParameters
from .parameters import GOOGLE_PARAMETER_MAPPERS

Expand Down Expand Up @@ -43,11 +39,6 @@ def _init_request(self, inputs: EmbeddingsInput) -> dict[str, Any]:
]
}

def _parse_usage(self, response_data: dict[str, Any]) -> EmbeddingsUsage:
"""Parse usage from response (embeddings API doesn't provide usage)."""
usage = super()._parse_usage(response_data)
return EmbeddingsUsage(**usage)

def _parse_content(
self,
response_data: dict[str, Any],
Expand All @@ -56,12 +47,5 @@ def _parse_content(
"""Parse embedding vectors from response."""
return super()._parse_content(response_data)

def _parse_finish_reason(
self, response_data: dict[str, Any]
) -> EmbeddingsFinishReason:
"""Parse finish reason (embeddings API doesn't provide finish reasons)."""
finish_reason = super()._parse_finish_reason(response_data)
return EmbeddingsFinishReason(reason=finish_reason.reason)


__all__ = ["GoogleEmbeddingsClient"]
4 changes: 3 additions & 1 deletion src/celeste/modalities/images/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from celeste.core import Modality
from celeste.types import ImageContent

from .io import ImageInput, ImageOutput
from .io import ImageFinishReason, ImageInput, ImageOutput, ImageUsage
from .parameters import ImageParameters
from .streaming import ImagesStream

Expand All @@ -20,6 +20,8 @@ class ImagesClient(
"""Base images client. Providers implement generate/edit methods."""

modality: Modality = Modality.IMAGES
_usage_class = ImageUsage
_finish_reason_class = ImageFinishReason

@classmethod
def _output_class(cls) -> type[ImageOutput]:
Expand Down
Loading
Loading