Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "celeste-ai"
version = "0.9.4"
version = "0.9.5"
description = "Open source, type-safe primitives for multi-modal AI. All capabilities, all providers, one interface"
authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}]
readme = "README.md"
Expand Down
15 changes: 12 additions & 3 deletions src/celeste/modalities/audio/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Audio modality client."""

from typing import Unpack
from typing import Any, Unpack

from asgiref.sync import async_to_sync

Expand Down Expand Up @@ -45,13 +45,16 @@ def __init__(self, client: AudioClient) -> None:
def speak(
self,
text: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[AudioParameters],
) -> AudioStream:
"""Stream speech generation."""
inputs = AudioInput(text=text)
return self._client._stream(
inputs,
stream_class=self._client._stream_class(),
extra_body=extra_body,
**parameters,
)

Expand All @@ -65,11 +68,15 @@ def __init__(self, client: AudioClient) -> None:
def speak(
self,
text: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[AudioParameters],
) -> AudioOutput:
"""Blocking speech generation."""
inputs = AudioInput(text=text)
return async_to_sync(self._client._predict)(inputs, **parameters)
return async_to_sync(self._client._predict)(
inputs, extra_body=extra_body, **parameters
)

@property
def stream(self) -> "AudioSyncStreamNamespace":
Expand All @@ -86,6 +93,8 @@ def __init__(self, client: AudioClient) -> None:
def speak(
self,
text: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[AudioParameters],
) -> AudioStream:
"""Sync streaming speech generation.
Expand All @@ -99,7 +108,7 @@ def speak(
stream.output.content.save("output.mp3")
"""
# Return same stream as async version - __iter__/__next__ handle sync iteration
return self._client.stream.speak(text, **parameters)
return self._client.stream.speak(text, extra_body=extra_body, **parameters)


__all__ = [
Expand Down
13 changes: 10 additions & 3 deletions src/celeste/modalities/embeddings/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Embeddings modality client."""

from typing import Unpack
from typing import Any, Unpack

from asgiref.sync import async_to_sync

Expand Down Expand Up @@ -29,12 +29,15 @@ def _output_class(cls) -> type[EmbeddingsOutput]:
async def embed(
self,
text: str | list[str],
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[EmbeddingsParameters],
) -> EmbeddingsOutput:
"""Generate embeddings from text.

Args:
text: Text to embed. Single string or list of strings.
extra_body: Additional provider-specific fields to merge into request.
**parameters: Embedding parameters (e.g., dimensions).

Returns:
Expand All @@ -43,7 +46,7 @@ async def embed(
- list[list[float]] if text was a list
"""
inputs = EmbeddingsInput(text=text)
output = await self._predict(inputs, **parameters)
output = await self._predict(inputs, extra_body=extra_body, **parameters)

# If single text input, unwrap from batch format to single embedding
if (
Expand Down Expand Up @@ -71,10 +74,14 @@ def __init__(self, client: EmbeddingsClient) -> None:
def embed(
self,
text: str | list[str],
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[EmbeddingsParameters],
) -> EmbeddingsOutput:
"""Blocking embeddings generation."""
return async_to_sync(self._client.embed)(text, **parameters)
return async_to_sync(self._client.embed)(
text, extra_body=extra_body, **parameters
)


__all__ = [
Expand Down
30 changes: 25 additions & 5 deletions src/celeste/modalities/images/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Images modality client."""

from typing import Unpack
from typing import Any, Unpack

from asgiref.sync import async_to_sync

Expand Down Expand Up @@ -49,27 +49,33 @@ def __init__(self, client: ImagesClient) -> None:
def generate(
self,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImagesStream:
"""Stream image generation."""
inputs = ImageInput(prompt=prompt)
return self._client._stream(
inputs,
stream_class=self._client._stream_class(),
extra_body=extra_body,
**parameters,
)

def edit(
self,
image: ImageArtifact,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImagesStream:
"""Stream image editing."""
inputs = ImageInput(prompt=prompt, image=image)
return self._client._stream(
inputs,
stream_class=self._client._stream_class(),
extra_body=extra_body,
**parameters,
)

Expand All @@ -86,6 +92,8 @@ def __init__(self, client: ImagesClient) -> None:
def generate(
self,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImageOutput:
"""Blocking image generation.
Expand All @@ -95,12 +103,16 @@ def generate(
result.content.show()
"""
inputs = ImageInput(prompt=prompt)
return async_to_sync(self._client._predict)(inputs, **parameters)
return async_to_sync(self._client._predict)(
inputs, extra_body=extra_body, **parameters
)

def edit(
self,
image: ImageArtifact,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImageOutput:
"""Blocking image edit.
Expand All @@ -110,7 +122,9 @@ def edit(
result.content.show()
"""
inputs = ImageInput(prompt=prompt, image=image)
return async_to_sync(self._client._predict)(inputs, **parameters)
return async_to_sync(self._client._predict)(
inputs, extra_body=extra_body, **parameters
)

@property
def stream(self) -> "ImagesSyncStreamNamespace":
Expand All @@ -127,6 +141,8 @@ def __init__(self, client: ImagesClient) -> None:
def generate(
self,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImagesStream:
"""Sync streaming image generation.
Expand All @@ -140,12 +156,14 @@ def generate(
print(stream.output.usage)
"""
# Return same stream as async version - __iter__/__next__ handle sync iteration
return self._client.stream.generate(prompt, **parameters)
return self._client.stream.generate(prompt, extra_body=extra_body, **parameters)

def edit(
self,
image: ImageArtifact,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImagesStream:
"""Sync streaming image editing.
Expand All @@ -158,7 +176,9 @@ def edit(
print(chunk.content)
print(stream.output.usage)
"""
return self._client.stream.edit(image, prompt, **parameters)
return self._client.stream.edit(
image, prompt, extra_body=extra_body, **parameters
)


__all__ = [
Expand Down
8 changes: 6 additions & 2 deletions src/celeste/modalities/videos/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Videos modality client."""

from typing import Unpack
from typing import Any, Unpack

from asgiref.sync import async_to_sync

Expand Down Expand Up @@ -42,6 +42,8 @@ def __init__(self, client: VideosClient) -> None:
def generate(
self,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[VideoParameters],
) -> VideoOutput:
"""Blocking video generation.
Expand All @@ -51,7 +53,9 @@ def generate(
result.content.save("video.mp4")
"""
inputs = VideoInput(prompt=prompt)
return async_to_sync(self._client._predict)(inputs, **parameters)
return async_to_sync(self._client._predict)(
inputs, extra_body=extra_body, **parameters
)


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""{Modality} modality client."""

from typing import Unpack
from typing import Any, Unpack

from asgiref.sync import async_to_sync

Expand Down Expand Up @@ -70,6 +70,8 @@ class {Modality}StreamNamespace:
def generate(
self,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[{Modality}Parameters],
) -> {Modality}Stream:
"""Stream {modality} generation.
Expand All @@ -82,6 +84,7 @@ class {Modality}StreamNamespace:
return self._client._stream(
inputs,
stream_class=self._client._stream_class(),
extra_body=extra_body,
**parameters,
)

Expand All @@ -92,6 +95,7 @@ class {Modality}StreamNamespace:
image: ImageContent | None = None,
video: VideoContent | None = None,
audio: AudioContent | None = None,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[{Modality}Parameters],
) -> {Modality}Stream:
"""Stream media analysis (image, video, or audio).
Expand All @@ -105,6 +109,7 @@ class {Modality}StreamNamespace:
return self._client._stream(
inputs,
stream_class=self._client._stream_class(),
extra_body=extra_body,
**parameters,
)

Expand All @@ -121,6 +126,8 @@ class {Modality}SyncNamespace:
def generate(
self,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[{Modality}Parameters],
) -> {Modality}Output:
"""Blocking {modality} generation.
Expand All @@ -130,7 +137,7 @@ class {Modality}SyncNamespace:
print(result.content)
"""
inputs = {Modality}Input(prompt=prompt)
return async_to_sync(self._client._predict)(inputs, **parameters)
return async_to_sync(self._client._predict)(inputs, extra_body=extra_body, **parameters)

def analyze(
self,
Expand All @@ -139,6 +146,7 @@ class {Modality}SyncNamespace:
image: ImageContent | None = None,
video: VideoContent | None = None,
audio: AudioContent | None = None,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[{Modality}Parameters],
) -> {Modality}Output:
"""Blocking media analysis (image, video, or audio).
Expand All @@ -149,7 +157,7 @@ class {Modality}SyncNamespace:
"""
self._client._check_media_support(image=image, video=video, audio=audio)
inputs = {Modality}Input(prompt=prompt, image=image, video=video, audio=audio)
return async_to_sync(self._client._predict)(inputs, **parameters)
return async_to_sync(self._client._predict)(inputs, extra_body=extra_body, **parameters)

@property
def stream(self) -> "{Modality}SyncStreamNamespace":
Expand All @@ -166,6 +174,8 @@ class {Modality}SyncStreamNamespace:
def generate(
self,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[{Modality}Parameters],
) -> {Modality}Stream:
"""Sync streaming {modality} generation.
Expand All @@ -179,7 +189,7 @@ class {Modality}SyncStreamNamespace:
print(stream.output.usage)
"""
# Return same stream as async version - __iter__/__next__ handle sync iteration
return self._client.stream.generate(prompt, **parameters)
return self._client.stream.generate(prompt, extra_body=extra_body, **parameters)

def analyze(
self,
Expand All @@ -188,6 +198,7 @@ class {Modality}SyncStreamNamespace:
image: ImageContent | None = None,
video: VideoContent | None = None,
audio: AudioContent | None = None,
extra_body: dict[str, Any] | None = None,
**parameters: Unpack[{Modality}Parameters],
) -> {Modality}Stream:
"""Sync streaming media analysis (image, video, or audio).
Expand All @@ -202,7 +213,7 @@ class {Modality}SyncStreamNamespace:
"""
# Return same stream as async version - __iter__/__next__ handle sync iteration
return self._client.stream.analyze(
prompt, image=image, video=video, audio=audio, **parameters
prompt, image=image, video=video, audio=audio, extra_body=extra_body, **parameters
)


Expand Down
Loading