Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations as _annotations

import asyncio
from collections.abc import AsyncIterator, Awaitable, Sequence
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager, suppress
from dataclasses import dataclass
from typing import Any, Callable, Generic, Literal, cast, final, overload

Expand Down Expand Up @@ -413,18 +413,20 @@ async def _handle_model_response(
raise exceptions.UnexpectedModelBehaviour('Received empty tool call message')

# otherwise we run all retriever functions in parallel
coros: list[Awaitable[_messages.Message]] = []
names: list[str] = []
for call in model_response.calls:
retriever = self._retrievers.get(call.tool_name)
if retriever is None:
# should this be a retry error?
raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}')
coros.append(retriever.run(deps, call))
names.append(call.tool_name)

with _logfire.span('running {tools=}', tools=names):
new_messages = await asyncio.gather(*coros)
tasks: list[asyncio.Task[_messages.Message]] = []
try:
for call in model_response.calls:
retriever = self._retrievers.get(call.tool_name)
if retriever is None:
# should this be a retry error?
raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}')
tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
except BaseException:
await _cancel_tasks(tasks)
raise

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
new_messages = await asyncio.gather(*tasks)
return _utils.Either(right=new_messages)
else:
assert_never(model_response)
Expand Down Expand Up @@ -476,17 +478,20 @@ async def _handle_streamed_model_response(
messages: list[_messages.Message] = [tool_call_msg]

# we now run all retriever functions in parallel
coros: list[Awaitable[_messages.Message]] = []
names: list[str] = []
for call in tool_call_msg.calls:
retriever = self._retrievers.get(call.tool_name)
if retriever is None:
raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}')
coros.append(retriever.run(deps, call))
names.append(call.tool_name)

with _logfire.span('running {tools=}', tools=names):
messages += await asyncio.gather(*coros)
tasks: list[asyncio.Task[_messages.Message]] = []
try:
for call in tool_call_msg.calls:
retriever = self._retrievers.get(call.tool_name)
if retriever is None:
raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}')
tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
except BaseException:
# otherwise we'll get warnings about coroutines not awaited
await _cancel_tasks(tasks)
raise

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
messages += await asyncio.gather(*tasks)
return _utils.Either(right=messages)

async def _validate_result(
Expand All @@ -510,3 +515,10 @@ async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]:
prompt = await sys_prompt_runner.run(deps)
messages.append(_messages.SystemPrompt(prompt))
return messages


async def _cancel_tasks(tasks: list[asyncio.Task[_messages.Message]]) -> None:
for task in tasks:
task.cancel()
with suppress(asyncio.CancelledError):
await asyncio.gather(*tasks)
29 changes: 23 additions & 6 deletions pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations as _annotations

from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Mapping, Sequence
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
from contextlib import asynccontextmanager
from datetime import datetime
from functools import cache
Expand Down Expand Up @@ -70,13 +70,27 @@ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherS
class StreamTextResponse(ABC):
"""Streamed response from an LLM when returning text."""

def __aiter__(self) -> AsyncIterator[str]:
"""Stream the response as an async iterable of string chunks."""
def __aiter__(self) -> AsyncIterator[None]:
"""Stream the response as an async iterable, building up the text as it goes.

This is an async iterator that yields `None` to avoid doing the work of validating the input and
extracting the text field when it will often be thrown away.
"""
return self

@abstractmethod
async def __anext__(self) -> str:
"""Process the next chunk of the response, and return the string delta."""
async def __anext__(self) -> None:
"""Process the next chunk of the response, see above for why this returns `None`."""
raise NotImplementedError()

@abstractmethod
def get(self, *, final: bool = False) -> Iterable[str]:
"""Returns an iterable of text since the last call to `get()` — e.g. the text delta.

Args:
final: If True, this is the final call, after iteration is complete, the response should be fully validated
and all text extracted.
"""
raise NotImplementedError()

@abstractmethod
Expand Down Expand Up @@ -110,10 +124,13 @@ async def __anext__(self) -> None:
raise NotImplementedError()

@abstractmethod
def get(self) -> LLMToolCalls:
def get(self, *, final: bool = False) -> LLMToolCalls:
"""Get the `LLMToolCalls` at this point.

The `LLMToolCalls` may or may not be complete, depending on whether the stream is finished.

Args:
final: If True, this is the final call, after iteration is complete, the response should be fully validated.
"""
raise NotImplementedError()

Expand Down
13 changes: 9 additions & 4 deletions pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,15 @@ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherS
@dataclass
class FunctionStreamTextResponse(StreamTextResponse):
_iter: Iterator[str]
_timestamp: datetime = field(default_factory=_utils.now_utc)
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
_buffer: list[str] = field(default_factory=list, init=False)

async def __anext__(self) -> None:
self._buffer.append(_utils.sync_anext(self._iter))

async def __anext__(self) -> str:
return _utils.sync_anext(self._iter)
def get(self, *, final: bool = False) -> Iterable[str]:
yield from self._buffer
self._buffer.clear()

def cost(self) -> result.Cost:
return result.Cost()
Expand All @@ -151,7 +156,7 @@ async def __anext__(self) -> None:
else:
self._delta_tool_calls[key] = new

def get(self) -> LLMToolCalls:
def get(self, *, final: bool = False) -> LLMToolCalls:
"""Map tool call deltas to a `LLMToolCalls`."""
calls: list[ToolCall] = []
for c in self._delta_tool_calls.values():
Expand Down
96 changes: 54 additions & 42 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@

import os
import re
from collections.abc import AsyncIterator, Mapping, Sequence
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
from contextlib import asynccontextmanager
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import datetime
from typing import Annotated, Any, Literal, Union

import pydantic_core
from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse
from pydantic import Field
from pydantic import Discriminator, Field, Tag
from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never

from .. import UnexpectedModelBehaviour, _pydantic, _utils, exceptions, result
Expand Down Expand Up @@ -196,7 +197,7 @@ async def _process_streamed_response(http_response: HTTPResponse) -> EitherStrea
if _extract_response_parts(start_response).is_left():
return GeminiStreamToolCallResponse(_content=content, _stream=aiter_bytes)
else:
return GeminiStreamTextResponse(_first=True, _content=content, _stream=aiter_bytes)
return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)

@staticmethod
def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]:
Expand Down Expand Up @@ -225,60 +226,57 @@ def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiCont

@dataclass
class GeminiStreamTextResponse(StreamTextResponse):
_first: bool
_content: bytearray
_json_content: bytearray
_stream: AsyncIterator[bytes]
_position: int = 0
_timestamp: datetime = field(default_factory=_utils.now_utc)
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
_cost: result.Cost = field(default_factory=result.Cost, init=False)

async def __anext__(self) -> str:
if self._first:
self._first = False
async def __anext__(self) -> None:
chunk = await self._stream.__anext__()
self._json_content.extend(chunk)

def get(self, *, final: bool = False) -> Iterable[str]:
if final:
all_items = pydantic_core.from_json(self._json_content)
new_items = all_items[self._position :]
self._position = len(all_items)
new_responses = _gemini_streamed_response_ta.validate_python(new_items)
else:
chunk = await self._stream.__anext__()
self._content.extend(chunk)

responses = self._responses()
new_responses = responses[self._position :]
self._position = len(responses)
new_text: list[str] = []
all_items = pydantic_core.from_json(self._json_content, allow_partial=True)
new_items = all_items[self._position : -1]
self._position = len(all_items) - 1
new_responses = _gemini_streamed_response_ta.validate_python(new_items, experimental_allow_partial=True)
for r in new_responses:
self._cost += _metadata_as_cost(r['usage_metadata'])
parts = r['candidates'][0]['content']['parts']
if all_text_parts(parts):
new_text.extend(part['text'] for part in parts)
for part in parts:
yield part['text']
else:
raise UnexpectedModelBehaviour(
'Streamed response with unexpected content, expected all parts to be text'
)
return ''.join(new_text)

def cost(self) -> result.Cost:
cost = result.Cost()
for response in self._responses():
cost += _metadata_as_cost(response['usage_metadata'])
return cost
return self._cost

def timestamp(self) -> datetime:
return self._timestamp

def _responses(self) -> list[_GeminiResponse]:
return _gemini_streamed_response_ta.validate_json(
self._content, # type: ignore # see https://github.com/pydantic/pydantic/pull/10802
experimental_allow_partial=True,
)


@dataclass
class GeminiStreamToolCallResponse(StreamToolCallResponse):
_content: bytearray
_stream: AsyncIterator[bytes]
_timestamp: datetime = field(default_factory=_utils.now_utc)
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
_cost: result.Cost = field(default_factory=result.Cost, init=False)

async def __anext__(self) -> None:
chunk = await self._stream.__anext__()
self._content.extend(chunk)

def get(self) -> LLMToolCalls:
def get(self, *, final: bool = False) -> LLMToolCalls:
"""Get the `LLMToolCalls` at this point.

NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
Expand All @@ -287,9 +285,14 @@ def get(self) -> LLMToolCalls:
I'm therefore assuming that each part contains a complete tool call, and not trying to combine data from
separate parts.
"""
responses = self._responses()
responses = _gemini_streamed_response_ta.validate_json(
self._content, # type: ignore # see https://github.com/pydantic/pydantic/pull/10802
experimental_allow_partial=not final,
)
combined_parts: list[_GeminiFunctionCallPart] = []
self._cost = result.Cost()
for r in responses:
self._cost += _metadata_as_cost(r['usage_metadata'])
candidate = r['candidates'][0]
parts = candidate['content']['parts']
if all_function_call_parts(parts):
Expand All @@ -302,20 +305,11 @@ def get(self) -> LLMToolCalls:
return _tool_call_from_parts(combined_parts, timestamp=self._timestamp)

def cost(self) -> result.Cost:
cost = result.Cost()
for response in self._responses():
cost += _metadata_as_cost(response['usage_metadata'])
return cost
return self._cost

def timestamp(self) -> datetime:
return self._timestamp

def _responses(self) -> list[_GeminiResponse]:
return _gemini_streamed_response_ta.validate_json(
self._content, # type: ignore # see https://github.com/pydantic/pydantic/pull/10802
experimental_allow_partial=True,
)


# We use typed dicts to define the Gemini API response schema
# once Pydantic partial validation supports, dataclasses, we could revert to using them
Expand Down Expand Up @@ -413,10 +407,28 @@ class _GeminiFunctionResponse(TypedDict):
response: dict[str, Any]


def _part_discriminator(v: Any) -> str:
if isinstance(v, dict):
if 'text' in v:
return 'text'
elif 'functionCall' in v or 'function_call' in v:
return 'function_call'
elif 'functionResponse' in v or 'function_response' in v:
return 'function_response'
return 'text'


# See <https://ai.google.dev/api/caching#Part>
# we don't currently support other part types
# TODO discriminator
_GeminiPartUnion = Union[_GeminiTextPart, _GeminiFunctionCallPart, _GeminiFunctionResponsePart]
_GeminiPartUnion = Annotated[
Union[
Annotated[_GeminiTextPart, Tag('text')],
Annotated[_GeminiFunctionCallPart, Tag('function_call')],
Annotated[_GeminiFunctionResponsePart, Tag('function_response')],
],
Discriminator(_part_discriminator),
]


class _GeminiTextContent(TypedDict):
Expand Down
Loading
Loading