diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 9919eb85b..36aeae59b 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -913,13 +913,19 @@ async def _call_tools( async def handle_call_or_result( coro_or_task: Awaitable[ - tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None] + tuple[ + _messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None + ] ] - | Task[tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]], + | Task[ + tuple[ + _messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None + ] + ], index: int, ) -> _messages.HandleResponseEvent | None: try: - tool_part, tool_user_part = ( + tool_part, tool_user_content = ( (await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result() ) except exceptions.CallDeferred: @@ -928,10 +934,10 @@ async def handle_call_or_result( deferred_calls_by_index[index] = 'unapproved' else: tool_parts_by_index[index] = tool_part - if tool_user_part: - user_parts_by_index[index] = tool_user_part + if tool_user_content: + user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content) - return _messages.FunctionToolResultEvent(tool_part) + return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content) if tool_manager.should_call_sequentially(tool_calls): for index, call in enumerate(tool_calls): @@ -971,7 +977,7 @@ async def _call_tool( tool_manager: ToolManager[DepsT], tool_call: _messages.ToolCallPart, tool_call_result: DeferredToolResult | None, -) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]: +) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None]: try: if tool_call_result is None: tool_result = await tool_manager.handle_call(tool_call) @@ -1048,14 +1054,7 @@ async def _call_tool( metadata=tool_return.metadata, ) - user_part: _messages.UserPromptPart | None = None - if tool_return.content: - user_part = _messages.UserPromptPart( - content=tool_return.content, - part_kind='user-prompt', - ) - - return return_part, user_part + return return_part, tool_return.content or None @dataclasses.dataclass diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 71e95c57e..89e82c652 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1654,6 +1654,9 @@ class FunctionToolResultEvent: _: KW_ONLY + content: str | Sequence[UserContent] | None = None + """The content that will be sent to the model as a UserPromptPart following the result.""" + event_kind: Literal['function_tool_result'] = 'function_tool_result' """Event type identifier, used as a discriminator.""" diff --git a/tests/test_dbos.py b/tests/test_dbos.py index dfe95580b..4f2abb897 100644 --- a/tests/test_dbos.py +++ b/tests/test_dbos.py @@ -359,7 +359,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -374,7 +374,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -435,7 +435,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D BasicSpan(content='ctx.run_step=2'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' ) ), ], diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 5cdfd10db..5bf3be23a 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -19,6 +19,7 @@ FinalResultEvent, FunctionToolCallEvent, FunctionToolResultEvent, + ImageUrl, ModelMessage, ModelRequest, ModelResponse, @@ -1575,3 +1576,52 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen FinalResultEvent(tool_name=None, tool_call_id=None), ] ) + + +async def test_stream_tool_returning_user_content(): + m = TestModel() + + agent = Agent(m) + assert agent.name is None + + @agent.tool_plain + async def get_image() -> ImageUrl: + return ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg') + + events: list[AgentStreamEvent] = [] + + async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[AgentStreamEvent]): + async for event in stream: + events.append(event) + + await agent.run('Hello', event_stream_handler=event_stream_handler) + + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr()), + ), + FunctionToolCallEvent(part=ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr())), + FunctionToolResultEvent( + result=ToolReturnPart( + tool_name='get_image', + content='See file bd38f5', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), + content=[ + 'This is file bd38f5:', + ImageUrl( + url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg', + identifier='bd38f5', + ), + ], + ), + PartStartEvent(index=0, part=TextPart(content='')), + FinalResultEvent(tool_name=None, tool_call_id=None), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='{"get_image":"See ')), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='file ')), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='bd38f5"}')), + ] + ) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 1f54e9fab..9347c002a 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -424,7 +424,7 @@ async def test_complex_agent_run_in_workflow( BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -453,7 +453,7 @@ async def test_complex_agent_run_in_workflow( BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -544,7 +544,7 @@ async def test_complex_agent_run_in_workflow( BasicSpan(content='ctx.run_step=2'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' ) ), ],