Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ lint = ["mypy>=1.11.2", "pyright>=1.1.390", "ruff>=0.6.9"]
harness = ["testcontainers", "hypercorn", "httpx"]
serde = ["dacite", "pydantic", "msgspec"]
client = ["httpx[http2]"]
adk = ["google-adk>=1.20.0"]

[build-system]
requires = ["maturin>=1.6,<2.0"]
Expand Down
10 changes: 10 additions & 0 deletions python/restate/ext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
18 changes: 18 additions & 0 deletions python/restate/ext/adk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#

from .session import RestateSessionService
from .plugin import RestatePlugin

__all__ = [
"RestateSessionService",
"RestatePlugin",
]
180 changes: 180 additions & 0 deletions python/restate/ext/adk/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""
ADK plugin implementation for restate.
"""

import asyncio
import restate

from datetime import timedelta
from typing import Optional, Any, cast

from google.genai import types

from google.adk.agents import BaseAgent, LlmAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.plugins import BasePlugin
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.models import LLMRegistry
from google.adk.models.base_llm import BaseLlm
from google.adk.flows.llm_flows.functions import generate_client_function_call_id


from restate.extensions import current_context

from .session import flush_session_state


class RestatePlugin(BasePlugin):
"""A plugin to integrate Restate with the ADK framework."""

_models: dict[str, BaseLlm]
_locks: dict[str, asyncio.Lock]

def __init__(self, *, max_model_call_retries: int = 10):
super().__init__(name="restate_plugin")
self._models = {}
self._locks = {}
self._max_model_call_retries = max_model_call_retries

async def before_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext
) -> Optional[types.Content]:
if not isinstance(agent, LlmAgent):
raise restate.TerminalError("RestatePlugin only supports LlmAgent agents.")
ctx = current_context() # Ensure we have a Restate context
if ctx is None:
raise restate.TerminalError(
"""No Restate context found for RestatePlugin.
Ensure that the agent is invoked within a restate handler and,
using a ```with restate_overrides(ctx):``` block. around your agent use."""
)
model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model)
self._models[callback_context.invocation_id] = model
self._locks[callback_context.invocation_id] = asyncio.Lock()

id = callback_context.invocation_id
event = ctx.request().attempt_finished_event

async def release_task():
"""make sure to release resources when the agent finishes"""
try:
await event.wait()
finally:
self._models.pop(id, None)
self._locks.pop(id, None)

_ = asyncio.create_task(release_task())
return None

async def after_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext
) -> Optional[types.Content]:
self._models.pop(callback_context.invocation_id, None)
self._locks.pop(callback_context.invocation_id, None)

ctx = cast(restate.ObjectContext, current_context())
await flush_session_state(ctx, callback_context.session)

return None

async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
model = self._models[callback_context.invocation_id]
ctx = current_context()
if ctx is None:
raise RuntimeError(
"No Restate context found, the restate plugin must be used from within a restate handler."
)
response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request)
return response

async def before_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> Optional[dict]:
tool_context.session.state["restate_context"] = current_context()
lock = self._locks[tool_context.invocation_id]
await lock.acquire()
# TODO: if we want we can also automatically wrap tools with ctx.run_typed here
return None

async def after_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
result: dict,
) -> Optional[dict]:
lock = self._locks[tool_context.invocation_id]
lock.release()
tool_context.session.state.pop("restate_context", None)
return None

async def on_tool_error_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
lock = self._locks[tool_context.invocation_id]
lock.release()
tool_context.session.state.pop("restate_context", None)
return None

async def close(self):
self._models.clear()
self._locks.clear()


def _generate_client_function_call_id(s: LlmResponse) -> None:
"""Generate client function call IDs for function calls in the LlmResponse.
It is important for the function call IDs to be stable across retries, as they
are used to correlate function call results with their invocations.
"""
if s.content and s.content.parts:
for part in s.content.parts:
if part.function_call:
if not part.function_call.id:
id = generate_client_function_call_id()
part.function_call.id = id


async def _generate_content_async(
ctx: restate.Context, max_attempts: int, model: BaseLlm, llm_request: LlmRequest
) -> LlmResponse:
"""Generate content using Restate's context."""

async def call_llm() -> LlmResponse:
a_gen = model.generate_content_async(llm_request, stream=False)
try:
result = await anext(a_gen)
_generate_client_function_call_id(result)
return result
finally:
await a_gen.aclose()

return await ctx.run_typed(
"call LLM",
call_llm,
restate.RunOptions(max_attempts=max_attempts, initial_retry_interval=timedelta(seconds=1)),
)
100 changes: 100 additions & 0 deletions python/restate/ext/adk/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""
ADK session service implementation using Restate Virtual Objects as the backing store.
"""

import restate

from typing import Optional, Any, cast

from google.adk.sessions import Session
from google.adk.events.event import Event
from google.adk.sessions.base_session_service import (
BaseSessionService,
ListSessionsResponse,
GetSessionConfig,
)

from restate.extensions import current_context


class RestateSessionService(BaseSessionService):
def ctx(self) -> restate.ObjectContext:
return cast(restate.ObjectContext, current_context())

async def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
if session_id is None:
session_id = str(self.ctx().uuid())

session = await self.ctx().get(f"session_store::{session_id}", type_hint=Session) or Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=state or {},
)
self.ctx().set(f"session_store::{session_id}", session)
return session

async def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
# TODO : Handle config options
return await self.ctx().get(f"session_store::{session_id}", type_hint=Session) or Session(
app_name=app_name,
user_id=user_id,
id=session_id,
)

async def list_sessions(self, *, app_name: str, user_id: Optional[str] = None) -> ListSessionsResponse:
state_keys = await self.ctx().state_keys()
sessions = []
for key in state_keys:
if key.startswith("session_store::"):
session = await self.ctx().get(key, type_hint=Session)
if session is not None:
sessions.append(session)
return ListSessionsResponse(sessions=sessions)

async def delete_session(self, *, app_name: str, user_id: str, session_id: str) -> None:
self.ctx().clear(f"session_store::{session_id}")

async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
return event
# For now, we also store temp state
event = self._trim_temp_delta_state(event)
self._update_session_state(session, event)
session.events.append(event)
return event


async def flush_session_state(ctx: restate.ObjectContext, session: Session):
session_to_store = session.model_copy()
# Remove restate-specific context that got added by the plugin before storing
session_to_store.state.pop("restate_context", None)
deterministic_session = await ctx.run_typed(
"store session", lambda: session_to_store, restate.RunOptions(type_hint=Session)
)
ctx.set(f"session_store::{session.id}", deterministic_session)
Loading