diff --git a/docs/agents.md b/docs/agents.md
index 68d1739770..5c9c0d3a11 100644
--- a/docs/agents.md
+++ b/docs/agents.md
@@ -114,14 +114,16 @@ print(result2.data)
```
1. Continue the conversation, without `message_history` the model would not know who "he" was referring to.
+_(This example is complete, it can be run "as is")_
+
## System Prompts
System prompts might seem simple at first glance since they're just strings (or sequences of strings that are concatenated), but crafting the right system prompt is key to getting the model to behave as you want.
Generally, system prompts fall into two categories:
-1. **Static system prompts**: These are known when writing the code and can be defined via the `system_prompt` parameter of the `Agent` constructor.
-2. **Dynamic system prompts**: These aren't known until runtime and should be defined via functions decorated with `@agent.system_prompt`.
+1. **Static system prompts**: These are known when writing the code and can be defined via the `system_prompt` parameter of the [`Agent` constructor][pydantic_ai.Agent.__init__].
+2. **Dynamic system prompts**: These aren't known until runtime and should be defined via functions decorated with [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt].
You can add both to a single agent; they're concatenated in the order they're defined at runtime.
@@ -159,28 +161,323 @@ print(result.data)
3. Dynamic system prompt defined via a decorator.
4. Another dynamic system prompt, system prompts don't have to have the `CallContext` parameter.
+_(This example is complete, it can be run "as is")_
+
## Retrievers
-* two different retriever decorators (`retriver_plain` and `retriever_context`) depending on whether you want to use the context or not, show an example using both
-* retriever parameters are extracted and used to build the schema for the tool, then validated with pydantic
-* if a retriever has a single "model like" parameter (e.g. pydantic mode, dataclass, typed dict), the schema for the tool will but just that type
-* docstrings are parsed to get the tool description, thanks to griffe docs for each parameter are extracting using Google, numpy or sphinx docstring styling
-* You can raise `ModelRetry` from within a retriever to suggest to the model it should retry
-* the return type of retriever can either be `str` or a JSON object typed as `dict[str, Any]` as some models (e.g. Gemini) support structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data
+Retrievers provide a mechanism for models to request extra information to help them generate a response.
+
+They're useful when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic by deferring some of the logic required to generate a response to another tool.
+
+!!! info "Retrievers vs. RAG"
+ Retrievers are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information.
+
+ The main semantic difference between PydanticAI Retreivers and RAG is RAG is synonymous with vector search, while PydanticAI retrievers are more general purpose. (Note: we might add support for some vector search functionality in the future, particuarly an API for generating embeddings, see [#58](https://github.com/pydantic/pydantic-ai/issues/58))
+
+There are two different decorator functions to register retrievers:
+
+1. [`@agent.retriever_plain`][pydantic_ai.Agent.retriever_plain] — for retrievers that don't need access to the agent [context][pydantic_ai.dependencies.CallContext]
+2. [`@agent.retriever_context`][pydantic_ai.Agent.retriever_context] — for retrievers that do need access to the agent [context][pydantic_ai.dependencies.CallContext]
+
+Here's an example using both:
+
+```py title="dice_game.py"
+import random
+
+from pydantic_ai import Agent, CallContext
+
+agent = Agent(
+ 'gemini-1.5-flash', # (1)!
+ deps_type=str, # (2)!
+ system_prompt=(
+ "You're a dice game, you should roll the dice and see if the number "
+ "you got back matches the user's guess, if so tell them they're a winner. "
+ "Use the player's name in the response."
+ ),
+)
+
+
+@agent.retriever_plain # (3)!
+def roll_dice() -> str:
+ """Roll a six-sided dice and return the result."""
+ return str(random.randint(1, 6))
+
+
+@agent.retriever_context # (4)!
+def get_player_name(ctx: CallContext[str]) -> str:
+ """Get the player's name."""
+ return ctx.deps
+
+
+dice_result = agent.run_sync('My guess is 4', deps='Adam') # (5)!
+print(dice_result.data)
+#> Congratulations Adam, you guessed correctly! You're a winner!
+```
+
+1. This is a pretty simple task, so we can use the fast and cheap Gemini flash model.
+2. We pass the user's name as the dependency, to keep things simple we use just the name as a string as the dependency.
+3. This retriever doesn't need any context, it just returns a random number. You could probably use a dynamic system prompt in this case.
+4. This retriever needs the player's name, so it uses `CallContext` to access dependencies which are just the player's name.
+5. Run the agent, passing the player's name as the dependency.
+
+_(This example is complete, it can be run "as is")_
+
+Let's print the messages from that game to see what happened:
+
+```python title="dice_game_messages.py"
+from dice_game import dice_result
+
+print(dice_result.all_messages())
+"""
+[
+ SystemPrompt(
+ content="You're a dice game, you should roll the dice and see if the number you got back matches the user's guess, if so tell them they're a winner. Use the player's name in the response.",
+ role='system',
+ ),
+ UserPrompt(
+ content='My guess is 4',
+ timestamp=datetime.datetime(...),
+ role='user',
+ ),
+ ModelStructuredResponse(
+ calls=[
+ ToolCall(
+ tool_name='roll_dice', args=ArgsObject(args_object={}), tool_id=None
+ )
+ ],
+ timestamp=datetime.datetime(...),
+ role='model-structured-response',
+ ),
+ ToolReturn(
+ tool_name='roll_dice',
+ content='4',
+ tool_id=None,
+ timestamp=datetime.datetime(...),
+ role='tool-return',
+ ),
+ ModelStructuredResponse(
+ calls=[
+ ToolCall(
+ tool_name='get_player_name',
+ args=ArgsObject(args_object={}),
+ tool_id=None,
+ )
+ ],
+ timestamp=datetime.datetime(...),
+ role='model-structured-response',
+ ),
+ ToolReturn(
+ tool_name='get_player_name',
+ content='Adam',
+ tool_id=None,
+ timestamp=datetime.datetime(...),
+ role='tool-return',
+ ),
+ ModelTextResponse(
+ content="Congratulations Adam, you guessed correctly! You're a winner!",
+ timestamp=datetime.datetime(...),
+ role='model-text-response',
+ ),
+]
+"""
+```
+
+We can represent that as a flow diagram, thus:
+
+
+
+
+### Retrievers, tools, and schema
+
+Under the hood, retrievers use the model's "tools" or "functions" API to let the model know what retrievers are available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call retrievers while others end the run and return a result.
+
+Function parameters are extracted from the function signature, and all parameters except `CallContext` are used to build the schema for that tool call.
+
+Even better, PydanticAI extracts the docstring from retriever functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and add them to the schema.
+
+[Griffe supports](https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings) extracting parameter descriptions from `google`, `numpy` and `sphinx` style docstrings, PydanticAI will infer the format to use based on the docstring. We'll add support in future to explicitly set the style to use, and warn/error if not all parameters are documented, see [#59](https://github.com/pydantic/pydantic-ai/issues/59).
+
+To demonstrate retriever schema, here we use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to print the schema a model would receive:
+
+```py title="retriever_schema.py"
+from pydantic_ai import Agent
+from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse
+from pydantic_ai.models.function import AgentInfo, FunctionModel
+
+agent = Agent()
+
+
+@agent.retriever_plain
+def foobar(a: int, b: str, c: dict[str, list[float]]) -> str:
+ """Get me foobar.
+
+ Args:
+ a: apple pie
+ b: banana cake
+ c: carrot smoothie
+ """
+ return f'{a} {b} {c}'
+
+
+def print_schema(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
+ retriever = info.retrievers['foobar']
+ print(retriever.description)
+ #> Get me foobar.
+ print(retriever.json_schema)
+ """
+ {
+ 'description': 'Get me foobar.',
+ 'properties': {
+ 'a': {'description': 'apple pie', 'title': 'A', 'type': 'integer'},
+ 'b': {'description': 'banana cake', 'title': 'B', 'type': 'string'},
+ 'c': {
+ 'additionalProperties': {'items': {'type': 'number'}, 'type': 'array'},
+ 'description': 'carrot smoothie',
+ 'title': 'C',
+ 'type': 'object',
+ },
+ },
+ 'required': ['a', 'b', 'c'],
+ 'type': 'object',
+ 'additionalProperties': False,
+ }
+ """
+ return ModelTextResponse(content='foobar')
+
+
+agent.run_sync('hello', model=FunctionModel(print_schema))
+```
+
+_(This example is complete, it can be run "as is")_
+
+The return type of retriever can any valid JSON object ([`JsonData`][pydantic_ai.dependencies.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data, if a Python is returned and the model expects a string, the value will be serialized to JSON
## Reflection and self-correction
-* validation errors from both retrievers parameter validation and structured result validation can be passed back to the with a request to retry
-* as described above, you can also raise `ModelRetry` from within a retriever or result validator to tell the model it should retry
-* the default retry count is 1, but can be altered both on a whole agent, or on a per-retriever basis and result validator basis
-* you can access the current retry count from within a retriever or result validator via `ctx.retry`
+Validation errors from both retriever parameter validation and [structured result validation](results.md#structured-result-validation) can be passed back to the model with a request to retry.
+
+You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [retriever](#retrievers) or [result validator functions](results.md#result-validators-functions) to tell the model it should retry.
+
+- The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific retriever][pydantic_ai.Agent.retriever_context], or a [result validator][pydantic_ai.Agent.__init__].
+- You can access the current retry count from within a retriever or result validator via [`ctx.retry`][pydantic_ai.dependencies.CallContext].
+
+Here's an example:
+
+```py title="retriever_retry.py"
+from fake_database import DatabaseConn
+from pydantic import BaseModel
+
+from pydantic_ai import Agent, CallContext, ModelRetry
+
+
+class ChatResult(BaseModel):
+ user_id: int
+ message: str
+
+
+agent = Agent(
+ 'openai:gpt-4o',
+ deps_type=DatabaseConn,
+ result_type=ChatResult,
+)
+
+
+@agent.retriever_context(retries=2)
+def get_user_by_name(ctx: CallContext[DatabaseConn], name: str) -> int:
+ """Get a user's ID from their full name."""
+ print(name)
+ #> John
+ #> John Doe
+ user_id = ctx.deps.users.get(name=name)
+ if user_id is None:
+ raise ModelRetry(
+ f'No user found with name {name!r}, remember to provide their full name'
+ )
+ return user_id
+
+
+result = agent.run_sync(
+ 'Send a message to John Doe asking for coffee next week', deps=DatabaseConn()
+)
+print(result.data)
+"""
+user_id=123 message='Hello John, would you be free for coffee sometime next week? Let me know what works for you!'
+"""
+```
## Model errors
-* If models behave unexpectedly, e.g. the retry limit is exceed, agent runs will raise `UnexpectedModelBehaviour` exceptions
-* If you use PydanticAI in correctly, we try to raise a `UserError` with a helpful message
-* show an except of a `UnexpectedModelBehaviour` being raised
-* if a `UnexpectedModelBehaviour` is raised, you may want to access the [`.last_run_messages`][pydantic_ai.Agent.last_run_messages] attribute of an agent to see the messages exchanged that led to the error, show an example of accessing `.last_run_messages` in an except block to get more details
+If models behave unexpectedly (e.g., the retry limit is exceeded, or their API returns `503`), agent runs will raise [`UnexpectedModelBehaviour`][pydantic_ai.exceptions.UnexpectedModelBehaviour].
+
+In these cases, [`agent.last_run_messages`][pydantic_ai.Agent.last_run_messages] can be used to access the messages exchanged during the run to help diagnose the issue.
+
+```python
+from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehaviour
+
+agent = Agent('openai:gpt-4o')
+
+
+@agent.retriever_plain
+def calc_volume(size: int) -> int: # (1)!
+ if size == 42:
+ return size**3
+ else:
+ raise ModelRetry('Please try again.')
+
+
+try:
+ result = agent.run_sync('Please get me the volume of a box with size 6.')
+except UnexpectedModelBehaviour as e:
+ print('An error occurred:', e)
+ #> An error occurred: Retriever exceeded max retries count of 1
+ print('cause:', repr(e.__cause__))
+ #> cause: ModelRetry('Please try again.')
+ print('messages:', agent.last_run_messages)
+ """
+ messages:
+ [
+ UserPrompt(
+ content='Please get me the volume of a box with size 6.',
+ timestamp=datetime.datetime(...),
+ role='user',
+ ),
+ ModelStructuredResponse(
+ calls=[
+ ToolCall(
+ tool_name='calc_volume',
+ args=ArgsObject(args_object={'size': 6}),
+ tool_id=None,
+ )
+ ],
+ timestamp=datetime.datetime(...),
+ role='model-structured-response',
+ ),
+ RetryPrompt(
+ content='Please try again.',
+ tool_name='calc_volume',
+ tool_id=None,
+ timestamp=datetime.datetime(...),
+ role='retry-prompt',
+ ),
+ ModelStructuredResponse(
+ calls=[
+ ToolCall(
+ tool_name='calc_volume',
+ args=ArgsObject(args_object={'size': 6}),
+ tool_id=None,
+ )
+ ],
+ timestamp=datetime.datetime(...),
+ role='model-structured-response',
+ ),
+ ]
+ """
+else:
+ print(result.data)
+```
+1. Define a retriever that will raise `ModelRetry` repeatedly in this case.
+
+_(This example is complete, it can be run "as is")_
## API Reference
diff --git a/docs/dependencies.md b/docs/dependencies.md
index 79a9660d1f..d97c66cd38 100644
--- a/docs/dependencies.md
+++ b/docs/dependencies.md
@@ -1,6 +1,6 @@
# Dependencies
-PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators).
+PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators-functions).
Matching PydanticAI's design philosophy, our dependency system tries to use existing best practice in Python development rather than inventing esoteric "magic", this should make dependencies type-safe, understandable easier to test and ultimately easier to deploy in production.
@@ -159,7 +159,7 @@ _(This example is complete, it can be run "as is")_
## Full Example
-As well as system prompts, dependencies can be used in [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators).
+As well as system prompts, dependencies can be used in [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators-functions).
```python title="full_example.py" hl_lines="27-35 38-48"
from dataclasses import dataclass
diff --git a/docs/img/dice-diagram-dark.svg b/docs/img/dice-diagram-dark.svg
new file mode 100644
index 0000000000..b0b77e93f5
--- /dev/null
+++ b/docs/img/dice-diagram-dark.svg
@@ -0,0 +1,11 @@
+
diff --git a/docs/img/dice-diagram-light.svg b/docs/img/dice-diagram-light.svg
new file mode 100644
index 0000000000..64856c9340
--- /dev/null
+++ b/docs/img/dice-diagram-light.svg
@@ -0,0 +1,11 @@
+
diff --git a/docs/message-history.md b/docs/message-history.md
index cb2c8a8ff2..528b17e6d3 100644
--- a/docs/message-history.md
+++ b/docs/message-history.md
@@ -41,16 +41,12 @@ print(result.all_messages())
SystemPrompt(content='Be a helpful assistant.', role='system'),
UserPrompt(
content='Tell me a joke.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='user',
),
ModelTextResponse(
content='Did you hear about the toothpaste scandal? They called it Colgate.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='model-text-response',
),
]
@@ -62,16 +58,12 @@ print(result.new_messages())
[
UserPrompt(
content='Tell me a joke.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='user',
),
ModelTextResponse(
content='Did you hear about the toothpaste scandal? They called it Colgate.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='model-text-response',
),
]
@@ -96,9 +88,7 @@ async def main():
SystemPrompt(content='Be a helpful assistant.', role='system'),
UserPrompt(
content='Tell me a joke.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='user',
),
]
@@ -120,16 +110,12 @@ async def main():
SystemPrompt(content='Be a helpful assistant.', role='system'),
UserPrompt(
content='Tell me a joke.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='user',
),
ModelTextResponse(
content='Did you hear about the toothpaste scandal? They called it Colgate.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='model-text-response',
),
]
@@ -175,30 +161,22 @@ print(result2.all_messages())
SystemPrompt(content='Be a helpful assistant.', role='system'),
UserPrompt(
content='Tell me a joke.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='user',
),
ModelTextResponse(
content='Did you hear about the toothpaste scandal? They called it Colgate.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='model-text-response',
),
UserPrompt(
content='Explain?',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='user',
),
ModelTextResponse(
content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='model-text-response',
),
]
@@ -233,30 +211,22 @@ print(result2.all_messages())
SystemPrompt(content='Be a helpful assistant.', role='system'),
UserPrompt(
content='Tell me a joke.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='user',
),
ModelTextResponse(
content='Did you hear about the toothpaste scandal? They called it Colgate.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='model-text-response',
),
UserPrompt(
content='Explain?',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='user',
),
ModelTextResponse(
content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.',
- timestamp=datetime.datetime(
- 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc
- ),
+ timestamp=datetime.datetime(...),
role='model-text-response',
),
]
diff --git a/docs/results.md b/docs/results.md
index 94ef57df20..ca7f1ce1fe 100644
--- a/docs/results.md
+++ b/docs/results.md
@@ -2,7 +2,9 @@
TODO
-## Result Validators
+## Structured result validation
+
+## Result validators functions
TODO
diff --git a/mkdocs.yml b/mkdocs.yml
index b7ebc585bf..5bce3b0b07 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -120,6 +120,10 @@ markdown_extensions:
- pymdownx.tasklist:
custom_checkbox: true
- sane_lists # this means you can start a list from any number
+ - material.extensions.preview:
+ targets:
+ include:
+ - '*'
watch:
- pydantic_ai
diff --git a/pydantic_ai/_griffe.py b/pydantic_ai/_griffe.py
new file mode 100644
index 0000000000..e69dbf7003
--- /dev/null
+++ b/pydantic_ai/_griffe.py
@@ -0,0 +1,128 @@
+from __future__ import annotations as _annotations
+
+import re
+from inspect import Signature
+from typing import Any, Callable, Literal, cast
+
+from _griffe.enumerations import DocstringSectionKind
+from _griffe.models import Docstring, Object as GriffeObject
+
+DocstringStyle = Literal['google', 'numpy', 'sphinx']
+
+
+def doc_descriptions(
+ func: Callable[..., Any], sig: Signature, *, style: DocstringStyle | None = None
+) -> tuple[str, dict[str, str]]:
+ """Extract the function description and parameter descriptions from a function's docstring.
+
+ Returns:
+ A tuple of (main function description, parameter descriptions).
+ """
+ doc = func.__doc__
+ if doc is None:
+ return '', {}
+
+ # see https://github.com/mkdocstrings/griffe/issues/293
+ parent = cast(GriffeObject, sig)
+
+ docstring = Docstring(doc, lineno=1, parser=style or _infer_docstring_style(doc), parent=parent)
+ sections = docstring.parse()
+
+ params = {}
+ if parameters := next((p for p in sections if p.kind == DocstringSectionKind.parameters), None):
+ params = {p.name: p.description for p in parameters.value}
+
+ main_desc = ''
+ if main := next((p for p in sections if p.kind == DocstringSectionKind.text), None):
+ main_desc = main.value
+
+ return main_desc, params
+
+
+def _infer_docstring_style(doc: str) -> DocstringStyle:
+ """Simplistic docstring style inference."""
+ for pattern, replacements, style in _docstring_style_patterns:
+ matches = (
+ re.search(pattern.format(replacement), doc, re.IGNORECASE | re.MULTILINE) for replacement in replacements
+ )
+ if any(matches):
+ return style
+ # fallback to google style
+ return 'google'
+
+
+# See https://github.com/mkdocstrings/griffe/issues/329#issuecomment-2425017804
+_docstring_style_patterns: list[tuple[str, list[str], DocstringStyle]] = [
+ (
+ r'\n[ \t]*:{0}([ \t]+\w+)*:([ \t]+.+)?\n',
+ [
+ 'param',
+ 'parameter',
+ 'arg',
+ 'argument',
+ 'key',
+ 'keyword',
+ 'type',
+ 'var',
+ 'ivar',
+ 'cvar',
+ 'vartype',
+ 'returns',
+ 'return',
+ 'rtype',
+ 'raises',
+ 'raise',
+ 'except',
+ 'exception',
+ ],
+ 'sphinx',
+ ),
+ (
+ r'\n[ \t]*{0}:([ \t]+.+)?\n[ \t]+.+',
+ [
+ 'args',
+ 'arguments',
+ 'params',
+ 'parameters',
+ 'keyword args',
+ 'keyword arguments',
+ 'other args',
+ 'other arguments',
+ 'other params',
+ 'other parameters',
+ 'raises',
+ 'exceptions',
+ 'returns',
+ 'yields',
+ 'receives',
+ 'examples',
+ 'attributes',
+ 'functions',
+ 'methods',
+ 'classes',
+ 'modules',
+ 'warns',
+ 'warnings',
+ ],
+ 'google',
+ ),
+ (
+ r'\n[ \t]*{0}\n[ \t]*---+\n',
+ [
+ 'deprecated',
+ 'parameters',
+ 'other parameters',
+ 'returns',
+ 'yields',
+ 'receives',
+ 'raises',
+ 'warns',
+ 'attributes',
+ 'functions',
+ 'methods',
+ 'classes',
+ 'modules',
+ ],
+ 'numpy',
+ ),
+]
diff --git a/pydantic_ai/_pydantic.py b/pydantic_ai/_pydantic.py
index 62cf87f487..f6d2e180f5 100644
--- a/pydantic_ai/_pydantic.py
+++ b/pydantic_ai/_pydantic.py
@@ -5,12 +5,9 @@
from __future__ import annotations as _annotations
-import re
-from inspect import Parameter, Signature, signature
-from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, cast, get_origin
+from inspect import Parameter, signature
+from typing import TYPE_CHECKING, Any, TypedDict, cast, get_origin
-from _griffe.enumerations import DocstringSectionKind
-from _griffe.models import Docstring, Object as GriffeObject
from pydantic import ConfigDict, TypeAdapter
from pydantic._internal import _decorators, _generate_schema, _typing_extra
from pydantic._internal._config import ConfigWrapper
@@ -19,6 +16,7 @@
from pydantic.plugin._schema_validator import create_schema_validator
from pydantic_core import SchemaValidator, core_schema
+from ._griffe import doc_descriptions
from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
if TYPE_CHECKING:
@@ -66,7 +64,7 @@ def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, R
var_positional_field: str | None = None
errors: list[str] = []
decorators = _decorators.DecoratorInfos()
- description, field_descriptions = _doc_descriptions(function, sig)
+ description, field_descriptions = doc_descriptions(function, sig)
for index, (name, p) in enumerate(sig.parameters.items()):
if p.annotation is sig.empty:
@@ -192,127 +190,6 @@ def _build_schema(
return td_schema, None
-DocstringStyle = Literal['google', 'numpy', 'sphinx']
-
-
-def _doc_descriptions(
- func: Callable[..., Any], sig: Signature, *, style: DocstringStyle | None = None
-) -> tuple[str, dict[str, str]]:
- """Extract the function description and parameter descriptions from a function's docstring.
-
- Returns:
- A tuple of (main function description, parameter descriptions).
- """
- doc = func.__doc__
- if doc is None:
- return '', {}
-
- # see https://github.com/mkdocstrings/griffe/issues/293
- parent = cast(GriffeObject, sig)
-
- docstring = Docstring(doc, lineno=1, parser=style or _infer_docstring_style(doc), parent=parent)
- sections = docstring.parse()
-
- params = {}
- if parameters := next((p for p in sections if p.kind == DocstringSectionKind.parameters), None):
- params = {p.name: p.description for p in parameters.value}
-
- main_desc = ''
- if main := next((p for p in sections if p.kind == DocstringSectionKind.text), None):
- main_desc = main.value
-
- return main_desc, params
-
-
-def _infer_docstring_style(doc: str) -> DocstringStyle:
- """Simplistic docstring style inference."""
- for pattern, replacements, style in _docstring_style_patterns:
- matches = (
- re.search(pattern.format(replacement), doc, re.IGNORECASE | re.MULTILINE) for replacement in replacements
- )
- if any(matches):
- return style
- # fallback to google style
- return 'google'
-
-
-# See https://github.com/mkdocstrings/griffe/issues/329#issuecomment-2425017804
-_docstring_style_patterns: list[tuple[str, list[str], DocstringStyle]] = [
- (
- r'\n[ \t]*:{0}([ \t]+\w+)*:([ \t]+.+)?\n',
- [
- 'param',
- 'parameter',
- 'arg',
- 'argument',
- 'key',
- 'keyword',
- 'type',
- 'var',
- 'ivar',
- 'cvar',
- 'vartype',
- 'returns',
- 'return',
- 'rtype',
- 'raises',
- 'raise',
- 'except',
- 'exception',
- ],
- 'sphinx',
- ),
- (
- r'\n[ \t]*{0}:([ \t]+.+)?\n[ \t]+.+',
- [
- 'args',
- 'arguments',
- 'params',
- 'parameters',
- 'keyword args',
- 'keyword arguments',
- 'other args',
- 'other arguments',
- 'other params',
- 'other parameters',
- 'raises',
- 'exceptions',
- 'returns',
- 'yields',
- 'receives',
- 'examples',
- 'attributes',
- 'functions',
- 'methods',
- 'classes',
- 'modules',
- 'warns',
- 'warnings',
- ],
- 'google',
- ),
- (
- r'\n[ \t]*{0}\n[ \t]*---+\n',
- [
- 'deprecated',
- 'parameters',
- 'other parameters',
- 'returns',
- 'yields',
- 'receives',
- 'raises',
- 'warns',
- 'attributes',
- 'functions',
- 'methods',
- 'classes',
- 'modules',
- ],
- 'numpy',
- ),
-]
-
-
def _is_call_ctx(annotation: Any) -> bool:
from .dependencies import CallContext
diff --git a/pydantic_ai/_retriever.py b/pydantic_ai/_retriever.py
index 7de34a5fec..543f2add38 100644
--- a/pydantic_ai/_retriever.py
+++ b/pydantic_ai/_retriever.py
@@ -5,13 +5,12 @@
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, cast
-import pydantic_core
from pydantic import ValidationError
from pydantic_core import SchemaValidator
from . import _pydantic, _utils, messages
from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
-from .exceptions import ModelRetry
+from .exceptions import ModelRetry, UnexpectedModelBehaviour
# Usage `RetrieverEitherFunc[AgentDependencies, P]`
RetrieverEitherFunc = _utils.Either[
@@ -64,7 +63,7 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes
else:
args_dict = self.validator.validate_python(message.args.args_object)
except ValidationError as e:
- return self._on_error(e.errors(include_url=False), message)
+ return self._on_error(e, message)
args, kwargs = self._call_args(deps, args_dict, message)
try:
@@ -75,7 +74,7 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes
function = cast(Callable[[Any], str], self.function.whichever())
response_content = await _utils.run_in_executor(function, *args, **kwargs)
except ModelRetry as e:
- return self._on_error(e.message, message)
+ return self._on_error(e, message)
self._current_retry = 0
return messages.ToolReturn(
@@ -98,14 +97,16 @@ def _call_args(
return args, args_dict
- def _on_error(
- self, content: list[pydantic_core.ErrorDetails] | str, call_message: messages.ToolCall
- ) -> messages.RetryPrompt:
+ def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
self._current_retry += 1
if self._current_retry > self.max_retries:
# TODO custom error with details of the retriever
- raise
+ raise UnexpectedModelBehaviour(f'Retriever exceeded max retries count of {self.max_retries}') from exc
else:
+ if isinstance(exc, ValidationError):
+ content = exc.errors(include_url=False)
+ else:
+ content = exc.message
return messages.RetryPrompt(
tool_name=call_message.tool_name,
content=content,
diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py
index aa81d3d326..bd38da5878 100644
--- a/pydantic_ai/agent.py
+++ b/pydantic_ai/agent.py
@@ -326,7 +326,7 @@ def override_model(self, overriding_model: models.Model | models.KnownModelName)
def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDeps]
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
- """Decorator to register a system prompt function that takes `CallContext` as it's only argument."""
+ """Decorator to register a system prompt function that optionally takes `CallContext` as it's only argument."""
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
return func
diff --git a/pydantic_ai/dependencies.py b/pydantic_ai/dependencies.py
index 96c82f8ff6..ee5a087830 100644
--- a/pydantic_ai/dependencies.py
+++ b/pydantic_ai/dependencies.py
@@ -1,10 +1,10 @@
from __future__ import annotations as _annotations
-from collections.abc import Awaitable
+from collections.abc import Awaitable, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
-from typing_extensions import Concatenate, ParamSpec
+from typing_extensions import Concatenate, ParamSpec, TypeAlias
if TYPE_CHECKING:
from .result import ResultData
@@ -20,6 +20,7 @@
'RetrieverContextFunc',
'RetrieverPlainFunc',
'RetrieverParams',
+ 'JsonData',
)
AgentDeps = TypeVar('AgentDeps')
@@ -65,7 +66,10 @@ class CallContext(Generic[AgentDeps]):
Usage `ResultValidator[AgentDeps, ResultData]`.
"""
-RetrieverReturnValue = Union[str, Awaitable[str], dict[str, Any], Awaitable[dict[str, Any]]]
+JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
+"""Type representing any JSON data."""
+
+RetrieverReturnValue = Union[JsonData, Awaitable[JsonData]]
"""Return value of a retriever function."""
RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], RetrieverParams], RetrieverReturnValue]
"""A retriever function that takes `CallContext` as the first argument.
diff --git a/pydantic_ai/messages.py b/pydantic_ai/messages.py
index 4da93aa197..3d6ec3f45d 100644
--- a/pydantic_ai/messages.py
+++ b/pydantic_ai/messages.py
@@ -1,12 +1,15 @@
from __future__ import annotations as _annotations
import json
+from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from datetime import datetime
-from typing import Annotated, Any, Literal, Union
+from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
import pydantic
import pydantic_core
+from pydantic import TypeAdapter
+from typing_extensions import TypeAlias, TypeAliasType
from . import _pydantic
from ._utils import now_utc as _now_utc
@@ -41,7 +44,13 @@ class UserPrompt:
"""Message type identifier, this type is available on all message as a discriminator."""
-tool_return_value_object = _pydantic.LazyTypeAdapter(dict[str, Any])
+JsonData: TypeAlias = 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]'
+if not TYPE_CHECKING:
+ # work around for https://github.com/pydantic/pydantic/issues/10873
+ # this is need for pydantic to work both `json_ta` and `MessagesTypeAdapter` at the bottom of this file
+ JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]')
+
+json_ta: TypeAdapter[JsonData] = TypeAdapter(JsonData)
@dataclass
@@ -50,7 +59,7 @@ class ToolReturn:
tool_name: str
"""The name of the "tool" was called."""
- content: str | dict[str, Any]
+ content: JsonData
"""The return value."""
tool_id: str | None = None
"""Optional tool identifier, this is used by some models including OpenAI."""
@@ -63,14 +72,15 @@ def model_response_str(self) -> str:
if isinstance(self.content, str):
return self.content
else:
- content = tool_return_value_object.validate_python(self.content)
- return tool_return_value_object.dump_json(content).decode()
+ content = json_ta.validate_python(self.content)
+ return json_ta.dump_json(content).decode()
- def model_response_object(self) -> dict[str, Any]:
- if isinstance(self.content, str):
- return {'return_value': self.content}
+ def model_response_object(self) -> dict[str, JsonData]:
+ # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
+ if isinstance(self.content, dict):
+ return json_ta.validate_python(self.content) # pyright: ignore[reportReturnType]
else:
- return tool_return_value_object.validate_python(self.content)
+ return {'return_value': json_ta.validate_python(self.content)}
@dataclass
diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py
index d944f8ee05..8d123e7332 100644
--- a/pydantic_ai/models/gemini.py
+++ b/pydantic_ai/models/gemini.py
@@ -472,26 +472,30 @@ class _GeminiTextContent(TypedDict):
class _GeminiTools(TypedDict):
- function_declarations: list[_GeminiFunction]
+ function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]]
class _GeminiFunction(TypedDict):
name: str
description: str
- parameters: dict[str, Any]
+ parameters: NotRequired[dict[str, Any]]
"""
ObjectJsonSchema isn't really true since Gemini only accepts a subset of JSON Schema
+ and
+
"""
def _function_from_abstract_tool(tool: AbstractToolDefinition) -> _GeminiFunction:
json_schema = _GeminiJsonSchema(tool.json_schema).simplify()
- return _GeminiFunction(
+ f = _GeminiFunction(
name=tool.name,
description=tool.description,
- parameters=json_schema,
)
+ if json_schema.get('properties'):
+ f['parameters'] = json_schema
+ return f
class _GeminiToolConfig(TypedDict):
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 0ba44c2da9..807da50aed 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -1,9 +1,10 @@
from __future__ import annotations as _annotations
import asyncio
+import re
import sys
from collections.abc import AsyncIterator, Iterable
-from datetime import datetime
+from dataclasses import dataclass, field
from types import ModuleType
from typing import Any
@@ -23,9 +24,30 @@
)
from pydantic_ai.models import KnownModelName, Model
from pydantic_ai.models.function import AgentInfo, DeltaToolCalls, FunctionModel
+from pydantic_ai.models.test import TestModel
from tests.conftest import ClientWithHandler
+@pytest.fixture(scope='module', autouse=True)
+def register_modules():
+ class FakeTable:
+ def get(self, name: str) -> int | None:
+ if name == 'John Doe':
+ return 123
+
+ @dataclass
+ class DatabaseConn:
+ users: FakeTable = field(default_factory=FakeTable)
+
+ module_name = 'fake_database'
+ sys.modules[module_name] = module = ModuleType(module_name)
+ module.__dict__.update({'DatabaseConn': DatabaseConn})
+
+ yield
+
+ sys.modules.pop(module_name)
+
+
def find_filter_examples() -> Iterable[CodeExample]:
for ex in find_examples('docs', 'pydantic_ai'):
if ex.path.name != '_utils.py':
@@ -36,21 +58,23 @@ def find_filter_examples() -> Iterable[CodeExample]:
def test_docs_examples(
example: CodeExample, eval_example: EvalExample, mocker: MockerFixture, client_with_handler: ClientWithHandler
):
- if example.path.name == '_utils.py':
- return
# debug(example)
mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model)
- mocker.patch('pydantic_ai._utils.datetime', MockedDatetime)
mocker.patch('httpx.Client.get', side_effect=http_request)
mocker.patch('httpx.Client.post', side_effect=http_request)
mocker.patch('httpx.AsyncClient.get', side_effect=async_http_request)
mocker.patch('httpx.AsyncClient.post', side_effect=async_http_request)
+ mocker.patch('random.randint', return_value=4)
+
+ prefix_settings = example.prefix_settings()
ruff_ignore: list[str] = ['D']
if str(example.path).endswith('docs/index.md'):
ruff_ignore.append('F841')
- eval_example.set_config(ruff_ignore=ruff_ignore)
+ eval_example.set_config(ruff_ignore=ruff_ignore, target_version='py39')
+
+ eval_example.print_callback = print_callback
call_name = 'main'
if 'def test_application_code' in example.source:
@@ -63,9 +87,15 @@ def test_docs_examples(
eval_example.lint(example)
module_dict = eval_example.run_print_check(example, call=call_name)
- if example.path.name == 'dependencies.md' and 'title="joke_app.py"' in example.prefix:
- sys.modules['joke_app'] = module = ModuleType('joke_app')
- module.__dict__.update(module_dict)
+ if title := prefix_settings.get('title'):
+ if title.endswith('.py'):
+ module_name = title[:-3]
+ sys.modules[module_name] = module = ModuleType(module_name)
+ module.__dict__.update(module_dict)
+
+
+def print_callback(s: str) -> str:
+ return re.sub(r'datetime.datetime\(.+?\)', 'datetime.datetime(...)', s, flags=re.DOTALL)
def http_request(url: str, **kwargs: Any) -> httpx.Response:
@@ -78,7 +108,7 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response:
return http_request(url, **kwargs)
-text_responses = {
+text_responses: dict[str, str | ToolCall] = {
'What is the weather like in West London and in Wiltshire?': 'The weather in West London is raining, while in Wiltshire it is sunny.',
'Tell me a joke.': 'Did you hear about the toothpaste scandal? They called it Colgate.',
'Explain?': 'This is an excellent joke invent by Samuel Colvin, it needs no explanation.',
@@ -88,22 +118,44 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response:
'Who was Albert Einstein?': 'Albert Einstein was a German-born theoretical physicist.',
'What was his most famous equation?': "Albert Einstein's most famous equation is (E = mc^2).",
'What is the date?': 'Hello Frank, the date today is 2032-01-02.',
+ 'Put my money on square eighteen': ToolCall(tool_name='roulette_wheel', args=ArgsObject({'square': 18})),
+ 'I bet five is the winner': ToolCall(tool_name='roulette_wheel', args=ArgsObject({'square': 5})),
+ 'My guess is 4': ToolCall(tool_name='roll_dice', args=ArgsObject({})),
+ 'Send a message to John Doe asking for coffee next week': ToolCall(
+ tool_name='get_user_by_name', args=ArgsObject({'name': 'John'})
+ ),
+ 'Please get me the volume of a box with size 6.': ToolCall(tool_name='calc_volume', args=ArgsObject({'size': 6})),
}
async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
m = messages[-1]
if m.role == 'user':
- if text_response := text_responses.get(m.content):
- return ModelTextResponse(content=text_response)
+ if response := text_responses.get(m.content):
+ if isinstance(response, str):
+ return ModelTextResponse(content=response)
+ else:
+ return ModelStructuredResponse(calls=[response])
- if m.role == 'user' and m.content == 'Put my money on square eighteen':
- return ModelStructuredResponse(calls=[ToolCall(tool_name='roulette_wheel', args=ArgsObject({'square': 18}))])
- elif m.role == 'user' and m.content == 'I bet five is the winner':
- return ModelStructuredResponse(calls=[ToolCall(tool_name='roulette_wheel', args=ArgsObject({'square': 5}))])
elif m.role == 'tool-return' and m.tool_name == 'roulette_wheel':
win = m.content == 'winner'
return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsObject({'response': win}))])
+ elif m.role == 'tool-return' and m.tool_name == 'roll_dice':
+ return ModelStructuredResponse(calls=[ToolCall(tool_name='get_player_name', args=ArgsObject({}))])
+ elif m.role == 'tool-return' and m.tool_name == 'get_player_name':
+ return ModelTextResponse(content="Congratulations Adam, you guessed correctly! You're a winner!")
+ if m.role == 'retry-prompt' and isinstance(m.content, str) and m.content.startswith("No user found with name 'Joh"):
+ return ModelStructuredResponse(
+ calls=[ToolCall(tool_name='get_user_by_name', args=ArgsObject({'name': 'John Doe'}))]
+ )
+ elif m.role == 'tool-return' and m.tool_name == 'get_user_by_name':
+ args = {
+ 'message': 'Hello John, would you be free for coffee sometime next week? Let me know what works for you!',
+ 'user_id': 123,
+ }
+ return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsObject(args))])
+ elif m.role == 'retry-prompt' and m.tool_name == 'calc_volume':
+ return ModelStructuredResponse(calls=[ToolCall(tool_name='calc_volume', args=ArgsObject({'size': 6}))])
else:
sys.stdout.write(str(debug.format(messages, info)))
raise RuntimeError(f'Unexpected message: {m}')
@@ -112,23 +164,23 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyRespo
async def stream_model_logic(messages: list[Message], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]:
m = messages[-1]
if m.role == 'user':
- if text_response := text_responses.get(m.content):
- *words, last_word = text_response.split(' ')
- for work in words:
- yield f'{work} '
- await asyncio.sleep(0.05)
- yield last_word
- return
+ if response := text_responses.get(m.content):
+ if isinstance(response, str):
+ *words, last_word = response.split(' ')
+ for work in words:
+ yield f'{work} '
+ await asyncio.sleep(0.05)
+ yield last_word
+ return
+ else:
+ raise NotImplementedError('todo')
sys.stdout.write(str(debug.format(messages, info)))
raise RuntimeError(f'Unexpected message: {m}')
-def mock_infer_model(_model: Model | KnownModelName) -> Model:
- return FunctionModel(model_logic, stream_function=stream_model_logic)
-
-
-class MockedDatetime(datetime):
- @classmethod
- def now(cls, tz: Any = None) -> datetime: # type: ignore
- return datetime(2032, 1, 2, 3, 4, 5, 6, tzinfo=tz)
+def mock_infer_model(model: Model | KnownModelName) -> Model:
+ if isinstance(model, (FunctionModel, TestModel)):
+ return model
+ else:
+ return FunctionModel(model_logic, stream_function=stream_model_logic)
diff --git a/tests/typed_agent.py b/tests/typed_agent.py
index 6ade228ec5..d31293aef5 100644
--- a/tests/typed_agent.py
+++ b/tests/typed_agent.py
@@ -41,6 +41,11 @@ def ok_retriever_plain(x: str) -> dict[str, str]:
return {'x': x}
+@typed_agent.retriever_plain
+def ok_json_list(x: str) -> list[Union[str, int]]:
+ return [x, 1]
+
+
@typed_agent.retriever_context
async def bad_retriever1(ctx: CallContext[MyDeps], x: str) -> str:
total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined]
@@ -53,8 +58,8 @@ async def bad_retriever2(ctx: CallContext[int], x: str) -> str:
@typed_agent.retriever_plain # type: ignore[arg-type]
-async def bad_retriever_return(x: int) -> list[int]:
- return [x]
+async def bad_retriever_return(x: int) -> list[MyDeps]:
+ return [MyDeps(1, x)]
with expect_error(ValueError):