-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Description
Description
I would like to propose an alternative StateDeps
implementation that supports calculating state deltas for AG-UI and enables streaming events from within tool calls in a convenient way, while still allowing the agent to be used out of AG-UI context because we don't have to return BaseEvent
from tool calls anymore.
There is a limitation from AG-UI, that forces us to send a whole StateSnapshot
(even when the state was not empty to start with) in each run. My emit_state
context manager implementation at least optimizes in such a way, if you have multiple tools calls or multiple "emitting" steps in a single run, it first creates a snapshot, then for the next ones it calculates the delta using jsonpatch
. And you can use the context manager multiple times, if your tool is multi-step.
A few gotchas:
- Should/could improve the
handle_ag_ui_request
helper to simplify the router - Also we are validating
deps
two times with this router, becauserun_ag_ui
validates it too - Usage of new dependency
jsonpatch
This would be the StateDeps
enhancement:
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import AsyncGenerator, Generic
import jsonpatch
from ag_ui.core import BaseEvent, EventType, StateDeltaEvent, StateSnapshotEvent
from ag_ui.encoder import EventEncoder
from anyio.streams.memory import MemoryObjectSendStream
from pydantic_ai.ag_ui import SSE_CONTENT_TYPE, StateT
@dataclass
class StreamStateDeps(Generic[StateT]):
"""Streaming capable state deps base class for AG-UI agents.
Handles calculating state deltas and streaming AG-UI events.
"""
state: StateT
stream: MemoryObjectSendStream[str] | None = None # this is optional to allow the same agent to be used without streaming
_state_initialized: bool = False # initial `StateSnapshot` sent or not
async def send(self, event: BaseEvent) -> None:
if not self.stream:
return
encoder = EventEncoder(accept=SSE_CONTENT_TYPE)
await self.stream.send(encoder.encode(event))
@asynccontextmanager
async def emit_state(self) -> AsyncGenerator[StateT]:
"""Async context manager for AG-UI state delta calculation and emission."""
# take snapshot of old state
old_json = self.state.model_dump(mode="json")
# yield state object for modification
yield self.state
if not self._state_initialized:
event = StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=self.state)
self._state_initialized = True
else:
# calculate event delta
current_json = self.state.model_dump(mode="json")
patch = jsonpatch.make_patch(old_json, current_json)
event = StateDeltaEvent(type=EventType.STATE_DELTA, delta=list(patch) if patch else [])
# emit AG-UI event
await self.send(event)
And this is how you set up the router, derived from many comments shared before:
@router.post("/chat")
async def run_agent(request: Request) -> Response:
accept = request.headers.get("accept", SSE_CONTENT_TYPE)
try:
run_input = RunAgentInput.model_validate(await request.json())
state = MyState.model_validate(run_input.state)
except ValidationError as e:
return Response(
content=json.dumps(e.json()),
media_type="application/json",
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
)
send_stream, receive_stream = create_memory_object_stream[str]()
deps = StreamStateDeps(state=state, stream=send_stream)
async def run_agent(send_stream: MemoryObjectSendStream[str]) -> None:
async with send_stream:
event_stream = run_ag_ui(chat_agent, run_input, deps=deps)
async for event in event_stream:
await send_stream.send(event)
async def stream_generator():
async with create_task_group() as tg:
tg.start_soon(run_agent, send_stream)
async with receive_stream:
async for event in receive_stream:
yield event
return StreamingResponse(stream_generator(), media_type=accept)
And lastly, the convenient interface to use from within tool calls:
@agent.tool
async def fetch_website_content(ctx: RunContext[StreamStateDeps[MyState]], url: str) -> str:
website_content = await _fetch_website_content(url)
# when the context manager is closed, state deltas will be auto-calculated and sent over the stream
async with ctx.deps.emit_state() as state:
state.website_content = website_content
return website_content