diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index b1a0dd1350..a7b9d8d415 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -221,6 +221,9 @@ async def run( # noqa: C901 next_message: _messages.ModelRequest | None = None + run_context: RunContext[DepsT] | None = None + instructions: str | None = None + if messages and (last_message := messages[-1]): if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None: # Drop last message from history and reuse its parts @@ -242,15 +245,19 @@ async def run( # noqa: C901 ctx.deps.prompt = combined_content elif isinstance(last_message, _messages.ModelResponse): if self.user_prompt is None: - # Skip ModelRequestNode and go directly to CallToolsNode - return CallToolsNode[DepsT, NodeRunEndT](last_message) + run_context = build_run_context(ctx) + instructions = await ctx.deps.get_instructions(run_context) + if not instructions: + # If there's no new prompt or instructions, skip ModelRequestNode and go directly to CallToolsNode + return CallToolsNode[DepsT, NodeRunEndT](last_message) elif last_message.tool_calls: raise exceptions.UserError( 'Cannot provide a new user prompt when the message history contains unprocessed tool calls.' ) - # Build the run context after `ctx.deps.prompt` has been updated - run_context = build_run_context(ctx) + if not run_context: + run_context = build_run_context(ctx) + instructions = await ctx.deps.get_instructions(run_context) if messages: await self._reevaluate_dynamic_prompts(messages, run_context) @@ -267,7 +274,7 @@ async def run( # noqa: C901 next_message = _messages.ModelRequest(parts=parts) - next_message.instructions = await ctx.deps.get_instructions(run_context) + next_message.instructions = instructions if not messages and not next_message.parts and not next_message.instructions: raise exceptions.UserError('No message history, user prompt, or instructions provided') @@ -592,8 +599,8 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa try: self._next_node = await self._handle_text_response(ctx, text, text_processor) return - except ToolRetryError: - # If the text from the preview response was invalid, ignore it. + except ToolRetryError: # pragma: no cover + # If the text from the previous response was invalid, ignore it. pass # Go back to the model request node with an empty request, which means we'll essentially diff --git a/tests/test_agent.py b/tests/test_agent.py index 29b54643f0..57fbcfcedd 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2492,8 +2492,6 @@ def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelRes def test_run_with_history_ending_on_model_response_without_tool_calls_or_user_prompt(): - """Test that an agent run raises error when message_history ends on ModelResponse without tool calls or a new prompt.""" - def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover @@ -2522,7 +2520,54 @@ def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelRes ] ) - assert result.new_messages() == [] + assert result.new_messages() == snapshot([]) + + +async def test_message_history_ending_on_model_response_with_instructions(): + model = TestModel(custom_output_text='James likes cars in general, especially the Fiat 126p that his parents had.') + summarize_agent = Agent( + model, + instructions=""" + Summarize this conversation to include all important facts about the user and + what their interactions were about. + """, + ) + + message_history = [ + ModelRequest(parts=[UserPromptPart(content='Hi, my name is James')]), + ModelResponse(parts=[TextPart(content='Nice to meet you, James.')]), + ModelRequest(parts=[UserPromptPart(content='I like cars')]), + ModelResponse(parts=[TextPart(content='I like them too. Sport cars?')]), + ModelRequest(parts=[UserPromptPart(content='No, cars in general.')]), + ModelResponse(parts=[TextPart(content='Awesome. Which one do you like most?')]), + ModelRequest(parts=[UserPromptPart(content='Fiat 126p')]), + ModelResponse(parts=[TextPart(content="That's an old one, isn't it?")]), + ModelRequest(parts=[UserPromptPart(content='Yes, it is. My parents had one.')]), + ModelResponse(parts=[TextPart(content='Cool. Was it fast?')]), + ] + + result = await summarize_agent.run(message_history=message_history) + + assert result.output == snapshot('James likes cars in general, especially the Fiat 126p that his parents had.') + assert result.new_messages() == snapshot( + [ + ModelRequest( + parts=[], + instructions="""\ +Summarize this conversation to include all important facts about the user and + what their interactions were about.\ +""", + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='James likes cars in general, especially the Fiat 126p that his parents had.')], + usage=RequestUsage(input_tokens=73, output_tokens=43), + model_name='test', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) def test_empty_response(): @@ -4216,11 +4261,32 @@ class Output(BaseModel): result1 = agent1.run_sync('Hello') - # TestModel doesn't support structured output, so this will fail with retries - # But we can still verify that Agent2's instructions are used in retry requests - with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior): - agent2.run_sync(message_history=result1.new_messages()) + result2 = agent2.run_sync(message_history=result1.new_messages()) + messages = result2.new_messages() + + assert messages == snapshot( + [ + ModelRequest(parts=[], instructions='Agent 2 instructions', run_id=IsStr()), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args={'text': 'a'}, tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=51, output_tokens=9), + model_name='test', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + run_id=IsStr(), + ), + ] + ) # Verify Agent2's retry requests used Agent2's instructions (not Agent1's) requests = [m for m in messages if isinstance(m, ModelRequest)]