diff --git a/docs/output.md b/docs/output.md index 182a753944..2fa3719b76 100644 --- a/docs/output.md +++ b/docs/output.md @@ -306,6 +306,54 @@ print(repr(result.output)) _(This example is complete, it can be run "as is")_ +##### Handling Multiple Output Tool Calls + +When a model calls multiple output tools in the same response (for example, when you have multiple output types in a union or list), the agent's `end_strategy` parameter controls whether all output tool functions are executed or only the first one: + +- `'early'` (default): Only the first output tool is executed, and additional output tool calls are skipped once a final result is found. This is the default behavior. +- `'exhaustive'`: All output tool functions are executed, even after a final result is found. The first tool's result is still used as the final output. + +This parameter also applies to [function tools](tools.md), not just output tools. + +```python {title="exhaustive_output.py"} +from pydantic import BaseModel + +from pydantic_ai import Agent, ToolOutput + +calls_made: list[str] = [] + +class ResultA(BaseModel): + value: str + +class ResultB(BaseModel): + value: str + +def process_a(result: ResultA) -> ResultA: + calls_made.append('A') + return result + +def process_b(result: ResultB) -> ResultB: + calls_made.append('B') + return result + +# With 'exhaustive' strategy, both functions will be called +agent = Agent( + 'openai:gpt-5', + output_type=[ + ToolOutput(process_a, name='return_a'), + ToolOutput(process_b, name='return_b'), + ], + end_strategy='exhaustive', # (1)! +) + +# If the model calls both output tools, both process_a and process_b will execute +# calls_made will be ['A', 'B'] +``` + +1. Setting `end_strategy='exhaustive'` ensures all output tool functions are executed, which can be useful for side effects like logging, metrics, or validation. + +_(This example is complete, it can be run "as is")_ + #### Native Output Native Output mode uses a model's native "Structured Outputs" feature (aka "JSON Schema response format"), where the model is forced to only output text matching the provided JSON schema. Note that this is not supported by all models, and sometimes comes with restrictions. For example, Anthropic does not support this at all, and Gemini cannot use tools at the same time as structured output, and attempting to do so will result in an error. diff --git a/docs/tools-advanced.md b/docs/tools-advanced.md index f01a243c11..b257764b85 100644 --- a/docs/tools-advanced.md +++ b/docs/tools-advanced.md @@ -378,6 +378,14 @@ If a tool requires sequential/serial execution, you can pass the [`sequential`][ Async functions are run on the event loop, while sync functions are offloaded to threads. To get the best performance, _always_ use an async function _unless_ you're doing blocking I/O (and there's no way to use a non-blocking library instead) or CPU-bound work (like `numpy` or `scikit-learn` operations), so that simple functions are not offloaded to threads unnecessarily. +!!! note "Handling tool calls after a final result" + When a model returns multiple tool calls including an [output tool](output.md) (which produces a final result), the agent's `end_strategy` parameter controls whether remaining function tools are executed: + + - `'early'` (default): Function tools are skipped once a final result is found + - `'exhaustive'`: All function tools are executed even after a final result is found + + This is useful when function tools have side effects (like logging or metrics) that should always execute. See [Handling Multiple Output Tool Calls](output.md#handling-multiple-output-tool-calls) for more details. + !!! note "Limiting tool executions" You can cap tool executions within a run using [`UsageLimits(tool_calls_limit=...)`](agents.md#usage-limits). The counter increments only after a successful tool invocation. Output tools (used for [structured output](output.md)) are not counted in the `tool_calls` metric. diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 186a386e4a..1cdd2d6c39 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -61,8 +61,10 @@ EndStrategy = Literal['early', 'exhaustive'] """The strategy for handling multiple tool calls when a final result is found. -- `'early'`: Stop processing other tool calls once a final result is found -- `'exhaustive'`: Process all tool calls even after finding a final result +- `'early'`: Stop processing other tool calls (both function tools and output tools) once a final result is found +- `'exhaustive'`: Process all tool calls (both function tools and output tools) even after finding a final result + +This applies to both function tools and output tools. """ DepsT = TypeVar('DepsT') OutputT = TypeVar('OutputT') @@ -833,24 +835,27 @@ async def process_tool_calls( # noqa: C901 # First, we handle output tool calls for call in tool_calls_by_kind['output']: - if final_result: - if final_result.tool_call_id == call.tool_call_id: - part = _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Final result processed.', - tool_call_id=call.tool_call_id, - ) - else: - yield _messages.FunctionToolCallEvent(call) - part = _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Output tool not used - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - yield _messages.FunctionToolResultEvent(part) - + # In case we got two tool calls with the same ID + if final_result and final_result.tool_call_id == call.tool_call_id: + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Final result processed.', + tool_call_id=call.tool_call_id, + ) output_parts.append(part) - else: + # Early strategy is chosen and final result is already set + elif ctx.deps.end_strategy == 'early' and final_result: + yield _messages.FunctionToolCallEvent(call) + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Output tool not used - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + yield _messages.FunctionToolResultEvent(part) + output_parts.append(part) + # Early strategy is chosen and final result is not yet set + # Or exhaustive strategy is chosen + elif (ctx.deps.end_strategy == 'early' and not final_result) or ctx.deps.end_strategy == 'exhaustive': try: result_data = await tool_manager.handle_call(call) except exceptions.UnexpectedModelBehavior as e: @@ -872,7 +877,15 @@ async def process_tool_calls( # noqa: C901 tool_call_id=call.tool_call_id, ) output_parts.append(part) - final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) + + # Even in exhaustive mode, use the first output tool's result as the final result + if not final_result: + final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) + # Unknown strategy or invalid state + else: + assert False, ( + f'Unexpected state: end_strategy={ctx.deps.end_strategy!r}, final_result={final_result!r};' + ) # pragma: no cover # Then, we handle function tool calls calls_to_run: list[_messages.ToolCallPart] = [] diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4cd353b44a..e21af557d9 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -115,7 +115,13 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _name: str | None end_strategy: EndStrategy - """Strategy for handling tool calls when a final result is found.""" + """Strategy for handling tool calls when a final result is found. + + - `'early'`: Stop processing other tool calls once a final result is found (default) + - `'exhaustive'`: Process all tool calls even after finding a final result + + This applies to both function tools and output tools. + """ model_settings: ModelSettings | None """Optional model request settings to use for this agents's runs, by default. diff --git a/tests/test_agent.py b/tests/test_agent.py index c912334434..802b74bec2 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3132,7 +3132,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover ), ToolReturnPart( tool_name='final_result', - content='Output tool not used - a final result was already processed.', + content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), @@ -3163,6 +3163,174 @@ def deferred_tool(x: int) -> int: # pragma: no cover ] ) + def test_exhaustive_strategy_calls_all_output_tools(self): + """Test that 'exhaustive' strategy executes all output tool functions.""" + output_tools_called: list[str] = [] + + class FirstOutput(BaseModel): + value: str + + class SecondOutput(BaseModel): + value: str + + def process_first(output: FirstOutput) -> FirstOutput: + """Process first output.""" + output_tools_called.append('first') + return output + + def process_second(output: SecondOutput) -> SecondOutput: + """Process second output.""" + output_tools_called.append('second') + return output + + def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + return ModelResponse( + parts=[ + ToolCallPart('first_output', {'value': 'first'}), + ToolCallPart('second_output', {'value': 'second'}), + ], + ) + + agent = Agent( + FunctionModel(return_model), + output_type=[ + ToolOutput(process_first, name='first_output'), + ToolOutput(process_second, name='second_output'), + ], + end_strategy='exhaustive', + ) + + result = agent.run_sync('test exhaustive output tools') + + # Verify the result came from the first output tool + assert isinstance(result.output, FirstOutput) + assert result.output.value == 'first' + + # Verify both output tools were called + assert output_tools_called == ['first', 'second'] + + # Verify we got tool returns in the correct order + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='test exhaustive output tools', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='first_output', args={'value': 'first'}, tool_call_id=IsStr()), + ToolCallPart(tool_name='second_output', args={'value': 'second'}, tool_call_id=IsStr()), + ], + usage=RequestUsage(input_tokens=54, output_tokens=10), + model_name='function:return_model:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='first_output', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturnPart( + tool_name='second_output', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), + ], + run_id=IsStr(), + ), + ] + ) + + def test_early_strategy_does_not_call_additional_output_tools(self): + """Test that 'early' strategy does not execute additional output tool functions.""" + output_tools_called: list[str] = [] + + class FirstOutput(BaseModel): + value: str + + class SecondOutput(BaseModel): + value: str + + def process_first(output: FirstOutput) -> FirstOutput: + """Process first output.""" + output_tools_called.append('first') + return output + + def process_second(output: SecondOutput) -> SecondOutput: # pragma: no cover + """Process second output.""" + output_tools_called.append('second') + return output + + def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + return ModelResponse( + parts=[ + ToolCallPart('first_output', {'value': 'first'}), + ToolCallPart('second_output', {'value': 'second'}), + ], + ) + + agent = Agent( + FunctionModel(return_model), + output_type=[ + ToolOutput(process_first, name='first_output'), + ToolOutput(process_second, name='second_output'), + ], + end_strategy='early', + ) + + result = agent.run_sync('test early output tools') + + # Verify the result came from the first output tool + assert isinstance(result.output, FirstOutput) + assert result.output.value == 'first' + + # Verify only the first output tool was called + assert output_tools_called == ['first'] + + # Verify we got tool returns in the correct order + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='test early output tools', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='first_output', args={'value': 'first'}, tool_call_id=IsStr()), + ToolCallPart(tool_name='second_output', args={'value': 'second'}, tool_call_id=IsStr()), + ], + usage=RequestUsage(input_tokens=54, output_tokens=10), + model_name='function:return_model:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='first_output', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturnPart( + tool_name='second_output', + content='Output tool not used - a final result was already processed.', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), + ], + run_id=IsStr(), + ), + ] + ) + def test_early_strategy_with_final_result_in_middle(self): """Test that 'early' strategy stops at first final result, regardless of position.""" tool_called = [] diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 0c6a46f3c0..b1d28ca9e8 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -12,6 +12,7 @@ import pytest from inline_snapshot import snapshot from pydantic import BaseModel +from pydantic_core import ErrorDetails from pydantic_ai import ( Agent, @@ -943,7 +944,7 @@ def another_tool(y: int) -> int: ), ToolReturnPart( tool_name='final_result', - content='Output tool not used - a final result was already processed.', + content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ), @@ -1585,12 +1586,12 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf FunctionToolResultEvent( result=RetryPromptPart( content=[ - { - 'type': 'missing', - 'loc': ('value',), - 'msg': 'Field required', - 'input': {'bad_value': 'invalid'}, - } + ErrorDetails( + type='missing', + loc=('value',), + msg='Field required', + input={'bad_value': 'invalid'}, + ), ], tool_name='final_result', tool_call_id=IsStr(), @@ -2191,3 +2192,81 @@ async def text_stream(_messages: list[ModelMessage], agent_info: AgentInfo) -> A chunks.append(response_data) assert chunks == snapshot([[1], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]) + + +async def test_exhaustive_strategy_multiple_final_results_with_retry(): + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: + # Return two final_result calls in one response + # The second one will fail validation + yield {1: DeltaToolCall('final_result', '{"value": "first"}')} + yield {2: DeltaToolCall('final_result', '{"value": 123}')} # Invalid - not a string + + agent = Agent(FunctionModel(stream_function=sf), output_type=OutputType, end_strategy='exhaustive') + + async with agent.run_stream('test retry') as result: + response = await result.get_output() + # Should use the first final result + assert response.value == snapshot('first') + + # Verify we got the expected message flow with retry for the second final_result + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='test retry', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='final_result', args='{"value": "first"}', tool_call_id=IsStr()), + ToolCallPart(tool_name='final_result', args='{"value": 123}', tool_call_id=IsStr()), + ], + usage=RequestUsage(input_tokens=50, output_tokens=7), + model_name='function::sf', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), + RetryPromptPart( + content=[ + ErrorDetails( + type='string_type', + loc=('value',), + msg='Input should be a valid string', + input=123, + ) + ], + tool_name='final_result', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), + ], + run_id=IsStr(), + ), + ] + ) + + +async def test_exhaustive_strategy_final_result_unexpected_behavior(): + class CustomOutputType(BaseModel): + value: str + + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: + # Return two final_result calls where second one is malformed + yield {1: DeltaToolCall('final_result', '{"value": "first"}')} + yield {2: DeltaToolCall('final_result', 'not valid json')} # Malformed JSON + + agent = Agent( + FunctionModel(stream_function=sf), output_type=CustomOutputType, end_strategy='exhaustive', output_retries=0 + ) + + # Should raise because the second final_result has invalid JSON + with pytest.raises(UnexpectedModelBehavior): + async with agent.run_stream('test') as result: + await result.get_output()