From deb5d5170f6236561174e2a9f223c243a2d30922 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 4 Dec 2024 19:27:34 +0000 Subject: [PATCH 1/4] adding Agent name --- pydantic_ai_examples/pydantic_model.py | 4 +-- pydantic_ai_slim/pydantic_ai/agent.py | 43 ++++++++++++++++++++++++-- tests/test_agent.py | 39 +++++++++++++++++++++++ tests/test_streaming.py | 11 ++++--- 4 files changed, 88 insertions(+), 9 deletions(-) diff --git a/pydantic_ai_examples/pydantic_model.py b/pydantic_ai_examples/pydantic_model.py index 9c654e9e1a..a247a37b26 100644 --- a/pydantic_ai_examples/pydantic_model.py +++ b/pydantic_ai_examples/pydantic_model.py @@ -25,9 +25,9 @@ class MyModel(BaseModel): model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o')) print(f'Using model: {model}') -agent = Agent(model, result_type=MyModel) +pydantic_agent = Agent(model, result_type=MyModel) if __name__ == '__main__': - result = agent.run_sync('The windy city in the US of A.') + result = pydantic_agent.run_sync('The windy city in the US of A.') print(result.data) print(result.cost()) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 5f8e2a8d78..383c7117de 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -2,9 +2,11 @@ import asyncio import dataclasses +import inspect from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass, field +from types import FrameType from typing import Any, Callable, Generic, cast, final, overload import logfire_api @@ -54,6 +56,8 @@ class Agent(Generic[AgentDeps, ResultData]): # dataclass fields mostly for my sanity — knowing what attributes are available model: models.Model | models.KnownModelName | None """The default model configured for this agent.""" + name: str | None + """The name of the agent, used for logging.""" _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False) _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False) _allow_text_result: bool = field(repr=False) @@ -79,6 +83,7 @@ def __init__( result_type: type[ResultData] = str, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDeps] = NoneType, + name: str | None = None, retries: int = 1, result_tool_name: str = 'final_result', result_tool_description: str | None = None, @@ -98,6 +103,8 @@ def __init__( parameterize the agent, and therefore get the best out of static type checking. If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright or add a type hint `: Agent[None, ]`. + name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame + when the agent is first run. retries: The default number of retries to allow before raising an error. result_tool_name: The name of the tool to use for the final result. result_tool_description: The description of the final result tool. @@ -115,6 +122,7 @@ def __init__( else: self.model = models.infer_model(model) + self.name = name self._result_schema = _result.ResultSchema[result_type].build( result_type, result_tool_name, result_tool_description ) @@ -139,6 +147,7 @@ async def run( message_history: list[_messages.Message] | None = None, model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, + infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt in async mode. @@ -147,16 +156,19 @@ async def run( message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) model_used, custom_model, agent_model = await self._get_agent_model(model) deps = self._get_deps(deps) with _logfire.span( - 'agent run {prompt=}', + '{agent.name} run {prompt=}', prompt=user_prompt, agent=self, custom_model=custom_model, @@ -208,6 +220,7 @@ def run_sync( message_history: list[_messages.Message] | None = None, model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, + infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. @@ -218,12 +231,17 @@ def run_sync( message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) loop = asyncio.get_event_loop() - return loop.run_until_complete(self.run(user_prompt, message_history=message_history, model=model, deps=deps)) + return loop.run_until_complete( + self.run(user_prompt, message_history=message_history, model=model, deps=deps, infer_name=False) + ) @asynccontextmanager async def run_stream( @@ -233,6 +251,7 @@ async def run_stream( message_history: list[_messages.Message] | None = None, model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, + infer_name: bool = True, ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -241,16 +260,21 @@ async def run_stream( message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ + if infer_name and self.name is None: + # f_back because `asynccontextmanager` adds one frame + if frame := inspect.currentframe(): + self._infer_name(frame.f_back) model_used, custom_model, agent_model = await self._get_agent_model(model) deps = self._get_deps(deps) with _logfire.span( - 'agent run stream {prompt=}', + '{agent.name} run stream {prompt=}', prompt=user_prompt, agent=self, custom_model=custom_model, @@ -798,6 +822,19 @@ def _get_deps(self, deps: AgentDeps) -> AgentDeps: else: return deps + def _infer_name(self, function_frame: FrameType | None) -> None: + """Infer the agent name from the call frame. + + Usage should be `self._infer_name(inspect.currentframe())`. + """ + assert self.name is None, 'Name already set' + if function_frame is not None: + if parent_frame := function_frame.f_back: + for name, item in parent_frame.f_locals.items(): + if item is self: + self.name = name + return + @dataclass class _MarkFinalResult(Generic[ResultData]): diff --git a/tests/test_agent.py b/tests/test_agent.py index d4fc1ccf46..7bad9889d7 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -27,6 +27,8 @@ from .conftest import IsNow, TestEnv +pytestmark = pytest.mark.anyio + def test_result_tuple(set_event_loop: None): def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse: @@ -69,7 +71,10 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: agent = Agent(FunctionModel(return_model), result_type=Foo) + assert agent.name is None + result = agent.run_sync('Hello') + assert agent.name == 'agent' assert isinstance(result.data, Foo) assert result.data.model_dump() == {'a': 42, 'b': 'foo'} assert result.all_messages() == snapshot( @@ -535,3 +540,37 @@ async def make_request() -> str: for _ in range(2): result = agent.run_sync('Hello') assert result.data == '{"make_request":"200"}' + + +async def test_agent_name(): + my_agent = Agent('test') + + assert my_agent.name is None + + await my_agent.run('Hello', infer_name=False) + assert my_agent.name is None + + await my_agent.run('Hello') + assert my_agent.name == 'my_agent' + + +async def test_agent_name_already_set(): + my_agent = Agent('test', name='fig_tree') + + assert my_agent.name == 'fig_tree' + + await my_agent.run('Hello') + assert my_agent.name == 'fig_tree' + + +async def test_agent_name_changes(): + my_agent = Agent('test') + + await my_agent.run('Hello') + assert my_agent.name == 'my_agent' + + new_agent = my_agent + del my_agent + + await new_agent.run('Hello') + assert new_agent.name == 'my_agent' diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 24016281db..3102d0b667 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -31,13 +31,15 @@ async def test_streamed_text_response(): m = TestModel() - agent = Agent(m) + test_agent = Agent(m) + assert test_agent.name is None - @agent.tool_plain + @test_agent.tool_plain async def ret_a(x: str) -> str: return f'{x}-apple' - async with agent.run_stream('Hello') as result: + async with test_agent.run_stream('Hello') as result: + assert test_agent.name == 'test_agent' assert not result.is_structured assert not result.is_complete assert result.all_messages() == snapshot( @@ -71,9 +73,10 @@ async def ret_a(x: str) -> str: async def test_streamed_structured_response(): m = TestModel() - agent = Agent(m, result_type=tuple[str, str]) + agent = Agent(m, result_type=tuple[str, str], name='fig_jam') async with agent.run_stream('') as result: + assert agent.name == 'fig_jam' assert result.is_structured assert not result.is_complete response = await result.get_data() From 4cc93e53a7cc0c9dc4dde716f45cf1e4a30cfa14 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 4 Dec 2024 19:33:02 +0000 Subject: [PATCH 2/4] revert unnecessary change --- pydantic_ai_examples/pydantic_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_examples/pydantic_model.py b/pydantic_ai_examples/pydantic_model.py index a247a37b26..9c654e9e1a 100644 --- a/pydantic_ai_examples/pydantic_model.py +++ b/pydantic_ai_examples/pydantic_model.py @@ -25,9 +25,9 @@ class MyModel(BaseModel): model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o')) print(f'Using model: {model}') -pydantic_agent = Agent(model, result_type=MyModel) +agent = Agent(model, result_type=MyModel) if __name__ == '__main__': - result = pydantic_agent.run_sync('The windy city in the US of A.') + result = agent.run_sync('The windy city in the US of A.') print(result.data) print(result.cost()) From dc8a72b2e502adba53ac5549bd65e57d10885a50 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 4 Dec 2024 20:02:22 +0000 Subject: [PATCH 3/4] improve coverage --- pydantic_ai_slim/pydantic_ai/agent.py | 6 +++--- tests/test_tools.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 383c7117de..467787df92 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -267,7 +267,7 @@ async def run_stream( """ if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame - if frame := inspect.currentframe(): + if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) model_used, custom_model, agent_model = await self._get_agent_model(model) @@ -828,8 +828,8 @@ def _infer_name(self, function_frame: FrameType | None) -> None: Usage should be `self._infer_name(inspect.currentframe())`. """ assert self.name is None, 'Name already set' - if function_frame is not None: - if parent_frame := function_frame.f_back: + if function_frame is not None: # pragma: no branch + if parent_frame := function_frame.f_back: # pragma: no branch for name, item in parent_frame.f_locals.items(): if item is self: self.name = name diff --git a/tests/test_tools.py b/tests/test_tools.py index 39905ae18a..7f56159650 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -335,7 +335,7 @@ def test_tool_return_conflict(): def test_init_ctx_tool_invalid(): - def plain_tool(x: int) -> int: + def plain_tool(x: int) -> int: # pragma: no cover return x + 1 m = r'First parameter of tools that take context must be annotated with RunContext\[\.\.\.\]' From f9c8249097ca872ded0354623a60d6ff304c1de6 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 4 Dec 2024 20:13:53 +0000 Subject: [PATCH 4/4] add docs --- docs/api/agent.md | 1 + pydantic_ai_slim/pydantic_ai/agent.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/api/agent.md b/docs/api/agent.md index 085abe6c1b..c4b43af6c1 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -4,6 +4,7 @@ options: members: - __init__ + - name - run - run_sync - run_stream diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 467787df92..7b83180760 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -57,7 +57,10 @@ class Agent(Generic[AgentDeps, ResultData]): model: models.Model | models.KnownModelName | None """The default model configured for this agent.""" name: str | None - """The name of the agent, used for logging.""" + """The name of the agent, used for logging. + + If `None`, we try to infer the agent name from the call frame when the agent is first run. + """ _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False) _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False) _allow_text_result: bool = field(repr=False)