From 9efdd40d1f965428cf0974105f6ff5199e99c5a5 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 24 Oct 2025 22:06:53 +0000 Subject: [PATCH] Ensure ToolCallPart resulting from TestModel(custom_output_args=...) always holds a dict --- pydantic_ai_slim/pydantic_ai/models/test.py | 9 +- tests/models/test_model_test.py | 102 ++++++++++++++++++++ 2 files changed, 108 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 6b772365ba..170113a999 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -44,11 +44,14 @@ class _WrappedTextOutput: value: str | None -@dataclass +@dataclass(init=False) class _WrappedToolOutput: """A wrapper class to tag an output that came from the custom_output_args field.""" - value: Any | None + value: dict[str, Any] | None + + def __init__(self, value: Any | None): + self.value = pydantic_core.to_jsonable_python(value) @dataclass(init=False) @@ -364,7 +367,7 @@ def __init__(self, schema: _utils.ObjectJsonSchema, seed: int = 0): self.defs = schema.get('$defs', {}) self.seed = seed - def generate(self) -> Any: + def generate(self) -> dict[str, Any]: """Generate data for the JSON schema.""" return self._gen_any(self.schema) diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index d73e8579c3..f6b4af74b1 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -69,6 +69,40 @@ def test_custom_output_args(): agent = Agent(output_type=tuple[str, str]) result = agent.run_sync('x', model=TestModel(custom_output_args=['a', 'b'])) assert result.output == ('a', 'b') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='x', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'response': ['a', 'b']}, + tool_call_id='pyd_ai_tool_call_id__final_result', + ) + ], + usage=RequestUsage(input_tokens=51, output_tokens=7), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id='pyd_ai_tool_call_id__final_result', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ] + ) def test_custom_output_args_model(): @@ -79,12 +113,80 @@ class Foo(BaseModel): agent = Agent(output_type=Foo) result = agent.run_sync('x', model=TestModel(custom_output_args={'foo': 'a', 'bar': 1})) assert result.output == Foo(foo='a', bar=1) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='x', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'foo': 'a', 'bar': 1}, + tool_call_id='pyd_ai_tool_call_id__final_result', + ) + ], + usage=RequestUsage(input_tokens=51, output_tokens=6), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id='pyd_ai_tool_call_id__final_result', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ] + ) def test_output_type(): agent = Agent(output_type=tuple[str, str]) result = agent.run_sync('x', model=TestModel()) assert result.output == ('a', 'a') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='x', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'response': ['a', 'a']}, + tool_call_id='pyd_ai_tool_call_id__final_result', + ) + ], + usage=RequestUsage(input_tokens=51, output_tokens=7), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id='pyd_ai_tool_call_id__final_result', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ] + ) def test_tool_retry():