Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
options:
members:
- __init__
- name
- run
- run_sync
- run_stream
Expand Down
46 changes: 43 additions & 3 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import asyncio
import dataclasses
import inspect
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from types import FrameType
from typing import Any, Callable, Generic, cast, final, overload

import logfire_api
Expand Down Expand Up @@ -54,6 +56,11 @@ class Agent(Generic[AgentDeps, ResultData]):
# dataclass fields mostly for my sanity — knowing what attributes are available
model: models.Model | models.KnownModelName | None
"""The default model configured for this agent."""
name: str | None
"""The name of the agent, used for logging.
If `None`, we try to infer the agent name from the call frame when the agent is first run.
"""
_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
_allow_text_result: bool = field(repr=False)
Expand All @@ -79,6 +86,7 @@ def __init__(
result_type: type[ResultData] = str,
system_prompt: str | Sequence[str] = (),
deps_type: type[AgentDeps] = NoneType,
name: str | None = None,
retries: int = 1,
result_tool_name: str = 'final_result',
result_tool_description: str | None = None,
Expand All @@ -98,6 +106,8 @@ def __init__(
parameterize the agent, and therefore get the best out of static type checking.
If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright
or add a type hint `: Agent[None, <return type>]`.
name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
when the agent is first run.
retries: The default number of retries to allow before raising an error.
result_tool_name: The name of the tool to use for the final result.
result_tool_description: The description of the final result tool.
Expand All @@ -115,6 +125,7 @@ def __init__(
else:
self.model = models.infer_model(model)

self.name = name
self._result_schema = _result.ResultSchema[result_type].build(
result_type, result_tool_name, result_tool_description
)
Expand All @@ -139,6 +150,7 @@ async def run(
message_history: list[_messages.Message] | None = None,
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
infer_name: bool = True,
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt in async mode.
Expand All @@ -147,16 +159,19 @@ async def run(
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
Returns:
The result of the run.
"""
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())
model_used, custom_model, agent_model = await self._get_agent_model(model)

deps = self._get_deps(deps)

with _logfire.span(
'agent run {prompt=}',
'{agent.name} run {prompt=}',
prompt=user_prompt,
agent=self,
custom_model=custom_model,
Expand Down Expand Up @@ -208,6 +223,7 @@ def run_sync(
message_history: list[_messages.Message] | None = None,
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
infer_name: bool = True,
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt synchronously.
Expand All @@ -218,12 +234,17 @@ def run_sync(
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
Returns:
The result of the run.
"""
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.run(user_prompt, message_history=message_history, model=model, deps=deps))
return loop.run_until_complete(
self.run(user_prompt, message_history=message_history, model=model, deps=deps, infer_name=False)
)

@asynccontextmanager
async def run_stream(
Expand All @@ -233,6 +254,7 @@ async def run_stream(
message_history: list[_messages.Message] | None = None,
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
infer_name: bool = True,
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
"""Run the agent with a user prompt in async mode, returning a streamed response.
Expand All @@ -241,16 +263,21 @@ async def run_stream(
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
Returns:
The result of the run.
"""
if infer_name and self.name is None:
# f_back because `asynccontextmanager` adds one frame
if frame := inspect.currentframe(): # pragma: no branch
self._infer_name(frame.f_back)
model_used, custom_model, agent_model = await self._get_agent_model(model)

deps = self._get_deps(deps)

with _logfire.span(
'agent run stream {prompt=}',
'{agent.name} run stream {prompt=}',
prompt=user_prompt,
agent=self,
custom_model=custom_model,
Expand Down Expand Up @@ -798,6 +825,19 @@ def _get_deps(self, deps: AgentDeps) -> AgentDeps:
else:
return deps

def _infer_name(self, function_frame: FrameType | None) -> None:
"""Infer the agent name from the call frame.
Usage should be `self._infer_name(inspect.currentframe())`.
"""
assert self.name is None, 'Name already set'
if function_frame is not None: # pragma: no branch
if parent_frame := function_frame.f_back: # pragma: no branch
for name, item in parent_frame.f_locals.items():
if item is self:
self.name = name
return


@dataclass
class _MarkFinalResult(Generic[ResultData]):
Expand Down
39 changes: 39 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from .conftest import IsNow, TestEnv

pytestmark = pytest.mark.anyio


def test_result_tuple(set_event_loop: None):
def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
Expand Down Expand Up @@ -69,7 +71,10 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:

agent = Agent(FunctionModel(return_model), result_type=Foo)

assert agent.name is None

result = agent.run_sync('Hello')
assert agent.name == 'agent'
assert isinstance(result.data, Foo)
assert result.data.model_dump() == {'a': 42, 'b': 'foo'}
assert result.all_messages() == snapshot(
Expand Down Expand Up @@ -535,3 +540,37 @@ async def make_request() -> str:
for _ in range(2):
result = agent.run_sync('Hello')
assert result.data == '{"make_request":"200"}'


async def test_agent_name():
my_agent = Agent('test')

assert my_agent.name is None

await my_agent.run('Hello', infer_name=False)
assert my_agent.name is None

await my_agent.run('Hello')
assert my_agent.name == 'my_agent'


async def test_agent_name_already_set():
my_agent = Agent('test', name='fig_tree')

assert my_agent.name == 'fig_tree'

await my_agent.run('Hello')
assert my_agent.name == 'fig_tree'


async def test_agent_name_changes():
my_agent = Agent('test')

await my_agent.run('Hello')
assert my_agent.name == 'my_agent'

new_agent = my_agent
del my_agent

await new_agent.run('Hello')
assert new_agent.name == 'my_agent'
11 changes: 7 additions & 4 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@
async def test_streamed_text_response():
m = TestModel()

agent = Agent(m)
test_agent = Agent(m)
assert test_agent.name is None

@agent.tool_plain
@test_agent.tool_plain
async def ret_a(x: str) -> str:
return f'{x}-apple'

async with agent.run_stream('Hello') as result:
async with test_agent.run_stream('Hello') as result:
assert test_agent.name == 'test_agent'
assert not result.is_structured
assert not result.is_complete
assert result.all_messages() == snapshot(
Expand Down Expand Up @@ -71,9 +73,10 @@ async def ret_a(x: str) -> str:
async def test_streamed_structured_response():
m = TestModel()

agent = Agent(m, result_type=tuple[str, str])
agent = Agent(m, result_type=tuple[str, str], name='fig_jam')

async with agent.run_stream('') as result:
assert agent.name == 'fig_jam'
assert result.is_structured
assert not result.is_complete
response = await result.get_data()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def test_tool_return_conflict():


def test_init_ctx_tool_invalid():
def plain_tool(x: int) -> int:
def plain_tool(x: int) -> int: # pragma: no cover
return x + 1

m = r'First parameter of tools that take context must be annotated with RunContext\[\.\.\.\]'
Expand Down
Loading