From 2f5d892eda9d5cfa568c0517e2cb81069a9b60d2 Mon Sep 17 00:00:00 2001 From: Abir Rahman <118625188+abirzishan32@users.noreply.github.com> Date: Sat, 8 Nov 2025 13:10:41 +0600 Subject: [PATCH 1/7] implement custom exceptions for Celeste and add unit tests --- src/celeste/__init__.py | 25 +++- src/celeste/client.py | 20 ++- src/celeste/credentials.py | 7 +- src/celeste/exceptions.py | 152 ++++++++++++++++++++++ src/celeste/http.py | 8 ++ src/celeste/parameters.py | 13 +- src/celeste/streaming.py | 9 +- tests/unit_tests/test_exceptions.py | 193 ++++++++++++++++++++++++++++ 8 files changed, 406 insertions(+), 21 deletions(-) create mode 100644 src/celeste/exceptions.py create mode 100644 tests/unit_tests/test_exceptions.py diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 7a0038a..22d83a9 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -6,6 +6,18 @@ from celeste.client import Client, get_client_class, register_client from celeste.core import Capability, Parameter, Provider from celeste.credentials import credentials +from celeste.exceptions import ( + CelesteError, + ClientNotFoundError, + ConstraintViolationError, + MissingCredentialsError, + ModelNotFoundError, + StreamEmptyError, + StreamingNotSupportedError, + StreamNotExhaustedError, + UnsupportedCapabilityError, + UnsupportedParameterError, +) from celeste.http import HTTPClient, close_all_http_clients from celeste.io import Input, Output, Usage from celeste.models import Model, get_model, list_models, register_models @@ -35,8 +47,7 @@ def _resolve_model( raise ValueError(msg) found = get_model(model, provider) if not found: - msg = f"Model '{model}' not found for provider {provider}" - raise ValueError(msg) + raise ModelNotFoundError(model_id=model, provider=provider.value) return found return model @@ -97,14 +108,24 @@ def _load_from_entry_points() -> None: # Exports __all__ = [ "Capability", + "CelesteError", "Client", + "ClientNotFoundError", + "ConstraintViolationError", "HTTPClient", "Input", + "MissingCredentialsError", "Model", + "ModelNotFoundError", "Output", "Parameter", "Parameters", "Provider", + "StreamEmptyError", + "StreamingNotSupportedError", + "StreamNotExhaustedError", + "UnsupportedCapabilityError", + "UnsupportedParameterError", "Usage", "close_all_http_clients", "create_client", diff --git a/src/celeste/client.py b/src/celeste/client.py index 00e5cb6..b352cb5 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -9,6 +9,11 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr from celeste.core import Capability, Provider +from celeste.exceptions import ( + ClientNotFoundError, + StreamingNotSupportedError, + UnsupportedCapabilityError, +) from celeste.http import HTTPClient, get_http_client from celeste.io import Input, Output, Usage from celeste.models import Model @@ -29,8 +34,9 @@ class Client[In: Input, Out: Output, Params: Parameters](ABC, BaseModel): def model_post_init(self, __context: object) -> None: """Validate capability compatibility.""" if self.capability not in self.model.capabilities: - raise ValueError( - f"Model '{self.model.id}' does not support capability {self.capability.value}" + raise UnsupportedCapabilityError( + model_id=self.model.id, + capability=self.capability.value, ) @property @@ -81,8 +87,7 @@ def stream( NotImplementedError: If model doesn't support streaming. """ if not self.model.streaming: - msg = f"Streaming not supported for model '{self.model.id}'" - raise NotImplementedError(msg) + raise StreamingNotSupportedError(model_id=self.model.id) inputs = self._create_inputs(*args, **parameters) request_body = self._build_request(inputs, **parameters) @@ -236,11 +241,12 @@ def get_client_class( The registered client class. Raises: - NotImplementedError: If no client is registered for this capability/provider. + ClientNotFoundError: If no client is registered for this capability/provider. """ if (capability, provider) not in _clients: - raise NotImplementedError( - f"No client registered for {capability.value} with provider {provider.value}" + raise ClientNotFoundError( + capability=capability.value, + provider=provider.value, ) return _clients[(capability, provider)] diff --git a/src/celeste/credentials.py b/src/celeste/credentials.py index a173598..a0a8b46 100644 --- a/src/celeste/credentials.py +++ b/src/celeste/credentials.py @@ -5,6 +5,7 @@ from pydantic_settings import BaseSettings from celeste.core import Provider +from celeste.exceptions import MissingCredentialsError # Provider to credential field mapping PROVIDER_CREDENTIAL_MAP = { @@ -61,15 +62,13 @@ def get_credentials( SecretStr containing the API key for the provider. Raises: - ValueError: If provider requires credentials but none are configured, - or if provider is not supported (no credential mapping). + MissingCredentialsError: If provider requires credentials but none are configured. """ if override_key: return override_key if not self.has_credential(provider): - msg = f"Provider {provider} has no credentials configured." - raise ValueError(msg) + raise MissingCredentialsError(provider=provider.value) credential: SecretStr = getattr(self, PROVIDER_CREDENTIAL_MAP[provider]) return credential diff --git a/src/celeste/exceptions.py b/src/celeste/exceptions.py new file mode 100644 index 0000000..21ae18f --- /dev/null +++ b/src/celeste/exceptions.py @@ -0,0 +1,152 @@ +"""Custom exceptions for Celeste.""" + + +class CelesteError(Exception): + """Base exception for all Celeste errors.""" + + pass + + +class ModelError(CelesteError): + """Errors related to model operations and registry.""" + + pass + + +class ModelNotFoundError(ModelError): + """Raised when a requested model cannot be found.""" + + def __init__(self, model_id: str, provider: str) -> None: + """Initialize with model details.""" + self.model_id = model_id + self.provider = provider + super().__init__(f"Model '{model_id}' not found for provider {provider}") + + +class CapabilityError(CelesteError): + """Errors related to capability compatibility.""" + + pass + + +class UnsupportedCapabilityError(CapabilityError): + """Raised when a model doesn't support a requested capability.""" + + def __init__(self, model_id: str, capability: str) -> None: + """Initialize with model and capability details.""" + self.model_id = model_id + self.capability = capability + super().__init__( + f"Model '{model_id}' does not support capability '{capability}'" + ) + + +class ClientError(CelesteError): + """Errors related to client operations.""" + + pass + + +class ClientNotFoundError(ClientError): + """Raised when no client is registered for a capability/provider combination.""" + + def __init__(self, capability: str, provider: str) -> None: + """Initialize with capability and provider details.""" + self.capability = capability + self.provider = provider + super().__init__( + f"No client registered for {capability} with provider {provider}" + ) + + +class StreamingError(CelesteError): + """Errors related to streaming operations.""" + + pass + + +class StreamingNotSupportedError(StreamingError): + """Raised when streaming is requested for a model that doesn't support it.""" + + def __init__(self, model_id: str) -> None: + """Initialize with model details.""" + self.model_id = model_id + super().__init__(f"Streaming not supported for model '{model_id}'") + + +class StreamNotExhaustedError(StreamingError): + """Raised when accessing stream output before consuming all chunks.""" + + def __init__(self) -> None: + """Initialize with helpful message.""" + super().__init__("Stream not exhausted. Consume all chunks before accessing .output") + + +class StreamEmptyError(StreamingError): + """Raised when a stream completes without producing any chunks.""" + + def __init__(self) -> None: + """Initialize with helpful message.""" + super().__init__("Stream completed but no chunks were produced") + + +class CredentialsError(CelesteError): + """Errors related to API credentials.""" + + pass + + +class MissingCredentialsError(CredentialsError): + """Raised when required credentials are not configured.""" + + def __init__(self, provider: str) -> None: + """Initialize with provider details.""" + self.provider = provider + super().__init__( + f"Provider {provider} has no credentials configured. " + f"Set the appropriate environment variable or pass api_key parameter." + ) + + +class ValidationError(CelesteError): + """Errors related to parameter and constraint validation.""" + + pass + + +class ConstraintViolationError(ValidationError): + """Raised when a value violates a constraint.""" + + pass + + +class UnsupportedParameterError(ValidationError): + """Raised when a parameter is not supported by a model.""" + + def __init__(self, parameter: str, model_id: str) -> None: + """Initialize with parameter and model details.""" + self.parameter = parameter + self.model_id = model_id + super().__init__( + f"Parameter '{parameter}' is not supported by model '{model_id}'" + ) + + +__all__ = [ + "CelesteError", + "ModelError", + "ModelNotFoundError", + "CapabilityError", + "UnsupportedCapabilityError", + "ClientError", + "ClientNotFoundError", + "StreamingError", + "StreamingNotSupportedError", + "StreamNotExhaustedError", + "StreamEmptyError", + "CredentialsError", + "MissingCredentialsError", + "ValidationError", + "ConstraintViolationError", + "UnsupportedParameterError", +] diff --git a/src/celeste/http.py b/src/celeste/http.py index 756ad5e..8d6746b 100644 --- a/src/celeste/http.py +++ b/src/celeste/http.py @@ -65,7 +65,11 @@ async def post( Raises: httpx.HTTPError: On network or timeout errors. + ValueError: If URL is empty or invalid. """ + if not url or not url.strip(): + raise ValueError("URL cannot be empty") + client = await self._get_client() return await client.post( url, @@ -94,7 +98,11 @@ async def get( Raises: httpx.HTTPError: On network or timeout errors. + ValueError: If URL is empty or invalid. """ + if not url or not url.strip(): + raise ValueError("URL cannot be empty") + client = await self._get_client() return await client.get( url, diff --git a/src/celeste/parameters.py b/src/celeste/parameters.py index ff85861..88cdb32 100644 --- a/src/celeste/parameters.py +++ b/src/celeste/parameters.py @@ -4,6 +4,7 @@ from enum import StrEnum from typing import Any, TypedDict +from celeste.exceptions import UnsupportedParameterError from celeste.models import Model @@ -36,14 +37,20 @@ def parse_output(self, content: Any, value: object | None) -> object: # noqa: A return content def _validate_value(self, value: Any, model: Model) -> Any: # noqa: ANN401 - """Validate parameter value using model constraint, raising ValueError if no constraint exists.""" + """Validate parameter value using model constraint. + + Raises: + UnsupportedParameterError: If parameter is not supported by the model. + """ if value is None: return None constraint = model.parameter_constraints.get(self.name) if constraint is None: - msg = f"Parameter {self.name.value} is not supported by model {model.id}" - raise ValueError(msg) + raise UnsupportedParameterError( + parameter=self.name.value, + model_id=model.id, + ) return constraint(value) diff --git a/src/celeste/streaming.py b/src/celeste/streaming.py index 16233cc..cdb2a8e 100644 --- a/src/celeste/streaming.py +++ b/src/celeste/streaming.py @@ -5,6 +5,7 @@ from types import TracebackType from typing import Any, Self, Unpack +from celeste.exceptions import StreamEmptyError, StreamNotExhaustedError from celeste.io import Chunk, Output from celeste.parameters import Parameters @@ -67,8 +68,7 @@ async def __anext__(self) -> Chunk: # Stream exhausted - validate and parse final output if not self._chunks: - msg = "Stream completed but no chunks were produced" - raise RuntimeError(msg) + raise StreamEmptyError() self._output = self._parse_output(self._chunks, **self._parameters) except Exception: @@ -96,10 +96,9 @@ async def __aexit__( @property def output(self) -> Out: - """Access final Output after stream exhaustion (raises RuntimeError if not ready).""" + """Access final Output after stream exhaustion (raises StreamNotExhaustedError if not ready).""" if self._output is None: - msg = "Stream not exhausted. Consume all chunks before accessing .output" - raise RuntimeError(msg) + raise StreamNotExhaustedError() return self._output async def aclose(self) -> None: diff --git a/tests/unit_tests/test_exceptions.py b/tests/unit_tests/test_exceptions.py new file mode 100644 index 0000000..57d95e8 --- /dev/null +++ b/tests/unit_tests/test_exceptions.py @@ -0,0 +1,193 @@ +"""Tests for custom exception classes.""" + +import pytest + +from celeste.exceptions import ( + CelesteError, + ClientNotFoundError, + ConstraintViolationError, + MissingCredentialsError, + ModelNotFoundError, + StreamEmptyError, + StreamingNotSupportedError, + StreamNotExhaustedError, + UnsupportedCapabilityError, + UnsupportedParameterError, +) + + +class TestExceptionHierarchy: + """Test exception hierarchy and inheritance.""" + + def test_all_exceptions_inherit_from_celeste_error(self) -> None: + """Test that all custom exceptions inherit from CelesteError.""" + exceptions = [ + ModelNotFoundError("model-1", "openai"), + UnsupportedCapabilityError("model-1", "text_generation"), + ClientNotFoundError("text_generation", "openai"), + StreamingNotSupportedError("model-1"), + StreamNotExhaustedError(), + StreamEmptyError(), + MissingCredentialsError("openai"), + UnsupportedParameterError("temperature", "model-1"), + ] + + for exc in exceptions: + assert isinstance(exc, CelesteError) + assert isinstance(exc, Exception) + + +class TestModelNotFoundError: + """Test ModelNotFoundError exception.""" + + def test_creates_with_model_and_provider(self) -> None: + """Test exception stores model and provider attributes.""" + exc = ModelNotFoundError("gpt-4", "openai") + + assert exc.model_id == "gpt-4" + assert exc.provider == "openai" + assert "gpt-4" in str(exc) + assert "openai" in str(exc) + + def test_message_is_descriptive(self) -> None: + """Test exception message is clear and actionable.""" + exc = ModelNotFoundError("claude-3", "anthropic") + + assert str(exc) == "Model 'claude-3' not found for provider anthropic" + + +class TestUnsupportedCapabilityError: + """Test UnsupportedCapabilityError exception.""" + + def test_creates_with_model_and_capability(self) -> None: + """Test exception stores model and capability attributes.""" + exc = UnsupportedCapabilityError("gpt-4", "image_generation") + + assert exc.model_id == "gpt-4" + assert exc.capability == "image_generation" + assert "gpt-4" in str(exc) + assert "image_generation" in str(exc) + + def test_message_is_descriptive(self) -> None: + """Test exception message is clear and actionable.""" + exc = UnsupportedCapabilityError("gpt-3.5-turbo", "video_generation") + + assert ( + str(exc) + == "Model 'gpt-3.5-turbo' does not support capability 'video_generation'" + ) + + +class TestClientNotFoundError: + """Test ClientNotFoundError exception.""" + + def test_creates_with_capability_and_provider(self) -> None: + """Test exception stores capability and provider attributes.""" + exc = ClientNotFoundError("text_generation", "unknown_provider") + + assert exc.capability == "text_generation" + assert exc.provider == "unknown_provider" + assert "text_generation" in str(exc) + assert "unknown_provider" in str(exc) + + +class TestStreamingNotSupportedError: + """Test StreamingNotSupportedError exception.""" + + def test_creates_with_model_id(self) -> None: + """Test exception stores model_id attribute.""" + exc = StreamingNotSupportedError("dall-e-3") + + assert exc.model_id == "dall-e-3" + assert "dall-e-3" in str(exc) + assert "Streaming not supported" in str(exc) + + +class TestStreamNotExhaustedError: + """Test StreamNotExhaustedError exception.""" + + def test_has_helpful_message(self) -> None: + """Test exception message guides user to consume chunks first.""" + exc = StreamNotExhaustedError() + + message = str(exc) + assert "not exhausted" in message.lower() + assert "consume all chunks" in message.lower() + assert ".output" in message + + +class TestStreamEmptyError: + """Test StreamEmptyError exception.""" + + def test_has_descriptive_message(self) -> None: + """Test exception message describes the problem clearly.""" + exc = StreamEmptyError() + + message = str(exc) + assert "completed" in message.lower() + assert "no chunks" in message.lower() + + +class TestMissingCredentialsError: + """Test MissingCredentialsError exception.""" + + def test_creates_with_provider(self) -> None: + """Test exception stores provider attribute.""" + exc = MissingCredentialsError("openai") + + assert exc.provider == "openai" + assert "openai" in str(exc) + + def test_message_provides_guidance(self) -> None: + """Test exception message helps user resolve the issue.""" + exc = MissingCredentialsError("anthropic") + + message = str(exc) + assert "anthropic" in message + assert "environment variable" in message.lower() + assert "api_key" in message.lower() + + +class TestUnsupportedParameterError: + """Test UnsupportedParameterError exception.""" + + def test_creates_with_parameter_and_model(self) -> None: + """Test exception stores parameter and model_id attributes.""" + exc = UnsupportedParameterError("temperature", "dall-e-3") + + assert exc.parameter == "temperature" + assert exc.model_id == "dall-e-3" + assert "temperature" in str(exc) + assert "dall-e-3" in str(exc) + + def test_message_is_clear(self) -> None: + """Test exception message clearly indicates the problem.""" + exc = UnsupportedParameterError("max_tokens", "whisper-1") + + assert ( + str(exc) == "Parameter 'max_tokens' is not supported by model 'whisper-1'" + ) + + +class TestExceptionUsability: + """Test that exceptions can be raised and caught properly.""" + + def test_can_catch_specific_exception(self) -> None: + """Test specific exception types can be caught.""" + with pytest.raises(ModelNotFoundError) as exc_info: + raise ModelNotFoundError("test-model", "test-provider") + + assert exc_info.value.model_id == "test-model" + + def test_can_catch_base_exception(self) -> None: + """Test all custom exceptions can be caught as CelesteError.""" + with pytest.raises(CelesteError): + raise StreamEmptyError() + + def test_can_access_exception_attributes(self) -> None: + """Test exception attributes are accessible after catching.""" + try: + raise UnsupportedParameterError("seed", "gpt-4") + except UnsupportedParameterError as e: + assert e.parameter == "seed" + assert e.model_id == "gpt-4" From cf731d060115a322ae0cf359b061cfe999be8b8b Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Sun, 9 Nov 2025 22:08:27 +0100 Subject: [PATCH 2/7] fix: update tests to use new exception types and kebab-case capabilities - Update test_init.py to expect ModelNotFoundError instead of ValueError - Update test_credentials.py to expect MissingCredentialsError instead of ValueError - Update test_client.py to expect ClientNotFoundError instead of NotImplementedError - Update test_exceptions.py to use kebab-case capability strings (text-generation, image-generation, etc.) --- tests/unit_tests/test_client.py | 6 ++++-- tests/unit_tests/test_credentials.py | 4 +++- tests/unit_tests/test_exceptions.py | 21 ++++++++++----------- tests/unit_tests/test_init.py | 3 ++- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 7f8317a..bd6fb91 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -328,14 +328,16 @@ def test_register_and_retrieve_client_success(self) -> None: assert retrieved_class is ConcreteClient def test_get_client_class_raises_for_unregistered_capability(self) -> None: - """get_client_class raises NotImplementedError for unregistered capabilities.""" + """get_client_class raises ClientNotFoundError for unregistered capabilities.""" # Arrange + from celeste.exceptions import ClientNotFoundError + unregistered_capability = Capability.IMAGE_GENERATION provider = Provider.OPENAI # Act & Assert with pytest.raises( - NotImplementedError, + ClientNotFoundError, match=rf"No client registered for {Capability.IMAGE_GENERATION}", ): get_client_class(unregistered_capability, provider) diff --git a/tests/unit_tests/test_credentials.py b/tests/unit_tests/test_credentials.py index 573f0fa..2ea11d0 100644 --- a/tests/unit_tests/test_credentials.py +++ b/tests/unit_tests/test_credentials.py @@ -137,7 +137,9 @@ def test_get_missing_credential_raises( creds = Credentials() # type: ignore[call-arg] # Act & Assert - with pytest.raises(ValueError, match="no credentials configured"): + from celeste.exceptions import MissingCredentialsError + + with pytest.raises(MissingCredentialsError, match="no credentials configured"): creds.get_credentials(Provider.OPENAI) @pytest.mark.parametrize( diff --git a/tests/unit_tests/test_exceptions.py b/tests/unit_tests/test_exceptions.py index 57d95e8..0fc59b3 100644 --- a/tests/unit_tests/test_exceptions.py +++ b/tests/unit_tests/test_exceptions.py @@ -5,7 +5,6 @@ from celeste.exceptions import ( CelesteError, ClientNotFoundError, - ConstraintViolationError, MissingCredentialsError, ModelNotFoundError, StreamEmptyError, @@ -23,8 +22,8 @@ def test_all_exceptions_inherit_from_celeste_error(self) -> None: """Test that all custom exceptions inherit from CelesteError.""" exceptions = [ ModelNotFoundError("model-1", "openai"), - UnsupportedCapabilityError("model-1", "text_generation"), - ClientNotFoundError("text_generation", "openai"), + UnsupportedCapabilityError("model-1", "text-generation"), + ClientNotFoundError("text-generation", "openai"), StreamingNotSupportedError("model-1"), StreamNotExhaustedError(), StreamEmptyError(), @@ -61,20 +60,20 @@ class TestUnsupportedCapabilityError: def test_creates_with_model_and_capability(self) -> None: """Test exception stores model and capability attributes.""" - exc = UnsupportedCapabilityError("gpt-4", "image_generation") + exc = UnsupportedCapabilityError("gpt-4", "image-generation") assert exc.model_id == "gpt-4" - assert exc.capability == "image_generation" + assert exc.capability == "image-generation" assert "gpt-4" in str(exc) - assert "image_generation" in str(exc) + assert "image-generation" in str(exc) def test_message_is_descriptive(self) -> None: """Test exception message is clear and actionable.""" - exc = UnsupportedCapabilityError("gpt-3.5-turbo", "video_generation") + exc = UnsupportedCapabilityError("gpt-3.5-turbo", "video-generation") assert ( str(exc) - == "Model 'gpt-3.5-turbo' does not support capability 'video_generation'" + == "Model 'gpt-3.5-turbo' does not support capability 'video-generation'" ) @@ -83,11 +82,11 @@ class TestClientNotFoundError: def test_creates_with_capability_and_provider(self) -> None: """Test exception stores capability and provider attributes.""" - exc = ClientNotFoundError("text_generation", "unknown_provider") + exc = ClientNotFoundError("text-generation", "unknown_provider") - assert exc.capability == "text_generation" + assert exc.capability == "text-generation" assert exc.provider == "unknown_provider" - assert "text_generation" in str(exc) + assert "text-generation" in str(exc) assert "unknown_provider" in str(exc) diff --git a/tests/unit_tests/test_init.py b/tests/unit_tests/test_init.py index 4cd4e81..3ccd4a1 100644 --- a/tests/unit_tests/test_init.py +++ b/tests/unit_tests/test_init.py @@ -6,6 +6,7 @@ from pydantic import SecretStr from celeste import Capability, Model, Provider, create_client +from celeste.exceptions import ModelNotFoundError @pytest.fixture @@ -55,7 +56,7 @@ def test_create_client_specific_model_not_found_raises_error(self) -> None: mock_get_model.return_value = None # Act & Assert - with pytest.raises(ValueError, match=r"Model.*not found"): + with pytest.raises(ModelNotFoundError, match=r"Model.*not found"): create_client( capability=Capability.TEXT_GENERATION, provider=Provider.OPENAI, From 8bf65d2109c0c2be5cf9c586154c54c3149e7ac3 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Sun, 9 Nov 2025 22:09:43 +0100 Subject: [PATCH 3/7] refactor: rename CelesteError to Error for naming consistency Following the pattern of Client (not CelesteClient) and Parameter (not CelesteParameter), the base exception is now Error (not CelesteError) for consistency. --- src/celeste/__init__.py | 6 ++--- src/celeste/exceptions.py | 38 +++++++++++++++-------------- tests/unit_tests/test_exceptions.py | 12 ++++----- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 22d83a9..1d3debe 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -7,9 +7,9 @@ from celeste.core import Capability, Parameter, Provider from celeste.credentials import credentials from celeste.exceptions import ( - CelesteError, ClientNotFoundError, ConstraintViolationError, + Error, MissingCredentialsError, ModelNotFoundError, StreamEmptyError, @@ -108,10 +108,10 @@ def _load_from_entry_points() -> None: # Exports __all__ = [ "Capability", - "CelesteError", "Client", "ClientNotFoundError", "ConstraintViolationError", + "Error", "HTTPClient", "Input", "MissingCredentialsError", @@ -122,8 +122,8 @@ def _load_from_entry_points() -> None: "Parameters", "Provider", "StreamEmptyError", - "StreamingNotSupportedError", "StreamNotExhaustedError", + "StreamingNotSupportedError", "UnsupportedCapabilityError", "UnsupportedParameterError", "Usage", diff --git a/src/celeste/exceptions.py b/src/celeste/exceptions.py index 21ae18f..e37a116 100644 --- a/src/celeste/exceptions.py +++ b/src/celeste/exceptions.py @@ -1,13 +1,13 @@ """Custom exceptions for Celeste.""" -class CelesteError(Exception): +class Error(Exception): """Base exception for all Celeste errors.""" pass -class ModelError(CelesteError): +class ModelError(Error): """Errors related to model operations and registry.""" pass @@ -23,7 +23,7 @@ def __init__(self, model_id: str, provider: str) -> None: super().__init__(f"Model '{model_id}' not found for provider {provider}") -class CapabilityError(CelesteError): +class CapabilityError(Error): """Errors related to capability compatibility.""" pass @@ -41,7 +41,7 @@ def __init__(self, model_id: str, capability: str) -> None: ) -class ClientError(CelesteError): +class ClientError(Error): """Errors related to client operations.""" pass @@ -59,7 +59,7 @@ def __init__(self, capability: str, provider: str) -> None: ) -class StreamingError(CelesteError): +class StreamingError(Error): """Errors related to streaming operations.""" pass @@ -79,7 +79,9 @@ class StreamNotExhaustedError(StreamingError): def __init__(self) -> None: """Initialize with helpful message.""" - super().__init__("Stream not exhausted. Consume all chunks before accessing .output") + super().__init__( + "Stream not exhausted. Consume all chunks before accessing .output" + ) class StreamEmptyError(StreamingError): @@ -90,7 +92,7 @@ def __init__(self) -> None: super().__init__("Stream completed but no chunks were produced") -class CredentialsError(CelesteError): +class CredentialsError(Error): """Errors related to API credentials.""" pass @@ -108,7 +110,7 @@ def __init__(self, provider: str) -> None: ) -class ValidationError(CelesteError): +class ValidationError(Error): """Errors related to parameter and constraint validation.""" pass @@ -133,20 +135,20 @@ def __init__(self, parameter: str, model_id: str) -> None: __all__ = [ - "CelesteError", - "ModelError", - "ModelNotFoundError", "CapabilityError", - "UnsupportedCapabilityError", "ClientError", "ClientNotFoundError", - "StreamingError", - "StreamingNotSupportedError", - "StreamNotExhaustedError", - "StreamEmptyError", + "ConstraintViolationError", "CredentialsError", + "Error", "MissingCredentialsError", - "ValidationError", - "ConstraintViolationError", + "ModelError", + "ModelNotFoundError", + "StreamEmptyError", + "StreamNotExhaustedError", + "StreamingError", + "StreamingNotSupportedError", + "UnsupportedCapabilityError", "UnsupportedParameterError", + "ValidationError", ] diff --git a/tests/unit_tests/test_exceptions.py b/tests/unit_tests/test_exceptions.py index 0fc59b3..d03c0a1 100644 --- a/tests/unit_tests/test_exceptions.py +++ b/tests/unit_tests/test_exceptions.py @@ -3,8 +3,8 @@ import pytest from celeste.exceptions import ( - CelesteError, ClientNotFoundError, + Error, MissingCredentialsError, ModelNotFoundError, StreamEmptyError, @@ -18,8 +18,8 @@ class TestExceptionHierarchy: """Test exception hierarchy and inheritance.""" - def test_all_exceptions_inherit_from_celeste_error(self) -> None: - """Test that all custom exceptions inherit from CelesteError.""" + def test_all_exceptions_inherit_from_error(self) -> None: + """Test that all custom exceptions inherit from Error.""" exceptions = [ ModelNotFoundError("model-1", "openai"), UnsupportedCapabilityError("model-1", "text-generation"), @@ -32,7 +32,7 @@ def test_all_exceptions_inherit_from_celeste_error(self) -> None: ] for exc in exceptions: - assert isinstance(exc, CelesteError) + assert isinstance(exc, Error) assert isinstance(exc, Exception) @@ -179,8 +179,8 @@ def test_can_catch_specific_exception(self) -> None: assert exc_info.value.model_id == "test-model" def test_can_catch_base_exception(self) -> None: - """Test all custom exceptions can be caught as CelesteError.""" - with pytest.raises(CelesteError): + """Test all custom exceptions can be caught as Error.""" + with pytest.raises(Error): raise StreamEmptyError() def test_can_access_exception_attributes(self) -> None: From d5abf651c64444b57e14bb02fe6d9895a3dc9c66 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Sun, 9 Nov 2025 22:10:47 +0100 Subject: [PATCH 4/7] fix: move imports to top of file Move exception imports from inside test functions to top of file for proper Python import ordering. --- tests/unit_tests/test_client.py | 3 +-- tests/unit_tests/test_credentials.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index bd6fb91..fd4709f 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -10,6 +10,7 @@ from celeste.client import Client, _clients, get_client_class, register_client from celeste.core import Capability, Provider +from celeste.exceptions import ClientNotFoundError from celeste.io import Input, Output, Usage from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters @@ -330,8 +331,6 @@ def test_register_and_retrieve_client_success(self) -> None: def test_get_client_class_raises_for_unregistered_capability(self) -> None: """get_client_class raises ClientNotFoundError for unregistered capabilities.""" # Arrange - from celeste.exceptions import ClientNotFoundError - unregistered_capability = Capability.IMAGE_GENERATION provider = Provider.OPENAI diff --git a/tests/unit_tests/test_credentials.py b/tests/unit_tests/test_credentials.py index 2ea11d0..7c537f8 100644 --- a/tests/unit_tests/test_credentials.py +++ b/tests/unit_tests/test_credentials.py @@ -9,6 +9,7 @@ from celeste.core import Provider from celeste.credentials import PROVIDER_CREDENTIAL_MAP, Credentials +from celeste.exceptions import MissingCredentialsError # Single source of truth for environment variable names ENV_VAR_NAMES = [ @@ -137,8 +138,6 @@ def test_get_missing_credential_raises( creds = Credentials() # type: ignore[call-arg] # Act & Assert - from celeste.exceptions import MissingCredentialsError - with pytest.raises(MissingCredentialsError, match="no credentials configured"): creds.get_credentials(Provider.OPENAI) From 89deced74ce7603a7e55896481c39bbaa31ed80c Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Sun, 9 Nov 2025 22:23:40 +0100 Subject: [PATCH 5/7] refactor: use ConstraintViolationError and remove base classes from exports - Replace all ValueError/TypeError in constraints.py with ConstraintViolationError - Remove base exception classes from __all__ exports (keep for inheritance) - Update all constraint tests to expect ConstraintViolationError - Add ConstraintViolationError tests to test_exceptions.py This makes constraint violations catchable as a specific exception type while keeping the API surface minimal by not exporting base classes. --- src/celeste/constraints.py | 30 +++++----- src/celeste/exceptions.py | 6 -- tests/unit_tests/test_constraints.py | 83 ++++++++++++++++------------ tests/unit_tests/test_exceptions.py | 21 +++++++ 4 files changed, 85 insertions(+), 55 deletions(-) diff --git a/src/celeste/constraints.py b/src/celeste/constraints.py index 725536c..cff104d 100644 --- a/src/celeste/constraints.py +++ b/src/celeste/constraints.py @@ -7,6 +7,8 @@ from pydantic import BaseModel, Field +from celeste.exceptions import ConstraintViolationError + class Constraint(BaseModel, ABC): """Base constraint for parameter validation.""" @@ -26,7 +28,7 @@ def __call__(self, value: T) -> T: """Validate value is in options.""" if value not in self.options: msg = f"Must be one of {self.options}, got {value!r}" - raise ValueError(msg) + raise ConstraintViolationError(msg) return value @@ -46,7 +48,7 @@ def __call__(self, value: float | int) -> float | int: """Validate value is within range and matches step increment.""" if not isinstance(value, (int, float)): msg = f"Must be numeric, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) # Check if value is a special value that bypasses range check if self.special_values is not None and value in self.special_values: @@ -58,7 +60,7 @@ def __call__(self, value: float | int) -> float | int: f" or one of {self.special_values}" if self.special_values else "" ) msg = f"Must be between {self.min} and {self.max}{special_msg}, got {value}" - raise ValueError(msg) + raise ConstraintViolationError(msg) # Validate step if provided if self.step is not None: @@ -74,7 +76,7 @@ def __call__(self, value: float | int) -> float | int: ) closest_above = closest_below + self.step msg = f"Value must match step {self.step}. Nearest valid: {closest_below} or {closest_above}, got {value}" - raise ValueError(msg) + raise ConstraintViolationError(msg) return value @@ -88,11 +90,11 @@ def __call__(self, value: str) -> str: """Validate value matches pattern.""" if not isinstance(value, str): msg = f"Must be string, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) if not re.fullmatch(self.pattern, value): msg = f"Must match pattern {self.pattern!r}, got {value!r}" - raise ValueError(msg) + raise ConstraintViolationError(msg) return value @@ -107,15 +109,15 @@ def __call__(self, value: str) -> str: """Validate value is a string.""" if not isinstance(value, str): msg = f"Must be string, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) if self.min_length is not None and len(value) < self.min_length: msg = f"String too short (min {self.min_length}), got {len(value)}" - raise ValueError(msg) + raise ConstraintViolationError(msg) if self.max_length is not None and len(value) > self.max_length: msg = f"String too long (max {self.max_length}), got {len(value)}" - raise ValueError(msg) + raise ConstraintViolationError(msg) return value @@ -128,7 +130,7 @@ def __call__(self, value: int) -> int: # isinstance(True, int) is True, so exclude bools explicitly if not isinstance(value, int) or isinstance(value, bool): msg = f"Must be int, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) return value @@ -140,7 +142,7 @@ def __call__(self, value: float) -> float: """Validate value is numeric.""" if not isinstance(value, (int, float)) or isinstance(value, bool): msg = f"Must be float or int, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) return float(value) @@ -152,7 +154,7 @@ def __call__(self, value: bool) -> bool: """Validate value is a boolean.""" if not isinstance(value, bool): msg = f"Must be bool, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) return value @@ -167,13 +169,13 @@ def __call__(self, value: type[BaseModel]) -> type[BaseModel]: inner = get_args(value)[0] if not (isinstance(inner, type) and issubclass(inner, BaseModel)): msg = f"List type must be BaseModel, got {inner}" - raise TypeError(msg) + raise ConstraintViolationError(msg) return value # For plain type, validate directly if not (isinstance(value, type) and issubclass(value, BaseModel)): msg = f"Must be BaseModel, got {value}" - raise TypeError(msg) + raise ConstraintViolationError(msg) return value diff --git a/src/celeste/exceptions.py b/src/celeste/exceptions.py index e37a116..8c0d943 100644 --- a/src/celeste/exceptions.py +++ b/src/celeste/exceptions.py @@ -135,20 +135,14 @@ def __init__(self, parameter: str, model_id: str) -> None: __all__ = [ - "CapabilityError", - "ClientError", "ClientNotFoundError", "ConstraintViolationError", - "CredentialsError", "Error", "MissingCredentialsError", - "ModelError", "ModelNotFoundError", "StreamEmptyError", "StreamNotExhaustedError", - "StreamingError", "StreamingNotSupportedError", "UnsupportedCapabilityError", "UnsupportedParameterError", - "ValidationError", ] diff --git a/tests/unit_tests/test_constraints.py b/tests/unit_tests/test_constraints.py index 90f039a..4ce7448 100644 --- a/tests/unit_tests/test_constraints.py +++ b/tests/unit_tests/test_constraints.py @@ -3,6 +3,7 @@ import pytest from celeste.constraints import Bool, Choice, Float, Int, Pattern, Range, Str +from celeste.exceptions import ConstraintViolationError class TestChoice: @@ -18,11 +19,11 @@ def test_validates_value_in_options(self) -> None: assert result == "b" def test_rejects_value_not_in_options(self) -> None: - """Test that invalid choice raises ValueError.""" + """Test that invalid choice raises ConstraintViolationError.""" constraint = Choice[str](options=["a", "b", "c"]) with pytest.raises( - ValueError, match=r"Must be one of \['a', 'b', 'c'\], got 'd'" + ConstraintViolationError, match=r"Must be one of \['a', 'b', 'c'\], got 'd'" ): constraint("d") @@ -60,24 +61,28 @@ def test_validates_boundary_values(self) -> None: assert constraint(10) == 10 def test_rejects_value_below_min(self) -> None: - """Test that value below min raises ValueError.""" + """Test that value below min raises ConstraintViolationError.""" constraint = Range(min=0, max=10) - with pytest.raises(ValueError, match=r"Must be between 0 and 10, got -1"): + with pytest.raises( + ConstraintViolationError, match=r"Must be between 0 and 10, got -1" + ): constraint(-1) def test_rejects_value_above_max(self) -> None: - """Test that value above max raises ValueError.""" + """Test that value above max raises ConstraintViolationError.""" constraint = Range(min=0, max=10) - with pytest.raises(ValueError, match=r"Must be between 0 and 10, got 11"): + with pytest.raises( + ConstraintViolationError, match=r"Must be between 0 and 10, got 11" + ): constraint(11) def test_rejects_non_numeric_value(self) -> None: - """Test that non-numeric value raises TypeError.""" + """Test that non-numeric value raises ConstraintViolationError.""" constraint = Range(min=0, max=10) - with pytest.raises(TypeError, match=r"Must be numeric, got str"): + with pytest.raises(ConstraintViolationError, match=r"Must be numeric, got str"): constraint("5") # type: ignore[arg-type] def test_accepts_both_int_and_float(self) -> None: @@ -97,11 +102,11 @@ def test_validates_value_with_step(self) -> None: assert constraint(10) == 10 # max def test_rejects_value_not_on_step(self) -> None: - """Test that value not on step increment raises ValueError.""" + """Test that value not on step increment raises ConstraintViolationError.""" constraint = Range(min=0, max=10, step=2) with pytest.raises( - ValueError, + ConstraintViolationError, match=r"Value must match step 2(\.0)?. Nearest valid: 2(\.0)? or 4(\.0)?, got 3", ): constraint(3) @@ -126,7 +131,7 @@ def test_step_validation_with_non_zero_min(self) -> None: assert constraint(14) == 14 # min + 9 with pytest.raises( - ValueError, + ConstraintViolationError, match=r"Value must match step 3(\.0)?. Nearest valid: 5(\.0)? or 8(\.0)?, got 7", ): constraint(7) @@ -169,17 +174,17 @@ def test_validates_matching_pattern(self) -> None: assert result == "123-4567" def test_rejects_non_matching_pattern(self) -> None: - """Test that non-matching pattern raises ValueError.""" + """Test that non-matching pattern raises ConstraintViolationError.""" constraint = Pattern(pattern=r"^\d{3}-\d{4}$") - with pytest.raises(ValueError, match=r"Must match pattern"): + with pytest.raises(ConstraintViolationError, match=r"Must match pattern"): constraint("abc-defg") def test_rejects_non_string_value(self) -> None: - """Test that non-string value raises TypeError.""" + """Test that non-string value raises ConstraintViolationError.""" constraint = Pattern(pattern=r"^\d+$") - with pytest.raises(TypeError, match=r"Must be string, got int"): + with pytest.raises(ConstraintViolationError, match=r"Must be string, got int"): constraint(123) # type: ignore[arg-type] def test_validates_complex_regex_patterns(self) -> None: @@ -213,24 +218,28 @@ def test_validates_string_within_length_bounds(self) -> None: assert result == "valid" def test_rejects_string_below_min_length(self) -> None: - """Test string shorter than min_length raises ValueError.""" + """Test string shorter than min_length raises ConstraintViolationError.""" constraint = Str(min_length=5) - with pytest.raises(ValueError, match=r"String too short \(min 5\), got 3"): + with pytest.raises( + ConstraintViolationError, match=r"String too short \(min 5\), got 3" + ): constraint("abc") def test_rejects_string_above_max_length(self) -> None: - """Test string longer than max_length raises ValueError.""" + """Test string longer than max_length raises ConstraintViolationError.""" constraint = Str(max_length=5) - with pytest.raises(ValueError, match=r"String too long \(max 5\), got 10"): + with pytest.raises( + ConstraintViolationError, match=r"String too long \(max 5\), got 10" + ): constraint("too long!!") def test_rejects_non_string_value(self) -> None: - """Test non-string value raises TypeError.""" + """Test non-string value raises ConstraintViolationError.""" constraint = Str() - with pytest.raises(TypeError, match=r"Must be string, got int"): + with pytest.raises(ConstraintViolationError, match=r"Must be string, got int"): constraint(123) # type: ignore[arg-type] def test_validates_boundary_lengths(self) -> None: @@ -254,24 +263,24 @@ def test_validates_integer_value(self) -> None: assert result == 42 def test_rejects_float_value(self) -> None: - """Test that float raises TypeError.""" + """Test that float raises ConstraintViolationError.""" constraint = Int() - with pytest.raises(TypeError, match=r"Must be int, got float"): + with pytest.raises(ConstraintViolationError, match=r"Must be int, got float"): constraint(42.0) # type: ignore[arg-type] def test_rejects_boolean_value(self) -> None: - """Test that bool raises TypeError despite isinstance(True, int).""" + """Test that bool raises ConstraintViolationError despite isinstance(True, int).""" constraint = Int() - with pytest.raises(TypeError, match=r"Must be int, got bool"): + with pytest.raises(ConstraintViolationError, match=r"Must be int, got bool"): constraint(True) def test_rejects_string_value(self) -> None: - """Test that string raises TypeError.""" + """Test that string raises ConstraintViolationError.""" constraint = Int() - with pytest.raises(TypeError, match=r"Must be int, got str"): + with pytest.raises(ConstraintViolationError, match=r"Must be int, got str"): constraint("42") # type: ignore[arg-type] @@ -297,17 +306,21 @@ def test_accepts_and_converts_int_to_float(self) -> None: assert isinstance(result, float) def test_rejects_boolean_value(self) -> None: - """Test that bool raises TypeError despite isinstance(True, int).""" + """Test that bool raises ConstraintViolationError despite isinstance(True, int).""" constraint = Float() - with pytest.raises(TypeError, match=r"Must be float or int, got bool"): + with pytest.raises( + ConstraintViolationError, match=r"Must be float or int, got bool" + ): constraint(True) def test_rejects_string_value(self) -> None: - """Test that string raises TypeError.""" + """Test that string raises ConstraintViolationError.""" constraint = Float() - with pytest.raises(TypeError, match=r"Must be float or int, got str"): + with pytest.raises( + ConstraintViolationError, match=r"Must be float or int, got str" + ): constraint("3.14") # type: ignore[arg-type] @@ -323,15 +336,15 @@ def test_validates_boolean_value(self) -> None: assert constraint(False) is False def test_rejects_int_value(self) -> None: - """Test that int raises TypeError (no implicit 0/1 conversion).""" + """Test that int raises ConstraintViolationError (no implicit 0/1 conversion).""" constraint = Bool() - with pytest.raises(TypeError, match=r"Must be bool, got int"): + with pytest.raises(ConstraintViolationError, match=r"Must be bool, got int"): constraint(1) # type: ignore[arg-type] def test_rejects_string_value(self) -> None: - """Test that string raises TypeError.""" + """Test that string raises ConstraintViolationError.""" constraint = Bool() - with pytest.raises(TypeError, match=r"Must be bool, got str"): + with pytest.raises(ConstraintViolationError, match=r"Must be bool, got str"): constraint("true") # type: ignore[arg-type] diff --git a/tests/unit_tests/test_exceptions.py b/tests/unit_tests/test_exceptions.py index d03c0a1..2817a3b 100644 --- a/tests/unit_tests/test_exceptions.py +++ b/tests/unit_tests/test_exceptions.py @@ -4,6 +4,7 @@ from celeste.exceptions import ( ClientNotFoundError, + ConstraintViolationError, Error, MissingCredentialsError, ModelNotFoundError, @@ -29,6 +30,7 @@ def test_all_exceptions_inherit_from_error(self) -> None: StreamEmptyError(), MissingCredentialsError("openai"), UnsupportedParameterError("temperature", "model-1"), + ConstraintViolationError("Must be between 0 and 1, got 2"), ] for exc in exceptions: @@ -168,6 +170,25 @@ def test_message_is_clear(self) -> None: ) +class TestConstraintViolationError: + """Test ConstraintViolationError exception.""" + + def test_creates_with_message(self) -> None: + """Test exception can be created with a message.""" + exc = ConstraintViolationError("Must be between 0 and 1, got 2") + + assert "Must be between 0 and 1" in str(exc) + assert "got 2" in str(exc) + + def test_inherits_from_validation_error(self) -> None: + """Test ConstraintViolationError inherits from ValidationError.""" + from celeste.exceptions import ValidationError + + exc = ConstraintViolationError("Test message") + assert isinstance(exc, ValidationError) + assert isinstance(exc, Error) + + class TestExceptionUsability: """Test that exceptions can be raised and caught properly.""" From 69b60a50789537f0f9000ba9f2513a9fb335cf16 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Sun, 9 Nov 2025 22:31:36 +0100 Subject: [PATCH 6/7] refactor: make ModelNotFoundError flexible and update exception docstrings - Make ModelNotFoundError support optional parameters (model_id, provider, capability) - Update _resolve_model to use ModelNotFoundError instead of ValueError for no models available - Add UnsupportedCapabilityError to create_client docstring Raises section - Update Client.stream() docstring to mention StreamingNotSupportedError - Update tests to expect correct custom exceptions instead of generic ones --- src/celeste/__init__.py | 12 ++++++--- src/celeste/client.py | 2 +- src/celeste/exceptions.py | 31 ++++++++++++++++++++--- tests/unit_tests/test_client.py | 24 ++++++++++-------- tests/unit_tests/test_exceptions.py | 38 ++++++++++++++++++++++++++--- tests/unit_tests/test_init.py | 5 ++-- tests/unit_tests/test_streaming.py | 19 ++++++++------- 7 files changed, 98 insertions(+), 33 deletions(-) diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 1d3debe..23360a8 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -36,8 +36,10 @@ def _resolve_model( # Auto-select first available model models = list_models(provider=provider, capability=capability) if not models: - msg = f"No model found for {capability}" - raise ValueError(msg) + raise ModelNotFoundError( + capability=capability.value, + provider=provider.value if provider else None, + ) return models[0] if isinstance(model, str): @@ -71,8 +73,10 @@ def create_client( Configured client instance ready for generation operations. Raises: - ValueError: If no model found or resolution fails. - NotImplementedError: If no client registered for capability/provider. + ModelNotFoundError: If no model found for the specified capability/provider. + ClientNotFoundError: If no client registered for capability/provider. + MissingCredentialsError: If required credentials are not configured. + UnsupportedCapabilityError: If the resolved model doesn't support the requested capability. """ # Resolve model resolved_model = _resolve_model(capability, provider, model) diff --git a/src/celeste/client.py b/src/celeste/client.py index b352cb5..7d6a844 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -84,7 +84,7 @@ def stream( Stream yielding chunks and providing final Output. Raises: - NotImplementedError: If model doesn't support streaming. + StreamingNotSupportedError: If model doesn't support streaming. """ if not self.model.streaming: raise StreamingNotSupportedError(model_id=self.model.id) diff --git a/src/celeste/exceptions.py b/src/celeste/exceptions.py index 8c0d943..d05b17a 100644 --- a/src/celeste/exceptions.py +++ b/src/celeste/exceptions.py @@ -16,11 +16,36 @@ class ModelError(Error): class ModelNotFoundError(ModelError): """Raised when a requested model cannot be found.""" - def __init__(self, model_id: str, provider: str) -> None: - """Initialize with model details.""" + def __init__( + self, + model_id: str | None = None, + provider: str | None = None, + capability: str | None = None, + ) -> None: + """Initialize with model details. + + Args: + model_id: Optional specific model ID that was not found. + provider: Optional provider name. + capability: Optional capability name (used when no specific model_id). + """ self.model_id = model_id self.provider = provider - super().__init__(f"Model '{model_id}' not found for provider {provider}") + self.capability = capability + + # Generate appropriate error message based on available parameters + if model_id and provider: + msg = f"Model '{model_id}' not found for provider {provider}" + elif capability and provider: + msg = ( + f"No model found for capability '{capability}' with provider {provider}" + ) + elif capability: + msg = f"No model found for capability '{capability}'" + else: + msg = "Model not found" + + super().__init__(msg) class CapabilityError(Error): diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index fd4709f..8bc96f5 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -6,11 +6,15 @@ import httpx import pytest -from pydantic import SecretStr, ValidationError +from pydantic import SecretStr from celeste.client import Client, _clients, get_client_class, register_client from celeste.core import Capability, Provider -from celeste.exceptions import ClientNotFoundError +from celeste.exceptions import ( + ClientNotFoundError, + StreamingNotSupportedError, + UnsupportedCapabilityError, +) from celeste.io import Input, Output, Usage from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters @@ -235,8 +239,8 @@ def test_validation_failure_with_incompatible_capability( """Client rejects model that lacks required capability.""" # Arrange & Act & Assert with pytest.raises( - ValidationError, - match=rf"Model 'gpt-4' does not support capability {Capability.IMAGE_GENERATION}", + UnsupportedCapabilityError, + match=rf"Model 'gpt-4' does not support capability '{Capability.IMAGE_GENERATION}'", ): ConcreteClient( model=text_model, @@ -287,8 +291,8 @@ def test_validation_fails_with_model_lacking_any_capabilities( # Act & Assert with pytest.raises( - ValidationError, - match=rf"Model 'broken-model' does not support capability {Capability.TEXT_GENERATION}", + UnsupportedCapabilityError, + match=rf"Model 'broken-model' does not support capability '{Capability.TEXT_GENERATION}'", ): ConcreteClient( model=empty_model, @@ -401,9 +405,9 @@ def test_exception_message_includes_capability_and_provider( expected_capability_str: str, expected_provider_str: str, ) -> None: - """NotImplementedError includes both capability and provider for debugging.""" + """ClientNotFoundError includes both capability and provider for debugging.""" # Arrange & Act & Assert - with pytest.raises(NotImplementedError) as exc_info: + with pytest.raises(ClientNotFoundError) as exc_info: get_client_class(missing_capability, provider) # Assert both parts in error message @@ -530,7 +534,7 @@ class TestClientStreaming: def test_stream_raises_not_implemented_with_descriptive_error( self, text_model: Model, api_key: str ) -> None: - """stream() raises NotImplementedError with capability and provider info.""" + """stream() raises StreamingNotSupportedError with capability and provider info.""" # Arrange client = ConcreteClient( model=text_model, @@ -540,7 +544,7 @@ def test_stream_raises_not_implemented_with_descriptive_error( ) # Act & Assert - with pytest.raises(NotImplementedError) as exc_info: + with pytest.raises(StreamingNotSupportedError) as exc_info: client.stream("test prompt") # Verify error message contains all debugging info diff --git a/tests/unit_tests/test_exceptions.py b/tests/unit_tests/test_exceptions.py index 2817a3b..b88ade6 100644 --- a/tests/unit_tests/test_exceptions.py +++ b/tests/unit_tests/test_exceptions.py @@ -43,19 +43,49 @@ class TestModelNotFoundError: def test_creates_with_model_and_provider(self) -> None: """Test exception stores model and provider attributes.""" - exc = ModelNotFoundError("gpt-4", "openai") + exc = ModelNotFoundError(model_id="gpt-4", provider="openai") assert exc.model_id == "gpt-4" assert exc.provider == "openai" + assert exc.capability is None assert "gpt-4" in str(exc) assert "openai" in str(exc) - def test_message_is_descriptive(self) -> None: - """Test exception message is clear and actionable.""" - exc = ModelNotFoundError("claude-3", "anthropic") + def test_message_is_descriptive_with_model_and_provider(self) -> None: + """Test exception message is clear and actionable for specific model.""" + exc = ModelNotFoundError(model_id="claude-3", provider="anthropic") assert str(exc) == "Model 'claude-3' not found for provider anthropic" + def test_creates_with_capability_only(self) -> None: + """Test exception with capability only (no models available for capability).""" + exc = ModelNotFoundError(capability="text-generation") + + assert exc.model_id is None + assert exc.provider is None + assert exc.capability == "text-generation" + assert "No model found for capability 'text-generation'" in str(exc) + + def test_creates_with_capability_and_provider(self) -> None: + """Test exception with capability and provider (no models for capability/provider combo).""" + exc = ModelNotFoundError(capability="text-generation", provider="openai") + + assert exc.model_id is None + assert exc.provider == "openai" + assert exc.capability == "text-generation" + assert ( + "No model found for capability 'text-generation' with provider openai" + in str(exc) + ) + + def test_backwards_compatibility_positional_args(self) -> None: + """Test that positional arguments still work for backwards compatibility.""" + exc = ModelNotFoundError("gpt-4", "openai") + + assert exc.model_id == "gpt-4" + assert exc.provider == "openai" + assert "gpt-4" in str(exc) + class TestUnsupportedCapabilityError: """Test UnsupportedCapabilityError exception.""" diff --git a/tests/unit_tests/test_init.py b/tests/unit_tests/test_init.py index 3ccd4a1..e4affcd 100644 --- a/tests/unit_tests/test_init.py +++ b/tests/unit_tests/test_init.py @@ -38,14 +38,15 @@ class TestCreateClient: """Test the create_client factory function.""" def test_create_client_no_models_available_raises_error(self) -> None: - """Test that create_client raises ValueError when no models are available.""" + """Test that create_client raises ModelNotFoundError when no models are available.""" with patch("celeste.list_models", autospec=True) as mock_list_models: # Arrange mock_list_models.return_value = [] # Act & Assert with pytest.raises( - ValueError, match=rf"No model found for.*{Capability.TEXT_GENERATION}" + ModelNotFoundError, + match=rf"No model found for capability.*{Capability.TEXT_GENERATION}", ): create_client(capability=Capability.TEXT_GENERATION) diff --git a/tests/unit_tests/test_streaming.py b/tests/unit_tests/test_streaming.py index bd60c71..d6af270 100644 --- a/tests/unit_tests/test_streaming.py +++ b/tests/unit_tests/test_streaming.py @@ -7,6 +7,7 @@ import pytest from pydantic import Field +from celeste.exceptions import StreamEmptyError, StreamNotExhaustedError from celeste.io import Chunk, FinishReason, Output, Usage from celeste.parameters import Parameters from celeste.streaming import Stream @@ -155,14 +156,14 @@ class TestStreamOutputProperty: """Test Stream.output property - access guard and final result.""" async def test_output_access_before_exhaustion_raises_error(self) -> None: - """Accessing .output before stream exhaustion must raise RuntimeError.""" + """Accessing .output before stream exhaustion must raise StreamNotExhaustedError.""" # Arrange events = [{"delta": "test"}] stream = ConcreteStream(_async_iter(events)) - # Act & Assert - Premature access raises RuntimeError + # Act & Assert - Premature access raises StreamNotExhaustedError with pytest.raises( - RuntimeError, match=r"Stream not exhausted\. Consume all chunks" + StreamNotExhaustedError, match=r"Stream not exhausted\. Consume all chunks" ): _ = stream.output @@ -478,7 +479,7 @@ class TestStreamEmptyStreamError: """Test Stream empty stream error handling.""" async def test_empty_stream_raises_runtime_error(self) -> None: - """Stream must raise RuntimeError when exhausted with no chunks.""" + """Stream must raise StreamEmptyError when exhausted with no chunks.""" # Arrange - Create stream where all events return None from _parse_chunk async def empty_iter() -> AsyncIterator[dict[str, Any]]: @@ -487,15 +488,15 @@ async def empty_iter() -> AsyncIterator[dict[str, Any]]: stream = ConcreteStream(empty_iter()) - # Act & Assert - Exhaustion raises RuntimeError + # Act & Assert - Exhaustion raises StreamEmptyError with pytest.raises( - RuntimeError, match=r"Stream completed but no chunks were produced" + StreamEmptyError, match=r"Stream completed but no chunks were produced" ): async for _ in stream: pass async def test_stream_with_only_lifecycle_events_raises_error(self) -> None: - """Stream raises RuntimeError when SSE yields events but all chunks are filtered to None.""" + """Stream raises StreamEmptyError when SSE yields events but all chunks are filtered to None.""" # Arrange - Events that all return None from _parse_chunk events = [ {"type": "ping"}, # Lifecycle event (no delta/content) @@ -505,9 +506,9 @@ async def test_stream_with_only_lifecycle_events_raises_error(self) -> None: ] stream = ConcreteStream(_async_iter(events)) - # Act & Assert - Should raise RuntimeError when exhausted + # Act & Assert - Should raise StreamEmptyError when exhausted with pytest.raises( - RuntimeError, match=r"Stream completed but no chunks were produced" + StreamEmptyError, match=r"Stream completed but no chunks were produced" ): async for _ in stream: pass From 2ca27884563e8a75f6f948f1e4a72daeb6432714 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Sun, 9 Nov 2025 22:35:13 +0100 Subject: [PATCH 7/7] refactor: remove .value from StrEnum usage - Remove .value attribute access from StrEnum instances (Capability, Provider, Parameter) - StrEnum instances can be used directly as strings - Updated all exception calls to pass StrEnum directly instead of .value --- src/celeste/__init__.py | 6 +++--- src/celeste/client.py | 10 +++++----- src/celeste/credentials.py | 2 +- src/celeste/http.py | 4 ++-- src/celeste/parameters.py | 4 ++-- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 23360a8..e2b923f 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -37,8 +37,8 @@ def _resolve_model( models = list_models(provider=provider, capability=capability) if not models: raise ModelNotFoundError( - capability=capability.value, - provider=provider.value if provider else None, + capability=capability, + provider=provider if provider else None, ) return models[0] @@ -49,7 +49,7 @@ def _resolve_model( raise ValueError(msg) found = get_model(model, provider) if not found: - raise ModelNotFoundError(model_id=model, provider=provider.value) + raise ModelNotFoundError(model_id=model, provider=provider) return found return model diff --git a/src/celeste/client.py b/src/celeste/client.py index 7d6a844..8676c78 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -36,7 +36,7 @@ def model_post_init(self, __context: object) -> None: if self.capability not in self.model.capabilities: raise UnsupportedCapabilityError( model_id=self.model.id, - capability=self.capability.value, + capability=self.capability, ) @property @@ -165,7 +165,7 @@ def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: """Build metadata dictionary from response data.""" return { "model": self.model.id, - "provider": self.provider.value, + "provider": self.provider, } def _handle_error_response(self, response: httpx.Response) -> None: @@ -178,7 +178,7 @@ def _handle_error_response(self, response: httpx.Response) -> None: error_msg = response.text or f"HTTP {response.status_code}" raise httpx.HTTPStatusError( - f"{self.provider.value} API error: {error_msg}", + f"{self.provider} API error: {error_msg}", request=response.request, response=response, ) @@ -245,8 +245,8 @@ def get_client_class( """ if (capability, provider) not in _clients: raise ClientNotFoundError( - capability=capability.value, - provider=provider.value, + capability=capability, + provider=provider, ) return _clients[(capability, provider)] diff --git a/src/celeste/credentials.py b/src/celeste/credentials.py index a0a8b46..124a001 100644 --- a/src/celeste/credentials.py +++ b/src/celeste/credentials.py @@ -68,7 +68,7 @@ def get_credentials( return override_key if not self.has_credential(provider): - raise MissingCredentialsError(provider=provider.value) + raise MissingCredentialsError(provider=provider) credential: SecretStr = getattr(self, PROVIDER_CREDENTIAL_MAP[provider]) return credential diff --git a/src/celeste/http.py b/src/celeste/http.py index 8d6746b..dbd448c 100644 --- a/src/celeste/http.py +++ b/src/celeste/http.py @@ -69,7 +69,7 @@ async def post( """ if not url or not url.strip(): raise ValueError("URL cannot be empty") - + client = await self._get_client() return await client.post( url, @@ -102,7 +102,7 @@ async def get( """ if not url or not url.strip(): raise ValueError("URL cannot be empty") - + client = await self._get_client() return await client.get( url, diff --git a/src/celeste/parameters.py b/src/celeste/parameters.py index 88cdb32..2c2b791 100644 --- a/src/celeste/parameters.py +++ b/src/celeste/parameters.py @@ -38,7 +38,7 @@ def parse_output(self, content: Any, value: object | None) -> object: # noqa: A def _validate_value(self, value: Any, model: Model) -> Any: # noqa: ANN401 """Validate parameter value using model constraint. - + Raises: UnsupportedParameterError: If parameter is not supported by the model. """ @@ -48,7 +48,7 @@ def _validate_value(self, value: Any, model: Model) -> Any: # noqa: ANN401 constraint = model.parameter_constraints.get(self.name) if constraint is None: raise UnsupportedParameterError( - parameter=self.name.value, + parameter=self.name, model_id=model.id, )