From edbed0ace8af26f05268aa0d57d33c8e169a308e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 10 Nov 2024 11:00:09 +0000 Subject: [PATCH 1/3] improvements to text stream efficiency --- pydantic_ai/models/__init__.py | 19 ++++++++++++++----- pydantic_ai/models/function.py | 11 ++++++++--- pydantic_ai/models/gemini.py | 34 ++++++++++++++++++++++++++-------- pydantic_ai/models/openai.py | 22 +++++++++++++--------- pydantic_ai/models/test.py | 11 ++++++++--- pydantic_ai/result.py | 14 ++++++++------ 6 files changed, 77 insertions(+), 34 deletions(-) diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index 9cc93159f0..e17b05024d 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -7,7 +7,7 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Mapping, Sequence +from collections.abc import AsyncIterator, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from datetime import datetime from functools import cache @@ -70,13 +70,22 @@ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherS class StreamTextResponse(ABC): """Streamed response from an LLM when returning text.""" - def __aiter__(self) -> AsyncIterator[str]: - """Stream the response as an async iterable of string chunks.""" + def __aiter__(self) -> AsyncIterator[None]: + """Stream the response as an async iterable, building up the text as it goes. + + This is an async iterator that yields `None` to avoid doing the work of validating the input and + extracting the text field when it will often be thrown away. + """ return self @abstractmethod - async def __anext__(self) -> str: - """Process the next chunk of the response, and return the string delta.""" + async def __anext__(self) -> None: + """Process the next chunk of the response, see above for why this returns `None`.""" + raise NotImplementedError() + + @abstractmethod + def get(self) -> Iterable[str]: + """Returns an iterable of text since the last call to `get()` — e.g. the text delta.""" raise NotImplementedError() @abstractmethod diff --git a/pydantic_ai/models/function.py b/pydantic_ai/models/function.py index ae2437b725..6015e2af2a 100644 --- a/pydantic_ai/models/function.py +++ b/pydantic_ai/models/function.py @@ -123,10 +123,15 @@ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherS @dataclass class FunctionStreamTextResponse(StreamTextResponse): _iter: Iterator[str] - _timestamp: datetime = field(default_factory=_utils.now_utc) + _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) + _buffer: list[str] = field(default_factory=list, init=False) + + async def __anext__(self) -> None: + self._buffer.append(_utils.sync_anext(self._iter)) - async def __anext__(self) -> str: - return _utils.sync_anext(self._iter) + def get(self) -> Iterable[str]: + yield from self._buffer + self._buffer.clear() def cost(self) -> result.Cost: return result.Cost() diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index b1fee6be2f..355cc2090f 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -18,7 +18,7 @@ import os import re -from collections.abc import AsyncIterator, Mapping, Sequence +from collections.abc import AsyncIterator, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from copy import deepcopy from dataclasses import dataclass, field @@ -26,7 +26,7 @@ from typing import Annotated, Any, Literal, Union from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse -from pydantic import Field +from pydantic import Discriminator, Field, Tag from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never from .. import UnexpectedModelBehaviour, _pydantic, _utils, exceptions, result @@ -229,28 +229,28 @@ class GeminiStreamTextResponse(StreamTextResponse): _content: bytearray _stream: AsyncIterator[bytes] _position: int = 0 - _timestamp: datetime = field(default_factory=_utils.now_utc) + _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) - async def __anext__(self) -> str: + async def __anext__(self) -> None: if self._first: self._first = False else: chunk = await self._stream.__anext__() self._content.extend(chunk) + def get(self) -> Iterable[str]: responses = self._responses() new_responses = responses[self._position :] self._position = len(responses) - new_text: list[str] = [] for r in new_responses: parts = r['candidates'][0]['content']['parts'] if all_text_parts(parts): - new_text.extend(part['text'] for part in parts) + for part in parts: + yield part['text'] else: raise UnexpectedModelBehaviour( 'Streamed response with unexpected content, expected all parts to be text' ) - return ''.join(new_text) def cost(self) -> result.Cost: cost = result.Cost() @@ -413,10 +413,28 @@ class _GeminiFunctionResponse(TypedDict): response: dict[str, Any] +def _part_discriminator(v: Any) -> str: + if isinstance(v, dict): + if 'text' in v: + return 'text' + elif 'functionCall' in v or 'function_call' in v: + return 'function_call' + elif 'functionResponse' in v or 'function_response' in v: + return 'function_response' + return 'text' + + # See # we don't currently support other part types # TODO discriminator -_GeminiPartUnion = Union[_GeminiTextPart, _GeminiFunctionCallPart, _GeminiFunctionResponsePart] +_GeminiPartUnion = Annotated[ + Union[ + Annotated[_GeminiTextPart, Tag('text')], + Annotated[_GeminiFunctionCallPart, Tag('function_call')], + Annotated[_GeminiFunctionResponsePart, Tag('function_response')], + ], + Discriminator(_part_discriminator), +] class _GeminiTextContent(TypedDict): diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 6f5c7c08d9..390f318f37 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -1,6 +1,6 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Mapping, Sequence +from collections.abc import AsyncIterator, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone @@ -230,12 +230,13 @@ class OpenAIStreamTextResponse(StreamTextResponse): _response: AsyncStream[ChatCompletionChunk] _timestamp: datetime _cost: result.Cost + _buffer: list[str] = field(default_factory=list, init=False) - async def __anext__(self) -> str: + async def __anext__(self) -> None: if self._first is not None: - first = self._first + self._buffer.append(self._first) self._first = None - return first + return None chunk = await self._response.__anext__() self._cost += _map_cost(chunk) @@ -244,12 +245,15 @@ async def __anext__(self) -> str: except IndexError: raise StopAsyncIteration() - if choice.finish_reason is not None: - # we don't raise StopAsyncIteration on the last chunk because usage comes after this - return choice.delta.content or '' + # we don't raise StopAsyncIteration on the last chunk because usage comes after this + if choice.finish_reason is None: + assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}' + if choice.delta.content is not None: + self._buffer.append(choice.delta.content) - assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}' - return choice.delta.content + def get(self) -> Iterable[str]: + yield from self._buffer + self._buffer.clear() def cost(self) -> Cost: return self._cost diff --git a/pydantic_ai/models/test.py b/pydantic_ai/models/test.py index 792a3dfcf7..78c9ec7e3d 100644 --- a/pydantic_ai/models/test.py +++ b/pydantic_ai/models/test.py @@ -8,7 +8,7 @@ import re import string -from collections.abc import AsyncIterator, Iterator, Mapping, Sequence +from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime @@ -177,6 +177,7 @@ class TestStreamTextResponse(StreamTextResponse): _cost: Cost _iter: Iterator[str] = field(init=False) _timestamp: datetime = field(default_factory=_utils.now_utc) + _buffer: list[str] = field(default_factory=list, init=False) def __post_init__(self): *words, last_word = self._text.split(' ') @@ -187,8 +188,12 @@ def __post_init__(self): words = [self._text[:mid], self._text[mid:]] self._iter = iter(words) - async def __anext__(self) -> str: - return _utils.sync_anext(self._iter) + async def __anext__(self) -> None: + self._buffer.append(_utils.sync_anext(self._iter)) + + def get(self) -> Iterable[str]: + yield from self._buffer + self._buffer.clear() def cost(self) -> Cost: return self._cost diff --git a/pydantic_ai/result.py b/pydantic_ai/result.py index 2fa84e5d4f..449489b5b5 100644 --- a/pydantic_ai/result.py +++ b/pydantic_ai/result.py @@ -98,7 +98,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat cost_so_far: Cost """Cost up until the last request.""" - _stream_response: models.StreamTextResponse | models.StreamToolCallResponse + _stream_response: models.EitherStreamedResponse _result_schema: _result.ResultSchema[ResultData] | None _deps: AgentDeps _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] @@ -135,15 +135,15 @@ async def stream( with _logfire.span('response stream text') as lf_span: if text_delta: async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: - async for chunks in group_iter: - yield ''.join(chunks) # pyright: ignore[reportReturnType] + async for _ in group_iter: + yield cast(ResultData, ''.join(self._stream_response.get())) else: # a quick benchmark shows it's faster to build up a string with concat when we're # yielding at each step combined = '' async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: - async for chunks in group_iter: - combined += ''.join(chunks) + async for _ in group_iter: + combined += ''.join(self._stream_response.get()) combined = await self._validate_text_result(combined) yield cast(ResultData, combined) lf_span.set_attribute('combined_text', combined) @@ -184,7 +184,9 @@ async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIt async def get_response(self) -> ResultData: """Stream the whole response, validate and return it.""" if isinstance(self._stream_response, models.StreamTextResponse): - text = ''.join([chunk async for chunk in self._stream_response]) + async for _ in self._stream_response: + pass + text = ''.join(self._stream_response.get()) text = await self._validate_text_result(text) self._marked_completed(text=text) return cast(ResultData, text) From 5e0aa25ea7a5218f913a311b539519917a8e5cfb Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 10 Nov 2024 13:25:38 +0000 Subject: [PATCH 2/3] fix gemini text streaming --- pydantic_ai/agent.py | 34 ++++++----- pydantic_ai/models/__init__.py | 9 ++- pydantic_ai/models/function.py | 2 +- pydantic_ai/models/gemini.py | 41 +++++++------ pydantic_ai/models/openai.py | 2 +- pydantic_ai/models/test.py | 2 +- pydantic_ai/result.py | 24 ++++++-- pyproject.toml | 1 + tests/models/test_gemini.py | 101 +++++++++++++++++++++++++++------ uv.lock | 2 + 10 files changed, 156 insertions(+), 62 deletions(-) diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 5357a8aaf4..ac2c027e3e 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -1,8 +1,8 @@ from __future__ import annotations as _annotations import asyncio -from collections.abc import AsyncIterator, Awaitable, Sequence -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Sequence +from contextlib import asynccontextmanager, suppress from dataclasses import dataclass from typing import Any, Callable, Generic, Literal, cast, final, overload @@ -413,18 +413,17 @@ async def _handle_model_response( raise exceptions.UnexpectedModelBehaviour('Received empty tool call message') # otherwise we run all retriever functions in parallel - coros: list[Awaitable[_messages.Message]] = [] - names: list[str] = [] + tasks: list[asyncio.Task[_messages.Message]] = [] for call in model_response.calls: retriever = self._retrievers.get(call.tool_name) if retriever is None: + await _cancel_tasks(tasks) # should this be a retry error? raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') - coros.append(retriever.run(deps, call)) - names.append(call.tool_name) + tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) - with _logfire.span('running {tools=}', tools=names): - new_messages = await asyncio.gather(*coros) + with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): + new_messages = await asyncio.gather(*tasks) return _utils.Either(right=new_messages) else: assert_never(model_response) @@ -476,17 +475,17 @@ async def _handle_streamed_model_response( messages: list[_messages.Message] = [tool_call_msg] # we now run all retriever functions in parallel - coros: list[Awaitable[_messages.Message]] = [] - names: list[str] = [] + tasks: list[asyncio.Task[_messages.Message]] = [] for call in tool_call_msg.calls: retriever = self._retrievers.get(call.tool_name) if retriever is None: + # otherwise we'll get warnings about coroutines not awaited + await _cancel_tasks(tasks) raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') - coros.append(retriever.run(deps, call)) - names.append(call.tool_name) + tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) - with _logfire.span('running {tools=}', tools=names): - messages += await asyncio.gather(*coros) + with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): + messages += await asyncio.gather(*tasks) return _utils.Either(right=messages) async def _validate_result( @@ -510,3 +509,10 @@ async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]: prompt = await sys_prompt_runner.run(deps) messages.append(_messages.SystemPrompt(prompt)) return messages + + +async def _cancel_tasks(tasks: list[asyncio.Task[_messages.Message]]) -> None: + for task in tasks: + task.cancel() + with suppress(asyncio.CancelledError): + await asyncio.gather(*tasks) diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index e17b05024d..4404c8abbf 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -84,8 +84,13 @@ async def __anext__(self) -> None: raise NotImplementedError() @abstractmethod - def get(self) -> Iterable[str]: - """Returns an iterable of text since the last call to `get()` — e.g. the text delta.""" + def get(self, *, final: bool = False) -> Iterable[str]: + """Returns an iterable of text since the last call to `get()` — e.g. the text delta. + + Args: + final: If True, this is the final call, after iteration is complete, JSON should be fully validated + and all items extracted. + """ raise NotImplementedError() @abstractmethod diff --git a/pydantic_ai/models/function.py b/pydantic_ai/models/function.py index 6015e2af2a..0b83d5ae4d 100644 --- a/pydantic_ai/models/function.py +++ b/pydantic_ai/models/function.py @@ -129,7 +129,7 @@ class FunctionStreamTextResponse(StreamTextResponse): async def __anext__(self) -> None: self._buffer.append(_utils.sync_anext(self._iter)) - def get(self) -> Iterable[str]: + def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer self._buffer.clear() diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index 355cc2090f..39fbf91a1f 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -25,6 +25,7 @@ from datetime import datetime from typing import Annotated, Any, Literal, Union +import pydantic_core from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse from pydantic import Discriminator, Field, Tag from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never @@ -196,7 +197,7 @@ async def _process_streamed_response(http_response: HTTPResponse) -> EitherStrea if _extract_response_parts(start_response).is_left(): return GeminiStreamToolCallResponse(_content=content, _stream=aiter_bytes) else: - return GeminiStreamTextResponse(_first=True, _content=content, _stream=aiter_bytes) + return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes) @staticmethod def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]: @@ -225,24 +226,29 @@ def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiCont @dataclass class GeminiStreamTextResponse(StreamTextResponse): - _first: bool - _content: bytearray + _json_content: bytearray _stream: AsyncIterator[bytes] _position: int = 0 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) + _cost: result.Cost = field(default_factory=result.Cost, init=False) async def __anext__(self) -> None: - if self._first: - self._first = False + chunk = await self._stream.__anext__() + self._json_content.extend(chunk) + + def get(self, *, final: bool = False) -> Iterable[str]: + if final: + all_items = pydantic_core.from_json(self._json_content) + new_items = all_items[self._position :] + self._position = len(all_items) + new_responses = _gemini_streamed_response_ta.validate_python(new_items) else: - chunk = await self._stream.__anext__() - self._content.extend(chunk) - - def get(self) -> Iterable[str]: - responses = self._responses() - new_responses = responses[self._position :] - self._position = len(responses) + all_items = pydantic_core.from_json(self._json_content, allow_partial=True) + new_items = all_items[self._position : -1] + self._position = len(all_items) - 1 + new_responses = _gemini_streamed_response_ta.validate_python(new_items, experimental_allow_partial=True) for r in new_responses: + self._cost += _metadata_as_cost(r['usage_metadata']) parts = r['candidates'][0]['content']['parts'] if all_text_parts(parts): for part in parts: @@ -253,20 +259,11 @@ def get(self) -> Iterable[str]: ) def cost(self) -> result.Cost: - cost = result.Cost() - for response in self._responses(): - cost += _metadata_as_cost(response['usage_metadata']) - return cost + return self._cost def timestamp(self) -> datetime: return self._timestamp - def _responses(self) -> list[_GeminiResponse]: - return _gemini_streamed_response_ta.validate_json( - self._content, # type: ignore # see https://github.com/pydantic/pydantic/pull/10802 - experimental_allow_partial=True, - ) - @dataclass class GeminiStreamToolCallResponse(StreamToolCallResponse): diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 390f318f37..06c560de12 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -251,7 +251,7 @@ async def __anext__(self) -> None: if choice.delta.content is not None: self._buffer.append(choice.delta.content) - def get(self) -> Iterable[str]: + def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer self._buffer.clear() diff --git a/pydantic_ai/models/test.py b/pydantic_ai/models/test.py index 78c9ec7e3d..2db46f6b2a 100644 --- a/pydantic_ai/models/test.py +++ b/pydantic_ai/models/test.py @@ -191,7 +191,7 @@ def __post_init__(self): async def __anext__(self) -> None: self._buffer.append(_utils.sync_anext(self._iter)) - def get(self) -> Iterable[str]: + def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer self._buffer.clear() diff --git a/pydantic_ai/result.py b/pydantic_ai/result.py index 449489b5b5..df9d4b77f2 100644 --- a/pydantic_ai/result.py +++ b/pydantic_ai/result.py @@ -137,15 +137,31 @@ async def stream( async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: async for _ in group_iter: yield cast(ResultData, ''.join(self._stream_response.get())) + final_delta = ''.join(self._stream_response.get(final=True)) + if final_delta: + yield cast(ResultData, final_delta) else: # a quick benchmark shows it's faster to build up a string with concat when we're # yielding at each step + chunks: list[str] = [] combined = '' async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: async for _ in group_iter: - combined += ''.join(self._stream_response.get()) - combined = await self._validate_text_result(combined) - yield cast(ResultData, combined) + new = False + for chunk in self._stream_response.get(): + chunks.append(chunk) + new = True + if new: + combined = await self._validate_text_result(''.join(chunks)) + yield cast(ResultData, combined) + + new = False + for chunk in self._stream_response.get(final=True): + chunks.append(chunk) + new = True + if new: + combined = await self._validate_text_result(''.join(chunks)) + yield cast(ResultData, combined) lf_span.set_attribute('combined_text', combined) self._marked_completed(text=combined) else: @@ -186,7 +202,7 @@ async def get_response(self) -> ResultData: if isinstance(self._stream_response, models.StreamTextResponse): async for _ in self._stream_response: pass - text = ''.join(self._stream_response.get()) + text = ''.join(self._stream_response.get(final=True)) text = await self._validate_text_result(text) self._marked_completed(text=text) return cast(ResultData, text) diff --git a/pyproject.toml b/pyproject.toml index 1b7f8b93bd..72e3e8cc54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "eval-type-backport>=0.2.0", "griffe>=1.3.2", "httpx>=0.27.2", + "jiter>=0.6.1", "logfire-api>=1.2.0", "openai>=1.54.3", "pydantic>=2.10.0b1", diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index e1c2fca93a..963133ddfa 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -2,7 +2,7 @@ from __future__ import annotations as _annotations import json -from collections.abc import AsyncIterator, Callable +from collections.abc import AsyncIterator, Callable, Sequence from dataclasses import dataclass from datetime import timezone @@ -330,28 +330,33 @@ async def __aiter__(self) -> AsyncIterator[bytes]: yield chunk +ResOrList: TypeAlias = '_GeminiResponse | httpx.AsyncByteStream | Sequence[_GeminiResponse | httpx.AsyncByteStream]' +GetGeminiClient: TypeAlias = 'Callable[[ResOrList], httpx.AsyncClient]' + + @pytest.fixture async def get_gemini_client(client_with_handler: ClientWithHandler, env: TestEnv): env.set('GEMINI_API_KEY', 'via-env-var') - def create_client( - response_data: _GeminiResponse | httpx.AsyncByteStream | list[_GeminiResponse], - ) -> httpx.AsyncClient: + def create_client(response_or_list: ResOrList) -> httpx.AsyncClient: index = 0 def handler(_request: httpx.Request) -> httpx.Response: nonlocal index - content: bytes | None = None - stream: httpx.AsyncByteStream | None = None - if isinstance(response_data, list): - content = _gemini_response_ta.dump_json(response_data[index], by_alias=True) - elif isinstance(response_data, httpx.AsyncByteStream): - stream = response_data + if isinstance(response_or_list, Sequence): + response = response_or_list[index] + index += 1 + else: + response = response_or_list + + if isinstance(response, httpx.AsyncByteStream): + content: bytes | None = None + stream: httpx.AsyncByteStream | None = response else: - content = _gemini_response_ta.dump_json(response_data, by_alias=True) + content = _gemini_response_ta.dump_json(response, by_alias=True) + stream = None - index += 1 return httpx.Response( 200, content=content, @@ -364,11 +369,6 @@ def handler(_request: httpx.Request) -> httpx.Response: return create_client -GetGeminiClient: TypeAlias = ( - 'Callable[[_GeminiResponse | httpx.AsyncByteStream | list[_GeminiResponse]], httpx.AsyncClient]' -) - - def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | None = 'STOP') -> _GeminiResponse: candidate = _GeminiCandidates(content=content, index=0, safety_ratings=[]) if finish_reason: # pragma: no cover @@ -551,6 +551,7 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(text_delta=True, debounce_by=None)] assert chunks == snapshot(['Hello ', 'world']) + assert result.cost() == snapshot(Cost(request_tokens=2, response_tokens=4, total_tokens=6)) async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): @@ -580,3 +581,69 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot([(1, 2), (1, 2)]) + assert result.cost() == snapshot(Cost(request_tokens=1, response_tokens=2, total_tokens=3)) + + +async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): + first_responses = [ + gemini_response( + _content_function_call(LLMToolCalls(calls=[ToolCall.from_object('foo', {'x': 'a'})])), + ), + gemini_response( + _content_function_call(LLMToolCalls(calls=[ToolCall.from_object('bar', {'y': 'b'})])), + ), + ] + first_json_data = _gemini_streamed_response_ta.dump_json(first_responses, by_alias=True) + first_stream = AsyncByteStreamList([first_json_data[:100], first_json_data[100:200], first_json_data[200:]]) + + second_responses = [ + gemini_response( + _content_function_call(LLMToolCalls(calls=[ToolCall.from_object('final_result', {'response': [1, 2]})])), + ), + ] + second_json_data = _gemini_streamed_response_ta.dump_json(second_responses, by_alias=True) + second_stream = AsyncByteStreamList([second_json_data[:100], second_json_data[100:200], second_json_data[200:]]) + + gemini_client = get_gemini_client([first_stream, second_stream]) + model = GeminiModel('gemini-1.5-flash', http_client=gemini_client) + agent = Agent(model, result_type=tuple[int, int], deps=None) + retriever_calls: list[str] = [] + + @agent.retriever_plain + async def foo(x: str) -> str: + retriever_calls.append(f'foo({x=!r})') + return x + + @agent.retriever_plain + async def bar(y: str) -> str: + retriever_calls.append(f'bar({y=!r})') + return y + + async with agent.run_stream('Hello') as result: + response = await result.get_response() + assert response == snapshot((1, 2)) + assert result.cost() == snapshot(Cost(request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.all_messages() == snapshot( + [ + UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), + LLMToolCalls( + calls=[ + ToolCall(tool_name='foo', args=ArgsObject(args_object={'x': 'a'})), + ToolCall(tool_name='bar', args=ArgsObject(args_object={'y': 'b'})), + ], + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturn(tool_name='foo', content='a', timestamp=IsNow(tz=timezone.utc)), + ToolReturn(tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc)), + LLMToolCalls( + calls=[ + ToolCall( + tool_name='final_result', + args=ArgsObject(args_object={'response': [1, 2]}), + ) + ], + timestamp=IsNow(tz=timezone.utc), + ), + ] + ) + assert retriever_calls == snapshot(["foo(x='a')", "bar(y='b')"]) diff --git a/uv.lock b/uv.lock index a77390ae83..a6c600b871 100644 --- a/uv.lock +++ b/uv.lock @@ -961,6 +961,7 @@ dependencies = [ { name = "eval-type-backport" }, { name = "griffe" }, { name = "httpx" }, + { name = "jiter" }, { name = "logfire-api" }, { name = "openai" }, { name = "pydantic" }, @@ -1000,6 +1001,7 @@ requires-dist = [ { name = "fastapi", marker = "extra == 'examples'", specifier = ">=0.115.4" }, { name = "griffe", specifier = ">=1.3.2" }, { name = "httpx", specifier = ">=0.27.2" }, + { name = "jiter", specifier = ">=0.6.1" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=2" }, { name = "logfire", extras = ["asyncpg", "fastapi"], marker = "extra == 'examples'", specifier = ">=2" }, { name = "logfire-api", specifier = ">=1.2.0" }, From 52fa9c2746c333dba6e0e938896f98524f298816 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 10 Nov 2024 14:06:13 +0000 Subject: [PATCH 3/3] add final to structured reponse .get() --- pydantic_ai/agent.py | 34 ++++++++++++++++----------- pydantic_ai/models/__init__.py | 9 ++++--- pydantic_ai/models/function.py | 2 +- pydantic_ai/models/gemini.py | 23 ++++++++---------- pydantic_ai/models/openai.py | 2 +- pydantic_ai/models/test.py | 4 ++-- pydantic_ai/result.py | 20 ++++++++++------ pydantic_ai_examples/stream_whales.py | 6 ++--- pyproject.toml | 1 - tests/models/test_gemini.py | 10 ++++---- tests/models/test_openai.py | 14 +++++++++-- tests/test_streaming.py | 6 ++--- uv.lock | 2 -- 13 files changed, 76 insertions(+), 57 deletions(-) diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index ac2c027e3e..2ea21ca7c3 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -414,13 +414,16 @@ async def _handle_model_response( # otherwise we run all retriever functions in parallel tasks: list[asyncio.Task[_messages.Message]] = [] - for call in model_response.calls: - retriever = self._retrievers.get(call.tool_name) - if retriever is None: - await _cancel_tasks(tasks) - # should this be a retry error? - raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') - tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) + try: + for call in model_response.calls: + retriever = self._retrievers.get(call.tool_name) + if retriever is None: + # should this be a retry error? + raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') + tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) + except BaseException: + await _cancel_tasks(tasks) + raise with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): new_messages = await asyncio.gather(*tasks) @@ -476,13 +479,16 @@ async def _handle_streamed_model_response( # we now run all retriever functions in parallel tasks: list[asyncio.Task[_messages.Message]] = [] - for call in tool_call_msg.calls: - retriever = self._retrievers.get(call.tool_name) - if retriever is None: - # otherwise we'll get warnings about coroutines not awaited - await _cancel_tasks(tasks) - raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') - tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) + try: + for call in tool_call_msg.calls: + retriever = self._retrievers.get(call.tool_name) + if retriever is None: + raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') + tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) + except BaseException: + # otherwise we'll get warnings about coroutines not awaited + await _cancel_tasks(tasks) + raise with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): messages += await asyncio.gather(*tasks) diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index 4404c8abbf..5bd62e5feb 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -88,8 +88,8 @@ def get(self, *, final: bool = False) -> Iterable[str]: """Returns an iterable of text since the last call to `get()` — e.g. the text delta. Args: - final: If True, this is the final call, after iteration is complete, JSON should be fully validated - and all items extracted. + final: If True, this is the final call, after iteration is complete, the response should be fully validated + and all text extracted. """ raise NotImplementedError() @@ -124,10 +124,13 @@ async def __anext__(self) -> None: raise NotImplementedError() @abstractmethod - def get(self) -> LLMToolCalls: + def get(self, *, final: bool = False) -> LLMToolCalls: """Get the `LLMToolCalls` at this point. The `LLMToolCalls` may or may not be complete, depending on whether the stream is finished. + + Args: + final: If True, this is the final call, after iteration is complete, the response should be fully validated. """ raise NotImplementedError() diff --git a/pydantic_ai/models/function.py b/pydantic_ai/models/function.py index 0b83d5ae4d..4485906324 100644 --- a/pydantic_ai/models/function.py +++ b/pydantic_ai/models/function.py @@ -156,7 +156,7 @@ async def __anext__(self) -> None: else: self._delta_tool_calls[key] = new - def get(self) -> LLMToolCalls: + def get(self, *, final: bool = False) -> LLMToolCalls: """Map tool call deltas to a `LLMToolCalls`.""" calls: list[ToolCall] = [] for c in self._delta_tool_calls.values(): diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index 39fbf91a1f..85f07fc30d 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -269,13 +269,14 @@ def timestamp(self) -> datetime: class GeminiStreamToolCallResponse(StreamToolCallResponse): _content: bytearray _stream: AsyncIterator[bytes] - _timestamp: datetime = field(default_factory=_utils.now_utc) + _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) + _cost: result.Cost = field(default_factory=result.Cost, init=False) async def __anext__(self) -> None: chunk = await self._stream.__anext__() self._content.extend(chunk) - def get(self) -> LLMToolCalls: + def get(self, *, final: bool = False) -> LLMToolCalls: """Get the `LLMToolCalls` at this point. NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always @@ -284,9 +285,14 @@ def get(self) -> LLMToolCalls: I'm therefore assuming that each part contains a complete tool call, and not trying to combine data from separate parts. """ - responses = self._responses() + responses = _gemini_streamed_response_ta.validate_json( + self._content, # type: ignore # see https://github.com/pydantic/pydantic/pull/10802 + experimental_allow_partial=not final, + ) combined_parts: list[_GeminiFunctionCallPart] = [] + self._cost = result.Cost() for r in responses: + self._cost += _metadata_as_cost(r['usage_metadata']) candidate = r['candidates'][0] parts = candidate['content']['parts'] if all_function_call_parts(parts): @@ -299,20 +305,11 @@ def get(self) -> LLMToolCalls: return _tool_call_from_parts(combined_parts, timestamp=self._timestamp) def cost(self) -> result.Cost: - cost = result.Cost() - for response in self._responses(): - cost += _metadata_as_cost(response['usage_metadata']) - return cost + return self._cost def timestamp(self) -> datetime: return self._timestamp - def _responses(self) -> list[_GeminiResponse]: - return _gemini_streamed_response_ta.validate_json( - self._content, # type: ignore # see https://github.com/pydantic/pydantic/pull/10802 - experimental_allow_partial=True, - ) - # We use typed dicts to define the Gemini API response schema # once Pydantic partial validation supports, dataclasses, we could revert to using them diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 06c560de12..115e13fc50 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -292,7 +292,7 @@ async def __anext__(self) -> None: else: self._delta_tool_calls[new.index] = new - def get(self) -> LLMToolCalls: + def get(self, *, final: bool = False) -> LLMToolCalls: """Map tool call deltas to a `LLMToolCalls`.""" calls: list[ToolCall] = [] for c in self._delta_tool_calls.values(): diff --git a/pydantic_ai/models/test.py b/pydantic_ai/models/test.py index 2db46f6b2a..bddce76037 100644 --- a/pydantic_ai/models/test.py +++ b/pydantic_ai/models/test.py @@ -207,12 +207,12 @@ class TestStreamToolCallResponse(StreamToolCallResponse): _structured_response: LLMToolCalls _cost: Cost _iter: Iterator[None] = field(default_factory=lambda: iter([None])) - _timestamp: datetime = field(default_factory=_utils.now_utc) + _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) async def __anext__(self) -> None: return _utils.sync_anext(self._iter) - def get(self) -> LLMToolCalls: + def get(self, *, final: bool = False) -> LLMToolCalls: return self._structured_response def cost(self) -> Cost: diff --git a/pydantic_ai/result.py b/pydantic_ai/result.py index df9d4b77f2..8a4b7f4a00 100644 --- a/pydantic_ai/result.py +++ b/pydantic_ai/result.py @@ -166,10 +166,12 @@ async def stream( self._marked_completed(text=combined) else: assert not text_delta, 'Cannot use `text_delta=True` for structured responses' - async for tool_message in self.stream_structured(debounce_by=debounce_by): - yield await self.validate_structured_result(tool_message, allow_partial=True) + async for tool_message, is_last in self.stream_structured(debounce_by=debounce_by): + yield await self.validate_structured_result(tool_message, allow_partial=not is_last) - async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[messages.LLMToolCalls]: + async def stream_structured( + self, *, debounce_by: float | None = 0.1 + ) -> AsyncIterator[tuple[messages.LLMToolCalls, bool]]: """Stream the response as an async iterable of Structured LLM Messages. !!! note @@ -180,6 +182,8 @@ async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIt the response stream is debounced by 0.2 seconds unless `text_delta` is `True`, in which case it doesn't make sense to debounce. `None` means no debouncing. Debouncing is important particularly for long structured responses to reduce the overhead of performing validation as each token is received. + + Returns: An async iterable of the structured response message and whether it's the last message. """ with _logfire.span('response stream structured') as lf_span: if isinstance(self._stream_response, models.StreamTextResponse): @@ -188,14 +192,16 @@ async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIt # we should already have a message at this point, yield that first if it has any content msg = self._stream_response.get() if any(call.has_content() for call in msg.calls): - yield msg + yield msg, False async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: async for _ in group_iter: msg = self._stream_response.get() if any(call.has_content() for call in msg.calls): - yield msg + yield msg, False + msg = self._stream_response.get(final=True) + yield msg, True lf_span.set_attribute('structured_response', msg) - self._marked_completed(structured_message=msg) + self._marked_completed(structured_message=msg) async def get_response(self) -> ResultData: """Stream the whole response, validate and return it.""" @@ -209,7 +215,7 @@ async def get_response(self) -> ResultData: else: async for _ in self._stream_response: pass - tool_message = self._stream_response.get() + tool_message = self._stream_response.get(final=True) self._marked_completed(structured_message=tool_message) return await self.validate_structured_result(tool_message) diff --git a/pydantic_ai_examples/stream_whales.py b/pydantic_ai_examples/stream_whales.py index 91f0a4e034..2b18c44522 100644 --- a/pydantic_ai_examples/stream_whales.py +++ b/pydantic_ai_examples/stream_whales.py @@ -43,12 +43,12 @@ async def main(): console = Console() with Live('\n' * 36, console=console) as live: console.print('Requesting data...', style='cyan') - async with agent.run_stream('Generate me details of 20 species of Whale.') as result: + async with agent.run_stream('Generate me details of 5 species of Whale.') as result: console.print('Response:', style='green') - async for message in result.stream_structured(debounce_by=0.01): + async for message, last in result.stream_structured(debounce_by=0.01): try: - whales = await result.validate_structured_result(message, allow_partial=True) + whales = await result.validate_structured_result(message, allow_partial=not last) except ValidationError as exc: if all(e['type'] == 'missing' and e['loc'] == ('response',) for e in exc.errors()): continue diff --git a/pyproject.toml b/pyproject.toml index 72e3e8cc54..1b7f8b93bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ dependencies = [ "eval-type-backport>=0.2.0", "griffe>=1.3.2", "httpx>=0.27.2", - "jiter>=0.6.1", "logfire-api>=1.2.0", "openai>=1.54.3", "pydantic>=2.10.0b1", diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 963133ddfa..13d982dc29 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -580,7 +580,7 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] - assert chunks == snapshot([(1, 2), (1, 2)]) + assert chunks == snapshot([(1, 2), (1, 2), (1, 2)]) assert result.cost() == snapshot(Cost(request_tokens=1, response_tokens=2, total_tokens=3)) @@ -593,16 +593,16 @@ async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): _content_function_call(LLMToolCalls(calls=[ToolCall.from_object('bar', {'y': 'b'})])), ), ] - first_json_data = _gemini_streamed_response_ta.dump_json(first_responses, by_alias=True) - first_stream = AsyncByteStreamList([first_json_data[:100], first_json_data[100:200], first_json_data[200:]]) + d1 = _gemini_streamed_response_ta.dump_json(first_responses, by_alias=True) + first_stream = AsyncByteStreamList([d1[:100], d1[100:200], d1[200:300], d1[300:]]) second_responses = [ gemini_response( _content_function_call(LLMToolCalls(calls=[ToolCall.from_object('final_result', {'response': [1, 2]})])), ), ] - second_json_data = _gemini_streamed_response_ta.dump_json(second_responses, by_alias=True) - second_stream = AsyncByteStreamList([second_json_data[:100], second_json_data[100:200], second_json_data[200:]]) + d2 = _gemini_streamed_response_ta.dump_json(second_responses, by_alias=True) + second_stream = AsyncByteStreamList([d2[:100], d2[100:]]) gemini_client = get_gemini_client([first_stream, second_stream]) model = GeminiModel('gemini-1.5-flash', http_client=gemini_client) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 1ba2b1d4c7..3ebacd73e8 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -379,7 +379,12 @@ async def test_stream_structured(): assert result.is_structured() assert not result.is_complete assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( - [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + [ + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] ) assert result.is_complete assert result.cost() == snapshot(Cost(request_tokens=20, response_tokens=10, total_tokens=30)) @@ -403,7 +408,12 @@ async def test_stream_structured_finish_reason(): assert result.is_structured() assert not result.is_complete assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( - [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + [ + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] ) assert result.is_complete diff --git a/tests/test_streaming.py b/tests/test_streaming.py index a14735d699..7e50945a37 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -90,11 +90,11 @@ def text_stream(_messages: list[Message], agent_info: AgentInfo) -> Iterable[Del chunks: list[list[int]] = [] async with agent.run_stream('') as result: - async for structured_response in result.stream_structured(debounce_by=None): - response_data = await result.validate_structured_result(structured_response, allow_partial=True) + async for structured_response, last in result.stream_structured(debounce_by=None): + response_data = await result.validate_structured_result(structured_response, allow_partial=not last) chunks.append(response_data) - assert chunks == snapshot([[1], [1, 2, 3, 4]]) + assert chunks == snapshot([[1], [1, 2, 3, 4], [1, 2, 3, 4]]) async def test_streamed_text_stream(): diff --git a/uv.lock b/uv.lock index a6c600b871..a77390ae83 100644 --- a/uv.lock +++ b/uv.lock @@ -961,7 +961,6 @@ dependencies = [ { name = "eval-type-backport" }, { name = "griffe" }, { name = "httpx" }, - { name = "jiter" }, { name = "logfire-api" }, { name = "openai" }, { name = "pydantic" }, @@ -1001,7 +1000,6 @@ requires-dist = [ { name = "fastapi", marker = "extra == 'examples'", specifier = ">=0.115.4" }, { name = "griffe", specifier = ">=1.3.2" }, { name = "httpx", specifier = ">=0.27.2" }, - { name = "jiter", specifier = ">=0.6.1" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=2" }, { name = "logfire", extras = ["asyncpg", "fastapi"], marker = "extra == 'examples'", specifier = ">=2" }, { name = "logfire-api", specifier = ">=1.2.0" },