diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 7a0038a..e2b923f 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -6,6 +6,18 @@ from celeste.client import Client, get_client_class, register_client from celeste.core import Capability, Parameter, Provider from celeste.credentials import credentials +from celeste.exceptions import ( + ClientNotFoundError, + ConstraintViolationError, + Error, + MissingCredentialsError, + ModelNotFoundError, + StreamEmptyError, + StreamingNotSupportedError, + StreamNotExhaustedError, + UnsupportedCapabilityError, + UnsupportedParameterError, +) from celeste.http import HTTPClient, close_all_http_clients from celeste.io import Input, Output, Usage from celeste.models import Model, get_model, list_models, register_models @@ -24,8 +36,10 @@ def _resolve_model( # Auto-select first available model models = list_models(provider=provider, capability=capability) if not models: - msg = f"No model found for {capability}" - raise ValueError(msg) + raise ModelNotFoundError( + capability=capability, + provider=provider if provider else None, + ) return models[0] if isinstance(model, str): @@ -35,8 +49,7 @@ def _resolve_model( raise ValueError(msg) found = get_model(model, provider) if not found: - msg = f"Model '{model}' not found for provider {provider}" - raise ValueError(msg) + raise ModelNotFoundError(model_id=model, provider=provider) return found return model @@ -60,8 +73,10 @@ def create_client( Configured client instance ready for generation operations. Raises: - ValueError: If no model found or resolution fails. - NotImplementedError: If no client registered for capability/provider. + ModelNotFoundError: If no model found for the specified capability/provider. + ClientNotFoundError: If no client registered for capability/provider. + MissingCredentialsError: If required credentials are not configured. + UnsupportedCapabilityError: If the resolved model doesn't support the requested capability. """ # Resolve model resolved_model = _resolve_model(capability, provider, model) @@ -98,13 +113,23 @@ def _load_from_entry_points() -> None: __all__ = [ "Capability", "Client", + "ClientNotFoundError", + "ConstraintViolationError", + "Error", "HTTPClient", "Input", + "MissingCredentialsError", "Model", + "ModelNotFoundError", "Output", "Parameter", "Parameters", "Provider", + "StreamEmptyError", + "StreamNotExhaustedError", + "StreamingNotSupportedError", + "UnsupportedCapabilityError", + "UnsupportedParameterError", "Usage", "close_all_http_clients", "create_client", diff --git a/src/celeste/client.py b/src/celeste/client.py index d35cd23..8676c78 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -9,6 +9,11 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr from celeste.core import Capability, Provider +from celeste.exceptions import ( + ClientNotFoundError, + StreamingNotSupportedError, + UnsupportedCapabilityError, +) from celeste.http import HTTPClient, get_http_client from celeste.io import Input, Output, Usage from celeste.models import Model @@ -29,8 +34,9 @@ class Client[In: Input, Out: Output, Params: Parameters](ABC, BaseModel): def model_post_init(self, __context: object) -> None: """Validate capability compatibility.""" if self.capability not in self.model.capabilities: - raise ValueError( - f"Model '{self.model.id}' does not support capability {self.capability}" + raise UnsupportedCapabilityError( + model_id=self.model.id, + capability=self.capability, ) @property @@ -78,11 +84,10 @@ def stream( Stream yielding chunks and providing final Output. Raises: - NotImplementedError: If model doesn't support streaming. + StreamingNotSupportedError: If model doesn't support streaming. """ if not self.model.streaming: - msg = f"Streaming not supported for model '{self.model.id}'" - raise NotImplementedError(msg) + raise StreamingNotSupportedError(model_id=self.model.id) inputs = self._create_inputs(*args, **parameters) request_body = self._build_request(inputs, **parameters) @@ -160,7 +165,7 @@ def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: """Build metadata dictionary from response data.""" return { "model": self.model.id, - "provider": self.provider.value, + "provider": self.provider, } def _handle_error_response(self, response: httpx.Response) -> None: @@ -173,7 +178,7 @@ def _handle_error_response(self, response: httpx.Response) -> None: error_msg = response.text or f"HTTP {response.status_code}" raise httpx.HTTPStatusError( - f"{self.provider.value} API error: {error_msg}", + f"{self.provider} API error: {error_msg}", request=response.request, response=response, ) @@ -236,11 +241,12 @@ def get_client_class( The registered client class. Raises: - NotImplementedError: If no client is registered for this capability/provider. + ClientNotFoundError: If no client is registered for this capability/provider. """ if (capability, provider) not in _clients: - raise NotImplementedError( - f"No client registered for {capability} with provider {provider}" + raise ClientNotFoundError( + capability=capability, + provider=provider, ) return _clients[(capability, provider)] diff --git a/src/celeste/constraints.py b/src/celeste/constraints.py index 725536c..cff104d 100644 --- a/src/celeste/constraints.py +++ b/src/celeste/constraints.py @@ -7,6 +7,8 @@ from pydantic import BaseModel, Field +from celeste.exceptions import ConstraintViolationError + class Constraint(BaseModel, ABC): """Base constraint for parameter validation.""" @@ -26,7 +28,7 @@ def __call__(self, value: T) -> T: """Validate value is in options.""" if value not in self.options: msg = f"Must be one of {self.options}, got {value!r}" - raise ValueError(msg) + raise ConstraintViolationError(msg) return value @@ -46,7 +48,7 @@ def __call__(self, value: float | int) -> float | int: """Validate value is within range and matches step increment.""" if not isinstance(value, (int, float)): msg = f"Must be numeric, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) # Check if value is a special value that bypasses range check if self.special_values is not None and value in self.special_values: @@ -58,7 +60,7 @@ def __call__(self, value: float | int) -> float | int: f" or one of {self.special_values}" if self.special_values else "" ) msg = f"Must be between {self.min} and {self.max}{special_msg}, got {value}" - raise ValueError(msg) + raise ConstraintViolationError(msg) # Validate step if provided if self.step is not None: @@ -74,7 +76,7 @@ def __call__(self, value: float | int) -> float | int: ) closest_above = closest_below + self.step msg = f"Value must match step {self.step}. Nearest valid: {closest_below} or {closest_above}, got {value}" - raise ValueError(msg) + raise ConstraintViolationError(msg) return value @@ -88,11 +90,11 @@ def __call__(self, value: str) -> str: """Validate value matches pattern.""" if not isinstance(value, str): msg = f"Must be string, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) if not re.fullmatch(self.pattern, value): msg = f"Must match pattern {self.pattern!r}, got {value!r}" - raise ValueError(msg) + raise ConstraintViolationError(msg) return value @@ -107,15 +109,15 @@ def __call__(self, value: str) -> str: """Validate value is a string.""" if not isinstance(value, str): msg = f"Must be string, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) if self.min_length is not None and len(value) < self.min_length: msg = f"String too short (min {self.min_length}), got {len(value)}" - raise ValueError(msg) + raise ConstraintViolationError(msg) if self.max_length is not None and len(value) > self.max_length: msg = f"String too long (max {self.max_length}), got {len(value)}" - raise ValueError(msg) + raise ConstraintViolationError(msg) return value @@ -128,7 +130,7 @@ def __call__(self, value: int) -> int: # isinstance(True, int) is True, so exclude bools explicitly if not isinstance(value, int) or isinstance(value, bool): msg = f"Must be int, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) return value @@ -140,7 +142,7 @@ def __call__(self, value: float) -> float: """Validate value is numeric.""" if not isinstance(value, (int, float)) or isinstance(value, bool): msg = f"Must be float or int, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) return float(value) @@ -152,7 +154,7 @@ def __call__(self, value: bool) -> bool: """Validate value is a boolean.""" if not isinstance(value, bool): msg = f"Must be bool, got {type(value).__name__}" - raise TypeError(msg) + raise ConstraintViolationError(msg) return value @@ -167,13 +169,13 @@ def __call__(self, value: type[BaseModel]) -> type[BaseModel]: inner = get_args(value)[0] if not (isinstance(inner, type) and issubclass(inner, BaseModel)): msg = f"List type must be BaseModel, got {inner}" - raise TypeError(msg) + raise ConstraintViolationError(msg) return value # For plain type, validate directly if not (isinstance(value, type) and issubclass(value, BaseModel)): msg = f"Must be BaseModel, got {value}" - raise TypeError(msg) + raise ConstraintViolationError(msg) return value diff --git a/src/celeste/credentials.py b/src/celeste/credentials.py index a173598..124a001 100644 --- a/src/celeste/credentials.py +++ b/src/celeste/credentials.py @@ -5,6 +5,7 @@ from pydantic_settings import BaseSettings from celeste.core import Provider +from celeste.exceptions import MissingCredentialsError # Provider to credential field mapping PROVIDER_CREDENTIAL_MAP = { @@ -61,15 +62,13 @@ def get_credentials( SecretStr containing the API key for the provider. Raises: - ValueError: If provider requires credentials but none are configured, - or if provider is not supported (no credential mapping). + MissingCredentialsError: If provider requires credentials but none are configured. """ if override_key: return override_key if not self.has_credential(provider): - msg = f"Provider {provider} has no credentials configured." - raise ValueError(msg) + raise MissingCredentialsError(provider=provider) credential: SecretStr = getattr(self, PROVIDER_CREDENTIAL_MAP[provider]) return credential diff --git a/src/celeste/exceptions.py b/src/celeste/exceptions.py new file mode 100644 index 0000000..d05b17a --- /dev/null +++ b/src/celeste/exceptions.py @@ -0,0 +1,173 @@ +"""Custom exceptions for Celeste.""" + + +class Error(Exception): + """Base exception for all Celeste errors.""" + + pass + + +class ModelError(Error): + """Errors related to model operations and registry.""" + + pass + + +class ModelNotFoundError(ModelError): + """Raised when a requested model cannot be found.""" + + def __init__( + self, + model_id: str | None = None, + provider: str | None = None, + capability: str | None = None, + ) -> None: + """Initialize with model details. + + Args: + model_id: Optional specific model ID that was not found. + provider: Optional provider name. + capability: Optional capability name (used when no specific model_id). + """ + self.model_id = model_id + self.provider = provider + self.capability = capability + + # Generate appropriate error message based on available parameters + if model_id and provider: + msg = f"Model '{model_id}' not found for provider {provider}" + elif capability and provider: + msg = ( + f"No model found for capability '{capability}' with provider {provider}" + ) + elif capability: + msg = f"No model found for capability '{capability}'" + else: + msg = "Model not found" + + super().__init__(msg) + + +class CapabilityError(Error): + """Errors related to capability compatibility.""" + + pass + + +class UnsupportedCapabilityError(CapabilityError): + """Raised when a model doesn't support a requested capability.""" + + def __init__(self, model_id: str, capability: str) -> None: + """Initialize with model and capability details.""" + self.model_id = model_id + self.capability = capability + super().__init__( + f"Model '{model_id}' does not support capability '{capability}'" + ) + + +class ClientError(Error): + """Errors related to client operations.""" + + pass + + +class ClientNotFoundError(ClientError): + """Raised when no client is registered for a capability/provider combination.""" + + def __init__(self, capability: str, provider: str) -> None: + """Initialize with capability and provider details.""" + self.capability = capability + self.provider = provider + super().__init__( + f"No client registered for {capability} with provider {provider}" + ) + + +class StreamingError(Error): + """Errors related to streaming operations.""" + + pass + + +class StreamingNotSupportedError(StreamingError): + """Raised when streaming is requested for a model that doesn't support it.""" + + def __init__(self, model_id: str) -> None: + """Initialize with model details.""" + self.model_id = model_id + super().__init__(f"Streaming not supported for model '{model_id}'") + + +class StreamNotExhaustedError(StreamingError): + """Raised when accessing stream output before consuming all chunks.""" + + def __init__(self) -> None: + """Initialize with helpful message.""" + super().__init__( + "Stream not exhausted. Consume all chunks before accessing .output" + ) + + +class StreamEmptyError(StreamingError): + """Raised when a stream completes without producing any chunks.""" + + def __init__(self) -> None: + """Initialize with helpful message.""" + super().__init__("Stream completed but no chunks were produced") + + +class CredentialsError(Error): + """Errors related to API credentials.""" + + pass + + +class MissingCredentialsError(CredentialsError): + """Raised when required credentials are not configured.""" + + def __init__(self, provider: str) -> None: + """Initialize with provider details.""" + self.provider = provider + super().__init__( + f"Provider {provider} has no credentials configured. " + f"Set the appropriate environment variable or pass api_key parameter." + ) + + +class ValidationError(Error): + """Errors related to parameter and constraint validation.""" + + pass + + +class ConstraintViolationError(ValidationError): + """Raised when a value violates a constraint.""" + + pass + + +class UnsupportedParameterError(ValidationError): + """Raised when a parameter is not supported by a model.""" + + def __init__(self, parameter: str, model_id: str) -> None: + """Initialize with parameter and model details.""" + self.parameter = parameter + self.model_id = model_id + super().__init__( + f"Parameter '{parameter}' is not supported by model '{model_id}'" + ) + + +__all__ = [ + "ClientNotFoundError", + "ConstraintViolationError", + "Error", + "MissingCredentialsError", + "ModelNotFoundError", + "StreamEmptyError", + "StreamNotExhaustedError", + "StreamingNotSupportedError", + "UnsupportedCapabilityError", + "UnsupportedParameterError", +] diff --git a/src/celeste/http.py b/src/celeste/http.py index 756ad5e..dbd448c 100644 --- a/src/celeste/http.py +++ b/src/celeste/http.py @@ -65,7 +65,11 @@ async def post( Raises: httpx.HTTPError: On network or timeout errors. + ValueError: If URL is empty or invalid. """ + if not url or not url.strip(): + raise ValueError("URL cannot be empty") + client = await self._get_client() return await client.post( url, @@ -94,7 +98,11 @@ async def get( Raises: httpx.HTTPError: On network or timeout errors. + ValueError: If URL is empty or invalid. """ + if not url or not url.strip(): + raise ValueError("URL cannot be empty") + client = await self._get_client() return await client.get( url, diff --git a/src/celeste/parameters.py b/src/celeste/parameters.py index ff85861..2c2b791 100644 --- a/src/celeste/parameters.py +++ b/src/celeste/parameters.py @@ -4,6 +4,7 @@ from enum import StrEnum from typing import Any, TypedDict +from celeste.exceptions import UnsupportedParameterError from celeste.models import Model @@ -36,14 +37,20 @@ def parse_output(self, content: Any, value: object | None) -> object: # noqa: A return content def _validate_value(self, value: Any, model: Model) -> Any: # noqa: ANN401 - """Validate parameter value using model constraint, raising ValueError if no constraint exists.""" + """Validate parameter value using model constraint. + + Raises: + UnsupportedParameterError: If parameter is not supported by the model. + """ if value is None: return None constraint = model.parameter_constraints.get(self.name) if constraint is None: - msg = f"Parameter {self.name.value} is not supported by model {model.id}" - raise ValueError(msg) + raise UnsupportedParameterError( + parameter=self.name, + model_id=model.id, + ) return constraint(value) diff --git a/src/celeste/streaming.py b/src/celeste/streaming.py index 16233cc..cdb2a8e 100644 --- a/src/celeste/streaming.py +++ b/src/celeste/streaming.py @@ -5,6 +5,7 @@ from types import TracebackType from typing import Any, Self, Unpack +from celeste.exceptions import StreamEmptyError, StreamNotExhaustedError from celeste.io import Chunk, Output from celeste.parameters import Parameters @@ -67,8 +68,7 @@ async def __anext__(self) -> Chunk: # Stream exhausted - validate and parse final output if not self._chunks: - msg = "Stream completed but no chunks were produced" - raise RuntimeError(msg) + raise StreamEmptyError() self._output = self._parse_output(self._chunks, **self._parameters) except Exception: @@ -96,10 +96,9 @@ async def __aexit__( @property def output(self) -> Out: - """Access final Output after stream exhaustion (raises RuntimeError if not ready).""" + """Access final Output after stream exhaustion (raises StreamNotExhaustedError if not ready).""" if self._output is None: - msg = "Stream not exhausted. Consume all chunks before accessing .output" - raise RuntimeError(msg) + raise StreamNotExhaustedError() return self._output async def aclose(self) -> None: diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 7f8317a..8bc96f5 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -6,10 +6,15 @@ import httpx import pytest -from pydantic import SecretStr, ValidationError +from pydantic import SecretStr from celeste.client import Client, _clients, get_client_class, register_client from celeste.core import Capability, Provider +from celeste.exceptions import ( + ClientNotFoundError, + StreamingNotSupportedError, + UnsupportedCapabilityError, +) from celeste.io import Input, Output, Usage from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters @@ -234,8 +239,8 @@ def test_validation_failure_with_incompatible_capability( """Client rejects model that lacks required capability.""" # Arrange & Act & Assert with pytest.raises( - ValidationError, - match=rf"Model 'gpt-4' does not support capability {Capability.IMAGE_GENERATION}", + UnsupportedCapabilityError, + match=rf"Model 'gpt-4' does not support capability '{Capability.IMAGE_GENERATION}'", ): ConcreteClient( model=text_model, @@ -286,8 +291,8 @@ def test_validation_fails_with_model_lacking_any_capabilities( # Act & Assert with pytest.raises( - ValidationError, - match=rf"Model 'broken-model' does not support capability {Capability.TEXT_GENERATION}", + UnsupportedCapabilityError, + match=rf"Model 'broken-model' does not support capability '{Capability.TEXT_GENERATION}'", ): ConcreteClient( model=empty_model, @@ -328,14 +333,14 @@ def test_register_and_retrieve_client_success(self) -> None: assert retrieved_class is ConcreteClient def test_get_client_class_raises_for_unregistered_capability(self) -> None: - """get_client_class raises NotImplementedError for unregistered capabilities.""" + """get_client_class raises ClientNotFoundError for unregistered capabilities.""" # Arrange unregistered_capability = Capability.IMAGE_GENERATION provider = Provider.OPENAI # Act & Assert with pytest.raises( - NotImplementedError, + ClientNotFoundError, match=rf"No client registered for {Capability.IMAGE_GENERATION}", ): get_client_class(unregistered_capability, provider) @@ -400,9 +405,9 @@ def test_exception_message_includes_capability_and_provider( expected_capability_str: str, expected_provider_str: str, ) -> None: - """NotImplementedError includes both capability and provider for debugging.""" + """ClientNotFoundError includes both capability and provider for debugging.""" # Arrange & Act & Assert - with pytest.raises(NotImplementedError) as exc_info: + with pytest.raises(ClientNotFoundError) as exc_info: get_client_class(missing_capability, provider) # Assert both parts in error message @@ -529,7 +534,7 @@ class TestClientStreaming: def test_stream_raises_not_implemented_with_descriptive_error( self, text_model: Model, api_key: str ) -> None: - """stream() raises NotImplementedError with capability and provider info.""" + """stream() raises StreamingNotSupportedError with capability and provider info.""" # Arrange client = ConcreteClient( model=text_model, @@ -539,7 +544,7 @@ def test_stream_raises_not_implemented_with_descriptive_error( ) # Act & Assert - with pytest.raises(NotImplementedError) as exc_info: + with pytest.raises(StreamingNotSupportedError) as exc_info: client.stream("test prompt") # Verify error message contains all debugging info diff --git a/tests/unit_tests/test_constraints.py b/tests/unit_tests/test_constraints.py index 90f039a..4ce7448 100644 --- a/tests/unit_tests/test_constraints.py +++ b/tests/unit_tests/test_constraints.py @@ -3,6 +3,7 @@ import pytest from celeste.constraints import Bool, Choice, Float, Int, Pattern, Range, Str +from celeste.exceptions import ConstraintViolationError class TestChoice: @@ -18,11 +19,11 @@ def test_validates_value_in_options(self) -> None: assert result == "b" def test_rejects_value_not_in_options(self) -> None: - """Test that invalid choice raises ValueError.""" + """Test that invalid choice raises ConstraintViolationError.""" constraint = Choice[str](options=["a", "b", "c"]) with pytest.raises( - ValueError, match=r"Must be one of \['a', 'b', 'c'\], got 'd'" + ConstraintViolationError, match=r"Must be one of \['a', 'b', 'c'\], got 'd'" ): constraint("d") @@ -60,24 +61,28 @@ def test_validates_boundary_values(self) -> None: assert constraint(10) == 10 def test_rejects_value_below_min(self) -> None: - """Test that value below min raises ValueError.""" + """Test that value below min raises ConstraintViolationError.""" constraint = Range(min=0, max=10) - with pytest.raises(ValueError, match=r"Must be between 0 and 10, got -1"): + with pytest.raises( + ConstraintViolationError, match=r"Must be between 0 and 10, got -1" + ): constraint(-1) def test_rejects_value_above_max(self) -> None: - """Test that value above max raises ValueError.""" + """Test that value above max raises ConstraintViolationError.""" constraint = Range(min=0, max=10) - with pytest.raises(ValueError, match=r"Must be between 0 and 10, got 11"): + with pytest.raises( + ConstraintViolationError, match=r"Must be between 0 and 10, got 11" + ): constraint(11) def test_rejects_non_numeric_value(self) -> None: - """Test that non-numeric value raises TypeError.""" + """Test that non-numeric value raises ConstraintViolationError.""" constraint = Range(min=0, max=10) - with pytest.raises(TypeError, match=r"Must be numeric, got str"): + with pytest.raises(ConstraintViolationError, match=r"Must be numeric, got str"): constraint("5") # type: ignore[arg-type] def test_accepts_both_int_and_float(self) -> None: @@ -97,11 +102,11 @@ def test_validates_value_with_step(self) -> None: assert constraint(10) == 10 # max def test_rejects_value_not_on_step(self) -> None: - """Test that value not on step increment raises ValueError.""" + """Test that value not on step increment raises ConstraintViolationError.""" constraint = Range(min=0, max=10, step=2) with pytest.raises( - ValueError, + ConstraintViolationError, match=r"Value must match step 2(\.0)?. Nearest valid: 2(\.0)? or 4(\.0)?, got 3", ): constraint(3) @@ -126,7 +131,7 @@ def test_step_validation_with_non_zero_min(self) -> None: assert constraint(14) == 14 # min + 9 with pytest.raises( - ValueError, + ConstraintViolationError, match=r"Value must match step 3(\.0)?. Nearest valid: 5(\.0)? or 8(\.0)?, got 7", ): constraint(7) @@ -169,17 +174,17 @@ def test_validates_matching_pattern(self) -> None: assert result == "123-4567" def test_rejects_non_matching_pattern(self) -> None: - """Test that non-matching pattern raises ValueError.""" + """Test that non-matching pattern raises ConstraintViolationError.""" constraint = Pattern(pattern=r"^\d{3}-\d{4}$") - with pytest.raises(ValueError, match=r"Must match pattern"): + with pytest.raises(ConstraintViolationError, match=r"Must match pattern"): constraint("abc-defg") def test_rejects_non_string_value(self) -> None: - """Test that non-string value raises TypeError.""" + """Test that non-string value raises ConstraintViolationError.""" constraint = Pattern(pattern=r"^\d+$") - with pytest.raises(TypeError, match=r"Must be string, got int"): + with pytest.raises(ConstraintViolationError, match=r"Must be string, got int"): constraint(123) # type: ignore[arg-type] def test_validates_complex_regex_patterns(self) -> None: @@ -213,24 +218,28 @@ def test_validates_string_within_length_bounds(self) -> None: assert result == "valid" def test_rejects_string_below_min_length(self) -> None: - """Test string shorter than min_length raises ValueError.""" + """Test string shorter than min_length raises ConstraintViolationError.""" constraint = Str(min_length=5) - with pytest.raises(ValueError, match=r"String too short \(min 5\), got 3"): + with pytest.raises( + ConstraintViolationError, match=r"String too short \(min 5\), got 3" + ): constraint("abc") def test_rejects_string_above_max_length(self) -> None: - """Test string longer than max_length raises ValueError.""" + """Test string longer than max_length raises ConstraintViolationError.""" constraint = Str(max_length=5) - with pytest.raises(ValueError, match=r"String too long \(max 5\), got 10"): + with pytest.raises( + ConstraintViolationError, match=r"String too long \(max 5\), got 10" + ): constraint("too long!!") def test_rejects_non_string_value(self) -> None: - """Test non-string value raises TypeError.""" + """Test non-string value raises ConstraintViolationError.""" constraint = Str() - with pytest.raises(TypeError, match=r"Must be string, got int"): + with pytest.raises(ConstraintViolationError, match=r"Must be string, got int"): constraint(123) # type: ignore[arg-type] def test_validates_boundary_lengths(self) -> None: @@ -254,24 +263,24 @@ def test_validates_integer_value(self) -> None: assert result == 42 def test_rejects_float_value(self) -> None: - """Test that float raises TypeError.""" + """Test that float raises ConstraintViolationError.""" constraint = Int() - with pytest.raises(TypeError, match=r"Must be int, got float"): + with pytest.raises(ConstraintViolationError, match=r"Must be int, got float"): constraint(42.0) # type: ignore[arg-type] def test_rejects_boolean_value(self) -> None: - """Test that bool raises TypeError despite isinstance(True, int).""" + """Test that bool raises ConstraintViolationError despite isinstance(True, int).""" constraint = Int() - with pytest.raises(TypeError, match=r"Must be int, got bool"): + with pytest.raises(ConstraintViolationError, match=r"Must be int, got bool"): constraint(True) def test_rejects_string_value(self) -> None: - """Test that string raises TypeError.""" + """Test that string raises ConstraintViolationError.""" constraint = Int() - with pytest.raises(TypeError, match=r"Must be int, got str"): + with pytest.raises(ConstraintViolationError, match=r"Must be int, got str"): constraint("42") # type: ignore[arg-type] @@ -297,17 +306,21 @@ def test_accepts_and_converts_int_to_float(self) -> None: assert isinstance(result, float) def test_rejects_boolean_value(self) -> None: - """Test that bool raises TypeError despite isinstance(True, int).""" + """Test that bool raises ConstraintViolationError despite isinstance(True, int).""" constraint = Float() - with pytest.raises(TypeError, match=r"Must be float or int, got bool"): + with pytest.raises( + ConstraintViolationError, match=r"Must be float or int, got bool" + ): constraint(True) def test_rejects_string_value(self) -> None: - """Test that string raises TypeError.""" + """Test that string raises ConstraintViolationError.""" constraint = Float() - with pytest.raises(TypeError, match=r"Must be float or int, got str"): + with pytest.raises( + ConstraintViolationError, match=r"Must be float or int, got str" + ): constraint("3.14") # type: ignore[arg-type] @@ -323,15 +336,15 @@ def test_validates_boolean_value(self) -> None: assert constraint(False) is False def test_rejects_int_value(self) -> None: - """Test that int raises TypeError (no implicit 0/1 conversion).""" + """Test that int raises ConstraintViolationError (no implicit 0/1 conversion).""" constraint = Bool() - with pytest.raises(TypeError, match=r"Must be bool, got int"): + with pytest.raises(ConstraintViolationError, match=r"Must be bool, got int"): constraint(1) # type: ignore[arg-type] def test_rejects_string_value(self) -> None: - """Test that string raises TypeError.""" + """Test that string raises ConstraintViolationError.""" constraint = Bool() - with pytest.raises(TypeError, match=r"Must be bool, got str"): + with pytest.raises(ConstraintViolationError, match=r"Must be bool, got str"): constraint("true") # type: ignore[arg-type] diff --git a/tests/unit_tests/test_credentials.py b/tests/unit_tests/test_credentials.py index 573f0fa..7c537f8 100644 --- a/tests/unit_tests/test_credentials.py +++ b/tests/unit_tests/test_credentials.py @@ -9,6 +9,7 @@ from celeste.core import Provider from celeste.credentials import PROVIDER_CREDENTIAL_MAP, Credentials +from celeste.exceptions import MissingCredentialsError # Single source of truth for environment variable names ENV_VAR_NAMES = [ @@ -137,7 +138,7 @@ def test_get_missing_credential_raises( creds = Credentials() # type: ignore[call-arg] # Act & Assert - with pytest.raises(ValueError, match="no credentials configured"): + with pytest.raises(MissingCredentialsError, match="no credentials configured"): creds.get_credentials(Provider.OPENAI) @pytest.mark.parametrize( diff --git a/tests/unit_tests/test_exceptions.py b/tests/unit_tests/test_exceptions.py new file mode 100644 index 0000000..b88ade6 --- /dev/null +++ b/tests/unit_tests/test_exceptions.py @@ -0,0 +1,243 @@ +"""Tests for custom exception classes.""" + +import pytest + +from celeste.exceptions import ( + ClientNotFoundError, + ConstraintViolationError, + Error, + MissingCredentialsError, + ModelNotFoundError, + StreamEmptyError, + StreamingNotSupportedError, + StreamNotExhaustedError, + UnsupportedCapabilityError, + UnsupportedParameterError, +) + + +class TestExceptionHierarchy: + """Test exception hierarchy and inheritance.""" + + def test_all_exceptions_inherit_from_error(self) -> None: + """Test that all custom exceptions inherit from Error.""" + exceptions = [ + ModelNotFoundError("model-1", "openai"), + UnsupportedCapabilityError("model-1", "text-generation"), + ClientNotFoundError("text-generation", "openai"), + StreamingNotSupportedError("model-1"), + StreamNotExhaustedError(), + StreamEmptyError(), + MissingCredentialsError("openai"), + UnsupportedParameterError("temperature", "model-1"), + ConstraintViolationError("Must be between 0 and 1, got 2"), + ] + + for exc in exceptions: + assert isinstance(exc, Error) + assert isinstance(exc, Exception) + + +class TestModelNotFoundError: + """Test ModelNotFoundError exception.""" + + def test_creates_with_model_and_provider(self) -> None: + """Test exception stores model and provider attributes.""" + exc = ModelNotFoundError(model_id="gpt-4", provider="openai") + + assert exc.model_id == "gpt-4" + assert exc.provider == "openai" + assert exc.capability is None + assert "gpt-4" in str(exc) + assert "openai" in str(exc) + + def test_message_is_descriptive_with_model_and_provider(self) -> None: + """Test exception message is clear and actionable for specific model.""" + exc = ModelNotFoundError(model_id="claude-3", provider="anthropic") + + assert str(exc) == "Model 'claude-3' not found for provider anthropic" + + def test_creates_with_capability_only(self) -> None: + """Test exception with capability only (no models available for capability).""" + exc = ModelNotFoundError(capability="text-generation") + + assert exc.model_id is None + assert exc.provider is None + assert exc.capability == "text-generation" + assert "No model found for capability 'text-generation'" in str(exc) + + def test_creates_with_capability_and_provider(self) -> None: + """Test exception with capability and provider (no models for capability/provider combo).""" + exc = ModelNotFoundError(capability="text-generation", provider="openai") + + assert exc.model_id is None + assert exc.provider == "openai" + assert exc.capability == "text-generation" + assert ( + "No model found for capability 'text-generation' with provider openai" + in str(exc) + ) + + def test_backwards_compatibility_positional_args(self) -> None: + """Test that positional arguments still work for backwards compatibility.""" + exc = ModelNotFoundError("gpt-4", "openai") + + assert exc.model_id == "gpt-4" + assert exc.provider == "openai" + assert "gpt-4" in str(exc) + + +class TestUnsupportedCapabilityError: + """Test UnsupportedCapabilityError exception.""" + + def test_creates_with_model_and_capability(self) -> None: + """Test exception stores model and capability attributes.""" + exc = UnsupportedCapabilityError("gpt-4", "image-generation") + + assert exc.model_id == "gpt-4" + assert exc.capability == "image-generation" + assert "gpt-4" in str(exc) + assert "image-generation" in str(exc) + + def test_message_is_descriptive(self) -> None: + """Test exception message is clear and actionable.""" + exc = UnsupportedCapabilityError("gpt-3.5-turbo", "video-generation") + + assert ( + str(exc) + == "Model 'gpt-3.5-turbo' does not support capability 'video-generation'" + ) + + +class TestClientNotFoundError: + """Test ClientNotFoundError exception.""" + + def test_creates_with_capability_and_provider(self) -> None: + """Test exception stores capability and provider attributes.""" + exc = ClientNotFoundError("text-generation", "unknown_provider") + + assert exc.capability == "text-generation" + assert exc.provider == "unknown_provider" + assert "text-generation" in str(exc) + assert "unknown_provider" in str(exc) + + +class TestStreamingNotSupportedError: + """Test StreamingNotSupportedError exception.""" + + def test_creates_with_model_id(self) -> None: + """Test exception stores model_id attribute.""" + exc = StreamingNotSupportedError("dall-e-3") + + assert exc.model_id == "dall-e-3" + assert "dall-e-3" in str(exc) + assert "Streaming not supported" in str(exc) + + +class TestStreamNotExhaustedError: + """Test StreamNotExhaustedError exception.""" + + def test_has_helpful_message(self) -> None: + """Test exception message guides user to consume chunks first.""" + exc = StreamNotExhaustedError() + + message = str(exc) + assert "not exhausted" in message.lower() + assert "consume all chunks" in message.lower() + assert ".output" in message + + +class TestStreamEmptyError: + """Test StreamEmptyError exception.""" + + def test_has_descriptive_message(self) -> None: + """Test exception message describes the problem clearly.""" + exc = StreamEmptyError() + + message = str(exc) + assert "completed" in message.lower() + assert "no chunks" in message.lower() + + +class TestMissingCredentialsError: + """Test MissingCredentialsError exception.""" + + def test_creates_with_provider(self) -> None: + """Test exception stores provider attribute.""" + exc = MissingCredentialsError("openai") + + assert exc.provider == "openai" + assert "openai" in str(exc) + + def test_message_provides_guidance(self) -> None: + """Test exception message helps user resolve the issue.""" + exc = MissingCredentialsError("anthropic") + + message = str(exc) + assert "anthropic" in message + assert "environment variable" in message.lower() + assert "api_key" in message.lower() + + +class TestUnsupportedParameterError: + """Test UnsupportedParameterError exception.""" + + def test_creates_with_parameter_and_model(self) -> None: + """Test exception stores parameter and model_id attributes.""" + exc = UnsupportedParameterError("temperature", "dall-e-3") + + assert exc.parameter == "temperature" + assert exc.model_id == "dall-e-3" + assert "temperature" in str(exc) + assert "dall-e-3" in str(exc) + + def test_message_is_clear(self) -> None: + """Test exception message clearly indicates the problem.""" + exc = UnsupportedParameterError("max_tokens", "whisper-1") + + assert ( + str(exc) == "Parameter 'max_tokens' is not supported by model 'whisper-1'" + ) + + +class TestConstraintViolationError: + """Test ConstraintViolationError exception.""" + + def test_creates_with_message(self) -> None: + """Test exception can be created with a message.""" + exc = ConstraintViolationError("Must be between 0 and 1, got 2") + + assert "Must be between 0 and 1" in str(exc) + assert "got 2" in str(exc) + + def test_inherits_from_validation_error(self) -> None: + """Test ConstraintViolationError inherits from ValidationError.""" + from celeste.exceptions import ValidationError + + exc = ConstraintViolationError("Test message") + assert isinstance(exc, ValidationError) + assert isinstance(exc, Error) + + +class TestExceptionUsability: + """Test that exceptions can be raised and caught properly.""" + + def test_can_catch_specific_exception(self) -> None: + """Test specific exception types can be caught.""" + with pytest.raises(ModelNotFoundError) as exc_info: + raise ModelNotFoundError("test-model", "test-provider") + + assert exc_info.value.model_id == "test-model" + + def test_can_catch_base_exception(self) -> None: + """Test all custom exceptions can be caught as Error.""" + with pytest.raises(Error): + raise StreamEmptyError() + + def test_can_access_exception_attributes(self) -> None: + """Test exception attributes are accessible after catching.""" + try: + raise UnsupportedParameterError("seed", "gpt-4") + except UnsupportedParameterError as e: + assert e.parameter == "seed" + assert e.model_id == "gpt-4" diff --git a/tests/unit_tests/test_init.py b/tests/unit_tests/test_init.py index 4cd4e81..e4affcd 100644 --- a/tests/unit_tests/test_init.py +++ b/tests/unit_tests/test_init.py @@ -6,6 +6,7 @@ from pydantic import SecretStr from celeste import Capability, Model, Provider, create_client +from celeste.exceptions import ModelNotFoundError @pytest.fixture @@ -37,14 +38,15 @@ class TestCreateClient: """Test the create_client factory function.""" def test_create_client_no_models_available_raises_error(self) -> None: - """Test that create_client raises ValueError when no models are available.""" + """Test that create_client raises ModelNotFoundError when no models are available.""" with patch("celeste.list_models", autospec=True) as mock_list_models: # Arrange mock_list_models.return_value = [] # Act & Assert with pytest.raises( - ValueError, match=rf"No model found for.*{Capability.TEXT_GENERATION}" + ModelNotFoundError, + match=rf"No model found for capability.*{Capability.TEXT_GENERATION}", ): create_client(capability=Capability.TEXT_GENERATION) @@ -55,7 +57,7 @@ def test_create_client_specific_model_not_found_raises_error(self) -> None: mock_get_model.return_value = None # Act & Assert - with pytest.raises(ValueError, match=r"Model.*not found"): + with pytest.raises(ModelNotFoundError, match=r"Model.*not found"): create_client( capability=Capability.TEXT_GENERATION, provider=Provider.OPENAI, diff --git a/tests/unit_tests/test_streaming.py b/tests/unit_tests/test_streaming.py index bd60c71..d6af270 100644 --- a/tests/unit_tests/test_streaming.py +++ b/tests/unit_tests/test_streaming.py @@ -7,6 +7,7 @@ import pytest from pydantic import Field +from celeste.exceptions import StreamEmptyError, StreamNotExhaustedError from celeste.io import Chunk, FinishReason, Output, Usage from celeste.parameters import Parameters from celeste.streaming import Stream @@ -155,14 +156,14 @@ class TestStreamOutputProperty: """Test Stream.output property - access guard and final result.""" async def test_output_access_before_exhaustion_raises_error(self) -> None: - """Accessing .output before stream exhaustion must raise RuntimeError.""" + """Accessing .output before stream exhaustion must raise StreamNotExhaustedError.""" # Arrange events = [{"delta": "test"}] stream = ConcreteStream(_async_iter(events)) - # Act & Assert - Premature access raises RuntimeError + # Act & Assert - Premature access raises StreamNotExhaustedError with pytest.raises( - RuntimeError, match=r"Stream not exhausted\. Consume all chunks" + StreamNotExhaustedError, match=r"Stream not exhausted\. Consume all chunks" ): _ = stream.output @@ -478,7 +479,7 @@ class TestStreamEmptyStreamError: """Test Stream empty stream error handling.""" async def test_empty_stream_raises_runtime_error(self) -> None: - """Stream must raise RuntimeError when exhausted with no chunks.""" + """Stream must raise StreamEmptyError when exhausted with no chunks.""" # Arrange - Create stream where all events return None from _parse_chunk async def empty_iter() -> AsyncIterator[dict[str, Any]]: @@ -487,15 +488,15 @@ async def empty_iter() -> AsyncIterator[dict[str, Any]]: stream = ConcreteStream(empty_iter()) - # Act & Assert - Exhaustion raises RuntimeError + # Act & Assert - Exhaustion raises StreamEmptyError with pytest.raises( - RuntimeError, match=r"Stream completed but no chunks were produced" + StreamEmptyError, match=r"Stream completed but no chunks were produced" ): async for _ in stream: pass async def test_stream_with_only_lifecycle_events_raises_error(self) -> None: - """Stream raises RuntimeError when SSE yields events but all chunks are filtered to None.""" + """Stream raises StreamEmptyError when SSE yields events but all chunks are filtered to None.""" # Arrange - Events that all return None from _parse_chunk events = [ {"type": "ping"}, # Lifecycle event (no delta/content) @@ -505,9 +506,9 @@ async def test_stream_with_only_lifecycle_events_raises_error(self) -> None: ] stream = ConcreteStream(_async_iter(events)) - # Act & Assert - Should raise RuntimeError when exhausted + # Act & Assert - Should raise StreamEmptyError when exhausted with pytest.raises( - RuntimeError, match=r"Stream completed but no chunks were produced" + StreamEmptyError, match=r"Stream completed but no chunks were produced" ): async for _ in stream: pass