From ac0342dc6b34501698e8f24416028725a2864ea5 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Mon, 10 Nov 2025 18:57:15 +0100 Subject: [PATCH 1/5] refactor: extract registry module to fix circular imports --- src/celeste/__init__.py | 17 +++-------------- src/celeste/models.py | 4 ++++ src/celeste/registry.py | 24 +++++++++++++++++++++++ tests/unit_tests/test_models.py | 34 ++++++++++++++++++++++++--------- 4 files changed, 56 insertions(+), 23 deletions(-) create mode 100644 src/celeste/registry.py diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 4ac3e73..e80e8e3 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -1,4 +1,3 @@ -import importlib.metadata import logging from pydantic import SecretStr @@ -23,6 +22,7 @@ from celeste.io import Input, Output, Usage from celeste.models import Model, get_model, list_models, register_models from celeste.parameters import Parameters +from celeste.registry import _load_from_entry_points logger = logging.getLogger(__name__) @@ -79,6 +79,8 @@ def create_client( MissingCredentialsError: If required credentials are not configured. UnsupportedCapabilityError: If the resolved model doesn't support the requested capability. """ + # Load packages lazily when create_client is called + _load_from_entry_points() # Resolve model resolved_model = _resolve_model(capability, provider, model) @@ -97,19 +99,6 @@ def create_client( ) -def _load_from_entry_points() -> None: - """Load models and clients from installed packages via entry points.""" - entry_points = importlib.metadata.entry_points(group="celeste.packages") - - for ep in entry_points: - register_func = ep.load() - # The function should register models and clients when called - register_func() - - -# Load from entry points on import -_load_from_entry_points() - # Exports __all__ = [ "Capability", diff --git a/src/celeste/models.py b/src/celeste/models.py index 4b66f20..5b6382b 100644 --- a/src/celeste/models.py +++ b/src/celeste/models.py @@ -94,6 +94,10 @@ def list_models( Returns: List of Model instances matching the filters. """ + # Load packages lazily to avoid circular imports + from celeste.registry import _load_from_entry_points + + _load_from_entry_points() models = list(_models.values()) if provider is not None: diff --git a/src/celeste/registry.py b/src/celeste/registry.py new file mode 100644 index 0000000..48690d9 --- /dev/null +++ b/src/celeste/registry.py @@ -0,0 +1,24 @@ +"""Package registry for lazy loading entry points.""" + +import importlib.metadata + +_loaded_packages: set[str] = set() + + +def _load_from_entry_points() -> None: + """Load models and clients from installed packages via entry points.""" + + entry_points = importlib.metadata.entry_points(group="celeste.packages") + + # Early return if all packages are already loaded + entry_point_names = {ep.name for ep in entry_points} + if entry_point_names.issubset(_loaded_packages): + return + + for ep in entry_points: + if ep.name in _loaded_packages: + continue + register_func = ep.load() + # The function should register models and clients when called + register_func() + _loaded_packages.add(ep.name) diff --git a/tests/unit_tests/test_models.py b/tests/unit_tests/test_models.py index 34d96e4..0baac72 100644 --- a/tests/unit_tests/test_models.py +++ b/tests/unit_tests/test_models.py @@ -41,8 +41,11 @@ class TestRegisterModels: """Test model registration functionality.""" @pytest.mark.smoke - def test_register_models_accepts_single_or_list(self) -> None: + @patch("celeste.registry._load_from_entry_points") + def test_register_models_accepts_single_or_list(self, mock_load: Mock) -> None: """Registering models works with both single model and list.""" + # Prevent entry point loading from interfering with test isolation + mock_load.return_value = None single_model = SAMPLE_MODELS[0] register_models(single_model, Capability.TEXT_GENERATION) retrieved = get_model(single_model.id, single_model.provider) @@ -63,8 +66,11 @@ def test_register_models_accepts_single_or_list(self) -> None: assert model.provider == retrieved.provider assert Capability.TEXT_GENERATION in retrieved.capabilities - def test_reregistering_same_key_raises_error(self) -> None: + @patch("celeste.registry._load_from_entry_points") + def test_reregistering_same_key_raises_error(self, mock_load: Mock) -> None: """Re-registering with same (id, provider) but different display_name raises ValueError.""" + # Prevent entry point loading from interfering with test isolation + mock_load.return_value = None original = SAMPLE_MODELS[0] register_models(original, Capability.TEXT_GENERATION) @@ -82,8 +88,13 @@ def test_reregistering_same_key_raises_error(self) -> None: assert result.display_name == original.display_name assert len(list_models()) == 1 - def test_registering_same_model_for_multiple_capabilities_merges(self) -> None: + @patch("celeste.registry._load_from_entry_points") + def test_registering_same_model_for_multiple_capabilities_merges( + self, mock_load: Mock + ) -> None: """Registering the same model for multiple capabilities merges capabilities.""" + # Prevent entry point loading from interfering with test isolation + mock_load.return_value = None model = Model( id="multi-cap-model", provider=Provider.OPENAI, @@ -115,8 +126,10 @@ class TestListModels: """Test model listing and filtering functionality.""" @pytest.fixture(autouse=True) - def setup_models(self) -> None: + def setup_models(self, monkeypatch: pytest.MonkeyPatch) -> None: """Set up test models for filtering tests.""" + # Prevent entry point loading from interfering with test isolation + monkeypatch.setattr("celeste.registry._load_from_entry_points", lambda: None) register_models(SAMPLE_MODELS[0], Capability.TEXT_GENERATION) register_models(SAMPLE_MODELS[1], Capability.IMAGE_GENERATION) register_models(SAMPLE_MODELS[2], Capability.TEXT_GENERATION) @@ -232,7 +245,7 @@ def test_same_id_different_providers_are_distinct(self) -> None: class TestEntryPoints: """Test entry point loading functionality.""" - @patch("celeste.importlib.metadata.entry_points") + @patch("celeste.registry.importlib.metadata.entry_points") def test_entry_point_loading_success( self, mock_entry_points: Mock, capsys: pytest.CaptureFixture[str] ) -> None: @@ -251,7 +264,7 @@ def test_entry_point_loading_success( mock_entry_points.return_value = [mock_ep] clear() - from celeste import _load_from_entry_points + from celeste.registry import _load_from_entry_points _load_from_entry_points() @@ -261,7 +274,7 @@ def test_entry_point_loading_success( captured = capsys.readouterr() assert captured.err == "" - @patch("celeste.importlib.metadata.entry_points") + @patch("celeste.registry.importlib.metadata.entry_points") def test_entry_point_returns_none_handled( self, mock_entry_points: Mock, capsys: pytest.CaptureFixture[str] ) -> None: @@ -273,7 +286,7 @@ def test_entry_point_returns_none_handled( mock_entry_points.return_value = [mock_ep] clear() - from celeste import _load_from_entry_points + from celeste.registry import _load_from_entry_points _load_from_entry_points() @@ -335,8 +348,11 @@ def test_list_models_includes_parameters(self) -> None: class TestClear: """Test registry clearing functionality.""" - def test_clear_removes_all_models(self) -> None: + @patch("celeste.registry._load_from_entry_points") + def test_clear_removes_all_models(self, mock_load: Mock) -> None: """clear removes all registered models.""" + # Prevent entry point loading from interfering with test isolation + mock_load.return_value = None register_models(SAMPLE_MODELS[0], Capability.TEXT_GENERATION) register_models(SAMPLE_MODELS[1], Capability.IMAGE_GENERATION) register_models(SAMPLE_MODELS[2], Capability.TEXT_GENERATION) From a057ceccce87623a5dcee7978ed4c253601fcdfa Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Mon, 10 Nov 2025 18:57:53 +0100 Subject: [PATCH 2/5] feat: add image-generation package with OpenAI, Google, and ByteDance providers --- Makefile | 4 +- packages/image-generation/README.md | 79 +++++++++ packages/image-generation/pyproject.toml | 39 +++++ .../src/celeste_image_generation/__init__.py | 36 ++++ .../src/celeste_image_generation/client.py | 87 +++++++++ .../celeste_image_generation/constraints.py | 72 ++++++++ .../src/celeste_image_generation/io.py | 58 ++++++ .../src/celeste_image_generation/models.py | 14 ++ .../celeste_image_generation/parameters.py | 23 +++ .../providers/__init__.py | 28 +++ .../providers/bytedance/README.md | 13 ++ .../providers/bytedance/__init__.py | 6 + .../providers/bytedance/client.py | 160 +++++++++++++++++ .../providers/bytedance/config.py | 10 ++ .../providers/bytedance/models.py | 34 ++++ .../providers/bytedance/parameters.py | 101 +++++++++++ .../providers/bytedance/streaming.py | 67 +++++++ .../providers/google/__init__.py | 6 + .../providers/google/client.py | 152 ++++++++++++++++ .../providers/google/config.py | 10 ++ .../providers/google/gemini_api.py | 82 +++++++++ .../providers/google/imagen_api.py | 67 +++++++ .../providers/google/models.py | 86 +++++++++ .../providers/google/parameters.py | 67 +++++++ .../providers/openai/__init__.py | 7 + .../providers/openai/client.py | 142 +++++++++++++++ .../providers/openai/config.py | 10 ++ .../providers/openai/models.py | 44 +++++ .../providers/openai/parameters.py | 118 +++++++++++++ .../providers/openai/streaming.py | 76 ++++++++ .../src/celeste_image_generation/py.typed | 0 .../src/celeste_image_generation/streaming.py | 48 +++++ .../test_image_generation/__init__.py | 1 + .../test_image_generation/test_generate.py | 61 +++++++ .../tests/unit_tests/__init__.py | 0 .../unit_tests/providers/google/__init__.py | 0 .../providers/google/test_finish_reason.py | 165 ++++++++++++++++++ pyproject.toml | 22 ++- 38 files changed, 1990 insertions(+), 5 deletions(-) create mode 100644 packages/image-generation/README.md create mode 100644 packages/image-generation/pyproject.toml create mode 100644 packages/image-generation/src/celeste_image_generation/__init__.py create mode 100644 packages/image-generation/src/celeste_image_generation/client.py create mode 100644 packages/image-generation/src/celeste_image_generation/constraints.py create mode 100644 packages/image-generation/src/celeste_image_generation/io.py create mode 100644 packages/image-generation/src/celeste_image_generation/models.py create mode 100644 packages/image-generation/src/celeste_image_generation/parameters.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/__init__.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/bytedance/README.md create mode 100644 packages/image-generation/src/celeste_image_generation/providers/bytedance/__init__.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/bytedance/client.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/bytedance/config.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/bytedance/models.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/bytedance/parameters.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/bytedance/streaming.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/google/__init__.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/google/client.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/google/config.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/google/gemini_api.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/google/imagen_api.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/google/models.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/google/parameters.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/openai/__init__.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/openai/client.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/openai/config.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/openai/models.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/openai/parameters.py create mode 100644 packages/image-generation/src/celeste_image_generation/providers/openai/streaming.py create mode 100644 packages/image-generation/src/celeste_image_generation/py.typed create mode 100644 packages/image-generation/src/celeste_image_generation/streaming.py create mode 100644 packages/image-generation/tests/integration_tests/test_image_generation/__init__.py create mode 100644 packages/image-generation/tests/integration_tests/test_image_generation/test_generate.py create mode 100644 packages/image-generation/tests/unit_tests/__init__.py create mode 100644 packages/image-generation/tests/unit_tests/providers/google/__init__.py create mode 100644 packages/image-generation/tests/unit_tests/providers/google/test_finish_reason.py diff --git a/Makefile b/Makefile index 19a5e34..7691661 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ help: @echo " make lint - Run Ruff linting" @echo " make format - Apply Ruff formatting" @echo " make typecheck - Run mypy type checking" - @echo " make test - Run pytest with coverage" + @echo " make test - Run all tests (core + packages) with coverage" @echo " make integration-test - Run integration tests (requires API keys)" @echo " make security - Run Bandit security scan" @echo " make ci - Run full CI/CD pipeline" @@ -37,7 +37,7 @@ typecheck: # Testing test: - uv run pytest tests/unit_tests --cov=celeste --cov-report=term-missing --cov-fail-under=90 + uv run pytest tests/unit_tests packages/*/tests/unit_tests --cov=celeste --cov-report=term-missing --cov-fail-under=90 -v # Integration testing (requires API keys) integration-test: diff --git a/packages/image-generation/README.md b/packages/image-generation/README.md new file mode 100644 index 0000000..bc87b36 --- /dev/null +++ b/packages/image-generation/README.md @@ -0,0 +1,79 @@ +
+ +# Celeste Logo Celeste Image Generation + +**Image Generation capability for Celeste AI** + +[![Python](https://img.shields.io/badge/Python-3.12+-blue?style=for-the-badge)](https://www.python.org/) +[![License](https://img.shields.io/badge/License-Apache_2.0-red?style=for-the-badge)](../../LICENSE) + +[Quick Start](#-quick-start) • [Documentation](https://withceleste.ai/docs) • [Request Provider](https://github.com/withceleste/celeste-python/issues/new) + +
+ +--- + +## 🚀 Quick Start + +```python +from celeste import create_client, Capability, Provider + +client = create_client( + capability=Capability.IMAGE_GENERATION, + provider=Provider.OPENAI, +) + +response = await client.generate(prompt="A red apple on a white background") +print(response.content) +``` + +**Install:** +```bash +uv add "celeste-ai[image-generation]" +``` + +--- + +## Supported Providers + + +
+ +OpenAI +Google +ByteDance + + +**Missing a provider?** [Request it](https://github.com/withceleste/celeste-python/issues/new) – ⚡ **we ship fast**. + +
+ +--- + +**Streaming**: ✅ Supported + +**Parameters**: See [API Documentation](https://withceleste.ai/docs/api) for full parameter reference. + +--- + +## 🤝 Contributing + +See [CONTRIBUTING.md](../../CONTRIBUTING.md) for guidelines. + +**Request a provider:** [GitHub Issues](https://github.com/withceleste/celeste-python/issues/new) + +--- + +## 📄 License + +Apache 2.0 License – see [LICENSE](../../LICENSE) for details. + +--- + +
+ +**[Get Started](https://withceleste.ai/docs/quickstart)** • **[Documentation](https://withceleste.ai/docs)** • **[GitHub](https://github.com/withceleste/celeste-python)** + +Made with ❤️ by developers tired of framework lock-in + +
diff --git a/packages/image-generation/pyproject.toml b/packages/image-generation/pyproject.toml new file mode 100644 index 0000000..b941fcb --- /dev/null +++ b/packages/image-generation/pyproject.toml @@ -0,0 +1,39 @@ +[project] +name = "celeste-image-generation" +version = "0.2.1" +description = "Type-safe image generation for Celeste AI. Unified interface for OpenAI, Google, ByteDance, and more" +authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}] +readme = "README.md" +license = {text = "Apache-2.0"} +requires-python = ">=3.12" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Typing :: Typed", +] +keywords = ["ai", "image-generation", "dall-e", "imagen", "openai", "google", "bytedance"] + +[project.urls] +Homepage = "https://withceleste.ai" +Documentation = "https://withceleste.ai/docs" +Repository = "https://github.com/withceleste/celeste-python" +Issues = "https://github.com/withceleste/celeste-python/issues" + +[tool.uv.sources] +celeste-ai = { workspace = true } + +[project.entry-points."celeste.packages"] +image-generation = "celeste_image_generation:register_package" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/celeste_image_generation"] diff --git a/packages/image-generation/src/celeste_image_generation/__init__.py b/packages/image-generation/src/celeste_image_generation/__init__.py new file mode 100644 index 0000000..41b1231 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/__init__.py @@ -0,0 +1,36 @@ +"""Celeste image generation capability.""" + + +def register_package() -> None: + """Register image generation package (client and models).""" + from celeste.client import register_client + from celeste.core import Capability + from celeste.models import register_models + from celeste_image_generation.models import MODELS + from celeste_image_generation.providers import PROVIDERS + + # Register provider-specific clients + for provider, client_class in PROVIDERS: + register_client(Capability.IMAGE_GENERATION, provider, client_class) + + register_models(MODELS, capability=Capability.IMAGE_GENERATION) + + +from celeste_image_generation.io import ( # noqa: E402 + ImageGenerationChunk, + ImageGenerationFinishReason, + ImageGenerationInput, + ImageGenerationOutput, + ImageGenerationUsage, +) +from celeste_image_generation.streaming import ImageGenerationStream # noqa: E402 + +__all__ = [ + "ImageGenerationChunk", + "ImageGenerationFinishReason", + "ImageGenerationInput", + "ImageGenerationOutput", + "ImageGenerationStream", + "ImageGenerationUsage", + "register_package", +] diff --git a/packages/image-generation/src/celeste_image_generation/client.py b/packages/image-generation/src/celeste_image_generation/client.py new file mode 100644 index 0000000..a555dc6 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/client.py @@ -0,0 +1,87 @@ +"""Base client for image generation.""" + +from abc import abstractmethod +from typing import Any, Unpack + +import httpx + +from celeste.artifacts import ImageArtifact +from celeste.client import Client +from celeste.exceptions import ValidationError +from celeste_image_generation.io import ( + ImageGenerationFinishReason, + ImageGenerationInput, + ImageGenerationOutput, + ImageGenerationUsage, +) +from celeste_image_generation.parameters import ImageGenerationParameters + + +class ImageGenerationClient( + Client[ImageGenerationInput, ImageGenerationOutput, ImageGenerationParameters] +): + """Client for image generation operations.""" + + @abstractmethod + def _init_request(self, inputs: ImageGenerationInput) -> dict[str, Any]: + """Initialize provider-specific request structure.""" + ... + + @abstractmethod + def _parse_usage(self, response_data: dict[str, Any]) -> ImageGenerationUsage: + """Parse usage information from provider response.""" + ... + + @abstractmethod + def _parse_content( + self, + response_data: dict[str, Any], + **parameters: Unpack[ImageGenerationParameters], + ) -> ImageArtifact: + """Parse content from provider response.""" + ... + + @abstractmethod + def _parse_finish_reason( + self, response_data: dict[str, Any] + ) -> ImageGenerationFinishReason | None: + """Parse finish reason from provider response.""" + ... + + def _create_inputs( + self, *args: str, **parameters: Unpack[ImageGenerationParameters] + ) -> ImageGenerationInput: + """Map positional arguments to Input type.""" + if args: + return ImageGenerationInput(prompt=args[0]) + prompt = parameters.get("prompt") + if prompt is None: + msg = ( + "prompt is required (either as positional argument or keyword argument)" + ) + raise ValidationError(msg) + return ImageGenerationInput(prompt=prompt) + + @classmethod + def _output_class(cls) -> type[ImageGenerationOutput]: + """Return the Output class for this client.""" + return ImageGenerationOutput + + def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: + """Build metadata dictionary from response data.""" + metadata = super()._build_metadata(response_data) + # Only parse finish_reason if not already set by provider override + if "finish_reason" not in metadata: + finish_reason = self._parse_finish_reason(response_data) + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + return metadata + + @abstractmethod + async def _make_request( + self, + request_body: dict[str, Any], + **parameters: Unpack[ImageGenerationParameters], + ) -> httpx.Response: + """Make HTTP request(s) and return response object.""" + ... diff --git a/packages/image-generation/src/celeste_image_generation/constraints.py b/packages/image-generation/src/celeste_image_generation/constraints.py new file mode 100644 index 0000000..615e3f2 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/constraints.py @@ -0,0 +1,72 @@ +"""Image generation specific constraints.""" + +from celeste.constraints import Constraint +from celeste.exceptions import ConstraintViolationError + + +class Dimensions(Constraint): + """Dimension string constraint with pixel and aspect ratio bounds.""" + + min_pixels: int + max_pixels: int + min_aspect_ratio: float + max_aspect_ratio: float + presets: dict[str, str] | None = None + + def __call__(self, value: str) -> str: + """Validate dimension string against pixel and aspect ratio bounds.""" + if not isinstance(value, str): + msg = f"Must be string, got {type(value).__name__}" + raise ConstraintViolationError(msg) + + # Check if value is a preset key + if self.presets and value in self.presets: + actual_value = self.presets[value] + else: + actual_value = value + + # Parse dimension format "WIDTHxHEIGHT" + parts = actual_value.lower().split("x") + if len(parts) != 2: + msg = f"Invalid dimension format: {actual_value!r}. Expected 'WIDTHxHEIGHT'" + raise ConstraintViolationError(msg) + + # Validate parts are numeric + if not parts[0].isdigit() or not parts[1].isdigit(): + msg = ( + f"Invalid dimension format: {actual_value!r}. " + f"Width and height must be positive integers" + ) + raise ConstraintViolationError(msg) + + width = int(parts[0]) + height = int(parts[1]) + + # Validate dimensions are positive + if width <= 0 or height <= 0: + msg = f"Width and height must be positive, got {width}x{height}" + raise ConstraintViolationError(msg) + + # Validate total pixels + total_pixels = width * height + if not (self.min_pixels <= total_pixels <= self.max_pixels): + msg = ( + f"Total pixels {total_pixels:,} outside valid range " + f"[{self.min_pixels:,}, {self.max_pixels:,}]" + ) + raise ConstraintViolationError(msg) + + # Validate aspect ratio + aspect_ratio = width / height + if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): + msg = ( + f"Aspect ratio {aspect_ratio:.3f} outside valid range " + f"[{self.min_aspect_ratio:.3f}, {self.max_aspect_ratio:.1f}]" + ) + raise ConstraintViolationError(msg) + + # Return normalized format + return f"{width}x{height}" + + +__all__ = ["Dimensions"] diff --git a/packages/image-generation/src/celeste_image_generation/io.py b/packages/image-generation/src/celeste_image_generation/io.py new file mode 100644 index 0000000..54047c5 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/io.py @@ -0,0 +1,58 @@ +"""Input and output types for image generation.""" + +from celeste.artifacts import ImageArtifact +from celeste.io import Chunk, FinishReason, Input, Output, Usage + + +class ImageGenerationInput(Input): + """Input for image generation operations.""" + + prompt: str + + +class ImageGenerationFinishReason(FinishReason): + """Image generation finish reason. + + Stores raw provider reason. Providers map their values in implementation. + """ + + reason: str | None = ( + None # Raw provider string (e.g., "STOP", "NO_IMAGE", "PROHIBITED_CONTENT") + ) + message: str | None = None # Optional human-readable explanation from provider + + +class ImageGenerationUsage(Usage): + """Image generation usage metrics. + + Most providers don't report usage metrics for image generation. + OpenAI gpt-image-1 reports usage only in streaming mode. + ByteDance reports tokens_used for billing tracking. + """ + + total_tokens: int | None = None + input_tokens: int | None = None + output_tokens: int | None = None + generated_images: int | None = None + + +class ImageGenerationOutput(Output[ImageArtifact]): + """Output with ImageArtifact content.""" + + pass + + +class ImageGenerationChunk(Chunk[ImageArtifact]): + """Typed chunk for image generation streaming.""" + + finish_reason: ImageGenerationFinishReason | None = None + usage: ImageGenerationUsage | None = None + + +__all__ = [ + "ImageGenerationChunk", + "ImageGenerationFinishReason", + "ImageGenerationInput", + "ImageGenerationOutput", + "ImageGenerationUsage", +] diff --git a/packages/image-generation/src/celeste_image_generation/models.py b/packages/image-generation/src/celeste_image_generation/models.py new file mode 100644 index 0000000..f8ef972 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/models.py @@ -0,0 +1,14 @@ +"""Model definitions for image generation.""" + +from celeste import Model +from celeste_image_generation.providers.bytedance.models import ( + MODELS as BYTEDANCE_MODELS, +) +from celeste_image_generation.providers.google.models import MODELS as GOOGLE_MODELS +from celeste_image_generation.providers.openai.models import MODELS as OPENAI_MODELS + +MODELS: list[Model] = [ + *BYTEDANCE_MODELS, + *GOOGLE_MODELS, + *OPENAI_MODELS, +] diff --git a/packages/image-generation/src/celeste_image_generation/parameters.py b/packages/image-generation/src/celeste_image_generation/parameters.py new file mode 100644 index 0000000..ba5ad24 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/parameters.py @@ -0,0 +1,23 @@ +"""Parameters for image generation.""" + +from enum import StrEnum + +from celeste.parameters import Parameters + + +class ImageGenerationParameter(StrEnum): + """Unified parameter names for image generation capability.""" + + ASPECT_RATIO = "aspect_ratio" + PARTIAL_IMAGES = "partial_images" + QUALITY = "quality" + WATERMARK = "watermark" + + +class ImageGenerationParameters(Parameters): + """Parameters for image generation.""" + + aspect_ratio: str | None + partial_images: int | None + quality: str | None + watermark: bool | None diff --git a/packages/image-generation/src/celeste_image_generation/providers/__init__.py b/packages/image-generation/src/celeste_image_generation/providers/__init__.py new file mode 100644 index 0000000..1549ef5 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/__init__.py @@ -0,0 +1,28 @@ +"""Provider implementations for image generation.""" + +from celeste import Client, Provider + +__all__ = ["PROVIDERS"] + + +def _get_providers() -> list[tuple[Provider, type[Client]]]: + """Lazy-load providers.""" + # Import clients directly from .client modules to avoid __init__.py imports + from celeste_image_generation.providers.bytedance.client import ( + ByteDanceImageGenerationClient, + ) + from celeste_image_generation.providers.google.client import ( + GoogleImageGenerationClient, + ) + from celeste_image_generation.providers.openai.client import ( + OpenAIImageGenerationClient, + ) + + return [ + (Provider.BYTEDANCE, ByteDanceImageGenerationClient), + (Provider.GOOGLE, GoogleImageGenerationClient), + (Provider.OPENAI, OpenAIImageGenerationClient), + ] + + +PROVIDERS: list[tuple[Provider, type[Client]]] = _get_providers() diff --git a/packages/image-generation/src/celeste_image_generation/providers/bytedance/README.md b/packages/image-generation/src/celeste_image_generation/providers/bytedance/README.md new file mode 100644 index 0000000..ffbcc75 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/bytedance/README.md @@ -0,0 +1,13 @@ +# ByteDance Image Generation Provider + +## Credentials + +**Environment variable:** `BYTEDANCE_API_KEY` + +**Setup:** +1. Register at [console.byteplus.com](https://console.byteplus.com) +2. Activate model in ModelArk section +3. Generate API key with image generation permissions +4. Set environment variable: `export BYTEDANCE_API_KEY="your-key"` + +**Note:** Models must be activated in BytePlus console before use. If you get a 404 error, activate the model or use an Endpoint ID (`ep-xxx`) instead of Model ID. diff --git a/packages/image-generation/src/celeste_image_generation/providers/bytedance/__init__.py b/packages/image-generation/src/celeste_image_generation/providers/bytedance/__init__.py new file mode 100644 index 0000000..1076978 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/bytedance/__init__.py @@ -0,0 +1,6 @@ +"""ByteDance provider.""" + +from .client import ByteDanceImageGenerationClient +from .models import MODELS + +__all__ = ["MODELS", "ByteDanceImageGenerationClient"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/bytedance/client.py b/packages/image-generation/src/celeste_image_generation/providers/bytedance/client.py new file mode 100644 index 0000000..822f6a2 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/bytedance/client.py @@ -0,0 +1,160 @@ +"""ByteDance client implementation.""" + +import base64 +from collections.abc import AsyncIterator +from typing import Any, Unpack + +import httpx + +from celeste.artifacts import ImageArtifact +from celeste.mime_types import ImageMimeType +from celeste.parameters import ParameterMapper +from celeste_image_generation.client import ImageGenerationClient +from celeste_image_generation.io import ( + ImageGenerationFinishReason, + ImageGenerationInput, + ImageGenerationUsage, +) +from celeste_image_generation.parameters import ImageGenerationParameters + +from . import config +from .parameters import BYTEDANCE_PARAMETER_MAPPERS +from .streaming import ByteDanceImageGenerationStream + + +class ByteDanceImageGenerationClient(ImageGenerationClient): + """ByteDance client for image generation.""" + + @classmethod + def parameter_mappers(cls) -> list[ParameterMapper]: + return BYTEDANCE_PARAMETER_MAPPERS + + def _init_request(self, inputs: ImageGenerationInput) -> dict[str, Any]: + """Initialize request from ByteDance API structure.""" + return { + "model": self.model.id, + "prompt": inputs.prompt, + "response_format": "url", + } + + def _parse_usage(self, response_data: dict[str, Any]) -> ImageGenerationUsage: + """Parse usage from ByteDance response.""" + usage_data = response_data.get("usage", {}) + + return ImageGenerationUsage( + total_tokens=usage_data.get("total_tokens"), + output_tokens=usage_data.get("output_tokens"), + generated_images=usage_data.get("generated_images"), + ) + + def _parse_content( + self, + response_data: dict[str, Any], + **parameters: Unpack[ImageGenerationParameters], + ) -> ImageArtifact: + """Parse image content from ByteDance response.""" + images = response_data.get("images", []) + if images and images[0].get("url"): + return ImageArtifact( + url=images[0]["url"], + mime_type=ImageMimeType.PNG, + ) + + data = response_data.get("data", []) + if data: + if data[0].get("url"): + return ImageArtifact( + url=data[0]["url"], + mime_type=ImageMimeType.PNG, + ) + if data[0].get("b64_json"): + image_bytes = base64.b64decode(data[0]["b64_json"]) + return ImageArtifact( + data=image_bytes, + mime_type=ImageMimeType.PNG, + ) + + msg = "No image content found in ByteDance response" + raise ValueError(msg) + + def _parse_finish_reason( + self, response_data: dict[str, Any] + ) -> ImageGenerationFinishReason | None: + """Parse finish reason from provider response. + + ByteDance doesn't provide finish reasons for image generation. + """ + return None + + def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: + """Build metadata dictionary from response data. + + Extracts seed if present. + """ + # Filter content fields before calling super + content_fields = {"images", "data"} + filtered_data = { + k: v for k, v in response_data.items() if k not in content_fields + } + metadata = super()._build_metadata(filtered_data) + # Add provider-specific parsed fields + seed = response_data.get("seed") + if seed is not None: + metadata["seed"] = seed + return metadata + + async def _make_request( + self, + request_body: dict[str, Any], + **parameters: Unpack[ImageGenerationParameters], + ) -> httpx.Response: + """Make HTTP request(s) and return response object.""" + # Validate mutually exclusive parameters + if parameters.get("aspect_ratio") and parameters.get("quality"): + msg = ( + "Cannot use both 'aspect_ratio' and 'quality' parameters. " + "ByteDance's 'size' field supports two methods that cannot be combined:\n" + " • quality: Resolution class ('1K', '2K', '4K')\n" + " • aspect_ratio: Exact dimensions (e.g., '2048x2048', '3840x2160')\n" + "Use one or the other, not both." + ) + raise ValueError(msg) + + request_body["stream"] = False + + headers = { + config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + "Content-Type": "application/json", + } + + return await self.http_client.post( + f"{config.BASE_URL}{config.ENDPOINT}", + headers=headers, + json_body=request_body, + ) + + def _stream_class(self) -> type[ByteDanceImageGenerationStream]: + """Return the Stream class for this client.""" + return ByteDanceImageGenerationStream + + def _make_stream_request( + self, + request_body: dict[str, Any], + **parameters: Unpack[ImageGenerationParameters], + ) -> AsyncIterator[dict[str, Any]]: + """Make HTTP streaming request and return async iterator of events.""" + request_body["stream"] = True + + headers = { + config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + "Content-Type": "application/json", + } + + return self.http_client.stream_post( + f"{config.BASE_URL}{config.ENDPOINT}", + headers=headers, + json_body=request_body, + ) + + +__all__ = ["ByteDanceImageGenerationClient"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/bytedance/config.py b/packages/image-generation/src/celeste_image_generation/providers/bytedance/config.py new file mode 100644 index 0000000..7554ddd --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/bytedance/config.py @@ -0,0 +1,10 @@ +"""ByteDance provider configuration.""" + +# HTTP Configuration +BASE_URL = "https://ark.ap-southeast.bytepluses.com/api/v3" +ENDPOINT = "/images/generations" +STREAM_ENDPOINT = "/responses" + +# Authentication +AUTH_HEADER_NAME = "Authorization" +AUTH_HEADER_PREFIX = "Bearer " diff --git a/packages/image-generation/src/celeste_image_generation/providers/bytedance/models.py b/packages/image-generation/src/celeste_image_generation/providers/bytedance/models.py new file mode 100644 index 0000000..5249525 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/bytedance/models.py @@ -0,0 +1,34 @@ +"""ByteDance models.""" + +from celeste import Model, Provider +from celeste.constraints import Bool, Choice +from celeste_image_generation.constraints import Dimensions +from celeste_image_generation.parameters import ImageGenerationParameter + +MODELS: list[Model] = [ + Model( + id="seedream-4-0-250828", + provider=Provider.BYTEDANCE, + display_name="Seedream 4.0", + parameter_constraints={ + ImageGenerationParameter.ASPECT_RATIO: Dimensions( + min_pixels=1280 * 720, # 921,600 + max_pixels=4096 * 4096, # 16,777,216 + min_aspect_ratio=1 / 16, # 0.0625 + max_aspect_ratio=16, + presets={ + "Square 2K": "2048x2048", + "Square 4K": "4096x4096", + "HD 16:9": "1920x1080", + "2K 16:9": "2560x1440", + "4K 16:9": "3840x2160", + "Portrait HD": "1080x1920", + "Portrait 2K": "1440x2560", + "Ultra-wide 21:9": "3024x1296", + }, + ), + ImageGenerationParameter.QUALITY: Choice(options=["1K", "2K", "4K"]), + ImageGenerationParameter.WATERMARK: Bool(), + }, + ), +] diff --git a/packages/image-generation/src/celeste_image_generation/providers/bytedance/parameters.py b/packages/image-generation/src/celeste_image_generation/providers/bytedance/parameters.py new file mode 100644 index 0000000..f6e145f --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/bytedance/parameters.py @@ -0,0 +1,101 @@ +"""ByteDance parameter mappers.""" + +from typing import Any + +from celeste import Model +from celeste.parameters import ParameterMapper +from celeste_image_generation.parameters import ImageGenerationParameter + + +class AspectRatioMapper(ParameterMapper): + """Map aspect_ratio to dimension string. + + Accepts freeform dimension strings (e.g., "2048x2048", "3840x2160") + validated by Dimensions constraint against ByteDance's pixel and aspect ratio bounds. + """ + + name = ImageGenerationParameter.ASPECT_RATIO + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform aspect_ratio into provider request. + + The Dimensions constraint validates: + - Format: "WIDTHxHEIGHT" + - Total pixels: [921,600, 16,777,216] + - Aspect ratio: [1/16, 16] + """ + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Transform to provider-specific request format (top-level field) + request["size"] = validated_value + return request + + +class QualityMapper(ParameterMapper): + """Map quality parameter with validation.""" + + name = ImageGenerationParameter.QUALITY + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform quality into provider request. + + Maps quality levels ("1K", "2K", "4K") to ByteDance's size parameter. + Skips if size is already set by aspect_ratio (conflict resolution). + """ + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Skip if size already set by aspect_ratio parameter (conflict resolution) + if "size" in request: + return request + + # Transform to provider-specific request format (top-level field) + request["size"] = validated_value + return request + + +class WatermarkMapper(ParameterMapper): + """Map watermark parameter with validation.""" + + name = ImageGenerationParameter.WATERMARK + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform watermark into provider request. + + Adds "AI generated" watermark to bottom-right corner when true. + Default is true if omitted. + """ + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Transform to provider-specific request format (top-level field) + request["watermark"] = validated_value + return request + + +BYTEDANCE_PARAMETER_MAPPERS: list[ParameterMapper] = [ + AspectRatioMapper(), + QualityMapper(), + WatermarkMapper(), +] + +__all__ = ["BYTEDANCE_PARAMETER_MAPPERS"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/bytedance/streaming.py b/packages/image-generation/src/celeste_image_generation/providers/bytedance/streaming.py new file mode 100644 index 0000000..215ba82 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/bytedance/streaming.py @@ -0,0 +1,67 @@ +"""ByteDance streaming for image generation.""" + +import logging +from collections.abc import AsyncIterator +from typing import Any + +from celeste.artifacts import ImageArtifact +from celeste.io import Chunk +from celeste.mime_types import ImageMimeType +from celeste_image_generation.io import ImageGenerationChunk, ImageGenerationUsage +from celeste_image_generation.streaming import ImageGenerationStream + +logger = logging.getLogger(__name__) + + +class ByteDanceImageGenerationStream(ImageGenerationStream): + """ByteDance streaming for image generation.""" + + def __init__(self, sse_iterator: AsyncIterator[dict[str, Any]]) -> None: + """Initialize stream and track completed event usage.""" + super().__init__(sse_iterator) + self._completed_usage: ImageGenerationUsage | None = None + + def _parse_chunk(self, chunk_data: dict[str, Any]) -> Chunk | None: + """Parse chunk from SSE event.""" + event_type = chunk_data.get("type") + + if event_type == "image_generation.partial_succeeded": + url = chunk_data.get("url") + if not url: + logger.warning("partial_succeeded event missing URL") + return None + + artifact = ImageArtifact(url=url, mime_type=ImageMimeType.PNG) + return ImageGenerationChunk(content=artifact) + + if event_type == "image_generation.completed": + usage_data = chunk_data.get("usage") + if usage_data: + self._completed_usage = ImageGenerationUsage( + total_tokens=usage_data.get("total_tokens"), + input_tokens=None, + output_tokens=None, + ) + return None + + if event_type == "image_generation.partial_failed": + error = chunk_data.get("error", {}) + logger.error( + "Image generation failed: %s - %s", + error.get("code"), + error.get("message"), + ) + return None + + logger.warning("Unknown event type: %s", event_type) + return None + + def _parse_usage(self, chunks: list[ImageGenerationChunk]) -> ImageGenerationUsage: + """Parse usage from chunks.""" + if self._completed_usage is not None: + return self._completed_usage + + return ImageGenerationUsage() + + +__all__ = ["ByteDanceImageGenerationStream"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/google/__init__.py b/packages/image-generation/src/celeste_image_generation/providers/google/__init__.py new file mode 100644 index 0000000..39c5b9b --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/google/__init__.py @@ -0,0 +1,6 @@ +"""Google provider.""" + +from .client import GoogleImageGenerationClient +from .models import MODELS + +__all__ = ["MODELS", "GoogleImageGenerationClient"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/google/client.py b/packages/image-generation/src/celeste_image_generation/providers/google/client.py new file mode 100644 index 0000000..7525f74 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/google/client.py @@ -0,0 +1,152 @@ +"""Google client implementation.""" + +import base64 +from typing import Any, Unpack + +import httpx +from pydantic import ConfigDict + +from celeste.artifacts import ImageArtifact +from celeste.mime_types import ImageMimeType +from celeste.parameters import ParameterMapper +from celeste_image_generation.client import ImageGenerationClient +from celeste_image_generation.io import ( + ImageGenerationFinishReason, + ImageGenerationInput, + ImageGenerationUsage, +) +from celeste_image_generation.parameters import ImageGenerationParameters + +from . import config +from .parameters import GOOGLE_PARAMETER_MAPPERS + + +class GoogleImageGenerationClient(ImageGenerationClient): + """Google client for image generation. + + Supports both Imagen API and Gemini multimodal API via adapter pattern. + Adapter selection happens automatically based on model type. + """ + + model_config = ConfigDict(extra="allow") + + def model_post_init(self, __context: Any) -> None: # noqa: ANN401 + """Initialize API adapter based on model type.""" + super().model_post_init(__context) + + adapter_class, _ = _get_adapter_for_model(self.model.id) + self.api = adapter_class() + self.endpoint = self.api.endpoint(self.model.id) + + @classmethod + def parameter_mappers(cls) -> list[ParameterMapper]: + """Return parameter mappers for Google provider.""" + return GOOGLE_PARAMETER_MAPPERS + + def _init_request(self, inputs: ImageGenerationInput) -> dict[str, Any]: + """Initialize request using API adapter.""" + return self.api.build_request(inputs.prompt, {}) + + def _parse_usage(self, response_data: dict[str, Any]) -> ImageGenerationUsage: + """Parse usage from response using API adapter.""" + return self.api.parse_usage(response_data) + + def _parse_content( + self, + response_data: dict[str, Any], + **parameters: Unpack[ImageGenerationParameters], + ) -> ImageArtifact: + """Parse content from response using API adapter.""" + prediction = self.api.parse_response(response_data) + + if prediction is None: + return ImageArtifact() + + base64_data = prediction.get("bytesBase64Encoded") or prediction["data"] + mime_type = ImageMimeType(prediction.get("mimeType", "image/png")) + image_bytes = base64.b64decode(base64_data) + + return ImageArtifact(data=image_bytes, mime_type=mime_type) + + def _parse_finish_reason( + self, response_data: dict[str, Any] + ) -> ImageGenerationFinishReason | None: + """Parse finish reason from provider response. + + For Gemini models, extracts finishReason from candidates[0]. + For Imagen models, returns None (not provided). + """ + # Check if this is a Gemini response (has "candidates") + candidates = response_data.get("candidates") + if candidates: + candidate = candidates[0] + finish_reason_str = candidate.get("finishReason") + if finish_reason_str: + finish_message = candidate.get("finishMessage") + return ImageGenerationFinishReason( + reason=finish_reason_str, + message=finish_message, + ) + # Imagen models don't provide finish reasons + return None + + def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: + """Build metadata dictionary from response data.""" + # Parse finish_reason from full response_data before filtering (needs "candidates") + finish_reason = self._parse_finish_reason(response_data) + + # Filter content fields before calling super (Imagen uses "predictions", Gemini uses "candidates") + content_fields = {"predictions", "candidates"} + filtered_data = { + k: v for k, v in response_data.items() if k not in content_fields + } + metadata = super()._build_metadata(filtered_data) + # Override with pre-parsed finish_reason (base class parsed from filtered_data which returns None) + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + return metadata + + async def _make_request( + self, + request_body: dict[str, Any], + **parameters: Unpack[ImageGenerationParameters], + ) -> httpx.Response: + """Make HTTP request(s) and return response object.""" + headers = { + config.AUTH_HEADER_NAME: self.api_key.get_secret_value(), + "Content-Type": "application/json", + } + + return await self.http_client.post( + f"{config.BASE_URL}{self.endpoint}", + headers=headers, + json_body=request_body, + ) + + +def _get_adapter_for_model(model_id: str) -> tuple[type, str]: + """Get adapter class and endpoint for model ID. + + Returns: + Tuple of (adapter_class, endpoint_template). + """ + from .models import GEMINI_MODELS, IMAGEN_MODELS + + # Create sets for O(1) lookup (computed once per import) + imagen_model_ids = {model.id for model in IMAGEN_MODELS} + gemini_model_ids = {model.id for model in GEMINI_MODELS} + + if model_id in imagen_model_ids: + from .imagen_api import ImagenAPIAdapter + + return ImagenAPIAdapter, config.IMAGEN_ENDPOINT + if model_id in gemini_model_ids: + from .gemini_api import GeminiImageAPIAdapter + + return GeminiImageAPIAdapter, config.GEMINI_ENDPOINT + + msg = f"Unknown Google image generation model: {model_id}" + raise ValueError(msg) + + +__all__ = ["GoogleImageGenerationClient"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/google/config.py b/packages/image-generation/src/celeste_image_generation/providers/google/config.py new file mode 100644 index 0000000..83901df --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/google/config.py @@ -0,0 +1,10 @@ +"""Google provider configuration.""" + +# HTTP Configuration +BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models" +IMAGEN_ENDPOINT = "/{model_id}:predict" +GEMINI_ENDPOINT = "/{model_id}:generateContent" + +# Authentication +AUTH_HEADER_NAME = "x-goog-api-key" +AUTH_HEADER_PREFIX = "" # Direct API key, no prefix diff --git a/packages/image-generation/src/celeste_image_generation/providers/google/gemini_api.py b/packages/image-generation/src/celeste_image_generation/providers/google/gemini_api.py new file mode 100644 index 0000000..3011578 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/google/gemini_api.py @@ -0,0 +1,82 @@ +"""Gemini API adapter for Google image generation. + +Pure data transformer for Gemini multimodal models (gemini-2.5-flash-image). +Handles request/response structure transformation only. +""" + +from typing import Any + +from celeste_image_generation.io import ImageGenerationUsage + +from . import config + + +class GeminiImageAPIAdapter: + """Adapter for Gemini multimodal API request/response transformation. + + Request format: contents[].parts[] + generationConfig.responseModalities + imageConfig + Response format: candidates[].content.parts[].inlineData (camelCase in REST API) + """ + + def build_request(self, prompt: str, parameters: dict[str, Any]) -> dict[str, Any]: + """Build Gemini API request structure. + + Args: + prompt: Text prompt for image generation. + parameters: Parameter dictionary (aspectRatio, etc.). + + Returns: + Gemini-formatted request with contents[] and generationConfig. + """ + return { + "contents": [{"parts": [{"text": prompt}]}], + "generationConfig": { + "responseModalities": ["Image"], + "imageConfig": parameters, + }, + } + + def parse_response(self, response_data: dict[str, Any]) -> dict[str, Any] | None: + """Parse Gemini API response structure. + + Args: + response_data: Raw API response. + + Returns: + First part containing inlineData with base64 image, or None if blocked. + """ + candidates = response_data.get("candidates", []) + if not candidates: + return None + + candidate = candidates[0] + if candidate.get("finishReason") != "STOP": + return None + return candidate["content"]["parts"][0]["inlineData"] + + def parse_usage(self, response_data: dict[str, Any]) -> ImageGenerationUsage: + """Parse usage from Gemini API response. + + Args: + response_data: Raw API response. + + Returns: + ImageGenerationUsage with token counts and generated_images count. + """ + usage_metadata = response_data.get("usageMetadata", {}) + candidates = response_data.get("candidates", []) + + return ImageGenerationUsage( + input_tokens=usage_metadata.get("promptTokenCount"), + output_tokens=usage_metadata.get("candidatesTokenCount"), + total_tokens=usage_metadata.get("totalTokenCount"), + generated_images=len(candidates), + ) + + @staticmethod + def endpoint(model_id: str) -> str: + """Get endpoint for model.""" + return config.GEMINI_ENDPOINT.format(model_id=model_id) + + +__all__ = ["GeminiImageAPIAdapter"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/google/imagen_api.py b/packages/image-generation/src/celeste_image_generation/providers/google/imagen_api.py new file mode 100644 index 0000000..3db7b85 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/google/imagen_api.py @@ -0,0 +1,67 @@ +"""Imagen API adapter for Google image generation. + +Pure data transformer for Imagen models (imagen-3.x, imagen-4.x). +Handles request/response structure transformation only. +""" + +from typing import Any + +from celeste_image_generation.io import ImageGenerationUsage + +from . import config + + +class ImagenAPIAdapter: + """Adapter for Imagen API request/response transformation. + + Request format: instances[].prompt + parameters + Response format: predictions[].bytesBase64Encoded + """ + + def build_request(self, prompt: str, parameters: dict[str, Any]) -> dict[str, Any]: + """Build Imagen API request structure. + + Args: + prompt: Text prompt for image generation. + parameters: Parameter dictionary (aspectRatio, imageSize, etc.). + + Returns: + Imagen-formatted request with instances[] and parameters. + """ + return { + "instances": [{"prompt": prompt}], + "parameters": parameters, + } + + def parse_response(self, response_data: dict[str, Any]) -> dict[str, Any]: + """Parse Imagen API response structure. + + Args: + response_data: Raw API response. + + Returns: + First prediction containing bytesBase64Encoded and mimeType. + """ + return response_data["predictions"][0] + + def parse_usage(self, response_data: dict[str, Any]) -> ImageGenerationUsage: + """Parse usage from Imagen API response. + + Args: + response_data: Raw API response. + + Returns: + ImageGenerationUsage with generated_images count from predictions array. + """ + predictions = response_data.get("predictions", []) + return ImageGenerationUsage( + generated_images=len(predictions), + ) + + @staticmethod + def endpoint(model_id: str) -> str: + """Get endpoint for model.""" + return config.IMAGEN_ENDPOINT.format(model_id=model_id) + + +__all__ = ["ImagenAPIAdapter"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/google/models.py b/packages/image-generation/src/celeste_image_generation/providers/google/models.py new file mode 100644 index 0000000..1730529 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/google/models.py @@ -0,0 +1,86 @@ +"""Google models.""" + +from celeste import Model, Provider +from celeste.constraints import Choice +from celeste_image_generation.parameters import ImageGenerationParameter + +# Imagen API models (instances[].prompt → predictions[]) +IMAGEN_MODELS: list[Model] = [ + # Imagen 4 models (text-to-image) - Current GA + Model( + id="imagen-4.0-generate-001", + provider=Provider.GOOGLE, + display_name="Imagen 4", + parameter_constraints={ + ImageGenerationParameter.ASPECT_RATIO: Choice( + options=["1:1", "3:4", "4:3", "9:16", "16:9"] + ), + ImageGenerationParameter.QUALITY: Choice(options=["1K", "2K"]), + }, + ), + Model( + id="imagen-4.0-fast-generate-001", + provider=Provider.GOOGLE, + display_name="Imagen 4 Fast", + parameter_constraints={ + ImageGenerationParameter.ASPECT_RATIO: Choice( + options=["1:1", "3:4", "4:3", "9:16", "16:9"] + ), + ImageGenerationParameter.QUALITY: Choice(options=["1K"]), + }, + ), + Model( + id="imagen-4.0-ultra-generate-001", + provider=Provider.GOOGLE, + display_name="Imagen 4 Ultra", + parameter_constraints={ + ImageGenerationParameter.ASPECT_RATIO: Choice( + options=["1:1", "3:4", "4:3", "9:16", "16:9"] + ), + ImageGenerationParameter.QUALITY: Choice(options=["1K", "2K"]), + }, + ), + # Imagen 3 models (deprecated June 24, 2025) - Support for backwards compatibility + Model( + id="imagen-3.0-generate-002", + provider=Provider.GOOGLE, + display_name="Imagen 3", + parameter_constraints={ + ImageGenerationParameter.ASPECT_RATIO: Choice( + options=["1:1", "3:4", "4:3", "9:16", "16:9"] + ), + ImageGenerationParameter.QUALITY: Choice(options=["1K"]), + }, + ), +] + +# Gemini API models (contents[].parts[] → candidates[]) +GEMINI_MODELS: list[Model] = [ + Model( + id="gemini-2.5-flash-image", + provider=Provider.GOOGLE, + display_name="Gemini 2.5 Flash Image", + parameter_constraints={ + ImageGenerationParameter.ASPECT_RATIO: Choice( + options=[ + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", + ] + ), + }, + ), +] + +# Unified model list for registration +MODELS: list[Model] = [ + *IMAGEN_MODELS, + *GEMINI_MODELS, +] diff --git a/packages/image-generation/src/celeste_image_generation/providers/google/parameters.py b/packages/image-generation/src/celeste_image_generation/providers/google/parameters.py new file mode 100644 index 0000000..a773085 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/google/parameters.py @@ -0,0 +1,67 @@ +"""Google parameter mappers.""" + +from typing import Any + +from celeste import Model +from celeste.parameters import ParameterMapper +from celeste_image_generation.parameters import ImageGenerationParameter + + +class AspectRatioMapper(ParameterMapper): + """Map aspect_ratio parameter with validation.""" + + name = ImageGenerationParameter.ASPECT_RATIO + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform aspect_ratio into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + if "generationConfig" in request: + request.setdefault("generationConfig", {}).setdefault("imageConfig", {})[ + "aspectRatio" + ] = validated_value + else: + request.setdefault("parameters", {})["aspectRatio"] = validated_value + + return request + + +class QualityMapper(ParameterMapper): + """Map quality parameter to imageSize.""" + + name = ImageGenerationParameter.QUALITY + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform quality into provider imageSize request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + if "generationConfig" in request: + request.setdefault("generationConfig", {}).setdefault("imageConfig", {})[ + "imageSize" + ] = validated_value + else: + request.setdefault("parameters", {})["imageSize"] = validated_value + + return request + + +GOOGLE_PARAMETER_MAPPERS: list[ParameterMapper] = [ + AspectRatioMapper(), + QualityMapper(), +] + +__all__ = ["GOOGLE_PARAMETER_MAPPERS"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/openai/__init__.py b/packages/image-generation/src/celeste_image_generation/providers/openai/__init__.py new file mode 100644 index 0000000..29b11cb --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/openai/__init__.py @@ -0,0 +1,7 @@ +"""OpenAI provider.""" + +from .client import OpenAIImageGenerationClient +from .models import MODELS +from .streaming import OpenAIImageGenerationStream + +__all__ = ["MODELS", "OpenAIImageGenerationClient", "OpenAIImageGenerationStream"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/openai/client.py b/packages/image-generation/src/celeste_image_generation/providers/openai/client.py new file mode 100644 index 0000000..3248fa2 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/openai/client.py @@ -0,0 +1,142 @@ +"""OpenAI client implementation.""" + +import base64 +from collections.abc import AsyncIterator +from typing import Any, Unpack + +import httpx + +from celeste.artifacts import ImageArtifact +from celeste.parameters import ParameterMapper +from celeste_image_generation.client import ImageGenerationClient +from celeste_image_generation.io import ( + ImageGenerationFinishReason, + ImageGenerationInput, + ImageGenerationUsage, +) +from celeste_image_generation.parameters import ImageGenerationParameters + +from . import config +from .parameters import OPENAI_PARAMETER_MAPPERS +from .streaming import OpenAIImageGenerationStream + + +class OpenAIImageGenerationClient(ImageGenerationClient): + """OpenAI client.""" + + @classmethod + def parameter_mappers(cls) -> list[ParameterMapper]: + return OPENAI_PARAMETER_MAPPERS + + def _init_request(self, inputs: ImageGenerationInput) -> dict[str, Any]: + """Initialize request from inputs.""" + request = { + "model": self.model.id, + "prompt": inputs.prompt, + "n": 1, + } + + if self.model.id in ("dall-e-2", "dall-e-3"): + request["response_format"] = "b64_json" + + return request + + def _parse_usage(self, response_data: dict[str, Any]) -> ImageGenerationUsage: + """Parse usage from response.""" + return ImageGenerationUsage() + + def _parse_content( + self, + response_data: dict[str, Any], + **parameters: Unpack[ImageGenerationParameters], + ) -> ImageArtifact: + """Parse content from response.""" + data = response_data.get("data", []) + if not data: + msg = "No image data in response" + raise ValueError(msg) + + image_data = data[0] + + b64_json = image_data.get("b64_json") + if b64_json: + image_bytes = base64.b64decode(b64_json) + return ImageArtifact(data=image_bytes) + + url = image_data.get("url") + if url: + return ImageArtifact(url=url) + + msg = "No image URL or base64 data in response" + raise ValueError(msg) + + def _parse_finish_reason( + self, response_data: dict[str, Any] + ) -> ImageGenerationFinishReason | None: + """Parse finish reason from response.""" + return None + + def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: + """Build metadata dictionary from response data.""" + # Filter content field before calling super + content_fields = {"data"} + filtered_data = { + k: v for k, v in response_data.items() if k not in content_fields + } + metadata = super()._build_metadata(filtered_data) + # Add provider-specific parsed fields + if response_data.get("data") and response_data["data"]: + revised_prompt = response_data["data"][0].get("revised_prompt") + if revised_prompt: + metadata["revised_prompt"] = revised_prompt + return metadata + + async def _make_request( + self, + request_body: dict[str, Any], + **parameters: Unpack[ImageGenerationParameters], + ) -> httpx.Response: + """Make HTTP request(s) and return response object.""" + headers = { + config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + "Content-Type": "application/json", + } + + return await self.http_client.post( + f"{config.BASE_URL}{config.ENDPOINT}", + headers=headers, + json_body=request_body, + ) + + def _stream_class(self) -> type[OpenAIImageGenerationStream]: + """Return the Stream class for this client.""" + return OpenAIImageGenerationStream + + def _make_stream_request( + self, + request_body: dict[str, Any], + **parameters: Unpack[ImageGenerationParameters], + ) -> AsyncIterator[dict[str, Any]]: + """Make HTTP streaming request and return async iterator of events.""" + if self.model.id != "gpt-image-1": + msg = f"Streaming not supported for model '{self.model.id}'. Only 'gpt-image-1' supports streaming." + raise ValueError(msg) + + request_body["stream"] = True + + if "partial_images" not in request_body: + request_body["partial_images"] = 1 + + headers = { + config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + "Content-Type": "application/json", + } + + return self.http_client.stream_post( + f"{config.BASE_URL}{config.STREAM_ENDPOINT}", + headers=headers, + json_body=request_body, + ) + + +__all__ = ["OpenAIImageGenerationClient"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/openai/config.py b/packages/image-generation/src/celeste_image_generation/providers/openai/config.py new file mode 100644 index 0000000..0af0f9e --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/openai/config.py @@ -0,0 +1,10 @@ +"""OpenAI provider configuration.""" + +# HTTP Configuration +BASE_URL = "https://api.openai.com" +ENDPOINT = "/v1/images/generations" +STREAM_ENDPOINT = ENDPOINT # Same endpoint, streaming enabled via request parameter + +# Authentication +AUTH_HEADER_NAME = "Authorization" +AUTH_HEADER_PREFIX = "Bearer " diff --git a/packages/image-generation/src/celeste_image_generation/providers/openai/models.py b/packages/image-generation/src/celeste_image_generation/providers/openai/models.py new file mode 100644 index 0000000..90e9ddd --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/openai/models.py @@ -0,0 +1,44 @@ +"""OpenAI models.""" + +from celeste import Model, Provider +from celeste.constraints import Choice, Range +from celeste_image_generation.parameters import ImageGenerationParameter + +MODELS: list[Model] = [ + Model( + id="dall-e-2", + provider=Provider.OPENAI, + display_name="DALL-E 2", + parameter_constraints={ + ImageGenerationParameter.ASPECT_RATIO: Choice( + options=["256x256", "512x512", "1024x1024"] + ), + }, + ), + Model( + id="dall-e-3", + provider=Provider.OPENAI, + display_name="DALL-E 3", + parameter_constraints={ + ImageGenerationParameter.ASPECT_RATIO: Choice( + options=["1024x1024", "1792x1024", "1024x1792"] + ), + ImageGenerationParameter.QUALITY: Choice(options=["standard", "hd"]), + }, + ), + Model( + id="gpt-image-1", + provider=Provider.OPENAI, + display_name="GPT Image 1", + streaming=True, + parameter_constraints={ + ImageGenerationParameter.PARTIAL_IMAGES: Range(min=0, max=3), + ImageGenerationParameter.ASPECT_RATIO: Choice( + options=["1024x1024", "1536x1024", "1024x1536", "auto"] + ), + ImageGenerationParameter.QUALITY: Choice( + options=["low", "medium", "high", "auto"] + ), + }, + ), +] diff --git a/packages/image-generation/src/celeste_image_generation/providers/openai/parameters.py b/packages/image-generation/src/celeste_image_generation/providers/openai/parameters.py new file mode 100644 index 0000000..1db7202 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/openai/parameters.py @@ -0,0 +1,118 @@ +"""OpenAI parameter mappers.""" + +from typing import Any + +from celeste import Model +from celeste.parameters import ParameterMapper +from celeste_image_generation.parameters import ImageGenerationParameter + + +class AspectRatioMapper(ParameterMapper): + """Map aspect_ratio parameter to OpenAI's size parameter.""" + + name = ImageGenerationParameter.ASPECT_RATIO + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform aspect_ratio into provider request. + + Maps unified aspect_ratio parameter to OpenAI's size format. + Values are OpenAI's native size strings (e.g., "1024x1024", "1792x1024"). + Coercion from ratio format ("16:9") to size format can be added later. + + Args: + request: Provider request dictionary to modify. + value: The aspect_ratio value (OpenAI size string). + model: Model instance with parameter constraints. + + Returns: + Modified request dictionary with size parameter. + """ + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Transform to provider-specific request format (size parameter) + request["size"] = validated_value + return request + + +class PartialImagesMapper(ParameterMapper): + """Map partial_images parameter for streaming (gpt-image-1 only).""" + + name = ImageGenerationParameter.PARTIAL_IMAGES + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform partial_images into provider request. + + Controls number of partial images during streaming (0-3). + Only supported by gpt-image-1 model. + + Args: + request: Provider request dictionary to modify. + value: The partial_images value (0-3). + model: Model instance with parameter constraints. + + Returns: + Modified request dictionary with partial_images parameter. + """ + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Transform to provider-specific request format (top-level field) + request["partial_images"] = validated_value + return request + + +class QualityMapper(ParameterMapper): + """Map quality parameter for DALL-E 3 and gpt-image-1.""" + + name = ImageGenerationParameter.QUALITY + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform quality into provider request. + + Controls image quality/detail level. + - DALL-E 3: "standard" or "hd" + - gpt-image-1: "low", "medium", "high", or "auto" + - DALL-E 2: Not supported (no constraint in model) + + Args: + request: Provider request dictionary to modify. + value: The quality value. + model: Model instance with parameter constraints. + + Returns: + Modified request dictionary with quality parameter. + """ + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Transform to provider-specific request format (top-level field) + request["quality"] = validated_value + return request + + +OPENAI_PARAMETER_MAPPERS: list[ParameterMapper] = [ + AspectRatioMapper(), + PartialImagesMapper(), + QualityMapper(), +] + +__all__ = ["OPENAI_PARAMETER_MAPPERS"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/openai/streaming.py b/packages/image-generation/src/celeste_image_generation/providers/openai/streaming.py new file mode 100644 index 0000000..182ab15 --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/providers/openai/streaming.py @@ -0,0 +1,76 @@ +"""OpenAI streaming for image generation.""" + +import base64 +import logging +from typing import Any + +from celeste.artifacts import ImageArtifact +from celeste.io import Chunk +from celeste_image_generation.io import ImageGenerationChunk, ImageGenerationUsage +from celeste_image_generation.streaming import ImageGenerationStream + +logger = logging.getLogger(__name__) + + +class OpenAIImageGenerationStream(ImageGenerationStream): + """OpenAI streaming for image generation.""" + + def _parse_chunk(self, chunk_data: dict[str, Any]) -> Chunk | None: + """Parse chunk from SSE event. + + OpenAI returns two event types: + - image_generation.partial_image: Progressive image chunks + - image_generation.completed: Final image with usage data + """ + event_type = chunk_data.get("type") + + if event_type == "image_generation.partial_image": + # Partial image chunk + b64_json = chunk_data.get("b64_json") + if not b64_json: + return None + + image_data = base64.b64decode(b64_json) + artifact = ImageArtifact(data=image_data) + + return ImageGenerationChunk(content=artifact) + + if event_type == "image_generation.completed": + # Final image with usage + b64_json = chunk_data.get("b64_json") + if not b64_json: + return None + + image_data = base64.b64decode(b64_json) + artifact = ImageArtifact(data=image_data) + + # Parse usage from completed event + usage_data = chunk_data.get("usage") + usage = None + if usage_data: + usage = ImageGenerationUsage( + total_tokens=usage_data.get("total_tokens"), + input_tokens=usage_data.get("input_tokens"), + output_tokens=usage_data.get("output_tokens"), + ) + + return ImageGenerationChunk(content=artifact, usage=usage) + + logger.warning("Unknown event type: %s", event_type) + return None + + def _parse_usage(self, chunks: list[ImageGenerationChunk]) -> ImageGenerationUsage: + """Parse usage from chunks. + + Usage is only available in the final completed event. + """ + # Look for usage in final chunk (completed event) + for chunk in reversed(chunks): + if chunk.usage is not None: + return chunk.usage + + # No usage found + return ImageGenerationUsage() + + +__all__ = ["OpenAIImageGenerationStream"] diff --git a/packages/image-generation/src/celeste_image_generation/py.typed b/packages/image-generation/src/celeste_image_generation/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/packages/image-generation/src/celeste_image_generation/streaming.py b/packages/image-generation/src/celeste_image_generation/streaming.py new file mode 100644 index 0000000..5748b3b --- /dev/null +++ b/packages/image-generation/src/celeste_image_generation/streaming.py @@ -0,0 +1,48 @@ +"""Streaming for image generation.""" + +from abc import abstractmethod +from typing import Unpack + +from celeste.streaming import Stream +from celeste_image_generation.io import ( + ImageGenerationChunk, + ImageGenerationOutput, + ImageGenerationUsage, +) +from celeste_image_generation.parameters import ImageGenerationParameters + + +class ImageGenerationStream(Stream[ImageGenerationOutput, ImageGenerationParameters]): + """Streaming for image generation.""" + + def _parse_output( + self, + chunks: list[ImageGenerationChunk], + **parameters: Unpack[ImageGenerationParameters], + ) -> ImageGenerationOutput: + """Assemble chunks into final output. + + For image generation, the final chunk contains the complete image. + Progressive chunks may contain partial/preview images. + """ + if not chunks: + msg = "No chunks received from stream" + raise ValueError(msg) + + # Final chunk contains complete image + content = chunks[-1].content + usage = self._parse_usage(chunks) + finish_reason = chunks[-1].finish_reason if chunks else None + + return ImageGenerationOutput( + content=content, + usage=usage, + metadata={"finish_reason": finish_reason}, + ) + + @abstractmethod + def _parse_usage(self, chunks: list[ImageGenerationChunk]) -> ImageGenerationUsage: + """Parse usage from chunks (provider-specific).""" + + +__all__ = ["ImageGenerationStream"] diff --git a/packages/image-generation/tests/integration_tests/test_image_generation/__init__.py b/packages/image-generation/tests/integration_tests/test_image_generation/__init__.py new file mode 100644 index 0000000..6b6119e --- /dev/null +++ b/packages/image-generation/tests/integration_tests/test_image_generation/__init__.py @@ -0,0 +1 @@ +"""Integration tests for image generation capability.""" diff --git a/packages/image-generation/tests/integration_tests/test_image_generation/test_generate.py b/packages/image-generation/tests/integration_tests/test_image_generation/test_generate.py new file mode 100644 index 0000000..73dd915 --- /dev/null +++ b/packages/image-generation/tests/integration_tests/test_image_generation/test_generate.py @@ -0,0 +1,61 @@ +"""Integration tests for image generation across all providers.""" + +import pytest + +from celeste import Capability, Provider, create_client + + +@pytest.mark.parametrize( + ("provider", "model", "parameters"), + [ + (Provider.OPENAI, "dall-e-2", {}), + (Provider.GOOGLE, "imagen-4.0-fast-generate-001", {}), + (Provider.BYTEDANCE, "seedream-4-0-250828", {}), + ], +) +@pytest.mark.integration +@pytest.mark.asyncio +async def test_generate(provider: Provider, model: str, parameters: dict) -> None: + """Test image generation with prompt parameter across all providers. + + This test demonstrates that the unified API works identically across + all providers using the same code - proving the abstraction value. + Uses cheapest models to minimize costs. + """ + # Import here to avoid circular import during pytest collection + from celeste_image_generation import ( + ImageGenerationOutput, + ImageGenerationUsage, + ) + + from celeste.artifacts import ImageArtifact + + # Arrange + client = create_client( + capability=Capability.IMAGE_GENERATION, + provider=provider, + ) + prompt = "A red apple on a white background" + + # Act + response = await client.generate( + prompt=prompt, + model=model, + **parameters, + ) + + # Assert + assert isinstance(response, ImageGenerationOutput), ( + f"Expected ImageGenerationOutput, got {type(response)}" + ) + assert isinstance(response.content, ImageArtifact), ( + f"Expected ImageArtifact content, got {type(response.content)}" + ) + assert response.content.has_content, ( + f"ImageArtifact has no content (url/data/path): {response.content}" + ) + + # Validate usage metrics + assert isinstance(response.usage, ImageGenerationUsage), ( + f"Expected ImageGenerationUsage, got {type(response.usage)}" + ) diff --git a/packages/image-generation/tests/unit_tests/__init__.py b/packages/image-generation/tests/unit_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/image-generation/tests/unit_tests/providers/google/__init__.py b/packages/image-generation/tests/unit_tests/providers/google/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/image-generation/tests/unit_tests/providers/google/test_finish_reason.py b/packages/image-generation/tests/unit_tests/providers/google/test_finish_reason.py new file mode 100644 index 0000000..466f608 --- /dev/null +++ b/packages/image-generation/tests/unit_tests/providers/google/test_finish_reason.py @@ -0,0 +1,165 @@ +"""Unit tests for Google image generation finish reason parsing.""" + +from typing import Any + +import pytest +from celeste_image_generation.providers.google.client import GoogleImageGenerationClient +from pydantic import SecretStr + +from celeste.core import Capability, Provider +from celeste.models import Model + + +class TestParseFinishReason: + """Test _parse_finish_reason method for Google image generation client.""" + + @pytest.fixture + def client(self) -> GoogleImageGenerationClient: + """Create a Google image generation client for testing.""" + return GoogleImageGenerationClient( + model=Model( + id="gemini-2.5-flash-image", + provider=Provider.GOOGLE, + display_name="Gemini 2.5 Flash Image", + capabilities={Capability.IMAGE_GENERATION}, + ), + provider=Provider.GOOGLE, + capability=Capability.IMAGE_GENERATION, + api_key=SecretStr("test-key"), + ) + + @pytest.mark.parametrize( + ("finish_reason", "finish_message", "expected_reason", "expected_message"), + [ + ("STOP", None, "STOP", None), + ( + "PROHIBITED_CONTENT", + "Content blocked due to policy violation", + "PROHIBITED_CONTENT", + "Content blocked due to policy violation", + ), + ("PROHIBITED_CONTENT", None, "PROHIBITED_CONTENT", None), + ("NO_IMAGE", "Prompt too vague", "NO_IMAGE", "Prompt too vague"), + ( + "SAFETY", + "Safety filters detected inappropriate content", + "SAFETY", + "Safety filters detected inappropriate content", + ), + ], + ids=[ + "stop_without_message", + "prohibited_content_with_message", + "prohibited_content_without_message", + "no_image_with_message", + "safety_with_message", + ], + ) + def test_parse_finish_reason_with_valid_reason( + self, + client: GoogleImageGenerationClient, + finish_reason: str, + finish_message: str | None, + expected_reason: str, + expected_message: str | None, + ) -> None: + """Test parsing finish reason with valid finishReason values.""" + # Arrange + candidate: dict[str, Any] = {"finishReason": finish_reason} + if finish_message is not None: + candidate["finishMessage"] = finish_message + + response_data: dict[str, Any] = { + "candidates": [candidate], + "usageMetadata": {}, + } + + # Act + result = client._parse_finish_reason(response_data) + + # Assert + assert result is not None + assert result.reason == expected_reason + assert result.message == expected_message + + @pytest.mark.parametrize( + "response_data", + [ + {"candidates": [], "usageMetadata": {}}, # Empty candidates + {"predictions": [], "usageMetadata": {}}, # No candidates key (Imagen) + { + "candidates": [ + { + "content": { + "parts": [ + { + "inlineData": { + "mimeType": "image/png", + "data": "base64data", + } + } + ] + } + } + ], + "usageMetadata": {}, + }, # Candidate without finishReason + ], + ids=[ + "empty_candidates", + "no_candidates_key", + "candidate_without_finish_reason", + ], + ) + def test_parse_finish_reason_returns_none_for_invalid_input( + self, + client: GoogleImageGenerationClient, + response_data: dict[str, Any], + ) -> None: + """Test parsing finish reason returns None for invalid/missing input.""" + # Act + result = client._parse_finish_reason(response_data) + + # Assert + assert result is None + + def test_parse_finish_reason_empty_string_finish_reason( + self, client: GoogleImageGenerationClient + ) -> None: + """Test parsing finish reason when finishReason is empty string.""" + # Arrange + response_data: dict[str, Any] = { + "candidates": [{"finishReason": ""}], + "usageMetadata": {}, + } + + # Act + result = client._parse_finish_reason(response_data) + + # Assert + # Empty string is falsy, so should return None + assert result is None + + def test_parse_finish_reason_empty_string_message( + self, client: GoogleImageGenerationClient + ) -> None: + """Test parsing finish reason when finishMessage is empty string.""" + # Arrange + response_data: dict[str, Any] = { + "candidates": [ + { + "finishReason": "STOP", + "finishMessage": "", # Empty string vs None + } + ], + "usageMetadata": {}, + } + + # Act + result = client._parse_finish_reason(response_data) + + # Assert + assert result is not None + assert result.reason == "STOP" + # Empty string is preserved (candidate.get("finishMessage") returns "") + assert result.message == "" diff --git a/pyproject.toml b/pyproject.toml index 31def19..c413378 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "celeste-ai" -version = "0.2.0" +version = "0.2.1" description = "Open source, type-safe primitives for multi-modal AI. All capabilities, all providers, one interface" authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}] readme = "README.md" @@ -33,8 +33,9 @@ Repository = "https://github.com/withceleste/celeste-python" Issues = "https://github.com/withceleste/celeste-python/issues" [project.optional-dependencies] -text-generation = ["celeste-text-generation>=0.1.0"] -all = ["celeste-text-generation>=0.1.0"] +text-generation = ["celeste-text-generation>=0.2.1"] +image-generation = ["celeste-image-generation>=0.2.1"] +all = ["celeste-text-generation>=0.2.1", "celeste-image-generation>=0.2.1"] [dependency-groups] dev = [ @@ -55,6 +56,7 @@ members = ["packages/*"] [tool.uv.sources] celeste-text-generation = { workspace = true } +celeste-image-generation = { workspace = true } [build-system] requires = ["hatchling"] @@ -68,6 +70,11 @@ minversion = "8.0" testpaths = ["tests"] addopts = "-ra --strict-markers --strict-config" asyncio_mode = "auto" +pythonpath = [ + "src", + "packages/text-generation/src", + "packages/image-generation/src", +] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "smoke: quick checks for critical paths", @@ -156,6 +163,15 @@ module = [ ] disable_error_code = ["override", "return-value", "arg-type", "call-arg", "assignment", "no-any-return"] +[[tool.mypy.overrides]] +module = [ + "celeste_image_generation.*", + "celeste_image_generation.client", + "celeste_image_generation.streaming", + "celeste_image_generation.providers.*", +] +disable_error_code = ["override", "return-value", "arg-type", "call-arg", "assignment", "no-any-return"] + [tool.bandit] exclude_dirs = [".venv", "__pycache__"] skips = ["B101"] # Skip B101 (assert_used) since we use pytest From 5bc67c783b48b318bad56aaa591f60767b34a96d Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Mon, 10 Nov 2025 18:58:19 +0100 Subject: [PATCH 3/5] chore: bump versions to 0.2.1 and optimize text-generation imports --- packages/text-generation/pyproject.toml | 2 +- .../providers/__init__.py | 30 +++++++++++-------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/packages/text-generation/pyproject.toml b/packages/text-generation/pyproject.toml index 35c0c96..873fbf4 100644 --- a/packages/text-generation/pyproject.toml +++ b/packages/text-generation/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "celeste-text-generation" -version = "0.2.0" +version = "0.2.1" description = "Type-safe text generation for Celeste AI. Unified interface for OpenAI, Anthropic, Google, Mistral, Cohere, and more" authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}] readme = "README.md" diff --git a/packages/text-generation/src/celeste_text_generation/providers/__init__.py b/packages/text-generation/src/celeste_text_generation/providers/__init__.py index d884fc4..96600e1 100644 --- a/packages/text-generation/src/celeste_text_generation/providers/__init__.py +++ b/packages/text-generation/src/celeste_text_generation/providers/__init__.py @@ -1,24 +1,28 @@ """Provider implementations for text generation.""" -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from celeste.client import Client - from celeste.core import Provider +from celeste import Client, Provider __all__ = ["PROVIDERS"] -def _get_providers() -> list[tuple["Provider", type["Client"]]]: +def _get_providers() -> list[tuple[Provider, type[Client]]]: """Lazy-load providers.""" - from celeste.core import Provider - from celeste_text_generation.providers.anthropic import ( + # Import clients directly from .client modules to avoid __init__.py imports + from celeste_text_generation.providers.anthropic.client import ( AnthropicTextGenerationClient, ) - from celeste_text_generation.providers.cohere import CohereTextGenerationClient - from celeste_text_generation.providers.google import GoogleTextGenerationClient - from celeste_text_generation.providers.mistral import MistralTextGenerationClient - from celeste_text_generation.providers.openai import OpenAITextGenerationClient + from celeste_text_generation.providers.cohere.client import ( + CohereTextGenerationClient, + ) + from celeste_text_generation.providers.google.client import ( + GoogleTextGenerationClient, + ) + from celeste_text_generation.providers.mistral.client import ( + MistralTextGenerationClient, + ) + from celeste_text_generation.providers.openai.client import ( + OpenAITextGenerationClient, + ) return [ (Provider.ANTHROPIC, AnthropicTextGenerationClient), @@ -29,4 +33,4 @@ def _get_providers() -> list[tuple["Provider", type["Client"]]]: ] -PROVIDERS: list[tuple["Provider", type["Client"]]] = _get_providers() +PROVIDERS: list[tuple[Provider, type[Client]]] = _get_providers() From c23c5face990cad8d5aa15715f4d84ae7346c971 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Mon, 10 Nov 2025 19:08:25 +0100 Subject: [PATCH 4/5] fix: replace ValueError with custom exceptions and remove metadata filtering - Replace ValueError with ValidationError, ModelNotFoundError, and ConstraintViolationError - Use Provider.GOOGLE enum instead of string 'google' - Remove content field filtering from _build_metadata methods to match text-generation pattern --- .../providers/bytedance/client.py | 12 ++++-------- .../providers/google/client.py | 16 ++++++---------- .../providers/openai/client.py | 12 ++++-------- 3 files changed, 14 insertions(+), 26 deletions(-) diff --git a/packages/image-generation/src/celeste_image_generation/providers/bytedance/client.py b/packages/image-generation/src/celeste_image_generation/providers/bytedance/client.py index 822f6a2..b72daa1 100644 --- a/packages/image-generation/src/celeste_image_generation/providers/bytedance/client.py +++ b/packages/image-generation/src/celeste_image_generation/providers/bytedance/client.py @@ -7,6 +7,7 @@ import httpx from celeste.artifacts import ImageArtifact +from celeste.exceptions import ConstraintViolationError, ValidationError from celeste.mime_types import ImageMimeType from celeste.parameters import ParameterMapper from celeste_image_generation.client import ImageGenerationClient @@ -75,7 +76,7 @@ def _parse_content( ) msg = "No image content found in ByteDance response" - raise ValueError(msg) + raise ValidationError(msg) def _parse_finish_reason( self, response_data: dict[str, Any] @@ -91,12 +92,7 @@ def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: Extracts seed if present. """ - # Filter content fields before calling super - content_fields = {"images", "data"} - filtered_data = { - k: v for k, v in response_data.items() if k not in content_fields - } - metadata = super()._build_metadata(filtered_data) + metadata = super()._build_metadata(response_data) # Add provider-specific parsed fields seed = response_data.get("seed") if seed is not None: @@ -118,7 +114,7 @@ async def _make_request( " • aspect_ratio: Exact dimensions (e.g., '2048x2048', '3840x2160')\n" "Use one or the other, not both." ) - raise ValueError(msg) + raise ConstraintViolationError(msg) request_body["stream"] = False diff --git a/packages/image-generation/src/celeste_image_generation/providers/google/client.py b/packages/image-generation/src/celeste_image_generation/providers/google/client.py index 7525f74..214731f 100644 --- a/packages/image-generation/src/celeste_image_generation/providers/google/client.py +++ b/packages/image-generation/src/celeste_image_generation/providers/google/client.py @@ -7,6 +7,8 @@ from pydantic import ConfigDict from celeste.artifacts import ImageArtifact +from celeste.core import Provider +from celeste.exceptions import ModelNotFoundError from celeste.mime_types import ImageMimeType from celeste.parameters import ParameterMapper from celeste_image_generation.client import ImageGenerationClient @@ -92,16 +94,11 @@ def _parse_finish_reason( def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: """Build metadata dictionary from response data.""" - # Parse finish_reason from full response_data before filtering (needs "candidates") + # Parse finish_reason from full response_data before calling super (needs "candidates") finish_reason = self._parse_finish_reason(response_data) - # Filter content fields before calling super (Imagen uses "predictions", Gemini uses "candidates") - content_fields = {"predictions", "candidates"} - filtered_data = { - k: v for k, v in response_data.items() if k not in content_fields - } - metadata = super()._build_metadata(filtered_data) - # Override with pre-parsed finish_reason (base class parsed from filtered_data which returns None) + metadata = super()._build_metadata(response_data) + # Override with pre-parsed finish_reason if finish_reason is not None: metadata["finish_reason"] = finish_reason return metadata @@ -145,8 +142,7 @@ def _get_adapter_for_model(model_id: str) -> tuple[type, str]: return GeminiImageAPIAdapter, config.GEMINI_ENDPOINT - msg = f"Unknown Google image generation model: {model_id}" - raise ValueError(msg) + raise ModelNotFoundError(model_id=model_id, provider=Provider.GOOGLE) __all__ = ["GoogleImageGenerationClient"] diff --git a/packages/image-generation/src/celeste_image_generation/providers/openai/client.py b/packages/image-generation/src/celeste_image_generation/providers/openai/client.py index 3248fa2..1b3d49e 100644 --- a/packages/image-generation/src/celeste_image_generation/providers/openai/client.py +++ b/packages/image-generation/src/celeste_image_generation/providers/openai/client.py @@ -7,6 +7,7 @@ import httpx from celeste.artifacts import ImageArtifact +from celeste.exceptions import ValidationError from celeste.parameters import ParameterMapper from celeste_image_generation.client import ImageGenerationClient from celeste_image_generation.io import ( @@ -54,7 +55,7 @@ def _parse_content( data = response_data.get("data", []) if not data: msg = "No image data in response" - raise ValueError(msg) + raise ValidationError(msg) image_data = data[0] @@ -68,7 +69,7 @@ def _parse_content( return ImageArtifact(url=url) msg = "No image URL or base64 data in response" - raise ValueError(msg) + raise ValidationError(msg) def _parse_finish_reason( self, response_data: dict[str, Any] @@ -78,12 +79,7 @@ def _parse_finish_reason( def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: """Build metadata dictionary from response data.""" - # Filter content field before calling super - content_fields = {"data"} - filtered_data = { - k: v for k, v in response_data.items() if k not in content_fields - } - metadata = super()._build_metadata(filtered_data) + metadata = super()._build_metadata(response_data) # Add provider-specific parsed fields if response_data.get("data") and response_data["data"]: revised_prompt = response_data["data"][0].get("revised_prompt") From ea53d63e0e2c2dedfb2139e955fddca913f9af46 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Mon, 10 Nov 2025 19:18:26 +0100 Subject: [PATCH 5/5] chore: add comment to integration test file --- .../integration_tests/test_image_generation/test_generate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/image-generation/tests/integration_tests/test_image_generation/test_generate.py b/packages/image-generation/tests/integration_tests/test_image_generation/test_generate.py index 73dd915..9b419c1 100644 --- a/packages/image-generation/tests/integration_tests/test_image_generation/test_generate.py +++ b/packages/image-generation/tests/integration_tests/test_image_generation/test_generate.py @@ -4,6 +4,8 @@ from celeste import Capability, Provider, create_client +# Integration tests require API credentials configured in CI environment + @pytest.mark.parametrize( ("provider", "model", "parameters"),