-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Allow user_prompt in HITL
#3528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -307,10 +307,6 @@ async def _handle_deferred_tool_results( # noqa: C901 | |
| raise exceptions.UserError( | ||
| 'Tool call results were provided, but the message history does not contain any unprocessed tool calls.' | ||
| ) | ||
| if self.user_prompt is not None: | ||
| raise exceptions.UserError( | ||
| 'Cannot provide a new user prompt when the message history contains unprocessed tool calls.' | ||
| ) | ||
|
|
||
| tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None = None | ||
| tool_call_results = {} | ||
|
|
@@ -338,7 +334,9 @@ async def _handle_deferred_tool_results( # noqa: C901 | |
| tool_call_results[part.tool_call_id] = 'skip' | ||
|
|
||
| # Skip ModelRequestNode and go directly to CallToolsNode | ||
| return CallToolsNode[DepsT, NodeRunEndT](last_model_response, tool_call_results=tool_call_results) | ||
| return CallToolsNode[DepsT, NodeRunEndT]( | ||
| last_model_response, tool_call_results=tool_call_results, user_prompt=self.user_prompt | ||
| ) | ||
|
|
||
| async def _reevaluate_dynamic_prompts( | ||
| self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT] | ||
|
|
@@ -543,6 +541,7 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]): | |
|
|
||
| model_response: _messages.ModelResponse | ||
| tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None = None | ||
| user_prompt: str | Sequence[_messages.UserContent] | None = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a docstring to clarify that this user prompt will only sent to the model along with tool call results if the model response being processed has tool calls. If it had final output, this user prompt is ignored. |
||
|
|
||
| _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False) | ||
| _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( | ||
|
|
@@ -723,6 +722,10 @@ async def _handle_tool_calls( | |
| final_result = output_final_result[0] | ||
| self._next_node = self._handle_final_result(ctx, final_result, output_parts) | ||
| else: | ||
| # Add user prompt if provided, after all tool return parts | ||
| if self.user_prompt is not None: | ||
| output_parts.append(_messages.UserPromptPart(self.user_prompt)) | ||
|
|
||
| instructions = await ctx.deps.get_instructions(run_context) | ||
| self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( | ||
| _messages.ModelRequest(parts=output_parts, instructions=instructions) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5583,15 +5583,6 @@ async def test_run_with_deferred_tool_results_errors(): | |
| deferred_tool_results=DeferredToolResults(approvals={'create_file': True}), | ||
| ) | ||
|
|
||
| with pytest.raises( | ||
| UserError, match='Cannot provide a new user prompt when the message history contains unprocessed tool calls.' | ||
| ): | ||
| await agent.run( | ||
| 'Hello again', | ||
| message_history=message_history, | ||
| deferred_tool_results=DeferredToolResults(approvals={'create_file': True}), | ||
| ) | ||
|
|
||
| message_history: list[ModelMessage] = [ | ||
| ModelRequest(parts=[UserPromptPart(content='Hello')]), | ||
| ModelResponse( | ||
|
|
@@ -5628,6 +5619,57 @@ async def test_run_with_deferred_tool_results_errors(): | |
| ) | ||
|
|
||
|
|
||
| async def test_user_prompt_with_deferred_tool_results(): | ||
| """Test that user_prompt can be provided alongside deferred_tool_results.""" | ||
| from pydantic_ai.exceptions import ApprovalRequired | ||
|
|
||
| def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: | ||
| # First call: model requests tool approval | ||
| if len(messages) == 1: | ||
| return ModelResponse( | ||
| parts=[ | ||
| ToolCallPart( | ||
| tool_name='update_file', tool_call_id='update_file_1', args={'path': '.env', 'content': ''} | ||
| ), | ||
| ] | ||
| ) | ||
| # Second call: model responds to tool results and user prompt | ||
| else: | ||
| # Verify we received both tool results and user prompt | ||
| last_request = messages[-1] | ||
| assert isinstance(last_request, ModelRequest) | ||
| has_tool_return = any(isinstance(p, ToolReturnPart) for p in last_request.parts) | ||
| has_user_prompt = any(isinstance(p, UserPromptPart) for p in last_request.parts) | ||
| assert has_tool_return, 'Expected tool return part in request' | ||
| assert has_user_prompt, 'Expected user prompt part in request' | ||
|
|
||
| # Get user prompt content | ||
| user_prompt_content = next(p.content for p in last_request.parts if isinstance(p, UserPromptPart)) | ||
| return ModelResponse(parts=[TextPart(f'Approved and {user_prompt_content}')]) | ||
|
|
||
| agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests]) | ||
|
|
||
| @agent.tool | ||
| def update_file(ctx: RunContext, path: str, content: str) -> str: | ||
| if path == '.env' and not ctx.tool_call_approved: | ||
| raise ApprovalRequired | ||
| return f'File {path!r} updated' | ||
|
|
||
| # First run: get deferred tool requests | ||
| result = await agent.run('Update .env file') | ||
| assert isinstance(result.output, DeferredToolRequests) | ||
| assert len(result.output.approvals) == 1 | ||
|
|
||
| messages = result.all_messages() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we include this in the test with a snapshot, as well as |
||
|
|
||
| # Second run: provide approvals AND user prompt | ||
| results = DeferredToolResults(approvals={result.output.approvals[0].tool_call_id: True}) | ||
| result2 = await agent.run('continue with the operation', message_history=messages, deferred_tool_results=results) | ||
|
|
||
| assert isinstance(result2.output, str) | ||
| assert 'continue with the operation' in result2.output | ||
|
|
||
|
|
||
| def test_tool_requires_approval_error(): | ||
| agent = Agent('test') | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we can also remove the error here, now that we have
CallToolsNode.user_prompt:pydantic-ai/pydantic_ai_slim/pydantic_ai/_agent_graph.py
Lines 246 to 256 in 5544e9f