diff --git a/docs/agents.md b/docs/agents.md
index 5de07cd31c..b77ddbf816 100644
--- a/docs/agents.md
+++ b/docs/agents.md
@@ -8,10 +8,10 @@ but multiple agents can also interact to embody more complex workflows.
The [`Agent`][pydantic_ai.Agent] class has full API documentation, but conceptually you can think of an agent as a container for:
* A [system prompt](#system-prompts) — a set of instructions for the LLM written by the developer
-* One or more [retrievers](#retrievers) — functions that the LLM may call to get information while generating a response
+* One or more [retrieval tool](#tools) — functions that the LLM may call to get information while generating a response
* An optional structured [result type](results.md) — the structured datatype the LLM must return at the end of a run
-* A [dependency](dependencies.md) type constraint — system prompt functions, retrievers and result validators may all use dependencies when they're run
-* Agents may optionally also have a default [model](api/models/base.md) associated with them; the model to use can also be specified when running the agent
+* A [dependency](dependencies.md) type constraint — system prompt functions, tools and result validators may all use dependencies when they're run
+* Agents may optionally also have a default [LLM model](api/models/base.md) associated with them; the model to use can also be specified when running the agent
In typing terms, agents are generic in their dependency and result types, e.g., an agent which required dependencies of type `#!python Foobar` and returned results of type `#!python list[str]` would have type `#!python Agent[Foobar, list[str]]`.
@@ -31,7 +31,7 @@ roulette_agent = Agent( # (1)!
)
-@roulette_agent.retriever
+@roulette_agent.tool
async def roulette_wheel(ctx: CallContext[int], square: int) -> str: # (2)!
"""check if the square is a winner"""
return 'winner' if square == ctx.deps else 'loser'
@@ -49,7 +49,7 @@ print(result.data)
```
1. Create an agent, which expects an integer dependency and returns a boolean result. This agent will have type `#!python Agent[int, bool]`.
-2. Define a retriever that checks if the square is a winner. Here [`CallContext`][pydantic_ai.dependencies.CallContext] is parameterized with the dependency type `int`; if you got the dependency type wrong you'd get a typing error.
+2. Define a tool that checks if the square is a winner. Here [`CallContext`][pydantic_ai.dependencies.CallContext] is parameterized with the dependency type `int`; if you got the dependency type wrong you'd get a typing error.
3. In reality, you might want to use a random number here e.g. `random.randint(0, 36)`.
4. `result.data` will be a boolean indicating if the square is a winner. Pydantic performs the result validation, it'll be typed as a `bool` since its type is derived from the `result_type` generic parameter of the agent.
@@ -166,23 +166,23 @@ print(result.data)
_(This example is complete, it can be run "as is")_
-## Retrievers
+## Function Tools
-Retrievers provide a mechanism for models to request extra information to help them generate a response.
+Function tools provide a mechanism for models to retrieve extra information to help them generate a response.
They're useful when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool.
-!!! info "Retrievers vs. RAG"
- Retrievers are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information.
+!!! info "Function tools vs. RAG"
+ Function tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information.
- The main semantic difference between PydanticAI Retrievers and RAG is RAG is synonymous with vector search, while PydanticAI retrievers are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58))
+ The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58))
-There are two different decorator functions to register retrievers:
+There are two different decorator functions to register tools:
-1. [`@agent.retriever_plain`][pydantic_ai.Agent.retriever_plain] — for retrievers that don't need access to the agent [context][pydantic_ai.dependencies.CallContext]
-2. [`@agent.retriever`][pydantic_ai.Agent.retriever] — for retrievers that do need access to the agent [context][pydantic_ai.dependencies.CallContext]
+1. [`@agent.tool`][pydantic_ai.Agent.tool] — for tools that need access to the agent [context][pydantic_ai.dependencies.CallContext]
+2. [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] — for tools that do not need access to the agent [context][pydantic_ai.dependencies.CallContext]
-`@agent.retriever` is the default since in the majority of cases retrievers will need access to the agent context.
+`@agent.tool` is the default since in the majority of cases tools will need access to the agent context.
Here's an example using both:
@@ -202,13 +202,13 @@ agent = Agent(
)
-@agent.retriever_plain # (3)!
+@agent.tool_plain # (3)!
def roll_die() -> str:
"""Roll a six-sided die and return the result."""
return str(random.randint(1, 6))
-@agent.retriever # (4)!
+@agent.tool # (4)!
def get_player_name(ctx: CallContext[str]) -> str:
"""Get the player's name."""
return ctx.deps
@@ -221,8 +221,8 @@ print(dice_result.data)
1. This is a pretty simple task, so we can use the fast and cheap Gemini flash model.
2. We pass the user's name as the dependency, to keep things simple we use just the name as a string as the dependency.
-3. This retriever doesn't need any context, it just returns a random number. You could probably use a dynamic system prompt in this case.
-4. This retriever needs the player's name, so it uses `CallContext` to access dependencies which are just the player's name in this case.
+3. This tool doesn't need any context, it just returns a random number. You could probably use a dynamic system prompt in this case.
+4. This tool needs the player's name, so it uses `CallContext` to access dependencies which are just the player's name in this case.
5. Run the agent, passing the player's name as the dependency.
_(This example is complete, it can be run "as is")_
@@ -297,9 +297,9 @@ sequenceDiagram
Note over Agent: Send prompts
Agent ->> LLM: System: "You're a dice game..."
User: "My guess is 4"
activate LLM
- Note over LLM: LLM decides to use
a retriever
+ Note over LLM: LLM decides to use
a tool
- LLM ->> Agent: Call retriever
roll_die()
+ LLM ->> Agent: Call tool
roll_die()
deactivate LLM
activate Agent
Note over Agent: Rolls a six-sided die
@@ -307,9 +307,9 @@ sequenceDiagram
Agent -->> LLM: ToolReturn
"4"
deactivate Agent
activate LLM
- Note over LLM: LLM decides to use
another retriever
+ Note over LLM: LLM decides to use
another tool
- LLM ->> Agent: Call retriever
get_player_name()
+ LLM ->> Agent: Call tool
get_player_name()
deactivate LLM
activate Agent
Note over Agent: Retrieves player name
@@ -323,19 +323,21 @@ sequenceDiagram
Note over Agent: Game session complete
```
-### Retrievers, tools, and schema
+### Function Tools vs. Structured Results
-Under the hood, retrievers use the model's "tools" or "functions" API to let the model know what retrievers are available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call retrievers while others end the run and return a result.
+As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call function tools while others end the run and return a result.
+
+### Function tools and schema
Function parameters are extracted from the function signature, and all parameters except `CallContext` are used to build the schema for that tool call.
-Even better, PydanticAI extracts the docstring from retriever functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and adds them to the schema.
+Even better, PydanticAI extracts the docstring from functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and adds them to the schema.
[Griffe supports](https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings) extracting parameter descriptions from `google`, `numpy` and `sphinx` style docstrings, and PydanticAI will infer the format to use based on the docstring. We plan to add support in the future to explicitly set the style to use, and warn/error if not all parameters are documented; see [#59](https://github.com/pydantic/pydantic-ai/issues/59).
-To demonstrate a retriever's schema, here we use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to print the schema a model would receive:
+To demonstrate a tool's schema, here we use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to print the schema a model would receive:
-```py title="retriever_schema.py"
+```py title="tool_schema.py"
from pydantic_ai import Agent
from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse
from pydantic_ai.models.function import AgentInfo, FunctionModel
@@ -343,7 +345,7 @@ from pydantic_ai.models.function import AgentInfo, FunctionModel
agent = Agent()
-@agent.retriever_plain
+@agent.tool_plain
def foobar(a: int, b: str, c: dict[str, list[float]]) -> str:
"""Get me foobar.
@@ -356,10 +358,10 @@ def foobar(a: int, b: str, c: dict[str, list[float]]) -> str:
def print_schema(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
- retriever = info.retrievers['foobar']
- print(retriever.description)
+ tool = info.function_tools['foobar']
+ print(tool.description)
#> Get me foobar.
- print(retriever.json_schema)
+ print(tool.json_schema)
"""
{
'description': 'Get me foobar.',
@@ -386,22 +388,22 @@ agent.run_sync('hello', model=FunctionModel(print_schema))
_(This example is complete, it can be run "as is")_
-The return type of retriever can be any valid JSON object ([`JsonData`][pydantic_ai.dependencies.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.
+The return type of tool can be any valid JSON object ([`JsonData`][pydantic_ai.dependencies.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.
-If a retriever has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the retriever is simplified to be just that object. (TODO example)
+If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object. (TODO example)
## Reflection and self-correction
-Validation errors from both retriever parameter validation and [structured result validation](results.md#structured-result-validation) can be passed back to the model with a request to retry.
+Validation errors from both tool parameter validation and [structured result validation](results.md#structured-result-validation) can be passed back to the model with a request to retry.
-You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [retriever](#retrievers) or [result validator function](results.md#result-validators-functions) to tell the model it should retry generating a response.
+You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [tool](#tools) or [result validator function](results.md#result-validators-functions) to tell the model it should retry generating a response.
-- The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific retriever][pydantic_ai.Agent.retriever], or a [result validator][pydantic_ai.Agent.__init__].
-- You can access the current retry count from within a retriever or result validator via [`ctx.retry`][pydantic_ai.dependencies.CallContext].
+- The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific tool][pydantic_ai.Agent.tool], or a [result validator][pydantic_ai.Agent.__init__].
+- You can access the current retry count from within a tool or result validator via [`ctx.retry`][pydantic_ai.dependencies.CallContext].
Here's an example:
-```py title="retriever_retry.py"
+```py title="tool_retry.py"
from fake_database import DatabaseConn
from pydantic import BaseModel
@@ -420,7 +422,7 @@ agent = Agent(
)
-@agent.retriever(retries=2)
+@agent.tool(retries=2)
def get_user_by_name(ctx: CallContext[DatabaseConn], name: str) -> int:
"""Get a user's ID from their full name."""
print(name)
@@ -455,7 +457,7 @@ from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior
agent = Agent('openai:gpt-4o')
-@agent.retriever_plain
+@agent.tool_plain
def calc_volume(size: int) -> int: # (1)!
if size == 42:
return size**3
@@ -467,7 +469,7 @@ try:
result = agent.run_sync('Please get me the volume of a box with size 6.')
except UnexpectedModelBehavior as e:
print('An error occurred:', e)
- #> An error occurred: Retriever exceeded max retries count of 1
+ #> An error occurred: Tool exceeded max retries count of 1
print('cause:', repr(e.__cause__))
#> cause: ModelRetry('Please try again.')
print('messages:', agent.last_run_messages)
@@ -513,6 +515,6 @@ except UnexpectedModelBehavior as e:
else:
print(result.data)
```
-1. Define a retriever that will raise `ModelRetry` repeatedly in this case.
+1. Define a tool that will raise `ModelRetry` repeatedly in this case.
_(This example is complete, it can be run "as is")_
diff --git a/docs/api/agent.md b/docs/api/agent.md
index 9074c0bb15..6b012f1ea3 100644
--- a/docs/api/agent.md
+++ b/docs/api/agent.md
@@ -12,6 +12,6 @@
- override_model
- last_run_messages
- system_prompt
- - retriever
- - retriever_plain
+ - tool
+ - tool_plain
- result_validator
diff --git a/docs/dependencies.md b/docs/dependencies.md
index c1f7aa1b2c..74ddc1a013 100644
--- a/docs/dependencies.md
+++ b/docs/dependencies.md
@@ -1,13 +1,12 @@
# Dependencies
-PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators-functions).
+PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [tools](agents.md#tools) and [result validators](results.md#result-validators-functions).
Matching PydanticAI's design philosophy, our dependency system tries to use existing best practice in Python development rather than inventing esoteric "magic", this should make dependencies type-safe, understandable easier to test and ultimately easier to deploy in production.
## Defining Dependencies
-Dependencies can be any python type. While in simple cases you might be able to pass a single object
-as a dependency (e.g. an HTTP connection), [dataclasses][] are generally a convenient container when your dependencies included multiple objects.
+Dependencies can be any python type. While in simple cases you might be able to pass a single object as a dependency (e.g. an HTTP connection), [dataclasses][] are generally a convenient container when your dependencies included multiple objects.
Here's an example of defining an agent that requires dependencies.
@@ -102,7 +101,7 @@ _(This example is complete, it can be run "as is")_
### Asynchronous vs. Synchronous dependencies
-System prompt functions, retriever functions and result validator are all run in the async context of an agent run.
+[System prompt functions](agents.md#system-prompts), [function tools](agents.md#function-tools) and [result validators](results.md#result-validators-functions) are all run in the async context of an agent run.
If these functions are not coroutines (e.g. `async def`) they are called with
[`run_in_executor`][asyncio.loop.run_in_executor] in a thread pool, it's therefore marginally preferable
@@ -159,7 +158,7 @@ _(This example is complete, it can be run "as is")_
## Full Example
-As well as system prompts, dependencies can be used in [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators-functions).
+As well as system prompts, dependencies can be used in [tools](agents.md#tools) and [result validators](results.md#result-validators-functions).
```py title="full_example.py" hl_lines="27-35 38-48"
from dataclasses import dataclass
@@ -188,7 +187,7 @@ async def get_system_prompt(ctx: CallContext[MyDeps]) -> str:
return f'Prompt: {response.text}'
-@agent.retriever # (1)!
+@agent.tool # (1)!
async def get_joke_material(ctx: CallContext[MyDeps], subject: str) -> str:
response = await ctx.deps.http_client.get(
'https://example.com#jokes',
@@ -220,7 +219,7 @@ async def main():
#> Did you hear about the toothpaste scandal? They called it Colgate.
```
-1. To pass `CallContext` and to a retriever, us the [`retriever`][pydantic_ai.Agent.retriever] decorator.
+1. To pass `CallContext` to a tool, use the [`tool`][pydantic_ai.Agent.tool] decorator.
2. `CallContext` may optionally be passed to a [`result_validator`][pydantic_ai.Agent.result_validator] function as the first argument.
_(This example is complete, it can be run "as is")_
@@ -324,7 +323,7 @@ joke_agent = Agent(
factory_agent = Agent('gemini-1.5-pro', result_type=list[str])
-@joke_agent.retriever
+@joke_agent.tool
async def joke_factory(ctx: CallContext[MyDeps], count: int) -> str:
r = await ctx.deps.factory_agent.run(f'Please generate {count} jokes.')
return '\n'.join(r.data)
diff --git a/docs/examples/bank-support.md b/docs/examples/bank-support.md
index 9068590bea..0abddc5d14 100644
--- a/docs/examples/bank-support.md
+++ b/docs/examples/bank-support.md
@@ -4,7 +4,7 @@ Demonstrates:
* [dynamic system prompt](../agents.md#system-prompts)
* [structured `result_type`](../results.md#structured-result-validation)
-* [retrievers](../agents.md#retrievers)
+* [tools](../agents.md#tools)
## Running the Example
diff --git a/docs/examples/rag.md b/docs/examples/rag.md
index 9a4d5382db..d7dee103d9 100644
--- a/docs/examples/rag.md
+++ b/docs/examples/rag.md
@@ -4,12 +4,12 @@ RAG search example. This demo allows you to ask question of the [logfire](https:
Demonstrates:
-* [retrievers](../agents.md#retrievers)
+* [tools](../agents.md#tools)
* [agent dependencies](../dependencies.md)
* RAG search
This is done by creating a database containing each section of the markdown documentation, then registering
-the search tool as a retriever with the PydanticAI agent.
+the search tool with the PydanticAI agent.
Logic for extracting sections from markdown files and a JSON file with that data is available in
[this gist](https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992).
diff --git a/docs/examples/weather-agent.md b/docs/examples/weather-agent.md
index f5112a0307..c770431134 100644
--- a/docs/examples/weather-agent.md
+++ b/docs/examples/weather-agent.md
@@ -2,7 +2,7 @@ Example of PydanticAI with multiple tools which the LLM needs to call in turn to
Demonstrates:
-* [retrievers](../agents.md#retrievers)
+* [tools](../agents.md#tools)
* [agent dependencies](../dependencies.md)
In this case the idea is a "weather" agent — the user can ask for the weather in multiple locations,
diff --git a/docs/index.md b/docs/index.md
index e0890303a3..938f372087 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -48,9 +48,9 @@ The first known use of "hello, world" was in a 1974 textbook about the C program
_(This example is complete, it can be run "as is")_
-Not very interesting yet, but we can easily add "retrievers", dynamic system prompts, and structured responses to build more powerful agents.
+Not very interesting yet, but we can easily add "tools", dynamic system prompts, and structured responses to build more powerful agents.
-## Retrievers & Dependency Injection Example
+## Tools & Dependency Injection Example
Here is a concise example using PydanticAI to build a support agent for a bank:
@@ -92,7 +92,7 @@ async def add_customer_name(ctx: CallContext[SupportDependencies]) -> str:
return f"The customer's name is {customer_name!r}"
-@support_agent.retriever # (6)!
+@support_agent.tool # (6)!
async def customer_balance(
ctx: CallContext[SupportDependencies], include_pending: bool
) -> str:
@@ -124,15 +124,15 @@ async def main():
1. This [agent](agents.md) will act as first-tier support in a bank. Agents are generic in the type of dependencies they accept and the type of result they return. In this case, the support agent has type `#!python Agent[SupportDependencies, SupportResult]`.
2. Here we configure the agent to use [OpenAI's GPT-4o model](api/models/openai.md), you can also set the model when running the agent.
-3. The `SupportDependencies` dataclass is used to pass data, connections, and logic into the model that will be needed when running [system prompt](agents.md#system-prompts) and [retriever](agents.md#retrievers) functions. PydanticAI's system of dependency injection provides a type-safe way to customise the behavior of your agents, and can be especially useful when running unit tests and evals.
+3. The `SupportDependencies` dataclass is used to pass data, connections, and logic into the model that will be needed when running [system prompt](agents.md#system-prompts) and [tool](agents.md#tools) functions. PydanticAI's system of dependency injection provides a type-safe way to customise the behavior of your agents, and can be especially useful when running unit tests and evals.
4. Static [system prompts](agents.md#system-prompts) can be registered with the [`system_prompt` keyword argument][pydantic_ai.Agent.__init__] to the agent.
5. Dynamic [system prompts](agents.md#system-prompts) can be registered with the [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt] decorator, and can make use of dependency injection. Dependencies are carried via the [`CallContext`][pydantic_ai.dependencies.CallContext] argument, which is parameterized with the `deps_type` from above. If the type annotation here is wrong, static type checkers will catch it.
-6. [Retrievers](agents.md#retrievers) let you register "tools" which the LLM may call while responding to a user. Again, dependencies are carried via [`CallContext`][pydantic_ai.dependencies.CallContext], and any other arguments become the tool schema passed to the LLM. Pydantic is used to validate these arguments, and errors are passed back to the LLM so it can retry.
-7. The docstring of a retriever also passed to the LLM as a description of the tool. Parameter descriptions are [extracted](agents.md#retrievers-tools-and-schema) from the docstring and added to the tool schema sent to the LLM.
-8. [Run the agent](agents.md#running-agents) asynchronously, conducting a conversation with the LLM until a final response is reached. Even in this fairly simple case, the agent will exchange multiple messages with the LLM as retrievers are called to retrieve a result.
+6. [Tools](agents.md#tools) let you register "tools" which the LLM may call while responding to a user. Again, dependencies are carried via [`CallContext`][pydantic_ai.dependencies.CallContext], and any other arguments become the tool schema passed to the LLM. Pydantic is used to validate these arguments, and errors are passed back to the LLM so it can retry.
+7. The docstring of a tool is also passed to the LLM as the description of the tool. Parameter descriptions are [extracted](agents.md#tools-tools-and-schema) from the docstring and added to the tool schema sent to the LLM.
+8. [Run the agent](agents.md#running-agents) asynchronously, conducting a conversation with the LLM until a final response is reached. Even in this fairly simple case, the agent will exchange multiple messages with the LLM as tools are called to retrieve a result.
9. The response from the agent will, be guaranteed to be a `SupportResult`, if validation fails [reflection](agents.md#reflection-and-self-correction) will mean the agent is prompted to try again.
10. The result will be validated with Pydantic to guarantee it is a `SupportResult`, since the agent is generic, it'll also be typed as a `SupportResult` to aid with static type checking.
-11. In a real use case, you'd add many more retrievers and a longer system prompt to the agent to extend the context it's equipped with and support it can provide.
+11. In a real use case, you'd add many more tools and a longer system prompt to the agent to extend the context it's equipped with and support it can provide.
12. This is a simple sketch of a database connection, used to keep the example short and readable. In reality, you'd be connecting to an external database (e.g. PostgreSQL) to get information about customers.
13. This [Pydantic](https://docs.pydantic.dev) model is used to constrain the structured data returned by the agent. From this simple definition, Pydantic builds the JSON Schema that tells the LLM how to return the data, and performs validation to guarantee the data is correct at the end of the conversation.
@@ -153,8 +153,8 @@ sequenceDiagram
Agent ->> LLM: Request
System: "You are a support agent..."
System: "The customer's name is John"
User: "What is my balance?"
activate LLM
- Note over LLM: LLM decides to use a retriever
- LLM ->> Agent: Call retriever
customer_balance()
+ Note over LLM: LLM decides to use a tool
+ LLM ->> Agent: Call tool
customer_balance()
deactivate LLM
activate Agent
Note over Agent: Retrieve account balance
diff --git a/docs/results.md b/docs/results.md
index 0c0a217677..4704c824e8 100644
--- a/docs/results.md
+++ b/docs/results.md
@@ -34,7 +34,7 @@ If the result type is a union with multiple members (after remove `str` from the
If the result type schema is not of type `"object"`, the result type is wrapped in a single element object, so the schema of all tools registered with the model are object schemas.
-Structured results (like retrievers) use Pydantic to build the JSON schema used for the tool, and to validate the data returned by the model.
+Structured results (like tools) use Pydantic to build the JSON schema used for the tool, and to validate the data returned by the model.
!!! note "Bring on PEP-747"
Until [PEP-747](https://peps.python.org/pep-0747/) "Annotating Type Forms" lands, unions are not valid as `type`s in Python.
@@ -160,7 +160,7 @@ _(This example is complete, it can be run "as is")_
There two main challenges with streamed results:
1. Validating structured responses before they're complete, this is achieved by "partial validation" which was recently added to Pydantic in [pydantic/pydantic#10748](https://github.com/pydantic/pydantic/pull/10748).
-2. When receiving a response, we don't know if it's the final response without starting to stream it and peeking at the content. PydanticAI streams just enough of the response to sniff out if it's a retriever call or a result, then streams the whole thing and calls retrievers, or returns the stream as a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult].
+2. When receiving a response, we don't know if it's the final response without starting to stream it and peeking at the content. PydanticAI streams just enough of the response to sniff out if it's a tool call or a result, then streams the whole thing and calls tools, or returns the stream as a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult].
### Streaming Text
diff --git a/pydantic_ai_examples/bank_support.py b/pydantic_ai_examples/bank_support.py
index 4f428ff2fb..021f10b905 100644
--- a/pydantic_ai_examples/bank_support.py
+++ b/pydantic_ai_examples/bank_support.py
@@ -62,7 +62,7 @@ async def add_customer_name(ctx: CallContext[SupportDependencies]) -> str:
return f"The customer's name is {customer_name!r}"
-@support_agent.retriever
+@support_agent.tool
async def customer_balance(
ctx: CallContext[SupportDependencies], include_pending: bool
) -> str:
diff --git a/pydantic_ai_examples/rag.py b/pydantic_ai_examples/rag.py
index 44aa88f2d9..d64ff50c93 100644
--- a/pydantic_ai_examples/rag.py
+++ b/pydantic_ai_examples/rag.py
@@ -51,7 +51,7 @@ class Deps:
agent = Agent('openai:gpt-4o', deps_type=Deps)
-@agent.retriever
+@agent.tool
async def retrieve(context: CallContext[Deps], search_query: str) -> str:
"""Retrieve documentation sections based on a search query.
diff --git a/pydantic_ai_examples/roulette_wheel.py b/pydantic_ai_examples/roulette_wheel.py
index 21820305e0..eee4947aaa 100644
--- a/pydantic_ai_examples/roulette_wheel.py
+++ b/pydantic_ai_examples/roulette_wheel.py
@@ -32,7 +32,7 @@ class Deps:
)
-@roulette_agent.retriever
+@roulette_agent.tool
async def roulette_wheel(
ctx: CallContext[Deps], square: int
) -> Literal['winner', 'loser']:
diff --git a/pydantic_ai_examples/weather_agent.py b/pydantic_ai_examples/weather_agent.py
index 6e69920dc2..7e62bf9e64 100644
--- a/pydantic_ai_examples/weather_agent.py
+++ b/pydantic_ai_examples/weather_agent.py
@@ -41,7 +41,7 @@ class Deps:
)
-@weather_agent.retriever
+@weather_agent.tool
async def get_lat_lng(
ctx: CallContext[Deps], location_description: str
) -> dict[str, float]:
@@ -71,7 +71,7 @@ async def get_lat_lng(
raise ModelRetry('Could not find the location')
-@weather_agent.retriever
+@weather_agent.tool
async def get_weather(ctx: CallContext[Deps], lat: float, lng: float) -> dict[str, Any]:
"""Get the weather at a location.
diff --git a/pydantic_ai_slim/pydantic_ai/_pydantic.py b/pydantic_ai_slim/pydantic_ai/_pydantic.py
index 3b241de35e..c0bdb2c4aa 100644
--- a/pydantic_ai_slim/pydantic_ai/_pydantic.py
+++ b/pydantic_ai_slim/pydantic_ai/_pydantic.py
@@ -20,8 +20,8 @@
from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
if TYPE_CHECKING:
- from . import _retriever
- from .dependencies import AgentDeps, RetrieverParams
+ from . import _tool
+ from .dependencies import AgentDeps, ToolParams
__all__ = 'function_schema', 'LazyTypeAdapter'
@@ -39,8 +39,8 @@ class FunctionSchema(TypedDict):
var_positional_field: str | None
-def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, RetrieverParams]) -> FunctionSchema: # noqa: C901
- """Build a Pydantic validator and JSON schema from a retriever function.
+def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]) -> FunctionSchema: # noqa: C901
+ """Build a Pydantic validator and JSON schema from a tool function.
Args:
either_function: The function to build a validator and JSON schema for.
@@ -78,10 +78,10 @@ def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, R
if index == 0 and takes_ctx:
if not _is_call_ctx(annotation):
- errors.append('First argument must be a CallContext instance when using `.retriever`')
+ errors.append('First argument must be a CallContext instance when using `.tool`')
continue
elif not takes_ctx and _is_call_ctx(annotation):
- errors.append('CallContext instance can only be used with `.retriever`')
+ errors.append('CallContext instance can only be used with `.tool`')
continue
elif index != 0 and _is_call_ctx(annotation):
errors.append('CallContext instance can only be used as the first argument')
diff --git a/pydantic_ai_slim/pydantic_ai/_result.py b/pydantic_ai_slim/pydantic_ai/_result.py
index d529ee546f..26d7ace457 100644
--- a/pydantic_ai_slim/pydantic_ai/_result.py
+++ b/pydantic_ai_slim/pydantic_ai/_result.py
@@ -75,7 +75,7 @@ def __init__(self, tool_retry: messages.RetryPrompt):
class ResultSchema(Generic[ResultData]):
"""Model the final response from an agent run.
- Similar to `Retriever` but for the final result of running an agent.
+ Similar to `Tool` but for the final result of running an agent.
"""
tools: dict[str, ResultTool[ResultData]]
diff --git a/pydantic_ai_slim/pydantic_ai/_retriever.py b/pydantic_ai_slim/pydantic_ai/_tool.py
similarity index 81%
rename from pydantic_ai_slim/pydantic_ai/_retriever.py
rename to pydantic_ai_slim/pydantic_ai/_tool.py
index bb376bf7ce..b808c3058e 100644
--- a/pydantic_ai_slim/pydantic_ai/_retriever.py
+++ b/pydantic_ai_slim/pydantic_ai/_tool.py
@@ -9,22 +9,20 @@
from pydantic_core import SchemaValidator
from . import _pydantic, _utils, messages
-from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
+from .dependencies import AgentDeps, CallContext, ToolContextFunc, ToolParams, ToolPlainFunc
from .exceptions import ModelRetry, UnexpectedModelBehavior
-# Usage `RetrieverEitherFunc[AgentDependencies, P]`
-RetrieverEitherFunc = _utils.Either[
- RetrieverContextFunc[AgentDeps, RetrieverParams], RetrieverPlainFunc[RetrieverParams]
-]
+# Usage `ToolEitherFunc[AgentDependencies, P]`
+ToolEitherFunc = _utils.Either[ToolContextFunc[AgentDeps, ToolParams], ToolPlainFunc[ToolParams]]
@dataclass(init=False)
-class Retriever(Generic[AgentDeps, RetrieverParams]):
- """A retriever function for an agent."""
+class Tool(Generic[AgentDeps, ToolParams]):
+ """A tool function for an agent."""
name: str
description: str
- function: RetrieverEitherFunc[AgentDeps, RetrieverParams] = field(repr=False)
+ function: ToolEitherFunc[AgentDeps, ToolParams] = field(repr=False)
is_async: bool
single_arg_name: str | None
positional_fields: list[str]
@@ -35,8 +33,8 @@ class Retriever(Generic[AgentDeps, RetrieverParams]):
_current_retry: int = 0
outer_typed_dict_key: str | None = None
- def __init__(self, function: RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int):
- """Build a Retriever dataclass from a function."""
+ def __init__(self, function: ToolEitherFunc[AgentDeps, ToolParams], retries: int):
+ """Build a Tool dataclass from a function."""
self.function = function
# noinspection PyTypeChecker
f = _pydantic.function_schema(function)
@@ -56,7 +54,7 @@ def reset(self) -> None:
self._current_retry = 0
async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
- """Run the retriever function asynchronously."""
+ """Run the tool function asynchronously."""
try:
if isinstance(message.args, messages.ArgsJson):
args_dict = self.validator.validate_json(message.args.args_json)
@@ -100,8 +98,8 @@ def _call_args(
def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
self._current_retry += 1
if self._current_retry > self.max_retries:
- # TODO custom error with details of the retriever
- raise UnexpectedModelBehavior(f'Retriever exceeded max retries count of {self.max_retries}') from exc
+ # TODO custom error with details of the tool
+ raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
else:
if isinstance(exc, ValidationError):
content = exc.errors(include_url=False)
diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py
index b2c05b208f..5a40cd2cf5 100644
--- a/pydantic_ai_slim/pydantic_ai/agent.py
+++ b/pydantic_ai_slim/pydantic_ai/agent.py
@@ -11,15 +11,15 @@
from . import (
_result,
- _retriever as _r,
_system_prompt,
+ _tool as _r,
_utils,
exceptions,
messages as _messages,
models,
result,
)
-from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
+from .dependencies import AgentDeps, CallContext, ToolContextFunc, ToolParams, ToolPlainFunc
from .result import ResultData
__all__ = ('Agent',)
@@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]):
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
_allow_text_result: bool = field(repr=False)
_system_prompts: tuple[str, ...] = field(repr=False)
- _retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = field(repr=False)
+ _function_tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False)
_default_retries: int = field(repr=False)
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
_deps_type: type[AgentDeps] = field(repr=False)
@@ -119,7 +119,7 @@ def __init__(
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
- self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {}
+ self._function_tools: dict[str, _r.Tool[AgentDeps, Any]] = {}
self._deps_type = deps_type
self._default_retries = retries
self._system_prompt_functions = []
@@ -153,8 +153,8 @@ async def run(
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
self.last_run_messages = messages
- for retriever in self._retrievers.values():
- retriever.reset()
+ for tool in self._function_tools.values():
+ tool.reset()
cost = result.Cost()
@@ -246,8 +246,8 @@ async def run_stream(
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
self.last_run_messages = messages
- for retriever in self._retrievers.values():
- retriever.reset()
+ for tool in self._function_tools.values():
+ tool.reset()
cost = result.Cost()
@@ -366,7 +366,7 @@ async def async_system_prompt(ctx: CallContext[str]) -> str:
result = agent.run_sync('foobar', deps='spam')
print(result.data)
- #> success (no retriever calls)
+ #> success (no tool calls)
```
"""
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
@@ -421,41 +421,37 @@ async def result_validator_deps(ctx: CallContext[str], data: str) -> str:
result = agent.run_sync('foobar', deps='spam')
print(result.data)
- #> success (no retriever calls)
+ #> success (no tool calls)
```
"""
self._result_validators.append(_result.ResultValidator(func))
return func
@overload
- def retriever(
- self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], /
- ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ...
+ def tool(self, func: ToolContextFunc[AgentDeps, ToolParams], /) -> ToolContextFunc[AgentDeps, ToolParams]: ...
@overload
- def retriever(
+ def tool(
self, /, *, retries: int | None = None
- ) -> Callable[
- [RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams]
- ]: ...
+ ) -> Callable[[ToolContextFunc[AgentDeps, ToolParams]], ToolContextFunc[AgentDeps, ToolParams]]: ...
- def retriever(
+ def tool(
self,
- func: RetrieverContextFunc[AgentDeps, RetrieverParams] | None = None,
+ func: ToolContextFunc[AgentDeps, ToolParams] | None = None,
/,
*,
retries: int | None = None,
) -> Any:
- """Decorator to register a retriever function which takes
+ """Decorator to register a tool function which takes
[`CallContext`][pydantic_ai.dependencies.CallContext] as its first argument.
Can decorate a sync or async functions.
The docstring is inspected to extract both the tool description and description of each parameter,
- [learn more](../agents.md#retrievers-tools-and-schema).
+ [learn more](../agents.md#tools-tools-and-schema).
- We can't add overloads for every possible signature of retriever, since the return type is a recursive union
- so the signature of functions decorated with `@agent.retriever` is obscured.
+ We can't add overloads for every possible signature of tool, since the return type is a recursive union
+ so the signature of functions decorated with `@agent.tool` is obscured.
Example:
```py
@@ -463,11 +459,11 @@ def retriever(
agent = Agent('test', deps_type=int)
- @agent.retriever
+ @agent.tool
def foobar(ctx: CallContext[int], x: int) -> int:
return ctx.deps + x
- @agent.retriever(retries=2)
+ @agent.tool(retries=2)
async def spam(ctx: CallContext[str], y: float) -> float:
return ctx.deps + y
@@ -477,45 +473,43 @@ async def spam(ctx: CallContext[str], y: float) -> float:
```
Args:
- func: The retriever function to register.
- retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
+ func: The tool function to register.
+ retries: The number of retries to allow for this tool, defaults to the agent's default retries,
which defaults to 1.
""" # noqa: D205
if func is None:
- def retriever_decorator(
- func_: RetrieverContextFunc[AgentDeps, RetrieverParams],
- ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]:
+ def tool_decorator(
+ func_: ToolContextFunc[AgentDeps, ToolParams],
+ ) -> ToolContextFunc[AgentDeps, ToolParams]:
# noinspection PyTypeChecker
- self._register_retriever(_utils.Either(left=func_), retries)
+ self._register_tool(_utils.Either(left=func_), retries)
return func_
- return retriever_decorator
+ return tool_decorator
else:
# noinspection PyTypeChecker
- self._register_retriever(_utils.Either(left=func), retries)
+ self._register_tool(_utils.Either(left=func), retries)
return func
@overload
- def retriever_plain(self, func: RetrieverPlainFunc[RetrieverParams], /) -> RetrieverPlainFunc[RetrieverParams]: ...
+ def tool_plain(self, func: ToolPlainFunc[ToolParams], /) -> ToolPlainFunc[ToolParams]: ...
@overload
- def retriever_plain(
+ def tool_plain(
self, /, *, retries: int | None = None
- ) -> Callable[[RetrieverPlainFunc[RetrieverParams]], RetrieverPlainFunc[RetrieverParams]]: ...
+ ) -> Callable[[ToolPlainFunc[ToolParams]], ToolPlainFunc[ToolParams]]: ...
- def retriever_plain(
- self, func: RetrieverPlainFunc[RetrieverParams] | None = None, /, *, retries: int | None = None
- ) -> Any:
- """Decorator to register a retriever function which DOES NOT take `CallContext` as an argument.
+ def tool_plain(self, func: ToolPlainFunc[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
+ """Decorator to register a tool function which DOES NOT take `CallContext` as an argument.
Can decorate a sync or async functions.
The docstring is inspected to extract both the tool description and description of each parameter,
- [learn more](../agents.md#retrievers-tools-and-schema).
+ [learn more](../agents.md#tools-tools-and-schema).
- We can't add overloads for every possible signature of retriever, since the return type is a recursive union
- so the signature of functions decorated with `@agent.retriever` is obscured.
+ We can't add overloads for every possible signature of tool, since the return type is a recursive union
+ so the signature of functions decorated with `@agent.tool` is obscured.
Example:
```py
@@ -523,11 +517,11 @@ def retriever_plain(
agent = Agent('test')
- @agent.retriever
+ @agent.tool
def foobar(ctx: CallContext[int]) -> int:
return 123
- @agent.retriever(retries=2)
+ @agent.tool(retries=2)
async def spam(ctx: CallContext[str]) -> float:
return 3.14
@@ -537,38 +531,36 @@ async def spam(ctx: CallContext[str]) -> float:
```
Args:
- func: The retriever function to register.
- retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
+ func: The tool function to register.
+ retries: The number of retries to allow for this tool, defaults to the agent's default retries,
which defaults to 1.
"""
if func is None:
- def retriever_decorator(
- func_: RetrieverPlainFunc[RetrieverParams],
- ) -> RetrieverPlainFunc[RetrieverParams]:
+ def tool_decorator(
+ func_: ToolPlainFunc[ToolParams],
+ ) -> ToolPlainFunc[ToolParams]:
# noinspection PyTypeChecker
- self._register_retriever(_utils.Either(right=func_), retries)
+ self._register_tool(_utils.Either(right=func_), retries)
return func_
- return retriever_decorator
+ return tool_decorator
else:
- self._register_retriever(_utils.Either(right=func), retries)
+ self._register_tool(_utils.Either(right=func), retries)
return func
- def _register_retriever(
- self, func: _r.RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int | None
- ) -> None:
- """Private utility to register a retriever function."""
+ def _register_tool(self, func: _r.ToolEitherFunc[AgentDeps, ToolParams], retries: int | None) -> None:
+ """Private utility to register a tool function."""
retries_ = retries if retries is not None else self._default_retries
- retriever = _r.Retriever[AgentDeps, RetrieverParams](func, retries_)
+ tool = _r.Tool[AgentDeps, ToolParams](func, retries_)
- if self._result_schema and retriever.name in self._result_schema.tools:
- raise ValueError(f'Retriever name conflicts with result schema name: {retriever.name!r}')
+ if self._result_schema and tool.name in self._result_schema.tools:
+ raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}')
- if retriever.name in self._retrievers:
- raise ValueError(f'Retriever name conflicts with existing retriever: {retriever.name!r}')
+ if tool.name in self._function_tools:
+ raise ValueError(f'Tool name conflicts with existing tool: {tool.name!r}')
- self._retrievers[retriever.name] = retriever
+ self._function_tools[tool.name] = tool
async def _get_agent_model(
self, model: models.Model | models.KnownModelName | None
@@ -601,7 +593,7 @@ async def _get_agent_model(
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
- agent_model = await model_.agent_model(self._retrievers, self._allow_text_result, result_tools)
+ agent_model = await model_.agent_model(self._function_tools, self._allow_text_result, result_tools)
return model_, custom_model, agent_model
async def _prepare_messages(
@@ -663,12 +655,12 @@ async def _handle_model_response(
if not model_response.calls:
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
- # otherwise we run all retriever functions in parallel
+ # otherwise we run all tool functions in parallel
messages: list[_messages.Message] = []
tasks: list[asyncio.Task[_messages.Message]] = []
for call in model_response.calls:
- if retriever := self._retrievers.get(call.tool_name):
- tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
+ if tool := self._function_tools.get(call.tool_name):
+ tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
else:
messages.append(self._unknown_tool(call.tool_name))
@@ -719,7 +711,7 @@ async def _handle_streamed_model_response(
if self._result_schema.find_tool(structured_msg):
return _MarkFinalResult(model_response)
- # the model is calling a retriever function, consume the response to get the next message
+ # the model is calling a tool function, consume the response to get the next message
async for _ in model_response:
pass
structured_msg = model_response.get()
@@ -727,11 +719,11 @@ async def _handle_streamed_model_response(
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
messages: list[_messages.Message] = [structured_msg]
- # we now run all retriever functions in parallel
+ # we now run all tool functions in parallel
tasks: list[asyncio.Task[_messages.Message]] = []
for call in structured_msg.calls:
- if retriever := self._retrievers.get(call.tool_name):
- tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
+ if tool := self._function_tools.get(call.tool_name):
+ tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
else:
messages.append(self._unknown_tool(call.tool_name))
@@ -763,7 +755,7 @@ async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]:
def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
self._incr_result_retry()
- names = list(self._retrievers.keys())
+ names = list(self._function_tools.keys())
if self._result_schema:
names.extend(self._result_schema.tool_names())
if names:
diff --git a/pydantic_ai_slim/pydantic_ai/dependencies.py b/pydantic_ai_slim/pydantic_ai/dependencies.py
index ee5a087830..6c22473040 100644
--- a/pydantic_ai_slim/pydantic_ai/dependencies.py
+++ b/pydantic_ai_slim/pydantic_ai/dependencies.py
@@ -16,10 +16,10 @@
'CallContext',
'ResultValidatorFunc',
'SystemPromptFunc',
- 'RetrieverReturnValue',
- 'RetrieverContextFunc',
- 'RetrieverPlainFunc',
- 'RetrieverParams',
+ 'ToolReturnValue',
+ 'ToolContextFunc',
+ 'ToolPlainFunc',
+ 'ToolParams',
'JsonData',
)
@@ -39,7 +39,7 @@ class CallContext(Generic[AgentDeps]):
"""Name of the tool being called."""
-RetrieverParams = ParamSpec('RetrieverParams')
+ToolParams = ParamSpec('ToolParams')
"""Retrieval function param spec."""
SystemPromptFunc = Union[
@@ -69,15 +69,15 @@ class CallContext(Generic[AgentDeps]):
JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
"""Type representing any JSON data."""
-RetrieverReturnValue = Union[JsonData, Awaitable[JsonData]]
-"""Return value of a retriever function."""
-RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], RetrieverParams], RetrieverReturnValue]
-"""A retriever function that takes `CallContext` as the first argument.
+ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
+"""Return value of a tool function."""
+ToolContextFunc = Callable[Concatenate[CallContext[AgentDeps], ToolParams], ToolReturnValue]
+"""A tool function that takes `CallContext` as the first argument.
-Usage `RetrieverContextFunc[AgentDeps, RetrieverParams]`.
+Usage `ToolContextFunc[AgentDeps, ToolParams]`.
"""
-RetrieverPlainFunc = Callable[RetrieverParams, RetrieverReturnValue]
-"""A retriever function that does not take `CallContext` as the first argument.
+ToolPlainFunc = Callable[ToolParams, ToolReturnValue]
+"""A tool function that does not take `CallContext` as the first argument.
-Usage `RetrieverPlainFunc[RetrieverParams]`.
+Usage `ToolPlainFunc[ToolParams]`.
"""
diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py
index c0df3a67b4..493f00d37f 100644
--- a/pydantic_ai_slim/pydantic_ai/exceptions.py
+++ b/pydantic_ai_slim/pydantic_ai/exceptions.py
@@ -6,7 +6,7 @@
class ModelRetry(Exception):
- """Exception raised when a retriever function should be retried.
+ """Exception raised when a tool function should be retried.
The agent will return the message to the model and ask it to try calling the function/tool again.
"""
diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py
index 3d6ec3f45d..9331d25c5c 100644
--- a/pydantic_ai_slim/pydantic_ai/messages.py
+++ b/pydantic_ai_slim/pydantic_ai/messages.py
@@ -55,7 +55,7 @@ class UserPrompt:
@dataclass
class ToolReturn:
- """A tool return message, this encodes the result of running a retriever."""
+ """A tool return message, this encodes the result of running a tool."""
tool_name: str
"""The name of the "tool" was called."""
@@ -89,10 +89,10 @@ class RetryPrompt:
This can be sent for a number of reasons:
- * Pydantic validation of retriever arguments failed, here content is derived from a Pydantic
+ * Pydantic validation of tool arguments failed, here content is derived from a Pydantic
[`ValidationError`][pydantic_core.ValidationError]
- * a retriever raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception
- * no retriever was found for the tool name
+ * a tool raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception
+ * no tool was found for the tool name
* the model returned plain text when a structured response was expected
* Pydantic validation of a structured response failed, here content is derived from a Pydantic
[`ValidationError`][pydantic_core.ValidationError]
@@ -182,7 +182,7 @@ def has_content(self) -> bool:
class ModelStructuredResponse:
"""A structured response from a model.
- This is used either to call a retriever or to return a structured response from an agent run.
+ This is used either to call a tool or to return a structured response from an agent run.
"""
calls: list[ToolCall]
diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py
index c56a5d19d9..19e615862f 100644
--- a/pydantic_ai_slim/pydantic_ai/models/__init__.py
+++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py
@@ -63,7 +63,7 @@ class Model(ABC):
@abstractmethod
async def agent_model(
self,
- retrievers: Mapping[str, AbstractToolDefinition],
+ function_tools: Mapping[str, AbstractToolDefinition],
allow_text_result: bool,
result_tools: Sequence[AbstractToolDefinition] | None,
) -> AgentModel:
@@ -72,7 +72,7 @@ async def agent_model(
This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
Args:
- retrievers: The retrievers available to the agent.
+ function_tools: The tools available to the agent.
allow_text_result: Whether a plain text final response/result is permitted.
result_tools: Tool definitions for the final result tool(s), if any.
@@ -259,7 +259,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
class AbstractToolDefinition(Protocol):
"""Abstract definition of a function/tool.
- This is used for both retrievers and result tools.
+ This is used for both tools and result tools.
"""
name: str
diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py
index 1fb862b1d1..b9f4a4a811 100644
--- a/pydantic_ai_slim/pydantic_ai/models/function.py
+++ b/pydantic_ai_slim/pydantic_ai/models/function.py
@@ -67,13 +67,13 @@ def __init__(self, function: FunctionDef | None = None, *, stream_function: Stre
async def agent_model(
self,
- retrievers: Mapping[str, AbstractToolDefinition],
+ function_tools: Mapping[str, AbstractToolDefinition],
allow_text_result: bool,
result_tools: Sequence[AbstractToolDefinition] | None,
) -> AgentModel:
result_tools = list(result_tools) if result_tools is not None else None
return FunctionAgentModel(
- self.function, self.stream_function, AgentInfo(retrievers, allow_text_result, result_tools)
+ self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
)
def name(self) -> str:
@@ -89,11 +89,15 @@ def name(self) -> str:
class AgentInfo:
"""Information about an agent.
- This is passed as the second to functions.
+ This is passed as the second to functions used within [`FunctionModel`][pydantic_ai.models.function.FunctionModel].
"""
- retrievers: Mapping[str, AbstractToolDefinition]
- """The retrievers available on this agent."""
+ function_tools: Mapping[str, AbstractToolDefinition]
+ """The function tools available on this agent.
+
+ These are the tools registered via the [`tool`][pydantic_ai.Agent.tool] and
+ [`tool_plain`][pydantic_ai.Agent.tool_plain] decorators.
+ """
allow_text_result: bool
"""Whether a plain text result is allowed."""
result_tools: list[AbstractToolDefinition] | None
diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py
index 3b825e0dde..f0e06684b2 100644
--- a/pydantic_ai_slim/pydantic_ai/models/gemini.py
+++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py
@@ -112,7 +112,7 @@ def __init__(
async def agent_model(
self,
- retrievers: Mapping[str, AbstractToolDefinition],
+ function_tools: Mapping[str, AbstractToolDefinition],
allow_text_result: bool,
result_tools: Sequence[AbstractToolDefinition] | None,
) -> GeminiAgentModel:
@@ -121,7 +121,7 @@ async def agent_model(
model_name=self.model_name,
auth=self.auth,
url=self.url,
- retrievers=retrievers,
+ function_tools=function_tools,
allow_text_result=allow_text_result,
result_tools=result_tools,
)
@@ -160,12 +160,12 @@ def __init__(
model_name: GeminiModelName,
auth: AuthProtocol,
url: str,
- retrievers: Mapping[str, AbstractToolDefinition],
+ function_tools: Mapping[str, AbstractToolDefinition],
allow_text_result: bool,
result_tools: Sequence[AbstractToolDefinition] | None,
):
check_allow_model_requests()
- tools = [_function_from_abstract_tool(t) for t in retrievers.values()]
+ tools = [_function_from_abstract_tool(t) for t in function_tools.values()]
if result_tools is not None:
tools += [_function_from_abstract_tool(t) for t in result_tools]
diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py
index 80d06f8f58..46420b48b1 100644
--- a/pydantic_ai_slim/pydantic_ai/models/groq.py
+++ b/pydantic_ai_slim/pydantic_ai/models/groq.py
@@ -109,12 +109,12 @@ def __init__(
async def agent_model(
self,
- retrievers: Mapping[str, AbstractToolDefinition],
+ function_tools: Mapping[str, AbstractToolDefinition],
allow_text_result: bool,
result_tools: Sequence[AbstractToolDefinition] | None,
) -> AgentModel:
check_allow_model_requests()
- tools = [self._map_tool_definition(r) for r in retrievers.values()]
+ tools = [self._map_tool_definition(r) for r in function_tools.values()]
if result_tools is not None:
tools += [self._map_tool_definition(r) for r in result_tools]
return GroqAgentModel(
diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py
index 8b8d1cf87b..4b409f720a 100644
--- a/pydantic_ai_slim/pydantic_ai/models/openai.py
+++ b/pydantic_ai_slim/pydantic_ai/models/openai.py
@@ -89,12 +89,12 @@ def __init__(
async def agent_model(
self,
- retrievers: Mapping[str, AbstractToolDefinition],
+ function_tools: Mapping[str, AbstractToolDefinition],
allow_text_result: bool,
result_tools: Sequence[AbstractToolDefinition] | None,
) -> AgentModel:
check_allow_model_requests()
- tools = [self._map_tool_definition(r) for r in retrievers.values()]
+ tools = [self._map_tool_definition(r) for r in function_tools.values()]
if result_tools is not None:
tools += [self._map_tool_definition(r) for r in result_tools]
return OpenAIAgentModel(
diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py
index cd7f920703..f8a1d5482e 100644
--- a/pydantic_ai_slim/pydantic_ai/models/test.py
+++ b/pydantic_ai_slim/pydantic_ai/models/test.py
@@ -45,7 +45,7 @@ def __repr__(self):
class TestModel(Model):
"""A model specifically for testing purposes.
- This will (by default) call all retrievers in the agent model, then return a tool response if possible,
+ This will (by default) call all tools in the agent model, then return a tool response if possible,
otherwise a plain response.
How useful this function will be is unknown, it may be useless, it may require significant changes to be useful.
@@ -57,8 +57,8 @@ class TestModel(Model):
# NOTE: Avoid test discovery by pytest.
__test__ = False
- call_retrievers: list[str] | Literal['all'] = 'all'
- """List of retrievers to call. If `'all'`, all retrievers will be called."""
+ call_tools: list[str] | Literal['all'] = 'all'
+ """List of tools to call. If `'all'`, all tools will be called."""
custom_result_text: str | None = None
"""If set, this text is return as the final result."""
custom_result_args: Any | None = None
@@ -66,25 +66,25 @@ class TestModel(Model):
seed: int = 0
"""Seed for generating random data."""
# these fields are set when the model is called by the agent
- agent_model_retrievers: Mapping[str, AbstractToolDefinition] | None = field(default=None, init=False)
+ agent_model_tools: Mapping[str, AbstractToolDefinition] | None = field(default=None, init=False)
agent_model_allow_text_result: bool | None = field(default=None, init=False)
agent_model_result_tools: list[AbstractToolDefinition] | None = field(default=None, init=False)
async def agent_model(
self,
- retrievers: Mapping[str, AbstractToolDefinition],
+ function_tools: Mapping[str, AbstractToolDefinition],
allow_text_result: bool,
result_tools: Sequence[AbstractToolDefinition] | None,
) -> AgentModel:
- self.agent_model_retrievers = retrievers
+ self.agent_model_tools = function_tools
self.agent_model_allow_text_result = allow_text_result
self.agent_model_result_tools = list(result_tools) if result_tools is not None else None
- if self.call_retrievers == 'all':
- retriever_calls = [(r.name, r) for r in retrievers.values()]
+ if self.call_tools == 'all':
+ tool_calls = [(r.name, r) for r in function_tools.values()]
else:
- retrievers_to_call = (retrievers[name] for name in self.call_retrievers)
- retriever_calls = [(r.name, r) for r in retrievers_to_call]
+ tools_to_call = (function_tools[name] for name in self.call_tools)
+ tool_calls = [(r.name, r) for r in tools_to_call]
if self.custom_result_text is not None:
assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
@@ -104,7 +104,7 @@ async def agent_model(
result = _utils.Either(right=None)
else:
result = _utils.Either(left=None)
- return TestAgentModel(retriever_calls, result, self.agent_model_result_tools, self.seed)
+ return TestAgentModel(tool_calls, result, self.agent_model_result_tools, self.seed)
def name(self) -> str:
return 'test-model'
@@ -117,7 +117,7 @@ class TestAgentModel(AgentModel):
# NOTE: Avoid test discovery by pytest.
__test__ = False
- retriever_calls: list[tuple[str, AbstractToolDefinition]]
+ tool_calls: list[tuple[str, AbstractToolDefinition]]
# left means the text is plain text; right means it's a function call
result: _utils.Either[str | None, Any | None]
result_tools: list[AbstractToolDefinition] | None
@@ -137,12 +137,12 @@ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherS
else:
yield TestStreamStructuredResponse(msg, cost)
- def gen_retriever_args(self, tool_def: AbstractToolDefinition) -> Any:
+ def gen_tool_args(self, tool_def: AbstractToolDefinition) -> Any:
return _JsonSchemaTestData(tool_def.json_schema, self.seed).generate()
def _request(self, messages: list[Message]) -> ModelAnyResponse:
- if self.step == 0 and self.retriever_calls:
- calls = [ToolCall.from_object(name, self.gen_retriever_args(args)) for name, args in self.retriever_calls]
+ if self.step == 0 and self.tool_calls:
+ calls = [ToolCall.from_object(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
self.step += 1
self.last_message_count = len(messages)
return ModelStructuredResponse(calls=calls)
@@ -152,8 +152,8 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse:
new_retry_names = {m.tool_name for m in new_messages if isinstance(m, RetryPrompt)}
if new_retry_names:
calls = [
- ToolCall.from_object(name, self.gen_retriever_args(args))
- for name, args in self.retriever_calls
+ ToolCall.from_object(name, self.gen_tool_args(args))
+ for name, args in self.tool_calls
if name in new_retry_names
]
self.step += 1
@@ -162,7 +162,7 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse:
if response_text := self.result.left:
self.step += 1
if response_text.value is None:
- # build up details of retriever responses
+ # build up details of tool responses
output: dict[str, Any] = {}
for message in messages:
if isinstance(message, ToolReturn):
@@ -170,7 +170,7 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse:
if output:
return ModelTextResponse(content=pydantic_core.to_json(output).decode())
else:
- return ModelTextResponse(content='success (no retriever calls)')
+ return ModelTextResponse(content='success (no tool calls)')
else:
return ModelTextResponse(content=response_text.value)
else:
@@ -181,7 +181,7 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse:
self.step += 1
return ModelStructuredResponse(calls=[ToolCall.from_object(result_tool.name, custom_result_args)])
else:
- response_args = self.gen_retriever_args(result_tool)
+ response_args = self.gen_tool_args(result_tool)
self.step += 1
return ModelStructuredResponse(calls=[ToolCall.from_object(result_tool.name, response_args)])
diff --git a/pydantic_ai_slim/pydantic_ai/models/vertexai.py b/pydantic_ai_slim/pydantic_ai/models/vertexai.py
index fc48ab6bfd..b28dbe2f67 100644
--- a/pydantic_ai_slim/pydantic_ai/models/vertexai.py
+++ b/pydantic_ai_slim/pydantic_ai/models/vertexai.py
@@ -166,7 +166,7 @@ def __init__(
async def agent_model(
self,
- retrievers: Mapping[str, AbstractToolDefinition],
+ function_tools: Mapping[str, AbstractToolDefinition],
allow_text_result: bool,
result_tools: Sequence[AbstractToolDefinition] | None,
) -> GeminiAgentModel:
@@ -176,7 +176,7 @@ async def agent_model(
model_name=self.model_name,
auth=auth,
url=url,
- retrievers=retrievers,
+ function_tools=function_tools,
allow_text_result=allow_text_result,
result_tools=result_tools,
)
diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py
index 7ec2d0d3ff..476b4ec718 100644
--- a/tests/models/test_gemini.py
+++ b/tests/models/test_gemini.py
@@ -96,7 +96,7 @@ class TestToolDefinition:
async def test_agent_model_tools(allow_model_requests: None):
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
- retrievers = {
+ tools = {
'foo': TestToolDefinition(
'foo',
'This is foo',
@@ -118,7 +118,7 @@ async def test_agent_model_tools(allow_model_requests: None):
'This is the tool for the final Result',
{'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']},
)
- agent_model = await m.agent_model(retrievers, True, [result_tool])
+ agent_model = await m.agent_model(tools, True, [result_tool])
assert agent_model.tools == snapshot(
_GeminiTools(
function_declarations=[
@@ -465,7 +465,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(m, system_prompt='this is the system prompt')
- @agent.retriever_plain
+ @agent.tool_plain
async def get_location(loc_name: str) -> str:
if loc_name == 'London':
return json.dumps({'lat': 51, 'lng': 0})
@@ -628,16 +628,16 @@ async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient):
gemini_client = get_gemini_client([first_stream, second_stream])
model = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(model, result_type=tuple[int, int])
- retriever_calls: list[str] = []
+ tool_calls: list[str] = []
- @agent.retriever_plain
+ @agent.tool_plain
async def foo(x: str) -> str:
- retriever_calls.append(f'foo({x=!r})')
+ tool_calls.append(f'foo({x=!r})')
return x
- @agent.retriever_plain
+ @agent.tool_plain
async def bar(y: str) -> str:
- retriever_calls.append(f'bar({y=!r})')
+ tool_calls.append(f'bar({y=!r})')
return y
async with agent.run_stream('Hello') as result:
@@ -667,7 +667,7 @@ async def bar(y: str) -> str:
),
]
)
- assert retriever_calls == snapshot(["foo(x='a')", "bar(y='b')"])
+ assert tool_calls == snapshot(["foo(x='a')", "bar(y='b')"])
async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient):
diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py
index 18c6efaa3f..d19fa94355 100644
--- a/tests/models/test_groq.py
+++ b/tests/models/test_groq.py
@@ -240,7 +240,7 @@ async def test_request_tool_call(allow_model_requests: None):
m = GroqModel('llama-3.1-70b-versatile', groq_client=mock_client)
agent = Agent(m, system_prompt='this is the system prompt')
- @agent.retriever_plain
+ @agent.tool_plain
async def get_location(loc_name: str) -> str:
if loc_name == 'London':
return json.dumps({'lat': 51, 'lng': 0})
diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py
index e88a44dfd0..44e01aa219 100644
--- a/tests/models/test_model_function.py
+++ b/tests/models/test_model_function.py
@@ -87,7 +87,7 @@ def test_simple():
async def weather_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: # pragma: no cover
assert info.allow_text_result
- assert info.retrievers.keys() == {'get_location', 'get_weather'}
+ assert info.function_tools.keys() == {'get_location', 'get_weather'}
last = messages[-1]
if last.role == 'user':
return ModelStructuredResponse(
@@ -111,7 +111,7 @@ async def weather_model(messages: list[Message], info: AgentInfo) -> ModelAnyRes
weather_agent: Agent[None, str] = Agent(FunctionModel(weather_model))
-@weather_agent.retriever_plain
+@weather_agent.tool_plain
async def get_location(location_description: str) -> str:
if location_description == 'London':
lat_lng = {'lat': 51, 'lng': 0}
@@ -120,7 +120,7 @@ async def get_location(location_description: str) -> str:
return json.dumps(lat_lng)
-@weather_agent.retriever
+@weather_agent.tool
async def get_weather(_: CallContext[None], lat: int, lng: int):
if (lat, lng) == (51, 0):
# it always rains in London
@@ -200,7 +200,7 @@ async def call_function_model(messages: list[Message], _: AgentInfo) -> ModelAny
var_args_agent = Agent(FunctionModel(call_function_model), deps_type=int)
-@var_args_agent.retriever
+@var_args_agent.tool
def get_var_args(ctx: CallContext[int], *args: int):
assert ctx.deps == 123
return json.dumps({'args': args})
@@ -222,19 +222,19 @@ def test_var_args():
)
-async def call_retriever(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
+async def call_tool(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
if len(messages) == 1:
- assert len(info.retrievers) == 1
- retriever_id = next(iter(info.retrievers.keys()))
- return ModelStructuredResponse(calls=[ToolCall.from_json(retriever_id, '{}')])
+ assert len(info.function_tools) == 1
+ tool_id = next(iter(info.function_tools.keys()))
+ return ModelStructuredResponse(calls=[ToolCall.from_json(tool_id, '{}')])
else:
return ModelTextResponse('final response')
def test_deps_none():
- agent = Agent(FunctionModel(call_retriever))
+ agent = Agent(FunctionModel(call_tool))
- @agent.retriever
+ @agent.tool
async def get_none(ctx: CallContext[None]):
nonlocal called
@@ -259,8 +259,8 @@ def get_check_foobar(ctx: CallContext[tuple[str, str]]) -> str:
assert ctx.deps == ('foo', 'bar')
return ''
- agent = Agent(FunctionModel(call_retriever), deps_type=tuple[str, str])
- agent.retriever(get_check_foobar)
+ agent = Agent(FunctionModel(call_tool), deps_type=tuple[str, str])
+ agent.tool(get_check_foobar)
called = False
agent.run_sync('Hello', deps=('foo', 'bar'))
assert called
@@ -278,27 +278,27 @@ def test_model_arg():
agent_all = Agent()
-@agent_all.retriever
+@agent_all.tool
async def foo(_: CallContext[None], x: int) -> str:
return str(x + 1)
-@agent_all.retriever(retries=3)
+@agent_all.tool(retries=3)
def bar(ctx, x: int) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingParameterType]
return str(x + 2)
-@agent_all.retriever_plain
+@agent_all.tool_plain
async def baz(x: int) -> str:
return str(x + 3)
-@agent_all.retriever_plain(retries=1)
+@agent_all.tool_plain(retries=1)
def qux(x: int) -> str:
return str(x + 4)
-@agent_all.retriever_plain # pyright: ignore[reportUnknownArgumentType]
+@agent_all.tool_plain # pyright: ignore[reportUnknownArgumentType]
def quz(x) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingParameterType]
return str(x) # pyright: ignore[reportUnknownArgumentType]
@@ -311,11 +311,11 @@ def spam() -> str:
def test_register_all():
async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
return ModelTextResponse(
- f'messages={len(messages)} allow_text_result={info.allow_text_result} retrievers={len(info.retrievers)}'
+ f'messages={len(messages)} allow_text_result={info.allow_text_result} tools={len(info.function_tools)}'
)
result = agent_all.run_sync('Hello', model=FunctionModel(f))
- assert result.data == snapshot('messages=2 allow_text_result=True retrievers=5')
+ assert result.data == snapshot('messages=2 allow_text_result=True tools=5')
def test_call_all():
diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py
index 331fb3ee38..fac96568bd 100644
--- a/tests/models/test_model_test.py
+++ b/tests/models/test_model_test.py
@@ -28,17 +28,17 @@ def test_call_one():
agent = Agent()
calls: list[str] = []
- @agent.retriever_plain
+ @agent.tool_plain
async def ret_a(x: str) -> str:
calls.append('a')
return f'{x}-a'
- @agent.retriever_plain
+ @agent.tool_plain
async def ret_b(x: str) -> str: # pragma: no cover
calls.append('b')
return f'{x}-b'
- result = agent.run_sync('x', model=TestModel(call_retrievers=['ret_a']))
+ result = agent.run_sync('x', model=TestModel(call_tools=['ret_a']))
assert result.data == snapshot('{"ret_a":"a-a"}')
assert calls == ['a']
@@ -74,11 +74,11 @@ def test_result_type():
assert result.data == ('a', 'a')
-def test_retriever_retry():
+def test_tool_retry():
agent = Agent()
call_count = 0
- @agent.retriever_plain
+ @agent.tool_plain
async def my_ret(x: int) -> str:
nonlocal call_count
call_count += 1
diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py
index 52b90a29e9..671180f2ec 100644
--- a/tests/models/test_openai.py
+++ b/tests/models/test_openai.py
@@ -243,7 +243,7 @@ async def test_request_tool_call(allow_model_requests: None):
m = OpenAIModel('gpt-4', openai_client=mock_client)
agent = Agent(m, system_prompt='this is the system prompt')
- @agent.retriever_plain
+ @agent.tool_plain
async def get_location(loc_name: str) -> str:
if loc_name == 'London':
return json.dumps({'lat': 51, 'lng': 0})
diff --git a/tests/test_agent.py b/tests/test_agent.py
index 2687148ac4..7b74a833c1 100644
--- a/tests/test_agent.py
+++ b/tests/test_agent.py
@@ -179,7 +179,7 @@ def test_response_tuple():
result = agent.run_sync('Hello')
assert result.data == snapshot(('a', 'a'))
- assert m.agent_model_retrievers == snapshot({})
+ assert m.agent_model_tools == snapshot({})
assert m.agent_model_allow_text_result is False
assert m.agent_model_result_tools is not None
@@ -235,10 +235,10 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
assert agent._result_schema.allow_text_result is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess]
result = agent.run_sync('Hello')
- assert result.data == snapshot('success (no retriever calls)')
+ assert result.data == snapshot('success (no tool calls)')
assert got_tool_call_name == snapshot(None)
- assert m.agent_model_retrievers == snapshot({})
+ assert m.agent_model_tools == snapshot({})
assert m.agent_model_allow_text_result is True
assert m.agent_model_result_tools is not None
@@ -312,7 +312,7 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
assert result.data == mod.Foo(a=0, b='a')
assert got_tool_call_name == snapshot('final_result_Foo')
- assert m.agent_model_retrievers == snapshot({})
+ assert m.agent_model_tools == snapshot({})
assert m.agent_model_allow_text_result is False
assert m.agent_model_result_tools is not None
@@ -361,7 +361,7 @@ def test_run_with_history_new():
agent = Agent(m, system_prompt='Foobar')
- @agent.retriever_plain
+ @agent.tool_plain
async def ret_a(x: str) -> str:
return f'{x}-apple'
@@ -450,7 +450,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
agent.run_sync('Hello')
-def test_unknown_retriever():
+def test_unknown_tool():
def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
return ModelStructuredResponse(calls=[ToolCall.from_json('foobar', '{}')])
@@ -472,7 +472,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
)
-def test_unknown_retriever_fix():
+def test_unknown_tool_fix():
def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse:
if len(m) > 1:
return ModelTextResponse(content='success')
diff --git a/tests/test_deps.py b/tests/test_deps.py
index 1c1ff29281..c52f3e03e2 100644
--- a/tests/test_deps.py
+++ b/tests/test_deps.py
@@ -13,27 +13,27 @@ class MyDeps:
agent = Agent(TestModel(), deps_type=MyDeps)
-@agent.retriever
-async def example_retriever(ctx: CallContext[MyDeps]) -> str:
+@agent.tool
+async def example_tool(ctx: CallContext[MyDeps]) -> str:
return f'{ctx.deps}'
def test_deps_used():
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
- assert result.data == '{"example_retriever":"MyDeps(foo=1, bar=2)"}'
+ assert result.data == '{"example_tool":"MyDeps(foo=1, bar=2)"}'
def test_deps_override():
with agent.override_deps(MyDeps(foo=3, bar=4)):
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
- assert result.data == '{"example_retriever":"MyDeps(foo=3, bar=4)"}'
+ assert result.data == '{"example_tool":"MyDeps(foo=3, bar=4)"}'
with agent.override_deps(MyDeps(foo=5, bar=6)):
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
- assert result.data == '{"example_retriever":"MyDeps(foo=5, bar=6)"}'
+ assert result.data == '{"example_tool":"MyDeps(foo=5, bar=6)"}'
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
- assert result.data == '{"example_retriever":"MyDeps(foo=3, bar=4)"}'
+ assert result.data == '{"example_tool":"MyDeps(foo=3, bar=4)"}'
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
- assert result.data == '{"example_retriever":"MyDeps(foo=1, bar=2)"}'
+ assert result.data == '{"example_tool":"MyDeps(foo=1, bar=2)"}'
diff --git a/tests/test_logfire.py b/tests/test_logfire.py
index 90c7d4587a..0141fca633 100644
--- a/tests/test_logfire.py
+++ b/tests/test_logfire.py
@@ -62,7 +62,7 @@ def get_summary() -> LogfireSummary:
def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
agent = Agent(model=TestModel())
- @agent.retriever_plain
+ @agent.tool_plain
async def my_ret(x: int) -> str:
return str(x + 1)
diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py
index 7ff73344d9..37258c4809 100644
--- a/tests/test_retrievers.py
+++ b/tests/test_retrievers.py
@@ -11,48 +11,48 @@
from pydantic_ai.models.test import TestModel
-def test_retriever_no_ctx():
+def test_tool_no_ctx():
agent = Agent(TestModel())
with pytest.raises(UserError) as exc_info:
- @agent.retriever # pyright: ignore[reportArgumentType]
- def invalid_retriever(x: int) -> str: # pragma: no cover
+ @agent.tool # pyright: ignore[reportArgumentType]
+ def invalid_tool(x: int) -> str: # pragma: no cover
return 'Hello'
assert str(exc_info.value) == snapshot(
- 'Error generating schema for test_retriever_no_ctx..invalid_retriever:\n'
- ' First argument must be a CallContext instance when using `.retriever`'
+ 'Error generating schema for test_tool_no_ctx..invalid_tool:\n'
+ ' First argument must be a CallContext instance when using `.tool`'
)
-def test_retriever_plain_with_ctx():
+def test_tool_plain_with_ctx():
agent = Agent(TestModel())
with pytest.raises(UserError) as exc_info:
- @agent.retriever_plain
- async def invalid_retriever(ctx: CallContext[None]) -> str: # pragma: no cover
+ @agent.tool_plain
+ async def invalid_tool(ctx: CallContext[None]) -> str: # pragma: no cover
return 'Hello'
assert str(exc_info.value) == snapshot(
- 'Error generating schema for test_retriever_plain_with_ctx..invalid_retriever:\n'
- ' CallContext instance can only be used with `.retriever`'
+ 'Error generating schema for test_tool_plain_with_ctx..invalid_tool:\n'
+ ' CallContext instance can only be used with `.tool`'
)
-def test_retriever_ctx_second():
+def test_tool_ctx_second():
agent = Agent(TestModel())
with pytest.raises(UserError) as exc_info:
- @agent.retriever # pyright: ignore[reportArgumentType]
- def invalid_retriever(x: int, ctx: CallContext[None]) -> str: # pragma: no cover
+ @agent.tool # pyright: ignore[reportArgumentType]
+ def invalid_tool(x: int, ctx: CallContext[None]) -> str: # pragma: no cover
return 'Hello'
assert str(exc_info.value) == snapshot(
- 'Error generating schema for test_retriever_ctx_second..invalid_retriever:\n'
- ' First argument must be a CallContext instance when using `.retriever`\n'
+ 'Error generating schema for test_tool_ctx_second..invalid_tool:\n'
+ ' First argument must be a CallContext instance when using `.tool`\n'
' CallContext instance can only be used as the first argument'
)
@@ -68,14 +68,14 @@ async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover
async def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
- assert len(info.retrievers) == 1
- r = next(iter(info.retrievers.values()))
+ assert len(info.function_tools) == 1
+ r = next(iter(info.function_tools.values()))
return ModelTextResponse(json.dumps(r.json_schema))
def test_docstring_google():
agent = Agent(FunctionModel(get_json_schema))
- agent.retriever_plain(google_style_docstring)
+ agent.tool_plain(google_style_docstring)
result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
@@ -106,7 +106,7 @@ def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover
def test_docstring_sphinx():
agent = Agent(FunctionModel(get_json_schema))
- agent.retriever_plain(sphinx_style_docstring)
+ agent.tool_plain(sphinx_style_docstring)
result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
@@ -138,7 +138,7 @@ def numpy_style_docstring(*, foo: int, bar: str) -> str: # pragma: no cover
def test_docstring_numpy():
agent = Agent(FunctionModel(get_json_schema))
- agent.retriever_plain(numpy_style_docstring)
+ agent.tool_plain(numpy_style_docstring)
result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
@@ -163,7 +163,7 @@ def unknown_docstring(**kwargs: int) -> str: # pragma: no cover
def test_docstring_unknown():
agent = Agent(FunctionModel(get_json_schema))
- agent.retriever_plain(unknown_docstring)
+ agent.tool_plain(unknown_docstring)
result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
@@ -192,7 +192,7 @@ async def google_style_docstring_no_body(
def test_docstring_google_no_body():
agent = Agent(FunctionModel(get_json_schema))
- agent.retriever_plain(google_style_docstring_no_body)
+ agent.tool_plain(google_style_docstring_no_body)
result = agent.run_sync('')
json_schema = json.loads(result.data)
@@ -216,7 +216,7 @@ class Foo(BaseModel):
x: int
y: str
- @agent.retriever_plain
+ @agent.tool_plain
def takes_just_model(model: Foo) -> str:
return f'{model.x} {model.y}'
@@ -242,7 +242,7 @@ class Foo(BaseModel):
x: int
y: str
- @agent.retriever_plain
+ @agent.tool_plain
def takes_just_model(model: Foo, z: int) -> str:
return f'{model.x} {model.y} {z}'
diff --git a/tests/test_streaming.py b/tests/test_streaming.py
index a0b9d94381..4b272a232d 100644
--- a/tests/test_streaming.py
+++ b/tests/test_streaming.py
@@ -33,7 +33,7 @@ async def test_streamed_text_response():
agent = Agent(m)
- @agent.retriever_plain
+ @agent.tool_plain
async def ret_a(x: str) -> str:
return f'{x}-apple'
@@ -162,14 +162,14 @@ async def text_stream(_messages: list[Message], _: AgentInfo) -> AsyncIterator[s
assert call_index == 2
-async def test_call_retriever():
+async def test_call_tool():
async def stream_structured_function(
messages: list[Message], agent_info: AgentInfo
) -> AsyncIterator[DeltaToolCalls | str]:
if len(messages) == 1:
- assert agent_info.retrievers is not None
- assert len(agent_info.retrievers) == 1
- name = next(iter(agent_info.retrievers))
+ assert agent_info.function_tools is not None
+ assert len(agent_info.function_tools) == 1
+ name = next(iter(agent_info.function_tools))
first = messages[0]
assert isinstance(first, UserPrompt)
json_string = json.dumps({'x': first.content})
@@ -189,7 +189,7 @@ async def stream_structured_function(
agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int])
- @agent.retriever_plain
+ @agent.tool_plain
async def ret_a(x: str) -> str:
assert x == 'hello'
return f'{x} world'
@@ -227,7 +227,7 @@ async def ret_a(x: str) -> str:
)
-async def test_call_retriever_empty():
+async def test_call_tool_empty():
async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
yield {}
@@ -238,13 +238,13 @@ async def stream_structured_function(_messages: list[Message], _: AgentInfo) ->
pass
-async def test_call_retriever_wrong_name():
+async def test_call_tool_wrong_name():
async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
yield {0: DeltaToolCall(name='foobar', json_args='{}')}
agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int])
- @agent.retriever_plain
+ @agent.tool_plain
async def ret_a(x: str) -> str:
return x
diff --git a/tests/typed_agent.py b/tests/typed_agent.py
index f2d7e2d6af..9295c5742b 100644
--- a/tests/typed_agent.py
+++ b/tests/typed_agent.py
@@ -44,47 +44,47 @@ def expect_error(error_type: type[Exception]) -> Iterator[None]:
raise AssertionError('Expected an error')
-@typed_agent.retriever
-async def ok_retriever(ctx: CallContext[MyDeps], x: str) -> str:
+@typed_agent.tool
+async def ok_tool(ctx: CallContext[MyDeps], x: str) -> str:
assert_type(ctx.deps, MyDeps)
total = ctx.deps.foo + ctx.deps.bar
return f'{x} {total}'
-# we can't add overloads for every possible signature of retriever, so the type of ok_retriever is obscured
-assert_type(ok_retriever, Callable[[CallContext[MyDeps], str], str]) # type: ignore[assert-type]
+# we can't add overloads for every possible signature of tool, so the type of ok_tool is obscured
+assert_type(ok_tool, Callable[[CallContext[MyDeps], str], str]) # type: ignore[assert-type]
-@typed_agent.retriever_plain
-def ok_retriever_plain(x: str) -> dict[str, str]:
+@typed_agent.tool_plain
+def ok_tool_plain(x: str) -> dict[str, str]:
return {'x': x}
-@typed_agent.retriever_plain
+@typed_agent.tool_plain
def ok_json_list(x: str) -> list[Union[str, int]]:
return [x, 1]
-@typed_agent.retriever
-async def bad_retriever1(ctx: CallContext[MyDeps], x: str) -> str:
+@typed_agent.tool
+async def bad_tool1(ctx: CallContext[MyDeps], x: str) -> str:
total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined]
return f'{x} {total}'
-@typed_agent.retriever # type: ignore[arg-type]
-async def bad_retriever2(ctx: CallContext[int], x: str) -> str:
+@typed_agent.tool # type: ignore[arg-type]
+async def bad_tool2(ctx: CallContext[int], x: str) -> str:
return f'{x} {ctx.deps}'
-@typed_agent.retriever_plain # type: ignore[arg-type]
-async def bad_retriever_return(x: int) -> list[MyDeps]:
+@typed_agent.tool_plain # type: ignore[arg-type]
+async def bad_tool_return(x: int) -> list[MyDeps]:
return [MyDeps(1, x)]
with expect_error(ValueError):
- @typed_agent.retriever # type: ignore[arg-type]
- async def bad_retriever3(x: str) -> str:
+ @typed_agent.tool # type: ignore[arg-type]
+ async def bad_tool3(x: str) -> str:
return x