Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def _output_class(cls) -> type[ImageGenerationOutput]:
def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]:
"""Build metadata dictionary from response data."""
metadata = super()._build_metadata(response_data)
metadata["raw_response"] = (
response_data # Filtered response data (content fields removed by providers before calling super)
)
metadata["raw_response"] = response_data
return metadata

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def _output_class(cls) -> type[TextGenerationOutput]:
def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]:
"""Build metadata dictionary from response data."""
metadata = super()._build_metadata(response_data)
metadata["raw_response"] = (
response_data
)
metadata["raw_response"] = response_data
return metadata

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
TextGenerationInput,
TextGenerationUsage,
)
from celeste_text_generation.parameters import TextGenerationParameters
from celeste_text_generation.parameters import (
TextGenerationParameter,
TextGenerationParameters,
)

from . import config
from .parameters import ANTHROPIC_PARAMETER_MAPPERS
Expand Down Expand Up @@ -62,14 +65,6 @@ def _parse_content(
msg = "No content blocks in response"
raise ValueError(msg)

output_schema = parameters.get("output_schema")
if output_schema is not None:
for content_block in content:
if content_block.get("type") == "tool_use":
tool_input = content_block.get("input")
if tool_input is not None:
return self._transform_output(tool_input, **parameters)

text_content = ""
for content_block in content:
if content_block.get("type") == "text":
Expand Down Expand Up @@ -112,6 +107,9 @@ async def _make_request(
"Content-Type": ApplicationMimeType.JSON,
}

if parameters.get(TextGenerationParameter.OUTPUT_SCHEMA) is not None:
headers[config.ANTHROPIC_BETA_HEADER] = config.STRUCTURED_OUTPUTS_BETA

return await self.http_client.post(
f"{config.BASE_URL}{config.ENDPOINT}",
headers=headers,
Expand All @@ -138,6 +136,9 @@ def _make_stream_request(
"Content-Type": ApplicationMimeType.JSON,
}

if parameters.get(TextGenerationParameter.OUTPUT_SCHEMA) is not None:
headers[config.ANTHROPIC_BETA_HEADER] = config.STRUCTURED_OUTPUTS_BETA

return self.http_client.stream_post(
f"{config.BASE_URL}{config.STREAM_ENDPOINT}",
headers=headers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@
# API Version Header (required by Anthropic)
ANTHROPIC_VERSION_HEADER = "anthropic-version"
ANTHROPIC_VERSION = "2023-06-01"

# Beta Features
ANTHROPIC_BETA_HEADER = "anthropic-beta"
STRUCTURED_OUTPUTS_BETA = "structured-outputs-2025-11-13"
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
streaming=True,
parameter_constraints={
TextGenerationParameter.THINKING_BUDGET: Range(min=-1, max=32000),
TextGenerationParameter.OUTPUT_SCHEMA: Schema(),
},
),
Model(
Expand All @@ -42,7 +41,6 @@
streaming=True,
parameter_constraints={
TextGenerationParameter.THINKING_BUDGET: Range(min=-1, max=64000),
TextGenerationParameter.OUTPUT_SCHEMA: Schema(),
},
),
Model(
Expand All @@ -52,7 +50,6 @@
streaming=True,
parameter_constraints={
TextGenerationParameter.THINKING_BUDGET: Range(min=-1, max=32000),
TextGenerationParameter.OUTPUT_SCHEMA: Schema(),
},
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydantic import BaseModel, TypeAdapter

from celeste.exceptions import ConstraintViolationError, ValidationError
from celeste.exceptions import ConstraintViolationError
from celeste.models import Model
from celeste.parameters import ParameterMapper
from celeste_text_generation.parameters import TextGenerationParameter
Expand Down Expand Up @@ -62,7 +62,7 @@ def map(


class OutputSchemaMapper(ParameterMapper):
"""Map output_schema parameter to Anthropic tools parameter (tool-based structured output)."""
"""Map output_schema parameter to Anthropic native structured outputs (output_format)."""

name: StrEnum = TextGenerationParameter.OUTPUT_SCHEMA

Expand All @@ -72,55 +72,44 @@ def map(
value: object,
model: Model,
) -> dict[str, Any]:
"""Transform output_schema into provider request.
"""Transform output_schema into provider request using native structured outputs.

Converts unified output_schema to Anthropic tools parameter:
- Creates a single tool definition with input_schema matching the output schema
- Sets tool_choice to force tool use
Converts unified output_schema to Anthropic output_format parameter:
- Uses output_format with type: "json_schema" and schema definition
- Handles both BaseModel and list[BaseModel] types
- For list[BaseModel], schema is array type directly

Args:
request: Provider request dict.
value: output_schema value (type[BaseModel] | None).
model: Model instance containing parameter_constraints for validation.
model: Model instance with parameter constraints for validation.

Returns:
Updated request dict with tools and tool_choice if value provided.
Updated request dict with output_format if value provided.
"""
validated_value = self._validate_value(value, model)
if validated_value is None:
return request

# Convert Pydantic model to JSON Schema
schema = self._convert_to_anthropic_schema(validated_value)
tool_name = self._get_tool_name(validated_value)
schema = self._convert_to_json_schema(validated_value)

# Create tool definition with input_schema matching output schema
tool_def = {
"name": tool_name,
"description": f"Extract structured data conforming to {self._get_schema_description(validated_value)}",
"input_schema": schema,
request["output_format"] = {
"type": "json_schema",
"schema": schema,
}

# Add tools array to request
request["tools"] = [tool_def]

# Force tool use by setting tool_choice
request["tool_choice"] = {"type": "tool", "name": tool_name}

return request

def parse_output(
self, content: str | dict[str, Any], value: object | None
) -> str | BaseModel:
"""Parse tool_use blocks from response to BaseModel instance.
"""Parse JSON text from response to BaseModel instance.

Extracts structured data from tool_use.input field and converts to BaseModel.
For list[BaseModel], extracts the "items" array from the wrapped object.
With native structured outputs, content is direct JSON text in content[0].text.
For list[BaseModel], content is array directly.

Args:
content: Either tool_use.input dict (from tool_use block) or JSON string.
content: JSON string from content[0].text.
value: Original output_schema parameter value.

Returns:
Expand All @@ -129,72 +118,69 @@ def parse_output(
if value is None:
return content if isinstance(content, str) else json.dumps(content)

# If content is already a dict (from tool_use.input), use it directly
if isinstance(content, dict):
parsed_json = content
else:
# Otherwise parse as JSON string
parsed_json = json.loads(content)

# Check if value is list[BaseModel] and content is wrapped in object
origin = get_origin(value)
if origin is list:
# Handle empty dict case FIRST - convert to empty array before checking for "items"
if isinstance(parsed_json, dict) and not parsed_json:
# Empty dict when expecting list - convert to empty array
parsed_json = []
elif isinstance(parsed_json, dict) and "items" in parsed_json:
# Extract items array from wrapped format
parsed_json = parsed_json["items"]
# If it's already an array (backward compatibility), use it directly
# parsed_json is now the array, ready for TypeAdapter
elif isinstance(parsed_json, dict) and not parsed_json:
# Empty dict for BaseModel (not list) - this is invalid, raise error
msg = "Empty tool_use input dict cannot be converted to BaseModel"
raise ValidationError(msg)

# Parse to BaseModel instance using TypeAdapter
# TypeAdapter handles both BaseModel and list[BaseModel]
return TypeAdapter(value).validate_json(json.dumps(parsed_json))

def _convert_to_anthropic_schema(self, output_schema: Any) -> dict[str, Any]: # noqa: ANN401
"""Convert Pydantic BaseModel or list[BaseModel] to Anthropic JSON Schema format.
def _convert_to_json_schema(self, output_schema: Any) -> dict[str, Any]: # noqa: ANN401
"""Convert Pydantic BaseModel or list[BaseModel] to JSON Schema format.

Anthropic requires input_schema to always be an object type.
For list[T], wraps array schema in an object with "items" property.
For native structured outputs, list[T] is array type directly.
Ensures all object types have additionalProperties: false as required by Anthropic.

Args:
output_schema: Pydantic BaseModel class or list[BaseModel] type.

Returns:
JSON Schema dictionary compatible with Anthropic (always object type).
JSON Schema dictionary compatible with Anthropic structured outputs.
"""
origin = get_origin(output_schema)
if origin is list:
# For list[T], wrap array schema in an object (Anthropic requirement)
inner_type = get_args(output_schema)[0]
items_schema = inner_type.model_json_schema()
# Resolve refs in items schema first
items_schema = self._resolve_refs(items_schema)
# Wrap in object with "items" property
json_schema = {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": items_schema,
},
},
"required": ["items"],
"type": "array",
"items": items_schema,
}
else:
# For BaseModel, use model_json_schema directly
json_schema = output_schema.model_json_schema()
# Resolve $ref references inline (Anthropic may not support $ref)
json_schema = self._resolve_refs(json_schema)

json_schema = self._ensure_additional_properties(json_schema)
return json_schema

def _ensure_additional_properties(self, schema: dict[str, Any]) -> dict[str, Any]:
"""Ensure all object types have additionalProperties: false."""
if not isinstance(schema, dict):
return schema

schema = schema.copy()

if schema.get("type") == "object":
schema["additionalProperties"] = False

for key in ["properties", "items"]:
if key in schema:
if key == "properties":
schema[key] = {
k: self._ensure_additional_properties(v)
for k, v in schema[key].items()
}
else:
schema[key] = self._ensure_additional_properties(schema[key])

for key in ["anyOf", "allOf"]:
if key in schema:
schema[key] = [
self._ensure_additional_properties(item) for item in schema[key]
]

return schema

def _resolve_refs(self, schema: dict[str, Any]) -> dict[str, Any]:
"""Resolve all $ref references and inline definitions.

Expand Down Expand Up @@ -250,36 +236,6 @@ def resolve(value: Any) -> Any: # noqa: ANN401

return resolve(schema)

def _get_tool_name(self, output_schema: Any) -> str: # noqa: ANN401
"""Derive tool name from model class name.

Args:
output_schema: Pydantic BaseModel class or list[BaseModel] type.

Returns:
Tool name (lowercase class name or "extract_data" as fallback).
"""
origin = get_origin(output_schema)
if origin is list:
inner_type = get_args(output_schema)[0]
return inner_type.__name__.lower() or "extract_data"
return output_schema.__name__.lower() or "extract_data"

def _get_schema_description(self, output_schema: Any) -> str: # noqa: ANN401
"""Get description for tool definition.

Args:
output_schema: Pydantic BaseModel class or list[BaseModel] type.

Returns:
Schema description string.
"""
origin = get_origin(output_schema)
if origin is list:
inner_type = get_args(output_schema)[0]
return f"array of {inner_type.__name__}"
return output_schema.__name__


ANTHROPIC_PARAMETER_MAPPERS: list[ParameterMapper] = [
ThinkingBudgetMapper(),
Expand Down
Loading
Loading