Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/celeste/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@


class Artifact(BaseModel):
"""Base class for all media artifacts."""
"""Base class for all media artifacts.

Artifacts can be represented in three ways:
- url: Remote HTTP/HTTPS URL (may expire, e.g., DALL-E URLs last 1 hour)
- data: In-memory bytes (for immediate use without download)
- path: Local filesystem path (for local providers or saved files)

Providers typically populate only one of these fields.
"""

url: str | None = None
data: bytes | None = None
Expand Down
125 changes: 89 additions & 36 deletions src/celeste/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base client and client registry for AI capabilities."""

from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from json import JSONDecodeError
from typing import Any, Unpack

Expand Down Expand Up @@ -37,6 +38,53 @@ def http_client(self) -> HTTPClient:
"""Shared HTTP client with connection pooling for this provider."""
return get_http_client(self.provider, self.capability)

async def generate(self, *args: Any, **parameters: Unpack[Parameters]) -> Out: # noqa: ANN401
"""Generate content - signature varies by capability.

Args:
*args: Capability-specific positional arguments (prompt, image, video, etc.).
**parameters: Capability-specific keyword arguments (temperature, max_tokens, etc.).

Returns:
Output of the parameterized type (e.g., TextGenerationOutput).
"""
inputs = self._create_inputs(*args, **parameters)
request_body = self._build_request(inputs, **parameters)
response = await self._make_request(request_body, **parameters)
self._handle_error_response(response)
response_data = response.json()
return self._output_class()(
content=self._parse_content(response_data, **parameters),
usage=self._parse_usage(response_data),
metadata=self._build_metadata(response_data),
)

def stream(self, *args: Any, **parameters: Unpack[Parameters]) -> Stream[Out]: # noqa: ANN401
"""Stream content - signature varies by capability.

Args:
*args: Capability-specific positional arguments (same as generate).
**parameters: Capability-specific keyword arguments (same as generate).

Returns:
Stream yielding chunks and providing final Output.

Raises:
NotImplementedError: If model doesn't support streaming.
"""
if not self.model.streaming:
msg = f"Streaming not supported for model '{self.model.id}'"
raise NotImplementedError(msg)

inputs = self._create_inputs(*args, **parameters)
request_body = self._build_request(inputs, **parameters)
sse_iterator = self._make_stream_request(request_body, **parameters)
return self._stream_class()( # type: ignore[call-arg]
sse_iterator,
transform_output=self._transform_output,
**parameters,
)

@classmethod
@abstractmethod
def parameter_mappers(cls) -> list[ParameterMapper]:
Expand All @@ -46,31 +94,66 @@ def parameter_mappers(cls) -> list[ParameterMapper]:
@abstractmethod
def _init_request(self, inputs: In) -> dict[str, Any]:
"""Initialize provider-specific base request structure."""
pass
...

@abstractmethod
def _parse_usage(self, response_data: dict[str, Any]) -> Usage:
"""Parse usage information from provider response."""
pass
...

@abstractmethod
def _parse_content(
self, response_data: dict[str, Any], **parameters: Unpack[Parameters]
) -> object:
"""Parse content from provider response."""
pass
...

@abstractmethod
def _create_inputs(self, *args: Any, **parameters: Unpack[Parameters]) -> In: # noqa: ANN401
"""Map positional arguments to Input type."""
...

@classmethod
@abstractmethod
def _output_class(cls) -> type[Out]:
"""Return the Output class for this client."""
...

@abstractmethod
async def _make_request(
self, request_body: dict[str, Any], **parameters: Unpack[Parameters]
) -> httpx.Response:
"""Make HTTP request(s) and return response object."""
...

@abstractmethod
def _stream_class(self) -> type[Stream[Out]]:
"""Return the Stream class for this client."""
...

@abstractmethod
def _make_stream_request(
self, request_body: dict[str, Any], **parameters: Unpack[Parameters]
) -> AsyncIterator[dict[str, Any]]:
"""Make HTTP streaming request and return async iterator of events."""
...

def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]:
"""Build metadata dictionary from response data."""
return {
"model": self.model.id,
"provider": self.provider.value,
}

def _handle_error_response(self, response: httpx.Response) -> None:
"""Handle error responses from provider APIs"""
"""Handle error responses from provider APIs."""
if not response.is_success:
# Try to extract error message from JSON response
try:
error_data = response.json()
error_msg = error_data.get("error", {}).get("message", response.text)
except JSONDecodeError:
error_msg = response.text or f"HTTP {response.status_code}"

# Raise HTTPStatusError with provider context
raise httpx.HTTPStatusError(
f"{self.provider.value} API error: {error_msg}",
request=response.request,
Expand All @@ -93,42 +176,12 @@ def _build_request(
"""Build complete request by combining base request with parameters."""
request = self._init_request(inputs)

# Apply parameter mappers from registry
for mapper in self.parameter_mappers():
value = parameters.get(mapper.name)
request = mapper.map(request, value, self.model)

return request

def stream(self, *args: Any, **parameters: Unpack[Parameters]) -> Stream[Out]: # noqa: ANN401
"""Stream content - signature varies by capability.

Args:
*args: Capability-specific positional arguments (same as generate).
**parameters: Capability-specific keyword arguments (same as generate).

Returns:
Stream yielding chunks and providing final Output.

Raises:
NotImplementedError: If capability doesn't support streaming.
"""
msg = f"Streaming not supported for {self.capability.value} with provider {self.provider.value}"
raise NotImplementedError(msg)

@abstractmethod
async def generate(self, *args: Any, **parameters: Unpack[Parameters]) -> Out: # noqa: ANN401
"""Generate content - signature varies by capability.

Args:
*args: Capability-specific positional arguments (prompt, text, image_url, etc.).
**parameters: Capability-specific keyword arguments (temperature, max_tokens, etc.).

Returns:
Output of the parameterized type (e.g., TextGenerationOutput).
"""
pass


_clients: dict[tuple[Capability, Provider], type[Client]] = {}

Expand Down
2 changes: 2 additions & 0 deletions src/celeste/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Model(BaseModel):
display_name: str
capabilities: set[Capability] = Field(default_factory=set)
parameter_constraints: dict[str, Constraint] = Field(default_factory=dict)
streaming: bool = Field(default=False)

@property
def supported_parameters(self) -> set[str]:
Expand Down Expand Up @@ -51,6 +52,7 @@ def register_models(models: Model | list[Model], capability: Capability) -> None
display_name=model.display_name,
capabilities=set(),
parameter_constraints={},
streaming=model.streaming,
),
)

Expand Down
48 changes: 42 additions & 6 deletions tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""High-value tests for Client - focusing on critical validation and framework behavior."""

from collections.abc import Generator
from collections.abc import AsyncIterator, Generator
from enum import StrEnum
from typing import Any, Unpack

import httpx
import pytest
from pydantic import SecretStr, ValidationError

Expand All @@ -12,6 +13,7 @@
from celeste.io import Input, Output, Usage
from celeste.models import Model
from celeste.parameters import ParameterMapper, Parameters
from celeste.streaming import Stream


class ParamEnum(StrEnum):
Expand Down Expand Up @@ -167,8 +169,43 @@ def _parse_content(
) -> Any: # noqa: ANN401
return response_data.get("content", "test content")

async def generate(self, **parameters: Unpack[Parameters]) -> Output:
return Output(content="test output")
def _create_inputs(
self,
*args: Any, # noqa: ANN401
**parameters: Unpack[Parameters],
) -> Input:
"""Map positional arguments to Input type."""
if args:
prompt = str(args[0])
return _TestInput(prompt=prompt)
prompt_value = parameters.get("prompt", "test prompt")
prompt = str(prompt_value) if prompt_value is not None else "test prompt"
return _TestInput(prompt=prompt)

@classmethod
def _output_class(cls) -> type[Output]:
"""Return the Output class for this client."""
return Output

async def _make_request(
self, request_body: dict[str, Any], **parameters: Unpack[Parameters]
) -> httpx.Response:
"""Make HTTP request(s) and return response object."""
return httpx.Response(
200,
json={"content": "test content"},
request=httpx.Request("POST", "https://test.com"),
)

def _stream_class(self) -> type[Stream[Output]]:
"""Return the Stream class for this client."""
raise NotImplementedError("Streaming not implemented in test client")

def _make_stream_request(
self, request_body: dict[str, Any], **parameters: Unpack[Parameters]
) -> AsyncIterator[dict[str, Any]]:
"""Make HTTP streaming request and return async iterator of events."""
raise NotImplementedError("Streaming not implemented in test client")


class TestClientValidation:
Expand Down Expand Up @@ -502,10 +539,9 @@ def test_stream_raises_not_implemented_with_descriptive_error(

# Act & Assert
with pytest.raises(NotImplementedError) as exc_info:
client.stream()
client.stream("test prompt")

# Verify error message contains all debugging info
error_msg = str(exc_info.value)
assert "Streaming not supported" in error_msg
assert "text_generation" in error_msg
assert "openai" in error_msg
assert "gpt-4" in error_msg
Loading