diff --git a/python/restate/ext/adk/__init__.py b/python/restate/ext/adk/__init__.py index bc806f6..f86a973 100644 --- a/python/restate/ext/adk/__init__.py +++ b/python/restate/ext/adk/__init__.py @@ -8,11 +8,33 @@ # directory of this repository or package, or at # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE # +import typing from .session import RestateSessionService from .plugin import RestatePlugin +from restate import ObjectContext, Context +from restate.extensions import current_context + + +def restate_object_context() -> ObjectContext: + """Get the current Restate ObjectContext.""" + ctx = current_context() + if ctx is None: + raise RuntimeError("No Restate context found.") + return typing.cast(ObjectContext, ctx) + + +def restate_context() -> Context: + """Get the current Restate Context.""" + ctx = current_context() + if ctx is None: + raise RuntimeError("No Restate context found.") + return ctx + __all__ = [ "RestateSessionService", "RestatePlugin", + "restate_object_context", + "restate_context", ] diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index e1863a8..0911190 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -22,6 +22,7 @@ from google.adk.agents import BaseAgent, LlmAgent from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext from google.adk.plugins import BasePlugin from google.adk.tools.base_tool import BaseTool from google.adk.tools.tool_context import ToolContext @@ -30,12 +31,10 @@ from google.adk.models import LLMRegistry from google.adk.models.base_llm import BaseLlm from google.adk.flows.llm_flows.functions import generate_client_function_call_id - +from restate.ext.adk import RestateSessionService from restate.extensions import current_context -from .session import flush_session_state - class RestatePlugin(BasePlugin): """A plugin to integrate Restate with the ADK framework.""" @@ -84,12 +83,13 @@ async def after_agent_callback( ) -> Optional[types.Content]: self._models.pop(callback_context.invocation_id, None) self._locks.pop(callback_context.invocation_id, None) - - ctx = cast(restate.ObjectContext, current_context()) - await flush_session_state(ctx, callback_context.session) - return None + async def after_run_callback(self, *, invocation_context: InvocationContext) -> None: + if isinstance(invocation_context.session_service, RestateSessionService): + restate_session_service = cast(RestateSessionService, invocation_context.session_service) + await restate_session_service.flush_session_state(invocation_context.session) + async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest ) -> Optional[LlmResponse]: @@ -109,9 +109,10 @@ async def before_tool_callback( tool_args: dict[str, Any], tool_context: ToolContext, ) -> Optional[dict]: - tool_context.session.state["restate_context"] = current_context() lock = self._locks[tool_context.invocation_id] + ctx = current_context() await lock.acquire() + tool_context.session.state["restate_context"] = ctx # TODO: if we want we can also automatically wrap tools with ctx.run_typed here return None @@ -123,9 +124,9 @@ async def after_tool_callback( tool_context: ToolContext, result: dict, ) -> Optional[dict]: + tool_context.session.state.pop("restate_context", None) lock = self._locks[tool_context.invocation_id] lock.release() - tool_context.session.state.pop("restate_context", None) return None async def on_tool_error_callback( @@ -136,9 +137,9 @@ async def on_tool_error_callback( tool_context: ToolContext, error: Exception, ) -> Optional[dict]: + tool_context.session.state.pop("restate_context", None) lock = self._locks[tool_context.invocation_id] lock.release() - tool_context.session.state.pop("restate_context", None) return None async def close(self): diff --git a/python/restate/ext/adk/session.py b/python/restate/ext/adk/session.py index 26e0373..2b25346 100644 --- a/python/restate/ext/adk/session.py +++ b/python/restate/ext/adk/session.py @@ -17,12 +17,14 @@ from typing import Optional, Any, cast from google.adk.sessions import Session +from google.adk.sessions.state import State from google.adk.events.event import Event from google.adk.sessions.base_session_service import ( BaseSessionService, ListSessionsResponse, GetSessionConfig, ) +from restate import TerminalError from restate.extensions import current_context @@ -42,15 +44,23 @@ async def create_session( if session_id is None: session_id = str(self.ctx().uuid()) - session = await self.ctx().get(f"session_store::{session_id}", type_hint=Session) or Session( + session = await self.ctx().get(f"session_store::{session_id}", type_hint=Session) + if session is not None: + raise TerminalError("Session with the given ID already exists.") + + session = Session( app_name=app_name, user_id=user_id, id=session_id, state=state or {}, ) - self.ctx().set(f"session_store::{session_id}", session) + + await self.flush_session_state(session) return session + async def has_session(self, *, session_id: str) -> bool: + return await self.ctx().get(f"session_store::{session_id}", type_hint=Session) is not None + async def get_session( self, *, @@ -89,12 +99,18 @@ async def append_event(self, session: Session, event: Event) -> Event: session.events.append(event) return event + async def flush_session_state(self, session: Session): + session_to_store = session.model_copy() -async def flush_session_state(ctx: restate.ObjectContext, session: Session): - session_to_store = session.model_copy() - # Remove restate-specific context that got added by the plugin before storing - session_to_store.state.pop("restate_context", None) - deterministic_session = await ctx.run_typed( - "store session", lambda: session_to_store, restate.RunOptions(type_hint=Session) - ) - ctx.set(f"session_store::{session.id}", deterministic_session) + # Remove temporary state keys before storing + for key in list(session_to_store.state.keys()): + if key.startswith(State.TEMP_PREFIX): + session_to_store.state.pop(key) + + # Remove restate-specific context that got added by the plugin before storing + session_to_store.state.pop("restate_context", None) + + deterministic_session = await self.ctx().run_typed( + "store session", lambda: session_to_store, restate.RunOptions(type_hint=Session) + ) + self.ctx().set(f"session_store::{session.id}", deterministic_session)