Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ harness = ["testcontainers", "hypercorn", "httpx"]
serde = ["dacite", "pydantic", "msgspec"]
client = ["httpx[http2]"]
adk = ["google-adk>=1.20.0"]
openai = ["openai-agents>=0.6.1"]

[build-system]
requires = ["maturin>=1.6,<2.0"]
Expand Down
22 changes: 22 additions & 0 deletions python/restate/ext/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""
This module contains the optional OpenAI integration for Restate.
"""

from .runner_wrapper import Runner, DurableModelCalls, continue_on_terminal_errors, raise_terminal_errors

__all__ = [
"DurableModelCalls",
"continue_on_terminal_errors",
"raise_terminal_errors",
"Runner",
]
283 changes: 283 additions & 0 deletions python/restate/ext/openai/runner_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
#
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""
This module contains the optional OpenAI integration for Restate.
"""

import asyncio
import dataclasses
import typing

from agents import (
Tool,
Usage,
Model,
RunContextWrapper,
AgentsException,
Runner as OpenAIRunner,
RunConfig,
TContext,
RunResult,
Agent,
ModelBehaviorError,
)

from agents.models.multi_provider import MultiProvider
from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse
from agents.memory.session import SessionABC
from agents.items import TResponseInputItem
from typing import List, Any
from typing import AsyncIterator

from agents.tool import FunctionTool
from agents.tool_context import ToolContext
from pydantic import BaseModel
from restate.exceptions import SdkInternalBaseException
from restate.extensions import current_context

from restate import RunOptions, ObjectContext, TerminalError


# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
class RestateModelResponse(BaseModel):
output: list[TResponseOutputItem]
"""A list of outputs (messages, tool calls, etc) generated by the model"""

usage: Usage
"""The usage information for the response."""

response_id: str | None
"""An ID for the response which can be used to refer to the response in subsequent calls to the
model. Not supported by all model providers.
If using OpenAI models via the Responses API, this is the `response_id` parameter, and it can
be passed to `Runner.run`.
"""

def to_input_items(self) -> list[TResponseInputItem]:
return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore


class DurableModelCalls(MultiProvider):
"""
A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
"""

def __init__(self, max_retries: int | None = 3):
super().__init__()
self.max_retries = max_retries

def get_model(self, model_name: str | None) -> Model:
return RestateModelWrapper(super().get_model(model_name or None), self.max_retries)


class RestateModelWrapper(Model):
"""
A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
"""

def __init__(self, model: Model, max_retries: int | None = 3):
self.model = model
self.model_name = "RestateModelWrapper"
self.max_retries = max_retries

async def get_response(self, *args, **kwargs) -> ModelResponse:
async def call_llm() -> RestateModelResponse:
resp = await self.model.get_response(*args, **kwargs)
# convert to pydantic model to be serializable by Restate SDK
return RestateModelResponse(
output=resp.output,
usage=resp.usage,
response_id=resp.response_id,
)

ctx = current_context()
if ctx is None:
raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler")
result = await ctx.run_typed("call LLM", call_llm, RunOptions(max_attempts=self.max_retries))
# convert back to original ModelResponse
return ModelResponse(
output=result.output,
usage=result.usage,
response_id=result.response_id,
)

def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent]:
raise TerminalError("Streaming is not supported in Restate. Use `get_response` instead.")


class RestateSession(SessionABC):
"""Restate session implementation following the Session protocol."""

def _ctx(self) -> ObjectContext:
return typing.cast(ObjectContext, current_context())

async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]:
"""Retrieve conversation history for this session."""
current_items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or []
if limit is not None:
return current_items[-limit:]
return current_items

async def add_items(self, items: List[TResponseInputItem]) -> None:
"""Store new items for this session."""
# Your implementation here
current_items = await self.get_items() or []
self._ctx().set("items", current_items + items)

async def pop_item(self) -> TResponseInputItem | None:
"""Remove and return the most recent item from this session."""
current_items = await self.get_items() or []
if current_items:
item = current_items.pop()
self._ctx().set("items", current_items)
return item
return None

async def clear_session(self) -> None:
"""Clear all items for this session."""
self._ctx().clear("items")


class AgentsTerminalException(AgentsException, TerminalError):
"""Exception that is both an AgentsException and a restate.TerminalError."""

def __init__(self, *args: object) -> None:
super().__init__(*args)


class AgentsSuspension(AgentsException, SdkInternalBaseException):
"""Exception that is both an AgentsException and a restate SdkInternalBaseException."""

def __init__(self, *args: object) -> None:
super().__init__(*args)


def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str:
"""A custom function to provide a user-friendly error message."""
# Raise terminal errors and cancellations
if isinstance(error, TerminalError):
# For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
# so we create a new exception that inherits from both
raise AgentsTerminalException(error.message)

if isinstance(error, ModelBehaviorError):
return f"An error occurred while calling the tool: {str(error)}"

raise error


def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str:
"""A custom function to provide a user-friendly error message."""
# Raise terminal errors and cancellations
if isinstance(error, TerminalError):
# For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
# so we create a new exception that inherits from both
return f"An error occurred while running the tool: {str(error)}"

if isinstance(error, ModelBehaviorError):
return f"An error occurred while calling the tool: {str(error)}"

raise error


class Runner:
"""
A wrapper around Runner.run that automatically configures RunConfig for Restate contexts.

This class automatically sets up the appropriate model provider (DurableModelCalls) and
model settings, taking over any model and model_settings configuration provided in the
original RunConfig.
"""

@staticmethod
async def run(
starting_agent: Agent[TContext],
disable_tool_autowrapping: bool = False,
*args: typing.Any,
run_config: RunConfig | None = None,
**kwargs,
) -> RunResult:
"""
Run an agent with automatic Restate configuration.

Returns:
The result from Runner.run
"""

current_run_config = run_config or RunConfig()
new_run_config = dataclasses.replace(
current_run_config,
model_provider=DurableModelCalls(),
)
restate_agent = sequentialize_and_wrap_tools(starting_agent, disable_tool_autowrapping)
return await OpenAIRunner.run(restate_agent, *args, run_config=new_run_config, **kwargs)


def sequentialize_and_wrap_tools(
agent: Agent[TContext],
disable_tool_autowrapping: bool,
) -> Agent[TContext]:
"""
Wrap the tools of an agent to use the Restate error handling.

Returns:
A new agent with wrapped tools.
"""

# Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution.
# This lock only affects tools for this agent; handoff agents are wrapped recursively.
sequential_tools_lock = asyncio.Lock()
wrapped_tools: list[Tool] = []
for tool in agent.tools:
if isinstance(tool, FunctionTool):

def create_wrapper(captured_tool):
async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any:
await sequential_tools_lock.acquire()

async def invoke():
result = await captured_tool.on_invoke_tool(tool_context, tool_input)
# Ensure Pydantic objects are serialized to dict for LLM compatibility
if hasattr(result, "model_dump"):
return result.model_dump()
elif hasattr(result, "dict"):
return result.dict()
return result

try:
if disable_tool_autowrapping:
return await invoke()

ctx = current_context()
if ctx is None:
raise RuntimeError(
"No current Restate context found, make sure to run inside a Restate handler"
)
return await ctx.run_typed(captured_tool.name, invoke)
finally:
sequential_tools_lock.release()

return on_invoke_tool_wrapper

wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=create_wrapper(tool)))
else:
wrapped_tools.append(tool)

handoffs_with_wrapped_tools = []
for handoff in agent.handoffs:
# recursively wrap tools in handoff agents
handoffs_with_wrapped_tools.append(sequentialize_and_wrap_tools(handoff, disable_tool_autowrapping)) # type: ignore

return agent.clone(
tools=wrapped_tools,
handoffs=handoffs_with_wrapped_tools,
)
Loading