Skip to content

Helper for calculating state deltas and emitting state updates from within tool calls (AG-UI) #3078

@dorukgezici

Description

@dorukgezici

Description

I would like to propose an alternative StateDeps implementation that supports calculating state deltas for AG-UI and enables streaming events from within tool calls in a convenient way, while still allowing the agent to be used out of AG-UI context because we don't have to return BaseEvent from tool calls anymore.

There is a limitation from AG-UI, that forces us to send a whole StateSnapshot (even when the state was not empty to start with) in each run. My emit_state context manager implementation at least optimizes in such a way, if you have multiple tools calls or multiple "emitting" steps in a single run, it first creates a snapshot, then for the next ones it calculates the delta using jsonpatch. And you can use the context manager multiple times, if your tool is multi-step.

A few gotchas:

  • Should/could improve the handle_ag_ui_request helper to simplify the router
  • Also we are validating deps two times with this router, because run_ag_ui validates it too
  • Usage of new dependency jsonpatch

This would be the StateDeps enhancement:

from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import AsyncGenerator, Generic

import jsonpatch
from ag_ui.core import BaseEvent, EventType, StateDeltaEvent, StateSnapshotEvent
from ag_ui.encoder import EventEncoder
from anyio.streams.memory import MemoryObjectSendStream
from pydantic_ai.ag_ui import SSE_CONTENT_TYPE, StateT


@dataclass
class StreamStateDeps(Generic[StateT]):
    """Streaming capable state deps base class for AG-UI agents.

    Handles calculating state deltas and streaming AG-UI events.
    """

    state: StateT
    stream: MemoryObjectSendStream[str] | None = None # this is optional to allow the same agent to be used without streaming

    _state_initialized: bool = False  # initial `StateSnapshot` sent or not

    async def send(self, event: BaseEvent) -> None:
        if not self.stream:
            return

        encoder = EventEncoder(accept=SSE_CONTENT_TYPE)
        await self.stream.send(encoder.encode(event))

    @asynccontextmanager
    async def emit_state(self) -> AsyncGenerator[StateT]:
        """Async context manager for AG-UI state delta calculation and emission."""

        # take snapshot of old state
        old_json = self.state.model_dump(mode="json")

        # yield state object for modification
        yield self.state

        if not self._state_initialized:
            event = StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=self.state)
            self._state_initialized = True
        else:
            # calculate event delta
            current_json = self.state.model_dump(mode="json")
            patch = jsonpatch.make_patch(old_json, current_json)
            event = StateDeltaEvent(type=EventType.STATE_DELTA, delta=list(patch) if patch else [])

        # emit AG-UI event
        await self.send(event)

And this is how you set up the router, derived from many comments shared before:

@router.post("/chat")
async def run_agent(request: Request) -> Response:
    accept = request.headers.get("accept", SSE_CONTENT_TYPE)
    try:
        run_input = RunAgentInput.model_validate(await request.json())
        state = MyState.model_validate(run_input.state)
    except ValidationError as e:
        return Response(
            content=json.dumps(e.json()),
            media_type="application/json",
            status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
        )

    send_stream, receive_stream = create_memory_object_stream[str]()
    deps = StreamStateDeps(state=state, stream=send_stream)

    async def run_agent(send_stream: MemoryObjectSendStream[str]) -> None:
        async with send_stream:
            event_stream = run_ag_ui(chat_agent, run_input, deps=deps)
            async for event in event_stream:
                await send_stream.send(event)

    async def stream_generator():
        async with create_task_group() as tg:
            tg.start_soon(run_agent, send_stream)
            async with receive_stream:
                async for event in receive_stream:
                    yield event

    return StreamingResponse(stream_generator(), media_type=accept)

And lastly, the convenient interface to use from within tool calls:

@agent.tool
async def fetch_website_content(ctx: RunContext[StreamStateDeps[MyState]], url: str) -> str:
    website_content = await _fetch_website_content(url)

    # when the context manager is closed, state deltas will be auto-calculated and sent over the stream
    async with ctx.deps.emit_state() as state:
        state.website_content = website_content

    return website_content

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions