From 33dfd08cf2728ae4313ab7fca7ca0aeee1b4b0a5 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 15 Nov 2024 11:09:43 +0000 Subject: [PATCH 1/7] adding pytest-examples --- .github/workflows/ci.yml | 17 ++++++- Makefile | 15 +++--- docs/concepts/{index.md => agents.md} | 0 docs/concepts/dependencies.md | 4 +- .../{result-validation.md => results.md} | 0 mkdocs.yml | 4 +- pyproject.toml | 16 ++++--- tests/test_examples.py | 10 ++++ uv.lock | 46 ++++++++++++++++--- 9 files changed, 88 insertions(+), 24 deletions(-) rename docs/concepts/{index.md => agents.md} (100%) rename docs/concepts/{result-validation.md => results.md} (100%) create mode 100644 tests/test_examples.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index de1eb4fefa..3d16317658 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: enable-cache: true - name: Install dependencies - run: uv sync --python 3.12 --frozen --all-extras + run: uv sync --python 3.12 --frozen --all-extras --no-dev --group lint - uses: pre-commit/action@v3.0.0 with: @@ -34,6 +34,19 @@ jobs: env: SKIP: no-commit-to-branch + # mypy and lint are a bit slower than other jobs, so we run them separately + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v3 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --python 3.12 --frozen --no-dev --group lint + - run: make typecheck-mypy docs: @@ -138,7 +151,7 @@ jobs: # https://github.com/marketplace/actions/alls-green#why used for branch protection checks check: if: always() - needs: [lint, docs, test-live, test, coverage] + needs: [lint, mypy, docs, test-live, test, coverage] runs-on: ubuntu-latest steps: diff --git a/Makefile b/Makefile index ac81b50190..35ccea20cb 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ .PHONY: install # Install the package, dependencies, and pre-commit for local development install: .uv .pre-commit - uv sync --frozen --all-extras --group docs + uv sync --frozen --all-extras --group lint --group docs pre-commit install --install-hooks .PHONY: format # Format the code @@ -67,12 +67,15 @@ docs: docs-serve: uv run mkdocs serve --no-strict -# install insiders packages for docs (to avoid running this on every build, `touch .docs-insiders-install`) +.PHONY: .docs-insiders-install # install insiders packages for docs if necessary .docs-insiders-install: - @echo 'installing insiders packages' - @uv pip install -U \ - --extra-index-url https://pydantic:${PPPR_TOKEN}@pppr.pydantic.dev/simple/ \ - mkdocs-material mkdocstrings-python +ifeq ($(shell uv pip show mkdocs-material | grep -q insiders && echo 'installed'), installed) + @echo 'insiders packages already installed' +else + @echo 'installing insiders packages...' + @uv pip install -U mkdocs-material mkdocstrings-python \ + --extra-index-url https://pydantic:${PPPR_TOKEN}@pppr.pydantic.dev/simple/ +endif .PHONY: docs-insiders # Build the documentation using insiders packages docs-insiders: .docs-insiders-install diff --git a/docs/concepts/index.md b/docs/concepts/agents.md similarity index 100% rename from docs/concepts/index.md rename to docs/concepts/agents.md diff --git a/docs/concepts/dependencies.md b/docs/concepts/dependencies.md index 843ddcfeca..202c6e56b6 100644 --- a/docs/concepts/dependencies.md +++ b/docs/concepts/dependencies.md @@ -1,6 +1,6 @@ # Dependencies -PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](system-prompt.md), [retrievers](retrievers.md) and [result validators](result-validation.md#TODO). +PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](system-prompt.md), [retrievers](retrievers.md) and [result validators](results.md#TODO). Matching PydanticAI's design philosophy, our dependency system tries to use existing best practice in Python development rather than inventing esoteric "magic", this should make dependencies type-safe, understandable easier to test and ultimately easier to deploy in production. @@ -157,7 +157,7 @@ _(This example is complete, it can be run "as is")_ ## Full Example -As well as system prompts, dependencies can be used in [retrievers](retrievers.md) and [result validators](result-validation.md#TODO). +As well as system prompts, dependencies can be used in [retrievers](retrievers.md) and [result validators](results.md#TODO). ```python title="full_example.py" hl_lines="27-35 38-48" from dataclasses import dataclass diff --git a/docs/concepts/result-validation.md b/docs/concepts/results.md similarity index 100% rename from docs/concepts/result-validation.md rename to docs/concepts/results.md diff --git a/mkdocs.yml b/mkdocs.yml index 27c227126c..d1934ebb19 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -14,11 +14,11 @@ nav: - Introduction: index.md - install.md - Concepts: - - concepts/index.md + - concepts/agents.md - concepts/dependencies.md - concepts/retrievers.md - concepts/system-prompt.md - - concepts/result-validation.md + - concepts/results.md - concepts/message-history.md - concepts/streaming.md - concepts/testing-evals.md diff --git a/pyproject.toml b/pyproject.toml index dab9f85a91..a3886577d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,16 +57,20 @@ examples = [ [dependency-groups] dev = [ + "anyio>=4.5.0", + "devtools>=0.12.2", + "coverage[toml]>=7.6.2", "dirty-equals>=0.8.0", - "mypy>=1.11.2", - "pyright>=1.1.388", + "inline-snapshot>=0.14", "pytest>=8.3.3", + "pytest-examples>=0.0.13", + "pytest-mock>=3.14.0", "pytest-pretty>=1.2.0", - "inline-snapshot>=0.14", +] +lint = [ + "mypy>=1.11.2", + "pyright>=1.1.388", "ruff>=0.6.9", - "coverage[toml]>=7.6.2", - "devtools>=0.12.2", - "anyio>=4.5.0", ] docs = [ "mkdocs", diff --git a/tests/test_examples.py b/tests/test_examples.py new file mode 100644 index 0000000000..0d3171b96f --- /dev/null +++ b/tests/test_examples.py @@ -0,0 +1,10 @@ +import pytest +from pytest_examples import CodeExample, EvalExample, find_examples +from pytest_mock import MockerFixture + + +@pytest.mark.parametrize('example', find_examples('docs'), ids=str) +def test_docs_examples(example: CodeExample, eval_example: EvalExample, mocker: MockerFixture): + # debug(example) + eval_example.lint(example) + eval_example.run(example) diff --git a/uv.lock b/uv.lock index db6f810631..7adfb412ed 100644 --- a/uv.lock +++ b/uv.lock @@ -1447,17 +1447,21 @@ dev = [ { name = "devtools" }, { name = "dirty-equals" }, { name = "inline-snapshot" }, - { name = "mypy" }, - { name = "pyright" }, { name = "pytest" }, + { name = "pytest-examples" }, + { name = "pytest-mock" }, { name = "pytest-pretty" }, - { name = "ruff" }, ] docs = [ { name = "mkdocs" }, { name = "mkdocs-material", extra = ["imaging"] }, { name = "mkdocstrings-python" }, ] +lint = [ + { name = "mypy" }, + { name = "pyright" }, + { name = "ruff" }, +] [package.metadata] requires-dist = [ @@ -1483,17 +1487,21 @@ dev = [ { name = "devtools", specifier = ">=0.12.2" }, { name = "dirty-equals", specifier = ">=0.8.0" }, { name = "inline-snapshot", specifier = ">=0.14" }, - { name = "mypy", specifier = ">=1.11.2" }, - { name = "pyright", specifier = ">=1.1.388" }, { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-examples", specifier = ">=0.0.13" }, + { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, - { name = "ruff", specifier = ">=0.6.9" }, ] docs = [ { name = "mkdocs" }, { name = "mkdocs-material", extras = ["imaging"] }, { name = "mkdocstrings-python" }, ] +lint = [ + { name = "mypy", specifier = ">=1.11.2" }, + { name = "pyright", specifier = ">=1.1.388" }, + { name = "ruff", specifier = ">=0.6.9" }, +] [[package]] name = "pydantic-core" @@ -1644,6 +1652,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, ] +[[package]] +name = "pytest-examples" +version = "0.0.13" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "black" }, + { name = "pytest" }, + { name = "ruff" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0d/67/9d3486460bde76a5fd3709de38eea6cbeab7f41b1fe2c657f52a71b19575/pytest_examples-0.0.13.tar.gz", hash = "sha256:4d6fd78154953e84444f58f193eb6cc8d853bca7f0ee9f44ea75db043a2c19b5", size = 20445 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/c4/9b639b46dc0aa246847c5d82171d4167218bcceb2108f566e25afdfd229f/pytest_examples-0.0.13-py3-none-any.whl", hash = "sha256:8c05cc66459199963c7970ff417300ae7e5b7b4c9d5bd859f1e74a35db75555b", size = 17313 }, +] + +[[package]] +name = "pytest-mock" +version = "3.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/90/a955c3ab35ccd41ad4de556596fa86685bf4fc5ffcc62d22d856cfd4e29a/pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0", size = 32814 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/3b/b26f90f74e2986a82df6e7ac7e319b8ea7ccece1caec9f8ab6104dc70603/pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f", size = 9863 }, +] + [[package]] name = "pytest-pretty" version = "1.2.0" From 6654872abadf789b1183c0a4ca16dd7e674872b0 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 15 Nov 2024 11:21:26 +0000 Subject: [PATCH 2/7] formatting examples --- docs/concepts/dependencies.md | 10 ++++------ docs/concepts/message-history.md | 4 +++- docs/index.md | 3 ++- tests/test_examples.py | 16 ++++++++++++++-- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/docs/concepts/dependencies.md b/docs/concepts/dependencies.md index 202c6e56b6..2425d7ab1a 100644 --- a/docs/concepts/dependencies.md +++ b/docs/concepts/dependencies.md @@ -78,7 +78,7 @@ agent = Agent( async def get_system_prompt(ctx: CallContext[MyDeps]) -> str: # (2)! response = await ctx.deps.http_client.get( # (3)! 'https://example.com', - headers={'Authorization': f'Bearer {ctx.deps.api_key}'} # (4)! + headers={'Authorization': f'Bearer {ctx.deps.api_key}'}, # (4)! ) response.raise_for_status() return f'Prompt: {response.text}' @@ -134,8 +134,7 @@ agent = Agent( @agent.system_prompt def get_system_prompt(ctx: CallContext[MyDeps]) -> str: # (2)! response = ctx.deps.http_client.get( - 'https://example.com', - headers={'Authorization': f'Bearer {ctx.deps.api_key}'} + 'https://example.com', headers={'Authorization': f'Bearer {ctx.deps.api_key}'} ) response.raise_for_status() return f'Prompt: {response.text}' @@ -264,7 +263,6 @@ async def application_code(prompt: str) -> str: # (3)! app_deps = MyDeps('foobar', client) result = await joke_agent.run(prompt, deps=app_deps) # (4)! return result.data - ``` 1. Define a method on the dependency to make the system prompt easier to customise. @@ -273,7 +271,7 @@ async def application_code(prompt: str) -> str: # (3)! 4. Call the agent from within the application code, in a real application this call might be deep within a call stack. Note `app_deps` here will NOT be used when deps are overridden. ```py title="test_joke_app.py" hl_lines="10-12" -from joke_app import application_code, joke_agent, MyDeps +from joke_app import MyDeps, application_code, joke_agent class TestMyDeps(MyDeps): # (1)! @@ -314,7 +312,7 @@ joke_agent = Agent( system_prompt=( 'Use the "joke_factory" to generate some jokes, then choose the best. ' 'You must return just a single joke.' - ) + ), ) factory_agent = Agent('gemini-1.5-pro', result_type=list[str]) diff --git a/docs/concepts/message-history.md b/docs/concepts/message-history.md index 8fff79a8c9..2d63de15f4 100644 --- a/docs/concepts/message-history.md +++ b/docs/concepts/message-history.md @@ -114,7 +114,9 @@ agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result1 = agent.run_sync('Tell me a joke.') print(result1.data) -result2 = agent.run_sync('Explain?', model='gemini-1.5-pro', message_history=result1.new_messages()) +result2 = agent.run_sync( + 'Explain?', model='gemini-1.5-pro', message_history=result1.new_messages() +) print(result2.data) print(result2.all_messages()) diff --git a/docs/index.md b/docs/index.md index 5b93a77ce1..2d6f5c55e0 100644 --- a/docs/index.md +++ b/docs/index.md @@ -38,6 +38,7 @@ Partial example of using retrievers to help an LLM respond to a user's query abo ```py title="weather_agent.py" import httpx + from pydantic_ai import Agent, CallContext weather_agent = Agent( # (1)! @@ -75,7 +76,7 @@ async def main(): deps=client, ) print(result.data) # (9)! - # > 'The weather in West London is raining, while in Wiltshire it is sunny.' + #> 'The weather in West London is raining, while in Wiltshire it is sunny.' print(result.all_messages()) # (10)! ``` diff --git a/tests/test_examples.py b/tests/test_examples.py index 0d3171b96f..4a8cebaf51 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -2,9 +2,21 @@ from pytest_examples import CodeExample, EvalExample, find_examples from pytest_mock import MockerFixture +LINE_LENGTH = 88 + @pytest.mark.parametrize('example', find_examples('docs'), ids=str) def test_docs_examples(example: CodeExample, eval_example: EvalExample, mocker: MockerFixture): # debug(example) - eval_example.lint(example) - eval_example.run(example) + ruff_ignore: list[str] = ['D'] + if str(example.path).endswith('docs/index.md'): + ruff_ignore.append('F841') + eval_example.set_config(ruff_ignore=ruff_ignore, line_length=LINE_LENGTH) + if eval_example.update_examples: + eval_example.format(example) + # eval_example.run_print_update(example) + else: + eval_example.lint(example) + # eval_example.run_print_check(example) + + # eval_example.run(example) From d0431cef2c8fcfd9e2c6634dcb2a26f02beeaada Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 15 Nov 2024 16:54:48 +0000 Subject: [PATCH 3/7] running examples, WIP --- docs/api/agent.md | 5 --- docs/api/models/base.md | 1 + docs/concepts/dependencies.md | 11 ++++-- docs/index.md | 6 +-- pydantic_ai/agent.py | 37 +++++------------- pydantic_ai/models/__init__.py | 21 ++++++++++- tests/conftest.py | 4 +- tests/test_examples.py | 69 +++++++++++++++++++++++++++++++--- 8 files changed, 106 insertions(+), 48 deletions(-) diff --git a/docs/api/agent.md b/docs/api/agent.md index 2dd569c96f..06de0fcc4d 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -15,8 +15,3 @@ - retriever_plain - retriever_context - result_validator - -::: pydantic_ai.agent - options: - members: - - KnownModelName diff --git a/docs/api/models/base.md b/docs/api/models/base.md index 48672830f6..ea2ed5f098 100644 --- a/docs/api/models/base.md +++ b/docs/api/models/base.md @@ -3,6 +3,7 @@ ::: pydantic_ai.models options: members: + - KnownModelName - Model - AgentModel - AbstractToolDefinition diff --git a/docs/concepts/dependencies.md b/docs/concepts/dependencies.md index 2425d7ab1a..f566113998 100644 --- a/docs/concepts/dependencies.md +++ b/docs/concepts/dependencies.md @@ -13,7 +13,7 @@ Here's an example of defining an agent that requires dependencies. (**Note:** dependencies aren't actually used in this example, see [Accessing Dependencies](#accessing-dependencies) below) -```python title="unused_dependencies.py" +```py title="unused_dependencies.py" from dataclasses import dataclass import httpx @@ -41,6 +41,7 @@ async def main(): deps=deps, # (3)! ) print(result.data) + #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. Define a dataclass to hold dependencies. @@ -89,6 +90,7 @@ async def main(): deps = MyDeps('foobar', client) result = await agent.run('Tell me a joke.', deps=deps) print(result.data) + #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. [`CallContext`][pydantic_ai.dependencies.CallContext] may optionally be passed to a [`system_prompt`][pydantic_ai.Agent.system_prompt] function as the only argument. @@ -147,6 +149,7 @@ async def main(): deps=deps, ) print(result.data) + #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. Here we use a synchronous `httpx.Client` instead of an asynchronous `httpx.AsyncClient`. @@ -214,6 +217,7 @@ async def main(): deps = MyDeps('foobar', client) result = await agent.run('Tell me a joke.', deps=deps) print(result.data) + #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. To pass `CallContext` and to a retriever, us the [`retriever_context`][pydantic_ai.Agent.retriever_context] decorator. @@ -282,8 +286,8 @@ class TestMyDeps(MyDeps): # (1)! async def test_application_code(): test_deps = TestMyDeps('test_key', None) # (2)! with joke_agent.override_deps(test_deps): # (3)! - joke = application_code('Tell me a joke.') # (4)! - assert joke == 'funny' + joke = await application_code('Tell me a joke.') # (4)! + assert joke.startswith('Did you hear about the toothpaste scandal?') ``` 1. Define a subclass of `MyDeps` in tests to customise the system prompt factory. @@ -326,6 +330,7 @@ async def joke_factory(ctx: CallContext[MyDeps], count: int) -> str: result = joke_agent.run_sync('Tell me a joke.', deps=MyDeps(factory_agent)) print(result.data) +#> Did you hear about the toothpaste scandal? They called it Colgate. ``` ## Examples diff --git a/docs/index.md b/docs/index.md index 2d6f5c55e0..ee0d3991f5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -76,9 +76,9 @@ async def main(): deps=client, ) print(result.data) # (9)! - #> 'The weather in West London is raining, while in Wiltshire it is sunny.' + #> The weather in West London is raining, while in Wiltshire it is sunny. - print(result.all_messages()) # (10)! + messages = result.all_messages() # (10)! ``` 1. An agent that can tell users about the weather in a particular location. Agents combine a system prompt, a response type (here `str`) and "retrievers" (aka tools). @@ -90,7 +90,7 @@ async def main(): 7. Multiple retrievers can be registered with the same agent, the LLM can choose which (if any) retrievers to call in order to respond to a user. 8. Run the agent asynchronously, conducting a conversation with the LLM until a final response is reached. You can also run agents synchronously with `run_sync`. Internally agents are all async, so `run_sync` is a helper using `asyncio.run` to call `run()`. 9. The response from the LLM, in this case a `str`, Agents are generic in both the type of `deps` and `result_type`, so calls are typed end-to-end. -10. `result.all_messages()` includes details of messages exchanged, this is useful both to understand the conversation that took place and useful if you want to continue the conversation later — messages can be passed back to later `run/run_sync` calls. +10. [`result.all_messages()`](concepts/message-history.md) includes details of messages exchanged, this is useful both to understand the conversation that took place and useful if you want to continue the conversation later — messages can be passed back to later `run/run_sync` calls. !!! tip "Complete `weather_agent.py` example" This example is incomplete for the sake of brevity; you can find a complete `weather_agent.py` example [here](examples/weather-agent.md). diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index fb4211b06b..36e8fdb2b4 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -4,7 +4,7 @@ from collections.abc import AsyncIterator, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass -from typing import Any, Callable, Generic, Literal, cast, final, overload +from typing import Any, Callable, Generic, cast, final, overload import logfire_api from typing_extensions import assert_never @@ -22,24 +22,7 @@ from .dependencies import AgentDeps, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc from .result import ResultData -__all__ = 'Agent', 'KnownModelName' - -KnownModelName = Literal[ - 'openai:gpt-4o', - 'openai:gpt-4o-mini', - 'openai:gpt-4-turbo', - 'openai:gpt-4', - 'openai:o1-preview', - 'openai:o1-mini', - 'openai:gpt-3.5-turbo', - 'gemini-1.5-flash', - 'gemini-1.5-pro', - 'test', -] -"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent]. - -`KnownModelName` is provided as a concise way to specify a model. -""" +__all__ = ('Agent',) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') @@ -52,7 +35,7 @@ class Agent(Generic[AgentDeps, ResultData]): """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.""" # dataclass fields mostly for my sanity — knowing what attributes are available - model: models.Model | KnownModelName | None + model: models.Model | models.KnownModelName | None """The default model configured for this agent.""" _result_schema: _result.ResultSchema[ResultData] | None _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] @@ -74,7 +57,7 @@ class Agent(Generic[AgentDeps, ResultData]): def __init__( self, - model: models.Model | KnownModelName | None = None, + model: models.Model | models.KnownModelName | None = None, result_type: type[ResultData] = str, *, system_prompt: str | Sequence[str] = (), @@ -101,7 +84,7 @@ def __init__( result_tool_name: The name of the tool to use for the final result. result_tool_description: The description of the final result tool. result_retries: The maximum number of retries to allow for result validation, defaults to `retries`. - defer_model_check: by default, if you provide a [named][pydantic_ai.agent.KnownModelName] model, + defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` to defer the evaluation until the first run. Useful if you want to @@ -132,7 +115,7 @@ async def run( user_prompt: str, *, message_history: list[_messages.Message] | None = None, - model: models.Model | KnownModelName | None = None, + model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt in async mode. @@ -201,7 +184,7 @@ def run_sync( user_prompt: str, *, message_history: list[_messages.Message] | None = None, - model: models.Model | KnownModelName | None = None, + model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. @@ -225,7 +208,7 @@ async def run_stream( user_prompt: str, *, message_history: list[_messages.Message] | None = None, - model: models.Model | KnownModelName | None = None, + model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -310,7 +293,7 @@ def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]: self._override_deps = override_deps_before @contextmanager - def override_model(self, overriding_model: models.Model | KnownModelName) -> Iterator[None]: + def override_model(self, overriding_model: models.Model | models.KnownModelName) -> Iterator[None]: """Context manager to temporarily override the model used by the agent. Args: @@ -411,7 +394,7 @@ def _register_retriever( return retriever async def _get_agent_model( - self, model: models.Model | KnownModelName | None + self, model: models.Model | models.KnownModelName | None ) -> tuple[models.Model, models.Model | None, models.AgentModel]: """Create a model configured for this agent. diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index 0378ab4c16..27e8c91202 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -11,7 +11,7 @@ from contextlib import asynccontextmanager, contextmanager from datetime import datetime from functools import cache -from typing import TYPE_CHECKING, Protocol, Union +from typing import TYPE_CHECKING, Literal, Protocol, Union import httpx @@ -19,10 +19,27 @@ if TYPE_CHECKING: from .._utils import ObjectJsonSchema - from ..agent import KnownModelName from ..result import Cost +KnownModelName = Literal[ + 'openai:gpt-4o', + 'openai:gpt-4o-mini', + 'openai:gpt-4-turbo', + 'openai:gpt-4', + 'openai:o1-preview', + 'openai:o1-mini', + 'openai:gpt-3.5-turbo', + 'gemini-1.5-flash', + 'gemini-1.5-pro', + 'test', +] +"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent]. + +`KnownModelName` is provided as a concise way to specify a model. +""" + + class Model(ABC): """Abstract class for a model.""" diff --git a/tests/conftest.py b/tests/conftest.py index 75156fae02..7592d82fc3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import re import secrets import sys -from collections.abc import Iterator +from collections.abc import AsyncIterator, Iterator from datetime import datetime from pathlib import Path from types import ModuleType @@ -82,7 +82,7 @@ def allow_model_requests(): @pytest.fixture -async def client_with_handler(): +async def client_with_handler() -> AsyncIterator[ClientWithHandler]: client: httpx.AsyncClient | None = None def create_client(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.AsyncClient: diff --git a/tests/test_examples.py b/tests/test_examples.py index 4a8cebaf51..9d1d201b2a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,22 +1,79 @@ +from __future__ import annotations as _annotations + +import sys +from datetime import datetime +from types import ModuleType + +import httpx import pytest +from devtools import debug from pytest_examples import CodeExample, EvalExample, find_examples from pytest_mock import MockerFixture -LINE_LENGTH = 88 +from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse +from pydantic_ai.models import KnownModelName, Model +from pydantic_ai.models.function import AgentInfo, FunctionModel +from tests.conftest import ClientWithHandler @pytest.mark.parametrize('example', find_examples('docs'), ids=str) -def test_docs_examples(example: CodeExample, eval_example: EvalExample, mocker: MockerFixture): +def test_docs_examples( + example: CodeExample, eval_example: EvalExample, mocker: MockerFixture, client_with_handler: ClientWithHandler +): # debug(example) + mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model) + mocker.patch('pydantic_ai._utils.datetime', MockedDatetime) + + def get_request(url, **kwargs): + # sys.stdout.write(f'GET {args=} {kwargs=}\n') + request = httpx.Request('GET', url, **kwargs) + return httpx.Response(status_code=202, content='', request=request) + + async def async_get_request(url, **kwargs): + return get_request(url, **kwargs) + + mocker.patch('httpx.Client.get', side_effect=get_request) + mocker.patch('httpx.Client.post', side_effect=get_request) + mocker.patch('httpx.AsyncClient.get', side_effect=async_get_request) + mocker.patch('httpx.AsyncClient.post', side_effect=async_get_request) + ruff_ignore: list[str] = ['D'] if str(example.path).endswith('docs/index.md'): ruff_ignore.append('F841') - eval_example.set_config(ruff_ignore=ruff_ignore, line_length=LINE_LENGTH) + eval_example.set_config(ruff_ignore=ruff_ignore) + + call_name = 'main' + if 'def test_application_code' in example.source: + call_name = 'test_application_code' + if eval_example.update_examples: eval_example.format(example) - # eval_example.run_print_update(example) + module_dict = eval_example.run_print_update(example, call=call_name) else: eval_example.lint(example) - # eval_example.run_print_check(example) + module_dict = eval_example.run_print_check(example, call=call_name) + + if example.path.name == 'dependencies.md' and 'title="joke_app.py"' in example.prefix: + sys.modules['joke_app'] = module = ModuleType('joke_app') + module.__dict__.update(module_dict) + + +def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: + m = messages[-1] + if m.role == 'user' and m.content == 'What is the weather like in West London and in Wiltshire?': + return ModelTextResponse(content='The weather in West London is raining, while in Wiltshire it is sunny.') + if m.role == 'user' and m.content == 'Tell me a joke.': + return ModelTextResponse(content='Did you hear about the toothpaste scandal? They called it Colgate.') + else: + sys.stdout.write(str(debug.format(messages, info))) + raise RuntimeError(f'Unexpected message: {m}') + + +def mock_infer_model(model: Model | KnownModelName) -> Model: + return FunctionModel(model_logic) + - # eval_example.run(example) +class MockedDatetime(datetime): + @classmethod + def now(cls, *args, tz=None, **kwargs): + return datetime(2032, 1, 2, 3, 4, 5, 6, tzinfo=tz) From 7e434edd3080ce638555f2ede89d5e75ee0be46f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 15 Nov 2024 19:02:18 +0000 Subject: [PATCH 4/7] fixing FunctionModel arguments --- docs/concepts/message-history.md | 124 ++++++++++++++++++++++-- pydantic_ai/models/function.py | 55 +++++++---- pydantic_ai_examples/pydantic_model.py | 2 +- pydantic_ai_examples/stream_markdown.py | 2 +- pyproject.toml | 3 + tests/models/test_model_function.py | 29 +++--- tests/test_examples.py | 35 +++++-- tests/test_retrievers.py | 2 +- tests/test_streaming.py | 17 ++-- uv.lock | 25 +++-- 10 files changed, 234 insertions(+), 60 deletions(-) diff --git a/docs/concepts/message-history.md b/docs/concepts/message-history.md index 2d63de15f4..1af820eb88 100644 --- a/docs/concepts/message-history.md +++ b/docs/concepts/message-history.md @@ -38,12 +38,50 @@ agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result = agent.run_sync('Tell me a joke.') print(result.data) +#> Did you hear about the toothpaste scandal? They called it Colgate. # all messages from the run print(result.all_messages()) +""" +[ + SystemPrompt(content='Be a helpful assistant.', role='system'), + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='Did you hear about the toothpaste scandal? They called it Colgate.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), +] +""" # messages excluding system prompts print(result.new_messages()) +""" +[ + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='Did you hear about the toothpaste scandal? They called it Colgate.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), +] +""" ``` _(This example is complete, it can be run "as is")_ @@ -54,15 +92,17 @@ from pydantic_ai import Agent agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') -async with agent.run_stream('Tell me a joke.') as result: - # incomplete messages before the stream finishes - print(result.all_messages()) - async for text in result.stream(): - print(text) +async def main(): + async with agent.run_stream('Tell me a joke.') as result: + # incomplete messages before the stream finishes + print(result.all_messages()) - # complete messages once the stream finishes - print(result.all_messages()) + async for text in result.stream(): + print(text) + + # complete messages once the stream finishes + print(result.all_messages()) ``` _(This example is complete, it can be run "as is" inside an async context)_ @@ -92,11 +132,46 @@ agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result1 = agent.run_sync('Tell me a joke.') print(result1.data) +#> Did you hear about the toothpaste scandal? They called it Colgate. result2 = agent.run_sync('Explain?', message_history=result1.new_messages()) print(result2.data) +#> This is an excellent joke invent by Samuel Colvin, it needs no explanation. print(result2.all_messages()) +""" +[ + SystemPrompt(content='Be a helpful assistant.', role='system'), + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='Did you hear about the toothpaste scandal? They called it Colgate.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), + UserPrompt( + content='Explain?', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), +] +""" ``` _(This example is complete, it can be run "as is")_ @@ -113,13 +188,48 @@ agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result1 = agent.run_sync('Tell me a joke.') print(result1.data) +#> Did you hear about the toothpaste scandal? They called it Colgate. result2 = agent.run_sync( 'Explain?', model='gemini-1.5-pro', message_history=result1.new_messages() ) print(result2.data) +#> This is an excellent joke invent by Samuel Colvin, it needs no explanation. print(result2.all_messages()) +""" +[ + SystemPrompt(content='Be a helpful assistant.', role='system'), + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='Did you hear about the toothpaste scandal? They called it Colgate.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), + UserPrompt( + content='Explain?', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), +] +""" ``` ## Last Run Messages diff --git a/pydantic_ai/models/function.py b/pydantic_ai/models/function.py index f81945effb..88869c4a96 100644 --- a/pydantic_ai/models/function.py +++ b/pydantic_ai/models/function.py @@ -8,11 +8,11 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence +import inspect +from collections.abc import AsyncIterator, Awaitable, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime -from itertools import chain from typing import Callable, Union, cast from typing_extensions import TypeAlias, overload @@ -114,11 +114,17 @@ class DeltaToolCall: """A mapping of tool call IDs to incremental changes.""" # TODO these should allow coroutines -FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], ModelAnyResponse] +FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[ModelAnyResponse, Awaitable[ModelAnyResponse]]] """A function used to generate a non-streamed response.""" -StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[Iterable[str], Iterable[DeltaToolCalls]]] -"""A function used to generate a streamed response.""" +StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]] +"""A function used to generate a streamed response. + +While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should +really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls]`, + +E.g. you need to yield all text or all `DeltaToolCalls`, not mix them. +""" @dataclass @@ -131,38 +137,46 @@ class FunctionAgentModel(AgentModel): async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]: assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests' - return self.function(messages, self.agent_info), result.Cost() + if inspect.iscoroutinefunction(self.function): + return await self.function(messages, self.agent_info), result.Cost() + else: + response = await _utils.run_in_executor(self.function, messages, self.agent_info) + return cast(ModelAnyResponse, response), result.Cost() @asynccontextmanager async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]: assert ( self.stream_function is not None ), 'FunctionModel must receive a `stream_function` to support streamed requests' - response_data = iter(self.stream_function(messages, self.agent_info)) + response_stream = self.stream_function(messages, self.agent_info) try: - first = next(response_data) + first = await response_stream.__anext__() except StopIteration as e: raise ValueError('Stream function must return at least one item') from e if isinstance(first, str): - text_stream = cast(Iterable[str], response_data) - yield FunctionStreamTextResponse(iter(chain([first], text_stream))) + text_stream = cast(AsyncIterator[str], response_stream) + yield FunctionStreamTextResponse(first, text_stream) else: - structured_stream = cast(Iterable[DeltaToolCalls], response_data) - # noinspection PyTypeChecker - yield FunctionStreamStructuredResponse(iter(chain([first], structured_stream)), {}) + structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream) + yield FunctionStreamStructuredResponse(first, structured_stream) @dataclass class FunctionStreamTextResponse(StreamTextResponse): """Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" - _iter: Iterator[str] + _next: str | None + _iter: AsyncIterator[str] _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) _buffer: list[str] = field(default_factory=list, init=False) async def __anext__(self) -> None: - self._buffer.append(_utils.sync_anext(self._iter)) + if self._next is not None: + self._buffer.append(self._next) + self._next = None + else: + self._buffer.append(await self._iter.__anext__()) def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer @@ -179,12 +193,17 @@ def timestamp(self) -> datetime: class FunctionStreamStructuredResponse(StreamStructuredResponse): """Implementation of `StreamStructuredResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" - _iter: Iterator[DeltaToolCalls] - _delta_tool_calls: dict[int, DeltaToolCall] + _next: DeltaToolCalls | None + _iter: AsyncIterator[DeltaToolCalls] + _delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict) _timestamp: datetime = field(default_factory=_utils.now_utc) async def __anext__(self) -> None: - tool_call = _utils.sync_anext(self._iter) + if self._next is not None: + tool_call = self._next + self._next = None + else: + tool_call = await self._iter.__anext__() for key, new in tool_call.items(): if current := self._delta_tool_calls.get(key): diff --git a/pydantic_ai_examples/pydantic_model.py b/pydantic_ai_examples/pydantic_model.py index 92a3258da6..83804f2287 100644 --- a/pydantic_ai_examples/pydantic_model.py +++ b/pydantic_ai_examples/pydantic_model.py @@ -12,7 +12,7 @@ from pydantic import BaseModel from pydantic_ai import Agent -from pydantic_ai.agent import KnownModelName +from pydantic_ai.models import KnownModelName # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') diff --git a/pydantic_ai_examples/stream_markdown.py b/pydantic_ai_examples/stream_markdown.py index ccfee6848e..da57b5f3ee 100644 --- a/pydantic_ai_examples/stream_markdown.py +++ b/pydantic_ai_examples/stream_markdown.py @@ -16,7 +16,7 @@ from rich.text import Text from pydantic_ai import Agent -from pydantic_ai.agent import KnownModelName +from pydantic_ai.models import KnownModelName # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') diff --git a/pyproject.toml b/pyproject.toml index a3886577d1..0a16d6502b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -178,3 +178,6 @@ ignore_no_config = true [tool.inline-snapshot.shortcuts] fix=["create", "fix"] + +[tool.uv.sources] +pytest-examples = { path = "../pytest-examples" } diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 0446ab1023..1475a93d4d 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -1,6 +1,6 @@ import json import re -from collections.abc import Iterable +from collections.abc import AsyncIterator from dataclasses import asdict from datetime import timezone @@ -29,7 +29,7 @@ pytestmark = pytest.mark.anyio -def return_last(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: +async def return_last(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: last = messages[-1] response = asdict(last) response.pop('timestamp', None) @@ -84,7 +84,7 @@ def test_simple(): ) -def weather_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: # pragma: no cover +async def weather_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: # pragma: no cover assert info.allow_text_result assert info.retrievers.keys() == {'get_location', 'get_weather'} last = messages[-1] @@ -177,7 +177,7 @@ def test_weather(): assert result.data == 'Sunny in Ipswich' -def call_function_model(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: # pragma: no cover +async def call_function_model(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: # pragma: no cover last = messages[-1] if last.role == 'user': if last.content.startswith('{'): @@ -221,7 +221,7 @@ def test_var_args(): ) -def call_retriever(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: +async def call_retriever(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: if len(messages) == 1: assert len(info.retrievers) == 1 retriever_id = next(iter(info.retrievers.keys())) @@ -308,7 +308,7 @@ def spam() -> str: def test_register_all(): - def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: + async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: return ModelTextResponse( f'messages={len(messages)} allow_text_result={info.allow_text_result} retrievers={len(info.retrievers)}' ) @@ -349,7 +349,7 @@ def test_call_all(): def test_retry_str(): call_count = 0 - def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: + async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: nonlocal call_count call_count += 1 @@ -371,7 +371,7 @@ async def validate_result(r: str) -> str: def test_retry_result_type(): call_count = 0 - def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: + async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: nonlocal call_count call_count += 1 @@ -393,7 +393,7 @@ async def validate_result(r: Foo) -> Foo: assert result.data == snapshot(Foo(x=2)) -def stream_text_function(_messages: list[Message], _: AgentInfo) -> Iterable[str]: +async def stream_text_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[str]: yield 'hello ' yield 'world' @@ -416,7 +416,9 @@ class Foo(BaseModel): async def test_stream_structure(): - def stream_structured_function(_messages: list[Message], agent_info: AgentInfo) -> Iterable[DeltaToolCalls]: + async def stream_structured_function( + _messages: list[Message], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls]: assert agent_info.result_tools is not None assert len(agent_info.result_tools) == 1 name = agent_info.result_tools[0].name @@ -439,8 +441,13 @@ async def test_pass_both(): Agent(FunctionModel(return_last, stream_function=stream_text_function)) +async def stream_text_function_empty(_messages: list[Message], _: AgentInfo) -> AsyncIterator[str]: + if False: + yield 'hello ' + + async def test_return_empty(): - agent = Agent(FunctionModel(stream_function=lambda _, __: [])) + agent = Agent(FunctionModel(stream_function=stream_text_function_empty)) with pytest.raises(ValueError, match='Stream function must return at least one item'): async with agent.run_stream(''): pass diff --git a/tests/test_examples.py b/tests/test_examples.py index 9d1d201b2a..04d647d62e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,8 +1,11 @@ from __future__ import annotations as _annotations +import asyncio import sys +from collections.abc import AsyncIterator from datetime import datetime from types import ModuleType +from typing import Any import httpx import pytest @@ -12,7 +15,7 @@ from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse from pydantic_ai.models import KnownModelName, Model -from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.models.function import AgentInfo, DeltaToolCalls, FunctionModel from tests.conftest import ClientWithHandler @@ -24,12 +27,12 @@ def test_docs_examples( mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model) mocker.patch('pydantic_ai._utils.datetime', MockedDatetime) - def get_request(url, **kwargs): + def get_request(url: str, **kwargs: Any) -> httpx.Response: # sys.stdout.write(f'GET {args=} {kwargs=}\n') request = httpx.Request('GET', url, **kwargs) return httpx.Response(status_code=202, content='', request=request) - async def async_get_request(url, **kwargs): + async def async_get_request(url: str, **kwargs: Any) -> httpx.Response: return get_request(url, **kwargs) mocker.patch('httpx.Client.get', side_effect=get_request) @@ -58,22 +61,40 @@ async def async_get_request(url, **kwargs): module.__dict__.update(module_dict) -def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: +async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: m = messages[-1] if m.role == 'user' and m.content == 'What is the weather like in West London and in Wiltshire?': return ModelTextResponse(content='The weather in West London is raining, while in Wiltshire it is sunny.') if m.role == 'user' and m.content == 'Tell me a joke.': return ModelTextResponse(content='Did you hear about the toothpaste scandal? They called it Colgate.') + if m.role == 'user' and m.content == 'Explain?': + return ModelTextResponse(content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.') else: sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') -def mock_infer_model(model: Model | KnownModelName) -> Model: - return FunctionModel(model_logic) +async def stream_model_logic(messages: list[Message], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: + m = messages[-1] + if m.role == 'user' and m.content == 'Tell me a joke.': + *words, last_word = 'Did you hear about the toothpaste scandal? They called it Colgate.'.split(' ') + for work in words: + yield f'{work} ' + await asyncio.sleep(0.05) + yield last_word + + # if m.role == 'user' and m.content == 'Explain?': + # return ['This is an excellent joke invent by Samuel Colvin, it needs no explanation.'] + else: + sys.stdout.write(str(debug.format(messages, info))) + raise RuntimeError(f'Unexpected message: {m}') + + +def mock_infer_model(_model: Model | KnownModelName) -> Model: + return FunctionModel(model_logic, stream_function=stream_model_logic) class MockedDatetime(datetime): @classmethod - def now(cls, *args, tz=None, **kwargs): + def now(cls, tz: Any = None) -> datetime: # type: ignore return datetime(2032, 1, 2, 3, 4, 5, 6, tzinfo=tz) diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 465956e705..11d9d5fe46 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -67,7 +67,7 @@ async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover return f'{foo} {bar}' -def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelAnyResponse: +async def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelAnyResponse: assert len(info.retrievers) == 1 r = next(iter(info.retrievers.values())) return ModelTextResponse(json.dumps(r.json_schema)) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index cb8a3385e2..4f60397158 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations import json -from collections.abc import Iterable +from collections.abc import AsyncIterator from datetime import timezone import pytest @@ -81,7 +81,7 @@ async def test_streamed_structured_response(): async def test_structured_response_iter(): - def text_stream(_messages: list[Message], agent_info: AgentInfo) -> Iterable[DeltaToolCalls]: + async def text_stream(_messages: list[Message], agent_info: AgentInfo) -> AsyncIterator[DeltaToolCalls]: assert agent_info.result_tools is not None assert len(agent_info.result_tools) == 1 name = agent_info.result_tools[0].name @@ -145,11 +145,12 @@ async def test_streamed_text_stream(): async def test_plain_response(): call_index = 0 - def text_stream(_messages: list[Message], _: AgentInfo) -> list[str]: + async def text_stream(_messages: list[Message], _: AgentInfo) -> AsyncIterator[str]: nonlocal call_index call_index += 1 - return ['hello ', 'world'] + yield 'hello ' + yield 'world' agent = Agent(FunctionModel(stream_function=text_stream), result_type=tuple[str, str]) @@ -161,9 +162,9 @@ def text_stream(_messages: list[Message], _: AgentInfo) -> list[str]: async def test_call_retriever(): - def stream_structured_function( + async def stream_structured_function( messages: list[Message], agent_info: AgentInfo - ) -> Iterable[DeltaToolCalls] | Iterable[str]: + ) -> AsyncIterator[DeltaToolCalls | str]: if len(messages) == 1: assert agent_info.retrievers is not None assert len(agent_info.retrievers) == 1 @@ -226,7 +227,7 @@ async def ret_a(x: str) -> str: async def test_call_retriever_empty(): - def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterable[DeltaToolCalls] | Iterable[str]: + async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]: yield {} agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int]) @@ -237,7 +238,7 @@ def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterab async def test_call_retriever_wrong_name(): - def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterable[DeltaToolCalls] | Iterable[str]: + async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]: yield {0: DeltaToolCall(name='foobar', json_args='{}')} agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int]) diff --git a/uv.lock b/uv.lock index 7adfb412ed..5c1e256ec1 100644 --- a/uv.lock +++ b/uv.lock @@ -1488,7 +1488,7 @@ dev = [ { name = "dirty-equals", specifier = ">=0.8.0" }, { name = "inline-snapshot", specifier = ">=0.14" }, { name = "pytest", specifier = ">=8.3.3" }, - { name = "pytest-examples", specifier = ">=0.0.13" }, + { name = "pytest-examples", directory = "../pytest-examples" }, { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, ] @@ -1654,16 +1654,29 @@ wheels = [ [[package]] name = "pytest-examples" -version = "0.0.13" -source = { registry = "https://pypi.org/simple" } +version = "0.0.14" +source = { directory = "../pytest-examples" } dependencies = [ { name = "black" }, { name = "pytest" }, { name = "ruff" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0d/67/9d3486460bde76a5fd3709de38eea6cbeab7f41b1fe2c657f52a71b19575/pytest_examples-0.0.13.tar.gz", hash = "sha256:4d6fd78154953e84444f58f193eb6cc8d853bca7f0ee9f44ea75db043a2c19b5", size = 20445 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/17/c4/9b639b46dc0aa246847c5d82171d4167218bcceb2108f566e25afdfd229f/pytest_examples-0.0.13-py3-none-any.whl", hash = "sha256:8c05cc66459199963c7970ff417300ae7e5b7b4c9d5bd859f1e74a35db75555b", size = 17313 }, + +[package.metadata] +requires-dist = [ + { name = "black", specifier = ">=23" }, + { name = "pytest", specifier = ">=7" }, + { name = "ruff", specifier = ">=0.5.0" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "coverage", extras = ["toml"], specifier = ">=7.6.1" }, + { name = "pytest-pretty", specifier = ">=1.2.0" }, +] +lint = [ + { name = "pre-commit", specifier = ">=3.5.0" }, + { name = "ruff", specifier = ">=0.7.4" }, ] [[package]] From 92824a2e61e08efbe09bd9048bb5be8d882510d4 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 15 Nov 2024 20:55:33 +0000 Subject: [PATCH 5/7] all examples passing --- docs/concepts/message-history.md | 37 ++++++++++++++++++++++++++++++++ pydantic_ai/models/function.py | 2 +- pyproject.toml | 5 +---- tests/test_examples.py | 29 ++++++++++++------------- uv.lock | 23 +++++--------------- 5 files changed, 58 insertions(+), 38 deletions(-) diff --git a/docs/concepts/message-history.md b/docs/concepts/message-history.md index 1af820eb88..697b15d98f 100644 --- a/docs/concepts/message-history.md +++ b/docs/concepts/message-history.md @@ -97,12 +97,49 @@ async def main(): async with agent.run_stream('Tell me a joke.') as result: # incomplete messages before the stream finishes print(result.all_messages()) + """ + [ + SystemPrompt(content='Be a helpful assistant.', role='system'), + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ] + """ async for text in result.stream(): print(text) + #> Did you + #> Did you hear about + #> Did you hear about the toothpaste + #> Did you hear about the toothpaste scandal? They + #> Did you hear about the toothpaste scandal? They called it + #> Did you hear about the toothpaste scandal? They called it Colgate. # complete messages once the stream finishes print(result.all_messages()) + """ + [ + SystemPrompt(content='Be a helpful assistant.', role='system'), + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='Did you hear about the toothpaste scandal? They called it Colgate.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), + ] + """ ``` _(This example is complete, it can be run "as is" inside an async context)_ diff --git a/pydantic_ai/models/function.py b/pydantic_ai/models/function.py index 88869c4a96..c0dbcc196d 100644 --- a/pydantic_ai/models/function.py +++ b/pydantic_ai/models/function.py @@ -151,7 +151,7 @@ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherS response_stream = self.stream_function(messages, self.agent_info) try: first = await response_stream.__anext__() - except StopIteration as e: + except StopAsyncIteration as e: raise ValueError('Stream function must return at least one item') from e if isinstance(first, str): diff --git a/pyproject.toml b/pyproject.toml index 0a16d6502b..7c340d561a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ dev = [ "dirty-equals>=0.8.0", "inline-snapshot>=0.14", "pytest>=8.3.3", - "pytest-examples>=0.0.13", + "pytest-examples>=0.0.14", "pytest-mock>=3.14.0", "pytest-pretty>=1.2.0", ] @@ -178,6 +178,3 @@ ignore_no_config = true [tool.inline-snapshot.shortcuts] fix=["create", "fix"] - -[tool.uv.sources] -pytest-examples = { path = "../pytest-examples" } diff --git a/tests/test_examples.py b/tests/test_examples.py index 04d647d62e..56daee1a47 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -27,18 +27,10 @@ def test_docs_examples( mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model) mocker.patch('pydantic_ai._utils.datetime', MockedDatetime) - def get_request(url: str, **kwargs: Any) -> httpx.Response: - # sys.stdout.write(f'GET {args=} {kwargs=}\n') - request = httpx.Request('GET', url, **kwargs) - return httpx.Response(status_code=202, content='', request=request) - - async def async_get_request(url: str, **kwargs: Any) -> httpx.Response: - return get_request(url, **kwargs) - - mocker.patch('httpx.Client.get', side_effect=get_request) - mocker.patch('httpx.Client.post', side_effect=get_request) - mocker.patch('httpx.AsyncClient.get', side_effect=async_get_request) - mocker.patch('httpx.AsyncClient.post', side_effect=async_get_request) + mocker.patch('httpx.Client.get', side_effect=http_request) + mocker.patch('httpx.Client.post', side_effect=http_request) + mocker.patch('httpx.AsyncClient.get', side_effect=async_http_request) + mocker.patch('httpx.AsyncClient.post', side_effect=async_http_request) ruff_ignore: list[str] = ['D'] if str(example.path).endswith('docs/index.md'): @@ -61,6 +53,16 @@ async def async_get_request(url: str, **kwargs: Any) -> httpx.Response: module.__dict__.update(module_dict) +def http_request(url: str, **kwargs: Any) -> httpx.Response: + # sys.stdout.write(f'GET {args=} {kwargs=}\n') + request = httpx.Request('GET', url, **kwargs) + return httpx.Response(status_code=202, content='', request=request) + + +async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: + return http_request(url, **kwargs) + + async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: m = messages[-1] if m.role == 'user' and m.content == 'What is the weather like in West London and in Wiltshire?': @@ -82,9 +84,6 @@ async def stream_model_logic(messages: list[Message], info: AgentInfo) -> AsyncI yield f'{work} ' await asyncio.sleep(0.05) yield last_word - - # if m.role == 'user' and m.content == 'Explain?': - # return ['This is an excellent joke invent by Samuel Colvin, it needs no explanation.'] else: sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') diff --git a/uv.lock b/uv.lock index 5c1e256ec1..755a4ad682 100644 --- a/uv.lock +++ b/uv.lock @@ -1488,7 +1488,7 @@ dev = [ { name = "dirty-equals", specifier = ">=0.8.0" }, { name = "inline-snapshot", specifier = ">=0.14" }, { name = "pytest", specifier = ">=8.3.3" }, - { name = "pytest-examples", directory = "../pytest-examples" }, + { name = "pytest-examples", specifier = ">=0.0.14" }, { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, ] @@ -1655,28 +1655,15 @@ wheels = [ [[package]] name = "pytest-examples" version = "0.0.14" -source = { directory = "../pytest-examples" } +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "black" }, { name = "pytest" }, { name = "ruff" }, ] - -[package.metadata] -requires-dist = [ - { name = "black", specifier = ">=23" }, - { name = "pytest", specifier = ">=7" }, - { name = "ruff", specifier = ">=0.5.0" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "coverage", extras = ["toml"], specifier = ">=7.6.1" }, - { name = "pytest-pretty", specifier = ">=1.2.0" }, -] -lint = [ - { name = "pre-commit", specifier = ">=3.5.0" }, - { name = "ruff", specifier = ">=0.7.4" }, +sdist = { url = "https://files.pythonhosted.org/packages/d2/a7/b81d5cf26e9713a2d4c8e6863ee009360c5c07a0cfb880456ec8b09adab7/pytest_examples-0.0.14.tar.gz", hash = "sha256:776d1910709c0c5ce01b29bfe3651c5312d5cfe5c063e23ca6f65aed9af23f09", size = 20767 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/99/f418071551ff2b5e8c06bd8b82b1f4fd472b5e4162f018773ba4ef52b6e8/pytest_examples-0.0.14-py3-none-any.whl", hash = "sha256:867a7ea105635d395df712a4b8d0df3bda4c3d78ae97a57b4f115721952b5e25", size = 17919 }, ] [[package]] From 28aef4ca1c926852092055992dcff303a330edbe Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 15 Nov 2024 21:18:02 +0000 Subject: [PATCH 6/7] running code examples in docstrings --- mkdocs.yml | 1 - pydantic_ai/agent.py | 19 ++++++++++++++++++- tests/test_examples.py | 14 ++++++++++++-- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index d1934ebb19..2b2c858805 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -147,7 +147,6 @@ plugins: show_signature_annotations: true signature_crossrefs: true group_by_category: false - show_source: false heading_level: 2 import: - url: https://docs.python.org/3/objects.inv diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 36e8fdb2b4..96212534b9 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -32,7 +32,24 @@ @final @dataclass(init=False) class Agent(Generic[AgentDeps, ResultData]): - """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.""" + """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. + + Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.dependencies.AgentDeps] + and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData]. + + By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. + + Minimal usage example: + + ```py + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + result = agent.run_sync('What is the capital of France?') + print(result.data) + # > Paris + ``` + """ # dataclass fields mostly for my sanity — knowing what attributes are available model: models.Model | models.KnownModelName | None diff --git a/tests/test_examples.py b/tests/test_examples.py index 56daee1a47..972d46a6c6 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -2,7 +2,7 @@ import asyncio import sys -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable from datetime import datetime from types import ModuleType from typing import Any @@ -19,10 +19,18 @@ from tests.conftest import ClientWithHandler -@pytest.mark.parametrize('example', find_examples('docs'), ids=str) +def find_filter_examples() -> Iterable[CodeExample]: + for ex in find_examples('docs', 'pydantic_ai'): + if ex.path.name != '_utils.py': + yield ex + + +@pytest.mark.parametrize('example', find_filter_examples(), ids=str) def test_docs_examples( example: CodeExample, eval_example: EvalExample, mocker: MockerFixture, client_with_handler: ClientWithHandler ): + if example.path.name == '_utils.py': + return # debug(example) mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model) mocker.patch('pydantic_ai._utils.datetime', MockedDatetime) @@ -71,6 +79,8 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyRespo return ModelTextResponse(content='Did you hear about the toothpaste scandal? They called it Colgate.') if m.role == 'user' and m.content == 'Explain?': return ModelTextResponse(content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.') + if m.role == 'user' and m.content == 'What is the capital of France?': + return ModelTextResponse(content='Paris') else: sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') From a5736f091b8d71376758281c0797ce877417b4cf Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 15 Nov 2024 22:46:04 +0000 Subject: [PATCH 7/7] fix example formatting --- pydantic_ai/agent.py | 2 +- pyproject.toml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 96212534b9..aa81d3d326 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -47,7 +47,7 @@ class Agent(Generic[AgentDeps, ResultData]): agent = Agent('openai:gpt-4o') result = agent.run_sync('What is the capital of France?') print(result.data) - # > Paris + #> Paris ``` """ diff --git a/pyproject.toml b/pyproject.toml index 7c340d561a..d69f9d60a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,7 +119,8 @@ ignore = [ convention = "google" [tool.ruff.format] -docstring-code-format = true +# don't format python in docstrings, pytest-examples takes care of it +docstring-code-format = false quote-style = "single" [tool.ruff.lint.per-file-ignores]