Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/celeste/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def get_auth_class(auth_type: str) -> type[Authentication]:
Raises:
ValueError: If auth type is not registered.
"""
from celeste.registry import _load_from_entry_points
from celeste.registry import _load_providers_from_entry_points

_load_from_entry_points()
_load_providers_from_entry_points()

if auth_type not in _auth_classes:
msg = f"Unknown auth type: {auth_type}. Available: {list(_auth_classes.keys())}"
Expand Down
7 changes: 4 additions & 3 deletions src/celeste/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from celeste.models import Model
from celeste.parameters import ParameterMapper, Parameters
from celeste.streaming import Stream
from celeste.types import StructuredOutput


class Client[In: Input, Out: Output, Params: Parameters](ABC, BaseModel):
Expand Down Expand Up @@ -123,7 +124,7 @@ def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[Params], # type: ignore[misc]
) -> object:
) -> StructuredOutput:
"""Parse content from provider response."""
...

Expand Down Expand Up @@ -205,9 +206,9 @@ def _handle_error_response(self, response: httpx.Response) -> None:

def _transform_output(
self,
content: object,
content: StructuredOutput,
**parameters: Unpack[Params], # type: ignore[misc]
) -> object:
) -> StructuredOutput:
"""Transform content using parameter mapper output transformations."""
for mapper in self.parameter_mappers():
value = parameters.get(mapper.name)
Expand Down
2 changes: 1 addition & 1 deletion src/celeste/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

MAX_CONNECTIONS = 20
MAX_KEEPALIVE_CONNECTIONS = 10
DEFAULT_TIMEOUT = 60.0
DEFAULT_TIMEOUT = 180.0


class HTTPClient:
Expand Down
5 changes: 4 additions & 1 deletion src/celeste/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from celeste.exceptions import UnsupportedParameterError
from celeste.models import Model
from celeste.types import StructuredOutput


class Parameters(TypedDict, total=False):
Expand All @@ -32,7 +33,9 @@ def map(self, request: dict[str, Any], value: Any, model: Model) -> dict[str, An
"""
...

def parse_output(self, content: Any, value: object | None) -> object: # noqa: ANN401
def parse_output(
self, content: StructuredOutput, value: object | None
) -> StructuredOutput:
"""Optionally transform parsed content based on parameter value (default: return unchanged)."""
return content

Expand Down
20 changes: 20 additions & 0 deletions src/celeste/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import importlib.metadata

_loaded_packages: set[str] = set()
_loaded_providers: set[str] = set()


def _load_from_entry_points() -> None:
Expand All @@ -22,3 +23,22 @@ def _load_from_entry_points() -> None:
# The function should register models and clients when called
register_func()
_loaded_packages.add(ep.name)


def _load_providers_from_entry_points() -> None:
"""Load auth from installed provider packages via entry points."""

entry_points = importlib.metadata.entry_points(group="celeste.providers")

# Early return if all providers are already loaded
entry_point_names = {ep.name for ep in entry_points}
if entry_point_names.issubset(_loaded_providers):
return

for ep in entry_points:
if ep.name in _loaded_providers:
continue
register_func = ep.load()
# The function should register auth types when called
register_func()
_loaded_providers.add(ep.name)
9 changes: 7 additions & 2 deletions tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from celeste.models import Model
from celeste.parameters import ParameterMapper, Parameters
from celeste.streaming import Stream
from celeste.types import StructuredOutput


class ParamEnum(StrEnum):
Expand Down Expand Up @@ -92,7 +93,9 @@ def map(
request[actual_map_key] = value
return request

def parse_output(self, content: object, value: object | None) -> object:
def parse_output(
self, content: StructuredOutput, value: object | None
) -> StructuredOutput:
return content

return TestMapperClass()
Expand Down Expand Up @@ -120,7 +123,9 @@ def map(
request[actual_map_key] = value
return request

def parse_output(self, content: object, value: object | None) -> object:
def parse_output(
self, content: StructuredOutput, value: object | None
) -> StructuredOutput:
if value is not None:
return f"{content}_transformed_with_{value}"
return content
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ async def test_post_uses_default_timeout_when_not_specified(
# Assert - Verify default timeout was used
mock_httpx_client.post.assert_called_once()
call_kwargs = mock_httpx_client.post.call_args[1]
assert call_kwargs["timeout"] == 60.0
assert call_kwargs["timeout"] == 180.0

async def test_get_uses_default_timeout_when_not_specified(
self, mock_httpx_client: AsyncMock
Expand All @@ -316,7 +316,7 @@ async def test_get_uses_default_timeout_when_not_specified(
# Assert - Verify default timeout was used
mock_httpx_client.get.assert_called_once()
call_kwargs = mock_httpx_client.get.call_args[1]
assert call_kwargs["timeout"] == 60.0
assert call_kwargs["timeout"] == 180.0

async def test_custom_timeout_passed_to_httpx(
self, mock_httpx_client: AsyncMock
Expand Down
Loading