diff --git a/pyproject.toml b/pyproject.toml index cd58c3e..8bcb744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "celeste-ai" -version = "0.3.7" +version = "0.3.8" description = "Open source, type-safe primitives for multi-modal AI. All capabilities, all providers, one interface" authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}] readme = "README.md" diff --git a/src/celeste/constraints.py b/src/celeste/constraints.py index 9ec6387..491d8f7 100644 --- a/src/celeste/constraints.py +++ b/src/celeste/constraints.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from typing import Any, get_args, get_origin -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, computed_field from celeste.artifacts import ImageArtifact from celeste.exceptions import ConstraintViolationError @@ -15,6 +15,12 @@ class Constraint(BaseModel, ABC): """Base constraint for parameter validation.""" + @computed_field # type: ignore[prop-decorator] + @property + def type(self) -> str: + """Constraint type identifier for serialization.""" + return self.__class__.__name__ + @abstractmethod def __call__(self, value: Any) -> Any: # noqa: ANN401 """Validate value against constraint and return validated value.""" diff --git a/src/celeste/core.py b/src/celeste/core.py index f663f33..804e7fa 100644 --- a/src/celeste/core.py +++ b/src/celeste/core.py @@ -49,6 +49,15 @@ class Capability(StrEnum): SEARCH = "search" +class InputType(StrEnum): + """Input types for capabilities.""" + + TEXT = "text" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + + class Parameter(StrEnum): """Universal parameters across most capabilities.""" @@ -77,4 +86,4 @@ class UsageField(StrEnum): CACHE_READ_INPUT_TOKENS = "cache_read_input_tokens" -__all__ = ["Capability", "Parameter", "Provider", "UsageField"] +__all__ = ["Capability", "InputType", "Parameter", "Provider", "UsageField"] diff --git a/src/celeste/io.py b/src/celeste/io.py index e74024d..d71912d 100644 --- a/src/celeste/io.py +++ b/src/celeste/io.py @@ -1,10 +1,14 @@ """Input and output types for generation operations.""" -from typing import Any +import inspect +import types +from typing import Any, get_args, get_origin from pydantic import BaseModel, Field -from celeste.core import Capability +from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact +from celeste.constraints import Constraint +from celeste.core import Capability, InputType class Input(BaseModel): @@ -59,12 +63,88 @@ def get_input_class(capability: Capability) -> type[Input]: return _inputs[capability] +# Centralized mapping: field type → InputType +INPUT_TYPE_MAPPING: dict[type, InputType] = { + str: InputType.TEXT, + ImageArtifact: InputType.IMAGE, + VideoArtifact: InputType.VIDEO, + AudioArtifact: InputType.AUDIO, +} + + +def get_required_input_types(capability: Capability) -> set[InputType]: + """Derive required input types from Input class fields. + + Introspects the Input class registered for a capability and returns + the set of InputTypes based on field annotations. + + Args: + capability: The capability to get required input types for. + + Returns: + Set of InputType values required by the capability's Input class. + """ + input_class = get_input_class(capability) + return { + INPUT_TYPE_MAPPING[field.annotation] + for field in input_class.model_fields.values() + if field.annotation in INPUT_TYPE_MAPPING + } + + +def _extract_input_type(param_type: type) -> InputType | None: + """Extract InputType from a type, handling unions and generics. + + Args: + param_type: The type annotation to inspect. + + Returns: + InputType if found in the type or its nested types, None otherwise. + """ + # Direct match + if param_type in INPUT_TYPE_MAPPING: + return INPUT_TYPE_MAPPING[param_type] + + # Handle union types (X | Y) and generics (list[X]) + origin = get_origin(param_type) + if origin is types.UnionType or origin is not None: + for arg in get_args(param_type): + result = _extract_input_type(arg) + if result is not None: + return result + + return None + + +def get_constraint_input_type(constraint: Constraint) -> InputType | None: + """Get InputType from constraint's __call__ signature. + + Introspects the constraint's __call__ method to find what artifact type + it accepts, then maps to InputType using INPUT_TYPE_MAPPING. + + Args: + constraint: The constraint to inspect. + + Returns: + InputType if the constraint accepts a mapped artifact type, None otherwise. + """ + annotations = inspect.get_annotations(constraint.__call__, eval_str=True) + for param_type in annotations.values(): + result = _extract_input_type(param_type) + if result is not None: + return result + return None + + __all__ = [ + "INPUT_TYPE_MAPPING", "Chunk", "FinishReason", "Input", "Output", "Usage", + "get_constraint_input_type", "get_input_class", + "get_required_input_types", "register_input", ] diff --git a/src/celeste/models.py b/src/celeste/models.py index ddf015d..2c0eee8 100644 --- a/src/celeste/models.py +++ b/src/celeste/models.py @@ -1,9 +1,10 @@ """Models and model registry for Celeste.""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SerializeAsAny, computed_field from celeste.constraints import Constraint -from celeste.core import Capability, Provider +from celeste.core import Capability, InputType, Provider +from celeste.io import get_constraint_input_type, get_required_input_types class Model(BaseModel): @@ -13,7 +14,9 @@ class Model(BaseModel): provider: Provider display_name: str capabilities: set[Capability] = Field(default_factory=set) - parameter_constraints: dict[str, Constraint] = Field(default_factory=dict) + parameter_constraints: dict[str, SerializeAsAny[Constraint]] = Field( + default_factory=dict + ) streaming: bool = Field(default=False) @property @@ -21,6 +24,23 @@ def supported_parameters(self) -> set[str]: """Compute supported parameter names from parameter_constraints.""" return set(self.parameter_constraints.keys()) + @computed_field # type: ignore[prop-decorator] + @property + def supported_input_types(self) -> dict[Capability, set[InputType]]: + """Input types supported per capability (derived from Input class fields).""" + return {cap: get_required_input_types(cap) for cap in self.capabilities} + + @computed_field # type: ignore[prop-decorator] + @property + def optional_input_types(self) -> set[InputType]: + """Optional input types accepted via parameter_constraints.""" + types: set[InputType] = set() + for constraint in self.parameter_constraints.values(): + input_type = get_constraint_input_type(constraint) + if input_type is not None: + types.add(input_type) + return types + # Module-level registry mapping (model_id, provider) to model _models: dict[tuple[str, Provider], Model] = {}