Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ def walk(self) -> JsonSchema:
if not self.prefer_inlined_defs and self.defs:
handled['$defs'] = {k: self._handle(v) for k, v in self.defs.items()}

elif self.recursive_refs: # pragma: no cover
elif self.recursive_refs:
# If we are preferring inlined defs and there are recursive refs, we _have_ to use a $defs+$ref structure
# We try to use whatever the original root key was, but if it is already in use,
# we modify it to avoid collisions.
defs = {key: self.defs[key] for key in self.recursive_refs}
root_ref = self.schema.get('$ref')
root_key = None if root_ref is None else re.sub(r'^#/\$defs/', '', root_ref)
if root_key is None:
if root_key is None: # pragma: no cover
root_key = self.schema.get('title', 'root')
while root_key in defs:
# Modify the root key until it is not already in use
Expand All @@ -77,6 +77,8 @@ def _handle(self, schema: JsonSchema) -> JsonSchema:
if self.prefer_inlined_defs:
while ref := schema.get('$ref'):
key = re.sub(r'^#/\$defs/', '', ref)
if key in self.recursive_refs:
break
if key in self.refs_stack:
self.recursive_refs.add(key)
break # recursive ref can't be unpacked
Expand Down
6 changes: 5 additions & 1 deletion pydantic_ai_slim/pydantic_ai/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic_core import core_schema
from typing_extensions import TypeAliasType, TypeVar, deprecated

from . import _utils
from . import _utils, exceptions
from ._json_schema import InlineDefsJsonSchemaTransformer
from .messages import ToolCallPart
from .tools import DeferredToolRequests, ObjectJsonSchema, RunContext, ToolDefinition
Expand Down Expand Up @@ -316,6 +316,10 @@ def StructuredDict(
# See https://github.com/pydantic/pydantic/issues/12145
if '$defs' in json_schema:
json_schema = InlineDefsJsonSchemaTransformer(json_schema).walk()
if '$defs' in json_schema:
raise exceptions.UserError(
'`StructuredDict` does not currently support recursive `$ref`s and `$defs`. See https://github.com/pydantic/pydantic/issues/12145 for more information.'
)

if name:
json_schema['title'] = name
Expand Down
3 changes: 2 additions & 1 deletion pydantic_evals/pydantic_evals/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ async def generate_dataset(
"""
output_schema = dataset_type.model_json_schema_with_evaluators(custom_evaluator_types)

# TODO(DavidM): Update this once we add better response_format and/or ResultTool support to Pydantic AI
# TODO: Use `output_type=StructuredDict(output_schema)` (and `from_dict` below) once https://github.com/pydantic/pydantic/issues/12145
# is fixed and `StructuredDict` no longer needs to use `InlineDefsJsonSchemaTransformer`.
agent = Agent(
model,
system_prompt=(
Expand Down
35 changes: 35 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,41 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert result.output == snapshot({'make': 'Toyota', 'model': 'Camry', 'tires': [{'brand': 'Michelin', 'size': 17}]})


def test_structured_dict_recursive_refs():
class Node(BaseModel):
nodes: list['Node'] | dict[str, 'Node']

schema = Node.model_json_schema()
assert schema == snapshot(
{
'$defs': {
'Node': {
'properties': {
'nodes': {
'anyOf': [
{'items': {'$ref': '#/$defs/Node'}, 'type': 'array'},
{'additionalProperties': {'$ref': '#/$defs/Node'}, 'type': 'object'},
],
'title': 'Nodes',
}
},
'required': ['nodes'],
'title': 'Node',
'type': 'object',
}
},
'$ref': '#/$defs/Node',
}
)
with pytest.raises(
UserError,
match=re.escape(
'`StructuredDict` does not currently support recursive `$ref`s and `$defs`. See https://github.com/pydantic/pydantic/issues/12145 for more information.'
),
):
StructuredDict(schema)


def test_default_structured_output_mode():
def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[TextPart(content='hello')]) # pragma: no cover
Expand Down