diff --git a/Makefile b/Makefile index 8be072a..a412fc1 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,8 @@ help: @echo " make format - Apply Ruff formatting" @echo " make typecheck - Run mypy type checking" @echo " make test - Run all tests (core + packages) with coverage" - @echo " make integration-test - Run integration tests (requires API keys)" + @echo " make integration-test [capability] - Run integration tests (all or specific)" + @echo " (e.g., make integration-test image-intelligence)" @echo " make security - Run Bandit security scan" @echo " make ci - Run full CI/CD pipeline" @echo " make clean - Clean cache directories" @@ -33,15 +34,22 @@ format: # Type checking (fail fast on any error) typecheck: - @uv run mypy -p celeste && uv run mypy tests/ && uv run mypy packages/capabilities/image-generation packages/capabilities/text-generation packages/capabilities/video-generation packages/capabilities/speech-generation + @uv run mypy -p celeste && uv run mypy tests/ && uv run mypy packages/*/*/src/ # Testing test: - uv run pytest tests/unit_tests packages/capabilities/*/tests/unit_tests --cov=celeste --cov-report=term-missing --cov-fail-under=80 -v + uv run pytest tests/unit_tests --cov=celeste --cov-report=term-missing --cov-fail-under=80 -v # Integration testing (requires API keys) +# Usage: make integration-test [capability] integration-test: - uv run pytest packages/capabilities/*/tests/integration_tests/ -m integration -v --dist=worksteal -n auto + @cap="$(filter-out $@,$(MAKECMDGOALS))"; \ + if [ -z "$$cap" ]; then cap="*"; fi; \ + uv run pytest packages/capabilities/$$cap/tests/integration_tests/ -m integration -v --dist=worksteal -n auto + +# Catch capability names as no-op targets +%: + @: # Security scanning (config reads from pyproject.toml) security: diff --git a/packages/capabilities/image-generation/src/celeste_image_generation/providers/bfl/client.py b/packages/capabilities/image-generation/src/celeste_image_generation/providers/bfl/client.py index 92cd1b5..cbdbcab 100644 --- a/packages/capabilities/image-generation/src/celeste_image_generation/providers/bfl/client.py +++ b/packages/capabilities/image-generation/src/celeste_image_generation/providers/bfl/client.py @@ -78,7 +78,7 @@ async def _make_request( ) -> httpx.Response: """Make HTTP request(s) and return response object.""" headers = { - config.AUTH_HEADER_NAME: self.api_key.get_secret_value(), + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, "Accept": ApplicationMimeType.JSON, } @@ -103,7 +103,7 @@ async def _make_request( start_time = time.monotonic() poll_headers = { - config.AUTH_HEADER_NAME: self.api_key.get_secret_value(), + **self.auth.get_headers(), "Accept": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/image-generation/src/celeste_image_generation/providers/bytedance/client.py b/packages/capabilities/image-generation/src/celeste_image_generation/providers/bytedance/client.py index f86ced4..f4cc752 100644 --- a/packages/capabilities/image-generation/src/celeste_image_generation/providers/bytedance/client.py +++ b/packages/capabilities/image-generation/src/celeste_image_generation/providers/bytedance/client.py @@ -119,7 +119,7 @@ async def _make_request( request_body["stream"] = False headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } @@ -142,7 +142,7 @@ def _make_stream_request( request_body["stream"] = True headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/image-generation/src/celeste_image_generation/providers/google/client.py b/packages/capabilities/image-generation/src/celeste_image_generation/providers/google/client.py index ab5a40a..faca14f 100644 --- a/packages/capabilities/image-generation/src/celeste_image_generation/providers/google/client.py +++ b/packages/capabilities/image-generation/src/celeste_image_generation/providers/google/client.py @@ -28,7 +28,7 @@ class GoogleImageGenerationClient(ImageGenerationClient): model_config = ConfigDict(extra="allow") - def model_post_init(self, __context: Any) -> None: # noqa: ANN401 + def model_post_init(self, __context: Any) -> None: """Initialize API adapter based on model type.""" super().model_post_init(__context) @@ -103,7 +103,7 @@ async def _make_request( ) -> httpx.Response: """Make HTTP request(s) and return response object.""" headers = { - config.AUTH_HEADER_NAME: self.api_key.get_secret_value(), + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/image-generation/src/celeste_image_generation/providers/openai/client.py b/packages/capabilities/image-generation/src/celeste_image_generation/providers/openai/client.py index 226af15..69a8bd8 100644 --- a/packages/capabilities/image-generation/src/celeste_image_generation/providers/openai/client.py +++ b/packages/capabilities/image-generation/src/celeste_image_generation/providers/openai/client.py @@ -94,7 +94,7 @@ async def _make_request( ) -> httpx.Response: """Make HTTP request(s) and return response object.""" headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } @@ -120,7 +120,7 @@ def _make_stream_request( request_body["partial_images"] = 1 headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/image-generation/tests/unit_tests/providers/google/test_finish_reason.py b/packages/capabilities/image-generation/tests/unit_tests/providers/google/test_finish_reason.py index 466f608..7943df1 100644 --- a/packages/capabilities/image-generation/tests/unit_tests/providers/google/test_finish_reason.py +++ b/packages/capabilities/image-generation/tests/unit_tests/providers/google/test_finish_reason.py @@ -6,6 +6,7 @@ from celeste_image_generation.providers.google.client import GoogleImageGenerationClient from pydantic import SecretStr +from celeste.auth import APIKey from celeste.core import Capability, Provider from celeste.models import Model @@ -25,7 +26,7 @@ def client(self) -> GoogleImageGenerationClient: ), provider=Provider.GOOGLE, capability=Capability.IMAGE_GENERATION, - api_key=SecretStr("test-key"), + auth=APIKey(key=SecretStr("test-key")), ) @pytest.mark.parametrize( diff --git a/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/elevenlabs/client.py b/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/elevenlabs/client.py index 1c3ae0b..8324bc2 100644 --- a/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/elevenlabs/client.py +++ b/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/elevenlabs/client.py @@ -88,7 +88,7 @@ async def _make_request( endpoint = config.ENDPOINT.format(voice_id=voice_id) headers = { - config.AUTH_HEADER_NAME: self.api_key.get_secret_value(), + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } @@ -152,7 +152,7 @@ def _make_stream_request( stream_endpoint = config.STREAM_ENDPOINT.format(voice_id=voice_id) headers = { - config.AUTH_HEADER_NAME: self.api_key.get_secret_value(), + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/google/client.py b/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/google/client.py index 3ccfee9..612fd28 100644 --- a/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/google/client.py +++ b/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/google/client.py @@ -109,7 +109,7 @@ async def _make_request( endpoint = config.ENDPOINT.format(model_id=self.model.id) headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/openai/client.py b/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/openai/client.py index 3465fc3..f968580 100644 --- a/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/openai/client.py +++ b/packages/capabilities/speech-generation/src/celeste_speech_generation/providers/openai/client.py @@ -74,7 +74,7 @@ async def _make_request( request_body["model"] = self.model.id headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/text-generation/src/celeste_text_generation/providers/anthropic/client.py b/packages/capabilities/text-generation/src/celeste_text_generation/providers/anthropic/client.py index 6436750..66c6210 100644 --- a/packages/capabilities/text-generation/src/celeste_text_generation/providers/anthropic/client.py +++ b/packages/capabilities/text-generation/src/celeste_text_generation/providers/anthropic/client.py @@ -100,7 +100,7 @@ async def _make_request( request_body["max_tokens"] = parameters.get("max_tokens") or 1024 headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), config.ANTHROPIC_VERSION_HEADER: config.ANTHROPIC_VERSION, "Content-Type": ApplicationMimeType.JSON, } @@ -129,7 +129,7 @@ def _make_stream_request( request_body["stream"] = True headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), config.ANTHROPIC_VERSION_HEADER: config.ANTHROPIC_VERSION, "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/text-generation/src/celeste_text_generation/providers/cohere/client.py b/packages/capabilities/text-generation/src/celeste_text_generation/providers/cohere/client.py index 631d611..6ec85c9 100644 --- a/packages/capabilities/text-generation/src/celeste_text_generation/providers/cohere/client.py +++ b/packages/capabilities/text-generation/src/celeste_text_generation/providers/cohere/client.py @@ -105,7 +105,7 @@ async def _make_request( request_body["model"] = self.model.id headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } @@ -129,7 +129,7 @@ def _make_stream_request( request_body["stream"] = True headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/text-generation/src/celeste_text_generation/providers/google/client.py b/packages/capabilities/text-generation/src/celeste_text_generation/providers/google/client.py index d547464..73cc955 100644 --- a/packages/capabilities/text-generation/src/celeste_text_generation/providers/google/client.py +++ b/packages/capabilities/text-generation/src/celeste_text_generation/providers/google/client.py @@ -108,7 +108,7 @@ async def _make_request( endpoint = config.ENDPOINT.format(model_id=self.model.id) headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } @@ -131,7 +131,7 @@ def _make_stream_request( stream_endpoint = config.STREAM_ENDPOINT.format(model_id=self.model.id) headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/text-generation/src/celeste_text_generation/providers/mistral/client.py b/packages/capabilities/text-generation/src/celeste_text_generation/providers/mistral/client.py index 669c36c..bcb2ed6 100644 --- a/packages/capabilities/text-generation/src/celeste_text_generation/providers/mistral/client.py +++ b/packages/capabilities/text-generation/src/celeste_text_generation/providers/mistral/client.py @@ -108,7 +108,7 @@ async def _make_request( request_body["model"] = self.model.id headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } @@ -132,7 +132,7 @@ def _make_stream_request( request_body["stream"] = True headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/text-generation/src/celeste_text_generation/providers/openai/client.py b/packages/capabilities/text-generation/src/celeste_text_generation/providers/openai/client.py index 3621c3a..ea9fedc 100644 --- a/packages/capabilities/text-generation/src/celeste_text_generation/providers/openai/client.py +++ b/packages/capabilities/text-generation/src/celeste_text_generation/providers/openai/client.py @@ -115,7 +115,7 @@ async def _make_request( request_body["model"] = self.model.id headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } @@ -139,7 +139,7 @@ def _make_stream_request( request_body["stream"] = True headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/text-generation/src/celeste_text_generation/providers/xai/client.py b/packages/capabilities/text-generation/src/celeste_text_generation/providers/xai/client.py index a4ef602..ea0530f 100644 --- a/packages/capabilities/text-generation/src/celeste_text_generation/providers/xai/client.py +++ b/packages/capabilities/text-generation/src/celeste_text_generation/providers/xai/client.py @@ -103,7 +103,7 @@ async def _make_request( request_body["model"] = self.model.id headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } @@ -127,7 +127,7 @@ def _make_stream_request( request_body["stream"] = True headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/video-generation/src/celeste_video_generation/providers/bytedance/client.py b/packages/capabilities/video-generation/src/celeste_video_generation/providers/bytedance/client.py index 322c137..198c149 100644 --- a/packages/capabilities/video-generation/src/celeste_video_generation/providers/bytedance/client.py +++ b/packages/capabilities/video-generation/src/celeste_video_generation/providers/bytedance/client.py @@ -180,7 +180,7 @@ async def _make_request( ) -> httpx.Response: """Make HTTP request with async polling.""" headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } diff --git a/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/client.py b/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/client.py index 4e17ddf..b950d63 100644 --- a/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/client.py +++ b/packages/capabilities/video-generation/src/celeste_video_generation/providers/google/client.py @@ -128,7 +128,7 @@ async def _make_request( url = f"{config.BASE_URL}{endpoint}" headers = { - "x-goog-api-key": self.api_key.get_secret_value(), + **self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON, } @@ -147,7 +147,7 @@ async def _make_request( logger.info(f"Video generation started: {operation_name}") poll_url = f"{config.BASE_URL}{config.POLL_ENDPOINT.format(operation_name=operation_name)}" - poll_headers = {"x-goog-api-key": self.api_key.get_secret_value()} + poll_headers = self.auth.get_headers() while True: await asyncio.sleep(config.POLL_INTERVAL) @@ -196,7 +196,7 @@ async def download_content(self, artifact: VideoArtifact) -> VideoArtifact: logger.info(f"Downloading video from: {download_url}") - headers = {"x-goog-api-key": self.api_key.get_secret_value()} + headers = self.auth.get_headers() response = await self.http_client.get( download_url, diff --git a/packages/capabilities/video-generation/src/celeste_video_generation/providers/openai/client.py b/packages/capabilities/video-generation/src/celeste_video_generation/providers/openai/client.py index b920853..e69b270 100644 --- a/packages/capabilities/video-generation/src/celeste_video_generation/providers/openai/client.py +++ b/packages/capabilities/video-generation/src/celeste_video_generation/providers/openai/client.py @@ -149,9 +149,7 @@ async def _make_request( **parameters: Unpack[VideoGenerationParameters], ) -> httpx.Response: """Make HTTP request with async polling for OpenAI video generation.""" - headers = { - config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}", - } + headers = self.auth.get_headers() files, data = await self._prepare_multipart_request(request_body.copy()) diff --git a/pyproject.toml b/pyproject.toml index 6c3a5fe..9df91a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,10 @@ dev = [ ] [tool.uv.workspace] -members = ["packages/capabilities/*"] +members = [ + "packages/providers/*", + "packages/capabilities/*", +] [tool.uv.sources] celeste-text-generation = { workspace = true } @@ -135,6 +138,7 @@ known-first-party = ["celeste"] [tool.ruff.lint.per-file-ignores] "tests/*" = ["D"] # No docstrings required in tests +"**/client.py" = ["ANN401"] # Allow Any return types in mixin client classes [tool.mypy] python_version = "3.12" diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 281c4a1..113a8cc 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -2,6 +2,7 @@ from pydantic import SecretStr +from celeste.auth import APIKey, Authentication from celeste.client import Client, get_client_class, register_client from celeste.core import Capability, Parameter, Provider from celeste.credentials import credentials @@ -61,7 +62,8 @@ def create_client( capability: Capability, provider: Provider | None = None, model: Model | str | None = None, - api_key: SecretStr | None = None, + api_key: str | SecretStr | None = None, + auth: Authentication | None = None, ) -> Client: """Create an async client for the specified AI capability. @@ -69,7 +71,8 @@ def create_client( capability: The AI capability to use (e.g., TEXT_GENERATION). provider: Optional provider. Required if model is a string ID. model: Model object, string model ID, or None for auto-selection. - api_key: Optional SecretStr override. If not specified, loaded from environment. + api_key: Optional API key override (string or SecretStr). + auth: Optional Authentication object for custom auth (e.g., GoogleADC). Returns: Configured client instance ready for generation operations. @@ -85,10 +88,12 @@ def create_client( # Resolve model resolved_model = _resolve_model(capability, provider, model) - # Get client class and credentials + # Get client class and authentication client_class = get_client_class(capability, resolved_model.provider) - resolved_key = credentials.get_credentials( - resolved_model.provider, override_key=api_key + resolved_auth = credentials.get_auth( + resolved_model.provider, + override_auth=auth, + override_key=api_key, ) # Create and return client @@ -96,12 +101,14 @@ def create_client( model=resolved_model, provider=resolved_model.provider, capability=capability, - api_key=resolved_key, + auth=resolved_auth, ) # Exports __all__ = [ + "APIKey", + "Authentication", "Capability", "Client", "ClientNotFoundError", diff --git a/src/celeste/auth.py b/src/celeste/auth.py new file mode 100644 index 0000000..3affe19 --- /dev/null +++ b/src/celeste/auth.py @@ -0,0 +1,80 @@ +"""Authentication methods for Celeste providers.""" + +from abc import ABC, abstractmethod + +from pydantic import BaseModel, SecretStr, field_validator + +# Module-level registry (same pattern as _clients and _models) +_auth_classes: dict[str, type["Authentication"]] = {} + + +class Authentication(ABC, BaseModel): + """Base class for authentication methods.""" + + @abstractmethod + def get_headers(self) -> dict[str, str]: + """Return authentication headers for HTTP requests.""" + ... + + +class APIKey(Authentication): + """API key authentication. + + Supports configurable header name and prefix for different provider formats: + - OpenAI: Authorization: Bearer + - Anthropic: x-api-key: + - Google: x-goog-api-key: + - ElevenLabs: xi-api-key: + """ + + key: SecretStr + header: str = "Authorization" + prefix: str = "Bearer " + + @field_validator("key", mode="before") + @classmethod + def convert_to_secret(cls, v: str | SecretStr) -> SecretStr: + """Accept plain strings, auto-convert to SecretStr.""" + if isinstance(v, str): + return SecretStr(v) + return v + + def get_headers(self) -> dict[str, str]: + """Return API key authentication header.""" + return {self.header: f"{self.prefix}{self.key.get_secret_value()}"} + + +def register_auth(auth_type: str, auth_class: type[Authentication]) -> None: + """Register an authentication class. + + Args: + auth_type: The auth type identifier (e.g., "google_adc"). + auth_class: The Authentication subclass to register. + """ + _auth_classes[auth_type] = auth_class + + +def get_auth_class(auth_type: str) -> type[Authentication]: + """Get a registered authentication class by type. + + Args: + auth_type: The auth type identifier. + + Returns: + The registered Authentication subclass. + + Raises: + ValueError: If auth type is not registered. + """ + from celeste.registry import _load_from_entry_points + + _load_from_entry_points() + + if auth_type not in _auth_classes: + msg = f"Unknown auth type: {auth_type}. Available: {list(_auth_classes.keys())}" + raise ValueError(msg) + + return _auth_classes[auth_type] + + +__all__ = ["APIKey", "Authentication", "get_auth_class", "register_auth"] diff --git a/src/celeste/client.py b/src/celeste/client.py index c2510b8..92ee479 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -6,8 +6,9 @@ from typing import Any, Unpack import httpx -from pydantic import BaseModel, ConfigDict, Field, SecretStr +from pydantic import BaseModel, ConfigDict, Field +from celeste.auth import Authentication from celeste.core import Capability, Provider from celeste.exceptions import ( ClientNotFoundError, @@ -29,7 +30,7 @@ class Client[In: Input, Out: Output, Params: Parameters](ABC, BaseModel): model: Model provider: Provider capability: Capability - api_key: SecretStr = Field(exclude=True) + auth: Authentication = Field(exclude=True) def model_post_init(self, __context: object) -> None: """Validate capability compatibility.""" @@ -46,7 +47,7 @@ def http_client(self) -> HTTPClient: async def generate( self, - *args: Any, # noqa: ANN401 + *args: Any, **parameters: Unpack[Params], # type: ignore[misc] ) -> Out: """Generate content - signature varies by capability. @@ -73,7 +74,7 @@ async def generate( def stream( self, - *args: Any, # noqa: ANN401 + *args: Any, **parameters: Unpack[Params], # type: ignore[misc] ) -> Stream[Out, Params, Chunk]: """Stream content - signature varies by capability. @@ -139,7 +140,7 @@ def _parse_finish_reason( @abstractmethod def _create_inputs( self, - *args: Any, # noqa: ANN401 + *args: Any, **parameters: Unpack[Params], # type: ignore[misc] ) -> In: """Map positional arguments to Input type.""" diff --git a/src/celeste/credentials.py b/src/celeste/credentials.py index 39949e9..382ca97 100644 --- a/src/celeste/credentials.py +++ b/src/celeste/credentials.py @@ -4,14 +4,34 @@ from pydantic import Field, SecretStr from pydantic_settings import BaseSettings +from celeste.auth import APIKey, Authentication from celeste.core import Provider from celeste.exceptions import MissingCredentialsError, UnsupportedProviderError +# Provider to auth configuration mapping +# Maps provider to (package_name, header, prefix) for API key auth +PROVIDER_AUTH_CONFIG: dict[Provider, tuple[str, str, str]] = { + Provider.OPENAI: ("celeste_openai", "Authorization", "Bearer "), + Provider.ANTHROPIC: ("celeste_anthropic", "x-api-key", ""), + Provider.GOOGLE: ("celeste_google", "x-goog-api-key", ""), + Provider.MISTRAL: ("celeste_mistral", "Authorization", "Bearer "), + Provider.HUGGINGFACE: ("celeste_huggingface", "Authorization", "Bearer "), + Provider.STABILITYAI: ("celeste_stabilityai", "Authorization", "Bearer "), + Provider.REPLICATE: ("celeste_replicate", "Authorization", "Bearer "), + Provider.COHERE: ("celeste_cohere", "Authorization", "bearer "), + Provider.XAI: ("celeste_xai", "Authorization", "Bearer "), + Provider.LUMA: ("celeste_luma", "Authorization", "Bearer "), + Provider.TOPAZLABS: ("celeste_topazlabs", "X-API-Key", ""), + Provider.PERPLEXITY: ("celeste_perplexity", "Authorization", "Bearer "), + Provider.BYTEDANCE: ("celeste_bytedance", "Authorization", "Bearer "), + Provider.ELEVENLABS: ("celeste_elevenlabs", "xi-api-key", ""), + Provider.BFL: ("celeste_bfl", "x-key", ""), +} + # Provider to credential field mapping PROVIDER_CREDENTIAL_MAP = { Provider.OPENAI: "openai_api_key", Provider.ANTHROPIC: "anthropic_api_key", - Provider.BFL: "bfl_api_key", Provider.GOOGLE: "google_api_key", Provider.MISTRAL: "mistral_api_key", Provider.HUGGINGFACE: "huggingface_token", @@ -24,6 +44,7 @@ Provider.PERPLEXITY: "perplexity_api_key", Provider.BYTEDANCE: "bytedance_api_key", Provider.ELEVENLABS: "elevenlabs_api_key", + Provider.BFL: "bfl_api_key", } @@ -32,7 +53,6 @@ class Credentials(BaseSettings): openai_api_key: SecretStr | None = Field(None, alias="OPENAI_API_KEY") anthropic_api_key: SecretStr | None = Field(None, alias="ANTHROPIC_API_KEY") - bfl_api_key: SecretStr | None = Field(None, alias="BFL_API_KEY") google_api_key: SecretStr | None = Field(None, alias="GOOGLE_API_KEY") mistral_api_key: SecretStr | None = Field(None, alias="MISTRAL_API_KEY") huggingface_token: SecretStr | None = Field(None, alias="HUGGINGFACE_TOKEN") @@ -45,6 +65,7 @@ class Credentials(BaseSettings): perplexity_api_key: SecretStr | None = Field(None, alias="PERPLEXITY_API_KEY") bytedance_api_key: SecretStr | None = Field(None, alias="BYTEDANCE_API_KEY") elevenlabs_api_key: SecretStr | None = Field(None, alias="ELEVENLABS_API_KEY") + bfl_api_key: SecretStr | None = Field(None, alias="BFL_API_KEY") model_config = { "env_file": find_dotenv(), @@ -54,13 +75,13 @@ class Credentials(BaseSettings): } def get_credentials( - self, provider: Provider, override_key: SecretStr | None = None + self, provider: Provider, override_key: str | SecretStr | None = None ) -> SecretStr: """Get credentials for a specific provider with optional override. Args: provider: The AI provider to get credentials for. - override_key: Optional SecretStr to use instead of environment variable. + override_key: Optional key to use instead of environment variable. Returns: SecretStr containing the API key for the provider. @@ -69,6 +90,8 @@ def get_credentials( MissingCredentialsError: If provider requires credentials but none are configured. """ if override_key: + if isinstance(override_key, str): + return SecretStr(override_key) return override_key if not self.has_credential(provider): @@ -106,7 +129,48 @@ def has_credential(self, provider: Provider) -> bool: raise UnsupportedProviderError(provider=provider) return getattr(self, credential_field, None) is not None + def get_auth( + self, + provider: Provider, + override_auth: Authentication | None = None, + override_key: str | SecretStr | None = None, + ) -> Authentication: + """Get authentication for a specific provider. + + Args: + provider: The AI provider to authenticate with. + override_auth: Optional Authentication object to use directly. + override_key: Optional API key to use instead of environment variable. + + Returns: + Authentication object configured for the provider. + + Raises: + MissingCredentialsError: If provider requires credentials but none configured. + UnsupportedProviderError: If provider has no auth configuration. + """ + # Direct auth object takes precedence + if override_auth is not None: + return override_auth + + # Get auth config for provider + auth_config = PROVIDER_AUTH_CONFIG.get(provider) + if not auth_config: + raise UnsupportedProviderError(provider=provider) + + # API key authentication + _package_name, header, prefix = auth_config + + # Get API key (override or from environment) + api_key = self.get_credentials(provider, override_key) + + return APIKey( + key=api_key, + header=header, + prefix=prefix, + ) + credentials = Credentials.model_validate({}) -__all__ = ["Credentials", "credentials"] +__all__ = ["PROVIDER_AUTH_CONFIG", "Credentials", "credentials"] diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 8d027f2..48f7606 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -8,6 +8,7 @@ import pytest from pydantic import SecretStr +from celeste.auth import APIKey from celeste.client import Client, _clients, get_client_class, register_client from celeste.core import Capability, Provider from celeste.exceptions import ( @@ -226,7 +227,7 @@ def test_successful_creation_with_compatible_capability( model=text_model, provider=text_model.provider, capability=Capability.TEXT_GENERATION, - api_key=SecretStr(api_key), + auth=APIKey(key=SecretStr(api_key)), ) # Assert @@ -246,7 +247,7 @@ def test_validation_failure_with_incompatible_capability( model=text_model, provider=text_model.provider, capability=Capability.IMAGE_GENERATION, # Model doesn't support this - api_key=SecretStr(api_key), + auth=APIKey(key=SecretStr(api_key)), ) @pytest.mark.parametrize( @@ -270,7 +271,7 @@ def test_validation_success_with_supported_capabilities( model=multimodal_model, provider=multimodal_model.provider, capability=capability, - api_key=SecretStr(api_key), + auth=APIKey(key=SecretStr(api_key)), ) # Assert @@ -298,7 +299,7 @@ def test_validation_fails_with_model_lacking_any_capabilities( model=empty_model, provider=empty_model.provider, capability=Capability.TEXT_GENERATION, - api_key=SecretStr(api_key), + auth=APIKey(key=SecretStr(api_key)), ) @@ -438,7 +439,7 @@ def parameter_mappers(cls) -> list[ParameterMapper]: model=text_model, provider=text_model.provider, capability=Capability.TEXT_GENERATION, - api_key=SecretStr(api_key), + auth=APIKey(key=SecretStr(api_key)), ) inputs = _TestInput(prompt="test prompt") @@ -471,7 +472,7 @@ def parameter_mappers(cls) -> list[ParameterMapper]: model=text_model, provider=text_model.provider, capability=Capability.TEXT_GENERATION, - api_key=SecretStr(api_key), + auth=APIKey(key=SecretStr(api_key)), ) inputs = _TestInput(prompt="test prompt") @@ -515,7 +516,7 @@ def parameter_mappers(cls) -> list[ParameterMapper]: model=text_model, provider=text_model.provider, capability=Capability.TEXT_GENERATION, - api_key=SecretStr(api_key), + auth=APIKey(key=SecretStr(api_key)), ) original_content = "original content" @@ -540,7 +541,7 @@ def test_stream_raises_not_implemented_with_descriptive_error( model=text_model, provider=text_model.provider, capability=Capability.TEXT_GENERATION, - api_key=SecretStr(api_key), + auth=APIKey(key=SecretStr(api_key)), ) # Act & Assert