From f65d728c4240e1f48d1a929dd2acb49a854403c5 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Thu, 6 Nov 2025 16:39:35 +0100 Subject: [PATCH] refactor: unify generate() and stream() methods to base Client class --- src/celeste/artifacts.py | 10 ++- src/celeste/client.py | 125 +++++++++++++++++++++++--------- src/celeste/models.py | 2 + tests/unit_tests/test_client.py | 48 ++++++++++-- 4 files changed, 142 insertions(+), 43 deletions(-) diff --git a/src/celeste/artifacts.py b/src/celeste/artifacts.py index e6424f2..113fa1b 100644 --- a/src/celeste/artifacts.py +++ b/src/celeste/artifacts.py @@ -8,7 +8,15 @@ class Artifact(BaseModel): - """Base class for all media artifacts.""" + """Base class for all media artifacts. + + Artifacts can be represented in three ways: + - url: Remote HTTP/HTTPS URL (may expire, e.g., DALL-E URLs last 1 hour) + - data: In-memory bytes (for immediate use without download) + - path: Local filesystem path (for local providers or saved files) + + Providers typically populate only one of these fields. + """ url: str | None = None data: bytes | None = None diff --git a/src/celeste/client.py b/src/celeste/client.py index bee3d9e..fe68e94 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -1,6 +1,7 @@ """Base client and client registry for AI capabilities.""" from abc import ABC, abstractmethod +from collections.abc import AsyncIterator from json import JSONDecodeError from typing import Any, Unpack @@ -37,6 +38,53 @@ def http_client(self) -> HTTPClient: """Shared HTTP client with connection pooling for this provider.""" return get_http_client(self.provider, self.capability) + async def generate(self, *args: Any, **parameters: Unpack[Parameters]) -> Out: # noqa: ANN401 + """Generate content - signature varies by capability. + + Args: + *args: Capability-specific positional arguments (prompt, image, video, etc.). + **parameters: Capability-specific keyword arguments (temperature, max_tokens, etc.). + + Returns: + Output of the parameterized type (e.g., TextGenerationOutput). + """ + inputs = self._create_inputs(*args, **parameters) + request_body = self._build_request(inputs, **parameters) + response = await self._make_request(request_body, **parameters) + self._handle_error_response(response) + response_data = response.json() + return self._output_class()( + content=self._parse_content(response_data, **parameters), + usage=self._parse_usage(response_data), + metadata=self._build_metadata(response_data), + ) + + def stream(self, *args: Any, **parameters: Unpack[Parameters]) -> Stream[Out]: # noqa: ANN401 + """Stream content - signature varies by capability. + + Args: + *args: Capability-specific positional arguments (same as generate). + **parameters: Capability-specific keyword arguments (same as generate). + + Returns: + Stream yielding chunks and providing final Output. + + Raises: + NotImplementedError: If model doesn't support streaming. + """ + if not self.model.streaming: + msg = f"Streaming not supported for model '{self.model.id}'" + raise NotImplementedError(msg) + + inputs = self._create_inputs(*args, **parameters) + request_body = self._build_request(inputs, **parameters) + sse_iterator = self._make_stream_request(request_body, **parameters) + return self._stream_class()( # type: ignore[call-arg] + sse_iterator, + transform_output=self._transform_output, + **parameters, + ) + @classmethod @abstractmethod def parameter_mappers(cls) -> list[ParameterMapper]: @@ -46,31 +94,66 @@ def parameter_mappers(cls) -> list[ParameterMapper]: @abstractmethod def _init_request(self, inputs: In) -> dict[str, Any]: """Initialize provider-specific base request structure.""" - pass + ... @abstractmethod def _parse_usage(self, response_data: dict[str, Any]) -> Usage: """Parse usage information from provider response.""" - pass + ... @abstractmethod def _parse_content( self, response_data: dict[str, Any], **parameters: Unpack[Parameters] ) -> object: """Parse content from provider response.""" - pass + ... + + @abstractmethod + def _create_inputs(self, *args: Any, **parameters: Unpack[Parameters]) -> In: # noqa: ANN401 + """Map positional arguments to Input type.""" + ... + + @classmethod + @abstractmethod + def _output_class(cls) -> type[Out]: + """Return the Output class for this client.""" + ... + + @abstractmethod + async def _make_request( + self, request_body: dict[str, Any], **parameters: Unpack[Parameters] + ) -> httpx.Response: + """Make HTTP request(s) and return response object.""" + ... + + @abstractmethod + def _stream_class(self) -> type[Stream[Out]]: + """Return the Stream class for this client.""" + ... + + @abstractmethod + def _make_stream_request( + self, request_body: dict[str, Any], **parameters: Unpack[Parameters] + ) -> AsyncIterator[dict[str, Any]]: + """Make HTTP streaming request and return async iterator of events.""" + ... + + def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: + """Build metadata dictionary from response data.""" + return { + "model": self.model.id, + "provider": self.provider.value, + } def _handle_error_response(self, response: httpx.Response) -> None: - """Handle error responses from provider APIs""" + """Handle error responses from provider APIs.""" if not response.is_success: - # Try to extract error message from JSON response try: error_data = response.json() error_msg = error_data.get("error", {}).get("message", response.text) except JSONDecodeError: error_msg = response.text or f"HTTP {response.status_code}" - # Raise HTTPStatusError with provider context raise httpx.HTTPStatusError( f"{self.provider.value} API error: {error_msg}", request=response.request, @@ -93,42 +176,12 @@ def _build_request( """Build complete request by combining base request with parameters.""" request = self._init_request(inputs) - # Apply parameter mappers from registry for mapper in self.parameter_mappers(): value = parameters.get(mapper.name) request = mapper.map(request, value, self.model) return request - def stream(self, *args: Any, **parameters: Unpack[Parameters]) -> Stream[Out]: # noqa: ANN401 - """Stream content - signature varies by capability. - - Args: - *args: Capability-specific positional arguments (same as generate). - **parameters: Capability-specific keyword arguments (same as generate). - - Returns: - Stream yielding chunks and providing final Output. - - Raises: - NotImplementedError: If capability doesn't support streaming. - """ - msg = f"Streaming not supported for {self.capability.value} with provider {self.provider.value}" - raise NotImplementedError(msg) - - @abstractmethod - async def generate(self, *args: Any, **parameters: Unpack[Parameters]) -> Out: # noqa: ANN401 - """Generate content - signature varies by capability. - - Args: - *args: Capability-specific positional arguments (prompt, text, image_url, etc.). - **parameters: Capability-specific keyword arguments (temperature, max_tokens, etc.). - - Returns: - Output of the parameterized type (e.g., TextGenerationOutput). - """ - pass - _clients: dict[tuple[Capability, Provider], type[Client]] = {} diff --git a/src/celeste/models.py b/src/celeste/models.py index 8ed0879..4b66f20 100644 --- a/src/celeste/models.py +++ b/src/celeste/models.py @@ -14,6 +14,7 @@ class Model(BaseModel): display_name: str capabilities: set[Capability] = Field(default_factory=set) parameter_constraints: dict[str, Constraint] = Field(default_factory=dict) + streaming: bool = Field(default=False) @property def supported_parameters(self) -> set[str]: @@ -51,6 +52,7 @@ def register_models(models: Model | list[Model], capability: Capability) -> None display_name=model.display_name, capabilities=set(), parameter_constraints={}, + streaming=model.streaming, ), ) diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 37c32a0..2747deb 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -1,9 +1,10 @@ """High-value tests for Client - focusing on critical validation and framework behavior.""" -from collections.abc import Generator +from collections.abc import AsyncIterator, Generator from enum import StrEnum from typing import Any, Unpack +import httpx import pytest from pydantic import SecretStr, ValidationError @@ -12,6 +13,7 @@ from celeste.io import Input, Output, Usage from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters +from celeste.streaming import Stream class ParamEnum(StrEnum): @@ -167,8 +169,43 @@ def _parse_content( ) -> Any: # noqa: ANN401 return response_data.get("content", "test content") - async def generate(self, **parameters: Unpack[Parameters]) -> Output: - return Output(content="test output") + def _create_inputs( + self, + *args: Any, # noqa: ANN401 + **parameters: Unpack[Parameters], + ) -> Input: + """Map positional arguments to Input type.""" + if args: + prompt = str(args[0]) + return _TestInput(prompt=prompt) + prompt_value = parameters.get("prompt", "test prompt") + prompt = str(prompt_value) if prompt_value is not None else "test prompt" + return _TestInput(prompt=prompt) + + @classmethod + def _output_class(cls) -> type[Output]: + """Return the Output class for this client.""" + return Output + + async def _make_request( + self, request_body: dict[str, Any], **parameters: Unpack[Parameters] + ) -> httpx.Response: + """Make HTTP request(s) and return response object.""" + return httpx.Response( + 200, + json={"content": "test content"}, + request=httpx.Request("POST", "https://test.com"), + ) + + def _stream_class(self) -> type[Stream[Output]]: + """Return the Stream class for this client.""" + raise NotImplementedError("Streaming not implemented in test client") + + def _make_stream_request( + self, request_body: dict[str, Any], **parameters: Unpack[Parameters] + ) -> AsyncIterator[dict[str, Any]]: + """Make HTTP streaming request and return async iterator of events.""" + raise NotImplementedError("Streaming not implemented in test client") class TestClientValidation: @@ -502,10 +539,9 @@ def test_stream_raises_not_implemented_with_descriptive_error( # Act & Assert with pytest.raises(NotImplementedError) as exc_info: - client.stream() + client.stream("test prompt") # Verify error message contains all debugging info error_msg = str(exc_info.value) assert "Streaming not supported" in error_msg - assert "text_generation" in error_msg - assert "openai" in error_msg + assert "gpt-4" in error_msg