From 33903b459883fb6b3e6f5aa8c9ff05f0eb5c00fa Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 18 Nov 2024 17:43:29 +0000 Subject: [PATCH 1/4] fixing agent docs and refactoring retriever return types --- docs/agents.md | 323 ++++++++++++++++++++++++++++++-- docs/dependencies.md | 4 +- docs/img/dice-diagram-dark.svg | 11 ++ docs/img/dice-diagram-light.svg | 11 ++ docs/message-history.md | 60 ++---- docs/results.md | 4 +- pydantic_ai/_griffe.py | 128 +++++++++++++ pydantic_ai/_pydantic.py | 131 +------------ pydantic_ai/_retriever.py | 17 +- pydantic_ai/agent.py | 2 +- pydantic_ai/dependencies.py | 10 +- pydantic_ai/messages.py | 28 ++- pydantic_ai/models/gemini.py | 12 +- tests/test_examples.py | 112 ++++++++--- 14 files changed, 607 insertions(+), 246 deletions(-) create mode 100644 docs/img/dice-diagram-dark.svg create mode 100644 docs/img/dice-diagram-light.svg create mode 100644 pydantic_ai/_griffe.py diff --git a/docs/agents.md b/docs/agents.md index 68d1739770..7b73aec716 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -120,8 +120,8 @@ System prompts might seem simple at first glance since they're just strings (or Generally, system prompts fall into two categories: -1. **Static system prompts**: These are known when writing the code and can be defined via the `system_prompt` parameter of the `Agent` constructor. -2. **Dynamic system prompts**: These aren't known until runtime and should be defined via functions decorated with `@agent.system_prompt`. +1. **Static system prompts**: These are known when writing the code and can be defined via the `system_prompt` parameter of the [`Agent` constructor][pydantic_ai.Agent.__init__]. +2. **Dynamic system prompts**: These aren't known until runtime and should be defined via functions decorated with [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt]. You can add both to a single agent; they're concatenated in the order they're defined at runtime. @@ -161,26 +161,317 @@ print(result.data) ## Retrievers -* two different retriever decorators (`retriver_plain` and `retriever_context`) depending on whether you want to use the context or not, show an example using both -* retriever parameters are extracted and used to build the schema for the tool, then validated with pydantic -* if a retriever has a single "model like" parameter (e.g. pydantic mode, dataclass, typed dict), the schema for the tool will but just that type -* docstrings are parsed to get the tool description, thanks to griffe docs for each parameter are extracting using Google, numpy or sphinx docstring styling -* You can raise `ModelRetry` from within a retriever to suggest to the model it should retry -* the return type of retriever can either be `str` or a JSON object typed as `dict[str, Any]` as some models (e.g. Gemini) support structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data +Retrievers provide a mechanism for models to request extra information to help them generate a response. + +They're useful when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic by deferring some of the logic required to generate a response to another tool. + +!!! info "Retrievers vs. RAG" + Retrievers are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. + + The main semantic difference between PydanticAI Retreivers and RAG is RAG is synonymous with vector search, while PydanticAI retrievers are more general purpose. (Note: we might add support for some vector search functionality in the future, particuarly an API for generating embeddings, see [#58](https://github.com/pydantic/pydantic-ai/issues/58)) + +There are two different decorator functions to register retrievers: + +1. [`@agent.retriever_plain`][pydantic_ai.Agent.retriever_plain] — for retrievers that don't need access to the agent [context][pydantic_ai.dependencies.CallContext] +2. [`@agent.retriever_context`][pydantic_ai.Agent.retriever_context] — for retrievers that do need access to the agent [context][pydantic_ai.dependencies.CallContext] + +Here's an example using both: + +```py title="dice_game.py" +import random + +from pydantic_ai import Agent, CallContext + +agent = Agent( + 'gemini-1.5-flash', # (1)! + deps_type=str, # (2)! + system_prompt=( + "You're a dice game, you should roll the dice and see if the number " + "you got back matches the user's guess, if so tell them they're a winner. " + "Use the player's name in the response." + ), +) + + +@agent.retriever_plain # (3)! +def roll_dice() -> str: + """Roll a six-sided dice and return the result.""" + return str(random.randint(1, 6)) + + +@agent.retriever_context # (4)! +def get_player_name(ctx: CallContext[str]) -> str: + """Get the player's name.""" + return ctx.deps + + +dice_result = agent.run_sync('My guess is 4', deps='Adam') # (5)! +print(dice_result.data) +#> Congratulations Adam, you guessed correctly! You're a winner! +``` + +1. This is a pretty simple task, so we can use the fast and cheap Gemini flash model. +2. We pass the user's name as the dependency, to keep things simple we use just the name as a string as the dependency. +3. This retriever doesn't need any context, it just returns a random number. You could probably use a dynamic system prompt in this case. +4. This retriever needs the player's name, so it uses `CallContext` to access dependencies which are just the player's name. +5. Run the agent, passing the player's name as the dependency. + +_(This example is complete, it can be run "as is")_ + +Let's print the messages from that game to see what happened: + +```python title="dice_game_messages.py" +from dice_game import dice_result + +print(dice_result.all_messages()) +""" +[ + SystemPrompt( + content="You're a dice game, you should roll the dice and see if the number you got back matches the user's guess, if so tell them they're a winner. Use the player's name in the response.", + role='system', + ), + UserPrompt( + content='My guess is 4', + timestamp=datetime.datetime(...), + role='user', + ), + ModelStructuredResponse( + calls=[ + ToolCall( + tool_name='roll_dice', args=ArgsObject(args_object={}), tool_id=None + ) + ], + timestamp=datetime.datetime(...), + role='model-structured-response', + ), + ToolReturn( + tool_name='roll_dice', + content='4', + tool_id=None, + timestamp=datetime.datetime(...), + role='tool-return', + ), + ModelStructuredResponse( + calls=[ + ToolCall( + tool_name='get_player_name', + args=ArgsObject(args_object={}), + tool_id=None, + ) + ], + timestamp=datetime.datetime(...), + role='model-structured-response', + ), + ToolReturn( + tool_name='get_player_name', + content='Adam', + tool_id=None, + timestamp=datetime.datetime(...), + role='tool-return', + ), + ModelTextResponse( + content="Congratulations Adam, you guessed correctly! You're a winner!", + timestamp=datetime.datetime(...), + role='model-text-response', + ), +] +""" +``` + +We can represent that as a flow diagram, thus: + +![Dice game flow diagram](./img/dice-diagram-light.svg#only-light) +![Dice game flow diagram](./img/dice-diagram-dark.svg#only-dark) + +### Retrievers, tools, and schema + +Under the hood, retrievers use the model's "tools" or "functions" API to let the model know what retrievers are available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call retrievers while others end the run and return a result. + +Function parameters are extracted from the function signature, and all parameters except `CallContext` are used to build the schema for that tool call. + +Even better, PydanticAI extracts the docstring from retriever functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and add them to the schema. + +[Griffe supports](https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings) extracting parameter descriptions from `google`, `numpy` and `sphinx` style docstrings, PydanticAI will infer the format to use based on the docstring. We'll add support in future to explicitly set the style to use, and warn/error if not all parameters are documented, see [#59](https://github.com/pydantic/pydantic-ai/issues/59). + +To demonstrate retriever schema, here we use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to print the schema a model would receive: + +```py title="retriever_schema.py" +from pydantic_ai import Agent +from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse +from pydantic_ai.models.function import AgentInfo, FunctionModel + +agent = Agent() + + +@agent.retriever_plain +def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: + """Get me foobar. + + Args: + a: apple pie + b: banana cake + c: carrot smoothie + """ + return f'{a} {b} {c}' + + +def print_schema(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: + retriever = info.retrievers['foobar'] + print(retriever.description) + #> Get me foobar. + print(retriever.json_schema) + """ + { + 'description': 'Get me foobar.', + 'properties': { + 'a': {'description': 'apple pie', 'title': 'A', 'type': 'integer'}, + 'b': {'description': 'banana cake', 'title': 'B', 'type': 'string'}, + 'c': { + 'additionalProperties': {'items': {'type': 'number'}, 'type': 'array'}, + 'description': 'carrot smoothie', + 'title': 'C', + 'type': 'object', + }, + }, + 'required': ['a', 'b', 'c'], + 'type': 'object', + 'additionalProperties': False, + } + """ + return ModelTextResponse(content='foobar') + + +agent.run_sync('hello', model=FunctionModel(print_schema)) +``` + +_(This example is complete, it can be run "as is")_ + +The return type of retriever can either be `str` or a JSON object typed as `dict[str, Any]` as some models (e.g. Gemini) support structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. ## Reflection and self-correction -* validation errors from both retrievers parameter validation and structured result validation can be passed back to the with a request to retry -* as described above, you can also raise `ModelRetry` from within a retriever or result validator to tell the model it should retry -* the default retry count is 1, but can be altered both on a whole agent, or on a per-retriever basis and result validator basis -* you can access the current retry count from within a retriever or result validator via `ctx.retry` +Validation errors from both retriever parameter validation and [structured result validation](results.md#structured-result-validation) can be passed back to the model with a request to retry. + +You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [retriever](#retrievers) or [result validators functions](results.md#result-validators-functions) to tell the model it should retry. + +- The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific retriever][pydantic_ai.Agent.retriever_context], or a [result validator][pydantic_ai.Agent.__init__]. +- You can access the current retry count from within a retriever or result validator via [`ctx.retry`][pydantic_ai.dependencies.CallContext]. + +Here's an example: + +```py title="retriever_retry.py" +from fake_database import DatabaseConn +from pydantic import BaseModel + +from pydantic_ai import Agent, CallContext, ModelRetry + + +class ChatResult(BaseModel): + user_id: int + message: str + + +agent = Agent( + 'openai:gpt-4o', + deps_type=DatabaseConn, + result_type=ChatResult, +) + + +@agent.retriever_context(retries=2) +def get_user_by_name(ctx: CallContext[DatabaseConn], name: str) -> int: + """Get a user's ID from their full name.""" + print(name) + #> John + #> John Doe + user_id = ctx.deps.users.get(name=name) + if user_id is None: + raise ModelRetry( + f'No user found with name {name!r}, remember to provide their full name' + ) + return user_id + + +result = agent.run_sync( + 'Send a message to John Doe asking for coffee next week', deps=DatabaseConn() +) +print(result.data) +""" +user_id=123 message='Hello John, would you be free for coffee sometime next week? Let me know what works for you!' +""" +``` ## Model errors -* If models behave unexpectedly, e.g. the retry limit is exceed, agent runs will raise `UnexpectedModelBehaviour` exceptions -* If you use PydanticAI in correctly, we try to raise a `UserError` with a helpful message -* show an except of a `UnexpectedModelBehaviour` being raised -* if a `UnexpectedModelBehaviour` is raised, you may want to access the [`.last_run_messages`][pydantic_ai.Agent.last_run_messages] attribute of an agent to see the messages exchanged that led to the error, show an example of accessing `.last_run_messages` in an except block to get more details +If models behave unexpectedly (e.g., the retry limit is exceeded, or their api returns `503`), agent runs will raise [`UnexpectedModelBehaviour`][pydantic_ai.exceptions.UnexpectedModelBehaviour]. + +In these cases, [`agent.last_run_messages`][pydantic_ai.Agent.last_run_messages] can be used to access the messages exchanged during the run to help diagnose the issue. + +```python +from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehaviour + +agent = Agent('openai:gpt-4o') + + +@agent.retriever_plain +def calc_volume(size: int) -> int: # (1)! + if size == 42: + return size**3 + else: + raise ModelRetry('Please try again.') + + +try: + result = agent.run_sync('Please get me the volume of a box with size 6.') +except UnexpectedModelBehaviour as e: + print('An error occurred:', e) + #> An error occurred: Retriever exceeded max retries count of 1 + print('cause:', repr(e.__cause__)) + #> cause: ModelRetry('Please try again.') + print('messages:', agent.last_run_messages) + """ + messages: + [ + UserPrompt( + content='Please get me the volume of a box with size 6.', + timestamp=datetime.datetime(...), + role='user', + ), + ModelStructuredResponse( + calls=[ + ToolCall( + tool_name='calc_volume', + args=ArgsObject(args_object={'size': 6}), + tool_id=None, + ) + ], + timestamp=datetime.datetime(...), + role='model-structured-response', + ), + RetryPrompt( + content='Please try again.', + tool_name='calc_volume', + tool_id=None, + timestamp=datetime.datetime(...), + role='retry-prompt', + ), + ModelStructuredResponse( + calls=[ + ToolCall( + tool_name='calc_volume', + args=ArgsObject(args_object={'size': 6}), + tool_id=None, + ) + ], + timestamp=datetime.datetime(...), + role='model-structured-response', + ), + ] + """ +else: + print(result.data) +``` +1. Define a retriever that will raise `ModelRetry` repeatedly in this case. ## API Reference diff --git a/docs/dependencies.md b/docs/dependencies.md index 79a9660d1f..d97c66cd38 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -1,6 +1,6 @@ # Dependencies -PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators). +PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators-functions). Matching PydanticAI's design philosophy, our dependency system tries to use existing best practice in Python development rather than inventing esoteric "magic", this should make dependencies type-safe, understandable easier to test and ultimately easier to deploy in production. @@ -159,7 +159,7 @@ _(This example is complete, it can be run "as is")_ ## Full Example -As well as system prompts, dependencies can be used in [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators). +As well as system prompts, dependencies can be used in [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators-functions). ```python title="full_example.py" hl_lines="27-35 38-48" from dataclasses import dataclass diff --git a/docs/img/dice-diagram-dark.svg b/docs/img/dice-diagram-dark.svg new file mode 100644 index 0000000000..b0b77e93f5 --- /dev/null +++ b/docs/img/dice-diagram-dark.svg @@ -0,0 +1,11 @@ + + + + + + + + dice_game.pyLLMSystemPromptYou're a dice game...UserPromptMy guess is 4ModelStructuredResponseroll_dice()ToolReturn4ModelStructuredResponseget_player_name()ToolReturnAdamModelTextResponseCongratulations Adam, ...End Run diff --git a/docs/img/dice-diagram-light.svg b/docs/img/dice-diagram-light.svg new file mode 100644 index 0000000000..64856c9340 --- /dev/null +++ b/docs/img/dice-diagram-light.svg @@ -0,0 +1,11 @@ + + + + + + + + dice_game.pyLLMSystemPromptYou're a dice game...UserPromptMy guess is 4ModelStructuredResponseroll_dice()ToolReturn4ModelStructuredResponseget_player_name()ToolReturnAdamModelTextResponseCongratulations Adam, ...End Run diff --git a/docs/message-history.md b/docs/message-history.md index cb2c8a8ff2..528b17e6d3 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -41,16 +41,12 @@ print(result.all_messages()) SystemPrompt(content='Be a helpful assistant.', role='system'), UserPrompt( content='Tell me a joke.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='user', ), ModelTextResponse( content='Did you hear about the toothpaste scandal? They called it Colgate.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='model-text-response', ), ] @@ -62,16 +58,12 @@ print(result.new_messages()) [ UserPrompt( content='Tell me a joke.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='user', ), ModelTextResponse( content='Did you hear about the toothpaste scandal? They called it Colgate.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='model-text-response', ), ] @@ -96,9 +88,7 @@ async def main(): SystemPrompt(content='Be a helpful assistant.', role='system'), UserPrompt( content='Tell me a joke.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='user', ), ] @@ -120,16 +110,12 @@ async def main(): SystemPrompt(content='Be a helpful assistant.', role='system'), UserPrompt( content='Tell me a joke.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='user', ), ModelTextResponse( content='Did you hear about the toothpaste scandal? They called it Colgate.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='model-text-response', ), ] @@ -175,30 +161,22 @@ print(result2.all_messages()) SystemPrompt(content='Be a helpful assistant.', role='system'), UserPrompt( content='Tell me a joke.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='user', ), ModelTextResponse( content='Did you hear about the toothpaste scandal? They called it Colgate.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='model-text-response', ), UserPrompt( content='Explain?', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='user', ), ModelTextResponse( content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='model-text-response', ), ] @@ -233,30 +211,22 @@ print(result2.all_messages()) SystemPrompt(content='Be a helpful assistant.', role='system'), UserPrompt( content='Tell me a joke.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='user', ), ModelTextResponse( content='Did you hear about the toothpaste scandal? They called it Colgate.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='model-text-response', ), UserPrompt( content='Explain?', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='user', ), ModelTextResponse( content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.', - timestamp=datetime.datetime( - 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc - ), + timestamp=datetime.datetime(...), role='model-text-response', ), ] diff --git a/docs/results.md b/docs/results.md index 94ef57df20..ca7f1ce1fe 100644 --- a/docs/results.md +++ b/docs/results.md @@ -2,7 +2,9 @@ TODO -## Result Validators +## Structured result validation + +## Result validators functions TODO diff --git a/pydantic_ai/_griffe.py b/pydantic_ai/_griffe.py new file mode 100644 index 0000000000..e69dbf7003 --- /dev/null +++ b/pydantic_ai/_griffe.py @@ -0,0 +1,128 @@ +from __future__ import annotations as _annotations + +import re +from inspect import Signature +from typing import Any, Callable, Literal, cast + +from _griffe.enumerations import DocstringSectionKind +from _griffe.models import Docstring, Object as GriffeObject + +DocstringStyle = Literal['google', 'numpy', 'sphinx'] + + +def doc_descriptions( + func: Callable[..., Any], sig: Signature, *, style: DocstringStyle | None = None +) -> tuple[str, dict[str, str]]: + """Extract the function description and parameter descriptions from a function's docstring. + + Returns: + A tuple of (main function description, parameter descriptions). + """ + doc = func.__doc__ + if doc is None: + return '', {} + + # see https://github.com/mkdocstrings/griffe/issues/293 + parent = cast(GriffeObject, sig) + + docstring = Docstring(doc, lineno=1, parser=style or _infer_docstring_style(doc), parent=parent) + sections = docstring.parse() + + params = {} + if parameters := next((p for p in sections if p.kind == DocstringSectionKind.parameters), None): + params = {p.name: p.description for p in parameters.value} + + main_desc = '' + if main := next((p for p in sections if p.kind == DocstringSectionKind.text), None): + main_desc = main.value + + return main_desc, params + + +def _infer_docstring_style(doc: str) -> DocstringStyle: + """Simplistic docstring style inference.""" + for pattern, replacements, style in _docstring_style_patterns: + matches = ( + re.search(pattern.format(replacement), doc, re.IGNORECASE | re.MULTILINE) for replacement in replacements + ) + if any(matches): + return style + # fallback to google style + return 'google' + + +# See https://github.com/mkdocstrings/griffe/issues/329#issuecomment-2425017804 +_docstring_style_patterns: list[tuple[str, list[str], DocstringStyle]] = [ + ( + r'\n[ \t]*:{0}([ \t]+\w+)*:([ \t]+.+)?\n', + [ + 'param', + 'parameter', + 'arg', + 'argument', + 'key', + 'keyword', + 'type', + 'var', + 'ivar', + 'cvar', + 'vartype', + 'returns', + 'return', + 'rtype', + 'raises', + 'raise', + 'except', + 'exception', + ], + 'sphinx', + ), + ( + r'\n[ \t]*{0}:([ \t]+.+)?\n[ \t]+.+', + [ + 'args', + 'arguments', + 'params', + 'parameters', + 'keyword args', + 'keyword arguments', + 'other args', + 'other arguments', + 'other params', + 'other parameters', + 'raises', + 'exceptions', + 'returns', + 'yields', + 'receives', + 'examples', + 'attributes', + 'functions', + 'methods', + 'classes', + 'modules', + 'warns', + 'warnings', + ], + 'google', + ), + ( + r'\n[ \t]*{0}\n[ \t]*---+\n', + [ + 'deprecated', + 'parameters', + 'other parameters', + 'returns', + 'yields', + 'receives', + 'raises', + 'warns', + 'attributes', + 'functions', + 'methods', + 'classes', + 'modules', + ], + 'numpy', + ), +] diff --git a/pydantic_ai/_pydantic.py b/pydantic_ai/_pydantic.py index 62cf87f487..f6d2e180f5 100644 --- a/pydantic_ai/_pydantic.py +++ b/pydantic_ai/_pydantic.py @@ -5,12 +5,9 @@ from __future__ import annotations as _annotations -import re -from inspect import Parameter, Signature, signature -from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, cast, get_origin +from inspect import Parameter, signature +from typing import TYPE_CHECKING, Any, TypedDict, cast, get_origin -from _griffe.enumerations import DocstringSectionKind -from _griffe.models import Docstring, Object as GriffeObject from pydantic import ConfigDict, TypeAdapter from pydantic._internal import _decorators, _generate_schema, _typing_extra from pydantic._internal._config import ConfigWrapper @@ -19,6 +16,7 @@ from pydantic.plugin._schema_validator import create_schema_validator from pydantic_core import SchemaValidator, core_schema +from ._griffe import doc_descriptions from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like if TYPE_CHECKING: @@ -66,7 +64,7 @@ def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, R var_positional_field: str | None = None errors: list[str] = [] decorators = _decorators.DecoratorInfos() - description, field_descriptions = _doc_descriptions(function, sig) + description, field_descriptions = doc_descriptions(function, sig) for index, (name, p) in enumerate(sig.parameters.items()): if p.annotation is sig.empty: @@ -192,127 +190,6 @@ def _build_schema( return td_schema, None -DocstringStyle = Literal['google', 'numpy', 'sphinx'] - - -def _doc_descriptions( - func: Callable[..., Any], sig: Signature, *, style: DocstringStyle | None = None -) -> tuple[str, dict[str, str]]: - """Extract the function description and parameter descriptions from a function's docstring. - - Returns: - A tuple of (main function description, parameter descriptions). - """ - doc = func.__doc__ - if doc is None: - return '', {} - - # see https://github.com/mkdocstrings/griffe/issues/293 - parent = cast(GriffeObject, sig) - - docstring = Docstring(doc, lineno=1, parser=style or _infer_docstring_style(doc), parent=parent) - sections = docstring.parse() - - params = {} - if parameters := next((p for p in sections if p.kind == DocstringSectionKind.parameters), None): - params = {p.name: p.description for p in parameters.value} - - main_desc = '' - if main := next((p for p in sections if p.kind == DocstringSectionKind.text), None): - main_desc = main.value - - return main_desc, params - - -def _infer_docstring_style(doc: str) -> DocstringStyle: - """Simplistic docstring style inference.""" - for pattern, replacements, style in _docstring_style_patterns: - matches = ( - re.search(pattern.format(replacement), doc, re.IGNORECASE | re.MULTILINE) for replacement in replacements - ) - if any(matches): - return style - # fallback to google style - return 'google' - - -# See https://github.com/mkdocstrings/griffe/issues/329#issuecomment-2425017804 -_docstring_style_patterns: list[tuple[str, list[str], DocstringStyle]] = [ - ( - r'\n[ \t]*:{0}([ \t]+\w+)*:([ \t]+.+)?\n', - [ - 'param', - 'parameter', - 'arg', - 'argument', - 'key', - 'keyword', - 'type', - 'var', - 'ivar', - 'cvar', - 'vartype', - 'returns', - 'return', - 'rtype', - 'raises', - 'raise', - 'except', - 'exception', - ], - 'sphinx', - ), - ( - r'\n[ \t]*{0}:([ \t]+.+)?\n[ \t]+.+', - [ - 'args', - 'arguments', - 'params', - 'parameters', - 'keyword args', - 'keyword arguments', - 'other args', - 'other arguments', - 'other params', - 'other parameters', - 'raises', - 'exceptions', - 'returns', - 'yields', - 'receives', - 'examples', - 'attributes', - 'functions', - 'methods', - 'classes', - 'modules', - 'warns', - 'warnings', - ], - 'google', - ), - ( - r'\n[ \t]*{0}\n[ \t]*---+\n', - [ - 'deprecated', - 'parameters', - 'other parameters', - 'returns', - 'yields', - 'receives', - 'raises', - 'warns', - 'attributes', - 'functions', - 'methods', - 'classes', - 'modules', - ], - 'numpy', - ), -] - - def _is_call_ctx(annotation: Any) -> bool: from .dependencies import CallContext diff --git a/pydantic_ai/_retriever.py b/pydantic_ai/_retriever.py index 7de34a5fec..543f2add38 100644 --- a/pydantic_ai/_retriever.py +++ b/pydantic_ai/_retriever.py @@ -5,13 +5,12 @@ from dataclasses import dataclass, field from typing import Any, Callable, Generic, cast -import pydantic_core from pydantic import ValidationError from pydantic_core import SchemaValidator from . import _pydantic, _utils, messages from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc -from .exceptions import ModelRetry +from .exceptions import ModelRetry, UnexpectedModelBehaviour # Usage `RetrieverEitherFunc[AgentDependencies, P]` RetrieverEitherFunc = _utils.Either[ @@ -64,7 +63,7 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes else: args_dict = self.validator.validate_python(message.args.args_object) except ValidationError as e: - return self._on_error(e.errors(include_url=False), message) + return self._on_error(e, message) args, kwargs = self._call_args(deps, args_dict, message) try: @@ -75,7 +74,7 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes function = cast(Callable[[Any], str], self.function.whichever()) response_content = await _utils.run_in_executor(function, *args, **kwargs) except ModelRetry as e: - return self._on_error(e.message, message) + return self._on_error(e, message) self._current_retry = 0 return messages.ToolReturn( @@ -98,14 +97,16 @@ def _call_args( return args, args_dict - def _on_error( - self, content: list[pydantic_core.ErrorDetails] | str, call_message: messages.ToolCall - ) -> messages.RetryPrompt: + def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt: self._current_retry += 1 if self._current_retry > self.max_retries: # TODO custom error with details of the retriever - raise + raise UnexpectedModelBehaviour(f'Retriever exceeded max retries count of {self.max_retries}') from exc else: + if isinstance(exc, ValidationError): + content = exc.errors(include_url=False) + else: + content = exc.message return messages.RetryPrompt( tool_name=call_message.tool_name, content=content, diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index aa81d3d326..bd38da5878 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -326,7 +326,7 @@ def override_model(self, overriding_model: models.Model | models.KnownModelName) def system_prompt( self, func: _system_prompt.SystemPromptFunc[AgentDeps] ) -> _system_prompt.SystemPromptFunc[AgentDeps]: - """Decorator to register a system prompt function that takes `CallContext` as it's only argument.""" + """Decorator to register a system prompt function that optionally takes `CallContext` as it's only argument.""" self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func)) return func diff --git a/pydantic_ai/dependencies.py b/pydantic_ai/dependencies.py index 96c82f8ff6..90e8b7e5ba 100644 --- a/pydantic_ai/dependencies.py +++ b/pydantic_ai/dependencies.py @@ -1,10 +1,10 @@ from __future__ import annotations as _annotations -from collections.abc import Awaitable +from collections.abc import Awaitable, Mapping from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import Concatenate, ParamSpec, TypeAlias if TYPE_CHECKING: from .result import ResultData @@ -20,6 +20,7 @@ 'RetrieverContextFunc', 'RetrieverPlainFunc', 'RetrieverParams', + 'JsonData', ) AgentDeps = TypeVar('AgentDeps') @@ -65,7 +66,10 @@ class CallContext(Generic[AgentDeps]): Usage `ResultValidator[AgentDeps, ResultData]`. """ -RetrieverReturnValue = Union[str, Awaitable[str], dict[str, Any], Awaitable[dict[str, Any]]] +JsonData: TypeAlias = 'None | str | int | float | list[JsonData] | Mapping[str, JsonData]' +"""JSON data type alias.""" + +RetrieverReturnValue = Union[JsonData, Awaitable[JsonData]] """Return value of a retriever function.""" RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], RetrieverParams], RetrieverReturnValue] """A retriever function that takes `CallContext` as the first argument. diff --git a/pydantic_ai/messages.py b/pydantic_ai/messages.py index 4da93aa197..26cfe65bba 100644 --- a/pydantic_ai/messages.py +++ b/pydantic_ai/messages.py @@ -1,12 +1,15 @@ from __future__ import annotations as _annotations import json +from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime -from typing import Annotated, Any, Literal, Union +from typing import TYPE_CHECKING, Annotated, Any, Literal, Union import pydantic import pydantic_core +from pydantic import TypeAdapter +from typing_extensions import TypeAlias, TypeAliasType from . import _pydantic from ._utils import now_utc as _now_utc @@ -41,7 +44,13 @@ class UserPrompt: """Message type identifier, this type is available on all message as a discriminator.""" -tool_return_value_object = _pydantic.LazyTypeAdapter(dict[str, Any]) +JsonData: TypeAlias = 'Union[str, int, float, None, list[JsonData], Mapping[str, JsonData]]' +if not TYPE_CHECKING: + # work around for https://github.com/pydantic/pydantic/issues/10873 + # this is need for pydantic to work both `json_ta` and `MessagesTypeAdapter` at the bottom of this file + JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, list[JsonData], Mapping[str, JsonData]]') + +json_ta: TypeAdapter[JsonData] = TypeAdapter(JsonData) @dataclass @@ -50,7 +59,7 @@ class ToolReturn: tool_name: str """The name of the "tool" was called.""" - content: str | dict[str, Any] + content: JsonData """The return value.""" tool_id: str | None = None """Optional tool identifier, this is used by some models including OpenAI.""" @@ -63,14 +72,15 @@ def model_response_str(self) -> str: if isinstance(self.content, str): return self.content else: - content = tool_return_value_object.validate_python(self.content) - return tool_return_value_object.dump_json(content).decode() + content = json_ta.validate_python(self.content) + return json_ta.dump_json(content).decode() - def model_response_object(self) -> dict[str, Any]: - if isinstance(self.content, str): - return {'return_value': self.content} + def model_response_object(self) -> dict[str, JsonData]: + # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict + if isinstance(self.content, dict): + return json_ta.validate_python(self.content) # pyright: ignore[reportReturnType] else: - return tool_return_value_object.validate_python(self.content) + return {'return_value': json_ta.validate_python(self.content)} @dataclass diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index d944f8ee05..8d123e7332 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -472,26 +472,30 @@ class _GeminiTextContent(TypedDict): class _GeminiTools(TypedDict): - function_declarations: list[_GeminiFunction] + function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]] class _GeminiFunction(TypedDict): name: str description: str - parameters: dict[str, Any] + parameters: NotRequired[dict[str, Any]] """ ObjectJsonSchema isn't really true since Gemini only accepts a subset of JSON Schema + and + """ def _function_from_abstract_tool(tool: AbstractToolDefinition) -> _GeminiFunction: json_schema = _GeminiJsonSchema(tool.json_schema).simplify() - return _GeminiFunction( + f = _GeminiFunction( name=tool.name, description=tool.description, - parameters=json_schema, ) + if json_schema.get('properties'): + f['parameters'] = json_schema + return f class _GeminiToolConfig(TypedDict): diff --git a/tests/test_examples.py b/tests/test_examples.py index 0ba44c2da9..807da50aed 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,9 +1,10 @@ from __future__ import annotations as _annotations import asyncio +import re import sys from collections.abc import AsyncIterator, Iterable -from datetime import datetime +from dataclasses import dataclass, field from types import ModuleType from typing import Any @@ -23,9 +24,30 @@ ) from pydantic_ai.models import KnownModelName, Model from pydantic_ai.models.function import AgentInfo, DeltaToolCalls, FunctionModel +from pydantic_ai.models.test import TestModel from tests.conftest import ClientWithHandler +@pytest.fixture(scope='module', autouse=True) +def register_modules(): + class FakeTable: + def get(self, name: str) -> int | None: + if name == 'John Doe': + return 123 + + @dataclass + class DatabaseConn: + users: FakeTable = field(default_factory=FakeTable) + + module_name = 'fake_database' + sys.modules[module_name] = module = ModuleType(module_name) + module.__dict__.update({'DatabaseConn': DatabaseConn}) + + yield + + sys.modules.pop(module_name) + + def find_filter_examples() -> Iterable[CodeExample]: for ex in find_examples('docs', 'pydantic_ai'): if ex.path.name != '_utils.py': @@ -36,21 +58,23 @@ def find_filter_examples() -> Iterable[CodeExample]: def test_docs_examples( example: CodeExample, eval_example: EvalExample, mocker: MockerFixture, client_with_handler: ClientWithHandler ): - if example.path.name == '_utils.py': - return # debug(example) mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model) - mocker.patch('pydantic_ai._utils.datetime', MockedDatetime) mocker.patch('httpx.Client.get', side_effect=http_request) mocker.patch('httpx.Client.post', side_effect=http_request) mocker.patch('httpx.AsyncClient.get', side_effect=async_http_request) mocker.patch('httpx.AsyncClient.post', side_effect=async_http_request) + mocker.patch('random.randint', return_value=4) + + prefix_settings = example.prefix_settings() ruff_ignore: list[str] = ['D'] if str(example.path).endswith('docs/index.md'): ruff_ignore.append('F841') - eval_example.set_config(ruff_ignore=ruff_ignore) + eval_example.set_config(ruff_ignore=ruff_ignore, target_version='py39') + + eval_example.print_callback = print_callback call_name = 'main' if 'def test_application_code' in example.source: @@ -63,9 +87,15 @@ def test_docs_examples( eval_example.lint(example) module_dict = eval_example.run_print_check(example, call=call_name) - if example.path.name == 'dependencies.md' and 'title="joke_app.py"' in example.prefix: - sys.modules['joke_app'] = module = ModuleType('joke_app') - module.__dict__.update(module_dict) + if title := prefix_settings.get('title'): + if title.endswith('.py'): + module_name = title[:-3] + sys.modules[module_name] = module = ModuleType(module_name) + module.__dict__.update(module_dict) + + +def print_callback(s: str) -> str: + return re.sub(r'datetime.datetime\(.+?\)', 'datetime.datetime(...)', s, flags=re.DOTALL) def http_request(url: str, **kwargs: Any) -> httpx.Response: @@ -78,7 +108,7 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: return http_request(url, **kwargs) -text_responses = { +text_responses: dict[str, str | ToolCall] = { 'What is the weather like in West London and in Wiltshire?': 'The weather in West London is raining, while in Wiltshire it is sunny.', 'Tell me a joke.': 'Did you hear about the toothpaste scandal? They called it Colgate.', 'Explain?': 'This is an excellent joke invent by Samuel Colvin, it needs no explanation.', @@ -88,22 +118,44 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: 'Who was Albert Einstein?': 'Albert Einstein was a German-born theoretical physicist.', 'What was his most famous equation?': "Albert Einstein's most famous equation is (E = mc^2).", 'What is the date?': 'Hello Frank, the date today is 2032-01-02.', + 'Put my money on square eighteen': ToolCall(tool_name='roulette_wheel', args=ArgsObject({'square': 18})), + 'I bet five is the winner': ToolCall(tool_name='roulette_wheel', args=ArgsObject({'square': 5})), + 'My guess is 4': ToolCall(tool_name='roll_dice', args=ArgsObject({})), + 'Send a message to John Doe asking for coffee next week': ToolCall( + tool_name='get_user_by_name', args=ArgsObject({'name': 'John'}) + ), + 'Please get me the volume of a box with size 6.': ToolCall(tool_name='calc_volume', args=ArgsObject({'size': 6})), } async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: m = messages[-1] if m.role == 'user': - if text_response := text_responses.get(m.content): - return ModelTextResponse(content=text_response) + if response := text_responses.get(m.content): + if isinstance(response, str): + return ModelTextResponse(content=response) + else: + return ModelStructuredResponse(calls=[response]) - if m.role == 'user' and m.content == 'Put my money on square eighteen': - return ModelStructuredResponse(calls=[ToolCall(tool_name='roulette_wheel', args=ArgsObject({'square': 18}))]) - elif m.role == 'user' and m.content == 'I bet five is the winner': - return ModelStructuredResponse(calls=[ToolCall(tool_name='roulette_wheel', args=ArgsObject({'square': 5}))]) elif m.role == 'tool-return' and m.tool_name == 'roulette_wheel': win = m.content == 'winner' return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsObject({'response': win}))]) + elif m.role == 'tool-return' and m.tool_name == 'roll_dice': + return ModelStructuredResponse(calls=[ToolCall(tool_name='get_player_name', args=ArgsObject({}))]) + elif m.role == 'tool-return' and m.tool_name == 'get_player_name': + return ModelTextResponse(content="Congratulations Adam, you guessed correctly! You're a winner!") + if m.role == 'retry-prompt' and isinstance(m.content, str) and m.content.startswith("No user found with name 'Joh"): + return ModelStructuredResponse( + calls=[ToolCall(tool_name='get_user_by_name', args=ArgsObject({'name': 'John Doe'}))] + ) + elif m.role == 'tool-return' and m.tool_name == 'get_user_by_name': + args = { + 'message': 'Hello John, would you be free for coffee sometime next week? Let me know what works for you!', + 'user_id': 123, + } + return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsObject(args))]) + elif m.role == 'retry-prompt' and m.tool_name == 'calc_volume': + return ModelStructuredResponse(calls=[ToolCall(tool_name='calc_volume', args=ArgsObject({'size': 6}))]) else: sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') @@ -112,23 +164,23 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyRespo async def stream_model_logic(messages: list[Message], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: m = messages[-1] if m.role == 'user': - if text_response := text_responses.get(m.content): - *words, last_word = text_response.split(' ') - for work in words: - yield f'{work} ' - await asyncio.sleep(0.05) - yield last_word - return + if response := text_responses.get(m.content): + if isinstance(response, str): + *words, last_word = response.split(' ') + for work in words: + yield f'{work} ' + await asyncio.sleep(0.05) + yield last_word + return + else: + raise NotImplementedError('todo') sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') -def mock_infer_model(_model: Model | KnownModelName) -> Model: - return FunctionModel(model_logic, stream_function=stream_model_logic) - - -class MockedDatetime(datetime): - @classmethod - def now(cls, tz: Any = None) -> datetime: # type: ignore - return datetime(2032, 1, 2, 3, 4, 5, 6, tzinfo=tz) +def mock_infer_model(model: Model | KnownModelName) -> Model: + if isinstance(model, (FunctionModel, TestModel)): + return model + else: + return FunctionModel(model_logic, stream_function=stream_model_logic) From 2f5c6112ea0988eec2adbb4e49f050a12eec4c3c Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 18 Nov 2024 17:49:41 +0000 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: hyperlint-ai[bot] <154288675+hyperlint-ai[bot]@users.noreply.github.com> --- docs/agents.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 7b73aec716..9f11aaebad 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -352,7 +352,7 @@ The return type of retriever can either be `str` or a JSON object typed as `dict Validation errors from both retriever parameter validation and [structured result validation](results.md#structured-result-validation) can be passed back to the model with a request to retry. -You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [retriever](#retrievers) or [result validators functions](results.md#result-validators-functions) to tell the model it should retry. +You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [retriever](#retrievers) or [result validator functions](results.md#result-validators-functions) to tell the model it should retry. - The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific retriever][pydantic_ai.Agent.retriever_context], or a [result validator][pydantic_ai.Agent.__init__]. - You can access the current retry count from within a retriever or result validator via [`ctx.retry`][pydantic_ai.dependencies.CallContext]. @@ -403,7 +403,7 @@ user_id=123 message='Hello John, would you be free for coffee sometime next week ## Model errors -If models behave unexpectedly (e.g., the retry limit is exceeded, or their api returns `503`), agent runs will raise [`UnexpectedModelBehaviour`][pydantic_ai.exceptions.UnexpectedModelBehaviour]. +If models behave unexpectedly (e.g., the retry limit is exceeded, or their API returns `503`), agent runs will raise [`UnexpectedModelBehaviour`][pydantic_ai.exceptions.UnexpectedModelBehaviour]. In these cases, [`agent.last_run_messages`][pydantic_ai.Agent.last_run_messages] can be used to access the messages exchanged during the run to help diagnose the issue. From 43d4cc3e5df0f09c381f70ab6ce72247b0cc46c6 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 18 Nov 2024 18:16:05 +0000 Subject: [PATCH 3/4] cleanup --- docs/agents.md | 2 +- mkdocs.yml | 4 ++++ pydantic_ai/dependencies.py | 6 +++--- pydantic_ai/messages.py | 6 +++--- tests/typed_agent.py | 9 +++++++-- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 9f11aaebad..c36d56ad66 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -346,7 +346,7 @@ agent.run_sync('hello', model=FunctionModel(print_schema)) _(This example is complete, it can be run "as is")_ -The return type of retriever can either be `str` or a JSON object typed as `dict[str, Any]` as some models (e.g. Gemini) support structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. +The return type of retriever can any valid JSON object ([`JsonData`][pydantic_ai.dependencies.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data, if a Python is returned and the model expects a string, the value will be serialized to JSON ## Reflection and self-correction diff --git a/mkdocs.yml b/mkdocs.yml index b7ebc585bf..5bce3b0b07 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -120,6 +120,10 @@ markdown_extensions: - pymdownx.tasklist: custom_checkbox: true - sane_lists # this means you can start a list from any number + - material.extensions.preview: + targets: + include: + - '*' watch: - pydantic_ai diff --git a/pydantic_ai/dependencies.py b/pydantic_ai/dependencies.py index 90e8b7e5ba..ee5a087830 100644 --- a/pydantic_ai/dependencies.py +++ b/pydantic_ai/dependencies.py @@ -1,6 +1,6 @@ from __future__ import annotations as _annotations -from collections.abc import Awaitable, Mapping +from collections.abc import Awaitable, Mapping, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union @@ -66,8 +66,8 @@ class CallContext(Generic[AgentDeps]): Usage `ResultValidator[AgentDeps, ResultData]`. """ -JsonData: TypeAlias = 'None | str | int | float | list[JsonData] | Mapping[str, JsonData]' -"""JSON data type alias.""" +JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]' +"""Type representing any JSON data.""" RetrieverReturnValue = Union[JsonData, Awaitable[JsonData]] """Return value of a retriever function.""" diff --git a/pydantic_ai/messages.py b/pydantic_ai/messages.py index 26cfe65bba..3d6ec3f45d 100644 --- a/pydantic_ai/messages.py +++ b/pydantic_ai/messages.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations import json -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from datetime import datetime from typing import TYPE_CHECKING, Annotated, Any, Literal, Union @@ -44,11 +44,11 @@ class UserPrompt: """Message type identifier, this type is available on all message as a discriminator.""" -JsonData: TypeAlias = 'Union[str, int, float, None, list[JsonData], Mapping[str, JsonData]]' +JsonData: TypeAlias = 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]' if not TYPE_CHECKING: # work around for https://github.com/pydantic/pydantic/issues/10873 # this is need for pydantic to work both `json_ta` and `MessagesTypeAdapter` at the bottom of this file - JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, list[JsonData], Mapping[str, JsonData]]') + JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]') json_ta: TypeAdapter[JsonData] = TypeAdapter(JsonData) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 6ade228ec5..d31293aef5 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -41,6 +41,11 @@ def ok_retriever_plain(x: str) -> dict[str, str]: return {'x': x} +@typed_agent.retriever_plain +def ok_json_list(x: str) -> list[Union[str, int]]: + return [x, 1] + + @typed_agent.retriever_context async def bad_retriever1(ctx: CallContext[MyDeps], x: str) -> str: total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined] @@ -53,8 +58,8 @@ async def bad_retriever2(ctx: CallContext[int], x: str) -> str: @typed_agent.retriever_plain # type: ignore[arg-type] -async def bad_retriever_return(x: int) -> list[int]: - return [x] +async def bad_retriever_return(x: int) -> list[MyDeps]: + return [MyDeps(1, x)] with expect_error(ValueError): From 9684df53f2660b5b8333675d50275a7a2e45aee9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 18 Nov 2024 18:18:05 +0000 Subject: [PATCH 4/4] add "as is" comments --- docs/agents.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/agents.md b/docs/agents.md index c36d56ad66..5c9c0d3a11 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -114,6 +114,8 @@ print(result2.data) ``` 1. Continue the conversation, without `message_history` the model would not know who "he" was referring to. +_(This example is complete, it can be run "as is")_ + ## System Prompts System prompts might seem simple at first glance since they're just strings (or sequences of strings that are concatenated), but crafting the right system prompt is key to getting the model to behave as you want. @@ -159,6 +161,8 @@ print(result.data) 3. Dynamic system prompt defined via a decorator. 4. Another dynamic system prompt, system prompts don't have to have the `CallContext` parameter. +_(This example is complete, it can be run "as is")_ + ## Retrievers Retrievers provide a mechanism for models to request extra information to help them generate a response. @@ -473,6 +477,8 @@ else: ``` 1. Define a retriever that will raise `ModelRetry` repeatedly in this case. +_(This example is complete, it can be run "as is")_ + ## API Reference ::: pydantic_ai.Agent