diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 5357a8aaf4..2ea21ca7c3 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -1,8 +1,8 @@ from __future__ import annotations as _annotations import asyncio -from collections.abc import AsyncIterator, Awaitable, Sequence -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Sequence +from contextlib import asynccontextmanager, suppress from dataclasses import dataclass from typing import Any, Callable, Generic, Literal, cast, final, overload @@ -413,18 +413,20 @@ async def _handle_model_response( raise exceptions.UnexpectedModelBehaviour('Received empty tool call message') # otherwise we run all retriever functions in parallel - coros: list[Awaitable[_messages.Message]] = [] - names: list[str] = [] - for call in model_response.calls: - retriever = self._retrievers.get(call.tool_name) - if retriever is None: - # should this be a retry error? - raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') - coros.append(retriever.run(deps, call)) - names.append(call.tool_name) - - with _logfire.span('running {tools=}', tools=names): - new_messages = await asyncio.gather(*coros) + tasks: list[asyncio.Task[_messages.Message]] = [] + try: + for call in model_response.calls: + retriever = self._retrievers.get(call.tool_name) + if retriever is None: + # should this be a retry error? + raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') + tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) + except BaseException: + await _cancel_tasks(tasks) + raise + + with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): + new_messages = await asyncio.gather(*tasks) return _utils.Either(right=new_messages) else: assert_never(model_response) @@ -476,17 +478,20 @@ async def _handle_streamed_model_response( messages: list[_messages.Message] = [tool_call_msg] # we now run all retriever functions in parallel - coros: list[Awaitable[_messages.Message]] = [] - names: list[str] = [] - for call in tool_call_msg.calls: - retriever = self._retrievers.get(call.tool_name) - if retriever is None: - raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') - coros.append(retriever.run(deps, call)) - names.append(call.tool_name) - - with _logfire.span('running {tools=}', tools=names): - messages += await asyncio.gather(*coros) + tasks: list[asyncio.Task[_messages.Message]] = [] + try: + for call in tool_call_msg.calls: + retriever = self._retrievers.get(call.tool_name) + if retriever is None: + raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') + tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name)) + except BaseException: + # otherwise we'll get warnings about coroutines not awaited + await _cancel_tasks(tasks) + raise + + with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): + messages += await asyncio.gather(*tasks) return _utils.Either(right=messages) async def _validate_result( @@ -510,3 +515,10 @@ async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]: prompt = await sys_prompt_runner.run(deps) messages.append(_messages.SystemPrompt(prompt)) return messages + + +async def _cancel_tasks(tasks: list[asyncio.Task[_messages.Message]]) -> None: + for task in tasks: + task.cancel() + with suppress(asyncio.CancelledError): + await asyncio.gather(*tasks) diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index 9cc93159f0..5bd62e5feb 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -7,7 +7,7 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Mapping, Sequence +from collections.abc import AsyncIterator, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from datetime import datetime from functools import cache @@ -70,13 +70,27 @@ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherS class StreamTextResponse(ABC): """Streamed response from an LLM when returning text.""" - def __aiter__(self) -> AsyncIterator[str]: - """Stream the response as an async iterable of string chunks.""" + def __aiter__(self) -> AsyncIterator[None]: + """Stream the response as an async iterable, building up the text as it goes. + + This is an async iterator that yields `None` to avoid doing the work of validating the input and + extracting the text field when it will often be thrown away. + """ return self @abstractmethod - async def __anext__(self) -> str: - """Process the next chunk of the response, and return the string delta.""" + async def __anext__(self) -> None: + """Process the next chunk of the response, see above for why this returns `None`.""" + raise NotImplementedError() + + @abstractmethod + def get(self, *, final: bool = False) -> Iterable[str]: + """Returns an iterable of text since the last call to `get()` — e.g. the text delta. + + Args: + final: If True, this is the final call, after iteration is complete, the response should be fully validated + and all text extracted. + """ raise NotImplementedError() @abstractmethod @@ -110,10 +124,13 @@ async def __anext__(self) -> None: raise NotImplementedError() @abstractmethod - def get(self) -> LLMToolCalls: + def get(self, *, final: bool = False) -> LLMToolCalls: """Get the `LLMToolCalls` at this point. The `LLMToolCalls` may or may not be complete, depending on whether the stream is finished. + + Args: + final: If True, this is the final call, after iteration is complete, the response should be fully validated. """ raise NotImplementedError() diff --git a/pydantic_ai/models/function.py b/pydantic_ai/models/function.py index ae2437b725..4485906324 100644 --- a/pydantic_ai/models/function.py +++ b/pydantic_ai/models/function.py @@ -123,10 +123,15 @@ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherS @dataclass class FunctionStreamTextResponse(StreamTextResponse): _iter: Iterator[str] - _timestamp: datetime = field(default_factory=_utils.now_utc) + _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) + _buffer: list[str] = field(default_factory=list, init=False) + + async def __anext__(self) -> None: + self._buffer.append(_utils.sync_anext(self._iter)) - async def __anext__(self) -> str: - return _utils.sync_anext(self._iter) + def get(self, *, final: bool = False) -> Iterable[str]: + yield from self._buffer + self._buffer.clear() def cost(self) -> result.Cost: return result.Cost() @@ -151,7 +156,7 @@ async def __anext__(self) -> None: else: self._delta_tool_calls[key] = new - def get(self) -> LLMToolCalls: + def get(self, *, final: bool = False) -> LLMToolCalls: """Map tool call deltas to a `LLMToolCalls`.""" calls: list[ToolCall] = [] for c in self._delta_tool_calls.values(): diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index b1fee6be2f..85f07fc30d 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -18,15 +18,16 @@ import os import re -from collections.abc import AsyncIterator, Mapping, Sequence +from collections.abc import AsyncIterator, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime from typing import Annotated, Any, Literal, Union +import pydantic_core from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse -from pydantic import Field +from pydantic import Discriminator, Field, Tag from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never from .. import UnexpectedModelBehaviour, _pydantic, _utils, exceptions, result @@ -196,7 +197,7 @@ async def _process_streamed_response(http_response: HTTPResponse) -> EitherStrea if _extract_response_parts(start_response).is_left(): return GeminiStreamToolCallResponse(_content=content, _stream=aiter_bytes) else: - return GeminiStreamTextResponse(_first=True, _content=content, _stream=aiter_bytes) + return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes) @staticmethod def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]: @@ -225,60 +226,57 @@ def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiCont @dataclass class GeminiStreamTextResponse(StreamTextResponse): - _first: bool - _content: bytearray + _json_content: bytearray _stream: AsyncIterator[bytes] _position: int = 0 - _timestamp: datetime = field(default_factory=_utils.now_utc) + _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) + _cost: result.Cost = field(default_factory=result.Cost, init=False) - async def __anext__(self) -> str: - if self._first: - self._first = False + async def __anext__(self) -> None: + chunk = await self._stream.__anext__() + self._json_content.extend(chunk) + + def get(self, *, final: bool = False) -> Iterable[str]: + if final: + all_items = pydantic_core.from_json(self._json_content) + new_items = all_items[self._position :] + self._position = len(all_items) + new_responses = _gemini_streamed_response_ta.validate_python(new_items) else: - chunk = await self._stream.__anext__() - self._content.extend(chunk) - - responses = self._responses() - new_responses = responses[self._position :] - self._position = len(responses) - new_text: list[str] = [] + all_items = pydantic_core.from_json(self._json_content, allow_partial=True) + new_items = all_items[self._position : -1] + self._position = len(all_items) - 1 + new_responses = _gemini_streamed_response_ta.validate_python(new_items, experimental_allow_partial=True) for r in new_responses: + self._cost += _metadata_as_cost(r['usage_metadata']) parts = r['candidates'][0]['content']['parts'] if all_text_parts(parts): - new_text.extend(part['text'] for part in parts) + for part in parts: + yield part['text'] else: raise UnexpectedModelBehaviour( 'Streamed response with unexpected content, expected all parts to be text' ) - return ''.join(new_text) def cost(self) -> result.Cost: - cost = result.Cost() - for response in self._responses(): - cost += _metadata_as_cost(response['usage_metadata']) - return cost + return self._cost def timestamp(self) -> datetime: return self._timestamp - def _responses(self) -> list[_GeminiResponse]: - return _gemini_streamed_response_ta.validate_json( - self._content, # type: ignore # see https://github.com/pydantic/pydantic/pull/10802 - experimental_allow_partial=True, - ) - @dataclass class GeminiStreamToolCallResponse(StreamToolCallResponse): _content: bytearray _stream: AsyncIterator[bytes] - _timestamp: datetime = field(default_factory=_utils.now_utc) + _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) + _cost: result.Cost = field(default_factory=result.Cost, init=False) async def __anext__(self) -> None: chunk = await self._stream.__anext__() self._content.extend(chunk) - def get(self) -> LLMToolCalls: + def get(self, *, final: bool = False) -> LLMToolCalls: """Get the `LLMToolCalls` at this point. NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always @@ -287,9 +285,14 @@ def get(self) -> LLMToolCalls: I'm therefore assuming that each part contains a complete tool call, and not trying to combine data from separate parts. """ - responses = self._responses() + responses = _gemini_streamed_response_ta.validate_json( + self._content, # type: ignore # see https://github.com/pydantic/pydantic/pull/10802 + experimental_allow_partial=not final, + ) combined_parts: list[_GeminiFunctionCallPart] = [] + self._cost = result.Cost() for r in responses: + self._cost += _metadata_as_cost(r['usage_metadata']) candidate = r['candidates'][0] parts = candidate['content']['parts'] if all_function_call_parts(parts): @@ -302,20 +305,11 @@ def get(self) -> LLMToolCalls: return _tool_call_from_parts(combined_parts, timestamp=self._timestamp) def cost(self) -> result.Cost: - cost = result.Cost() - for response in self._responses(): - cost += _metadata_as_cost(response['usage_metadata']) - return cost + return self._cost def timestamp(self) -> datetime: return self._timestamp - def _responses(self) -> list[_GeminiResponse]: - return _gemini_streamed_response_ta.validate_json( - self._content, # type: ignore # see https://github.com/pydantic/pydantic/pull/10802 - experimental_allow_partial=True, - ) - # We use typed dicts to define the Gemini API response schema # once Pydantic partial validation supports, dataclasses, we could revert to using them @@ -413,10 +407,28 @@ class _GeminiFunctionResponse(TypedDict): response: dict[str, Any] +def _part_discriminator(v: Any) -> str: + if isinstance(v, dict): + if 'text' in v: + return 'text' + elif 'functionCall' in v or 'function_call' in v: + return 'function_call' + elif 'functionResponse' in v or 'function_response' in v: + return 'function_response' + return 'text' + + # See # we don't currently support other part types # TODO discriminator -_GeminiPartUnion = Union[_GeminiTextPart, _GeminiFunctionCallPart, _GeminiFunctionResponsePart] +_GeminiPartUnion = Annotated[ + Union[ + Annotated[_GeminiTextPart, Tag('text')], + Annotated[_GeminiFunctionCallPart, Tag('function_call')], + Annotated[_GeminiFunctionResponsePart, Tag('function_response')], + ], + Discriminator(_part_discriminator), +] class _GeminiTextContent(TypedDict): diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 6f5c7c08d9..115e13fc50 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -1,6 +1,6 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Mapping, Sequence +from collections.abc import AsyncIterator, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone @@ -230,12 +230,13 @@ class OpenAIStreamTextResponse(StreamTextResponse): _response: AsyncStream[ChatCompletionChunk] _timestamp: datetime _cost: result.Cost + _buffer: list[str] = field(default_factory=list, init=False) - async def __anext__(self) -> str: + async def __anext__(self) -> None: if self._first is not None: - first = self._first + self._buffer.append(self._first) self._first = None - return first + return None chunk = await self._response.__anext__() self._cost += _map_cost(chunk) @@ -244,12 +245,15 @@ async def __anext__(self) -> str: except IndexError: raise StopAsyncIteration() - if choice.finish_reason is not None: - # we don't raise StopAsyncIteration on the last chunk because usage comes after this - return choice.delta.content or '' + # we don't raise StopAsyncIteration on the last chunk because usage comes after this + if choice.finish_reason is None: + assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}' + if choice.delta.content is not None: + self._buffer.append(choice.delta.content) - assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}' - return choice.delta.content + def get(self, *, final: bool = False) -> Iterable[str]: + yield from self._buffer + self._buffer.clear() def cost(self) -> Cost: return self._cost @@ -288,7 +292,7 @@ async def __anext__(self) -> None: else: self._delta_tool_calls[new.index] = new - def get(self) -> LLMToolCalls: + def get(self, *, final: bool = False) -> LLMToolCalls: """Map tool call deltas to a `LLMToolCalls`.""" calls: list[ToolCall] = [] for c in self._delta_tool_calls.values(): diff --git a/pydantic_ai/models/test.py b/pydantic_ai/models/test.py index 792a3dfcf7..bddce76037 100644 --- a/pydantic_ai/models/test.py +++ b/pydantic_ai/models/test.py @@ -8,7 +8,7 @@ import re import string -from collections.abc import AsyncIterator, Iterator, Mapping, Sequence +from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime @@ -177,6 +177,7 @@ class TestStreamTextResponse(StreamTextResponse): _cost: Cost _iter: Iterator[str] = field(init=False) _timestamp: datetime = field(default_factory=_utils.now_utc) + _buffer: list[str] = field(default_factory=list, init=False) def __post_init__(self): *words, last_word = self._text.split(' ') @@ -187,8 +188,12 @@ def __post_init__(self): words = [self._text[:mid], self._text[mid:]] self._iter = iter(words) - async def __anext__(self) -> str: - return _utils.sync_anext(self._iter) + async def __anext__(self) -> None: + self._buffer.append(_utils.sync_anext(self._iter)) + + def get(self, *, final: bool = False) -> Iterable[str]: + yield from self._buffer + self._buffer.clear() def cost(self) -> Cost: return self._cost @@ -202,12 +207,12 @@ class TestStreamToolCallResponse(StreamToolCallResponse): _structured_response: LLMToolCalls _cost: Cost _iter: Iterator[None] = field(default_factory=lambda: iter([None])) - _timestamp: datetime = field(default_factory=_utils.now_utc) + _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) async def __anext__(self) -> None: return _utils.sync_anext(self._iter) - def get(self) -> LLMToolCalls: + def get(self, *, final: bool = False) -> LLMToolCalls: return self._structured_response def cost(self) -> Cost: diff --git a/pydantic_ai/result.py b/pydantic_ai/result.py index 2fa84e5d4f..8a4b7f4a00 100644 --- a/pydantic_ai/result.py +++ b/pydantic_ai/result.py @@ -98,7 +98,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat cost_so_far: Cost """Cost up until the last request.""" - _stream_response: models.StreamTextResponse | models.StreamToolCallResponse + _stream_response: models.EitherStreamedResponse _result_schema: _result.ResultSchema[ResultData] | None _deps: AgentDeps _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] @@ -135,25 +135,43 @@ async def stream( with _logfire.span('response stream text') as lf_span: if text_delta: async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: - async for chunks in group_iter: - yield ''.join(chunks) # pyright: ignore[reportReturnType] + async for _ in group_iter: + yield cast(ResultData, ''.join(self._stream_response.get())) + final_delta = ''.join(self._stream_response.get(final=True)) + if final_delta: + yield cast(ResultData, final_delta) else: # a quick benchmark shows it's faster to build up a string with concat when we're # yielding at each step + chunks: list[str] = [] combined = '' async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: - async for chunks in group_iter: - combined += ''.join(chunks) - combined = await self._validate_text_result(combined) - yield cast(ResultData, combined) + async for _ in group_iter: + new = False + for chunk in self._stream_response.get(): + chunks.append(chunk) + new = True + if new: + combined = await self._validate_text_result(''.join(chunks)) + yield cast(ResultData, combined) + + new = False + for chunk in self._stream_response.get(final=True): + chunks.append(chunk) + new = True + if new: + combined = await self._validate_text_result(''.join(chunks)) + yield cast(ResultData, combined) lf_span.set_attribute('combined_text', combined) self._marked_completed(text=combined) else: assert not text_delta, 'Cannot use `text_delta=True` for structured responses' - async for tool_message in self.stream_structured(debounce_by=debounce_by): - yield await self.validate_structured_result(tool_message, allow_partial=True) + async for tool_message, is_last in self.stream_structured(debounce_by=debounce_by): + yield await self.validate_structured_result(tool_message, allow_partial=not is_last) - async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[messages.LLMToolCalls]: + async def stream_structured( + self, *, debounce_by: float | None = 0.1 + ) -> AsyncIterator[tuple[messages.LLMToolCalls, bool]]: """Stream the response as an async iterable of Structured LLM Messages. !!! note @@ -164,6 +182,8 @@ async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIt the response stream is debounced by 0.2 seconds unless `text_delta` is `True`, in which case it doesn't make sense to debounce. `None` means no debouncing. Debouncing is important particularly for long structured responses to reduce the overhead of performing validation as each token is received. + + Returns: An async iterable of the structured response message and whether it's the last message. """ with _logfire.span('response stream structured') as lf_span: if isinstance(self._stream_response, models.StreamTextResponse): @@ -172,26 +192,30 @@ async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIt # we should already have a message at this point, yield that first if it has any content msg = self._stream_response.get() if any(call.has_content() for call in msg.calls): - yield msg + yield msg, False async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: async for _ in group_iter: msg = self._stream_response.get() if any(call.has_content() for call in msg.calls): - yield msg + yield msg, False + msg = self._stream_response.get(final=True) + yield msg, True lf_span.set_attribute('structured_response', msg) - self._marked_completed(structured_message=msg) + self._marked_completed(structured_message=msg) async def get_response(self) -> ResultData: """Stream the whole response, validate and return it.""" if isinstance(self._stream_response, models.StreamTextResponse): - text = ''.join([chunk async for chunk in self._stream_response]) + async for _ in self._stream_response: + pass + text = ''.join(self._stream_response.get(final=True)) text = await self._validate_text_result(text) self._marked_completed(text=text) return cast(ResultData, text) else: async for _ in self._stream_response: pass - tool_message = self._stream_response.get() + tool_message = self._stream_response.get(final=True) self._marked_completed(structured_message=tool_message) return await self.validate_structured_result(tool_message) diff --git a/pydantic_ai_examples/stream_whales.py b/pydantic_ai_examples/stream_whales.py index 91f0a4e034..2b18c44522 100644 --- a/pydantic_ai_examples/stream_whales.py +++ b/pydantic_ai_examples/stream_whales.py @@ -43,12 +43,12 @@ async def main(): console = Console() with Live('\n' * 36, console=console) as live: console.print('Requesting data...', style='cyan') - async with agent.run_stream('Generate me details of 20 species of Whale.') as result: + async with agent.run_stream('Generate me details of 5 species of Whale.') as result: console.print('Response:', style='green') - async for message in result.stream_structured(debounce_by=0.01): + async for message, last in result.stream_structured(debounce_by=0.01): try: - whales = await result.validate_structured_result(message, allow_partial=True) + whales = await result.validate_structured_result(message, allow_partial=not last) except ValidationError as exc: if all(e['type'] == 'missing' and e['loc'] == ('response',) for e in exc.errors()): continue diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index e1c2fca93a..13d982dc29 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -2,7 +2,7 @@ from __future__ import annotations as _annotations import json -from collections.abc import AsyncIterator, Callable +from collections.abc import AsyncIterator, Callable, Sequence from dataclasses import dataclass from datetime import timezone @@ -330,28 +330,33 @@ async def __aiter__(self) -> AsyncIterator[bytes]: yield chunk +ResOrList: TypeAlias = '_GeminiResponse | httpx.AsyncByteStream | Sequence[_GeminiResponse | httpx.AsyncByteStream]' +GetGeminiClient: TypeAlias = 'Callable[[ResOrList], httpx.AsyncClient]' + + @pytest.fixture async def get_gemini_client(client_with_handler: ClientWithHandler, env: TestEnv): env.set('GEMINI_API_KEY', 'via-env-var') - def create_client( - response_data: _GeminiResponse | httpx.AsyncByteStream | list[_GeminiResponse], - ) -> httpx.AsyncClient: + def create_client(response_or_list: ResOrList) -> httpx.AsyncClient: index = 0 def handler(_request: httpx.Request) -> httpx.Response: nonlocal index - content: bytes | None = None - stream: httpx.AsyncByteStream | None = None - if isinstance(response_data, list): - content = _gemini_response_ta.dump_json(response_data[index], by_alias=True) - elif isinstance(response_data, httpx.AsyncByteStream): - stream = response_data + if isinstance(response_or_list, Sequence): + response = response_or_list[index] + index += 1 + else: + response = response_or_list + + if isinstance(response, httpx.AsyncByteStream): + content: bytes | None = None + stream: httpx.AsyncByteStream | None = response else: - content = _gemini_response_ta.dump_json(response_data, by_alias=True) + content = _gemini_response_ta.dump_json(response, by_alias=True) + stream = None - index += 1 return httpx.Response( 200, content=content, @@ -364,11 +369,6 @@ def handler(_request: httpx.Request) -> httpx.Response: return create_client -GetGeminiClient: TypeAlias = ( - 'Callable[[_GeminiResponse | httpx.AsyncByteStream | list[_GeminiResponse]], httpx.AsyncClient]' -) - - def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | None = 'STOP') -> _GeminiResponse: candidate = _GeminiCandidates(content=content, index=0, safety_ratings=[]) if finish_reason: # pragma: no cover @@ -551,6 +551,7 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(text_delta=True, debounce_by=None)] assert chunks == snapshot(['Hello ', 'world']) + assert result.cost() == snapshot(Cost(request_tokens=2, response_tokens=4, total_tokens=6)) async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): @@ -579,4 +580,70 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] - assert chunks == snapshot([(1, 2), (1, 2)]) + assert chunks == snapshot([(1, 2), (1, 2), (1, 2)]) + assert result.cost() == snapshot(Cost(request_tokens=1, response_tokens=2, total_tokens=3)) + + +async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): + first_responses = [ + gemini_response( + _content_function_call(LLMToolCalls(calls=[ToolCall.from_object('foo', {'x': 'a'})])), + ), + gemini_response( + _content_function_call(LLMToolCalls(calls=[ToolCall.from_object('bar', {'y': 'b'})])), + ), + ] + d1 = _gemini_streamed_response_ta.dump_json(first_responses, by_alias=True) + first_stream = AsyncByteStreamList([d1[:100], d1[100:200], d1[200:300], d1[300:]]) + + second_responses = [ + gemini_response( + _content_function_call(LLMToolCalls(calls=[ToolCall.from_object('final_result', {'response': [1, 2]})])), + ), + ] + d2 = _gemini_streamed_response_ta.dump_json(second_responses, by_alias=True) + second_stream = AsyncByteStreamList([d2[:100], d2[100:]]) + + gemini_client = get_gemini_client([first_stream, second_stream]) + model = GeminiModel('gemini-1.5-flash', http_client=gemini_client) + agent = Agent(model, result_type=tuple[int, int], deps=None) + retriever_calls: list[str] = [] + + @agent.retriever_plain + async def foo(x: str) -> str: + retriever_calls.append(f'foo({x=!r})') + return x + + @agent.retriever_plain + async def bar(y: str) -> str: + retriever_calls.append(f'bar({y=!r})') + return y + + async with agent.run_stream('Hello') as result: + response = await result.get_response() + assert response == snapshot((1, 2)) + assert result.cost() == snapshot(Cost(request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.all_messages() == snapshot( + [ + UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), + LLMToolCalls( + calls=[ + ToolCall(tool_name='foo', args=ArgsObject(args_object={'x': 'a'})), + ToolCall(tool_name='bar', args=ArgsObject(args_object={'y': 'b'})), + ], + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturn(tool_name='foo', content='a', timestamp=IsNow(tz=timezone.utc)), + ToolReturn(tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc)), + LLMToolCalls( + calls=[ + ToolCall( + tool_name='final_result', + args=ArgsObject(args_object={'response': [1, 2]}), + ) + ], + timestamp=IsNow(tz=timezone.utc), + ), + ] + ) + assert retriever_calls == snapshot(["foo(x='a')", "bar(y='b')"]) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 1ba2b1d4c7..3ebacd73e8 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -379,7 +379,12 @@ async def test_stream_structured(): assert result.is_structured() assert not result.is_complete assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( - [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + [ + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] ) assert result.is_complete assert result.cost() == snapshot(Cost(request_tokens=20, response_tokens=10, total_tokens=30)) @@ -403,7 +408,12 @@ async def test_stream_structured_finish_reason(): assert result.is_structured() assert not result.is_complete assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( - [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + [ + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] ) assert result.is_complete diff --git a/tests/test_streaming.py b/tests/test_streaming.py index a14735d699..7e50945a37 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -90,11 +90,11 @@ def text_stream(_messages: list[Message], agent_info: AgentInfo) -> Iterable[Del chunks: list[list[int]] = [] async with agent.run_stream('') as result: - async for structured_response in result.stream_structured(debounce_by=None): - response_data = await result.validate_structured_result(structured_response, allow_partial=True) + async for structured_response, last in result.stream_structured(debounce_by=None): + response_data = await result.validate_structured_result(structured_response, allow_partial=not last) chunks.append(response_data) - assert chunks == snapshot([[1], [1, 2, 3, 4]]) + assert chunks == snapshot([[1], [1, 2, 3, 4], [1, 2, 3, 4]]) async def test_streamed_text_stream():