diff --git a/docs/agents.md b/docs/agents.md index c723dff546..1805510f19 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -66,7 +66,7 @@ There are three ways to run an agent: Here's a simple example demonstrating all three: -```python title="run_agent.py" +```py title="run_agent.py" from pydantic_ai import Agent agent = Agent('openai:gpt-4o') @@ -95,7 +95,7 @@ An agent **run** might represent an entire conversation — there's no limit to Here's an example of a conversation comprised of multiple runs: -```python title="conversation_example.py" hl_lines="13" +```py title="conversation_example.py" hl_lines="13" from pydantic_ai import Agent agent = Agent('openai:gpt-4o') @@ -131,7 +131,7 @@ You can add both to a single agent; they're concatenated in the order they're de Here's an example using both types of system prompts: -```python title="system_prompts.py" +```py title="system_prompts.py" from datetime import date from pydantic_ai import Agent, CallContext @@ -228,7 +228,7 @@ _(This example is complete, it can be run "as is")_ Let's print the messages from that game to see what happened: -```python title="dice_game_messages.py" +```py title="dice_game_messages.py" from dice_game import dice_result print(dice_result.all_messages()) @@ -417,7 +417,7 @@ If models behave unexpectedly (e.g., the retry limit is exceeded, or their API r In these cases, [`agent.last_run_messages`][pydantic_ai.Agent.last_run_messages] can be used to access the messages exchanged during the run to help diagnose the issue. -```python +```py from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehaviour agent = Agent('openai:gpt-4o') diff --git a/docs/dependencies.md b/docs/dependencies.md index 79b40f6c26..c1f7aa1b2c 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -55,7 +55,7 @@ _(This example is complete, it can be run "as is")_ Dependencies are accessed through the [`CallContext`][pydantic_ai.dependencies.CallContext] type, this should be the first parameter of system prompt functions etc. -```python title="system_prompt_dependencies.py" hl_lines="20-27" +```py title="system_prompt_dependencies.py" hl_lines="20-27" from dataclasses import dataclass import httpx @@ -113,7 +113,7 @@ to use `async` methods where dependencies perform IO, although synchronous depen Here's the same example as above, but with a synchronous dependency: -```python title="sync_dependencies.py" +```py title="sync_dependencies.py" from dataclasses import dataclass import httpx @@ -161,7 +161,7 @@ _(This example is complete, it can be run "as is")_ As well as system prompts, dependencies can be used in [retrievers](agents.md#retrievers) and [result validators](results.md#result-validators-functions). -```python title="full_example.py" hl_lines="27-35 38-48" +```py title="full_example.py" hl_lines="27-35 38-48" from dataclasses import dataclass import httpx diff --git a/docs/index.md b/docs/index.md index e938e01e90..406485c422 100644 --- a/docs/index.md +++ b/docs/index.md @@ -136,7 +136,7 @@ async def main(): 10. The result will be validated with Pydantic to guarantee it is a `SupportResult`, since the agent is generic, it'll also be typed as a `SupportResult` to aid with static type checking. 11. In a real use case, you'd add many more retrievers and a longer system prompt to the agent to extend the context it's equipped with and support it can provide. 12. This is a simple sketch of a database connection, used to keep the example short and readable. In reality, you'd be connecting to an external database (e.g. PostgreSQL) to get information about customers. -13. This [Pydantic](https://docs.pydantic.dev) model is used to constrain the structured data returned by the agent. From this simple definition, Pydantic builds teh JSON Schema that tells the LLM how to return the data, and performs validation to guarantee the data is correct at the end of the conversation. +13. This [Pydantic](https://docs.pydantic.dev) model is used to constrain the structured data returned by the agent. From this simple definition, Pydantic builds the JSON Schema that tells the LLM how to return the data, and performs validation to guarantee the data is correct at the end of the conversation. !!! tip "Complete `bank_support.py` example" This example is incomplete for the sake of brevity (the definition of `DatabaseConn` is missing); you can find a complete `bank_support.py` example [here](examples/bank-support.md). diff --git a/docs/message-history.md b/docs/message-history.md index 46483da0fb..6cbac2b6ff 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -25,7 +25,7 @@ and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`A Example of accessing methods on a [`RunResult`][pydantic_ai.result.RunResult] : -```python title="run_result_messages.py" hl_lines="10 28" +```py title="run_result_messages.py" hl_lines="10 28" from pydantic_ai import Agent agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') @@ -73,7 +73,7 @@ _(This example is complete, it can be run "as is")_ Example of accessing methods on a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] : -```python title="streamed_run_result_messages.py" hl_lines="9 31" +```py title="streamed_run_result_messages.py" hl_lines="9 31" from pydantic_ai import Agent agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') @@ -96,11 +96,9 @@ async def main(): async for text in result.stream(): print(text) - #> Did you - #> Did you hear about + #> Did you hear #> Did you hear about the toothpaste - #> Did you hear about the toothpaste scandal? They - #> Did you hear about the toothpaste scandal? They called it + #> Did you hear about the toothpaste scandal? They called #> Did you hear about the toothpaste scandal? They called it Colgate. # complete messages once the stream finishes @@ -190,7 +188,7 @@ Since messages are defined by simple dataclasses, you can manually create and ma The message format is independent of the model used, so you can use messages in different agents, or the same agent with different models. -```python +```py from pydantic_ai import Agent agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') diff --git a/docs/results.md b/docs/results.md index d4b1610d3e..0c0a217677 100644 --- a/docs/results.md +++ b/docs/results.md @@ -1,41 +1,311 @@ -## Ending runs +Results are the final values returned from [running an agent](agents.md#running-agents). +The result values are wrapped in [`RunResult`][pydantic_ai.result.RunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [cost][pydantic_ai.result.Cost] of the run and [message history](message-history.md#accessing-messages-from-results) -**TODO** +Both `RunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved. -* runs end when either a plain text response is received or the model calls a tool associated with one of the structured result types -* example -* we should add `message_limit` (number of model messages) and `cost_limit` to `run()` etc. +```py title="olympics.py" +from pydantic import BaseModel -## Structured result validation +from pydantic_ai import Agent -**TODO** -* structured results (like retrievers) use Pydantic, Pydantic builds the JSON schema and does the validation -* PydanticAI tries hard to simplify the schema, this means: - * if the return type is `str` or a union including `str`, plain text responses are enabled - * if the schema is a union (after remove `str` from the members), each member is registered as its own tool call - * if the schema is not an object, the result type is wrapped in a single element object +class CityLocation(BaseModel): + city: str + country: str -## Result validators functions -**TODO** +agent = Agent('gemini-1.5-flash', result_type=CityLocation) +result = agent.run_sync('Where the olympics held in 2012?') +print(result.data) +#> city='London' country='United Kingdom' +print(result.cost()) +#> Cost(request_tokens=56, response_tokens=8, total_tokens=64, details=None) +``` -* Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. PydanticAI provides a way to add validation functions via the [`agent.result_validator`][pydantic_ai.Agent.result_validator] decorator. -* example +_(This example is complete, it can be run "as is")_ + +Runs end when either a plain text response is received or the model calls a tool associated with one of the structured result types. We will add limits to make sure a run doesn't go on indefinitely, see [#70](https://github.com/pydantic/pydantic-ai/issues/70). + +## Result data {#structured-result-validation} + +When the result type is `str`, or a union including `str`, plain text responses are enabled on the model, and the raw text response from the model is used as the response data. + +If the result type is a union with multiple members (after remove `str` from the members), each member is registered as a separate tool with the model in order to reduce the complexity of the tool schemas and maximise the changes a model will respond correctly. + +If the result type schema is not of type `"object"`, the result type is wrapped in a single element object, so the schema of all tools registered with the model are object schemas. + +Structured results (like retrievers) use Pydantic to build the JSON schema used for the tool, and to validate the data returned by the model. + +!!! note "Bring on PEP-747" + Until [PEP-747](https://peps.python.org/pep-0747/) "Annotating Type Forms" lands, unions are not valid as `type`s in Python. + + When creating the agent we nee to `# type: ignore` the `result_type` argument, and add a type hint to tell type checkers about the type of the agent. + +Here's an example of returning either text or a structured value + +```py title="box_or_error.py" +from typing import Union + +from pydantic import BaseModel + +from pydantic_ai import Agent + + +class Box(BaseModel): + width: int + height: int + depth: int + units: str + + +agent: Agent[None, Union[Box, str]] = Agent( + 'openai:gpt-4o-mini', + result_type=Union[Box, str], # type: ignore + system_prompt=( + "Extract me the dimensions of a box, " + "if you can't extract all data, ask the user to try again." + ), +) + +result = agent.run_sync('The box is 10x20x30') +print(result.data) +#> Please provide the units for the dimensions (e.g., cm, in, m). + +result = agent.run_sync('The box is 10x20x30 cm') +print(result.data) +#> width=10 height=20 depth=30 units='cm' +``` + +_(This example is complete, it can be run "as is")_ + +Here's an example of using a union return type which registered multiple tools, and wraps non-object schemas in an object: + +```py title="colors_or_sizes.py" +from typing import Union + +from pydantic_ai import Agent + +agent: Agent[None, Union[list[str], list[int]]] = Agent( + 'openai:gpt-4o-mini', + result_type=Union[list[str], list[int]], # type: ignore + system_prompt='Extract either colors or sizes from the shapes provided.', +) + +result = agent.run_sync('red square, blue circle, green triangle') +print(result.data) +#> ['red', 'blue', 'green'] + +result = agent.run_sync('square size 10, circle size 20, triangle size 30') +print(result.data) +#> [10, 20, 30] +``` + +_(This example is complete, it can be run "as is")_ + +### Result validators functions + +Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. PydanticAI provides a way to add validation functions via the [`agent.result_validator`][pydantic_ai.Agent.result_validator] decorator. + +Here's a simplified variant of the [SQL Generation example](examples/sql-gen.md): + +```py title="sql_gen.py" +from typing import Union + +from fake_database import DatabaseConn, QueryError +from pydantic import BaseModel + +from pydantic_ai import Agent, CallContext, ModelRetry + + +class Success(BaseModel): + sql_query: str + + +class InvalidRequest(BaseModel): + error_message: str + + +Response = Union[Success, InvalidRequest] +agent: Agent[DatabaseConn, Response] = Agent( + 'gemini-1.5-flash', + result_type=Response, # type: ignore + deps_type=DatabaseConn, + system_prompt='Generate PostgreSQL flavored SQL queries based on user input.', +) + + +@agent.result_validator +async def validate_result(ctx: CallContext[DatabaseConn], result: Response) -> Response: + if isinstance(result, InvalidRequest): + return result + try: + await ctx.deps.execute(f'EXPLAIN {result.sql_query}') + except QueryError as e: + raise ModelRetry(f'Invalid query: {e}') from e + else: + return result + + +result = agent.run_sync( + 'get me uses who were last active yesterday.', deps=DatabaseConn() +) +print(result.data) +#> sql_query='SELECT * FROM users WHERE last_active::date = today() - interval 1 day' +``` + +_(This example is complete, it can be run "as is")_ ## Streamed Results -**TODO** +There two main challenges with streamed results: + +1. Validating structured responses before they're complete, this is achieved by "partial validation" which was recently added to Pydantic in [pydantic/pydantic#10748](https://github.com/pydantic/pydantic/pull/10748). +2. When receiving a response, we don't know if it's the final response without starting to stream it and peeking at the content. PydanticAI streams just enough of the response to sniff out if it's a retriever call or a result, then streams the whole thing and calls retrievers, or returns the stream as a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult]. + +### Streaming Text + +Example of streamed text result: + +```py title="streamed_hello_world.py" +from pydantic_ai import Agent + +agent = Agent('gemini-1.5-flash') # (1)! + + +async def main(): + async with agent.run_stream('Where does "hello world" come from?') as result: # (2)! + async for message in result.stream(): # (3)! + print(message) + #> The first known + #> The first known use of "hello, + #> The first known use of "hello, world" was in + #> The first known use of "hello, world" was in a 1974 textbook + #> The first known use of "hello, world" was in a 1974 textbook about the C + #> The first known use of "hello, world" was in a 1974 textbook about the C programming language. +``` + +1. Streaming works with the standard [`Agent`][pydantic_ai.Agent] class, and doesn't require any special setup, just a model that supports streaming (currently all models support streaming). +2. The [`Agent.run_stream()`][pydantic_ai.Agent.run_stream] method is used to start a streamed run, this method returns a context manager so the connection can be closed when the stream completes. +3. Each item yield by [`StreamedRunResult.stream()`][pydantic_ai.result.StreamedRunResult.stream] is the complete text response, extended as new data is received. + +_(This example is complete, it can be run "as is")_ + +We can also stream text as deltas rather than the entire text in each item: + +```py title="streamed_delta_hello_world.py" +from pydantic_ai import Agent + +agent = Agent('gemini-1.5-flash') + + +async def main(): + async with agent.run_stream('Where does "hello world" come from?') as result: + async for message in result.stream_text(delta=True): # (1)! + print(message) + #> The first known + #> use of "hello, + #> world" was in + #> a 1974 textbook + #> about the C + #> programming language. +``` + +1. [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text] will error if the response is not text + +_(This example is complete, it can be run "as is")_ + +### Streaming Structured Responses + +Not all types are supported with partial validation in Pydantic, see [pydantic/pydantic#10748](https://github.com/pydantic/pydantic/pull/10748), generally for model-like structures it's currently best to use `TypeDict`. + +Here's an example of streaming a use profile as it's built: + +```py title="streamed_user_profile.py" +from datetime import date + +from typing_extensions import TypedDict + +from pydantic_ai import Agent + + +class UserProfile(TypedDict, total=False): + name: str + dob: date + bio: str + + +agent = Agent( + 'openai:gpt-4o', + result_type=UserProfile, + system_prompt='Extract a user profile from the input', +) + + +async def main(): + user_input = 'My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid.' + async with agent.run_stream(user_input) as result: + async for profile in result.stream(): + print(profile) + #> {'name': 'Ben'} + #> {'name': 'Ben'} + #> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes'} + #> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes the chain the '} + #> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes the chain the dog and the pyr'} + #> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes the chain the dog and the pyramid'} + #> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes the chain the dog and the pyramid'} +``` + +_(This example is complete, it can be run "as is")_ + +If you want fine-grained control of validation, particularly catching validation errors, you can use the following pattern: + +```py title="streamed_user_profile.py" +from datetime import date + +from pydantic import ValidationError +from typing_extensions import TypedDict + +from pydantic_ai import Agent + + +class UserProfile(TypedDict, total=False): + name: str + dob: date + bio: str + + +agent = Agent('openai:gpt-4o', result_type=UserProfile) + + +async def main(): + user_input = 'My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid.' + async with agent.run_stream(user_input) as result: + async for message, last in result.stream_structured(debounce_by=0.01): # (1)! + try: + profile = await result.validate_structured_result( # (2)! + message, + allow_partial=not last, + ) + except ValidationError: + continue + print(profile) + #> {'name': 'Ben'} + #> {'name': 'Ben'} + #> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes'} + #> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes the chain the '} + #> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes the chain the dog and the pyr'} + #> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes the chain the dog and the pyramid'} + #> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes the chain the dog and the pyramid'} +``` + +1. [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] streams the data as [`ModelStructuredResponse`][pydantic_ai.messages.ModelStructuredResponse] objects, thus iteration can't fail with a `ValidationError`. +2. [`validate_structured_result`][pydantic_ai.result.StreamedRunResult.validate_structured_result] validates the data, `allow_partial=True` enables pydantic's [`experimental_allow_partial` flag on `TypeAdapter`][pydantic.type_adapter.TypeAdapter.validate_json]. -Streamed responses provide a unique challenge: -* validating the partial result is both practically and semantically complex, but pydantic can do this -* we don't know if a result will be the final result of a run until we start streaming it, so PydanticAI has to start streaming just enough of the response to sniff out if it's the final response, then either stream the rest of the response to call a retriever, or return an object that lets the rest of the response be streamed by the user -* examples including: streaming text, streaming validated data, streaming the raw data to do validation inside a try/except block when necessary -* explanation of how streamed responses are "debounced" +_(This example is complete, it can be run "as is")_ -## Cost +## Examples -**TODO** +The following examples demonstrate how to use streamed responses in PydanticAI: -* counts tokens, not dollars -* example +- [Stream markdown](examples/stream-markdown.md) +- [Stream Whales](examples/stream-whales.md) diff --git a/mkdocs.yml b/mkdocs.yml index 72d335cc44..3a81abc767 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -151,7 +151,8 @@ plugins: heading_level: 3 import: - url: https://docs.python.org/3/objects.inv - - url: https://docs.pydantic.dev/latest/objects.inv + # TODO use /latest/ not /dev/ when Pydantic 2.10 is released + - url: https://docs.pydantic.dev/dev/objects.inv - url: https://fastapi.tiangolo.com/objects.inv - url: https://typing-extensions.readthedocs.io/en/latest/objects.inv - url: https://rich.readthedocs.io/en/stable/objects.inv diff --git a/pydantic_ai/_result.py b/pydantic_ai/_result.py index 4e64495061..ac4512975d 100644 --- a/pydantic_ai/_result.py +++ b/pydantic_ai/_result.py @@ -101,8 +101,10 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat tools: dict[str, ResultTool[ResultData]] = {} if args := get_union_args(response_type): - for arg in args: + for i, arg in enumerate(args, start=1): tool_name = union_tool_name(name, arg) + while tool_name in tools: + tool_name = f'{tool_name}_{i}' tools[tool_name] = _build_tool(arg, tool_name, True) else: tools[name] = _build_tool(response_type, name, False) diff --git a/pydantic_ai/models/function.py b/pydantic_ai/models/function.py index c0dbcc196d..7d6fadd826 100644 --- a/pydantic_ai/models/function.py +++ b/pydantic_ai/models/function.py @@ -9,16 +9,19 @@ from __future__ import annotations as _annotations import inspect +import re from collections.abc import AsyncIterator, Awaitable, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime +from itertools import chain from typing import Callable, Union, cast -from typing_extensions import TypeAlias, overload +import pydantic_core +from typing_extensions import TypeAlias, assert_never, overload from .. import _utils, result -from ..messages import Message, ModelAnyResponse, ModelStructuredResponse, ToolCall +from ..messages import ArgsJson, Message, ModelAnyResponse, ModelStructuredResponse, ToolCall from . import ( AbstractToolDefinition, AgentModel, @@ -113,7 +116,6 @@ class DeltaToolCall: DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall] """A mapping of tool call IDs to incremental changes.""" -# TODO these should allow coroutines FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[ModelAnyResponse, Awaitable[ModelAnyResponse]]] """A function used to generate a non-streamed response.""" @@ -138,10 +140,12 @@ class FunctionAgentModel(AgentModel): async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]: assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests' if inspect.iscoroutinefunction(self.function): - return await self.function(messages, self.agent_info), result.Cost() + response = await self.function(messages, self.agent_info) else: - response = await _utils.run_in_executor(self.function, messages, self.agent_info) - return cast(ModelAnyResponse, response), result.Cost() + response_ = await _utils.run_in_executor(self.function, messages, self.agent_info) + response = cast(ModelAnyResponse, response_) + # TODO is `messages` right here? Should it just be new messages? + return response, _estimate_cost(chain(messages, [response])) @asynccontextmanager async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]: @@ -225,3 +229,40 @@ def cost(self) -> result.Cost: def timestamp(self) -> datetime: return self._timestamp + + +def _estimate_cost(messages: Iterable[Message]) -> result.Cost: + """Very rough guesstimate of the number of tokens associate with a series of messages. + + This is designed to be used solely to give plausible numbers for testing! + """ + # there seem to be about 50 tokens of overhead for both Gemini and OpenAI calls, so add that here ¯\_(ツ)_/¯ + + request_tokens = 50 + response_tokens = 0 + for message in messages: + if message.role == 'system' or message.role == 'user': + request_tokens += _string_cost(message.content) + elif message.role == 'tool-return': + request_tokens += _string_cost(message.model_response_str()) + elif message.role == 'retry-prompt': + request_tokens += _string_cost(message.model_response()) + elif message.role == 'model-text-response': + response_tokens += _string_cost(message.content) + elif message.role == 'model-structured-response': + for call in message.calls: + if isinstance(call.args, ArgsJson): + args_str = call.args.args_json + else: + args_str = pydantic_core.to_json(call.args.args_object).decode() + + response_tokens += 1 + _string_cost(args_str) + else: + assert_never(message) + return result.Cost( + request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens + ) + + +def _string_cost(content: str) -> int: + return len(re.split(r'[\s",.:]+', content)) diff --git a/pydantic_ai/models/test.py b/pydantic_ai/models/test.py index 5e2d0f26e3..3d48588461 100644 --- a/pydantic_ai/models/test.py +++ b/pydantic_ai/models/test.py @@ -60,7 +60,7 @@ class TestModel(Model): call_retrievers: list[str] | Literal['all'] = 'all' """List of retrievers to call. If `'all'`, all retrievers will be called.""" custom_result_text: str | None = None - """If set, this text is return as teh final result.""" + """If set, this text is return as the final result.""" custom_result_args: Any | None = None """If set, these args will be passed to the result tool.""" seed: int = 0 diff --git a/pydantic_ai/result.py b/pydantic_ai/result.py index 73b327ed3a..5f68259ada 100644 --- a/pydantic_ai/result.py +++ b/pydantic_ai/result.py @@ -132,22 +132,14 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes. """ - async def stream(self, *, text_delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[ResultData]: + async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultData]: """Stream the response as an async iterable. - Result validators are called on each iteration, if `text_delta=False` (the default) or for structured - responses. - - !!! note - Result validators will NOT be called on the text result if `text_delta=True`. - The pydantic validator for structured data will be called in [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation) on each iteration. Args: - text_delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text - up to the current point. debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. @@ -156,22 +148,24 @@ async def stream(self, *, text_delta: bool = False, debounce_by: float | None = An async iterable of the response data. """ if isinstance(self._stream_response, models.StreamTextResponse): - async for text in self.stream_text(text_delta=text_delta, debounce_by=debounce_by): + async for text in self.stream_text(debounce_by=debounce_by): yield cast(ResultData, text) else: - assert not text_delta, 'Cannot use `text_delta=True` for structured responses' async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): yield await self.validate_structured_result(structured_message, allow_partial=not is_last) - async def stream_text(self, *, text_delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: + async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. !!! note This method will fail if the response is structured, e.g. if [`is_structured`][pydantic_ai.result.StreamedRunResult.is_structured] returns `True`. + !!! note + Result validators will NOT be called on the text result if `delta=True`. + Args: - text_delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text + delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text up to the current point. debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of @@ -180,7 +174,7 @@ async def stream_text(self, *, text_delta: bool = False, debounce_by: float | No with _logfire.span('response stream text') as lf_span: if isinstance(self._stream_response, models.StreamStructuredResponse): raise exceptions.UserError('stream_text() can only be used with text responses') - if text_delta: + if delta: async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: async for _ in group_iter: yield ''.join(self._stream_response.get()) diff --git a/pydantic_ai_examples/sql_gen.py b/pydantic_ai_examples/sql_gen.py index 34d9ec39fd..f5c4e28aa5 100644 --- a/pydantic_ai_examples/sql_gen.py +++ b/pydantic_ai_examples/sql_gen.py @@ -77,6 +77,7 @@ class InvalidRequest(BaseModel): 'gemini-1.5-flash', # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else result_type=Response, # type: ignore + deps_type=Deps, ) diff --git a/pydantic_ai_examples/stream_markdown.py b/pydantic_ai_examples/stream_markdown.py index da57b5f3ee..3adea34a02 100644 --- a/pydantic_ai_examples/stream_markdown.py +++ b/pydantic_ai_examples/stream_markdown.py @@ -39,10 +39,10 @@ async def main(): if env_var in os.environ: console.log(f'Using model: {model}') with Live('', console=console, vertical_overflow='visible') as live: - async with agent.run_stream(prompt, model=model) as response: - async for message in response.stream(): + async with agent.run_stream(prompt, model=model) as result: + async for message in result.stream(): live.update(Markdown(message)) - console.log(response.cost()) + console.log(result.cost()) else: console.log(f'{model} requires {env_var} to be set.') diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 7af5e9f48a..9b5dfc7c59 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -561,7 +561,7 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): assert result.cost() == snapshot(Cost(request_tokens=2, response_tokens=4, total_tokens=6)) async with agent.run_stream('Hello') as result: - chunks = [chunk async for chunk in result.stream(text_delta=True, debounce_by=None)] + chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)] assert chunks == snapshot(['Hello ', 'world']) assert result.cost() == snapshot(Cost(request_tokens=2, response_tokens=4, total_tokens=6)) diff --git a/tests/test_examples.py b/tests/test_examples.py index 1ff0ed4eed..b7e874ab7c 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,6 +1,5 @@ from __future__ import annotations as _annotations -import asyncio import re import sys from collections.abc import AsyncIterator, Iterable @@ -9,11 +8,13 @@ from typing import Any import httpx +import pydantic_core import pytest from devtools import debug from pytest_examples import CodeExample, EvalExample, find_examples from pytest_mock import MockerFixture +from pydantic_ai._utils import group_by_temporal from pydantic_ai.messages import ( ArgsObject, Message, @@ -23,7 +24,7 @@ ToolCall, ) from pydantic_ai.models import KnownModelName, Model -from pydantic_ai.models.function import AgentInfo, DeltaToolCalls, FunctionModel +from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel from tests.conftest import ClientWithHandler @@ -39,9 +40,15 @@ def get(self, name: str) -> int | None: class DatabaseConn: users: FakeTable = field(default_factory=FakeTable) + async def execute(self, query: str) -> None: + pass + + class QueryError(RuntimeError): + pass + module_name = 'fake_database' sys.modules[module_name] = module = ModuleType(module_name) - module.__dict__.update({'DatabaseConn': DatabaseConn}) + module.__dict__.update({'DatabaseConn': DatabaseConn, 'QueryError': QueryError}) yield @@ -84,6 +91,7 @@ def test_docs_examples( ): # debug(example) mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model) + mocker.patch('pydantic_ai._utils.group_by_temporal', side_effect=mock_group_by_temporal) mocker.patch('httpx.Client.get', side_effect=http_request) mocker.patch('httpx.Client.post', side_effect=http_request) @@ -99,7 +107,11 @@ def test_docs_examples( if 'from bank_database import DatabaseConn' in example.source: ruff_ignore.append('I001') - eval_example.set_config(ruff_ignore=ruff_ignore, target_version='py39') + line_length = 88 + if prefix_settings.get('title') in ('streamed_hello_world.py', 'streamed_user_profile.py'): + line_length = 120 + + eval_example.set_config(ruff_ignore=ruff_ignore, target_version='py39', line_length=line_length) eval_example.print_callback = print_callback @@ -122,7 +134,8 @@ def test_docs_examples( def print_callback(s: str) -> str: - return re.sub(r'datetime.datetime\(.+?\)', 'datetime.datetime(...)', s, flags=re.DOTALL) + s = re.sub(r'datetime\.datetime\(.+?\)', 'datetime.datetime(...)', s, flags=re.DOTALL) + return re.sub(r'datetime.date\(', 'date(', s) def http_request(url: str, **kwargs: Any) -> httpx.Response: @@ -171,6 +184,37 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: } ), ), + 'Where the olympics held in 2012?': ToolCall( + tool_name='final_result', + args=ArgsObject({'city': 'London', 'country': 'United Kingdom'}), + ), + 'The box is 10x20x30': 'Please provide the units for the dimensions (e.g., cm, in, m).', + 'The box is 10x20x30 cm': ToolCall( + tool_name='final_result', + args=ArgsObject({'width': 10, 'height': 20, 'depth': 30, 'units': 'cm'}), + ), + 'red square, blue circle, green triangle': ToolCall( + tool_name='final_result_list', + args=ArgsObject({'response': ['red', 'blue', 'green']}), + ), + 'square size 10, circle size 20, triangle size 30': ToolCall( + tool_name='final_result_list_2', + args=ArgsObject({'response': [10, 20, 30]}), + ), + 'get me uses who were last active yesterday.': ToolCall( + tool_name='final_result_Success', + args=ArgsObject({'sql_query': 'SELECT * FROM users WHERE last_active::date = today() - interval 1 day'}), + ), + 'My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid.': ToolCall( + tool_name='final_result', + args=ArgsObject( + { + 'name': 'Ben', + 'dob': '1990-01-28', + 'bio': 'Likes the chain the dog and the pyramid', + } + ), + ), } @@ -219,14 +263,27 @@ async def stream_model_logic(messages: list[Message], info: AgentInfo) -> AsyncI if m.role == 'user': if response := text_responses.get(m.content): if isinstance(response, str): - *words, last_word = response.split(' ') + words = response.split(' ') + chunk: list[str] = [] for work in words: - yield f'{work} ' - await asyncio.sleep(0.05) - yield last_word + chunk.append(work) + if len(chunk) == 3: + yield ' '.join(chunk) + ' ' + chunk.clear() + if chunk: + yield ' '.join(chunk) return else: - raise NotImplementedError('todo') + if isinstance(response.args, ArgsObject): + json_text = pydantic_core.to_json(response.args.args_object).decode() + else: + json_text = response.args.args_json + + yield {1: DeltaToolCall(name=response.tool_name)} + for chunk_index in range(0, len(json_text), 15): + text_chunk = json_text[chunk_index : chunk_index + 15] + yield {1: DeltaToolCall(json_args=text_chunk)} + return sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') @@ -239,3 +296,8 @@ def mock_infer_model(model: Model | KnownModelName) -> Model: return TestModel() else: return FunctionModel(model_logic, stream_function=stream_model_logic) + + +def mock_group_by_temporal(aiter: Any, soft_max_interval: float | None) -> Any: + """Mock group_by_temporal to avoid debouncing, since the iterators above have no delay.""" + return group_by_temporal(aiter, None) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 4f60397158..12cd8a9bdb 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -132,7 +132,7 @@ async def test_streamed_text_stream(): ) async with agent.run_stream('Hello') as result: - assert [c async for c in result.stream(text_delta=True, debounce_by=None)] == snapshot( + assert [c async for c in result.stream_text(delta=True, debounce_by=None)] == snapshot( ['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.'] )