From 099b83ce6c6144024d4e573d4abbe169444ff0a6 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Mon, 15 Dec 2025 15:48:20 +0100 Subject: [PATCH 1/3] feat(providers): add Google Veo API mixin ## Client (GoogleVeoClient mixin) - HTTP POST to /v1beta/models/{model_id}:predictLongRunning endpoint - Async polling for long-running video generation operations - Parse generatedSamples[0].video from response - Download video content from GCS URLs ## Parameters - AspectRatioMapper: parameters.aspectRatio - ResolutionMapper: parameters.resolution - DurationSecondsMapper: parameters.durationSeconds - ReferenceImagesMapper: instances.referenceImages with base64 encoding - FirstFrameMapper: instances.image for image-to-video - LastFrameMapper: instances.lastFrame for interpolation ## Config - API base URL: https://generativelanguage.googleapis.com - Endpoints: predictLongRunning, operation polling - Polling interval: 10s, timeout: 300s - GCS storage URL for video downloads --- .../google/src/celeste_google/veo/__init__.py | 1 + .../google/src/celeste_google/veo/client.py | 143 +++++++++++++ .../google/src/celeste_google/veo/config.py | 20 ++ .../src/celeste_google/veo/parameters.py | 195 ++++++++++++++++++ 4 files changed, 359 insertions(+) create mode 100644 packages/providers/google/src/celeste_google/veo/__init__.py create mode 100644 packages/providers/google/src/celeste_google/veo/client.py create mode 100644 packages/providers/google/src/celeste_google/veo/config.py create mode 100644 packages/providers/google/src/celeste_google/veo/parameters.py diff --git a/packages/providers/google/src/celeste_google/veo/__init__.py b/packages/providers/google/src/celeste_google/veo/__init__.py new file mode 100644 index 0000000..e1fabad --- /dev/null +++ b/packages/providers/google/src/celeste_google/veo/__init__.py @@ -0,0 +1 @@ +"""Google Veo API provider package.""" diff --git a/packages/providers/google/src/celeste_google/veo/client.py b/packages/providers/google/src/celeste_google/veo/client.py new file mode 100644 index 0000000..1f84a62 --- /dev/null +++ b/packages/providers/google/src/celeste_google/veo/client.py @@ -0,0 +1,143 @@ +"""Google Veo API client with shared implementation.""" + +import asyncio +import json +import logging +from typing import Any + +import httpx + +from celeste.mime_types import ApplicationMimeType + +from . import config + +logger = logging.getLogger(__name__) + + +class GoogleVeoClient: + """Mixin for Veo API capabilities. + + Provides shared implementation for video generation using the Veo API: + - _make_request() - HTTP POST with async polling for long-running operations + - _parse_content() - Extract raw video dict from response (generic) + - download_content() - Download from GCS URL, returns raw bytes (generic) + + Capability clients extend via super() to wrap results in artifacts: + class GoogleVideoGenerationClient(GoogleVeoClient, VideoGenerationClient): + def _parse_content(self, response_data, **params): + video_data = super()._parse_content(response_data) # Get generic dict + return VideoArtifact(url=video_data["uri"]) # Capability-specific + + async def download_content(self, artifact: VideoArtifact) -> VideoArtifact: + video_bytes = await super().download_content(artifact.url) # Get raw bytes + return VideoArtifact(data=video_bytes, mime_type=VideoMimeType.MP4, ...) + """ + + async def _make_request( + self, + request_body: dict[str, Any], + **parameters: Any, + ) -> httpx.Response: + """Make HTTP request with async polling for Veo video generation.""" + model_id = self.model.id # type: ignore[attr-defined] + endpoint = config.GoogleVeoEndpoint.CREATE_VIDEO.format(model_id=model_id) + url = f"{config.BASE_URL}{endpoint}" + + auth_headers = self.auth.get_headers() # type: ignore[attr-defined] + headers = { + **auth_headers, + "Content-Type": ApplicationMimeType.JSON, + } + + logger.info(f"Initiating video generation with model {model_id}") + response = await self.http_client.post( # type: ignore[attr-defined] + url, + headers=headers, + json_body=request_body, + timeout=config.DEFAULT_TIMEOUT, + ) + + self._handle_error_response(response) # type: ignore[attr-defined] + operation_data = response.json() + + operation_name = operation_data["name"] + logger.info(f"Video generation started: {operation_name}") + + poll_url = f"{config.BASE_URL}{config.GoogleVeoEndpoint.GET_OPERATION.format(operation_name=operation_name)}" + poll_headers = auth_headers + + while True: + await asyncio.sleep(config.POLL_INTERVAL) + logger.debug(f"Polling operation status: {operation_name}") + + poll_response = await self.http_client.get( # type: ignore[attr-defined] + poll_url, + headers=poll_headers, + timeout=config.DEFAULT_TIMEOUT, + ) + + self._handle_error_response(poll_response) # type: ignore[attr-defined] + operation_data = poll_response.json() + + if operation_data.get("done"): + if "error" in operation_data: + error = operation_data["error"] + error_msg = error.get("message", "Unknown error") + error_code = error.get("code", "UNKNOWN") + msg = f"Video generation failed: {error_code} - {error_msg}" + raise ValueError(msg) + + logger.info(f"Video generation completed: {operation_name}") + break + + return httpx.Response( + 200, + content=json.dumps(operation_data).encode(), + headers={"Content-Type": ApplicationMimeType.JSON}, + ) + + def _parse_content(self, response_data: dict[str, Any]) -> Any: + """Extract raw video dict from response. + + Returns generic dict with video data that capability clients wrap in artifacts. + """ + generate_response = response_data.get("response", {}).get( + "generateVideoResponse", {} + ) + generated_samples = generate_response.get("generatedSamples", []) + if not generated_samples: + msg = "No generated samples in response" + raise ValueError(msg) + return generated_samples[0].get("video", {}) + + async def download_content(self, url: str) -> bytes: + """Download video content from GCS URL. + + Returns raw bytes that capability clients wrap in VideoArtifact. + + Args: + url: GCS URL (gs://) or HTTPS URL to download from. + + Returns: + Raw video bytes. + """ + download_url = url + if download_url.startswith("gs://"): + download_url = download_url.replace("gs://", config.STORAGE_BASE_URL, 1) + + logger.info(f"Downloading video from: {download_url}") + + headers = self.auth.get_headers() # type: ignore[attr-defined] + + response = await self.http_client.get( # type: ignore[attr-defined] + download_url, + headers=headers, + timeout=config.DEFAULT_TIMEOUT, + follow_redirects=True, + ) + + self._handle_error_response(response) # type: ignore[attr-defined] + return response.content # type: ignore[no-any-return] + + +__all__ = ["GoogleVeoClient"] diff --git a/packages/providers/google/src/celeste_google/veo/config.py b/packages/providers/google/src/celeste_google/veo/config.py new file mode 100644 index 0000000..41405db --- /dev/null +++ b/packages/providers/google/src/celeste_google/veo/config.py @@ -0,0 +1,20 @@ +"""Configuration for Google Veo API.""" + +from enum import StrEnum + + +class GoogleVeoEndpoint(StrEnum): + """Endpoints for Veo API.""" + + CREATE_VIDEO = "/v1beta/models/{model_id}:predictLongRunning" + GET_OPERATION = "/v1beta/{operation_name}" + + +BASE_URL = "https://generativelanguage.googleapis.com" + +# Polling Configuration +POLL_INTERVAL = 10 # seconds +DEFAULT_TIMEOUT = 300.0 # 5 minutes for long-running operations + +# Storage Configuration +STORAGE_BASE_URL = "https://storage.googleapis.com/" diff --git a/packages/providers/google/src/celeste_google/veo/parameters.py b/packages/providers/google/src/celeste_google/veo/parameters.py new file mode 100644 index 0000000..de6f9b3 --- /dev/null +++ b/packages/providers/google/src/celeste_google/veo/parameters.py @@ -0,0 +1,195 @@ +"""Google Veo API parameter mappers.""" + +from typing import Any + +from celeste import image_to_data_uri +from celeste.exceptions import ValidationError +from celeste.models import Model +from celeste.parameters import ParameterMapper + + +class AspectRatioMapper(ParameterMapper): + """Map aspect_ratio to Google Veo parameters.aspectRatio field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform aspect_ratio into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Transform to provider-specific request format + request.setdefault("parameters", {})["aspectRatio"] = validated_value + return request + + +class ResolutionMapper(ParameterMapper): + """Map resolution to Google Veo parameters.resolution field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform resolution into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Transform to provider-specific request format + request.setdefault("parameters", {})["resolution"] = validated_value + return request + + +class DurationSecondsMapper(ParameterMapper): + """Map duration to Google Veo parameters.durationSeconds field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform duration into provider request.""" + # Coerce to integer if string provided + if isinstance(value, str): + value = int(value) + + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Transform to provider-specific request format (API expects integer) + request.setdefault("parameters", {})["durationSeconds"] = validated_value + return request + + +class ReferenceImagesMapper(ParameterMapper): + """Map reference_images to Google Veo instances.referenceImages field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform reference_images into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + reference_images = [] + # Validated value is list[ImageArtifact] based on capability constraints + for img in validated_value: + ref_image: dict[str, Any] = { + "image": {}, + "referenceType": "asset", + } + + # Convert to data URI using core utility + try: + data_uri = image_to_data_uri(img) + + # Extract base64 data from data URI + header, encoded = data_uri.split(",", 1) + mime_type = header.split(":")[1].split(";")[0] + + ref_image["image"]["bytesBase64Encoded"] = encoded + ref_image["image"]["mimeType"] = mime_type + except (ValueError, IndexError, OSError): + msg = "Failed to process reference image. Ensure valid data/path/url." + raise ValidationError(msg) from None + + reference_images.append(ref_image) + + request.setdefault("instances", [{}])[0]["referenceImages"] = reference_images + return request + + +class FirstFrameMapper(ParameterMapper): + """Map first_frame to Google Veo instances.image field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform first_frame into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Convert to data URI using core utility + try: + data_uri = image_to_data_uri(validated_value) + + # Extract base64 data from data URI + header, encoded = data_uri.split(",", 1) + mime_type = header.split(":")[1].split(";")[0] + + # Set image in instances[0].image + request.setdefault("instances", [{}])[0]["image"] = { + "bytesBase64Encoded": encoded, + "mimeType": mime_type, + } + except (ValueError, IndexError, OSError): + msg = "Failed to process first_frame. Ensure valid data/path/url." + raise ValidationError(msg) from None + + return request + + +class LastFrameMapper(ParameterMapper): + """Map last_frame to Google Veo instances.lastFrame field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform last_frame into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + # Check if first_frame (image) exists - lastFrame requires image per API docs + instances = request.get("instances", [{}]) + if not instances or "image" not in instances[0]: + msg = "last_frame requires first_frame to be provided" + raise ValidationError(msg) + + # Convert to data URI using core utility + try: + data_uri = image_to_data_uri(validated_value) + + # Extract base64 data from data URI + header, encoded = data_uri.split(",", 1) + mime_type = header.split(":")[1].split(";")[0] + + # Set lastFrame in instances[0] to match image structure + request.setdefault("instances", [{}])[0]["lastFrame"] = { + "bytesBase64Encoded": encoded, + "mimeType": mime_type, + } + except (ValueError, IndexError, OSError): + msg = "Failed to process last_frame. Ensure valid data/path/url." + raise ValidationError(msg) from None + + return request + + +__all__ = [ + "AspectRatioMapper", + "DurationSecondsMapper", + "FirstFrameMapper", + "LastFrameMapper", + "ReferenceImagesMapper", + "ResolutionMapper", +] From 867dcbc6a0d9b5604e4b99f72976a9d30259de01 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Mon, 15 Dec 2025 15:48:42 +0100 Subject: [PATCH 2/3] feat(video-generation): migrate Google provider to Veo API mixin - Update GoogleVideoGenerationClient to extend GoogleVeoClient mixin - Simplify client by delegating HTTP, polling, downloads to API layer - Update parameter mappers to extend base Veo API mappers - Remove config.py (now in API layer) - Add celeste-google as workspace dependency - Fix _create_inputs to get prompt from parameters dict --- .../video-generation/pyproject.toml | 1 + .../src/celeste_video_generation/client.py | 6 +- .../providers/google/client.py | 189 +++-------------- .../providers/google/config.py | 8 - .../providers/google/parameters.py | 196 +++--------------- .../test_video_generation/__init__.py | 2 +- .../test_video_generation/test_generate.py | 18 +- 7 files changed, 60 insertions(+), 360 deletions(-) delete mode 100644 packages/capabilities/video-generation/src/celeste_video_generation/providers/google/config.py diff --git a/packages/capabilities/video-generation/pyproject.toml b/packages/capabilities/video-generation/pyproject.toml index dcc187e..6e6dcf9 100644 --- a/packages/capabilities/video-generation/pyproject.toml +++ b/packages/capabilities/video-generation/pyproject.toml @@ -30,6 +30,7 @@ Issues = "https://github.com/withceleste/celeste-python/issues" [tool.uv.sources] celeste-ai = { workspace = true } +celeste-google = { workspace = true } [project.entry-points."celeste.packages"] video-generation = "celeste_video_generation:register_package" diff --git a/packages/capabilities/video-generation/src/celeste_video_generation/client.py b/packages/capabilities/video-generation/src/celeste_video_generation/client.py index 8530513..680c232 100644 --- a/packages/capabilities/video-generation/src/celeste_video_generation/client.py +++ b/packages/capabilities/video-generation/src/celeste_video_generation/client.py @@ -38,14 +38,12 @@ def _parse_content( """Parse content from provider response.""" def _create_inputs( - self, - *args: str, - prompt: str | None = None, - **parameters: Unpack[VideoGenerationParameters], + self, *args: str, **parameters: Unpack[VideoGenerationParameters] ) -> VideoGenerationInput: """Map positional arguments to Input type.""" if args: return VideoGenerationInput(prompt=args[0]) + prompt: str | None = parameters.get("prompt") if prompt is None: msg = ( "prompt is required (either as positional argument or keyword argument)" diff --git a/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/client.py b/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/client.py index b950d63..51dd140 100644 --- a/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/client.py +++ b/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/client.py @@ -1,15 +1,11 @@ -"""Google provider client for video generation.""" +"""Google client implementation for video generation.""" -import asyncio -import base64 -import json -import logging from typing import Any, Unpack -import httpx +from celeste_google.veo.client import GoogleVeoClient -from celeste.artifacts import ImageArtifact, VideoArtifact -from celeste.mime_types import ApplicationMimeType, VideoMimeType +from celeste.artifacts import VideoArtifact +from celeste.mime_types import VideoMimeType from celeste.parameters import ParameterMapper from celeste_video_generation.client import VideoGenerationClient from celeste_video_generation.io import ( @@ -17,62 +13,17 @@ VideoGenerationUsage, ) from celeste_video_generation.parameters import VideoGenerationParameters -from celeste_video_generation.providers.google import config -from celeste_video_generation.providers.google.parameters import ( - GOOGLE_PARAMETER_MAPPERS, -) -logger = logging.getLogger(__name__) +from .parameters import GOOGLE_PARAMETER_MAPPERS -class GoogleVideoGenerationClient(VideoGenerationClient): +class GoogleVideoGenerationClient(GoogleVeoClient, VideoGenerationClient): """Google client for video generation.""" @classmethod def parameter_mappers(cls) -> list[ParameterMapper]: return GOOGLE_PARAMETER_MAPPERS - def _validate_artifacts( - self, - inputs: VideoGenerationInput, - **parameters: Unpack[VideoGenerationParameters], - ) -> tuple[VideoGenerationInput, dict[str, Any]]: - """Validate and prepare artifacts for Google Veo API.""" - - def convert_to_base64_uri(img: ImageArtifact) -> ImageArtifact: - if img.data: - file_data = img.data - elif img.path: - with open(img.path, "rb") as f: - file_data = f.read() - else: - msg = "ImageArtifact must have data or path" - raise ValueError(msg) - - base64_data = base64.b64encode(file_data).decode("utf-8") - mime_type = img.mime_type.value if img.mime_type else "image/jpeg" - - return ImageArtifact( - url=f"data:image/{mime_type.split('/')[-1]};base64,{base64_data}", - mime_type=img.mime_type, - metadata=img.metadata, - ) - - reference_images = parameters.get("reference_images") - if reference_images: - converted_images = [convert_to_base64_uri(img) for img in reference_images] - parameters["reference_images"] = converted_images - - first_frame = parameters.get("first_frame") - if first_frame: - parameters["first_frame"] = convert_to_base64_uri(first_frame) - - last_frame = parameters.get("last_frame") - if last_frame: - parameters["last_frame"] = convert_to_base64_uri(last_frame) - - return inputs, dict(parameters) - def _init_request(self, inputs: VideoGenerationInput) -> dict[str, Any]: """Initialize request from Google API format.""" instance: dict[str, Any] = {"prompt": inputs.prompt} @@ -92,125 +43,33 @@ def _parse_content( **parameters: Unpack[VideoGenerationParameters], ) -> VideoArtifact: """Parse content from response.""" - try: - generate_response = response_data.get("response", {}).get( - "generateVideoResponse", {} - ) - generated_samples = generate_response.get("generatedSamples", []) - if not generated_samples: - msg = "No generated samples in response" - raise ValueError(msg) - - video_data = generated_samples[0].get("video", {}) - uri = video_data.get("uri") - if not uri: - msg = "No video URI in response" - raise ValueError(msg) - - video_artifact = VideoArtifact(url=uri) - - transformed = self._transform_output(video_artifact, **parameters) - if isinstance(transformed, VideoArtifact): - return transformed - return video_artifact - except (KeyError, IndexError) as e: - msg = f"Invalid response structure: {e}" - raise ValueError(msg) from e - - async def _make_request( - self, - request_body: dict[str, Any], - **parameters: Unpack[VideoGenerationParameters], - ) -> httpx.Response: - """Make HTTP request with async polling for Google video generation.""" - model_id = self.model.id - endpoint = config.GENERATE_ENDPOINT.format(model_id=model_id) - url = f"{config.BASE_URL}{endpoint}" - - headers = { - **self.auth.get_headers(), - "Content-Type": ApplicationMimeType.JSON, - } - - logger.info(f"Initiating video generation with model {model_id}") - response = await self.http_client.post( - url, - headers=headers, - json_body=request_body, - timeout=config.DEFAULT_TIMEOUT, - ) - - self._handle_error_response(response) - operation_data = response.json() - - operation_name = operation_data["name"] - logger.info(f"Video generation started: {operation_name}") - - poll_url = f"{config.BASE_URL}{config.POLL_ENDPOINT.format(operation_name=operation_name)}" - poll_headers = self.auth.get_headers() - - while True: - await asyncio.sleep(config.POLL_INTERVAL) - logger.debug(f"Polling operation status: {operation_name}") - - poll_response = await self.http_client.get( - poll_url, - headers=poll_headers, - timeout=config.DEFAULT_TIMEOUT, - ) - - self._handle_error_response(poll_response) - operation_data = poll_response.json() - - if operation_data.get("done"): - if "error" in operation_data: - error = operation_data["error"] - error_msg = error.get("message", "Unknown error") - error_code = error.get("code", "UNKNOWN") - msg = f"Video generation failed: {error_code} - {error_msg}" - raise ValueError(msg) + video_data = super()._parse_content(response_data) + uri = video_data.get("uri") + if not uri: + msg = "No video URI in response" + raise ValueError(msg) - logger.info(f"Video generation completed: {operation_name}") - break + video_artifact = VideoArtifact(url=uri) - return httpx.Response( - 200, - content=json.dumps(operation_data).encode(), - headers={"Content-Type": ApplicationMimeType.JSON}, - ) + transformed = self._transform_output(video_artifact, **parameters) + if isinstance(transformed, VideoArtifact): + return transformed + return video_artifact async def download_content(self, artifact: VideoArtifact) -> VideoArtifact: - """Download video content from URI. - - Google-specific method. Google Veo returns gs:// URIs that require - downloading with API key authentication. Other providers return video - content directly in the response. - """ + """Download video content from URI.""" if not artifact.url: msg = "VideoArtifact has no URL to download from" raise ValueError(msg) - download_url = artifact.url - if download_url.startswith("gs://"): - download_url = download_url.replace("gs://", config.STORAGE_BASE_URL, 1) - - logger.info(f"Downloading video from: {download_url}") - - headers = self.auth.get_headers() - - response = await self.http_client.get( - download_url, - headers=headers, - timeout=config.DEFAULT_TIMEOUT, - follow_redirects=True, - ) - - self._handle_error_response(response) - video_data = response.content + video_bytes = await super().download_content(artifact.url) return VideoArtifact( - url=artifact.url, # Keep original URI - data=video_data, - mime_type=VideoMimeType.MP4, # Default to MP4 for videos + url=artifact.url, + data=video_bytes, + mime_type=VideoMimeType.MP4, metadata=artifact.metadata, ) + + +__all__ = ["GoogleVideoGenerationClient"] diff --git a/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/config.py b/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/config.py deleted file mode 100644 index 29f53b6..0000000 --- a/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/config.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Google provider configuration for video generation.""" - -BASE_URL = "https://generativelanguage.googleapis.com/v1beta" -GENERATE_ENDPOINT = "/models/{model_id}:predictLongRunning" -POLL_ENDPOINT = "/{operation_name}" -POLL_INTERVAL = 10 # seconds -DEFAULT_TIMEOUT = 300.0 # 5 minutes for long-running operations -STORAGE_BASE_URL = "https://storage.googleapis.com/" diff --git a/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/parameters.py b/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/parameters.py index 7fb63a6..217946f 100644 --- a/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/parameters.py +++ b/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/parameters.py @@ -1,198 +1,56 @@ -"""Google parameter mappers for video generation.""" +"""Google Veo parameter mappers for video generation.""" + +from celeste_google.veo.parameters import ( + AspectRatioMapper as _AspectRatioMapper, +) +from celeste_google.veo.parameters import ( + DurationSecondsMapper as _DurationSecondsMapper, +) +from celeste_google.veo.parameters import ( + FirstFrameMapper as _FirstFrameMapper, +) +from celeste_google.veo.parameters import ( + LastFrameMapper as _LastFrameMapper, +) +from celeste_google.veo.parameters import ( + ReferenceImagesMapper as _ReferenceImagesMapper, +) +from celeste_google.veo.parameters import ( + ResolutionMapper as _ResolutionMapper, +) -from typing import Any - -from celeste.exceptions import ValidationError -from celeste.models import Model from celeste.parameters import ParameterMapper from celeste_video_generation.parameters import VideoGenerationParameter -class AspectRatioMapper(ParameterMapper): - """Map aspect_ratio parameter to Google API format.""" - +class AspectRatioMapper(_AspectRatioMapper): name = VideoGenerationParameter.ASPECT_RATIO - def map( - self, - request: dict[str, Any], - value: object, - model: Model, - ) -> dict[str, Any]: - """Transform aspect_ratio into provider request.""" - validated_value = self._validate_value(value, model) - if validated_value is None: - return request - - # Transform to provider-specific request format - request.setdefault("parameters", {})["aspectRatio"] = validated_value - return request - - -class ResolutionMapper(ParameterMapper): - """Map resolution parameter to Google API format.""" +class ResolutionMapper(_ResolutionMapper): name = VideoGenerationParameter.RESOLUTION - def map( - self, - request: dict[str, Any], - value: object, - model: Model, - ) -> dict[str, Any]: - """Transform resolution into provider request.""" - validated_value = self._validate_value(value, model) - if validated_value is None: - return request - - # Transform to provider-specific request format - request.setdefault("parameters", {})["resolution"] = validated_value - return request - - -class DurationSecondsMapper(ParameterMapper): - """Map duration parameter to Google API format.""" +class DurationMapper(_DurationSecondsMapper): name = VideoGenerationParameter.DURATION - def map( - self, - request: dict[str, Any], - value: object, - model: Model, - ) -> dict[str, Any]: - """Transform duration into provider request.""" - # Coerce to integer if string provided (for backward compatibility) - if isinstance(value, str): - value = int(value) - - validated_value = self._validate_value(value, model) - if validated_value is None: - return request - - # Transform to provider-specific request format (API expects integer) - request.setdefault("parameters", {})["durationSeconds"] = validated_value - return request - - -class ReferenceImagesMapper(ParameterMapper): - """Map reference_images parameter to Google API format.""" +class ReferenceImagesMapper(_ReferenceImagesMapper): name = VideoGenerationParameter.REFERENCE_IMAGES - def map( - self, - request: dict[str, Any], - value: object, - model: Model, - ) -> dict[str, Any]: - """Transform reference_images into provider request.""" - validated_value = self._validate_value(value, model) - if validated_value is None: - return request - - reference_images = [] - for img in validated_value: - ref_image: dict[str, Any] = { - "image": {}, - "referenceType": "asset", - } - - # Check if URL is base64 data URI - if img.url and img.url.startswith("data:image/"): - # Extract base64 data from data URI - header, encoded = img.url.split(",", 1) - mime_type = header.split(":")[1].split(";")[0] - ref_image["image"]["bytesBase64Encoded"] = encoded - ref_image["image"]["mimeType"] = mime_type - else: - msg = "ImageArtifact must have data or path for reference images (base64 encoding required)" - raise ValidationError(msg) - - reference_images.append(ref_image) - - request.setdefault("instances", [{}])[0]["referenceImages"] = reference_images - return request - - -class FirstFrameMapper(ParameterMapper): - """Map first_frame parameter to Google API format.""" +class FirstFrameMapper(_FirstFrameMapper): name = VideoGenerationParameter.FIRST_FRAME - def map( - self, - request: dict[str, Any], - value: object, - model: Model, - ) -> dict[str, Any]: - """Transform first_frame into provider request.""" - validated_value = self._validate_value(value, model) - if validated_value is None: - return request - - # Check if URL is base64 data URI - if validated_value.url and validated_value.url.startswith("data:image/"): - # Extract base64 data from data URI - header, encoded = validated_value.url.split(",", 1) - mime_type = header.split(":")[1].split(";")[0] - - # Set image in instances[0].image - request.setdefault("instances", [{}])[0]["image"] = { - "bytesBase64Encoded": encoded, - "mimeType": mime_type, - } - else: - msg = "ImageArtifact must have data or path for first_frame (base64 encoding required)" - raise ValidationError(msg) - - return request - - -class LastFrameMapper(ParameterMapper): - """Map last_frame parameter to Google API format.""" +class LastFrameMapper(_LastFrameMapper): name = VideoGenerationParameter.LAST_FRAME - def map( - self, - request: dict[str, Any], - value: object, - model: Model, - ) -> dict[str, Any]: - """Transform last_frame into provider request.""" - validated_value = self._validate_value(value, model) - if validated_value is None: - return request - - # Check if first_frame (image) exists - lastFrame requires image per API docs - instances = request.get("instances", [{}]) - if not instances or "image" not in instances[0]: - msg = "last_frame requires first_frame to be provided" - raise ValidationError(msg) - - # Check if URL is base64 data URI - if validated_value.url and validated_value.url.startswith("data:image/"): - # Extract base64 data from data URI - header, encoded = validated_value.url.split(",", 1) - mime_type = header.split(":")[1].split(";")[0] - - # Set lastFrame in instances[0] to match image structure - request.setdefault("instances", [{}])[0]["lastFrame"] = { - "bytesBase64Encoded": encoded, - "mimeType": mime_type, - } - else: - msg = "ImageArtifact must have data or path for last_frame (base64 encoding required)" - raise ValidationError(msg) - - return request - GOOGLE_PARAMETER_MAPPERS: list[ParameterMapper] = [ AspectRatioMapper(), ResolutionMapper(), - DurationSecondsMapper(), + DurationMapper(), ReferenceImagesMapper(), FirstFrameMapper(), LastFrameMapper(), diff --git a/packages/capabilities/video-generation/tests/integration_tests/test_video_generation/__init__.py b/packages/capabilities/video-generation/tests/integration_tests/test_video_generation/__init__.py index c6eb386..25f049b 100644 --- a/packages/capabilities/video-generation/tests/integration_tests/test_video_generation/__init__.py +++ b/packages/capabilities/video-generation/tests/integration_tests/test_video_generation/__init__.py @@ -1 +1 @@ -"""Integration tests for video generation.""" +"""Video generation integration test module.""" diff --git a/packages/capabilities/video-generation/tests/integration_tests/test_video_generation/test_generate.py b/packages/capabilities/video-generation/tests/integration_tests/test_video_generation/test_generate.py index 90e59de..455d92f 100644 --- a/packages/capabilities/video-generation/tests/integration_tests/test_video_generation/test_generate.py +++ b/packages/capabilities/video-generation/tests/integration_tests/test_video_generation/test_generate.py @@ -28,17 +28,9 @@ @pytest.mark.integration @pytest.mark.asyncio async def test_generate(provider: Provider, model: str, parameters: dict) -> None: - """Test video generation with prompt parameter across all providers. - - This test demonstrates that the unified API works identically across - all providers using the same code - proving the abstraction value. - Uses cheapest models with minimum duration and lowest resolution to minimize costs. - """ - # Import here to avoid circular import during pytest collection - from celeste_video_generation import ( - VideoGenerationOutput, - VideoGenerationUsage, - ) + """Test video generation across all providers. Uses cheapest models.""" + # Import inside function to avoid circular import + from celeste_video_generation import VideoGenerationOutput, VideoGenerationUsage from celeste.artifacts import VideoArtifact @@ -46,13 +38,13 @@ async def test_generate(provider: Provider, model: str, parameters: dict) -> Non client = create_client( capability=Capability.VIDEO_GENERATION, provider=provider, + model=model, ) - prompt = "A cinematic video of a sunset over mountains" + prompt = "A cat playing with a ball" # Act response = await client.generate( prompt=prompt, - model=model, **parameters, ) From f0634fe82ca29c6967fde791d353c5cc862253b2 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Mon, 15 Dec 2025 15:51:15 +0100 Subject: [PATCH 3/3] feat(core): add image_to_data_uri utility function Converts ImageArtifact to base64 data URI string for API requests that require inline image data (e.g., Google Veo reference images). --- src/celeste/__init__.py | 2 ++ src/celeste/utils.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 src/celeste/utils.py diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 113a8cc..d13f655 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -25,6 +25,7 @@ from celeste.models import Model, get_model, list_models, register_models from celeste.parameters import Parameters from celeste.registry import _load_from_entry_points +from celeste.utils import image_to_data_uri logger = logging.getLogger(__name__) @@ -135,6 +136,7 @@ def create_client( "create_client", "get_client_class", "get_model", + "image_to_data_uri", "list_models", "register_client", "register_models", diff --git a/src/celeste/utils.py b/src/celeste/utils.py new file mode 100644 index 0000000..13eba3d --- /dev/null +++ b/src/celeste/utils.py @@ -0,0 +1,35 @@ +"""Utility functions for Celeste.""" + +import base64 + +from celeste.artifacts import ImageArtifact + + +def image_to_data_uri(image: ImageArtifact) -> str: + """Convert an ImageArtifact to a base64 data URI string. + + Args: + image: ImageArtifact with data or path. + + Returns: + Data URI string (e.g., "data:image/png;base64,iVBORw0KGgo..."). + + Raises: + ValueError: If image has neither data nor path. + """ + if image.data: + file_data = image.data + elif image.path: + with open(image.path, "rb") as f: + file_data = f.read() + else: + msg = "ImageArtifact must have data or path" + raise ValueError(msg) + + base64_data = base64.b64encode(file_data).decode("utf-8") + mime_type = image.mime_type.value if image.mime_type else "image/jpeg" + + return f"data:{mime_type};base64,{base64_data}" + + +__all__ = ["image_to_data_uri"]