diff --git a/docs/api/agent.md b/docs/api/agent.md index 085abe6c1b..c4b43af6c1 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -4,6 +4,7 @@ options: members: - __init__ + - name - run - run_sync - run_stream diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 5f8e2a8d78..7b83180760 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -2,9 +2,11 @@ import asyncio import dataclasses +import inspect from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass, field +from types import FrameType from typing import Any, Callable, Generic, cast, final, overload import logfire_api @@ -54,6 +56,11 @@ class Agent(Generic[AgentDeps, ResultData]): # dataclass fields mostly for my sanity — knowing what attributes are available model: models.Model | models.KnownModelName | None """The default model configured for this agent.""" + name: str | None + """The name of the agent, used for logging. + + If `None`, we try to infer the agent name from the call frame when the agent is first run. + """ _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False) _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False) _allow_text_result: bool = field(repr=False) @@ -79,6 +86,7 @@ def __init__( result_type: type[ResultData] = str, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDeps] = NoneType, + name: str | None = None, retries: int = 1, result_tool_name: str = 'final_result', result_tool_description: str | None = None, @@ -98,6 +106,8 @@ def __init__( parameterize the agent, and therefore get the best out of static type checking. If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright or add a type hint `: Agent[None, ]`. + name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame + when the agent is first run. retries: The default number of retries to allow before raising an error. result_tool_name: The name of the tool to use for the final result. result_tool_description: The description of the final result tool. @@ -115,6 +125,7 @@ def __init__( else: self.model = models.infer_model(model) + self.name = name self._result_schema = _result.ResultSchema[result_type].build( result_type, result_tool_name, result_tool_description ) @@ -139,6 +150,7 @@ async def run( message_history: list[_messages.Message] | None = None, model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, + infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt in async mode. @@ -147,16 +159,19 @@ async def run( message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) model_used, custom_model, agent_model = await self._get_agent_model(model) deps = self._get_deps(deps) with _logfire.span( - 'agent run {prompt=}', + '{agent.name} run {prompt=}', prompt=user_prompt, agent=self, custom_model=custom_model, @@ -208,6 +223,7 @@ def run_sync( message_history: list[_messages.Message] | None = None, model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, + infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. @@ -218,12 +234,17 @@ def run_sync( message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) loop = asyncio.get_event_loop() - return loop.run_until_complete(self.run(user_prompt, message_history=message_history, model=model, deps=deps)) + return loop.run_until_complete( + self.run(user_prompt, message_history=message_history, model=model, deps=deps, infer_name=False) + ) @asynccontextmanager async def run_stream( @@ -233,6 +254,7 @@ async def run_stream( message_history: list[_messages.Message] | None = None, model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, + infer_name: bool = True, ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -241,16 +263,21 @@ async def run_stream( message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ + if infer_name and self.name is None: + # f_back because `asynccontextmanager` adds one frame + if frame := inspect.currentframe(): # pragma: no branch + self._infer_name(frame.f_back) model_used, custom_model, agent_model = await self._get_agent_model(model) deps = self._get_deps(deps) with _logfire.span( - 'agent run stream {prompt=}', + '{agent.name} run stream {prompt=}', prompt=user_prompt, agent=self, custom_model=custom_model, @@ -798,6 +825,19 @@ def _get_deps(self, deps: AgentDeps) -> AgentDeps: else: return deps + def _infer_name(self, function_frame: FrameType | None) -> None: + """Infer the agent name from the call frame. + + Usage should be `self._infer_name(inspect.currentframe())`. + """ + assert self.name is None, 'Name already set' + if function_frame is not None: # pragma: no branch + if parent_frame := function_frame.f_back: # pragma: no branch + for name, item in parent_frame.f_locals.items(): + if item is self: + self.name = name + return + @dataclass class _MarkFinalResult(Generic[ResultData]): diff --git a/tests/test_agent.py b/tests/test_agent.py index d4fc1ccf46..7bad9889d7 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -27,6 +27,8 @@ from .conftest import IsNow, TestEnv +pytestmark = pytest.mark.anyio + def test_result_tuple(set_event_loop: None): def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse: @@ -69,7 +71,10 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: agent = Agent(FunctionModel(return_model), result_type=Foo) + assert agent.name is None + result = agent.run_sync('Hello') + assert agent.name == 'agent' assert isinstance(result.data, Foo) assert result.data.model_dump() == {'a': 42, 'b': 'foo'} assert result.all_messages() == snapshot( @@ -535,3 +540,37 @@ async def make_request() -> str: for _ in range(2): result = agent.run_sync('Hello') assert result.data == '{"make_request":"200"}' + + +async def test_agent_name(): + my_agent = Agent('test') + + assert my_agent.name is None + + await my_agent.run('Hello', infer_name=False) + assert my_agent.name is None + + await my_agent.run('Hello') + assert my_agent.name == 'my_agent' + + +async def test_agent_name_already_set(): + my_agent = Agent('test', name='fig_tree') + + assert my_agent.name == 'fig_tree' + + await my_agent.run('Hello') + assert my_agent.name == 'fig_tree' + + +async def test_agent_name_changes(): + my_agent = Agent('test') + + await my_agent.run('Hello') + assert my_agent.name == 'my_agent' + + new_agent = my_agent + del my_agent + + await new_agent.run('Hello') + assert new_agent.name == 'my_agent' diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 24016281db..3102d0b667 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -31,13 +31,15 @@ async def test_streamed_text_response(): m = TestModel() - agent = Agent(m) + test_agent = Agent(m) + assert test_agent.name is None - @agent.tool_plain + @test_agent.tool_plain async def ret_a(x: str) -> str: return f'{x}-apple' - async with agent.run_stream('Hello') as result: + async with test_agent.run_stream('Hello') as result: + assert test_agent.name == 'test_agent' assert not result.is_structured assert not result.is_complete assert result.all_messages() == snapshot( @@ -71,9 +73,10 @@ async def ret_a(x: str) -> str: async def test_streamed_structured_response(): m = TestModel() - agent = Agent(m, result_type=tuple[str, str]) + agent = Agent(m, result_type=tuple[str, str], name='fig_jam') async with agent.run_stream('') as result: + assert agent.name == 'fig_jam' assert result.is_structured assert not result.is_complete response = await result.get_data() diff --git a/tests/test_tools.py b/tests/test_tools.py index 39905ae18a..7f56159650 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -335,7 +335,7 @@ def test_tool_return_conflict(): def test_init_ctx_tool_invalid(): - def plain_tool(x: int) -> int: + def plain_tool(x: int) -> int: # pragma: no cover return x + 1 m = r'First parameter of tools that take context must be annotated with RunContext\[\.\.\.\]'