From 5534961bfb9285c5fd4b95809a150d591147c891 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 10 Nov 2024 14:59:15 +0000 Subject: [PATCH 1/2] last_run_messages and retry on unknown --- pydantic_ai/agent.py | 56 +++++++++++++++++++++-------------------- tests/test_agent.py | 37 ++++++++++++++++++++++++++- tests/test_streaming.py | 19 +++++++++++++- 3 files changed, 83 insertions(+), 29 deletions(-) diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index d5792a1522..997edd99d1 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -2,7 +2,7 @@ import asyncio from collections.abc import AsyncIterator, Sequence -from contextlib import asynccontextmanager, suppress +from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, Callable, Generic, Literal, cast, final, overload @@ -44,6 +44,11 @@ class Agent(Generic[AgentDeps, ResponseData]): _default_deps: AgentDeps _max_result_retries: int _current_result_retry: int + last_run_messages: list[_messages.Message] | None = None + """The messages from the last run, useful when a run raised an exception. + + Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience. + """ def __init__( self, @@ -100,6 +105,7 @@ async def run( deps = self._default_deps new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history) + self.last_run_messages = messages for retriever in self._retrievers.values(): retriever.reset() @@ -197,6 +203,7 @@ async def run_stream( deps = self._default_deps new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history) + self.last_run_messages = messages for retriever in self._retrievers.values(): retriever.reset() @@ -413,21 +420,17 @@ async def _handle_model_response( raise exceptions.UnexpectedModelBehaviour('Received empty tool call message') # otherwise we run all retriever functions in parallel + messages: list[_messages.Message] = [] tasks: list[asyncio.Task[_messages.Message]] = [] - try: - for call in model_response.calls: - retriever = self._retrievers.get(call.tool_name) - if retriever is None: - # should this be a retry error? - raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') + for call in model_response.calls: + if retriever := self._retrievers.get(call.tool_name): tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) - except BaseException: - await _cancel_tasks(tasks) - raise + else: + messages.append(self._unknown_tool(call.tool_name)) with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): - new_messages = await asyncio.gather(*tasks) - return _utils.Either(right=new_messages) + messages += await asyncio.gather(*tasks) + return _utils.Either(right=messages) else: assert_never(model_response) @@ -479,16 +482,11 @@ async def _handle_streamed_model_response( # we now run all retriever functions in parallel tasks: list[asyncio.Task[_messages.Message]] = [] - try: - for call in structured_msg.calls: - retriever = self._retrievers.get(call.tool_name) - if retriever is None: - raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') + for call in structured_msg.calls: + if retriever := self._retrievers.get(call.tool_name): tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) - except BaseException: - # otherwise we'll get warnings about coroutines not awaited - await _cancel_tasks(tasks) - raise + else: + messages.append(self._unknown_tool(call.tool_name)) with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): messages += await asyncio.gather(*tasks) @@ -516,9 +514,13 @@ async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]: messages.append(_messages.SystemPrompt(prompt)) return messages - -async def _cancel_tasks(tasks: list[asyncio.Task[_messages.Message]]) -> None: - for task in tasks: - task.cancel() - with suppress(asyncio.CancelledError): - await asyncio.gather(*tasks) + def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt: + self._incr_result_retry() + names = list(self._retrievers.keys()) + if self._result_schema: + names.extend(self._result_schema.tool_names()) + if names: + msg = f'Available tools: {", ".join(names)}' + else: + msg = 'No tools available.' + return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}') diff --git a/tests/test_agent.py b/tests/test_agent.py index ac6eafa263..3c1ad52f12 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -455,5 +455,40 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse: agent = Agent(FunctionModel(empty), deps=None) - with pytest.raises(AgentError, match="caused by unexpected model behavior: Unknown function name: 'foobar'"): + with pytest.raises(AgentError, match=r'model behavior: Exceeded maximum retries \(1\) for result validation'): agent.run_sync('Hello') + assert agent.last_run_messages == snapshot( + [ + UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ModelStructuredResponse( + calls=[ToolCall(tool_name='foobar', args=ArgsJson(args_json='{}'))], timestamp=IsNow(tz=timezone.utc) + ), + RetryPrompt(content="Unknown tool name: 'foobar'. No tools available.", timestamp=IsNow(tz=timezone.utc)), + ModelStructuredResponse( + calls=[ToolCall(tool_name='foobar', args=ArgsJson(args_json='{}'))], timestamp=IsNow(tz=timezone.utc) + ), + ] + ) + + +def test_unknown_retriever_fix(): + def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse: + if len(m) > 1: + return ModelTextResponse(content='success') + else: + return ModelStructuredResponse(calls=[ToolCall.from_json('foobar', '{}')]) + + agent = Agent(FunctionModel(empty), deps=None) + + result = agent.run_sync('Hello') + assert result.response == 'success' + assert result.all_messages() == snapshot( + [ + UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ModelStructuredResponse( + calls=[ToolCall(tool_name='foobar', args=ArgsJson(args_json='{}'))], timestamp=IsNow(tz=timezone.utc) + ), + RetryPrompt(content="Unknown tool name: 'foobar'. No tools available.", timestamp=IsNow(tz=timezone.utc)), + ModelTextResponse(content='success', timestamp=IsNow(tz=timezone.utc)), + ] + ) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 7d9179aa99..5db52bfae3 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -14,6 +14,7 @@ Message, ModelStructuredResponse, ModelTextResponse, + RetryPrompt, ToolCall, ToolReturn, UserPrompt, @@ -242,6 +243,22 @@ def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterab agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=tuple[str, int]) - with pytest.raises(AgentError, match="caused by unexpected model behavior: Unknown function name: 'foobar'"): + @agent.retriever_plain + async def ret_a(x: str) -> str: + return x + + with pytest.raises(AgentError, match=r'model behavior: Exceeded maximum retries \(1\) for result validation'): async with agent.run_stream('hello'): pass + assert agent.last_run_messages == snapshot( + [ + UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)), + ModelStructuredResponse( + calls=[ToolCall(tool_name='foobar', args=ArgsJson(args_json='{}'))], timestamp=IsNow(tz=timezone.utc) + ), + RetryPrompt( + content="Unknown tool name: 'foobar'. Available tools: ret_a, final_result", + timestamp=IsNow(tz=timezone.utc), + ), + ] + ) From 7015549c71c92d5a61be42cc0c62ff6377b7ed54 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 10 Nov 2024 15:06:13 +0000 Subject: [PATCH 2/2] remove AgentError --- pydantic_ai/__init__.py | 4 +- pydantic_ai/agent.py | 142 +++++++++++++++++------------------- pydantic_ai/exceptions.py | 50 +------------ tests/models/test_gemini.py | 21 ++---- tests/models/test_openai.py | 4 +- tests/test_agent.py | 6 +- tests/test_streaming.py | 12 +-- 7 files changed, 83 insertions(+), 156 deletions(-) diff --git a/pydantic_ai/__init__.py b/pydantic_ai/__init__.py index 1dea86b60f..5d9c31e1f1 100644 --- a/pydantic_ai/__init__.py +++ b/pydantic_ai/__init__.py @@ -2,7 +2,7 @@ from .agent import Agent from .call_typing import CallContext -from .exceptions import AgentError, ModelRetry, UnexpectedModelBehaviour, UserError +from .exceptions import ModelRetry, UnexpectedModelBehaviour, UserError -__all__ = 'Agent', 'AgentError', 'CallContext', 'ModelRetry', 'UnexpectedModelBehaviour', 'UserError', '__version__' +__all__ = 'Agent', 'CallContext', 'ModelRetry', 'UnexpectedModelBehaviour', 'UserError', '__version__' __version__ = version('pydantic_ai') diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 997edd99d1..095f98aa4c 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -7,7 +7,6 @@ from typing import Any, Callable, Generic, Literal, cast, final, overload import logfire_api -from pydantic import ValidationError from typing_extensions import assert_never from . import _result, _retriever as _r, _system_prompt, _utils, exceptions, messages as _messages, models, result @@ -120,39 +119,34 @@ async def run( model_name=model_used.name(), ) as run_span: run_step = 0 - try: - while True: - run_step += 1 - with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span: - model_response, request_cost = await agent_model.request(messages) - model_req_span.set_attribute('response', model_response) - model_req_span.set_attribute('cost', request_cost) - model_req_span.message = f'model request -> {model_response.role}' - - messages.append(model_response) - cost += request_cost - - with _logfire.span('handle model response') as handle_span: - either = await self._handle_model_response(model_response, deps) - - if left := either.left: - # left means return a streamed result - run_span.set_attribute('all_messages', messages) - run_span.set_attribute('cost', cost) - handle_span.set_attribute('result', left.value) - handle_span.message = 'handle model response -> final result' - return result.RunResult(messages, new_message_index, left.value, cost) - else: - # right means continue the conversation - tool_responses = either.right - handle_span.set_attribute('tool_responses', tool_responses) - response_msgs = ' '.join(m.role for m in tool_responses) - handle_span.message = f'handle model response -> {response_msgs}' - messages.extend(tool_responses) - except (ValidationError, exceptions.UnexpectedModelBehaviour) as e: - run_span.set_attribute('messages', messages) - # noinspection PyTypeChecker - raise exceptions.AgentError(messages, model_used) from e + while True: + run_step += 1 + with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span: + model_response, request_cost = await agent_model.request(messages) + model_req_span.set_attribute('response', model_response) + model_req_span.set_attribute('cost', request_cost) + model_req_span.message = f'model request -> {model_response.role}' + + messages.append(model_response) + cost += request_cost + + with _logfire.span('handle model response') as handle_span: + either = await self._handle_model_response(model_response, deps) + + if left := either.left: + # left means return a streamed result + run_span.set_attribute('all_messages', messages) + run_span.set_attribute('cost', cost) + handle_span.set_attribute('result', left.value) + handle_span.message = 'handle model response -> final result' + return result.RunResult(messages, new_message_index, left.value, cost) + else: + # right means continue the conversation + tool_responses = either.right + handle_span.set_attribute('tool_responses', tool_responses) + response_msgs = ' '.join(m.role for m in tool_responses) + handle_span.message = f'handle model response -> {response_msgs}' + messages.extend(tool_responses) def run_sync( self, @@ -218,49 +212,43 @@ async def run_stream( model_name=model_used.name(), ) as run_span: run_step = 0 - try: - while True: - run_step += 1 - with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span: - async with agent_model.request_stream(messages) as model_response: - model_req_span.set_attribute('response_type', model_response.__class__.__name__) - # We want to end the "model request" span here, but we can't exit the context manager - # in the traditional way - model_req_span.__exit__(None, None, None) - - with _logfire.span('handle model response') as handle_span: - either = await self._handle_streamed_model_response(model_response, deps) - - if left := either.left: - # left means return a streamed result - result_stream = left.value - run_span.set_attribute('all_messages', messages) - handle_span.set_attribute('result_type', result_stream.__class__.__name__) - handle_span.message = 'handle model response -> final result' - yield result.StreamedRunResult( - messages, - new_message_index, - cost, - result_stream, - self._result_schema, - deps, - self._result_validators, - ) - return - else: - # right means continue the conversation - tool_responses = either.right - handle_span.set_attribute('tool_responses', tool_responses) - response_msgs = ' '.join(m.role for m in tool_responses) - handle_span.message = f'handle model response -> {response_msgs}' - messages.extend(tool_responses) - # the model_response should have been fully streamed by now, we can add it's cost - cost += model_response.cost() - - except exceptions.UnexpectedModelBehaviour as e: - run_span.set_attribute('messages', messages) - # noinspection PyTypeChecker - raise exceptions.AgentError(messages, model_used) from e + while True: + run_step += 1 + with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span: + async with agent_model.request_stream(messages) as model_response: + model_req_span.set_attribute('response_type', model_response.__class__.__name__) + # We want to end the "model request" span here, but we can't exit the context manager + # in the traditional way + model_req_span.__exit__(None, None, None) + + with _logfire.span('handle model response') as handle_span: + either = await self._handle_streamed_model_response(model_response, deps) + + if left := either.left: + # left means return a streamed result + result_stream = left.value + run_span.set_attribute('all_messages', messages) + handle_span.set_attribute('result_type', result_stream.__class__.__name__) + handle_span.message = 'handle model response -> final result' + yield result.StreamedRunResult( + messages, + new_message_index, + cost, + result_stream, + self._result_schema, + deps, + self._result_validators, + ) + return + else: + # right means continue the conversation + tool_responses = either.right + handle_span.set_attribute('tool_responses', tool_responses) + response_msgs = ' '.join(m.role for m in tool_responses) + handle_span.message = f'handle model response -> {response_msgs}' + messages.extend(tool_responses) + # the model_response should have been fully streamed by now, we can add it's cost + cost += model_response.cost() def system_prompt( self, func: _system_prompt.SystemPromptFunc[AgentDeps] diff --git a/pydantic_ai/exceptions.py b/pydantic_ai/exceptions.py index c850adbf2a..3af9db0ca3 100644 --- a/pydantic_ai/exceptions.py +++ b/pydantic_ai/exceptions.py @@ -1,16 +1,8 @@ from __future__ import annotations as _annotations import json -from typing import TYPE_CHECKING -from pydantic import ValidationError - -from . import messages - -if TYPE_CHECKING: - from .models import Model - -__all__ = 'ModelRetry', 'AgentError', 'UserError', 'UnexpectedModelBehaviour' +__all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehaviour' class ModelRetry(Exception): @@ -26,46 +18,6 @@ def __init__(self, message: str): super().__init__(message) -class AgentError(RuntimeError): - """Exception raised when an Agent run fails due to a problem with the LLM being used or. - - This exception should always have a cause which you can access to find out what went wrong, it exists so you - can access the history of messages when the error occurred. - """ - - history: list[messages.Message] - agent_name: str - - def __init__(self, history: list[messages.Message], model: Model): - self.history = history - self.model_name = model.name() - super().__init__(f'Error while running model {self.model_name}') - - def cause(self) -> ValidationError | UnexpectedModelBehaviour: - """This is really just typing super and improved find-ability for `Exception.__cause__`.""" - cause = self.__cause__ - if isinstance(cause, (ValidationError, UnexpectedModelBehaviour)): - return cause - else: - raise TypeError( - f'Unexpected cause type for AgentError: {type(cause)}, ' - f'expected ValidationError or UnexpectedModelBehaviour' - ) - - def __str__(self) -> str: - count = len(self.history) - plural = 's' if count != 1 else '' - msg = f'{super().__str__()} after {count} message{plural}' - cause = self.__cause__ - if isinstance(cause, UnexpectedModelBehaviour): - return f'{msg}\n caused by unexpected model behavior: {cause.message}' - elif isinstance(cause, ValidationError): - summary = str(cause).split('\n', 1)[0] - return f'{msg}\n caused by: {summary}' - else: - return msg - - class UserError(RuntimeError): """Error caused by a usage mistake by the application developer — You!""" diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index aae1f02357..26a1167245 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -12,7 +12,7 @@ from pydantic import BaseModel from typing_extensions import Literal, TypeAlias -from pydantic_ai import Agent, AgentError, ModelRetry, UnexpectedModelBehaviour, UserError, _utils +from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehaviour, UserError, _utils from pydantic_ai.messages import ( ArgsObject, ModelStructuredResponse, @@ -494,17 +494,10 @@ def handler(_: httpx.Request): m = GeminiModel('gemini-1.5-flash', http_client=gemini_client) agent = Agent(m, deps=None, system_prompt='this is the system prompt') - with pytest.raises(AgentError, match='Error while running model gemini-1.5-flash') as exc_info: + with pytest.raises(UnexpectedModelBehaviour) as exc_info: await agent.run('Hello') - assert str(exc_info.value) == snapshot( - 'Error while running model gemini-1.5-flash after 2 messages\n' - ' caused by unexpected model behavior: Unexpected response from gemini 401' - ) - - cause = exc_info.value.cause() - assert isinstance(cause, UnexpectedModelBehaviour) - assert str(cause) == snapshot('Unexpected response from gemini 401, body:\ninvalid request') + assert str(exc_info.value) == snapshot('Unexpected response from gemini 401, body:\ninvalid request') async def test_heterogeneous_responses(get_gemini_client: GetGeminiClient): @@ -525,12 +518,10 @@ async def test_heterogeneous_responses(get_gemini_client: GetGeminiClient): gemini_client = get_gemini_client(response) m = GeminiModel('gemini-1.5-flash', http_client=gemini_client) agent = Agent(m, deps=None) - with pytest.raises(AgentError, match='Error while running model gemini-1.5-flash') as exc_info: + with pytest.raises(UnexpectedModelBehaviour) as exc_info: await agent.run('Hello') - cause = exc_info.value.cause() - assert isinstance(cause, UnexpectedModelBehaviour) - assert str(cause) == snapshot( + assert str(exc_info.value) == snapshot( 'Unsupported response from Gemini, expected all parts to be function calls or text, got: ' "[{'text': 'foo'}, {'function_call': {'name': 'get_location', 'args': {'loc_name': 'San Fransisco'}}}]" ) @@ -565,7 +556,7 @@ async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): gemini_client = get_gemini_client(stream) m = GeminiModel('gemini-1.5-flash', http_client=gemini_client) agent = Agent(m, deps=None) - with pytest.raises(AgentError, match='caused by unexpected model behavior: Streamed response ended without con'): + with pytest.raises(UnexpectedModelBehaviour, match='Streamed response ended without con'): async with agent.run_stream('Hello'): pass diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 5634cb5e88..b1c19cb7ed 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -23,7 +23,7 @@ from openai.types.completion_usage import CompletionUsage, PromptTokensDetails from typing_extensions import TypedDict -from pydantic_ai import Agent, AgentError, ModelRetry, _utils +from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehaviour, _utils from pydantic_ai.messages import ( ArgsJson, ModelStructuredResponse, @@ -424,6 +424,6 @@ async def test_no_content(): m = OpenAIModel('gpt-4', openai_client=mock_client) agent = Agent(m, deps=None, result_type=MyTypedDict) - with pytest.raises(AgentError, match='caused by unexpected model behavior: Streamed response ended without con'): + with pytest.raises(UnexpectedModelBehaviour, match='Streamed response ended without con'): async with agent.run_stream(''): pass diff --git a/tests/test_agent.py b/tests/test_agent.py index 3c1ad52f12..4d845a4c3e 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -6,7 +6,7 @@ from inline_snapshot import snapshot from pydantic import BaseModel -from pydantic_ai import Agent, AgentError, CallContext, ModelRetry +from pydantic_ai import Agent, CallContext, ModelRetry, UnexpectedModelBehaviour from pydantic_ai.messages import ( ArgsJson, ArgsObject, @@ -445,7 +445,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse: agent = Agent(FunctionModel(empty), deps=None) - with pytest.raises(AgentError, match='caused by unexpected model behavior: Received empty tool call message'): + with pytest.raises(UnexpectedModelBehaviour, match='Received empty tool call message'): agent.run_sync('Hello') @@ -455,7 +455,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse: agent = Agent(FunctionModel(empty), deps=None) - with pytest.raises(AgentError, match=r'model behavior: Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehaviour, match=r'Exceeded maximum retries \(1\) for result validation'): agent.run_sync('Hello') assert agent.last_run_messages == snapshot( [ diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 5db52bfae3..da38cb03a7 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -7,7 +7,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_ai import Agent, AgentError, UserError +from pydantic_ai import Agent, UnexpectedModelBehaviour, UserError from pydantic_ai.messages import ( ArgsJson, ArgsObject, @@ -150,14 +150,10 @@ def text_stream(_messages: list[Message], _: AgentInfo) -> list[str]: agent = Agent(FunctionModel(stream_function=text_stream), deps=None, result_type=tuple[str, str]) - with pytest.raises(AgentError) as exc_info: + with pytest.raises(UnexpectedModelBehaviour, match=r'Exceeded maximum retries \(1\) for result validation'): async with agent.run_stream(''): pass - assert str(exc_info.value) == snapshot( - 'Error while running model function:stream-text_stream after 2 messages\n' - ' caused by unexpected model behavior: Exceeded maximum retries (1) for result validation' - ) assert call_index == 2 @@ -232,7 +228,7 @@ def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterab agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=tuple[str, int]) - with pytest.raises(AgentError, match='caused by unexpected model behavior: Received empty tool call message'): + with pytest.raises(UnexpectedModelBehaviour, match='Received empty tool call message'): async with agent.run_stream('hello'): pass @@ -247,7 +243,7 @@ def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterab async def ret_a(x: str) -> str: return x - with pytest.raises(AgentError, match=r'model behavior: Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehaviour, match=r'Exceeded maximum retries \(1\) for result validation'): async with agent.run_stream('hello'): pass assert agent.last_run_messages == snapshot(