diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index de1eb4fefa..3d16317658 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: enable-cache: true - name: Install dependencies - run: uv sync --python 3.12 --frozen --all-extras + run: uv sync --python 3.12 --frozen --all-extras --no-dev --group lint - uses: pre-commit/action@v3.0.0 with: @@ -34,6 +34,19 @@ jobs: env: SKIP: no-commit-to-branch + # mypy and lint are a bit slower than other jobs, so we run them separately + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v3 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --python 3.12 --frozen --no-dev --group lint + - run: make typecheck-mypy docs: @@ -138,7 +151,7 @@ jobs: # https://github.com/marketplace/actions/alls-green#why used for branch protection checks check: if: always() - needs: [lint, docs, test-live, test, coverage] + needs: [lint, mypy, docs, test-live, test, coverage] runs-on: ubuntu-latest steps: diff --git a/Makefile b/Makefile index ac81b50190..35ccea20cb 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ .PHONY: install # Install the package, dependencies, and pre-commit for local development install: .uv .pre-commit - uv sync --frozen --all-extras --group docs + uv sync --frozen --all-extras --group lint --group docs pre-commit install --install-hooks .PHONY: format # Format the code @@ -67,12 +67,15 @@ docs: docs-serve: uv run mkdocs serve --no-strict -# install insiders packages for docs (to avoid running this on every build, `touch .docs-insiders-install`) +.PHONY: .docs-insiders-install # install insiders packages for docs if necessary .docs-insiders-install: - @echo 'installing insiders packages' - @uv pip install -U \ - --extra-index-url https://pydantic:${PPPR_TOKEN}@pppr.pydantic.dev/simple/ \ - mkdocs-material mkdocstrings-python +ifeq ($(shell uv pip show mkdocs-material | grep -q insiders && echo 'installed'), installed) + @echo 'insiders packages already installed' +else + @echo 'installing insiders packages...' + @uv pip install -U mkdocs-material mkdocstrings-python \ + --extra-index-url https://pydantic:${PPPR_TOKEN}@pppr.pydantic.dev/simple/ +endif .PHONY: docs-insiders # Build the documentation using insiders packages docs-insiders: .docs-insiders-install diff --git a/docs/api/agent.md b/docs/api/agent.md index 2dd569c96f..06de0fcc4d 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -15,8 +15,3 @@ - retriever_plain - retriever_context - result_validator - -::: pydantic_ai.agent - options: - members: - - KnownModelName diff --git a/docs/api/models/base.md b/docs/api/models/base.md index 48672830f6..ea2ed5f098 100644 --- a/docs/api/models/base.md +++ b/docs/api/models/base.md @@ -3,6 +3,7 @@ ::: pydantic_ai.models options: members: + - KnownModelName - Model - AgentModel - AbstractToolDefinition diff --git a/docs/concepts/index.md b/docs/concepts/agents.md similarity index 100% rename from docs/concepts/index.md rename to docs/concepts/agents.md diff --git a/docs/concepts/dependencies.md b/docs/concepts/dependencies.md index 843ddcfeca..f566113998 100644 --- a/docs/concepts/dependencies.md +++ b/docs/concepts/dependencies.md @@ -1,6 +1,6 @@ # Dependencies -PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](system-prompt.md), [retrievers](retrievers.md) and [result validators](result-validation.md#TODO). +PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](system-prompt.md), [retrievers](retrievers.md) and [result validators](results.md#TODO). Matching PydanticAI's design philosophy, our dependency system tries to use existing best practice in Python development rather than inventing esoteric "magic", this should make dependencies type-safe, understandable easier to test and ultimately easier to deploy in production. @@ -13,7 +13,7 @@ Here's an example of defining an agent that requires dependencies. (**Note:** dependencies aren't actually used in this example, see [Accessing Dependencies](#accessing-dependencies) below) -```python title="unused_dependencies.py" +```py title="unused_dependencies.py" from dataclasses import dataclass import httpx @@ -41,6 +41,7 @@ async def main(): deps=deps, # (3)! ) print(result.data) + #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. Define a dataclass to hold dependencies. @@ -78,7 +79,7 @@ agent = Agent( async def get_system_prompt(ctx: CallContext[MyDeps]) -> str: # (2)! response = await ctx.deps.http_client.get( # (3)! 'https://example.com', - headers={'Authorization': f'Bearer {ctx.deps.api_key}'} # (4)! + headers={'Authorization': f'Bearer {ctx.deps.api_key}'}, # (4)! ) response.raise_for_status() return f'Prompt: {response.text}' @@ -89,6 +90,7 @@ async def main(): deps = MyDeps('foobar', client) result = await agent.run('Tell me a joke.', deps=deps) print(result.data) + #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. [`CallContext`][pydantic_ai.dependencies.CallContext] may optionally be passed to a [`system_prompt`][pydantic_ai.Agent.system_prompt] function as the only argument. @@ -134,8 +136,7 @@ agent = Agent( @agent.system_prompt def get_system_prompt(ctx: CallContext[MyDeps]) -> str: # (2)! response = ctx.deps.http_client.get( - 'https://example.com', - headers={'Authorization': f'Bearer {ctx.deps.api_key}'} + 'https://example.com', headers={'Authorization': f'Bearer {ctx.deps.api_key}'} ) response.raise_for_status() return f'Prompt: {response.text}' @@ -148,6 +149,7 @@ async def main(): deps=deps, ) print(result.data) + #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. Here we use a synchronous `httpx.Client` instead of an asynchronous `httpx.AsyncClient`. @@ -157,7 +159,7 @@ _(This example is complete, it can be run "as is")_ ## Full Example -As well as system prompts, dependencies can be used in [retrievers](retrievers.md) and [result validators](result-validation.md#TODO). +As well as system prompts, dependencies can be used in [retrievers](retrievers.md) and [result validators](results.md#TODO). ```python title="full_example.py" hl_lines="27-35 38-48" from dataclasses import dataclass @@ -215,6 +217,7 @@ async def main(): deps = MyDeps('foobar', client) result = await agent.run('Tell me a joke.', deps=deps) print(result.data) + #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. To pass `CallContext` and to a retriever, us the [`retriever_context`][pydantic_ai.Agent.retriever_context] decorator. @@ -264,7 +267,6 @@ async def application_code(prompt: str) -> str: # (3)! app_deps = MyDeps('foobar', client) result = await joke_agent.run(prompt, deps=app_deps) # (4)! return result.data - ``` 1. Define a method on the dependency to make the system prompt easier to customise. @@ -273,7 +275,7 @@ async def application_code(prompt: str) -> str: # (3)! 4. Call the agent from within the application code, in a real application this call might be deep within a call stack. Note `app_deps` here will NOT be used when deps are overridden. ```py title="test_joke_app.py" hl_lines="10-12" -from joke_app import application_code, joke_agent, MyDeps +from joke_app import MyDeps, application_code, joke_agent class TestMyDeps(MyDeps): # (1)! @@ -284,8 +286,8 @@ class TestMyDeps(MyDeps): # (1)! async def test_application_code(): test_deps = TestMyDeps('test_key', None) # (2)! with joke_agent.override_deps(test_deps): # (3)! - joke = application_code('Tell me a joke.') # (4)! - assert joke == 'funny' + joke = await application_code('Tell me a joke.') # (4)! + assert joke.startswith('Did you hear about the toothpaste scandal?') ``` 1. Define a subclass of `MyDeps` in tests to customise the system prompt factory. @@ -314,7 +316,7 @@ joke_agent = Agent( system_prompt=( 'Use the "joke_factory" to generate some jokes, then choose the best. ' 'You must return just a single joke.' - ) + ), ) factory_agent = Agent('gemini-1.5-pro', result_type=list[str]) @@ -328,6 +330,7 @@ async def joke_factory(ctx: CallContext[MyDeps], count: int) -> str: result = joke_agent.run_sync('Tell me a joke.', deps=MyDeps(factory_agent)) print(result.data) +#> Did you hear about the toothpaste scandal? They called it Colgate. ``` ## Examples diff --git a/docs/concepts/message-history.md b/docs/concepts/message-history.md index 8fff79a8c9..697b15d98f 100644 --- a/docs/concepts/message-history.md +++ b/docs/concepts/message-history.md @@ -38,12 +38,50 @@ agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result = agent.run_sync('Tell me a joke.') print(result.data) +#> Did you hear about the toothpaste scandal? They called it Colgate. # all messages from the run print(result.all_messages()) +""" +[ + SystemPrompt(content='Be a helpful assistant.', role='system'), + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='Did you hear about the toothpaste scandal? They called it Colgate.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), +] +""" # messages excluding system prompts print(result.new_messages()) +""" +[ + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='Did you hear about the toothpaste scandal? They called it Colgate.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), +] +""" ``` _(This example is complete, it can be run "as is")_ @@ -54,15 +92,54 @@ from pydantic_ai import Agent agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') -async with agent.run_stream('Tell me a joke.') as result: - # incomplete messages before the stream finishes - print(result.all_messages()) - async for text in result.stream(): - print(text) - - # complete messages once the stream finishes - print(result.all_messages()) +async def main(): + async with agent.run_stream('Tell me a joke.') as result: + # incomplete messages before the stream finishes + print(result.all_messages()) + """ + [ + SystemPrompt(content='Be a helpful assistant.', role='system'), + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ] + """ + + async for text in result.stream(): + print(text) + #> Did you + #> Did you hear about + #> Did you hear about the toothpaste + #> Did you hear about the toothpaste scandal? They + #> Did you hear about the toothpaste scandal? They called it + #> Did you hear about the toothpaste scandal? They called it Colgate. + + # complete messages once the stream finishes + print(result.all_messages()) + """ + [ + SystemPrompt(content='Be a helpful assistant.', role='system'), + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='Did you hear about the toothpaste scandal? They called it Colgate.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), + ] + """ ``` _(This example is complete, it can be run "as is" inside an async context)_ @@ -92,11 +169,46 @@ agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result1 = agent.run_sync('Tell me a joke.') print(result1.data) +#> Did you hear about the toothpaste scandal? They called it Colgate. result2 = agent.run_sync('Explain?', message_history=result1.new_messages()) print(result2.data) +#> This is an excellent joke invent by Samuel Colvin, it needs no explanation. print(result2.all_messages()) +""" +[ + SystemPrompt(content='Be a helpful assistant.', role='system'), + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='Did you hear about the toothpaste scandal? They called it Colgate.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), + UserPrompt( + content='Explain?', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), +] +""" ``` _(This example is complete, it can be run "as is")_ @@ -113,11 +225,48 @@ agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result1 = agent.run_sync('Tell me a joke.') print(result1.data) +#> Did you hear about the toothpaste scandal? They called it Colgate. -result2 = agent.run_sync('Explain?', model='gemini-1.5-pro', message_history=result1.new_messages()) +result2 = agent.run_sync( + 'Explain?', model='gemini-1.5-pro', message_history=result1.new_messages() +) print(result2.data) +#> This is an excellent joke invent by Samuel Colvin, it needs no explanation. print(result2.all_messages()) +""" +[ + SystemPrompt(content='Be a helpful assistant.', role='system'), + UserPrompt( + content='Tell me a joke.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='Did you hear about the toothpaste scandal? They called it Colgate.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), + UserPrompt( + content='Explain?', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='user', + ), + ModelTextResponse( + content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.', + timestamp=datetime.datetime( + 2032, 1, 2, 3, 4, 5, 6, tzinfo=datetime.timezone.utc + ), + role='model-text-response', + ), +] +""" ``` ## Last Run Messages diff --git a/docs/concepts/result-validation.md b/docs/concepts/results.md similarity index 100% rename from docs/concepts/result-validation.md rename to docs/concepts/results.md diff --git a/docs/index.md b/docs/index.md index 5b93a77ce1..ee0d3991f5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -38,6 +38,7 @@ Partial example of using retrievers to help an LLM respond to a user's query abo ```py title="weather_agent.py" import httpx + from pydantic_ai import Agent, CallContext weather_agent = Agent( # (1)! @@ -75,9 +76,9 @@ async def main(): deps=client, ) print(result.data) # (9)! - # > 'The weather in West London is raining, while in Wiltshire it is sunny.' + #> The weather in West London is raining, while in Wiltshire it is sunny. - print(result.all_messages()) # (10)! + messages = result.all_messages() # (10)! ``` 1. An agent that can tell users about the weather in a particular location. Agents combine a system prompt, a response type (here `str`) and "retrievers" (aka tools). @@ -89,7 +90,7 @@ async def main(): 7. Multiple retrievers can be registered with the same agent, the LLM can choose which (if any) retrievers to call in order to respond to a user. 8. Run the agent asynchronously, conducting a conversation with the LLM until a final response is reached. You can also run agents synchronously with `run_sync`. Internally agents are all async, so `run_sync` is a helper using `asyncio.run` to call `run()`. 9. The response from the LLM, in this case a `str`, Agents are generic in both the type of `deps` and `result_type`, so calls are typed end-to-end. -10. `result.all_messages()` includes details of messages exchanged, this is useful both to understand the conversation that took place and useful if you want to continue the conversation later — messages can be passed back to later `run/run_sync` calls. +10. [`result.all_messages()`](concepts/message-history.md) includes details of messages exchanged, this is useful both to understand the conversation that took place and useful if you want to continue the conversation later — messages can be passed back to later `run/run_sync` calls. !!! tip "Complete `weather_agent.py` example" This example is incomplete for the sake of brevity; you can find a complete `weather_agent.py` example [here](examples/weather-agent.md). diff --git a/mkdocs.yml b/mkdocs.yml index 27c227126c..2b2c858805 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -14,11 +14,11 @@ nav: - Introduction: index.md - install.md - Concepts: - - concepts/index.md + - concepts/agents.md - concepts/dependencies.md - concepts/retrievers.md - concepts/system-prompt.md - - concepts/result-validation.md + - concepts/results.md - concepts/message-history.md - concepts/streaming.md - concepts/testing-evals.md @@ -147,7 +147,6 @@ plugins: show_signature_annotations: true signature_crossrefs: true group_by_category: false - show_source: false heading_level: 2 import: - url: https://docs.python.org/3/objects.inv diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index fb4211b06b..aa81d3d326 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -4,7 +4,7 @@ from collections.abc import AsyncIterator, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass -from typing import Any, Callable, Generic, Literal, cast, final, overload +from typing import Any, Callable, Generic, cast, final, overload import logfire_api from typing_extensions import assert_never @@ -22,24 +22,7 @@ from .dependencies import AgentDeps, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc from .result import ResultData -__all__ = 'Agent', 'KnownModelName' - -KnownModelName = Literal[ - 'openai:gpt-4o', - 'openai:gpt-4o-mini', - 'openai:gpt-4-turbo', - 'openai:gpt-4', - 'openai:o1-preview', - 'openai:o1-mini', - 'openai:gpt-3.5-turbo', - 'gemini-1.5-flash', - 'gemini-1.5-pro', - 'test', -] -"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent]. - -`KnownModelName` is provided as a concise way to specify a model. -""" +__all__ = ('Agent',) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') @@ -49,10 +32,27 @@ @final @dataclass(init=False) class Agent(Generic[AgentDeps, ResultData]): - """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.""" + """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. + + Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.dependencies.AgentDeps] + and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData]. + + By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. + + Minimal usage example: + + ```py + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + result = agent.run_sync('What is the capital of France?') + print(result.data) + #> Paris + ``` + """ # dataclass fields mostly for my sanity — knowing what attributes are available - model: models.Model | KnownModelName | None + model: models.Model | models.KnownModelName | None """The default model configured for this agent.""" _result_schema: _result.ResultSchema[ResultData] | None _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] @@ -74,7 +74,7 @@ class Agent(Generic[AgentDeps, ResultData]): def __init__( self, - model: models.Model | KnownModelName | None = None, + model: models.Model | models.KnownModelName | None = None, result_type: type[ResultData] = str, *, system_prompt: str | Sequence[str] = (), @@ -101,7 +101,7 @@ def __init__( result_tool_name: The name of the tool to use for the final result. result_tool_description: The description of the final result tool. result_retries: The maximum number of retries to allow for result validation, defaults to `retries`. - defer_model_check: by default, if you provide a [named][pydantic_ai.agent.KnownModelName] model, + defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` to defer the evaluation until the first run. Useful if you want to @@ -132,7 +132,7 @@ async def run( user_prompt: str, *, message_history: list[_messages.Message] | None = None, - model: models.Model | KnownModelName | None = None, + model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt in async mode. @@ -201,7 +201,7 @@ def run_sync( user_prompt: str, *, message_history: list[_messages.Message] | None = None, - model: models.Model | KnownModelName | None = None, + model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. @@ -225,7 +225,7 @@ async def run_stream( user_prompt: str, *, message_history: list[_messages.Message] | None = None, - model: models.Model | KnownModelName | None = None, + model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -310,7 +310,7 @@ def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]: self._override_deps = override_deps_before @contextmanager - def override_model(self, overriding_model: models.Model | KnownModelName) -> Iterator[None]: + def override_model(self, overriding_model: models.Model | models.KnownModelName) -> Iterator[None]: """Context manager to temporarily override the model used by the agent. Args: @@ -411,7 +411,7 @@ def _register_retriever( return retriever async def _get_agent_model( - self, model: models.Model | KnownModelName | None + self, model: models.Model | models.KnownModelName | None ) -> tuple[models.Model, models.Model | None, models.AgentModel]: """Create a model configured for this agent. diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index 0378ab4c16..27e8c91202 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -11,7 +11,7 @@ from contextlib import asynccontextmanager, contextmanager from datetime import datetime from functools import cache -from typing import TYPE_CHECKING, Protocol, Union +from typing import TYPE_CHECKING, Literal, Protocol, Union import httpx @@ -19,10 +19,27 @@ if TYPE_CHECKING: from .._utils import ObjectJsonSchema - from ..agent import KnownModelName from ..result import Cost +KnownModelName = Literal[ + 'openai:gpt-4o', + 'openai:gpt-4o-mini', + 'openai:gpt-4-turbo', + 'openai:gpt-4', + 'openai:o1-preview', + 'openai:o1-mini', + 'openai:gpt-3.5-turbo', + 'gemini-1.5-flash', + 'gemini-1.5-pro', + 'test', +] +"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent]. + +`KnownModelName` is provided as a concise way to specify a model. +""" + + class Model(ABC): """Abstract class for a model.""" diff --git a/pydantic_ai/models/function.py b/pydantic_ai/models/function.py index f81945effb..c0dbcc196d 100644 --- a/pydantic_ai/models/function.py +++ b/pydantic_ai/models/function.py @@ -8,11 +8,11 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence +import inspect +from collections.abc import AsyncIterator, Awaitable, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime -from itertools import chain from typing import Callable, Union, cast from typing_extensions import TypeAlias, overload @@ -114,11 +114,17 @@ class DeltaToolCall: """A mapping of tool call IDs to incremental changes.""" # TODO these should allow coroutines -FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], ModelAnyResponse] +FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[ModelAnyResponse, Awaitable[ModelAnyResponse]]] """A function used to generate a non-streamed response.""" -StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[Iterable[str], Iterable[DeltaToolCalls]]] -"""A function used to generate a streamed response.""" +StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]] +"""A function used to generate a streamed response. + +While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should +really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls]`, + +E.g. you need to yield all text or all `DeltaToolCalls`, not mix them. +""" @dataclass @@ -131,38 +137,46 @@ class FunctionAgentModel(AgentModel): async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]: assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests' - return self.function(messages, self.agent_info), result.Cost() + if inspect.iscoroutinefunction(self.function): + return await self.function(messages, self.agent_info), result.Cost() + else: + response = await _utils.run_in_executor(self.function, messages, self.agent_info) + return cast(ModelAnyResponse, response), result.Cost() @asynccontextmanager async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]: assert ( self.stream_function is not None ), 'FunctionModel must receive a `stream_function` to support streamed requests' - response_data = iter(self.stream_function(messages, self.agent_info)) + response_stream = self.stream_function(messages, self.agent_info) try: - first = next(response_data) - except StopIteration as e: + first = await response_stream.__anext__() + except StopAsyncIteration as e: raise ValueError('Stream function must return at least one item') from e if isinstance(first, str): - text_stream = cast(Iterable[str], response_data) - yield FunctionStreamTextResponse(iter(chain([first], text_stream))) + text_stream = cast(AsyncIterator[str], response_stream) + yield FunctionStreamTextResponse(first, text_stream) else: - structured_stream = cast(Iterable[DeltaToolCalls], response_data) - # noinspection PyTypeChecker - yield FunctionStreamStructuredResponse(iter(chain([first], structured_stream)), {}) + structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream) + yield FunctionStreamStructuredResponse(first, structured_stream) @dataclass class FunctionStreamTextResponse(StreamTextResponse): """Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" - _iter: Iterator[str] + _next: str | None + _iter: AsyncIterator[str] _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) _buffer: list[str] = field(default_factory=list, init=False) async def __anext__(self) -> None: - self._buffer.append(_utils.sync_anext(self._iter)) + if self._next is not None: + self._buffer.append(self._next) + self._next = None + else: + self._buffer.append(await self._iter.__anext__()) def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer @@ -179,12 +193,17 @@ def timestamp(self) -> datetime: class FunctionStreamStructuredResponse(StreamStructuredResponse): """Implementation of `StreamStructuredResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" - _iter: Iterator[DeltaToolCalls] - _delta_tool_calls: dict[int, DeltaToolCall] + _next: DeltaToolCalls | None + _iter: AsyncIterator[DeltaToolCalls] + _delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict) _timestamp: datetime = field(default_factory=_utils.now_utc) async def __anext__(self) -> None: - tool_call = _utils.sync_anext(self._iter) + if self._next is not None: + tool_call = self._next + self._next = None + else: + tool_call = await self._iter.__anext__() for key, new in tool_call.items(): if current := self._delta_tool_calls.get(key): diff --git a/pydantic_ai_examples/pydantic_model.py b/pydantic_ai_examples/pydantic_model.py index 92a3258da6..83804f2287 100644 --- a/pydantic_ai_examples/pydantic_model.py +++ b/pydantic_ai_examples/pydantic_model.py @@ -12,7 +12,7 @@ from pydantic import BaseModel from pydantic_ai import Agent -from pydantic_ai.agent import KnownModelName +from pydantic_ai.models import KnownModelName # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') diff --git a/pydantic_ai_examples/stream_markdown.py b/pydantic_ai_examples/stream_markdown.py index ccfee6848e..da57b5f3ee 100644 --- a/pydantic_ai_examples/stream_markdown.py +++ b/pydantic_ai_examples/stream_markdown.py @@ -16,7 +16,7 @@ from rich.text import Text from pydantic_ai import Agent -from pydantic_ai.agent import KnownModelName +from pydantic_ai.models import KnownModelName # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') diff --git a/pyproject.toml b/pyproject.toml index dab9f85a91..d69f9d60a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,16 +57,20 @@ examples = [ [dependency-groups] dev = [ + "anyio>=4.5.0", + "devtools>=0.12.2", + "coverage[toml]>=7.6.2", "dirty-equals>=0.8.0", - "mypy>=1.11.2", - "pyright>=1.1.388", + "inline-snapshot>=0.14", "pytest>=8.3.3", + "pytest-examples>=0.0.14", + "pytest-mock>=3.14.0", "pytest-pretty>=1.2.0", - "inline-snapshot>=0.14", +] +lint = [ + "mypy>=1.11.2", + "pyright>=1.1.388", "ruff>=0.6.9", - "coverage[toml]>=7.6.2", - "devtools>=0.12.2", - "anyio>=4.5.0", ] docs = [ "mkdocs", @@ -115,7 +119,8 @@ ignore = [ convention = "google" [tool.ruff.format] -docstring-code-format = true +# don't format python in docstrings, pytest-examples takes care of it +docstring-code-format = false quote-style = "single" [tool.ruff.lint.per-file-ignores] diff --git a/tests/conftest.py b/tests/conftest.py index 75156fae02..7592d82fc3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import re import secrets import sys -from collections.abc import Iterator +from collections.abc import AsyncIterator, Iterator from datetime import datetime from pathlib import Path from types import ModuleType @@ -82,7 +82,7 @@ def allow_model_requests(): @pytest.fixture -async def client_with_handler(): +async def client_with_handler() -> AsyncIterator[ClientWithHandler]: client: httpx.AsyncClient | None = None def create_client(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.AsyncClient: diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 0446ab1023..1475a93d4d 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -1,6 +1,6 @@ import json import re -from collections.abc import Iterable +from collections.abc import AsyncIterator from dataclasses import asdict from datetime import timezone @@ -29,7 +29,7 @@ pytestmark = pytest.mark.anyio -def return_last(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: +async def return_last(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: last = messages[-1] response = asdict(last) response.pop('timestamp', None) @@ -84,7 +84,7 @@ def test_simple(): ) -def weather_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: # pragma: no cover +async def weather_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: # pragma: no cover assert info.allow_text_result assert info.retrievers.keys() == {'get_location', 'get_weather'} last = messages[-1] @@ -177,7 +177,7 @@ def test_weather(): assert result.data == 'Sunny in Ipswich' -def call_function_model(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: # pragma: no cover +async def call_function_model(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: # pragma: no cover last = messages[-1] if last.role == 'user': if last.content.startswith('{'): @@ -221,7 +221,7 @@ def test_var_args(): ) -def call_retriever(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: +async def call_retriever(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: if len(messages) == 1: assert len(info.retrievers) == 1 retriever_id = next(iter(info.retrievers.keys())) @@ -308,7 +308,7 @@ def spam() -> str: def test_register_all(): - def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: + async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: return ModelTextResponse( f'messages={len(messages)} allow_text_result={info.allow_text_result} retrievers={len(info.retrievers)}' ) @@ -349,7 +349,7 @@ def test_call_all(): def test_retry_str(): call_count = 0 - def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: + async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: nonlocal call_count call_count += 1 @@ -371,7 +371,7 @@ async def validate_result(r: str) -> str: def test_retry_result_type(): call_count = 0 - def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: + async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: nonlocal call_count call_count += 1 @@ -393,7 +393,7 @@ async def validate_result(r: Foo) -> Foo: assert result.data == snapshot(Foo(x=2)) -def stream_text_function(_messages: list[Message], _: AgentInfo) -> Iterable[str]: +async def stream_text_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[str]: yield 'hello ' yield 'world' @@ -416,7 +416,9 @@ class Foo(BaseModel): async def test_stream_structure(): - def stream_structured_function(_messages: list[Message], agent_info: AgentInfo) -> Iterable[DeltaToolCalls]: + async def stream_structured_function( + _messages: list[Message], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls]: assert agent_info.result_tools is not None assert len(agent_info.result_tools) == 1 name = agent_info.result_tools[0].name @@ -439,8 +441,13 @@ async def test_pass_both(): Agent(FunctionModel(return_last, stream_function=stream_text_function)) +async def stream_text_function_empty(_messages: list[Message], _: AgentInfo) -> AsyncIterator[str]: + if False: + yield 'hello ' + + async def test_return_empty(): - agent = Agent(FunctionModel(stream_function=lambda _, __: [])) + agent = Agent(FunctionModel(stream_function=stream_text_function_empty)) with pytest.raises(ValueError, match='Stream function must return at least one item'): async with agent.run_stream(''): pass diff --git a/tests/test_examples.py b/tests/test_examples.py new file mode 100644 index 0000000000..972d46a6c6 --- /dev/null +++ b/tests/test_examples.py @@ -0,0 +1,109 @@ +from __future__ import annotations as _annotations + +import asyncio +import sys +from collections.abc import AsyncIterator, Iterable +from datetime import datetime +from types import ModuleType +from typing import Any + +import httpx +import pytest +from devtools import debug +from pytest_examples import CodeExample, EvalExample, find_examples +from pytest_mock import MockerFixture + +from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse +from pydantic_ai.models import KnownModelName, Model +from pydantic_ai.models.function import AgentInfo, DeltaToolCalls, FunctionModel +from tests.conftest import ClientWithHandler + + +def find_filter_examples() -> Iterable[CodeExample]: + for ex in find_examples('docs', 'pydantic_ai'): + if ex.path.name != '_utils.py': + yield ex + + +@pytest.mark.parametrize('example', find_filter_examples(), ids=str) +def test_docs_examples( + example: CodeExample, eval_example: EvalExample, mocker: MockerFixture, client_with_handler: ClientWithHandler +): + if example.path.name == '_utils.py': + return + # debug(example) + mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model) + mocker.patch('pydantic_ai._utils.datetime', MockedDatetime) + + mocker.patch('httpx.Client.get', side_effect=http_request) + mocker.patch('httpx.Client.post', side_effect=http_request) + mocker.patch('httpx.AsyncClient.get', side_effect=async_http_request) + mocker.patch('httpx.AsyncClient.post', side_effect=async_http_request) + + ruff_ignore: list[str] = ['D'] + if str(example.path).endswith('docs/index.md'): + ruff_ignore.append('F841') + eval_example.set_config(ruff_ignore=ruff_ignore) + + call_name = 'main' + if 'def test_application_code' in example.source: + call_name = 'test_application_code' + + if eval_example.update_examples: + eval_example.format(example) + module_dict = eval_example.run_print_update(example, call=call_name) + else: + eval_example.lint(example) + module_dict = eval_example.run_print_check(example, call=call_name) + + if example.path.name == 'dependencies.md' and 'title="joke_app.py"' in example.prefix: + sys.modules['joke_app'] = module = ModuleType('joke_app') + module.__dict__.update(module_dict) + + +def http_request(url: str, **kwargs: Any) -> httpx.Response: + # sys.stdout.write(f'GET {args=} {kwargs=}\n') + request = httpx.Request('GET', url, **kwargs) + return httpx.Response(status_code=202, content='', request=request) + + +async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: + return http_request(url, **kwargs) + + +async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: + m = messages[-1] + if m.role == 'user' and m.content == 'What is the weather like in West London and in Wiltshire?': + return ModelTextResponse(content='The weather in West London is raining, while in Wiltshire it is sunny.') + if m.role == 'user' and m.content == 'Tell me a joke.': + return ModelTextResponse(content='Did you hear about the toothpaste scandal? They called it Colgate.') + if m.role == 'user' and m.content == 'Explain?': + return ModelTextResponse(content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.') + if m.role == 'user' and m.content == 'What is the capital of France?': + return ModelTextResponse(content='Paris') + else: + sys.stdout.write(str(debug.format(messages, info))) + raise RuntimeError(f'Unexpected message: {m}') + + +async def stream_model_logic(messages: list[Message], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: + m = messages[-1] + if m.role == 'user' and m.content == 'Tell me a joke.': + *words, last_word = 'Did you hear about the toothpaste scandal? They called it Colgate.'.split(' ') + for work in words: + yield f'{work} ' + await asyncio.sleep(0.05) + yield last_word + else: + sys.stdout.write(str(debug.format(messages, info))) + raise RuntimeError(f'Unexpected message: {m}') + + +def mock_infer_model(_model: Model | KnownModelName) -> Model: + return FunctionModel(model_logic, stream_function=stream_model_logic) + + +class MockedDatetime(datetime): + @classmethod + def now(cls, tz: Any = None) -> datetime: # type: ignore + return datetime(2032, 1, 2, 3, 4, 5, 6, tzinfo=tz) diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 465956e705..11d9d5fe46 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -67,7 +67,7 @@ async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover return f'{foo} {bar}' -def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelAnyResponse: +async def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelAnyResponse: assert len(info.retrievers) == 1 r = next(iter(info.retrievers.values())) return ModelTextResponse(json.dumps(r.json_schema)) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index cb8a3385e2..4f60397158 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations import json -from collections.abc import Iterable +from collections.abc import AsyncIterator from datetime import timezone import pytest @@ -81,7 +81,7 @@ async def test_streamed_structured_response(): async def test_structured_response_iter(): - def text_stream(_messages: list[Message], agent_info: AgentInfo) -> Iterable[DeltaToolCalls]: + async def text_stream(_messages: list[Message], agent_info: AgentInfo) -> AsyncIterator[DeltaToolCalls]: assert agent_info.result_tools is not None assert len(agent_info.result_tools) == 1 name = agent_info.result_tools[0].name @@ -145,11 +145,12 @@ async def test_streamed_text_stream(): async def test_plain_response(): call_index = 0 - def text_stream(_messages: list[Message], _: AgentInfo) -> list[str]: + async def text_stream(_messages: list[Message], _: AgentInfo) -> AsyncIterator[str]: nonlocal call_index call_index += 1 - return ['hello ', 'world'] + yield 'hello ' + yield 'world' agent = Agent(FunctionModel(stream_function=text_stream), result_type=tuple[str, str]) @@ -161,9 +162,9 @@ def text_stream(_messages: list[Message], _: AgentInfo) -> list[str]: async def test_call_retriever(): - def stream_structured_function( + async def stream_structured_function( messages: list[Message], agent_info: AgentInfo - ) -> Iterable[DeltaToolCalls] | Iterable[str]: + ) -> AsyncIterator[DeltaToolCalls | str]: if len(messages) == 1: assert agent_info.retrievers is not None assert len(agent_info.retrievers) == 1 @@ -226,7 +227,7 @@ async def ret_a(x: str) -> str: async def test_call_retriever_empty(): - def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterable[DeltaToolCalls] | Iterable[str]: + async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]: yield {} agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int]) @@ -237,7 +238,7 @@ def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterab async def test_call_retriever_wrong_name(): - def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterable[DeltaToolCalls] | Iterable[str]: + async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]: yield {0: DeltaToolCall(name='foobar', json_args='{}')} agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int]) diff --git a/uv.lock b/uv.lock index db6f810631..755a4ad682 100644 --- a/uv.lock +++ b/uv.lock @@ -1447,17 +1447,21 @@ dev = [ { name = "devtools" }, { name = "dirty-equals" }, { name = "inline-snapshot" }, - { name = "mypy" }, - { name = "pyright" }, { name = "pytest" }, + { name = "pytest-examples" }, + { name = "pytest-mock" }, { name = "pytest-pretty" }, - { name = "ruff" }, ] docs = [ { name = "mkdocs" }, { name = "mkdocs-material", extra = ["imaging"] }, { name = "mkdocstrings-python" }, ] +lint = [ + { name = "mypy" }, + { name = "pyright" }, + { name = "ruff" }, +] [package.metadata] requires-dist = [ @@ -1483,17 +1487,21 @@ dev = [ { name = "devtools", specifier = ">=0.12.2" }, { name = "dirty-equals", specifier = ">=0.8.0" }, { name = "inline-snapshot", specifier = ">=0.14" }, - { name = "mypy", specifier = ">=1.11.2" }, - { name = "pyright", specifier = ">=1.1.388" }, { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-examples", specifier = ">=0.0.14" }, + { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, - { name = "ruff", specifier = ">=0.6.9" }, ] docs = [ { name = "mkdocs" }, { name = "mkdocs-material", extras = ["imaging"] }, { name = "mkdocstrings-python" }, ] +lint = [ + { name = "mypy", specifier = ">=1.11.2" }, + { name = "pyright", specifier = ">=1.1.388" }, + { name = "ruff", specifier = ">=0.6.9" }, +] [[package]] name = "pydantic-core" @@ -1644,6 +1652,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, ] +[[package]] +name = "pytest-examples" +version = "0.0.14" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "black" }, + { name = "pytest" }, + { name = "ruff" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d2/a7/b81d5cf26e9713a2d4c8e6863ee009360c5c07a0cfb880456ec8b09adab7/pytest_examples-0.0.14.tar.gz", hash = "sha256:776d1910709c0c5ce01b29bfe3651c5312d5cfe5c063e23ca6f65aed9af23f09", size = 20767 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/99/f418071551ff2b5e8c06bd8b82b1f4fd472b5e4162f018773ba4ef52b6e8/pytest_examples-0.0.14-py3-none-any.whl", hash = "sha256:867a7ea105635d395df712a4b8d0df3bda4c3d78ae97a57b4f115721952b5e25", size = 17919 }, +] + +[[package]] +name = "pytest-mock" +version = "3.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/90/a955c3ab35ccd41ad4de556596fa86685bf4fc5ffcc62d22d856cfd4e29a/pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0", size = 32814 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/3b/b26f90f74e2986a82df6e7ac7e319b8ea7ccece1caec9f8ab6104dc70603/pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f", size = 9863 }, +] + [[package]] name = "pytest-pretty" version = "1.2.0"