From 85d2ac828e8d77f8a09d1f8aa5a36f1900ee7ff9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 25 Nov 2024 23:15:38 +0000 Subject: [PATCH 1/4] rename retriever -> tool --- docs/agents.md | 18 ++++---- docs/dependencies.md | 7 ++- docs/index.md | 2 +- pydantic_ai_examples/bank_support.py | 2 +- pydantic_ai_examples/rag.py | 2 +- pydantic_ai_examples/roulette_wheel.py | 2 +- pydantic_ai_examples/weather_agent.py | 4 +- pydantic_ai_slim/pydantic_ai/_retriever.py | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 50 +++++++++++----------- tests/models/test_gemini.py | 6 +-- tests/models/test_groq.py | 2 +- tests/models/test_model_function.py | 20 ++++----- tests/models/test_model_test.py | 6 +-- tests/models/test_openai.py | 2 +- tests/test_agent.py | 2 +- tests/test_deps.py | 2 +- tests/test_logfire.py | 2 +- tests/test_retrievers.py | 20 ++++----- tests/test_streaming.py | 6 +-- tests/typed_agent.py | 14 +++--- 20 files changed, 85 insertions(+), 86 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 5de07cd31c..675c243efb 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -31,7 +31,7 @@ roulette_agent = Agent( # (1)! ) -@roulette_agent.retriever +@roulette_agent.tool async def roulette_wheel(ctx: CallContext[int], square: int) -> str: # (2)! """check if the square is a winner""" return 'winner' if square == ctx.deps else 'loser' @@ -179,10 +179,10 @@ They're useful when it is impractical or impossible to put all the context an ag There are two different decorator functions to register retrievers: -1. [`@agent.retriever_plain`][pydantic_ai.Agent.retriever_plain] — for retrievers that don't need access to the agent [context][pydantic_ai.dependencies.CallContext] -2. [`@agent.retriever`][pydantic_ai.Agent.retriever] — for retrievers that do need access to the agent [context][pydantic_ai.dependencies.CallContext] +1. [`@agent.tool_plain`][pydantic_ai.Agent.retriever_plain] — for retrievers that don't need access to the agent [context][pydantic_ai.dependencies.CallContext] +2. [`@agent.tool`][pydantic_ai.Agent.retriever] — for retrievers that do need access to the agent [context][pydantic_ai.dependencies.CallContext] -`@agent.retriever` is the default since in the majority of cases retrievers will need access to the agent context. +`@agent.tool` is the default since in the majority of cases retrievers will need access to the agent context. Here's an example using both: @@ -202,13 +202,13 @@ agent = Agent( ) -@agent.retriever_plain # (3)! +@agent.tool_plain # (3)! def roll_die() -> str: """Roll a six-sided die and return the result.""" return str(random.randint(1, 6)) -@agent.retriever # (4)! +@agent.tool # (4)! def get_player_name(ctx: CallContext[str]) -> str: """Get the player's name.""" return ctx.deps @@ -343,7 +343,7 @@ from pydantic_ai.models.function import AgentInfo, FunctionModel agent = Agent() -@agent.retriever_plain +@agent.tool_plain def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: """Get me foobar. @@ -420,7 +420,7 @@ agent = Agent( ) -@agent.retriever(retries=2) +@agent.tool(retries=2) def get_user_by_name(ctx: CallContext[DatabaseConn], name: str) -> int: """Get a user's ID from their full name.""" print(name) @@ -455,7 +455,7 @@ from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior agent = Agent('openai:gpt-4o') -@agent.retriever_plain +@agent.tool_plain def calc_volume(size: int) -> int: # (1)! if size == 42: return size**3 diff --git a/docs/dependencies.md b/docs/dependencies.md index c1f7aa1b2c..4ccba8a6d7 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -6,8 +6,7 @@ Matching PydanticAI's design philosophy, our dependency system tries to use exis ## Defining Dependencies -Dependencies can be any python type. While in simple cases you might be able to pass a single object -as a dependency (e.g. an HTTP connection), [dataclasses][] are generally a convenient container when your dependencies included multiple objects. +Dependencies can be any python type. While in simple cases you might be able to pass a single object as a dependency (e.g. an HTTP connection), [dataclasses][] are generally a convenient container when your dependencies included multiple objects. Here's an example of defining an agent that requires dependencies. @@ -188,7 +187,7 @@ async def get_system_prompt(ctx: CallContext[MyDeps]) -> str: return f'Prompt: {response.text}' -@agent.retriever # (1)! +@agent.tool # (1)! async def get_joke_material(ctx: CallContext[MyDeps], subject: str) -> str: response = await ctx.deps.http_client.get( 'https://example.com#jokes', @@ -324,7 +323,7 @@ joke_agent = Agent( factory_agent = Agent('gemini-1.5-pro', result_type=list[str]) -@joke_agent.retriever +@joke_agent.tool async def joke_factory(ctx: CallContext[MyDeps], count: int) -> str: r = await ctx.deps.factory_agent.run(f'Please generate {count} jokes.') return '\n'.join(r.data) diff --git a/docs/index.md b/docs/index.md index e0890303a3..fe04c68b23 100644 --- a/docs/index.md +++ b/docs/index.md @@ -92,7 +92,7 @@ async def add_customer_name(ctx: CallContext[SupportDependencies]) -> str: return f"The customer's name is {customer_name!r}" -@support_agent.retriever # (6)! +@support_agent.tool # (6)! async def customer_balance( ctx: CallContext[SupportDependencies], include_pending: bool ) -> str: diff --git a/pydantic_ai_examples/bank_support.py b/pydantic_ai_examples/bank_support.py index 4f428ff2fb..021f10b905 100644 --- a/pydantic_ai_examples/bank_support.py +++ b/pydantic_ai_examples/bank_support.py @@ -62,7 +62,7 @@ async def add_customer_name(ctx: CallContext[SupportDependencies]) -> str: return f"The customer's name is {customer_name!r}" -@support_agent.retriever +@support_agent.tool async def customer_balance( ctx: CallContext[SupportDependencies], include_pending: bool ) -> str: diff --git a/pydantic_ai_examples/rag.py b/pydantic_ai_examples/rag.py index 44aa88f2d9..d64ff50c93 100644 --- a/pydantic_ai_examples/rag.py +++ b/pydantic_ai_examples/rag.py @@ -51,7 +51,7 @@ class Deps: agent = Agent('openai:gpt-4o', deps_type=Deps) -@agent.retriever +@agent.tool async def retrieve(context: CallContext[Deps], search_query: str) -> str: """Retrieve documentation sections based on a search query. diff --git a/pydantic_ai_examples/roulette_wheel.py b/pydantic_ai_examples/roulette_wheel.py index 21820305e0..eee4947aaa 100644 --- a/pydantic_ai_examples/roulette_wheel.py +++ b/pydantic_ai_examples/roulette_wheel.py @@ -32,7 +32,7 @@ class Deps: ) -@roulette_agent.retriever +@roulette_agent.tool async def roulette_wheel( ctx: CallContext[Deps], square: int ) -> Literal['winner', 'loser']: diff --git a/pydantic_ai_examples/weather_agent.py b/pydantic_ai_examples/weather_agent.py index 6e69920dc2..7e62bf9e64 100644 --- a/pydantic_ai_examples/weather_agent.py +++ b/pydantic_ai_examples/weather_agent.py @@ -41,7 +41,7 @@ class Deps: ) -@weather_agent.retriever +@weather_agent.tool async def get_lat_lng( ctx: CallContext[Deps], location_description: str ) -> dict[str, float]: @@ -71,7 +71,7 @@ async def get_lat_lng( raise ModelRetry('Could not find the location') -@weather_agent.retriever +@weather_agent.tool async def get_weather(ctx: CallContext[Deps], lat: float, lng: float) -> dict[str, Any]: """Get the weather at a location. diff --git a/pydantic_ai_slim/pydantic_ai/_retriever.py b/pydantic_ai_slim/pydantic_ai/_retriever.py index bb376bf7ce..2a24127d03 100644 --- a/pydantic_ai_slim/pydantic_ai/_retriever.py +++ b/pydantic_ai_slim/pydantic_ai/_retriever.py @@ -19,7 +19,7 @@ @dataclass(init=False) -class Retriever(Generic[AgentDeps, RetrieverParams]): +class Tool(Generic[AgentDeps, RetrieverParams]): """A retriever function for an agent.""" name: str diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index b2c05b208f..e9b8b0c69e 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]): _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False) _allow_text_result: bool = field(repr=False) _system_prompts: tuple[str, ...] = field(repr=False) - _retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = field(repr=False) + _tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False) _default_retries: int = field(repr=False) _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False) _deps_type: type[AgentDeps] = field(repr=False) @@ -119,7 +119,7 @@ def __init__( self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) - self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {} + self._tools: dict[str, _r.Tool[AgentDeps, Any]] = {} self._deps_type = deps_type self._default_retries = retries self._system_prompt_functions = [] @@ -153,8 +153,8 @@ async def run( new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history) self.last_run_messages = messages - for retriever in self._retrievers.values(): - retriever.reset() + for tool in self._tools.values(): + tool.reset() cost = result.Cost() @@ -246,8 +246,8 @@ async def run_stream( new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history) self.last_run_messages = messages - for retriever in self._retrievers.values(): - retriever.reset() + for tool in self._tools.values(): + tool.reset() cost = result.Cost() @@ -428,18 +428,18 @@ async def result_validator_deps(ctx: CallContext[str], data: str) -> str: return func @overload - def retriever( + def tool( self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], / ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ... @overload - def retriever( + def tool( self, /, *, retries: int | None = None ) -> Callable[ [RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams] ]: ... - def retriever( + def tool( self, func: RetrieverContextFunc[AgentDeps, RetrieverParams] | None = None, /, @@ -455,7 +455,7 @@ def retriever( [learn more](../agents.md#retrievers-tools-and-schema). We can't add overloads for every possible signature of retriever, since the return type is a recursive union - so the signature of functions decorated with `@agent.retriever` is obscured. + so the signature of functions decorated with `@agent.tool` is obscured. Example: ```py @@ -463,11 +463,11 @@ def retriever( agent = Agent('test', deps_type=int) - @agent.retriever + @agent.tool def foobar(ctx: CallContext[int], x: int) -> int: return ctx.deps + x - @agent.retriever(retries=2) + @agent.tool(retries=2) async def spam(ctx: CallContext[str], y: float) -> float: return ctx.deps + y @@ -497,14 +497,14 @@ def retriever_decorator( return func @overload - def retriever_plain(self, func: RetrieverPlainFunc[RetrieverParams], /) -> RetrieverPlainFunc[RetrieverParams]: ... + def tool_plain(self, func: RetrieverPlainFunc[RetrieverParams], /) -> RetrieverPlainFunc[RetrieverParams]: ... @overload - def retriever_plain( + def tool_plain( self, /, *, retries: int | None = None ) -> Callable[[RetrieverPlainFunc[RetrieverParams]], RetrieverPlainFunc[RetrieverParams]]: ... - def retriever_plain( + def tool_plain( self, func: RetrieverPlainFunc[RetrieverParams] | None = None, /, *, retries: int | None = None ) -> Any: """Decorator to register a retriever function which DOES NOT take `CallContext` as an argument. @@ -515,7 +515,7 @@ def retriever_plain( [learn more](../agents.md#retrievers-tools-and-schema). We can't add overloads for every possible signature of retriever, since the return type is a recursive union - so the signature of functions decorated with `@agent.retriever` is obscured. + so the signature of functions decorated with `@agent.tool` is obscured. Example: ```py @@ -523,11 +523,11 @@ def retriever_plain( agent = Agent('test') - @agent.retriever + @agent.tool def foobar(ctx: CallContext[int]) -> int: return 123 - @agent.retriever(retries=2) + @agent.tool(retries=2) async def spam(ctx: CallContext[str]) -> float: return 3.14 @@ -560,15 +560,15 @@ def _register_retriever( ) -> None: """Private utility to register a retriever function.""" retries_ = retries if retries is not None else self._default_retries - retriever = _r.Retriever[AgentDeps, RetrieverParams](func, retries_) + retriever = _r.Tool[AgentDeps, RetrieverParams](func, retries_) if self._result_schema and retriever.name in self._result_schema.tools: raise ValueError(f'Retriever name conflicts with result schema name: {retriever.name!r}') - if retriever.name in self._retrievers: + if retriever.name in self._tools: raise ValueError(f'Retriever name conflicts with existing retriever: {retriever.name!r}') - self._retrievers[retriever.name] = retriever + self._tools[retriever.name] = retriever async def _get_agent_model( self, model: models.Model | models.KnownModelName | None @@ -601,7 +601,7 @@ async def _get_agent_model( raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.') result_tools = list(self._result_schema.tools.values()) if self._result_schema else None - agent_model = await model_.agent_model(self._retrievers, self._allow_text_result, result_tools) + agent_model = await model_.agent_model(self._tools, self._allow_text_result, result_tools) return model_, custom_model, agent_model async def _prepare_messages( @@ -667,7 +667,7 @@ async def _handle_model_response( messages: list[_messages.Message] = [] tasks: list[asyncio.Task[_messages.Message]] = [] for call in model_response.calls: - if retriever := self._retrievers.get(call.tool_name): + if retriever := self._tools.get(call.tool_name): tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) else: messages.append(self._unknown_tool(call.tool_name)) @@ -730,7 +730,7 @@ async def _handle_streamed_model_response( # we now run all retriever functions in parallel tasks: list[asyncio.Task[_messages.Message]] = [] for call in structured_msg.calls: - if retriever := self._retrievers.get(call.tool_name): + if retriever := self._tools.get(call.tool_name): tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) else: messages.append(self._unknown_tool(call.tool_name)) @@ -763,7 +763,7 @@ async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]: def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt: self._incr_result_retry() - names = list(self._retrievers.keys()) + names = list(self._tools.keys()) if self._result_schema: names.extend(self._result_schema.tool_names()) if names: diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 7ec2d0d3ff..72b1dee670 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -465,7 +465,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient): m = GeminiModel('gemini-1.5-flash', http_client=gemini_client) agent = Agent(m, system_prompt='this is the system prompt') - @agent.retriever_plain + @agent.tool_plain async def get_location(loc_name: str) -> str: if loc_name == 'London': return json.dumps({'lat': 51, 'lng': 0}) @@ -630,12 +630,12 @@ async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): agent = Agent(model, result_type=tuple[int, int]) retriever_calls: list[str] = [] - @agent.retriever_plain + @agent.tool_plain async def foo(x: str) -> str: retriever_calls.append(f'foo({x=!r})') return x - @agent.retriever_plain + @agent.tool_plain async def bar(y: str) -> str: retriever_calls.append(f'bar({y=!r})') return y diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 18c6efaa3f..d19fa94355 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -240,7 +240,7 @@ async def test_request_tool_call(allow_model_requests: None): m = GroqModel('llama-3.1-70b-versatile', groq_client=mock_client) agent = Agent(m, system_prompt='this is the system prompt') - @agent.retriever_plain + @agent.tool_plain async def get_location(loc_name: str) -> str: if loc_name == 'London': return json.dumps({'lat': 51, 'lng': 0}) diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index e88a44dfd0..da51d48a53 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -111,7 +111,7 @@ async def weather_model(messages: list[Message], info: AgentInfo) -> ModelAnyRes weather_agent: Agent[None, str] = Agent(FunctionModel(weather_model)) -@weather_agent.retriever_plain +@weather_agent.tool_plain async def get_location(location_description: str) -> str: if location_description == 'London': lat_lng = {'lat': 51, 'lng': 0} @@ -120,7 +120,7 @@ async def get_location(location_description: str) -> str: return json.dumps(lat_lng) -@weather_agent.retriever +@weather_agent.tool async def get_weather(_: CallContext[None], lat: int, lng: int): if (lat, lng) == (51, 0): # it always rains in London @@ -200,7 +200,7 @@ async def call_function_model(messages: list[Message], _: AgentInfo) -> ModelAny var_args_agent = Agent(FunctionModel(call_function_model), deps_type=int) -@var_args_agent.retriever +@var_args_agent.tool def get_var_args(ctx: CallContext[int], *args: int): assert ctx.deps == 123 return json.dumps({'args': args}) @@ -234,7 +234,7 @@ async def call_retriever(messages: list[Message], info: AgentInfo) -> ModelAnyRe def test_deps_none(): agent = Agent(FunctionModel(call_retriever)) - @agent.retriever + @agent.tool async def get_none(ctx: CallContext[None]): nonlocal called @@ -260,7 +260,7 @@ def get_check_foobar(ctx: CallContext[tuple[str, str]]) -> str: return '' agent = Agent(FunctionModel(call_retriever), deps_type=tuple[str, str]) - agent.retriever(get_check_foobar) + agent.tool(get_check_foobar) called = False agent.run_sync('Hello', deps=('foo', 'bar')) assert called @@ -278,27 +278,27 @@ def test_model_arg(): agent_all = Agent() -@agent_all.retriever +@agent_all.tool async def foo(_: CallContext[None], x: int) -> str: return str(x + 1) -@agent_all.retriever(retries=3) +@agent_all.tool(retries=3) def bar(ctx, x: int) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingParameterType] return str(x + 2) -@agent_all.retriever_plain +@agent_all.tool_plain async def baz(x: int) -> str: return str(x + 3) -@agent_all.retriever_plain(retries=1) +@agent_all.tool_plain(retries=1) def qux(x: int) -> str: return str(x + 4) -@agent_all.retriever_plain # pyright: ignore[reportUnknownArgumentType] +@agent_all.tool_plain # pyright: ignore[reportUnknownArgumentType] def quz(x) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingParameterType] return str(x) # pyright: ignore[reportUnknownArgumentType] diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 331fb3ee38..ecd4cfd01a 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -28,12 +28,12 @@ def test_call_one(): agent = Agent() calls: list[str] = [] - @agent.retriever_plain + @agent.tool_plain async def ret_a(x: str) -> str: calls.append('a') return f'{x}-a' - @agent.retriever_plain + @agent.tool_plain async def ret_b(x: str) -> str: # pragma: no cover calls.append('b') return f'{x}-b' @@ -78,7 +78,7 @@ def test_retriever_retry(): agent = Agent() call_count = 0 - @agent.retriever_plain + @agent.tool_plain async def my_ret(x: int) -> str: nonlocal call_count call_count += 1 diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 52b90a29e9..671180f2ec 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -243,7 +243,7 @@ async def test_request_tool_call(allow_model_requests: None): m = OpenAIModel('gpt-4', openai_client=mock_client) agent = Agent(m, system_prompt='this is the system prompt') - @agent.retriever_plain + @agent.tool_plain async def get_location(loc_name: str) -> str: if loc_name == 'London': return json.dumps({'lat': 51, 'lng': 0}) diff --git a/tests/test_agent.py b/tests/test_agent.py index 2687148ac4..b5112f3168 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -361,7 +361,7 @@ def test_run_with_history_new(): agent = Agent(m, system_prompt='Foobar') - @agent.retriever_plain + @agent.tool_plain async def ret_a(x: str) -> str: return f'{x}-apple' diff --git a/tests/test_deps.py b/tests/test_deps.py index 1c1ff29281..251a08c844 100644 --- a/tests/test_deps.py +++ b/tests/test_deps.py @@ -13,7 +13,7 @@ class MyDeps: agent = Agent(TestModel(), deps_type=MyDeps) -@agent.retriever +@agent.tool async def example_retriever(ctx: CallContext[MyDeps]) -> str: return f'{ctx.deps}' diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 90c7d4587a..0141fca633 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -62,7 +62,7 @@ def get_summary() -> LogfireSummary: def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None: agent = Agent(model=TestModel()) - @agent.retriever_plain + @agent.tool_plain async def my_ret(x: int) -> str: return str(x + 1) diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 7ff73344d9..b713ba3595 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -16,7 +16,7 @@ def test_retriever_no_ctx(): with pytest.raises(UserError) as exc_info: - @agent.retriever # pyright: ignore[reportArgumentType] + @agent.tool # pyright: ignore[reportArgumentType] def invalid_retriever(x: int) -> str: # pragma: no cover return 'Hello' @@ -31,7 +31,7 @@ def test_retriever_plain_with_ctx(): with pytest.raises(UserError) as exc_info: - @agent.retriever_plain + @agent.tool_plain async def invalid_retriever(ctx: CallContext[None]) -> str: # pragma: no cover return 'Hello' @@ -46,7 +46,7 @@ def test_retriever_ctx_second(): with pytest.raises(UserError) as exc_info: - @agent.retriever # pyright: ignore[reportArgumentType] + @agent.tool # pyright: ignore[reportArgumentType] def invalid_retriever(x: int, ctx: CallContext[None]) -> str: # pragma: no cover return 'Hello' @@ -75,7 +75,7 @@ async def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelAny def test_docstring_google(): agent = Agent(FunctionModel(get_json_schema)) - agent.retriever_plain(google_style_docstring) + agent.tool_plain(google_style_docstring) result = agent.run_sync('Hello') json_schema = json.loads(result.data) @@ -106,7 +106,7 @@ def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover def test_docstring_sphinx(): agent = Agent(FunctionModel(get_json_schema)) - agent.retriever_plain(sphinx_style_docstring) + agent.tool_plain(sphinx_style_docstring) result = agent.run_sync('Hello') json_schema = json.loads(result.data) @@ -138,7 +138,7 @@ def numpy_style_docstring(*, foo: int, bar: str) -> str: # pragma: no cover def test_docstring_numpy(): agent = Agent(FunctionModel(get_json_schema)) - agent.retriever_plain(numpy_style_docstring) + agent.tool_plain(numpy_style_docstring) result = agent.run_sync('Hello') json_schema = json.loads(result.data) @@ -163,7 +163,7 @@ def unknown_docstring(**kwargs: int) -> str: # pragma: no cover def test_docstring_unknown(): agent = Agent(FunctionModel(get_json_schema)) - agent.retriever_plain(unknown_docstring) + agent.tool_plain(unknown_docstring) result = agent.run_sync('Hello') json_schema = json.loads(result.data) @@ -192,7 +192,7 @@ async def google_style_docstring_no_body( def test_docstring_google_no_body(): agent = Agent(FunctionModel(get_json_schema)) - agent.retriever_plain(google_style_docstring_no_body) + agent.tool_plain(google_style_docstring_no_body) result = agent.run_sync('') json_schema = json.loads(result.data) @@ -216,7 +216,7 @@ class Foo(BaseModel): x: int y: str - @agent.retriever_plain + @agent.tool_plain def takes_just_model(model: Foo) -> str: return f'{model.x} {model.y}' @@ -242,7 +242,7 @@ class Foo(BaseModel): x: int y: str - @agent.retriever_plain + @agent.tool_plain def takes_just_model(model: Foo, z: int) -> str: return f'{model.x} {model.y} {z}' diff --git a/tests/test_streaming.py b/tests/test_streaming.py index a0b9d94381..68ce8f9a62 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -33,7 +33,7 @@ async def test_streamed_text_response(): agent = Agent(m) - @agent.retriever_plain + @agent.tool_plain async def ret_a(x: str) -> str: return f'{x}-apple' @@ -189,7 +189,7 @@ async def stream_structured_function( agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int]) - @agent.retriever_plain + @agent.tool_plain async def ret_a(x: str) -> str: assert x == 'hello' return f'{x} world' @@ -244,7 +244,7 @@ async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int]) - @agent.retriever_plain + @agent.tool_plain async def ret_a(x: str) -> str: return x diff --git a/tests/typed_agent.py b/tests/typed_agent.py index f2d7e2d6af..dea55eba47 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -44,7 +44,7 @@ def expect_error(error_type: type[Exception]) -> Iterator[None]: raise AssertionError('Expected an error') -@typed_agent.retriever +@typed_agent.tool async def ok_retriever(ctx: CallContext[MyDeps], x: str) -> str: assert_type(ctx.deps, MyDeps) total = ctx.deps.foo + ctx.deps.bar @@ -55,35 +55,35 @@ async def ok_retriever(ctx: CallContext[MyDeps], x: str) -> str: assert_type(ok_retriever, Callable[[CallContext[MyDeps], str], str]) # type: ignore[assert-type] -@typed_agent.retriever_plain +@typed_agent.tool_plain def ok_retriever_plain(x: str) -> dict[str, str]: return {'x': x} -@typed_agent.retriever_plain +@typed_agent.tool_plain def ok_json_list(x: str) -> list[Union[str, int]]: return [x, 1] -@typed_agent.retriever +@typed_agent.tool async def bad_retriever1(ctx: CallContext[MyDeps], x: str) -> str: total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined] return f'{x} {total}' -@typed_agent.retriever # type: ignore[arg-type] +@typed_agent.tool # type: ignore[arg-type] async def bad_retriever2(ctx: CallContext[int], x: str) -> str: return f'{x} {ctx.deps}' -@typed_agent.retriever_plain # type: ignore[arg-type] +@typed_agent.tool_plain # type: ignore[arg-type] async def bad_retriever_return(x: int) -> list[MyDeps]: return [MyDeps(1, x)] with expect_error(ValueError): - @typed_agent.retriever # type: ignore[arg-type] + @typed_agent.tool # type: ignore[arg-type] async def bad_retriever3(x: str) -> str: return x From bda8b16cb1c8ab9522780edfe1f78ee90e542c82 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 25 Nov 2024 23:43:29 +0000 Subject: [PATCH 2/4] rename remaining references --- docs/agents.md | 70 ++++++------ docs/api/agent.md | 4 +- docs/dependencies.md | 8 +- docs/examples/bank-support.md | 2 +- docs/examples/rag.md | 4 +- docs/examples/weather-agent.md | 2 +- docs/index.md | 18 ++-- docs/results.md | 4 +- pydantic_ai_slim/pydantic_ai/_pydantic.py | 12 +-- pydantic_ai_slim/pydantic_ai/_result.py | 2 +- .../pydantic_ai/{_retriever.py => _tool.py} | 24 ++--- pydantic_ai_slim/pydantic_ai/agent.py | 102 ++++++++---------- pydantic_ai_slim/pydantic_ai/dependencies.py | 26 ++--- pydantic_ai_slim/pydantic_ai/exceptions.py | 2 +- pydantic_ai_slim/pydantic_ai/messages.py | 10 +- .../pydantic_ai/models/__init__.py | 6 +- .../pydantic_ai/models/function.py | 8 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 8 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 4 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 4 +- pydantic_ai_slim/pydantic_ai/models/test.py | 40 +++---- .../pydantic_ai/models/vertexai.py | 4 +- tests/models/test_gemini.py | 12 +-- tests/models/test_model_function.py | 18 ++-- tests/models/test_model_test.py | 4 +- tests/test_agent.py | 12 +-- tests/test_deps.py | 12 +-- tests/test_retrievers.py | 28 ++--- tests/test_streaming.py | 12 +-- tests/typed_agent.py | 16 +-- 30 files changed, 234 insertions(+), 244 deletions(-) rename pydantic_ai_slim/pydantic_ai/{_retriever.py => _tool.py} (81%) diff --git a/docs/agents.md b/docs/agents.md index 675c243efb..81640462b6 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -8,9 +8,9 @@ but multiple agents can also interact to embody more complex workflows. The [`Agent`][pydantic_ai.Agent] class has full API documentation, but conceptually you can think of an agent as a container for: * A [system prompt](#system-prompts) — a set of instructions for the LLM written by the developer -* One or more [retrievers](#retrievers) — functions that the LLM may call to get information while generating a response +* One or more [tools](#tools) — functions that the LLM may call to get information while generating a response * An optional structured [result type](results.md) — the structured datatype the LLM must return at the end of a run -* A [dependency](dependencies.md) type constraint — system prompt functions, retrievers and result validators may all use dependencies when they're run +* A [dependency](dependencies.md) type constraint — system prompt functions, tools and result validators may all use dependencies when they're run * Agents may optionally also have a default [model](api/models/base.md) associated with them; the model to use can also be specified when running the agent In typing terms, agents are generic in their dependency and result types, e.g., an agent which required dependencies of type `#!python Foobar` and returned results of type `#!python list[str]` would have type `#!python Agent[Foobar, list[str]]`. @@ -49,7 +49,7 @@ print(result.data) ``` 1. Create an agent, which expects an integer dependency and returns a boolean result. This agent will have type `#!python Agent[int, bool]`. -2. Define a retriever that checks if the square is a winner. Here [`CallContext`][pydantic_ai.dependencies.CallContext] is parameterized with the dependency type `int`; if you got the dependency type wrong you'd get a typing error. +2. Define a tool that checks if the square is a winner. Here [`CallContext`][pydantic_ai.dependencies.CallContext] is parameterized with the dependency type `int`; if you got the dependency type wrong you'd get a typing error. 3. In reality, you might want to use a random number here e.g. `random.randint(0, 36)`. 4. `result.data` will be a boolean indicating if the square is a winner. Pydantic performs the result validation, it'll be typed as a `bool` since its type is derived from the `result_type` generic parameter of the agent. @@ -166,23 +166,23 @@ print(result.data) _(This example is complete, it can be run "as is")_ -## Retrievers +## Tools -Retrievers provide a mechanism for models to request extra information to help them generate a response. +Tools provide a mechanism for models to request extra information to help them generate a response. They're useful when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. -!!! info "Retrievers vs. RAG" - Retrievers are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. +!!! info "Tools vs. RAG" + Tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. - The main semantic difference between PydanticAI Retrievers and RAG is RAG is synonymous with vector search, while PydanticAI retrievers are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) + The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) -There are two different decorator functions to register retrievers: +There are two different decorator functions to register tools: -1. [`@agent.tool_plain`][pydantic_ai.Agent.retriever_plain] — for retrievers that don't need access to the agent [context][pydantic_ai.dependencies.CallContext] -2. [`@agent.tool`][pydantic_ai.Agent.retriever] — for retrievers that do need access to the agent [context][pydantic_ai.dependencies.CallContext] +1. [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] — for tools that don't need access to the agent [context][pydantic_ai.dependencies.CallContext] +2. [`@agent.tool`][pydantic_ai.Agent.tool] — for tools that do need access to the agent [context][pydantic_ai.dependencies.CallContext] -`@agent.tool` is the default since in the majority of cases retrievers will need access to the agent context. +`@agent.tool` is the default since in the majority of cases tools will need access to the agent context. Here's an example using both: @@ -221,8 +221,8 @@ print(dice_result.data) 1. This is a pretty simple task, so we can use the fast and cheap Gemini flash model. 2. We pass the user's name as the dependency, to keep things simple we use just the name as a string as the dependency. -3. This retriever doesn't need any context, it just returns a random number. You could probably use a dynamic system prompt in this case. -4. This retriever needs the player's name, so it uses `CallContext` to access dependencies which are just the player's name in this case. +3. This tool doesn't need any context, it just returns a random number. You could probably use a dynamic system prompt in this case. +4. This tool needs the player's name, so it uses `CallContext` to access dependencies which are just the player's name in this case. 5. Run the agent, passing the player's name as the dependency. _(This example is complete, it can be run "as is")_ @@ -297,9 +297,9 @@ sequenceDiagram Note over Agent: Send prompts Agent ->> LLM: System: "You're a dice game..."
User: "My guess is 4" activate LLM - Note over LLM: LLM decides to use
a retriever + Note over LLM: LLM decides to use
a tool - LLM ->> Agent: Call retriever
roll_die() + LLM ->> Agent: Call tool
roll_die() deactivate LLM activate Agent Note over Agent: Rolls a six-sided die @@ -307,9 +307,9 @@ sequenceDiagram Agent -->> LLM: ToolReturn
"4" deactivate Agent activate LLM - Note over LLM: LLM decides to use
another retriever + Note over LLM: LLM decides to use
another tool - LLM ->> Agent: Call retriever
get_player_name() + LLM ->> Agent: Call tool
get_player_name() deactivate LLM activate Agent Note over Agent: Retrieves player name @@ -323,19 +323,19 @@ sequenceDiagram Note over Agent: Game session complete ``` -### Retrievers, tools, and schema +### tools and schema -Under the hood, retrievers use the model's "tools" or "functions" API to let the model know what retrievers are available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call retrievers while others end the run and return a result. +Under the hood, tools use the model's "tools" or "functions" API to let the model know what tools are available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call tools while others end the run and return a result. Function parameters are extracted from the function signature, and all parameters except `CallContext` are used to build the schema for that tool call. -Even better, PydanticAI extracts the docstring from retriever functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and adds them to the schema. +Even better, PydanticAI extracts the docstring from tool functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and adds them to the schema. [Griffe supports](https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings) extracting parameter descriptions from `google`, `numpy` and `sphinx` style docstrings, and PydanticAI will infer the format to use based on the docstring. We plan to add support in the future to explicitly set the style to use, and warn/error if not all parameters are documented; see [#59](https://github.com/pydantic/pydantic-ai/issues/59). -To demonstrate a retriever's schema, here we use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to print the schema a model would receive: +To demonstrate a tool's schema, here we use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to print the schema a model would receive: -```py title="retriever_schema.py" +```py title="tool_schema.py" from pydantic_ai import Agent from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse from pydantic_ai.models.function import AgentInfo, FunctionModel @@ -356,10 +356,10 @@ def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: def print_schema(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: - retriever = info.retrievers['foobar'] - print(retriever.description) + tool = info.tools['foobar'] + print(tool.description) #> Get me foobar. - print(retriever.json_schema) + print(tool.json_schema) """ { 'description': 'Get me foobar.', @@ -386,22 +386,22 @@ agent.run_sync('hello', model=FunctionModel(print_schema)) _(This example is complete, it can be run "as is")_ -The return type of retriever can be any valid JSON object ([`JsonData`][pydantic_ai.dependencies.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON. +The return type of tool can be any valid JSON object ([`JsonData`][pydantic_ai.dependencies.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON. -If a retriever has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the retriever is simplified to be just that object. (TODO example) +If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object. (TODO example) ## Reflection and self-correction -Validation errors from both retriever parameter validation and [structured result validation](results.md#structured-result-validation) can be passed back to the model with a request to retry. +Validation errors from both tool parameter validation and [structured result validation](results.md#structured-result-validation) can be passed back to the model with a request to retry. -You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [retriever](#retrievers) or [result validator function](results.md#result-validators-functions) to tell the model it should retry generating a response. +You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [tool](#tools) or [result validator function](results.md#result-validators-functions) to tell the model it should retry generating a response. -- The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific retriever][pydantic_ai.Agent.retriever], or a [result validator][pydantic_ai.Agent.__init__]. -- You can access the current retry count from within a retriever or result validator via [`ctx.retry`][pydantic_ai.dependencies.CallContext]. +- The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific tool][pydantic_ai.Agent.tool], or a [result validator][pydantic_ai.Agent.__init__]. +- You can access the current retry count from within a tool or result validator via [`ctx.retry`][pydantic_ai.dependencies.CallContext]. Here's an example: -```py title="retriever_retry.py" +```py title="tool_retry.py" from fake_database import DatabaseConn from pydantic import BaseModel @@ -467,7 +467,7 @@ try: result = agent.run_sync('Please get me the volume of a box with size 6.') except UnexpectedModelBehavior as e: print('An error occurred:', e) - #> An error occurred: Retriever exceeded max retries count of 1 + #> An error occurred: Tool exceeded max retries count of 1 print('cause:', repr(e.__cause__)) #> cause: ModelRetry('Please try again.') print('messages:', agent.last_run_messages) @@ -513,6 +513,6 @@ except UnexpectedModelBehavior as e: else: print(result.data) ``` -1. Define a retriever that will raise `ModelRetry` repeatedly in this case. +1. Define a tool that will raise `ModelRetry` repeatedly in this case. _(This example is complete, it can be run "as is")_ diff --git a/docs/api/agent.md b/docs/api/agent.md index 9074c0bb15..6b012f1ea3 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -12,6 +12,6 @@ - override_model - last_run_messages - system_prompt - - retriever - - retriever_plain + - tool + - tool_plain - result_validator diff --git a/docs/dependencies.md b/docs/dependencies.md index 4ccba8a6d7..32e5474fd7 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -1,6 +1,6 @@ # Dependencies -PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators-functions). +PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [tools](agents.md#tools) and [result validators](results.md#result-validators-functions). Matching PydanticAI's design philosophy, our dependency system tries to use existing best practice in Python development rather than inventing esoteric "magic", this should make dependencies type-safe, understandable easier to test and ultimately easier to deploy in production. @@ -101,7 +101,7 @@ _(This example is complete, it can be run "as is")_ ### Asynchronous vs. Synchronous dependencies -System prompt functions, retriever functions and result validator are all run in the async context of an agent run. +System prompt functions, tool functions and result validator are all run in the async context of an agent run. If these functions are not coroutines (e.g. `async def`) they are called with [`run_in_executor`][asyncio.loop.run_in_executor] in a thread pool, it's therefore marginally preferable @@ -158,7 +158,7 @@ _(This example is complete, it can be run "as is")_ ## Full Example -As well as system prompts, dependencies can be used in [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators-functions). +As well as system prompts, dependencies can be used in [tools](agents.md#tools) and [result validators](results.md#result-validators-functions). ```py title="full_example.py" hl_lines="27-35 38-48" from dataclasses import dataclass @@ -219,7 +219,7 @@ async def main(): #> Did you hear about the toothpaste scandal? They called it Colgate. ``` -1. To pass `CallContext` and to a retriever, us the [`retriever`][pydantic_ai.Agent.retriever] decorator. +1. To pass `CallContext` and to a tool, us the [`tool`][pydantic_ai.Agent.tool] decorator. 2. `CallContext` may optionally be passed to a [`result_validator`][pydantic_ai.Agent.result_validator] function as the first argument. _(This example is complete, it can be run "as is")_ diff --git a/docs/examples/bank-support.md b/docs/examples/bank-support.md index 9068590bea..0abddc5d14 100644 --- a/docs/examples/bank-support.md +++ b/docs/examples/bank-support.md @@ -4,7 +4,7 @@ Demonstrates: * [dynamic system prompt](../agents.md#system-prompts) * [structured `result_type`](../results.md#structured-result-validation) -* [retrievers](../agents.md#retrievers) +* [tools](../agents.md#tools) ## Running the Example diff --git a/docs/examples/rag.md b/docs/examples/rag.md index 9a4d5382db..64cf4af321 100644 --- a/docs/examples/rag.md +++ b/docs/examples/rag.md @@ -4,12 +4,12 @@ RAG search example. This demo allows you to ask question of the [logfire](https: Demonstrates: -* [retrievers](../agents.md#retrievers) +* [tools](../agents.md#tools) * [agent dependencies](../dependencies.md) * RAG search This is done by creating a database containing each section of the markdown documentation, then registering -the search tool as a retriever with the PydanticAI agent. +the search tool as a tool with the PydanticAI agent. Logic for extracting sections from markdown files and a JSON file with that data is available in [this gist](https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992). diff --git a/docs/examples/weather-agent.md b/docs/examples/weather-agent.md index f5112a0307..c770431134 100644 --- a/docs/examples/weather-agent.md +++ b/docs/examples/weather-agent.md @@ -2,7 +2,7 @@ Example of PydanticAI with multiple tools which the LLM needs to call in turn to Demonstrates: -* [retrievers](../agents.md#retrievers) +* [tools](../agents.md#tools) * [agent dependencies](../dependencies.md) In this case the idea is a "weather" agent — the user can ask for the weather in multiple locations, diff --git a/docs/index.md b/docs/index.md index fe04c68b23..2eea30ce77 100644 --- a/docs/index.md +++ b/docs/index.md @@ -48,9 +48,9 @@ The first known use of "hello, world" was in a 1974 textbook about the C program _(This example is complete, it can be run "as is")_ -Not very interesting yet, but we can easily add "retrievers", dynamic system prompts, and structured responses to build more powerful agents. +Not very interesting yet, but we can easily add "tools", dynamic system prompts, and structured responses to build more powerful agents. -## Retrievers & Dependency Injection Example +## Tools & Dependency Injection Example Here is a concise example using PydanticAI to build a support agent for a bank: @@ -124,15 +124,15 @@ async def main(): 1. This [agent](agents.md) will act as first-tier support in a bank. Agents are generic in the type of dependencies they accept and the type of result they return. In this case, the support agent has type `#!python Agent[SupportDependencies, SupportResult]`. 2. Here we configure the agent to use [OpenAI's GPT-4o model](api/models/openai.md), you can also set the model when running the agent. -3. The `SupportDependencies` dataclass is used to pass data, connections, and logic into the model that will be needed when running [system prompt](agents.md#system-prompts) and [retriever](agents.md#retrievers) functions. PydanticAI's system of dependency injection provides a type-safe way to customise the behavior of your agents, and can be especially useful when running unit tests and evals. +3. The `SupportDependencies` dataclass is used to pass data, connections, and logic into the model that will be needed when running [system prompt](agents.md#system-prompts) and [tool](agents.md#tools) functions. PydanticAI's system of dependency injection provides a type-safe way to customise the behavior of your agents, and can be especially useful when running unit tests and evals. 4. Static [system prompts](agents.md#system-prompts) can be registered with the [`system_prompt` keyword argument][pydantic_ai.Agent.__init__] to the agent. 5. Dynamic [system prompts](agents.md#system-prompts) can be registered with the [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt] decorator, and can make use of dependency injection. Dependencies are carried via the [`CallContext`][pydantic_ai.dependencies.CallContext] argument, which is parameterized with the `deps_type` from above. If the type annotation here is wrong, static type checkers will catch it. -6. [Retrievers](agents.md#retrievers) let you register "tools" which the LLM may call while responding to a user. Again, dependencies are carried via [`CallContext`][pydantic_ai.dependencies.CallContext], and any other arguments become the tool schema passed to the LLM. Pydantic is used to validate these arguments, and errors are passed back to the LLM so it can retry. -7. The docstring of a retriever also passed to the LLM as a description of the tool. Parameter descriptions are [extracted](agents.md#retrievers-tools-and-schema) from the docstring and added to the tool schema sent to the LLM. -8. [Run the agent](agents.md#running-agents) asynchronously, conducting a conversation with the LLM until a final response is reached. Even in this fairly simple case, the agent will exchange multiple messages with the LLM as retrievers are called to retrieve a result. +6. [Tools](agents.md#tools) let you register "tools" which the LLM may call while responding to a user. Again, dependencies are carried via [`CallContext`][pydantic_ai.dependencies.CallContext], and any other arguments become the tool schema passed to the LLM. Pydantic is used to validate these arguments, and errors are passed back to the LLM so it can retry. +7. The docstring of a tool also passed to the LLM as a description of the tool. Parameter descriptions are [extracted](agents.md#tools-tools-and-schema) from the docstring and added to the tool schema sent to the LLM. +8. [Run the agent](agents.md#running-agents) asynchronously, conducting a conversation with the LLM until a final response is reached. Even in this fairly simple case, the agent will exchange multiple messages with the LLM as tools are called to retrieve a result. 9. The response from the agent will, be guaranteed to be a `SupportResult`, if validation fails [reflection](agents.md#reflection-and-self-correction) will mean the agent is prompted to try again. 10. The result will be validated with Pydantic to guarantee it is a `SupportResult`, since the agent is generic, it'll also be typed as a `SupportResult` to aid with static type checking. -11. In a real use case, you'd add many more retrievers and a longer system prompt to the agent to extend the context it's equipped with and support it can provide. +11. In a real use case, you'd add many more tools and a longer system prompt to the agent to extend the context it's equipped with and support it can provide. 12. This is a simple sketch of a database connection, used to keep the example short and readable. In reality, you'd be connecting to an external database (e.g. PostgreSQL) to get information about customers. 13. This [Pydantic](https://docs.pydantic.dev) model is used to constrain the structured data returned by the agent. From this simple definition, Pydantic builds the JSON Schema that tells the LLM how to return the data, and performs validation to guarantee the data is correct at the end of the conversation. @@ -153,8 +153,8 @@ sequenceDiagram Agent ->> LLM: Request
System: "You are a support agent..."
System: "The customer's name is John"
User: "What is my balance?" activate LLM - Note over LLM: LLM decides to use a retriever - LLM ->> Agent: Call retriever
customer_balance() + Note over LLM: LLM decides to use a tool + LLM ->> Agent: Call tool
customer_balance() deactivate LLM activate Agent Note over Agent: Retrieve account balance diff --git a/docs/results.md b/docs/results.md index 0c0a217677..4704c824e8 100644 --- a/docs/results.md +++ b/docs/results.md @@ -34,7 +34,7 @@ If the result type is a union with multiple members (after remove `str` from the If the result type schema is not of type `"object"`, the result type is wrapped in a single element object, so the schema of all tools registered with the model are object schemas. -Structured results (like retrievers) use Pydantic to build the JSON schema used for the tool, and to validate the data returned by the model. +Structured results (like tools) use Pydantic to build the JSON schema used for the tool, and to validate the data returned by the model. !!! note "Bring on PEP-747" Until [PEP-747](https://peps.python.org/pep-0747/) "Annotating Type Forms" lands, unions are not valid as `type`s in Python. @@ -160,7 +160,7 @@ _(This example is complete, it can be run "as is")_ There two main challenges with streamed results: 1. Validating structured responses before they're complete, this is achieved by "partial validation" which was recently added to Pydantic in [pydantic/pydantic#10748](https://github.com/pydantic/pydantic/pull/10748). -2. When receiving a response, we don't know if it's the final response without starting to stream it and peeking at the content. PydanticAI streams just enough of the response to sniff out if it's a retriever call or a result, then streams the whole thing and calls retrievers, or returns the stream as a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult]. +2. When receiving a response, we don't know if it's the final response without starting to stream it and peeking at the content. PydanticAI streams just enough of the response to sniff out if it's a tool call or a result, then streams the whole thing and calls tools, or returns the stream as a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult]. ### Streaming Text diff --git a/pydantic_ai_slim/pydantic_ai/_pydantic.py b/pydantic_ai_slim/pydantic_ai/_pydantic.py index 3b241de35e..c0bdb2c4aa 100644 --- a/pydantic_ai_slim/pydantic_ai/_pydantic.py +++ b/pydantic_ai_slim/pydantic_ai/_pydantic.py @@ -20,8 +20,8 @@ from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like if TYPE_CHECKING: - from . import _retriever - from .dependencies import AgentDeps, RetrieverParams + from . import _tool + from .dependencies import AgentDeps, ToolParams __all__ = 'function_schema', 'LazyTypeAdapter' @@ -39,8 +39,8 @@ class FunctionSchema(TypedDict): var_positional_field: str | None -def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, RetrieverParams]) -> FunctionSchema: # noqa: C901 - """Build a Pydantic validator and JSON schema from a retriever function. +def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]) -> FunctionSchema: # noqa: C901 + """Build a Pydantic validator and JSON schema from a tool function. Args: either_function: The function to build a validator and JSON schema for. @@ -78,10 +78,10 @@ def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, R if index == 0 and takes_ctx: if not _is_call_ctx(annotation): - errors.append('First argument must be a CallContext instance when using `.retriever`') + errors.append('First argument must be a CallContext instance when using `.tool`') continue elif not takes_ctx and _is_call_ctx(annotation): - errors.append('CallContext instance can only be used with `.retriever`') + errors.append('CallContext instance can only be used with `.tool`') continue elif index != 0 and _is_call_ctx(annotation): errors.append('CallContext instance can only be used as the first argument') diff --git a/pydantic_ai_slim/pydantic_ai/_result.py b/pydantic_ai_slim/pydantic_ai/_result.py index d529ee546f..26d7ace457 100644 --- a/pydantic_ai_slim/pydantic_ai/_result.py +++ b/pydantic_ai_slim/pydantic_ai/_result.py @@ -75,7 +75,7 @@ def __init__(self, tool_retry: messages.RetryPrompt): class ResultSchema(Generic[ResultData]): """Model the final response from an agent run. - Similar to `Retriever` but for the final result of running an agent. + Similar to `Tool` but for the final result of running an agent. """ tools: dict[str, ResultTool[ResultData]] diff --git a/pydantic_ai_slim/pydantic_ai/_retriever.py b/pydantic_ai_slim/pydantic_ai/_tool.py similarity index 81% rename from pydantic_ai_slim/pydantic_ai/_retriever.py rename to pydantic_ai_slim/pydantic_ai/_tool.py index 2a24127d03..b808c3058e 100644 --- a/pydantic_ai_slim/pydantic_ai/_retriever.py +++ b/pydantic_ai_slim/pydantic_ai/_tool.py @@ -9,22 +9,20 @@ from pydantic_core import SchemaValidator from . import _pydantic, _utils, messages -from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc +from .dependencies import AgentDeps, CallContext, ToolContextFunc, ToolParams, ToolPlainFunc from .exceptions import ModelRetry, UnexpectedModelBehavior -# Usage `RetrieverEitherFunc[AgentDependencies, P]` -RetrieverEitherFunc = _utils.Either[ - RetrieverContextFunc[AgentDeps, RetrieverParams], RetrieverPlainFunc[RetrieverParams] -] +# Usage `ToolEitherFunc[AgentDependencies, P]` +ToolEitherFunc = _utils.Either[ToolContextFunc[AgentDeps, ToolParams], ToolPlainFunc[ToolParams]] @dataclass(init=False) -class Tool(Generic[AgentDeps, RetrieverParams]): - """A retriever function for an agent.""" +class Tool(Generic[AgentDeps, ToolParams]): + """A tool function for an agent.""" name: str description: str - function: RetrieverEitherFunc[AgentDeps, RetrieverParams] = field(repr=False) + function: ToolEitherFunc[AgentDeps, ToolParams] = field(repr=False) is_async: bool single_arg_name: str | None positional_fields: list[str] @@ -35,8 +33,8 @@ class Tool(Generic[AgentDeps, RetrieverParams]): _current_retry: int = 0 outer_typed_dict_key: str | None = None - def __init__(self, function: RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int): - """Build a Retriever dataclass from a function.""" + def __init__(self, function: ToolEitherFunc[AgentDeps, ToolParams], retries: int): + """Build a Tool dataclass from a function.""" self.function = function # noinspection PyTypeChecker f = _pydantic.function_schema(function) @@ -56,7 +54,7 @@ def reset(self) -> None: self._current_retry = 0 async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message: - """Run the retriever function asynchronously.""" + """Run the tool function asynchronously.""" try: if isinstance(message.args, messages.ArgsJson): args_dict = self.validator.validate_json(message.args.args_json) @@ -100,8 +98,8 @@ def _call_args( def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt: self._current_retry += 1 if self._current_retry > self.max_retries: - # TODO custom error with details of the retriever - raise UnexpectedModelBehavior(f'Retriever exceeded max retries count of {self.max_retries}') from exc + # TODO custom error with details of the tool + raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc else: if isinstance(exc, ValidationError): content = exc.errors(include_url=False) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index e9b8b0c69e..edb75fa25c 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -11,15 +11,15 @@ from . import ( _result, - _retriever as _r, _system_prompt, + _tool as _r, _utils, exceptions, messages as _messages, models, result, ) -from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc +from .dependencies import AgentDeps, CallContext, ToolContextFunc, ToolParams, ToolPlainFunc from .result import ResultData __all__ = ('Agent',) @@ -366,7 +366,7 @@ async def async_system_prompt(ctx: CallContext[str]) -> str: result = agent.run_sync('foobar', deps='spam') print(result.data) - #> success (no retriever calls) + #> success (no tool calls) ``` """ self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func)) @@ -421,40 +421,36 @@ async def result_validator_deps(ctx: CallContext[str], data: str) -> str: result = agent.run_sync('foobar', deps='spam') print(result.data) - #> success (no retriever calls) + #> success (no tool calls) ``` """ self._result_validators.append(_result.ResultValidator(func)) return func @overload - def tool( - self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], / - ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ... + def tool(self, func: ToolContextFunc[AgentDeps, ToolParams], /) -> ToolContextFunc[AgentDeps, ToolParams]: ... @overload def tool( self, /, *, retries: int | None = None - ) -> Callable[ - [RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams] - ]: ... + ) -> Callable[[ToolContextFunc[AgentDeps, ToolParams]], ToolContextFunc[AgentDeps, ToolParams]]: ... def tool( self, - func: RetrieverContextFunc[AgentDeps, RetrieverParams] | None = None, + func: ToolContextFunc[AgentDeps, ToolParams] | None = None, /, *, retries: int | None = None, ) -> Any: - """Decorator to register a retriever function which takes + """Decorator to register a tool function which takes [`CallContext`][pydantic_ai.dependencies.CallContext] as its first argument. Can decorate a sync or async functions. The docstring is inspected to extract both the tool description and description of each parameter, - [learn more](../agents.md#retrievers-tools-and-schema). + [learn more](../agents.md#tools-tools-and-schema). - We can't add overloads for every possible signature of retriever, since the return type is a recursive union + We can't add overloads for every possible signature of tool, since the return type is a recursive union so the signature of functions decorated with `@agent.tool` is obscured. Example: @@ -477,44 +473,42 @@ async def spam(ctx: CallContext[str], y: float) -> float: ``` Args: - func: The retriever function to register. - retries: The number of retries to allow for this retriever, defaults to the agent's default retries, + func: The tool function to register. + retries: The number of retries to allow for this tool, defaults to the agent's default retries, which defaults to 1. """ # noqa: D205 if func is None: - def retriever_decorator( - func_: RetrieverContextFunc[AgentDeps, RetrieverParams], - ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: + def tool_decorator( + func_: ToolContextFunc[AgentDeps, ToolParams], + ) -> ToolContextFunc[AgentDeps, ToolParams]: # noinspection PyTypeChecker - self._register_retriever(_utils.Either(left=func_), retries) + self._register_tool(_utils.Either(left=func_), retries) return func_ - return retriever_decorator + return tool_decorator else: # noinspection PyTypeChecker - self._register_retriever(_utils.Either(left=func), retries) + self._register_tool(_utils.Either(left=func), retries) return func @overload - def tool_plain(self, func: RetrieverPlainFunc[RetrieverParams], /) -> RetrieverPlainFunc[RetrieverParams]: ... + def tool_plain(self, func: ToolPlainFunc[ToolParams], /) -> ToolPlainFunc[ToolParams]: ... @overload def tool_plain( self, /, *, retries: int | None = None - ) -> Callable[[RetrieverPlainFunc[RetrieverParams]], RetrieverPlainFunc[RetrieverParams]]: ... + ) -> Callable[[ToolPlainFunc[ToolParams]], ToolPlainFunc[ToolParams]]: ... - def tool_plain( - self, func: RetrieverPlainFunc[RetrieverParams] | None = None, /, *, retries: int | None = None - ) -> Any: - """Decorator to register a retriever function which DOES NOT take `CallContext` as an argument. + def tool_plain(self, func: ToolPlainFunc[ToolParams] | None = None, /, *, retries: int | None = None) -> Any: + """Decorator to register a tool function which DOES NOT take `CallContext` as an argument. Can decorate a sync or async functions. The docstring is inspected to extract both the tool description and description of each parameter, - [learn more](../agents.md#retrievers-tools-and-schema). + [learn more](../agents.md#tools-tools-and-schema). - We can't add overloads for every possible signature of retriever, since the return type is a recursive union + We can't add overloads for every possible signature of tool, since the return type is a recursive union so the signature of functions decorated with `@agent.tool` is obscured. Example: @@ -537,38 +531,36 @@ async def spam(ctx: CallContext[str]) -> float: ``` Args: - func: The retriever function to register. - retries: The number of retries to allow for this retriever, defaults to the agent's default retries, + func: The tool function to register. + retries: The number of retries to allow for this tool, defaults to the agent's default retries, which defaults to 1. """ if func is None: - def retriever_decorator( - func_: RetrieverPlainFunc[RetrieverParams], - ) -> RetrieverPlainFunc[RetrieverParams]: + def tool_decorator( + func_: ToolPlainFunc[ToolParams], + ) -> ToolPlainFunc[ToolParams]: # noinspection PyTypeChecker - self._register_retriever(_utils.Either(right=func_), retries) + self._register_tool(_utils.Either(right=func_), retries) return func_ - return retriever_decorator + return tool_decorator else: - self._register_retriever(_utils.Either(right=func), retries) + self._register_tool(_utils.Either(right=func), retries) return func - def _register_retriever( - self, func: _r.RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int | None - ) -> None: - """Private utility to register a retriever function.""" + def _register_tool(self, func: _r.ToolEitherFunc[AgentDeps, ToolParams], retries: int | None) -> None: + """Private utility to register a tool function.""" retries_ = retries if retries is not None else self._default_retries - retriever = _r.Tool[AgentDeps, RetrieverParams](func, retries_) + tool = _r.Tool[AgentDeps, ToolParams](func, retries_) - if self._result_schema and retriever.name in self._result_schema.tools: - raise ValueError(f'Retriever name conflicts with result schema name: {retriever.name!r}') + if self._result_schema and tool.name in self._result_schema.tools: + raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}') - if retriever.name in self._tools: - raise ValueError(f'Retriever name conflicts with existing retriever: {retriever.name!r}') + if tool.name in self._tools: + raise ValueError(f'Tool name conflicts with existing tool: {tool.name!r}') - self._tools[retriever.name] = retriever + self._tools[tool.name] = tool async def _get_agent_model( self, model: models.Model | models.KnownModelName | None @@ -663,12 +655,12 @@ async def _handle_model_response( if not model_response.calls: raise exceptions.UnexpectedModelBehavior('Received empty tool call message') - # otherwise we run all retriever functions in parallel + # otherwise we run all tool functions in parallel messages: list[_messages.Message] = [] tasks: list[asyncio.Task[_messages.Message]] = [] for call in model_response.calls: - if retriever := self._tools.get(call.tool_name): - tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) + if tool := self._tools.get(call.tool_name): + tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name)) else: messages.append(self._unknown_tool(call.tool_name)) @@ -719,7 +711,7 @@ async def _handle_streamed_model_response( if self._result_schema.find_tool(structured_msg): return _MarkFinalResult(model_response) - # the model is calling a retriever function, consume the response to get the next message + # the model is calling a tool function, consume the response to get the next message async for _ in model_response: pass structured_msg = model_response.get() @@ -727,11 +719,11 @@ async def _handle_streamed_model_response( raise exceptions.UnexpectedModelBehavior('Received empty tool call message') messages: list[_messages.Message] = [structured_msg] - # we now run all retriever functions in parallel + # we now run all tool functions in parallel tasks: list[asyncio.Task[_messages.Message]] = [] for call in structured_msg.calls: - if retriever := self._tools.get(call.tool_name): - tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) + if tool := self._tools.get(call.tool_name): + tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name)) else: messages.append(self._unknown_tool(call.tool_name)) diff --git a/pydantic_ai_slim/pydantic_ai/dependencies.py b/pydantic_ai_slim/pydantic_ai/dependencies.py index ee5a087830..6c22473040 100644 --- a/pydantic_ai_slim/pydantic_ai/dependencies.py +++ b/pydantic_ai_slim/pydantic_ai/dependencies.py @@ -16,10 +16,10 @@ 'CallContext', 'ResultValidatorFunc', 'SystemPromptFunc', - 'RetrieverReturnValue', - 'RetrieverContextFunc', - 'RetrieverPlainFunc', - 'RetrieverParams', + 'ToolReturnValue', + 'ToolContextFunc', + 'ToolPlainFunc', + 'ToolParams', 'JsonData', ) @@ -39,7 +39,7 @@ class CallContext(Generic[AgentDeps]): """Name of the tool being called.""" -RetrieverParams = ParamSpec('RetrieverParams') +ToolParams = ParamSpec('ToolParams') """Retrieval function param spec.""" SystemPromptFunc = Union[ @@ -69,15 +69,15 @@ class CallContext(Generic[AgentDeps]): JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]' """Type representing any JSON data.""" -RetrieverReturnValue = Union[JsonData, Awaitable[JsonData]] -"""Return value of a retriever function.""" -RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], RetrieverParams], RetrieverReturnValue] -"""A retriever function that takes `CallContext` as the first argument. +ToolReturnValue = Union[JsonData, Awaitable[JsonData]] +"""Return value of a tool function.""" +ToolContextFunc = Callable[Concatenate[CallContext[AgentDeps], ToolParams], ToolReturnValue] +"""A tool function that takes `CallContext` as the first argument. -Usage `RetrieverContextFunc[AgentDeps, RetrieverParams]`. +Usage `ToolContextFunc[AgentDeps, ToolParams]`. """ -RetrieverPlainFunc = Callable[RetrieverParams, RetrieverReturnValue] -"""A retriever function that does not take `CallContext` as the first argument. +ToolPlainFunc = Callable[ToolParams, ToolReturnValue] +"""A tool function that does not take `CallContext` as the first argument. -Usage `RetrieverPlainFunc[RetrieverParams]`. +Usage `ToolPlainFunc[ToolParams]`. """ diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index c0df3a67b4..493f00d37f 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -6,7 +6,7 @@ class ModelRetry(Exception): - """Exception raised when a retriever function should be retried. + """Exception raised when a tool function should be retried. The agent will return the message to the model and ask it to try calling the function/tool again. """ diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 3d6ec3f45d..9331d25c5c 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -55,7 +55,7 @@ class UserPrompt: @dataclass class ToolReturn: - """A tool return message, this encodes the result of running a retriever.""" + """A tool return message, this encodes the result of running a tool.""" tool_name: str """The name of the "tool" was called.""" @@ -89,10 +89,10 @@ class RetryPrompt: This can be sent for a number of reasons: - * Pydantic validation of retriever arguments failed, here content is derived from a Pydantic + * Pydantic validation of tool arguments failed, here content is derived from a Pydantic [`ValidationError`][pydantic_core.ValidationError] - * a retriever raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception - * no retriever was found for the tool name + * a tool raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception + * no tool was found for the tool name * the model returned plain text when a structured response was expected * Pydantic validation of a structured response failed, here content is derived from a Pydantic [`ValidationError`][pydantic_core.ValidationError] @@ -182,7 +182,7 @@ def has_content(self) -> bool: class ModelStructuredResponse: """A structured response from a model. - This is used either to call a retriever or to return a structured response from an agent run. + This is used either to call a tool or to return a structured response from an agent run. """ calls: list[ToolCall] diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index c56a5d19d9..6b3758bd57 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -63,7 +63,7 @@ class Model(ABC): @abstractmethod async def agent_model( self, - retrievers: Mapping[str, AbstractToolDefinition], + retrieval_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: @@ -72,7 +72,7 @@ async def agent_model( This is async in case slow/async config checks need to be performed that can't be done in `__init__`. Args: - retrievers: The retrievers available to the agent. + retrieval_tools: The tools available to the agent. allow_text_result: Whether a plain text final response/result is permitted. result_tools: Tool definitions for the final result tool(s), if any. @@ -259,7 +259,7 @@ def infer_model(model: Model | KnownModelName) -> Model: class AbstractToolDefinition(Protocol): """Abstract definition of a function/tool. - This is used for both retrievers and result tools. + This is used for both tools and result tools. """ name: str diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 1fb862b1d1..b8b310a427 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -67,13 +67,13 @@ def __init__(self, function: FunctionDef | None = None, *, stream_function: Stre async def agent_model( self, - retrievers: Mapping[str, AbstractToolDefinition], + retrieval_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: result_tools = list(result_tools) if result_tools is not None else None return FunctionAgentModel( - self.function, self.stream_function, AgentInfo(retrievers, allow_text_result, result_tools) + self.function, self.stream_function, AgentInfo(retrieval_tools, allow_text_result, result_tools) ) def name(self) -> str: @@ -92,8 +92,8 @@ class AgentInfo: This is passed as the second to functions. """ - retrievers: Mapping[str, AbstractToolDefinition] - """The retrievers available on this agent.""" + tools: Mapping[str, AbstractToolDefinition] + """The tools available on this agent.""" allow_text_result: bool """Whether a plain text result is allowed.""" result_tools: list[AbstractToolDefinition] | None diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 3b825e0dde..c9bc54def7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -112,7 +112,7 @@ def __init__( async def agent_model( self, - retrievers: Mapping[str, AbstractToolDefinition], + retrieval_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> GeminiAgentModel: @@ -121,7 +121,7 @@ async def agent_model( model_name=self.model_name, auth=self.auth, url=self.url, - retrievers=retrievers, + retrieval_tools=retrieval_tools, allow_text_result=allow_text_result, result_tools=result_tools, ) @@ -160,12 +160,12 @@ def __init__( model_name: GeminiModelName, auth: AuthProtocol, url: str, - retrievers: Mapping[str, AbstractToolDefinition], + retrieval_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ): check_allow_model_requests() - tools = [_function_from_abstract_tool(t) for t in retrievers.values()] + tools = [_function_from_abstract_tool(t) for t in retrieval_tools.values()] if result_tools is not None: tools += [_function_from_abstract_tool(t) for t in result_tools] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 80d06f8f58..7d98579fb8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -109,12 +109,12 @@ def __init__( async def agent_model( self, - retrievers: Mapping[str, AbstractToolDefinition], + retrieval_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: check_allow_model_requests() - tools = [self._map_tool_definition(r) for r in retrievers.values()] + tools = [self._map_tool_definition(r) for r in retrieval_tools.values()] if result_tools is not None: tools += [self._map_tool_definition(r) for r in result_tools] return GroqAgentModel( diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 8b8d1cf87b..58a0521768 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -89,12 +89,12 @@ def __init__( async def agent_model( self, - retrievers: Mapping[str, AbstractToolDefinition], + retrieval_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: check_allow_model_requests() - tools = [self._map_tool_definition(r) for r in retrievers.values()] + tools = [self._map_tool_definition(r) for r in retrieval_tools.values()] if result_tools is not None: tools += [self._map_tool_definition(r) for r in result_tools] return OpenAIAgentModel( diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index cd7f920703..a231ebbf99 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -45,7 +45,7 @@ def __repr__(self): class TestModel(Model): """A model specifically for testing purposes. - This will (by default) call all retrievers in the agent model, then return a tool response if possible, + This will (by default) call all tools in the agent model, then return a tool response if possible, otherwise a plain response. How useful this function will be is unknown, it may be useless, it may require significant changes to be useful. @@ -57,8 +57,8 @@ class TestModel(Model): # NOTE: Avoid test discovery by pytest. __test__ = False - call_retrievers: list[str] | Literal['all'] = 'all' - """List of retrievers to call. If `'all'`, all retrievers will be called.""" + call_tools: list[str] | Literal['all'] = 'all' + """List of tools to call. If `'all'`, all tools will be called.""" custom_result_text: str | None = None """If set, this text is return as the final result.""" custom_result_args: Any | None = None @@ -66,25 +66,25 @@ class TestModel(Model): seed: int = 0 """Seed for generating random data.""" # these fields are set when the model is called by the agent - agent_model_retrievers: Mapping[str, AbstractToolDefinition] | None = field(default=None, init=False) + agent_model_tools: Mapping[str, AbstractToolDefinition] | None = field(default=None, init=False) agent_model_allow_text_result: bool | None = field(default=None, init=False) agent_model_result_tools: list[AbstractToolDefinition] | None = field(default=None, init=False) async def agent_model( self, - retrievers: Mapping[str, AbstractToolDefinition], + retrieval_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: - self.agent_model_retrievers = retrievers + self.agent_model_tools = retrieval_tools self.agent_model_allow_text_result = allow_text_result self.agent_model_result_tools = list(result_tools) if result_tools is not None else None - if self.call_retrievers == 'all': - retriever_calls = [(r.name, r) for r in retrievers.values()] + if self.call_tools == 'all': + tool_calls = [(r.name, r) for r in retrieval_tools.values()] else: - retrievers_to_call = (retrievers[name] for name in self.call_retrievers) - retriever_calls = [(r.name, r) for r in retrievers_to_call] + tools_to_call = (retrieval_tools[name] for name in self.call_tools) + tool_calls = [(r.name, r) for r in tools_to_call] if self.custom_result_text is not None: assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.' @@ -104,7 +104,7 @@ async def agent_model( result = _utils.Either(right=None) else: result = _utils.Either(left=None) - return TestAgentModel(retriever_calls, result, self.agent_model_result_tools, self.seed) + return TestAgentModel(tool_calls, result, self.agent_model_result_tools, self.seed) def name(self) -> str: return 'test-model' @@ -117,7 +117,7 @@ class TestAgentModel(AgentModel): # NOTE: Avoid test discovery by pytest. __test__ = False - retriever_calls: list[tuple[str, AbstractToolDefinition]] + tool_calls: list[tuple[str, AbstractToolDefinition]] # left means the text is plain text; right means it's a function call result: _utils.Either[str | None, Any | None] result_tools: list[AbstractToolDefinition] | None @@ -137,12 +137,12 @@ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherS else: yield TestStreamStructuredResponse(msg, cost) - def gen_retriever_args(self, tool_def: AbstractToolDefinition) -> Any: + def gen_tool_args(self, tool_def: AbstractToolDefinition) -> Any: return _JsonSchemaTestData(tool_def.json_schema, self.seed).generate() def _request(self, messages: list[Message]) -> ModelAnyResponse: - if self.step == 0 and self.retriever_calls: - calls = [ToolCall.from_object(name, self.gen_retriever_args(args)) for name, args in self.retriever_calls] + if self.step == 0 and self.tool_calls: + calls = [ToolCall.from_object(name, self.gen_tool_args(args)) for name, args in self.tool_calls] self.step += 1 self.last_message_count = len(messages) return ModelStructuredResponse(calls=calls) @@ -152,8 +152,8 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse: new_retry_names = {m.tool_name for m in new_messages if isinstance(m, RetryPrompt)} if new_retry_names: calls = [ - ToolCall.from_object(name, self.gen_retriever_args(args)) - for name, args in self.retriever_calls + ToolCall.from_object(name, self.gen_tool_args(args)) + for name, args in self.tool_calls if name in new_retry_names ] self.step += 1 @@ -162,7 +162,7 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse: if response_text := self.result.left: self.step += 1 if response_text.value is None: - # build up details of retriever responses + # build up details of tool responses output: dict[str, Any] = {} for message in messages: if isinstance(message, ToolReturn): @@ -170,7 +170,7 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse: if output: return ModelTextResponse(content=pydantic_core.to_json(output).decode()) else: - return ModelTextResponse(content='success (no retriever calls)') + return ModelTextResponse(content='success (no tool calls)') else: return ModelTextResponse(content=response_text.value) else: @@ -181,7 +181,7 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse: self.step += 1 return ModelStructuredResponse(calls=[ToolCall.from_object(result_tool.name, custom_result_args)]) else: - response_args = self.gen_retriever_args(result_tool) + response_args = self.gen_tool_args(result_tool) self.step += 1 return ModelStructuredResponse(calls=[ToolCall.from_object(result_tool.name, response_args)]) diff --git a/pydantic_ai_slim/pydantic_ai/models/vertexai.py b/pydantic_ai_slim/pydantic_ai/models/vertexai.py index fc48ab6bfd..2a511a4db3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/vertexai.py +++ b/pydantic_ai_slim/pydantic_ai/models/vertexai.py @@ -166,7 +166,7 @@ def __init__( async def agent_model( self, - retrievers: Mapping[str, AbstractToolDefinition], + retrieval_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> GeminiAgentModel: @@ -176,7 +176,7 @@ async def agent_model( model_name=self.model_name, auth=auth, url=url, - retrievers=retrievers, + retrieval_tools=retrieval_tools, allow_text_result=allow_text_result, result_tools=result_tools, ) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 72b1dee670..476b4ec718 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -96,7 +96,7 @@ class TestToolDefinition: async def test_agent_model_tools(allow_model_requests: None): m = GeminiModel('gemini-1.5-flash', api_key='via-arg') - retrievers = { + tools = { 'foo': TestToolDefinition( 'foo', 'This is foo', @@ -118,7 +118,7 @@ async def test_agent_model_tools(allow_model_requests: None): 'This is the tool for the final Result', {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']}, ) - agent_model = await m.agent_model(retrievers, True, [result_tool]) + agent_model = await m.agent_model(tools, True, [result_tool]) assert agent_model.tools == snapshot( _GeminiTools( function_declarations=[ @@ -628,16 +628,16 @@ async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): gemini_client = get_gemini_client([first_stream, second_stream]) model = GeminiModel('gemini-1.5-flash', http_client=gemini_client) agent = Agent(model, result_type=tuple[int, int]) - retriever_calls: list[str] = [] + tool_calls: list[str] = [] @agent.tool_plain async def foo(x: str) -> str: - retriever_calls.append(f'foo({x=!r})') + tool_calls.append(f'foo({x=!r})') return x @agent.tool_plain async def bar(y: str) -> str: - retriever_calls.append(f'bar({y=!r})') + tool_calls.append(f'bar({y=!r})') return y async with agent.run_stream('Hello') as result: @@ -667,7 +667,7 @@ async def bar(y: str) -> str: ), ] ) - assert retriever_calls == snapshot(["foo(x='a')", "bar(y='b')"]) + assert tool_calls == snapshot(["foo(x='a')", "bar(y='b')"]) async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index da51d48a53..b534ea9790 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -87,7 +87,7 @@ def test_simple(): async def weather_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: # pragma: no cover assert info.allow_text_result - assert info.retrievers.keys() == {'get_location', 'get_weather'} + assert info.tools.keys() == {'get_location', 'get_weather'} last = messages[-1] if last.role == 'user': return ModelStructuredResponse( @@ -222,17 +222,17 @@ def test_var_args(): ) -async def call_retriever(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: +async def call_tool(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: if len(messages) == 1: - assert len(info.retrievers) == 1 - retriever_id = next(iter(info.retrievers.keys())) - return ModelStructuredResponse(calls=[ToolCall.from_json(retriever_id, '{}')]) + assert len(info.tools) == 1 + tool_id = next(iter(info.tools.keys())) + return ModelStructuredResponse(calls=[ToolCall.from_json(tool_id, '{}')]) else: return ModelTextResponse('final response') def test_deps_none(): - agent = Agent(FunctionModel(call_retriever)) + agent = Agent(FunctionModel(call_tool)) @agent.tool async def get_none(ctx: CallContext[None]): @@ -259,7 +259,7 @@ def get_check_foobar(ctx: CallContext[tuple[str, str]]) -> str: assert ctx.deps == ('foo', 'bar') return '' - agent = Agent(FunctionModel(call_retriever), deps_type=tuple[str, str]) + agent = Agent(FunctionModel(call_tool), deps_type=tuple[str, str]) agent.tool(get_check_foobar) called = False agent.run_sync('Hello', deps=('foo', 'bar')) @@ -311,11 +311,11 @@ def spam() -> str: def test_register_all(): async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: return ModelTextResponse( - f'messages={len(messages)} allow_text_result={info.allow_text_result} retrievers={len(info.retrievers)}' + f'messages={len(messages)} allow_text_result={info.allow_text_result} tools={len(info.tools)}' ) result = agent_all.run_sync('Hello', model=FunctionModel(f)) - assert result.data == snapshot('messages=2 allow_text_result=True retrievers=5') + assert result.data == snapshot('messages=2 allow_text_result=True tools=5') def test_call_all(): diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index ecd4cfd01a..fac96568bd 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -38,7 +38,7 @@ async def ret_b(x: str) -> str: # pragma: no cover calls.append('b') return f'{x}-b' - result = agent.run_sync('x', model=TestModel(call_retrievers=['ret_a'])) + result = agent.run_sync('x', model=TestModel(call_tools=['ret_a'])) assert result.data == snapshot('{"ret_a":"a-a"}') assert calls == ['a'] @@ -74,7 +74,7 @@ def test_result_type(): assert result.data == ('a', 'a') -def test_retriever_retry(): +def test_tool_retry(): agent = Agent() call_count = 0 diff --git a/tests/test_agent.py b/tests/test_agent.py index b5112f3168..7b74a833c1 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -179,7 +179,7 @@ def test_response_tuple(): result = agent.run_sync('Hello') assert result.data == snapshot(('a', 'a')) - assert m.agent_model_retrievers == snapshot({}) + assert m.agent_model_tools == snapshot({}) assert m.agent_model_allow_text_result is False assert m.agent_model_result_tools is not None @@ -235,10 +235,10 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any: assert agent._result_schema.allow_text_result is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] result = agent.run_sync('Hello') - assert result.data == snapshot('success (no retriever calls)') + assert result.data == snapshot('success (no tool calls)') assert got_tool_call_name == snapshot(None) - assert m.agent_model_retrievers == snapshot({}) + assert m.agent_model_tools == snapshot({}) assert m.agent_model_allow_text_result is True assert m.agent_model_result_tools is not None @@ -312,7 +312,7 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any: assert result.data == mod.Foo(a=0, b='a') assert got_tool_call_name == snapshot('final_result_Foo') - assert m.agent_model_retrievers == snapshot({}) + assert m.agent_model_tools == snapshot({}) assert m.agent_model_allow_text_result is False assert m.agent_model_result_tools is not None @@ -450,7 +450,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse: agent.run_sync('Hello') -def test_unknown_retriever(): +def test_unknown_tool(): def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse: return ModelStructuredResponse(calls=[ToolCall.from_json('foobar', '{}')]) @@ -472,7 +472,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse: ) -def test_unknown_retriever_fix(): +def test_unknown_tool_fix(): def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse: if len(m) > 1: return ModelTextResponse(content='success') diff --git a/tests/test_deps.py b/tests/test_deps.py index 251a08c844..c52f3e03e2 100644 --- a/tests/test_deps.py +++ b/tests/test_deps.py @@ -14,26 +14,26 @@ class MyDeps: @agent.tool -async def example_retriever(ctx: CallContext[MyDeps]) -> str: +async def example_tool(ctx: CallContext[MyDeps]) -> str: return f'{ctx.deps}' def test_deps_used(): result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) - assert result.data == '{"example_retriever":"MyDeps(foo=1, bar=2)"}' + assert result.data == '{"example_tool":"MyDeps(foo=1, bar=2)"}' def test_deps_override(): with agent.override_deps(MyDeps(foo=3, bar=4)): result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) - assert result.data == '{"example_retriever":"MyDeps(foo=3, bar=4)"}' + assert result.data == '{"example_tool":"MyDeps(foo=3, bar=4)"}' with agent.override_deps(MyDeps(foo=5, bar=6)): result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) - assert result.data == '{"example_retriever":"MyDeps(foo=5, bar=6)"}' + assert result.data == '{"example_tool":"MyDeps(foo=5, bar=6)"}' result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) - assert result.data == '{"example_retriever":"MyDeps(foo=3, bar=4)"}' + assert result.data == '{"example_tool":"MyDeps(foo=3, bar=4)"}' result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) - assert result.data == '{"example_retriever":"MyDeps(foo=1, bar=2)"}' + assert result.data == '{"example_tool":"MyDeps(foo=1, bar=2)"}' diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index b713ba3595..5e9b1bdbcf 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -11,48 +11,48 @@ from pydantic_ai.models.test import TestModel -def test_retriever_no_ctx(): +def test_tool_no_ctx(): agent = Agent(TestModel()) with pytest.raises(UserError) as exc_info: @agent.tool # pyright: ignore[reportArgumentType] - def invalid_retriever(x: int) -> str: # pragma: no cover + def invalid_tool(x: int) -> str: # pragma: no cover return 'Hello' assert str(exc_info.value) == snapshot( - 'Error generating schema for test_retriever_no_ctx..invalid_retriever:\n' - ' First argument must be a CallContext instance when using `.retriever`' + 'Error generating schema for test_tool_no_ctx..invalid_tool:\n' + ' First argument must be a CallContext instance when using `.tool`' ) -def test_retriever_plain_with_ctx(): +def test_tool_plain_with_ctx(): agent = Agent(TestModel()) with pytest.raises(UserError) as exc_info: @agent.tool_plain - async def invalid_retriever(ctx: CallContext[None]) -> str: # pragma: no cover + async def invalid_tool(ctx: CallContext[None]) -> str: # pragma: no cover return 'Hello' assert str(exc_info.value) == snapshot( - 'Error generating schema for test_retriever_plain_with_ctx..invalid_retriever:\n' - ' CallContext instance can only be used with `.retriever`' + 'Error generating schema for test_tool_plain_with_ctx..invalid_tool:\n' + ' CallContext instance can only be used with `.tool`' ) -def test_retriever_ctx_second(): +def test_tool_ctx_second(): agent = Agent(TestModel()) with pytest.raises(UserError) as exc_info: @agent.tool # pyright: ignore[reportArgumentType] - def invalid_retriever(x: int, ctx: CallContext[None]) -> str: # pragma: no cover + def invalid_tool(x: int, ctx: CallContext[None]) -> str: # pragma: no cover return 'Hello' assert str(exc_info.value) == snapshot( - 'Error generating schema for test_retriever_ctx_second..invalid_retriever:\n' - ' First argument must be a CallContext instance when using `.retriever`\n' + 'Error generating schema for test_tool_ctx_second..invalid_tool:\n' + ' First argument must be a CallContext instance when using `.tool`\n' ' CallContext instance can only be used as the first argument' ) @@ -68,8 +68,8 @@ async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover async def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelAnyResponse: - assert len(info.retrievers) == 1 - r = next(iter(info.retrievers.values())) + assert len(info.tools) == 1 + r = next(iter(info.tools.values())) return ModelTextResponse(json.dumps(r.json_schema)) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 68ce8f9a62..e8ab93e2dd 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -162,14 +162,14 @@ async def text_stream(_messages: list[Message], _: AgentInfo) -> AsyncIterator[s assert call_index == 2 -async def test_call_retriever(): +async def test_call_tool(): async def stream_structured_function( messages: list[Message], agent_info: AgentInfo ) -> AsyncIterator[DeltaToolCalls | str]: if len(messages) == 1: - assert agent_info.retrievers is not None - assert len(agent_info.retrievers) == 1 - name = next(iter(agent_info.retrievers)) + assert agent_info.tools is not None + assert len(agent_info.tools) == 1 + name = next(iter(agent_info.tools)) first = messages[0] assert isinstance(first, UserPrompt) json_string = json.dumps({'x': first.content}) @@ -227,7 +227,7 @@ async def ret_a(x: str) -> str: ) -async def test_call_retriever_empty(): +async def test_call_tool_empty(): async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]: yield {} @@ -238,7 +238,7 @@ async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> pass -async def test_call_retriever_wrong_name(): +async def test_call_tool_wrong_name(): async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]: yield {0: DeltaToolCall(name='foobar', json_args='{}')} diff --git a/tests/typed_agent.py b/tests/typed_agent.py index dea55eba47..9295c5742b 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -45,18 +45,18 @@ def expect_error(error_type: type[Exception]) -> Iterator[None]: @typed_agent.tool -async def ok_retriever(ctx: CallContext[MyDeps], x: str) -> str: +async def ok_tool(ctx: CallContext[MyDeps], x: str) -> str: assert_type(ctx.deps, MyDeps) total = ctx.deps.foo + ctx.deps.bar return f'{x} {total}' -# we can't add overloads for every possible signature of retriever, so the type of ok_retriever is obscured -assert_type(ok_retriever, Callable[[CallContext[MyDeps], str], str]) # type: ignore[assert-type] +# we can't add overloads for every possible signature of tool, so the type of ok_tool is obscured +assert_type(ok_tool, Callable[[CallContext[MyDeps], str], str]) # type: ignore[assert-type] @typed_agent.tool_plain -def ok_retriever_plain(x: str) -> dict[str, str]: +def ok_tool_plain(x: str) -> dict[str, str]: return {'x': x} @@ -66,25 +66,25 @@ def ok_json_list(x: str) -> list[Union[str, int]]: @typed_agent.tool -async def bad_retriever1(ctx: CallContext[MyDeps], x: str) -> str: +async def bad_tool1(ctx: CallContext[MyDeps], x: str) -> str: total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined] return f'{x} {total}' @typed_agent.tool # type: ignore[arg-type] -async def bad_retriever2(ctx: CallContext[int], x: str) -> str: +async def bad_tool2(ctx: CallContext[int], x: str) -> str: return f'{x} {ctx.deps}' @typed_agent.tool_plain # type: ignore[arg-type] -async def bad_retriever_return(x: int) -> list[MyDeps]: +async def bad_tool_return(x: int) -> list[MyDeps]: return [MyDeps(1, x)] with expect_error(ValueError): @typed_agent.tool # type: ignore[arg-type] - async def bad_retriever3(x: str) -> str: + async def bad_tool3(x: str) -> str: return x From d816ad4b578b59ee2325b7cebb3ba4081a3a2fad Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 26 Nov 2024 12:02:49 +0000 Subject: [PATCH 3/4] use "function_tools" name --- docs/agents.md | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 20 +++++++++---------- .../pydantic_ai/models/__init__.py | 4 ++-- .../pydantic_ai/models/function.py | 14 ++++++++----- pydantic_ai_slim/pydantic_ai/models/gemini.py | 8 ++++---- pydantic_ai_slim/pydantic_ai/models/groq.py | 4 ++-- pydantic_ai_slim/pydantic_ai/models/openai.py | 4 ++-- pydantic_ai_slim/pydantic_ai/models/test.py | 8 ++++---- .../pydantic_ai/models/vertexai.py | 4 ++-- tests/models/test_model_function.py | 8 ++++---- tests/test_retrievers.py | 4 ++-- tests/test_streaming.py | 6 +++--- 12 files changed, 45 insertions(+), 41 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 81640462b6..fca5191e3f 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -356,7 +356,7 @@ def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: def print_schema(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: - tool = info.tools['foobar'] + tool = info.function_tools['foobar'] print(tool.description) #> Get me foobar. print(tool.json_schema) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index edb75fa25c..5a40cd2cf5 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]): _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False) _allow_text_result: bool = field(repr=False) _system_prompts: tuple[str, ...] = field(repr=False) - _tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False) + _function_tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False) _default_retries: int = field(repr=False) _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False) _deps_type: type[AgentDeps] = field(repr=False) @@ -119,7 +119,7 @@ def __init__( self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) - self._tools: dict[str, _r.Tool[AgentDeps, Any]] = {} + self._function_tools: dict[str, _r.Tool[AgentDeps, Any]] = {} self._deps_type = deps_type self._default_retries = retries self._system_prompt_functions = [] @@ -153,7 +153,7 @@ async def run( new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history) self.last_run_messages = messages - for tool in self._tools.values(): + for tool in self._function_tools.values(): tool.reset() cost = result.Cost() @@ -246,7 +246,7 @@ async def run_stream( new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history) self.last_run_messages = messages - for tool in self._tools.values(): + for tool in self._function_tools.values(): tool.reset() cost = result.Cost() @@ -557,10 +557,10 @@ def _register_tool(self, func: _r.ToolEitherFunc[AgentDeps, ToolParams], retries if self._result_schema and tool.name in self._result_schema.tools: raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}') - if tool.name in self._tools: + if tool.name in self._function_tools: raise ValueError(f'Tool name conflicts with existing tool: {tool.name!r}') - self._tools[tool.name] = tool + self._function_tools[tool.name] = tool async def _get_agent_model( self, model: models.Model | models.KnownModelName | None @@ -593,7 +593,7 @@ async def _get_agent_model( raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.') result_tools = list(self._result_schema.tools.values()) if self._result_schema else None - agent_model = await model_.agent_model(self._tools, self._allow_text_result, result_tools) + agent_model = await model_.agent_model(self._function_tools, self._allow_text_result, result_tools) return model_, custom_model, agent_model async def _prepare_messages( @@ -659,7 +659,7 @@ async def _handle_model_response( messages: list[_messages.Message] = [] tasks: list[asyncio.Task[_messages.Message]] = [] for call in model_response.calls: - if tool := self._tools.get(call.tool_name): + if tool := self._function_tools.get(call.tool_name): tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name)) else: messages.append(self._unknown_tool(call.tool_name)) @@ -722,7 +722,7 @@ async def _handle_streamed_model_response( # we now run all tool functions in parallel tasks: list[asyncio.Task[_messages.Message]] = [] for call in structured_msg.calls: - if tool := self._tools.get(call.tool_name): + if tool := self._function_tools.get(call.tool_name): tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name)) else: messages.append(self._unknown_tool(call.tool_name)) @@ -755,7 +755,7 @@ async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]: def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt: self._incr_result_retry() - names = list(self._tools.keys()) + names = list(self._function_tools.keys()) if self._result_schema: names.extend(self._result_schema.tool_names()) if names: diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 6b3758bd57..19e615862f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -63,7 +63,7 @@ class Model(ABC): @abstractmethod async def agent_model( self, - retrieval_tools: Mapping[str, AbstractToolDefinition], + function_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: @@ -72,7 +72,7 @@ async def agent_model( This is async in case slow/async config checks need to be performed that can't be done in `__init__`. Args: - retrieval_tools: The tools available to the agent. + function_tools: The tools available to the agent. allow_text_result: Whether a plain text final response/result is permitted. result_tools: Tool definitions for the final result tool(s), if any. diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index b8b310a427..b9f4a4a811 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -67,13 +67,13 @@ def __init__(self, function: FunctionDef | None = None, *, stream_function: Stre async def agent_model( self, - retrieval_tools: Mapping[str, AbstractToolDefinition], + function_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: result_tools = list(result_tools) if result_tools is not None else None return FunctionAgentModel( - self.function, self.stream_function, AgentInfo(retrieval_tools, allow_text_result, result_tools) + self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools) ) def name(self) -> str: @@ -89,11 +89,15 @@ def name(self) -> str: class AgentInfo: """Information about an agent. - This is passed as the second to functions. + This is passed as the second to functions used within [`FunctionModel`][pydantic_ai.models.function.FunctionModel]. """ - tools: Mapping[str, AbstractToolDefinition] - """The tools available on this agent.""" + function_tools: Mapping[str, AbstractToolDefinition] + """The function tools available on this agent. + + These are the tools registered via the [`tool`][pydantic_ai.Agent.tool] and + [`tool_plain`][pydantic_ai.Agent.tool_plain] decorators. + """ allow_text_result: bool """Whether a plain text result is allowed.""" result_tools: list[AbstractToolDefinition] | None diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index c9bc54def7..f0e06684b2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -112,7 +112,7 @@ def __init__( async def agent_model( self, - retrieval_tools: Mapping[str, AbstractToolDefinition], + function_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> GeminiAgentModel: @@ -121,7 +121,7 @@ async def agent_model( model_name=self.model_name, auth=self.auth, url=self.url, - retrieval_tools=retrieval_tools, + function_tools=function_tools, allow_text_result=allow_text_result, result_tools=result_tools, ) @@ -160,12 +160,12 @@ def __init__( model_name: GeminiModelName, auth: AuthProtocol, url: str, - retrieval_tools: Mapping[str, AbstractToolDefinition], + function_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ): check_allow_model_requests() - tools = [_function_from_abstract_tool(t) for t in retrieval_tools.values()] + tools = [_function_from_abstract_tool(t) for t in function_tools.values()] if result_tools is not None: tools += [_function_from_abstract_tool(t) for t in result_tools] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 7d98579fb8..46420b48b1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -109,12 +109,12 @@ def __init__( async def agent_model( self, - retrieval_tools: Mapping[str, AbstractToolDefinition], + function_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: check_allow_model_requests() - tools = [self._map_tool_definition(r) for r in retrieval_tools.values()] + tools = [self._map_tool_definition(r) for r in function_tools.values()] if result_tools is not None: tools += [self._map_tool_definition(r) for r in result_tools] return GroqAgentModel( diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 58a0521768..4b409f720a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -89,12 +89,12 @@ def __init__( async def agent_model( self, - retrieval_tools: Mapping[str, AbstractToolDefinition], + function_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: check_allow_model_requests() - tools = [self._map_tool_definition(r) for r in retrieval_tools.values()] + tools = [self._map_tool_definition(r) for r in function_tools.values()] if result_tools is not None: tools += [self._map_tool_definition(r) for r in result_tools] return OpenAIAgentModel( diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index a231ebbf99..f8a1d5482e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -72,18 +72,18 @@ class TestModel(Model): async def agent_model( self, - retrieval_tools: Mapping[str, AbstractToolDefinition], + function_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: - self.agent_model_tools = retrieval_tools + self.agent_model_tools = function_tools self.agent_model_allow_text_result = allow_text_result self.agent_model_result_tools = list(result_tools) if result_tools is not None else None if self.call_tools == 'all': - tool_calls = [(r.name, r) for r in retrieval_tools.values()] + tool_calls = [(r.name, r) for r in function_tools.values()] else: - tools_to_call = (retrieval_tools[name] for name in self.call_tools) + tools_to_call = (function_tools[name] for name in self.call_tools) tool_calls = [(r.name, r) for r in tools_to_call] if self.custom_result_text is not None: diff --git a/pydantic_ai_slim/pydantic_ai/models/vertexai.py b/pydantic_ai_slim/pydantic_ai/models/vertexai.py index 2a511a4db3..b28dbe2f67 100644 --- a/pydantic_ai_slim/pydantic_ai/models/vertexai.py +++ b/pydantic_ai_slim/pydantic_ai/models/vertexai.py @@ -166,7 +166,7 @@ def __init__( async def agent_model( self, - retrieval_tools: Mapping[str, AbstractToolDefinition], + function_tools: Mapping[str, AbstractToolDefinition], allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> GeminiAgentModel: @@ -176,7 +176,7 @@ async def agent_model( model_name=self.model_name, auth=auth, url=url, - retrieval_tools=retrieval_tools, + function_tools=function_tools, allow_text_result=allow_text_result, result_tools=result_tools, ) diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index b534ea9790..44e01aa219 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -87,7 +87,7 @@ def test_simple(): async def weather_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: # pragma: no cover assert info.allow_text_result - assert info.tools.keys() == {'get_location', 'get_weather'} + assert info.function_tools.keys() == {'get_location', 'get_weather'} last = messages[-1] if last.role == 'user': return ModelStructuredResponse( @@ -224,8 +224,8 @@ def test_var_args(): async def call_tool(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: if len(messages) == 1: - assert len(info.tools) == 1 - tool_id = next(iter(info.tools.keys())) + assert len(info.function_tools) == 1 + tool_id = next(iter(info.function_tools.keys())) return ModelStructuredResponse(calls=[ToolCall.from_json(tool_id, '{}')]) else: return ModelTextResponse('final response') @@ -311,7 +311,7 @@ def spam() -> str: def test_register_all(): async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: return ModelTextResponse( - f'messages={len(messages)} allow_text_result={info.allow_text_result} tools={len(info.tools)}' + f'messages={len(messages)} allow_text_result={info.allow_text_result} tools={len(info.function_tools)}' ) result = agent_all.run_sync('Hello', model=FunctionModel(f)) diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 5e9b1bdbcf..37258c4809 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -68,8 +68,8 @@ async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover async def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelAnyResponse: - assert len(info.tools) == 1 - r = next(iter(info.tools.values())) + assert len(info.function_tools) == 1 + r = next(iter(info.function_tools.values())) return ModelTextResponse(json.dumps(r.json_schema)) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index e8ab93e2dd..4b272a232d 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -167,9 +167,9 @@ async def stream_structured_function( messages: list[Message], agent_info: AgentInfo ) -> AsyncIterator[DeltaToolCalls | str]: if len(messages) == 1: - assert agent_info.tools is not None - assert len(agent_info.tools) == 1 - name = next(iter(agent_info.tools)) + assert agent_info.function_tools is not None + assert len(agent_info.function_tools) == 1 + name = next(iter(agent_info.function_tools)) first = messages[0] assert isinstance(first, UserPrompt) json_string = json.dumps({'x': first.content}) From eeec58773715951be53eaf518cc8711ead473625 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 26 Nov 2024 12:57:01 +0000 Subject: [PATCH 4/4] docs suggestion from @dmontagu --- docs/agents.md | 24 +++++++++++++----------- docs/dependencies.md | 4 ++-- docs/examples/rag.md | 2 +- docs/index.md | 2 +- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index fca5191e3f..b77ddbf816 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -8,10 +8,10 @@ but multiple agents can also interact to embody more complex workflows. The [`Agent`][pydantic_ai.Agent] class has full API documentation, but conceptually you can think of an agent as a container for: * A [system prompt](#system-prompts) — a set of instructions for the LLM written by the developer -* One or more [tools](#tools) — functions that the LLM may call to get information while generating a response +* One or more [retrieval tool](#tools) — functions that the LLM may call to get information while generating a response * An optional structured [result type](results.md) — the structured datatype the LLM must return at the end of a run * A [dependency](dependencies.md) type constraint — system prompt functions, tools and result validators may all use dependencies when they're run -* Agents may optionally also have a default [model](api/models/base.md) associated with them; the model to use can also be specified when running the agent +* Agents may optionally also have a default [LLM model](api/models/base.md) associated with them; the model to use can also be specified when running the agent In typing terms, agents are generic in their dependency and result types, e.g., an agent which required dependencies of type `#!python Foobar` and returned results of type `#!python list[str]` would have type `#!python Agent[Foobar, list[str]]`. @@ -166,21 +166,21 @@ print(result.data) _(This example is complete, it can be run "as is")_ -## Tools +## Function Tools -Tools provide a mechanism for models to request extra information to help them generate a response. +Function tools provide a mechanism for models to retrieve extra information to help them generate a response. They're useful when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. -!!! info "Tools vs. RAG" - Tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. +!!! info "Function tools vs. RAG" + Function tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) There are two different decorator functions to register tools: -1. [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] — for tools that don't need access to the agent [context][pydantic_ai.dependencies.CallContext] -2. [`@agent.tool`][pydantic_ai.Agent.tool] — for tools that do need access to the agent [context][pydantic_ai.dependencies.CallContext] +1. [`@agent.tool`][pydantic_ai.Agent.tool] — for tools that need access to the agent [context][pydantic_ai.dependencies.CallContext] +2. [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] — for tools that do not need access to the agent [context][pydantic_ai.dependencies.CallContext] `@agent.tool` is the default since in the majority of cases tools will need access to the agent context. @@ -323,13 +323,15 @@ sequenceDiagram Note over Agent: Game session complete ``` -### tools and schema +### Function Tools vs. Structured Results -Under the hood, tools use the model's "tools" or "functions" API to let the model know what tools are available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call tools while others end the run and return a result. +As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call function tools while others end the run and return a result. + +### Function tools and schema Function parameters are extracted from the function signature, and all parameters except `CallContext` are used to build the schema for that tool call. -Even better, PydanticAI extracts the docstring from tool functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and adds them to the schema. +Even better, PydanticAI extracts the docstring from functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and adds them to the schema. [Griffe supports](https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings) extracting parameter descriptions from `google`, `numpy` and `sphinx` style docstrings, and PydanticAI will infer the format to use based on the docstring. We plan to add support in the future to explicitly set the style to use, and warn/error if not all parameters are documented; see [#59](https://github.com/pydantic/pydantic-ai/issues/59). diff --git a/docs/dependencies.md b/docs/dependencies.md index 32e5474fd7..74ddc1a013 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -101,7 +101,7 @@ _(This example is complete, it can be run "as is")_ ### Asynchronous vs. Synchronous dependencies -System prompt functions, tool functions and result validator are all run in the async context of an agent run. +[System prompt functions](agents.md#system-prompts), [function tools](agents.md#function-tools) and [result validators](results.md#result-validators-functions) are all run in the async context of an agent run. If these functions are not coroutines (e.g. `async def`) they are called with [`run_in_executor`][asyncio.loop.run_in_executor] in a thread pool, it's therefore marginally preferable @@ -219,7 +219,7 @@ async def main(): #> Did you hear about the toothpaste scandal? They called it Colgate. ``` -1. To pass `CallContext` and to a tool, us the [`tool`][pydantic_ai.Agent.tool] decorator. +1. To pass `CallContext` to a tool, use the [`tool`][pydantic_ai.Agent.tool] decorator. 2. `CallContext` may optionally be passed to a [`result_validator`][pydantic_ai.Agent.result_validator] function as the first argument. _(This example is complete, it can be run "as is")_ diff --git a/docs/examples/rag.md b/docs/examples/rag.md index 64cf4af321..d7dee103d9 100644 --- a/docs/examples/rag.md +++ b/docs/examples/rag.md @@ -9,7 +9,7 @@ Demonstrates: * RAG search This is done by creating a database containing each section of the markdown documentation, then registering -the search tool as a tool with the PydanticAI agent. +the search tool with the PydanticAI agent. Logic for extracting sections from markdown files and a JSON file with that data is available in [this gist](https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992). diff --git a/docs/index.md b/docs/index.md index 2eea30ce77..938f372087 100644 --- a/docs/index.md +++ b/docs/index.md @@ -128,7 +128,7 @@ async def main(): 4. Static [system prompts](agents.md#system-prompts) can be registered with the [`system_prompt` keyword argument][pydantic_ai.Agent.__init__] to the agent. 5. Dynamic [system prompts](agents.md#system-prompts) can be registered with the [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt] decorator, and can make use of dependency injection. Dependencies are carried via the [`CallContext`][pydantic_ai.dependencies.CallContext] argument, which is parameterized with the `deps_type` from above. If the type annotation here is wrong, static type checkers will catch it. 6. [Tools](agents.md#tools) let you register "tools" which the LLM may call while responding to a user. Again, dependencies are carried via [`CallContext`][pydantic_ai.dependencies.CallContext], and any other arguments become the tool schema passed to the LLM. Pydantic is used to validate these arguments, and errors are passed back to the LLM so it can retry. -7. The docstring of a tool also passed to the LLM as a description of the tool. Parameter descriptions are [extracted](agents.md#tools-tools-and-schema) from the docstring and added to the tool schema sent to the LLM. +7. The docstring of a tool is also passed to the LLM as the description of the tool. Parameter descriptions are [extracted](agents.md#tools-tools-and-schema) from the docstring and added to the tool schema sent to the LLM. 8. [Run the agent](agents.md#running-agents) asynchronously, conducting a conversation with the LLM until a final response is reached. Even in this fairly simple case, the agent will exchange multiple messages with the LLM as tools are called to retrieve a result. 9. The response from the agent will, be guaranteed to be a `SupportResult`, if validation fails [reflection](agents.md#reflection-and-self-correction) will mean the agent is prompted to try again. 10. The result will be validated with Pydantic to guarantee it is a `SupportResult`, since the agent is generic, it'll also be typed as a `SupportResult` to aid with static type checking.