Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ async def run( # noqa: C901

next_message: _messages.ModelRequest | None = None

run_context: RunContext[DepsT] | None = None
instructions: str | None = None

if messages and (last_message := messages[-1]):
if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None:
# Drop last message from history and reuse its parts
Expand All @@ -242,15 +245,19 @@ async def run( # noqa: C901
ctx.deps.prompt = combined_content
elif isinstance(last_message, _messages.ModelResponse):
if self.user_prompt is None:
# Skip ModelRequestNode and go directly to CallToolsNode
return CallToolsNode[DepsT, NodeRunEndT](last_message)
run_context = build_run_context(ctx)
instructions = await ctx.deps.get_instructions(run_context)
if not instructions:
# If there's no new prompt or instructions, skip ModelRequestNode and go directly to CallToolsNode
return CallToolsNode[DepsT, NodeRunEndT](last_message)
elif last_message.tool_calls:
raise exceptions.UserError(
'Cannot provide a new user prompt when the message history contains unprocessed tool calls.'
)

# Build the run context after `ctx.deps.prompt` has been updated
run_context = build_run_context(ctx)
if not run_context:
run_context = build_run_context(ctx)
instructions = await ctx.deps.get_instructions(run_context)

if messages:
await self._reevaluate_dynamic_prompts(messages, run_context)
Expand All @@ -267,7 +274,7 @@ async def run( # noqa: C901

next_message = _messages.ModelRequest(parts=parts)

next_message.instructions = await ctx.deps.get_instructions(run_context)
next_message.instructions = instructions

if not messages and not next_message.parts and not next_message.instructions:
raise exceptions.UserError('No message history, user prompt, or instructions provided')
Expand Down Expand Up @@ -592,8 +599,8 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
try:
self._next_node = await self._handle_text_response(ctx, text, text_processor)
return
except ToolRetryError:
# If the text from the preview response was invalid, ignore it.
except ToolRetryError: # pragma: no cover
# If the text from the previous response was invalid, ignore it.
pass

# Go back to the model request node with an empty request, which means we'll essentially
Expand Down
82 changes: 74 additions & 8 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2492,8 +2492,6 @@ def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelRes


def test_run_with_history_ending_on_model_response_without_tool_calls_or_user_prompt():
"""Test that an agent run raises error when message_history ends on ModelResponse without tool calls or a new prompt."""

def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover

Expand Down Expand Up @@ -2522,7 +2520,54 @@ def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelRes
]
)

assert result.new_messages() == []
assert result.new_messages() == snapshot([])


async def test_message_history_ending_on_model_response_with_instructions():
model = TestModel(custom_output_text='James likes cars in general, especially the Fiat 126p that his parents had.')
summarize_agent = Agent(
model,
instructions="""
Summarize this conversation to include all important facts about the user and
what their interactions were about.
""",
)

message_history = [
ModelRequest(parts=[UserPromptPart(content='Hi, my name is James')]),
ModelResponse(parts=[TextPart(content='Nice to meet you, James.')]),
ModelRequest(parts=[UserPromptPart(content='I like cars')]),
ModelResponse(parts=[TextPart(content='I like them too. Sport cars?')]),
ModelRequest(parts=[UserPromptPart(content='No, cars in general.')]),
ModelResponse(parts=[TextPart(content='Awesome. Which one do you like most?')]),
ModelRequest(parts=[UserPromptPart(content='Fiat 126p')]),
ModelResponse(parts=[TextPart(content="That's an old one, isn't it?")]),
ModelRequest(parts=[UserPromptPart(content='Yes, it is. My parents had one.')]),
ModelResponse(parts=[TextPart(content='Cool. Was it fast?')]),
]

result = await summarize_agent.run(message_history=message_history)

assert result.output == snapshot('James likes cars in general, especially the Fiat 126p that his parents had.')
assert result.new_messages() == snapshot(
[
ModelRequest(
parts=[],
instructions="""\
Summarize this conversation to include all important facts about the user and
what their interactions were about.\
""",
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='James likes cars in general, especially the Fiat 126p that his parents had.')],
usage=RequestUsage(input_tokens=73, output_tokens=43),
model_name='test',
timestamp=IsDatetime(),
run_id=IsStr(),
),
]
)


def test_empty_response():
Expand Down Expand Up @@ -4216,11 +4261,32 @@ class Output(BaseModel):

result1 = agent1.run_sync('Hello')

# TestModel doesn't support structured output, so this will fail with retries
# But we can still verify that Agent2's instructions are used in retry requests
with capture_run_messages() as messages:
with pytest.raises(UnexpectedModelBehavior):
agent2.run_sync(message_history=result1.new_messages())
result2 = agent2.run_sync(message_history=result1.new_messages())
messages = result2.new_messages()

assert messages == snapshot(
[
ModelRequest(parts=[], instructions='Agent 2 instructions', run_id=IsStr()),
ModelResponse(
parts=[ToolCallPart(tool_name='final_result', args={'text': 'a'}, tool_call_id=IsStr())],
usage=RequestUsage(input_tokens=51, output_tokens=9),
model_name='test',
timestamp=IsDatetime(),
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
content='Final result processed.',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
]
)

# Verify Agent2's retry requests used Agent2's instructions (not Agent1's)
requests = [m for m in messages if isinstance(m, ModelRequest)]
Expand Down