diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 6d70471bd7..bd0245ae83 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload -from pydantic import TypeAdapter, ValidationError +from pydantic import Json, TypeAdapter, ValidationError from pydantic_core import SchemaValidator, to_json from typing_extensions import Self, TypedDict, TypeVar, assert_never @@ -624,21 +624,33 @@ def __init__( json_schema = self._function_schema.json_schema json_schema['description'] = self._function_schema.description else: - type_adapter: TypeAdapter[Any] + json_schema_type_adapter: TypeAdapter[Any] + validation_type_adapter: TypeAdapter[Any] if _utils.is_model_like(output): - type_adapter = TypeAdapter(output) + json_schema_type_adapter = validation_type_adapter = TypeAdapter(output) else: self.outer_typed_dict_key = 'response' + output_type: type[OutputDataT] = cast(type[OutputDataT], output) + response_data_typed_dict = TypedDict( # noqa: UP013 'response_data_typed_dict', - {'response': cast(type[OutputDataT], output)}, # pyright: ignore[reportInvalidTypeForm] + {'response': output_type}, # pyright: ignore[reportInvalidTypeForm] + ) + json_schema_type_adapter = TypeAdapter(response_data_typed_dict) + + # More lenient validator: allow either the native type or a JSON string containing it + # i.e. `response: OutputDataT | Json[OutputDataT]`, as some models don't follow the schema correctly, + # e.g. `BedrockConverseModel('us.meta.llama3-2-11b-instruct-v1:0')` + response_validation_typed_dict = TypedDict( # noqa: UP013 + 'response_validation_typed_dict', + {'response': output_type | Json[output_type]}, # pyright: ignore[reportInvalidTypeForm] ) - type_adapter = TypeAdapter(response_data_typed_dict) + validation_type_adapter = TypeAdapter(response_validation_typed_dict) # Really a PluggableSchemaValidator, but it's API-compatible - self.validator = cast(SchemaValidator, type_adapter.validator) + self.validator = cast(SchemaValidator, validation_type_adapter.validator) json_schema = _utils.check_object_json_schema( - type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) + json_schema_type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) ) if self.outer_typed_dict_key: diff --git a/tests/test_agent.py b/tests/test_agent.py index d86b17f164..2194160cd6 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -75,6 +75,37 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert result.output == ('foo', 'bar') +class Person(BaseModel): + name: str + + +def test_result_list_of_models_with_stringified_response(): + def return_list(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + # Simulate providers that return the nested payload as a JSON string under "response" + args_json = json.dumps( + { + 'response': json.dumps( + [ + {'name': 'John Doe'}, + {'name': 'Jane Smith'}, + ] + ) + } + ) + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(return_list), output_type=list[Person]) + + result = agent.run_sync('Hello') + assert result.output == snapshot( + [ + Person(name='John Doe'), + Person(name='Jane Smith'), + ] + ) + + class Foo(BaseModel): a: int b: str