Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ help:
@echo " make format - Apply Ruff formatting"
@echo " make typecheck - Run mypy type checking"
@echo " make test - Run all tests (core + packages) with coverage"
@echo " make integration-test - Run integration tests (requires API keys)"
@echo " make integration-test [capability] - Run integration tests (all or specific)"
@echo " (e.g., make integration-test image-intelligence)"
@echo " make security - Run Bandit security scan"
@echo " make ci - Run full CI/CD pipeline"
@echo " make clean - Clean cache directories"
Expand All @@ -33,15 +34,22 @@ format:

# Type checking (fail fast on any error)
typecheck:
@uv run mypy -p celeste && uv run mypy tests/ && uv run mypy packages/capabilities/image-generation packages/capabilities/text-generation packages/capabilities/video-generation packages/capabilities/speech-generation
@uv run mypy -p celeste && uv run mypy tests/ && uv run mypy packages/*/*/src/

# Testing
test:
uv run pytest tests/unit_tests packages/capabilities/*/tests/unit_tests --cov=celeste --cov-report=term-missing --cov-fail-under=80 -v
uv run pytest tests/unit_tests --cov=celeste --cov-report=term-missing --cov-fail-under=80 -v

# Integration testing (requires API keys)
# Usage: make integration-test [capability]
integration-test:
uv run pytest packages/capabilities/*/tests/integration_tests/ -m integration -v --dist=worksteal -n auto
@cap="$(filter-out $@,$(MAKECMDGOALS))"; \
if [ -z "$$cap" ]; then cap="*"; fi; \
uv run pytest packages/capabilities/$$cap/tests/integration_tests/ -m integration -v --dist=worksteal -n auto

# Catch capability names as no-op targets
%:
@:

# Security scanning (config reads from pyproject.toml)
security:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def _make_request(
) -> httpx.Response:
"""Make HTTP request(s) and return response object."""
headers = {
config.AUTH_HEADER_NAME: self.api_key.get_secret_value(),
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
"Accept": ApplicationMimeType.JSON,
}
Expand All @@ -103,7 +103,7 @@ async def _make_request(

start_time = time.monotonic()
poll_headers = {
config.AUTH_HEADER_NAME: self.api_key.get_secret_value(),
**self.auth.get_headers(),
"Accept": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def _make_request(
request_body["stream"] = False

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand All @@ -142,7 +142,7 @@ def _make_stream_request(
request_body["stream"] = True

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class GoogleImageGenerationClient(ImageGenerationClient):

model_config = ConfigDict(extra="allow")

def model_post_init(self, __context: Any) -> None: # noqa: ANN401
def model_post_init(self, __context: Any) -> None:
"""Initialize API adapter based on model type."""
super().model_post_init(__context)

Expand Down Expand Up @@ -103,7 +103,7 @@ async def _make_request(
) -> httpx.Response:
"""Make HTTP request(s) and return response object."""
headers = {
config.AUTH_HEADER_NAME: self.api_key.get_secret_value(),
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def _make_request(
) -> httpx.Response:
"""Make HTTP request(s) and return response object."""
headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand All @@ -120,7 +120,7 @@ def _make_stream_request(
request_body["partial_images"] = 1

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from celeste_image_generation.providers.google.client import GoogleImageGenerationClient
from pydantic import SecretStr

from celeste.auth import APIKey
from celeste.core import Capability, Provider
from celeste.models import Model

Expand All @@ -25,7 +26,7 @@ def client(self) -> GoogleImageGenerationClient:
),
provider=Provider.GOOGLE,
capability=Capability.IMAGE_GENERATION,
api_key=SecretStr("test-key"),
auth=APIKey(key=SecretStr("test-key")),
)

@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def _make_request(
endpoint = config.ENDPOINT.format(voice_id=voice_id)

headers = {
config.AUTH_HEADER_NAME: self.api_key.get_secret_value(),
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down Expand Up @@ -152,7 +152,7 @@ def _make_stream_request(
stream_endpoint = config.STREAM_ENDPOINT.format(voice_id=voice_id)

headers = {
config.AUTH_HEADER_NAME: self.api_key.get_secret_value(),
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async def _make_request(
endpoint = config.ENDPOINT.format(model_id=self.model.id)

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def _make_request(
request_body["model"] = self.model.id

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def _make_request(
request_body["max_tokens"] = parameters.get("max_tokens") or 1024

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
config.ANTHROPIC_VERSION_HEADER: config.ANTHROPIC_VERSION,
"Content-Type": ApplicationMimeType.JSON,
}
Expand Down Expand Up @@ -129,7 +129,7 @@ def _make_stream_request(
request_body["stream"] = True

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
config.ANTHROPIC_VERSION_HEADER: config.ANTHROPIC_VERSION,
"Content-Type": ApplicationMimeType.JSON,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def _make_request(
request_body["model"] = self.model.id

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand All @@ -129,7 +129,7 @@ def _make_stream_request(
request_body["stream"] = True

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def _make_request(
endpoint = config.ENDPOINT.format(model_id=self.model.id)

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand All @@ -131,7 +131,7 @@ def _make_stream_request(
stream_endpoint = config.STREAM_ENDPOINT.format(model_id=self.model.id)

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def _make_request(
request_body["model"] = self.model.id

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand All @@ -132,7 +132,7 @@ def _make_stream_request(
request_body["stream"] = True

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def _make_request(
request_body["model"] = self.model.id

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand All @@ -139,7 +139,7 @@ def _make_stream_request(
request_body["stream"] = True

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def _make_request(
request_body["model"] = self.model.id

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand All @@ -127,7 +127,7 @@ def _make_stream_request(
request_body["stream"] = True

headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def _make_request(
) -> httpx.Response:
"""Make HTTP request with async polling."""
headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def _make_request(
url = f"{config.BASE_URL}{endpoint}"

headers = {
"x-goog-api-key": self.api_key.get_secret_value(),
**self.auth.get_headers(),
"Content-Type": ApplicationMimeType.JSON,
}

Expand All @@ -147,7 +147,7 @@ async def _make_request(
logger.info(f"Video generation started: {operation_name}")

poll_url = f"{config.BASE_URL}{config.POLL_ENDPOINT.format(operation_name=operation_name)}"
poll_headers = {"x-goog-api-key": self.api_key.get_secret_value()}
poll_headers = self.auth.get_headers()

while True:
await asyncio.sleep(config.POLL_INTERVAL)
Expand Down Expand Up @@ -196,7 +196,7 @@ async def download_content(self, artifact: VideoArtifact) -> VideoArtifact:

logger.info(f"Downloading video from: {download_url}")

headers = {"x-goog-api-key": self.api_key.get_secret_value()}
headers = self.auth.get_headers()

response = await self.http_client.get(
download_url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ async def _make_request(
**parameters: Unpack[VideoGenerationParameters],
) -> httpx.Response:
"""Make HTTP request with async polling for OpenAI video generation."""
headers = {
config.AUTH_HEADER_NAME: f"{config.AUTH_HEADER_PREFIX}{self.api_key.get_secret_value()}",
}
headers = self.auth.get_headers()

files, data = await self._prepare_multipart_request(request_body.copy())

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ dev = [
]

[tool.uv.workspace]
members = ["packages/capabilities/*"]
members = [
"packages/providers/*",
"packages/capabilities/*",
]

[tool.uv.sources]
celeste-text-generation = { workspace = true }
Expand Down Expand Up @@ -135,6 +138,7 @@ known-first-party = ["celeste"]

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D"] # No docstrings required in tests
"**/client.py" = ["ANN401"] # Allow Any return types in mixin client classes

[tool.mypy]
python_version = "3.12"
Expand Down
19 changes: 13 additions & 6 deletions src/celeste/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pydantic import SecretStr

from celeste.auth import APIKey, Authentication
from celeste.client import Client, get_client_class, register_client
from celeste.core import Capability, Parameter, Provider
from celeste.credentials import credentials
Expand Down Expand Up @@ -61,15 +62,17 @@ def create_client(
capability: Capability,
provider: Provider | None = None,
model: Model | str | None = None,
api_key: SecretStr | None = None,
api_key: str | SecretStr | None = None,
auth: Authentication | None = None,
) -> Client:
"""Create an async client for the specified AI capability.

Args:
capability: The AI capability to use (e.g., TEXT_GENERATION).
provider: Optional provider. Required if model is a string ID.
model: Model object, string model ID, or None for auto-selection.
api_key: Optional SecretStr override. If not specified, loaded from environment.
api_key: Optional API key override (string or SecretStr).
auth: Optional Authentication object for custom auth (e.g., GoogleADC).

Returns:
Configured client instance ready for generation operations.
Expand All @@ -85,23 +88,27 @@ def create_client(
# Resolve model
resolved_model = _resolve_model(capability, provider, model)

# Get client class and credentials
# Get client class and authentication
client_class = get_client_class(capability, resolved_model.provider)
resolved_key = credentials.get_credentials(
resolved_model.provider, override_key=api_key
resolved_auth = credentials.get_auth(
resolved_model.provider,
override_auth=auth,
override_key=api_key,
)

# Create and return client
return client_class(
model=resolved_model,
provider=resolved_model.provider,
capability=capability,
api_key=resolved_key,
auth=resolved_auth,
)


# Exports
__all__ = [
"APIKey",
"Authentication",
"Capability",
"Client",
"ClientNotFoundError",
Expand Down
Loading
Loading