Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,13 +913,19 @@ async def _call_tools(

async def handle_call_or_result(
coro_or_task: Awaitable[
tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]
tuple[
_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None
]
]
| Task[tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]],
| Task[
tuple[
_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None
]
],
index: int,
) -> _messages.HandleResponseEvent | None:
try:
tool_part, tool_user_part = (
tool_part, tool_user_content = (
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
)
except exceptions.CallDeferred:
Expand All @@ -928,10 +934,10 @@ async def handle_call_or_result(
deferred_calls_by_index[index] = 'unapproved'
else:
tool_parts_by_index[index] = tool_part
if tool_user_part:
user_parts_by_index[index] = tool_user_part
if tool_user_content:
user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content)

return _messages.FunctionToolResultEvent(tool_part)
return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)

if tool_manager.should_call_sequentially(tool_calls):
for index, call in enumerate(tool_calls):
Expand Down Expand Up @@ -971,7 +977,7 @@ async def _call_tool(
tool_manager: ToolManager[DepsT],
tool_call: _messages.ToolCallPart,
tool_call_result: DeferredToolResult | None,
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]:
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None]:
try:
if tool_call_result is None:
tool_result = await tool_manager.handle_call(tool_call)
Expand Down Expand Up @@ -1048,14 +1054,7 @@ async def _call_tool(
metadata=tool_return.metadata,
)

user_part: _messages.UserPromptPart | None = None
if tool_return.content:
user_part = _messages.UserPromptPart(
content=tool_return.content,
part_kind='user-prompt',
)

return return_part, user_part
return return_part, tool_return.content or None


@dataclasses.dataclass
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,9 @@ class FunctionToolResultEvent:

_: KW_ONLY

content: str | Sequence[UserContent] | None = None
"""The content that will be sent to the model as a UserPromptPart following the result."""

event_kind: Literal['function_tool_result'] = 'function_tool_result'
"""Event type identifier, used as a discriminator."""

Expand Down
6 changes: 3 additions & 3 deletions tests/test_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D
BasicSpan(content='ctx.run_step=1'),
BasicSpan(
content=IsStr(
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
)
),
],
Expand All @@ -374,7 +374,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D
BasicSpan(content='ctx.run_step=1'),
BasicSpan(
content=IsStr(
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
)
),
],
Expand Down Expand Up @@ -435,7 +435,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D
BasicSpan(content='ctx.run_step=2'),
BasicSpan(
content=IsStr(
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
)
),
],
Expand Down
50 changes: 50 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
FinalResultEvent,
FunctionToolCallEvent,
FunctionToolResultEvent,
ImageUrl,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -1575,3 +1576,52 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen
FinalResultEvent(tool_name=None, tool_call_id=None),
]
)


async def test_stream_tool_returning_user_content():
m = TestModel()

agent = Agent(m)
assert agent.name is None

@agent.tool_plain
async def get_image() -> ImageUrl:
return ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg')

events: list[AgentStreamEvent] = []

async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[AgentStreamEvent]):
async for event in stream:
events.append(event)

await agent.run('Hello', event_stream_handler=event_stream_handler)

assert events == snapshot(
[
PartStartEvent(
index=0,
part=ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr()),
),
FunctionToolCallEvent(part=ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr())),
FunctionToolResultEvent(
result=ToolReturnPart(
tool_name='get_image',
content='See file bd38f5',
tool_call_id=IsStr(),
timestamp=IsNow(tz=timezone.utc),
),
content=[
'This is file bd38f5:',
ImageUrl(
url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg',
identifier='bd38f5',
),
],
),
PartStartEvent(index=0, part=TextPart(content='')),
FinalResultEvent(tool_name=None, tool_call_id=None),
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='{"get_image":"See ')),
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='file ')),
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='bd38f5"}')),
]
)
6 changes: 3 additions & 3 deletions tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ async def test_complex_agent_run_in_workflow(
BasicSpan(content='ctx.run_step=1'),
BasicSpan(
content=IsStr(
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
)
),
],
Expand Down Expand Up @@ -453,7 +453,7 @@ async def test_complex_agent_run_in_workflow(
BasicSpan(content='ctx.run_step=1'),
BasicSpan(
content=IsStr(
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
)
),
],
Expand Down Expand Up @@ -544,7 +544,7 @@ async def test_complex_agent_run_in_workflow(
BasicSpan(content='ctx.run_step=2'),
BasicSpan(
content=IsStr(
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
)
),
],
Expand Down