diff --git a/pydantic_ai_slim/pydantic_ai/_json_schema.py b/pydantic_ai_slim/pydantic_ai/_json_schema.py index cde9eeb215..cbaa180208 100644 --- a/pydantic_ai_slim/pydantic_ai/_json_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_json_schema.py @@ -54,14 +54,14 @@ def walk(self) -> JsonSchema: if not self.prefer_inlined_defs and self.defs: handled['$defs'] = {k: self._handle(v) for k, v in self.defs.items()} - elif self.recursive_refs: # pragma: no cover + elif self.recursive_refs: # If we are preferring inlined defs and there are recursive refs, we _have_ to use a $defs+$ref structure # We try to use whatever the original root key was, but if it is already in use, # we modify it to avoid collisions. defs = {key: self.defs[key] for key in self.recursive_refs} root_ref = self.schema.get('$ref') root_key = None if root_ref is None else re.sub(r'^#/\$defs/', '', root_ref) - if root_key is None: + if root_key is None: # pragma: no cover root_key = self.schema.get('title', 'root') while root_key in defs: # Modify the root key until it is not already in use @@ -77,6 +77,8 @@ def _handle(self, schema: JsonSchema) -> JsonSchema: if self.prefer_inlined_defs: while ref := schema.get('$ref'): key = re.sub(r'^#/\$defs/', '', ref) + if key in self.recursive_refs: + break if key in self.refs_stack: self.recursive_refs.add(key) break # recursive ref can't be unpacked diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 2a4c821dc4..beb0b84adb 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -9,7 +9,7 @@ from pydantic_core import core_schema from typing_extensions import TypeAliasType, TypeVar, deprecated -from . import _utils +from . import _utils, exceptions from ._json_schema import InlineDefsJsonSchemaTransformer from .messages import ToolCallPart from .tools import DeferredToolRequests, ObjectJsonSchema, RunContext, ToolDefinition @@ -316,6 +316,10 @@ def StructuredDict( # See https://github.com/pydantic/pydantic/issues/12145 if '$defs' in json_schema: json_schema = InlineDefsJsonSchemaTransformer(json_schema).walk() + if '$defs' in json_schema: + raise exceptions.UserError( + '`StructuredDict` does not currently support recursive `$ref`s and `$defs`. See https://github.com/pydantic/pydantic/issues/12145 for more information.' + ) if name: json_schema['title'] = name diff --git a/pydantic_evals/pydantic_evals/generation.py b/pydantic_evals/pydantic_evals/generation.py index c1e68a6ea8..fd3b034573 100644 --- a/pydantic_evals/pydantic_evals/generation.py +++ b/pydantic_evals/pydantic_evals/generation.py @@ -59,7 +59,8 @@ async def generate_dataset( """ output_schema = dataset_type.model_json_schema_with_evaluators(custom_evaluator_types) - # TODO(DavidM): Update this once we add better response_format and/or ResultTool support to Pydantic AI + # TODO: Use `output_type=StructuredDict(output_schema)` (and `from_dict` below) once https://github.com/pydantic/pydantic/issues/12145 + # is fixed and `StructuredDict` no longer needs to use `InlineDefsJsonSchemaTransformer`. agent = Agent( model, system_prompt=( diff --git a/tests/test_agent.py b/tests/test_agent.py index c382725c9d..fcffd2b846 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1466,6 +1466,41 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert result.output == snapshot({'make': 'Toyota', 'model': 'Camry', 'tires': [{'brand': 'Michelin', 'size': 17}]}) +def test_structured_dict_recursive_refs(): + class Node(BaseModel): + nodes: list['Node'] | dict[str, 'Node'] + + schema = Node.model_json_schema() + assert schema == snapshot( + { + '$defs': { + 'Node': { + 'properties': { + 'nodes': { + 'anyOf': [ + {'items': {'$ref': '#/$defs/Node'}, 'type': 'array'}, + {'additionalProperties': {'$ref': '#/$defs/Node'}, 'type': 'object'}, + ], + 'title': 'Nodes', + } + }, + 'required': ['nodes'], + 'title': 'Node', + 'type': 'object', + } + }, + '$ref': '#/$defs/Node', + } + ) + with pytest.raises( + UserError, + match=re.escape( + '`StructuredDict` does not currently support recursive `$ref`s and `$defs`. See https://github.com/pydantic/pydantic/issues/12145 for more information.' + ), + ): + StructuredDict(schema) + + def test_default_structured_output_mode(): def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content='hello')]) # pragma: no cover