Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/restate/ext/adk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,33 @@
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
import typing

from .session import RestateSessionService
from .plugin import RestatePlugin
from restate import ObjectContext, Context
from restate.extensions import current_context


def restate_object_context() -> ObjectContext:
"""Get the current Restate ObjectContext."""
ctx = current_context()
if ctx is None:
raise RuntimeError("No Restate context found.")
return typing.cast(ObjectContext, ctx)


def restate_context() -> Context:
"""Get the current Restate Context."""
ctx = current_context()
if ctx is None:
raise RuntimeError("No Restate context found.")
return ctx


__all__ = [
"RestateSessionService",
"RestatePlugin",
"restate_object_context",
"restate_context",
]
21 changes: 11 additions & 10 deletions python/restate/ext/adk/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from google.adk.agents import BaseAgent, LlmAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
from google.adk.plugins import BasePlugin
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext
Expand All @@ -30,12 +31,10 @@
from google.adk.models import LLMRegistry
from google.adk.models.base_llm import BaseLlm
from google.adk.flows.llm_flows.functions import generate_client_function_call_id

from restate.ext.adk import RestateSessionService

from restate.extensions import current_context

from .session import flush_session_state


class RestatePlugin(BasePlugin):
"""A plugin to integrate Restate with the ADK framework."""
Expand Down Expand Up @@ -84,12 +83,13 @@ async def after_agent_callback(
) -> Optional[types.Content]:
self._models.pop(callback_context.invocation_id, None)
self._locks.pop(callback_context.invocation_id, None)

ctx = cast(restate.ObjectContext, current_context())
await flush_session_state(ctx, callback_context.session)

return None

async def after_run_callback(self, *, invocation_context: InvocationContext) -> None:
if isinstance(invocation_context.session_service, RestateSessionService):
restate_session_service = cast(RestateSessionService, invocation_context.session_service)
await restate_session_service.flush_session_state(invocation_context.session)

async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
Expand All @@ -109,9 +109,10 @@ async def before_tool_callback(
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> Optional[dict]:
tool_context.session.state["restate_context"] = current_context()
lock = self._locks[tool_context.invocation_id]
ctx = current_context()
await lock.acquire()
tool_context.session.state["restate_context"] = ctx
# TODO: if we want we can also automatically wrap tools with ctx.run_typed here
return None

Expand All @@ -123,9 +124,9 @@ async def after_tool_callback(
tool_context: ToolContext,
result: dict,
) -> Optional[dict]:
tool_context.session.state.pop("restate_context", None)
lock = self._locks[tool_context.invocation_id]
lock.release()
tool_context.session.state.pop("restate_context", None)
return None

async def on_tool_error_callback(
Expand All @@ -136,9 +137,9 @@ async def on_tool_error_callback(
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
tool_context.session.state.pop("restate_context", None)
lock = self._locks[tool_context.invocation_id]
lock.release()
tool_context.session.state.pop("restate_context", None)
return None

async def close(self):
Expand Down
36 changes: 26 additions & 10 deletions python/restate/ext/adk/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from typing import Optional, Any, cast

from google.adk.sessions import Session
from google.adk.sessions.state import State
from google.adk.events.event import Event
from google.adk.sessions.base_session_service import (
BaseSessionService,
ListSessionsResponse,
GetSessionConfig,
)
from restate import TerminalError

from restate.extensions import current_context

Expand All @@ -42,15 +44,23 @@ async def create_session(
if session_id is None:
session_id = str(self.ctx().uuid())

session = await self.ctx().get(f"session_store::{session_id}", type_hint=Session) or Session(
session = await self.ctx().get(f"session_store::{session_id}", type_hint=Session)
if session is not None:
raise TerminalError("Session with the given ID already exists.")

session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=state or {},
)
self.ctx().set(f"session_store::{session_id}", session)

await self.flush_session_state(session)
return session

async def has_session(self, *, session_id: str) -> bool:
return await self.ctx().get(f"session_store::{session_id}", type_hint=Session) is not None

async def get_session(
self,
*,
Expand Down Expand Up @@ -89,12 +99,18 @@ async def append_event(self, session: Session, event: Event) -> Event:
session.events.append(event)
return event

async def flush_session_state(self, session: Session):
session_to_store = session.model_copy()

async def flush_session_state(ctx: restate.ObjectContext, session: Session):
session_to_store = session.model_copy()
# Remove restate-specific context that got added by the plugin before storing
session_to_store.state.pop("restate_context", None)
deterministic_session = await ctx.run_typed(
"store session", lambda: session_to_store, restate.RunOptions(type_hint=Session)
)
ctx.set(f"session_store::{session.id}", deterministic_session)
# Remove temporary state keys before storing
for key in list(session_to_store.state.keys()):
if key.startswith(State.TEMP_PREFIX):
session_to_store.state.pop(key)

# Remove restate-specific context that got added by the plugin before storing
session_to_store.state.pop("restate_context", None)

deterministic_session = await self.ctx().run_typed(
"store session", lambda: session_to_store, restate.RunOptions(type_hint=Session)
)
self.ctx().set(f"session_store::{session.id}", deterministic_session)