Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .agent import Agent
from .call_typing import CallContext
from .exceptions import AgentError, ModelRetry, UnexpectedModelBehaviour, UserError
from .exceptions import ModelRetry, UnexpectedModelBehaviour, UserError

__all__ = 'Agent', 'AgentError', 'CallContext', 'ModelRetry', 'UnexpectedModelBehaviour', 'UserError', '__version__'
__all__ = 'Agent', 'CallContext', 'ModelRetry', 'UnexpectedModelBehaviour', 'UserError', '__version__'
__version__ = version('pydantic_ai')
198 changes: 94 additions & 104 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import asyncio
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager, suppress
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Callable, Generic, Literal, cast, final, overload

import logfire_api
from pydantic import ValidationError
from typing_extensions import assert_never

from . import _result, _retriever as _r, _system_prompt, _utils, exceptions, messages as _messages, models, result
Expand Down Expand Up @@ -44,6 +43,11 @@ class Agent(Generic[AgentDeps, ResponseData]):
_default_deps: AgentDeps
_max_result_retries: int
_current_result_retry: int
last_run_messages: list[_messages.Message] | None = None
"""The messages from the last run, useful when a run raised an exception.

Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
"""

def __init__(
self,
Expand Down Expand Up @@ -100,6 +104,7 @@ async def run(
deps = self._default_deps

new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
self.last_run_messages = messages

for retriever in self._retrievers.values():
retriever.reset()
Expand All @@ -114,39 +119,34 @@ async def run(
model_name=model_used.name(),
) as run_span:
run_step = 0
try:
while True:
run_step += 1
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
model_response, request_cost = await agent_model.request(messages)
model_req_span.set_attribute('response', model_response)
model_req_span.set_attribute('cost', request_cost)
model_req_span.message = f'model request -> {model_response.role}'

messages.append(model_response)
cost += request_cost

with _logfire.span('handle model response') as handle_span:
either = await self._handle_model_response(model_response, deps)

if left := either.left:
# left means return a streamed result
run_span.set_attribute('all_messages', messages)
run_span.set_attribute('cost', cost)
handle_span.set_attribute('result', left.value)
handle_span.message = 'handle model response -> final result'
return result.RunResult(messages, new_message_index, left.value, cost)
else:
# right means continue the conversation
tool_responses = either.right
handle_span.set_attribute('tool_responses', tool_responses)
response_msgs = ' '.join(m.role for m in tool_responses)
handle_span.message = f'handle model response -> {response_msgs}'
messages.extend(tool_responses)
except (ValidationError, exceptions.UnexpectedModelBehaviour) as e:
run_span.set_attribute('messages', messages)
# noinspection PyTypeChecker
raise exceptions.AgentError(messages, model_used) from e
while True:
run_step += 1
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
model_response, request_cost = await agent_model.request(messages)
model_req_span.set_attribute('response', model_response)
model_req_span.set_attribute('cost', request_cost)
model_req_span.message = f'model request -> {model_response.role}'

messages.append(model_response)
cost += request_cost

with _logfire.span('handle model response') as handle_span:
either = await self._handle_model_response(model_response, deps)

if left := either.left:
# left means return a streamed result
run_span.set_attribute('all_messages', messages)
run_span.set_attribute('cost', cost)
handle_span.set_attribute('result', left.value)
handle_span.message = 'handle model response -> final result'
return result.RunResult(messages, new_message_index, left.value, cost)
else:
# right means continue the conversation
tool_responses = either.right
handle_span.set_attribute('tool_responses', tool_responses)
response_msgs = ' '.join(m.role for m in tool_responses)
handle_span.message = f'handle model response -> {response_msgs}'
messages.extend(tool_responses)

def run_sync(
self,
Expand Down Expand Up @@ -197,6 +197,7 @@ async def run_stream(
deps = self._default_deps

new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
self.last_run_messages = messages

for retriever in self._retrievers.values():
retriever.reset()
Expand All @@ -211,49 +212,43 @@ async def run_stream(
model_name=model_used.name(),
) as run_span:
run_step = 0
try:
while True:
run_step += 1
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
async with agent_model.request_stream(messages) as model_response:
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
# We want to end the "model request" span here, but we can't exit the context manager
# in the traditional way
model_req_span.__exit__(None, None, None)

with _logfire.span('handle model response') as handle_span:
either = await self._handle_streamed_model_response(model_response, deps)

if left := either.left:
# left means return a streamed result
result_stream = left.value
run_span.set_attribute('all_messages', messages)
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
handle_span.message = 'handle model response -> final result'
yield result.StreamedRunResult(
messages,
new_message_index,
cost,
result_stream,
self._result_schema,
deps,
self._result_validators,
)
return
else:
# right means continue the conversation
tool_responses = either.right
handle_span.set_attribute('tool_responses', tool_responses)
response_msgs = ' '.join(m.role for m in tool_responses)
handle_span.message = f'handle model response -> {response_msgs}'
messages.extend(tool_responses)
# the model_response should have been fully streamed by now, we can add it's cost
cost += model_response.cost()

except exceptions.UnexpectedModelBehaviour as e:
run_span.set_attribute('messages', messages)
# noinspection PyTypeChecker
raise exceptions.AgentError(messages, model_used) from e
while True:
run_step += 1
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
async with agent_model.request_stream(messages) as model_response:
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
# We want to end the "model request" span here, but we can't exit the context manager
# in the traditional way
model_req_span.__exit__(None, None, None)

with _logfire.span('handle model response') as handle_span:
either = await self._handle_streamed_model_response(model_response, deps)

if left := either.left:
# left means return a streamed result
result_stream = left.value
run_span.set_attribute('all_messages', messages)
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
handle_span.message = 'handle model response -> final result'
yield result.StreamedRunResult(
messages,
new_message_index,
cost,
result_stream,
self._result_schema,
deps,
self._result_validators,
)
return
else:
# right means continue the conversation
tool_responses = either.right
handle_span.set_attribute('tool_responses', tool_responses)
response_msgs = ' '.join(m.role for m in tool_responses)
handle_span.message = f'handle model response -> {response_msgs}'
messages.extend(tool_responses)
# the model_response should have been fully streamed by now, we can add it's cost
cost += model_response.cost()

def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDeps]
Expand Down Expand Up @@ -413,21 +408,17 @@ async def _handle_model_response(
raise exceptions.UnexpectedModelBehaviour('Received empty tool call message')

# otherwise we run all retriever functions in parallel
messages: list[_messages.Message] = []
tasks: list[asyncio.Task[_messages.Message]] = []
try:
for call in model_response.calls:
retriever = self._retrievers.get(call.tool_name)
if retriever is None:
# should this be a retry error?
raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}')
for call in model_response.calls:
if retriever := self._retrievers.get(call.tool_name):
tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
except BaseException:
await _cancel_tasks(tasks)
raise
else:
messages.append(self._unknown_tool(call.tool_name))

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
new_messages = await asyncio.gather(*tasks)
return _utils.Either(right=new_messages)
messages += await asyncio.gather(*tasks)
return _utils.Either(right=messages)
else:
assert_never(model_response)

Expand Down Expand Up @@ -479,16 +470,11 @@ async def _handle_streamed_model_response(

# we now run all retriever functions in parallel
tasks: list[asyncio.Task[_messages.Message]] = []
try:
for call in structured_msg.calls:
retriever = self._retrievers.get(call.tool_name)
if retriever is None:
raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}')
for call in structured_msg.calls:
if retriever := self._retrievers.get(call.tool_name):
tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
except BaseException:
# otherwise we'll get warnings about coroutines not awaited
await _cancel_tasks(tasks)
raise
else:
messages.append(self._unknown_tool(call.tool_name))

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
messages += await asyncio.gather(*tasks)
Expand Down Expand Up @@ -516,9 +502,13 @@ async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]:
messages.append(_messages.SystemPrompt(prompt))
return messages


async def _cancel_tasks(tasks: list[asyncio.Task[_messages.Message]]) -> None:
for task in tasks:
task.cancel()
with suppress(asyncio.CancelledError):
await asyncio.gather(*tasks)
def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
self._incr_result_retry()
names = list(self._retrievers.keys())
if self._result_schema:
names.extend(self._result_schema.tool_names())
if names:
msg = f'Available tools: {", ".join(names)}'
else:
msg = 'No tools available.'
return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}')
50 changes: 1 addition & 49 deletions pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
from __future__ import annotations as _annotations

import json
from typing import TYPE_CHECKING

from pydantic import ValidationError

from . import messages

if TYPE_CHECKING:
from .models import Model

__all__ = 'ModelRetry', 'AgentError', 'UserError', 'UnexpectedModelBehaviour'
__all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehaviour'


class ModelRetry(Exception):
Expand All @@ -26,46 +18,6 @@ def __init__(self, message: str):
super().__init__(message)


class AgentError(RuntimeError):
"""Exception raised when an Agent run fails due to a problem with the LLM being used or.

This exception should always have a cause which you can access to find out what went wrong, it exists so you
can access the history of messages when the error occurred.
"""

history: list[messages.Message]
agent_name: str

def __init__(self, history: list[messages.Message], model: Model):
self.history = history
self.model_name = model.name()
super().__init__(f'Error while running model {self.model_name}')

def cause(self) -> ValidationError | UnexpectedModelBehaviour:
"""This is really just typing super and improved find-ability for `Exception.__cause__`."""
cause = self.__cause__
if isinstance(cause, (ValidationError, UnexpectedModelBehaviour)):
return cause
else:
raise TypeError(
f'Unexpected cause type for AgentError: {type(cause)}, '
f'expected ValidationError or UnexpectedModelBehaviour'
)

def __str__(self) -> str:
count = len(self.history)
plural = 's' if count != 1 else ''
msg = f'{super().__str__()} after {count} message{plural}'
cause = self.__cause__
if isinstance(cause, UnexpectedModelBehaviour):
return f'{msg}\n caused by unexpected model behavior: {cause.message}'
elif isinstance(cause, ValidationError):
summary = str(cause).split('\n', 1)[0]
return f'{msg}\n caused by: {summary}'
else:
return msg


class UserError(RuntimeError):
"""Error caused by a usage mistake by the application developer — You!"""

Expand Down
21 changes: 6 additions & 15 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pydantic import BaseModel
from typing_extensions import Literal, TypeAlias

from pydantic_ai import Agent, AgentError, ModelRetry, UnexpectedModelBehaviour, UserError, _utils
from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehaviour, UserError, _utils
from pydantic_ai.messages import (
ArgsObject,
ModelStructuredResponse,
Expand Down Expand Up @@ -494,17 +494,10 @@ def handler(_: httpx.Request):
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(m, deps=None, system_prompt='this is the system prompt')

with pytest.raises(AgentError, match='Error while running model gemini-1.5-flash') as exc_info:
with pytest.raises(UnexpectedModelBehaviour) as exc_info:
await agent.run('Hello')

assert str(exc_info.value) == snapshot(
'Error while running model gemini-1.5-flash after 2 messages\n'
' caused by unexpected model behavior: Unexpected response from gemini 401'
)

cause = exc_info.value.cause()
assert isinstance(cause, UnexpectedModelBehaviour)
assert str(cause) == snapshot('Unexpected response from gemini 401, body:\ninvalid request')
assert str(exc_info.value) == snapshot('Unexpected response from gemini 401, body:\ninvalid request')


async def test_heterogeneous_responses(get_gemini_client: GetGeminiClient):
Expand All @@ -525,12 +518,10 @@ async def test_heterogeneous_responses(get_gemini_client: GetGeminiClient):
gemini_client = get_gemini_client(response)
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(m, deps=None)
with pytest.raises(AgentError, match='Error while running model gemini-1.5-flash') as exc_info:
with pytest.raises(UnexpectedModelBehaviour) as exc_info:
await agent.run('Hello')

cause = exc_info.value.cause()
assert isinstance(cause, UnexpectedModelBehaviour)
assert str(cause) == snapshot(
assert str(exc_info.value) == snapshot(
'Unsupported response from Gemini, expected all parts to be function calls or text, got: '
"[{'text': 'foo'}, {'function_call': {'name': 'get_location', 'args': {'loc_name': 'San Fransisco'}}}]"
)
Expand Down Expand Up @@ -565,7 +556,7 @@ async def test_stream_text_no_data(get_gemini_client: GetGeminiClient):
gemini_client = get_gemini_client(stream)
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(m, deps=None)
with pytest.raises(AgentError, match='caused by unexpected model behavior: Streamed response ended without con'):
with pytest.raises(UnexpectedModelBehaviour, match='Streamed response ended without con'):
async with agent.run_stream('Hello'):
pass

Expand Down
Loading
Loading