Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ typecheck:

# Testing
test:
uv run pytest tests/unit_tests packages/*/tests/unit_tests --cov=celeste --cov-report=term-missing --cov-fail-under=90 -v
uv run pytest tests/unit_tests packages/*/tests/unit_tests --cov=celeste --cov-report=term-missing --cov-fail-under=80 -v

# Integration testing (requires API keys)
integration-test:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ class ImageGenerationClient(
@abstractmethod
def _init_request(self, inputs: ImageGenerationInput) -> dict[str, Any]:
"""Initialize provider-specific request structure."""
...

@abstractmethod
def _parse_usage(self, response_data: dict[str, Any]) -> ImageGenerationUsage:
"""Parse usage information from provider response."""
...

@abstractmethod
def _parse_content(
Expand All @@ -39,22 +37,20 @@ def _parse_content(
**parameters: Unpack[ImageGenerationParameters],
) -> ImageArtifact:
"""Parse content from provider response."""
...

@abstractmethod
def _parse_finish_reason(
self, response_data: dict[str, Any]
) -> ImageGenerationFinishReason | None:
"""Parse finish reason from provider response."""
...

def _create_inputs(
self, *args: str, **parameters: Unpack[ImageGenerationParameters]
) -> ImageGenerationInput:
"""Map positional arguments to Input type."""
if args:
return ImageGenerationInput(prompt=args[0])
prompt = parameters.get("prompt")
prompt: str | None = parameters.get("prompt")
if prompt is None:
msg = (
"prompt is required (either as positional argument or keyword argument)"
Expand Down Expand Up @@ -84,4 +80,3 @@ async def _make_request(
**parameters: Unpack[ImageGenerationParameters],
) -> httpx.Response:
"""Make HTTP request(s) and return response object."""
...
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from celeste.artifacts import ImageArtifact
from celeste.exceptions import ConstraintViolationError, ValidationError
from celeste.mime_types import ImageMimeType
from celeste.mime_types import ApplicationMimeType, ImageMimeType
from celeste.parameters import ParameterMapper
from celeste_image_generation.client import ImageGenerationClient
from celeste_image_generation.io import (
Expand Down Expand Up @@ -39,7 +39,7 @@ def _init_request(self, inputs: ImageGenerationInput) -> dict[str, Any]:
}

def _parse_usage(self, response_data: dict[str, Any]) -> ImageGenerationUsage:
"""Parse usage from ByteDance response."""
"""Parse usage from response."""
usage_data = response_data.get("usage", {})

return ImageGenerationUsage(
Expand All @@ -53,7 +53,7 @@ def _parse_content(
response_data: dict[str, Any],
**parameters: Unpack[ImageGenerationParameters],
) -> ImageArtifact:
"""Parse image content from ByteDance response."""
"""Parse content from response."""
images = response_data.get("images", [])
if images and images[0].get("url"):
return ImageArtifact(
Expand Down Expand Up @@ -81,7 +81,7 @@ def _parse_content(
def _parse_finish_reason(
self, response_data: dict[str, Any]
) -> ImageGenerationFinishReason | None:
"""Parse finish reason from provider response.
"""Parse finish reason from response.

ByteDance doesn't provide finish reasons for image generation.
"""
Expand Down Expand Up @@ -120,7 +120,7 @@ async def _make_request(

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
"Content-Type": "application/json",
"Content-Type": ApplicationMimeType.JSON,
}

return await self.http_client.post(
Expand All @@ -143,7 +143,7 @@ def _make_stream_request(

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
"Content-Type": "application/json",
"Content-Type": ApplicationMimeType.JSON,
}

return self.http_client.stream_post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from celeste.artifacts import ImageArtifact
from celeste.core import Provider
from celeste.exceptions import ModelNotFoundError
from celeste.mime_types import ImageMimeType
from celeste.mime_types import ApplicationMimeType, ImageMimeType
from celeste.parameters import ParameterMapper
from celeste_image_generation.client import ImageGenerationClient
from celeste_image_generation.io import (
Expand All @@ -24,11 +24,7 @@


class GoogleImageGenerationClient(ImageGenerationClient):
"""Google client for image generation.

Supports both Imagen API and Gemini multimodal API via adapter pattern.
Adapter selection happens automatically based on model type.
"""
"""Google client for image generation."""

model_config = ConfigDict(extra="allow")

Expand All @@ -42,23 +38,22 @@ def model_post_init(self, __context: Any) -> None: # noqa: ANN401

@classmethod
def parameter_mappers(cls) -> list[ParameterMapper]:
"""Return parameter mappers for Google provider."""
return GOOGLE_PARAMETER_MAPPERS

def _init_request(self, inputs: ImageGenerationInput) -> dict[str, Any]:
"""Initialize request using API adapter."""
"""Initialize request from Google API format."""
return self.api.build_request(inputs.prompt, {})

def _parse_usage(self, response_data: dict[str, Any]) -> ImageGenerationUsage:
"""Parse usage from response using API adapter."""
"""Parse usage from response."""
return self.api.parse_usage(response_data)

def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[ImageGenerationParameters],
) -> ImageArtifact:
"""Parse content from response using API adapter."""
"""Parse content from response."""
prediction = self.api.parse_response(response_data)

if prediction is None:
Expand All @@ -73,7 +68,7 @@ def _parse_content(
def _parse_finish_reason(
self, response_data: dict[str, Any]
) -> ImageGenerationFinishReason | None:
"""Parse finish reason from provider response.
"""Parse finish reason from response.

For Gemini models, extracts finishReason from candidates[0].
For Imagen models, returns None (not provided).
Expand Down Expand Up @@ -111,7 +106,7 @@ async def _make_request(
"""Make HTTP request(s) and return response object."""
headers = {
config.AUTH_HEADER_NAME: self.api_key.get_secret_value(),
"Content-Type": "application/json",
"Content-Type": ApplicationMimeType.JSON,
}

return await self.http_client.post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import httpx

from celeste.artifacts import ImageArtifact
from celeste.exceptions import ValidationError
from celeste.mime_types import ApplicationMimeType
from celeste.parameters import ParameterMapper
from celeste_image_generation.client import ImageGenerationClient
from celeste_image_generation.io import (
Expand All @@ -23,14 +23,14 @@


class OpenAIImageGenerationClient(ImageGenerationClient):
"""OpenAI client."""
"""OpenAI client for image generation."""

@classmethod
def parameter_mappers(cls) -> list[ParameterMapper]:
return OPENAI_PARAMETER_MAPPERS

def _init_request(self, inputs: ImageGenerationInput) -> dict[str, Any]:
"""Initialize request from inputs."""
"""Initialize request from OpenAI API format."""
request = {
"model": self.model.id,
"prompt": inputs.prompt,
Expand All @@ -55,7 +55,7 @@ def _parse_content(
data = response_data.get("data", [])
if not data:
msg = "No image data in response"
raise ValidationError(msg)
raise ValueError(msg)

image_data = data[0]

Expand All @@ -69,7 +69,7 @@ def _parse_content(
return ImageArtifact(url=url)

msg = "No image URL or base64 data in response"
raise ValidationError(msg)
raise ValueError(msg)

def _parse_finish_reason(
self, response_data: dict[str, Any]
Expand All @@ -95,7 +95,7 @@ async def _make_request(
"""Make HTTP request(s) and return response object."""
headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
"Content-Type": "application/json",
"Content-Type": ApplicationMimeType.JSON,
}

return await self.http_client.post(
Expand Down Expand Up @@ -125,7 +125,7 @@ def _make_stream_request(

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
"Content-Type": "application/json",
"Content-Type": ApplicationMimeType.JSON,
}

return self.http_client.stream_post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from celeste import Capability, Provider, create_client

# Integration tests require API credentials configured in CI environment


@pytest.mark.parametrize(
("provider", "model", "parameters"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ class TextGenerationClient(
@abstractmethod
def _init_request(self, inputs: TextGenerationInput) -> dict[str, Any]:
"""Initialize provider-specific request structure."""
...

@abstractmethod
def _parse_usage(self, response_data: dict[str, Any]) -> TextGenerationUsage:
"""Parse usage information from provider response."""
...

@abstractmethod
def _parse_content(
Expand All @@ -39,14 +37,12 @@ def _parse_content(
**parameters: Unpack[TextGenerationParameters],
) -> str | BaseModel:
"""Parse content from provider response."""
...

@abstractmethod
def _parse_finish_reason(
self, response_data: dict[str, Any]
) -> TextGenerationFinishReason | None:
"""Parse finish reason from provider response."""
...

def _create_inputs(
self, *args: str, **parameters: Unpack[TextGenerationParameters]
Expand Down Expand Up @@ -80,4 +76,3 @@ async def _make_request(
**parameters: Unpack[TextGenerationParameters],
) -> httpx.Response:
"""Make HTTP request(s) and return response object."""
...
2 changes: 1 addition & 1 deletion packages/text-generation/src/celeste_text_generation/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class TextGenerationInput(Input):
"""Input for text generation requests."""
"""Input for text generation operations."""

prompt: str

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import httpx
from pydantic import BaseModel

from celeste.mime_types import ApplicationMimeType
from celeste.parameters import ParameterMapper
from celeste_text_generation.client import TextGenerationClient
from celeste_text_generation.io import (
Expand All @@ -21,7 +22,7 @@


class AnthropicTextGenerationClient(TextGenerationClient):
"""Anthropic client."""
"""Anthropic client for text generation."""

@classmethod
def parameter_mappers(cls) -> list[ParameterMapper]:
Expand Down Expand Up @@ -99,7 +100,7 @@ async def _make_request(
headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
config.ANTHROPIC_VERSION_HEADER: config.ANTHROPIC_VERSION,
"Content-Type": "application/json",
"Content-Type": ApplicationMimeType.JSON,
}

return await self.http_client.post(
Expand All @@ -125,7 +126,7 @@ def _make_stream_request(
headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
config.ANTHROPIC_VERSION_HEADER: config.ANTHROPIC_VERSION,
"Content-Type": "application/json",
"Content-Type": ApplicationMimeType.JSON,
}

return self.http_client.stream_post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class AnthropicTextGenerationStream(TextGenerationStream):
def __init__(
self,
sse_iterator: Any, # noqa: ANN401
transform_output: Callable[[object, Any], object],
transform_output: Callable[..., object],
**parameters: Unpack[TextGenerationParameters],
) -> None:
"""Initialize stream with output transformation support.
Expand Down Expand Up @@ -306,31 +306,29 @@ def _parse_output(
# Empty dict for BaseModel - try text chunks, but if none, raise error
text_content = "".join(chunk.content for chunk in chunks)
if text_content:
content = self._transform_output(
text_content, **self._parameters
)
content = self._transform_output(text_content, **parameters)
else:
msg = "Empty tool_use input dict and no text chunks available for BaseModel"
raise ValidationError(msg)
else:
# Empty dict for list[BaseModel] - OK, parse_output will convert to []
content = self._transform_output(tool_input, **self._parameters)
content = self._transform_output(tool_input, **parameters)
else:
# Valid tool_input - transform to BaseModel
content = self._transform_output(tool_input, **self._parameters)
content = self._transform_output(tool_input, **parameters)
else:
# Fallback: concatenate text chunks
text_content = "".join(chunk.content for chunk in chunks)
if text_content:
content = self._transform_output(text_content, **self._parameters)
content = self._transform_output(text_content, **parameters)
else:
msg = "No tool_use input and no text chunks available"
raise ValidationError(msg)
else:
# No tool_use blocks or no output_schema: concatenate text chunks
content = "".join(chunk.content for chunk in chunks)
# Apply parameter transformations (e.g., JSON → BaseModel if output_schema provided)
content = self._transform_output(content, **self._parameters)
content = self._transform_output(content, **parameters)

usage = self._parse_usage(chunks)
finish_reason = chunks[-1].finish_reason if chunks else None
Expand Down
Loading
Loading