Skip to content
37 changes: 31 additions & 6 deletions src/celeste/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@
from celeste.client import Client, get_client_class, register_client
from celeste.core import Capability, Parameter, Provider
from celeste.credentials import credentials
from celeste.exceptions import (
ClientNotFoundError,
ConstraintViolationError,
Error,
MissingCredentialsError,
ModelNotFoundError,
StreamEmptyError,
StreamingNotSupportedError,
StreamNotExhaustedError,
UnsupportedCapabilityError,
UnsupportedParameterError,
)
from celeste.http import HTTPClient, close_all_http_clients
from celeste.io import Input, Output, Usage
from celeste.models import Model, get_model, list_models, register_models
Expand All @@ -24,8 +36,10 @@ def _resolve_model(
# Auto-select first available model
models = list_models(provider=provider, capability=capability)
if not models:
msg = f"No model found for {capability}"
raise ValueError(msg)
raise ModelNotFoundError(
capability=capability,
provider=provider if provider else None,
)
return models[0]

if isinstance(model, str):
Expand All @@ -35,8 +49,7 @@ def _resolve_model(
raise ValueError(msg)
found = get_model(model, provider)
if not found:
msg = f"Model '{model}' not found for provider {provider}"
raise ValueError(msg)
raise ModelNotFoundError(model_id=model, provider=provider)
return found

return model
Expand All @@ -60,8 +73,10 @@ def create_client(
Configured client instance ready for generation operations.

Raises:
ValueError: If no model found or resolution fails.
NotImplementedError: If no client registered for capability/provider.
ModelNotFoundError: If no model found for the specified capability/provider.
ClientNotFoundError: If no client registered for capability/provider.
MissingCredentialsError: If required credentials are not configured.
UnsupportedCapabilityError: If the resolved model doesn't support the requested capability.
"""
# Resolve model
resolved_model = _resolve_model(capability, provider, model)
Expand Down Expand Up @@ -98,13 +113,23 @@ def _load_from_entry_points() -> None:
__all__ = [
"Capability",
"Client",
"ClientNotFoundError",
"ConstraintViolationError",
"Error",
"HTTPClient",
"Input",
"MissingCredentialsError",
"Model",
"ModelNotFoundError",
"Output",
"Parameter",
"Parameters",
"Provider",
"StreamEmptyError",
"StreamNotExhaustedError",
"StreamingNotSupportedError",
"UnsupportedCapabilityError",
"UnsupportedParameterError",
"Usage",
"close_all_http_clients",
"create_client",
Expand Down
26 changes: 16 additions & 10 deletions src/celeste/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from pydantic import BaseModel, ConfigDict, Field, SecretStr

from celeste.core import Capability, Provider
from celeste.exceptions import (
ClientNotFoundError,
StreamingNotSupportedError,
UnsupportedCapabilityError,
)
from celeste.http import HTTPClient, get_http_client
from celeste.io import Input, Output, Usage
from celeste.models import Model
Expand All @@ -29,8 +34,9 @@ class Client[In: Input, Out: Output, Params: Parameters](ABC, BaseModel):
def model_post_init(self, __context: object) -> None:
"""Validate capability compatibility."""
if self.capability not in self.model.capabilities:
raise ValueError(
f"Model '{self.model.id}' does not support capability {self.capability}"
raise UnsupportedCapabilityError(
model_id=self.model.id,
capability=self.capability,
)

@property
Expand Down Expand Up @@ -78,11 +84,10 @@ def stream(
Stream yielding chunks and providing final Output.

Raises:
NotImplementedError: If model doesn't support streaming.
StreamingNotSupportedError: If model doesn't support streaming.
"""
if not self.model.streaming:
msg = f"Streaming not supported for model '{self.model.id}'"
raise NotImplementedError(msg)
raise StreamingNotSupportedError(model_id=self.model.id)

inputs = self._create_inputs(*args, **parameters)
request_body = self._build_request(inputs, **parameters)
Expand Down Expand Up @@ -160,7 +165,7 @@ def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]:
"""Build metadata dictionary from response data."""
return {
"model": self.model.id,
"provider": self.provider.value,
"provider": self.provider,
}

def _handle_error_response(self, response: httpx.Response) -> None:
Expand All @@ -173,7 +178,7 @@ def _handle_error_response(self, response: httpx.Response) -> None:
error_msg = response.text or f"HTTP {response.status_code}"

raise httpx.HTTPStatusError(
f"{self.provider.value} API error: {error_msg}",
f"{self.provider} API error: {error_msg}",
request=response.request,
response=response,
)
Expand Down Expand Up @@ -236,11 +241,12 @@ def get_client_class(
The registered client class.

Raises:
NotImplementedError: If no client is registered for this capability/provider.
ClientNotFoundError: If no client is registered for this capability/provider.
"""
if (capability, provider) not in _clients:
raise NotImplementedError(
f"No client registered for {capability} with provider {provider}"
raise ClientNotFoundError(
capability=capability,
provider=provider,
)
return _clients[(capability, provider)]

Expand Down
30 changes: 16 additions & 14 deletions src/celeste/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from pydantic import BaseModel, Field

from celeste.exceptions import ConstraintViolationError


class Constraint(BaseModel, ABC):
"""Base constraint for parameter validation."""
Expand All @@ -26,7 +28,7 @@ def __call__(self, value: T) -> T:
"""Validate value is in options."""
if value not in self.options:
msg = f"Must be one of {self.options}, got {value!r}"
raise ValueError(msg)
raise ConstraintViolationError(msg)
return value


Expand All @@ -46,7 +48,7 @@ def __call__(self, value: float | int) -> float | int:
"""Validate value is within range and matches step increment."""
if not isinstance(value, (int, float)):
msg = f"Must be numeric, got {type(value).__name__}"
raise TypeError(msg)
raise ConstraintViolationError(msg)

# Check if value is a special value that bypasses range check
if self.special_values is not None and value in self.special_values:
Expand All @@ -58,7 +60,7 @@ def __call__(self, value: float | int) -> float | int:
f" or one of {self.special_values}" if self.special_values else ""
)
msg = f"Must be between {self.min} and {self.max}{special_msg}, got {value}"
raise ValueError(msg)
raise ConstraintViolationError(msg)

# Validate step if provided
if self.step is not None:
Expand All @@ -74,7 +76,7 @@ def __call__(self, value: float | int) -> float | int:
)
closest_above = closest_below + self.step
msg = f"Value must match step {self.step}. Nearest valid: {closest_below} or {closest_above}, got {value}"
raise ValueError(msg)
raise ConstraintViolationError(msg)

return value

Expand All @@ -88,11 +90,11 @@ def __call__(self, value: str) -> str:
"""Validate value matches pattern."""
if not isinstance(value, str):
msg = f"Must be string, got {type(value).__name__}"
raise TypeError(msg)
raise ConstraintViolationError(msg)

if not re.fullmatch(self.pattern, value):
msg = f"Must match pattern {self.pattern!r}, got {value!r}"
raise ValueError(msg)
raise ConstraintViolationError(msg)

return value

Expand All @@ -107,15 +109,15 @@ def __call__(self, value: str) -> str:
"""Validate value is a string."""
if not isinstance(value, str):
msg = f"Must be string, got {type(value).__name__}"
raise TypeError(msg)
raise ConstraintViolationError(msg)

if self.min_length is not None and len(value) < self.min_length:
msg = f"String too short (min {self.min_length}), got {len(value)}"
raise ValueError(msg)
raise ConstraintViolationError(msg)

if self.max_length is not None and len(value) > self.max_length:
msg = f"String too long (max {self.max_length}), got {len(value)}"
raise ValueError(msg)
raise ConstraintViolationError(msg)

return value

Expand All @@ -128,7 +130,7 @@ def __call__(self, value: int) -> int:
# isinstance(True, int) is True, so exclude bools explicitly
if not isinstance(value, int) or isinstance(value, bool):
msg = f"Must be int, got {type(value).__name__}"
raise TypeError(msg)
raise ConstraintViolationError(msg)

return value

Expand All @@ -140,7 +142,7 @@ def __call__(self, value: float) -> float:
"""Validate value is numeric."""
if not isinstance(value, (int, float)) or isinstance(value, bool):
msg = f"Must be float or int, got {type(value).__name__}"
raise TypeError(msg)
raise ConstraintViolationError(msg)

return float(value)

Expand All @@ -152,7 +154,7 @@ def __call__(self, value: bool) -> bool:
"""Validate value is a boolean."""
if not isinstance(value, bool):
msg = f"Must be bool, got {type(value).__name__}"
raise TypeError(msg)
raise ConstraintViolationError(msg)

return value

Expand All @@ -167,13 +169,13 @@ def __call__(self, value: type[BaseModel]) -> type[BaseModel]:
inner = get_args(value)[0]
if not (isinstance(inner, type) and issubclass(inner, BaseModel)):
msg = f"List type must be BaseModel, got {inner}"
raise TypeError(msg)
raise ConstraintViolationError(msg)
return value

# For plain type, validate directly
if not (isinstance(value, type) and issubclass(value, BaseModel)):
msg = f"Must be BaseModel, got {value}"
raise TypeError(msg)
raise ConstraintViolationError(msg)

return value

Expand Down
7 changes: 3 additions & 4 deletions src/celeste/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic_settings import BaseSettings

from celeste.core import Provider
from celeste.exceptions import MissingCredentialsError

# Provider to credential field mapping
PROVIDER_CREDENTIAL_MAP = {
Expand Down Expand Up @@ -61,15 +62,13 @@ def get_credentials(
SecretStr containing the API key for the provider.

Raises:
ValueError: If provider requires credentials but none are configured,
or if provider is not supported (no credential mapping).
MissingCredentialsError: If provider requires credentials but none are configured.
"""
if override_key:
return override_key

if not self.has_credential(provider):
msg = f"Provider {provider} has no credentials configured."
raise ValueError(msg)
raise MissingCredentialsError(provider=provider)

credential: SecretStr = getattr(self, PROVIDER_CREDENTIAL_MAP[provider])
return credential
Expand Down
Loading
Loading