diff --git a/src/celeste/auth.py b/src/celeste/auth.py index 3affe19..14acadc 100644 --- a/src/celeste/auth.py +++ b/src/celeste/auth.py @@ -66,9 +66,9 @@ def get_auth_class(auth_type: str) -> type[Authentication]: Raises: ValueError: If auth type is not registered. """ - from celeste.registry import _load_from_entry_points + from celeste.registry import _load_providers_from_entry_points - _load_from_entry_points() + _load_providers_from_entry_points() if auth_type not in _auth_classes: msg = f"Unknown auth type: {auth_type}. Available: {list(_auth_classes.keys())}" diff --git a/src/celeste/client.py b/src/celeste/client.py index 92ee479..c9d5eb9 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -20,6 +20,7 @@ from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters from celeste.streaming import Stream +from celeste.types import StructuredOutput class Client[In: Input, Out: Output, Params: Parameters](ABC, BaseModel): @@ -123,7 +124,7 @@ def _parse_content( self, response_data: dict[str, Any], **parameters: Unpack[Params], # type: ignore[misc] - ) -> object: + ) -> StructuredOutput: """Parse content from provider response.""" ... @@ -205,9 +206,9 @@ def _handle_error_response(self, response: httpx.Response) -> None: def _transform_output( self, - content: object, + content: StructuredOutput, **parameters: Unpack[Params], # type: ignore[misc] - ) -> object: + ) -> StructuredOutput: """Transform content using parameter mapper output transformations.""" for mapper in self.parameter_mappers(): value = parameters.get(mapper.name) diff --git a/src/celeste/http.py b/src/celeste/http.py index 2b0361b..c267d87 100644 --- a/src/celeste/http.py +++ b/src/celeste/http.py @@ -14,7 +14,7 @@ MAX_CONNECTIONS = 20 MAX_KEEPALIVE_CONNECTIONS = 10 -DEFAULT_TIMEOUT = 60.0 +DEFAULT_TIMEOUT = 180.0 class HTTPClient: diff --git a/src/celeste/parameters.py b/src/celeste/parameters.py index 2c2b791..d904bdc 100644 --- a/src/celeste/parameters.py +++ b/src/celeste/parameters.py @@ -6,6 +6,7 @@ from celeste.exceptions import UnsupportedParameterError from celeste.models import Model +from celeste.types import StructuredOutput class Parameters(TypedDict, total=False): @@ -32,7 +33,9 @@ def map(self, request: dict[str, Any], value: Any, model: Model) -> dict[str, An """ ... - def parse_output(self, content: Any, value: object | None) -> object: # noqa: ANN401 + def parse_output( + self, content: StructuredOutput, value: object | None + ) -> StructuredOutput: """Optionally transform parsed content based on parameter value (default: return unchanged).""" return content diff --git a/src/celeste/registry.py b/src/celeste/registry.py index 48690d9..514b51b 100644 --- a/src/celeste/registry.py +++ b/src/celeste/registry.py @@ -3,6 +3,7 @@ import importlib.metadata _loaded_packages: set[str] = set() +_loaded_providers: set[str] = set() def _load_from_entry_points() -> None: @@ -22,3 +23,22 @@ def _load_from_entry_points() -> None: # The function should register models and clients when called register_func() _loaded_packages.add(ep.name) + + +def _load_providers_from_entry_points() -> None: + """Load auth from installed provider packages via entry points.""" + + entry_points = importlib.metadata.entry_points(group="celeste.providers") + + # Early return if all providers are already loaded + entry_point_names = {ep.name for ep in entry_points} + if entry_point_names.issubset(_loaded_providers): + return + + for ep in entry_points: + if ep.name in _loaded_providers: + continue + register_func = ep.load() + # The function should register auth types when called + register_func() + _loaded_providers.add(ep.name) diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 48f7606..7bc2490 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -20,6 +20,7 @@ from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters from celeste.streaming import Stream +from celeste.types import StructuredOutput class ParamEnum(StrEnum): @@ -92,7 +93,9 @@ def map( request[actual_map_key] = value return request - def parse_output(self, content: object, value: object | None) -> object: + def parse_output( + self, content: StructuredOutput, value: object | None + ) -> StructuredOutput: return content return TestMapperClass() @@ -120,7 +123,9 @@ def map( request[actual_map_key] = value return request - def parse_output(self, content: object, value: object | None) -> object: + def parse_output( + self, content: StructuredOutput, value: object | None + ) -> StructuredOutput: if value is not None: return f"{content}_transformed_with_{value}" return content diff --git a/tests/unit_tests/test_http.py b/tests/unit_tests/test_http.py index 6fbd219..415d0a4 100644 --- a/tests/unit_tests/test_http.py +++ b/tests/unit_tests/test_http.py @@ -296,7 +296,7 @@ async def test_post_uses_default_timeout_when_not_specified( # Assert - Verify default timeout was used mock_httpx_client.post.assert_called_once() call_kwargs = mock_httpx_client.post.call_args[1] - assert call_kwargs["timeout"] == 60.0 + assert call_kwargs["timeout"] == 180.0 async def test_get_uses_default_timeout_when_not_specified( self, mock_httpx_client: AsyncMock @@ -316,7 +316,7 @@ async def test_get_uses_default_timeout_when_not_specified( # Assert - Verify default timeout was used mock_httpx_client.get.assert_called_once() call_kwargs = mock_httpx_client.get.call_args[1] - assert call_kwargs["timeout"] == 60.0 + assert call_kwargs["timeout"] == 180.0 async def test_custom_timeout_passed_to_httpx( self, mock_httpx_client: AsyncMock