Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
367d860
Add metadata support to deferred tool exceptions
cjohnhanson Nov 5, 2025
fbb0457
Reduce complexity of _call_tools function
cjohnhanson Nov 5, 2025
93c24ab
Fix typo in test_reporting.py snapshot
cjohnhanson Nov 6, 2025
42cf5b8
Address Douwe's review feedback on deferred tool metadata
cjohnhanson Nov 7, 2025
c7aaca9
Improve deferred tools metadata documentation example
cjohnhanson Nov 7, 2025
b23609a
Fix evaluation reporting snapshots for CI (COLUMNS=150)
cjohnhanson Nov 7, 2025
d4f1613
Replace ML training example with flight booking example
cjohnhanson Nov 7, 2025
8e6ac24
test: Add test support for deferred_tools_with_metadata example
cjohnhanson Nov 7, 2025
09a38a2
test: Complete deferred tool tests and remove dead branch
cjohnhanson Nov 7, 2025
c3c543f
Merge branch 'main' into feat/deferred-tool-metadata
cjohnhanson Nov 10, 2025
bb91b35
Merge branch 'main' into feat/deferred-tool-metadata
cjohnhanson Nov 10, 2025
e4b86a3
chore: Remove unused ML training test expectations and codespell ignore
cjohnhanson Nov 10, 2025
37f5790
Merge branch 'main' into feat/deferred-tool-metadata
cjohnhanson Nov 10, 2025
e22754d
Merge branch 'main' into feat/deferred-tool-metadata
cjohnhanson Nov 14, 2025
c5527af
Address PR review comments
cjohnhanson Nov 14, 2025
10c24a6
Fix lint errors and remove orphaned mocks
cjohnhanson Nov 14, 2025
f3fd801
Fix IsStr imports - use conftest version for type safety
cjohnhanson Nov 14, 2025
d3c7603
Update docs/deferred-tools.md
cjohnhanson Nov 17, 2025
d933a1d
Address Douwe's feedback
cjohnhanson Nov 17, 2025
fe54136
Merge branch 'main' into feat/deferred-tool-metadata
cjohnhanson Nov 17, 2025
41791d5
Merge branch 'main' into feat/deferred-tool-metadata
cjohnhanson Nov 17, 2025
d366b1e
Restore critical test infrastructure
cjohnhanson Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions docs/deferred-tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ PROTECTED_FILES = {'.env'}
@agent.tool
def update_file(ctx: RunContext, path: str, content: str) -> str:
if path in PROTECTED_FILES and not ctx.tool_call_approved:
raise ApprovalRequired
raise ApprovalRequired(metadata={'reason': 'protected'}) # (1)!
return f'File {path!r} updated: {content!r}'


Expand Down Expand Up @@ -77,6 +77,7 @@ DeferredToolRequests(
tool_call_id='delete_file',
),
],
metadata={'update_file_dotenv': {'reason': 'protected'}},
)
"""

Expand Down Expand Up @@ -175,6 +176,8 @@ print(result.all_messages())
"""
```

1. The optional `metadata` parameter can attach arbitrary context to deferred tool calls, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.

_(This example is complete, it can be run "as is")_

## External Tool Execution
Expand Down Expand Up @@ -209,13 +212,13 @@ from pydantic_ai import (

@dataclass
class TaskResult:
tool_call_id: str
task_id: str
result: Any


async def calculate_answer_task(tool_call_id: str, question: str) -> TaskResult:
async def calculate_answer_task(task_id: str, question: str) -> TaskResult:
await asyncio.sleep(1)
return TaskResult(tool_call_id=tool_call_id, result=42)
return TaskResult(task_id=task_id, result=42)


agent = Agent('openai:gpt-5', output_type=[str, DeferredToolRequests])
Expand All @@ -225,12 +228,11 @@ tasks: list[asyncio.Task[TaskResult]] = []

@agent.tool
async def calculate_answer(ctx: RunContext, question: str) -> str:
assert ctx.tool_call_id is not None

task = asyncio.create_task(calculate_answer_task(ctx.tool_call_id, question)) # (1)!
task_id = f'task_{len(tasks)}' # (1)!
task = asyncio.create_task(calculate_answer_task(task_id, question))
tasks.append(task)

raise CallDeferred
raise CallDeferred(metadata={'task_id': task_id}) # (2)!


async def main():
Expand All @@ -252,17 +254,19 @@ async def main():
)
],
approvals=[],
metadata={'pyd_ai_tool_call_id': {'task_id': 'task_0'}},
)
"""

done, _ = await asyncio.wait(tasks) # (2)!
done, _ = await asyncio.wait(tasks) # (3)!
task_results = [task.result() for task in done]
task_results_by_tool_call_id = {result.tool_call_id: result.result for result in task_results}
task_results_by_task_id = {result.task_id: result.result for result in task_results}

results = DeferredToolResults()
for call in requests.calls:
try:
result = task_results_by_tool_call_id[call.tool_call_id]
task_id = requests.metadata[call.tool_call_id]['task_id']
result = task_results_by_task_id[task_id]
except KeyError:
result = ModelRetry('No result for this tool call was found.')

Expand Down Expand Up @@ -324,8 +328,9 @@ async def main():
"""
```

1. In reality, you'd likely use Celery or a similar task queue to run the task in the background.
2. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
1. Generate a task ID that can be tracked independently of the tool call ID.
2. The optional `metadata` parameter passes the `task_id` so it can be matched with results later, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
3. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.

_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_

Expand Down
1 change: 1 addition & 0 deletions docs/toolsets.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ DeferredToolRequests(
tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit',
),
],
metadata={},
)
"""

Expand Down
30 changes: 27 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ async def process_tool_calls( # noqa: C901
calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results]

deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
deferred_metadata: dict[str, dict[str, Any]] = {}

if calls_to_run:
async for event in _call_tools(
Expand All @@ -899,6 +900,7 @@ async def process_tool_calls( # noqa: C901
usage_limits=ctx.deps.usage_limits,
output_parts=output_parts,
output_deferred_calls=deferred_calls,
output_deferred_metadata=deferred_metadata,
):
yield event

Expand Down Expand Up @@ -932,6 +934,7 @@ async def process_tool_calls( # noqa: C901
deferred_tool_requests = _output.DeferredToolRequests(
calls=deferred_calls['external'],
approvals=deferred_calls['unapproved'],
metadata=deferred_metadata,
)

final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
Expand All @@ -949,10 +952,12 @@ async def _call_tools(
usage_limits: _usage.UsageLimits,
output_parts: list[_messages.ModelRequestPart],
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
output_deferred_metadata: dict[str, dict[str, Any]],
) -> AsyncIterator[_messages.HandleResponseEvent]:
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {}

if usage_limits.tool_calls_limit is not None:
projected_usage = deepcopy(usage)
Expand Down Expand Up @@ -987,10 +992,12 @@ async def handle_call_or_result(
tool_part, tool_user_content = (
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
)
except exceptions.CallDeferred:
except exceptions.CallDeferred as e:
deferred_calls_by_index[index] = 'external'
except exceptions.ApprovalRequired:
deferred_metadata_by_index[index] = e.metadata
except exceptions.ApprovalRequired as e:
deferred_calls_by_index[index] = 'unapproved'
deferred_metadata_by_index[index] = e.metadata
else:
tool_parts_by_index[index] = tool_part
if tool_user_content:
Expand Down Expand Up @@ -1028,8 +1035,25 @@ async def handle_call_or_result(
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])

_populate_deferred_calls(
tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata
)


def _populate_deferred_calls(
tool_calls: list[_messages.ToolCallPart],
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']],
deferred_metadata_by_index: dict[int, dict[str, Any] | None],
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
output_deferred_metadata: dict[str, dict[str, Any]],
) -> None:
"""Populate deferred calls and metadata from indexed mappings."""
for k in sorted(deferred_calls_by_index):
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
call = tool_calls[k]
output_deferred_calls[deferred_calls_by_index[k]].append(call)
metadata = deferred_metadata_by_index[k]
if metadata is not None:
output_deferred_metadata[call.tool_call_id] = metadata


async def _call_tool(
Expand Down
14 changes: 8 additions & 6 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ class CallToolParams:

@dataclass
class _ApprovalRequired:
metadata: dict[str, Any] | None = None
kind: Literal['approval_required'] = 'approval_required'


@dataclass
class _CallDeferred:
metadata: dict[str, Any] | None = None
kind: Literal['call_deferred'] = 'call_deferred'


Expand Down Expand Up @@ -75,20 +77,20 @@ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
try:
result = await coro
return _ToolReturn(result=result)
except ApprovalRequired:
return _ApprovalRequired()
except CallDeferred:
return _CallDeferred()
except ApprovalRequired as e:
return _ApprovalRequired(metadata=e.metadata)
except CallDeferred as e:
return _CallDeferred(metadata=e.metadata)
except ModelRetry as e:
return _ModelRetry(message=e.message)

def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
if isinstance(result, _ToolReturn):
return result.result
elif isinstance(result, _ApprovalRequired):
raise ApprovalRequired()
raise ApprovalRequired(metadata=result.metadata)
elif isinstance(result, _CallDeferred):
raise CallDeferred()
raise CallDeferred(metadata=result.metadata)
elif isinstance(result, _ModelRetry):
raise ModelRetry(result.message)
else:
Expand Down
16 changes: 14 additions & 2 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,30 @@ class CallDeferred(Exception):
"""Exception to raise when a tool call should be deferred.

See [tools docs](../deferred-tools.md#deferred-tools) for more information.

Args:
metadata: Optional dictionary of metadata to attach to the deferred tool call.
This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
"""

pass
def __init__(self, metadata: dict[str, Any] | None = None):
self.metadata = metadata
super().__init__()


class ApprovalRequired(Exception):
"""Exception to raise when a tool call requires human-in-the-loop approval.

See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information.

Args:
metadata: Optional dictionary of metadata to attach to the deferred tool call.
This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
"""

pass
def __init__(self, metadata: dict[str, Any] | None = None):
self.metadata = metadata
super().__init__()


class UserError(RuntimeError):
Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class DeferredToolRequests:
"""Tool calls that require external execution."""
approvals: list[ToolCallPart] = field(default_factory=list)
"""Tool calls that require human-in-the-loop approval."""
metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Metadata for deferred tool calls, keyed by `tool_call_id`."""


@dataclass(kw_only=True)
Expand Down
8 changes: 2 additions & 6 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,9 +1712,7 @@ def my_tool(x: int) -> int:
[DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])]
)
assert await result.get_output() == snapshot(
DeferredToolRequests(
calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
)
DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])
)
responses = [c async for c, _is_last in result.stream_responses(debounce_by=None)]
assert responses == snapshot(
Expand Down Expand Up @@ -1757,9 +1755,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int:
messages = result.all_messages()
output = await result.get_output()
assert output == snapshot(
DeferredToolRequests(
approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())],
)
DeferredToolRequests(approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())])
)
assert result.is_complete

Expand Down
Loading