Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 15 additions & 38 deletions temporalio/client.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@
WorkflowSerializationContext, WorkflowSerializationContext,
) )
from temporalio.service import ( from temporalio.service import (
ConnectConfig,
HttpConnectProxyConfig, HttpConnectProxyConfig,
KeepAliveConfig, KeepAliveConfig,
RetryConfig, RetryConfig,
RPCError, RPCError,
RPCStatusCode, RPCStatusCode,
ServiceClient,
TLSConfig, TLSConfig,
) )


Expand Down Expand Up @@ -198,12 +200,14 @@ async def connect(
http_connect_proxy_config=http_connect_proxy_config, http_connect_proxy_config=http_connect_proxy_config,
) )


root_plugin: Plugin = _RootPlugin() def make_lambda(plugin, next):
return lambda config: plugin.connect_service_client(config, next)

next_function = ServiceClient.connect
for plugin in reversed(plugins): for plugin in reversed(plugins):
plugin.init_client_plugin(root_plugin) next_function = make_lambda(plugin, next_function)
root_plugin = plugin


service_client = await root_plugin.connect_service_client(connect_config) service_client = await next_function(connect_config)


return Client( return Client(
service_client, service_client,
Expand Down Expand Up @@ -243,12 +247,10 @@ def __init__(
plugins=plugins, plugins=plugins,
) )


root_plugin: Plugin = _RootPlugin() for plugin in plugins:
for plugin in reversed(plugins): config = plugin.configure_client(config)
plugin.init_client_plugin(root_plugin)
root_plugin = plugin


self._init_from_config(root_plugin.configure_client(config)) self._init_from_config(config)


def _init_from_config(self, config: ClientConfig): def _init_from_config(self, config: ClientConfig):
self._config = config self._config = config
Expand Down Expand Up @@ -7541,20 +7543,6 @@ def name(self) -> str:
""" """
return type(self).__module__ + "." + type(self).__qualname__ return type(self).__module__ + "." + type(self).__qualname__


@abstractmethod
def init_client_plugin(self, next: Plugin) -> None:
"""Initialize this plugin in the plugin chain.

This method sets up the chain of responsibility pattern by providing a reference
to the next plugin in the chain. It is called during client creation to build
the plugin chain. Note, this may be called twice in the case of :py:meth:`connect`.
Implementations should store this reference and call the corresponding method
of the next plugin on method calls.

Args:
next: The next plugin in the chain to delegate to.
"""

@abstractmethod @abstractmethod
def configure_client(self, config: ClientConfig) -> ClientConfig: def configure_client(self, config: ClientConfig) -> ClientConfig:
"""Hook called when creating a client to allow modification of configuration. """Hook called when creating a client to allow modification of configuration.
Expand All @@ -7572,8 +7560,10 @@ def configure_client(self, config: ClientConfig) -> ClientConfig:


@abstractmethod @abstractmethod
async def connect_service_client( async def connect_service_client(
self, config: temporalio.service.ConnectConfig self,
) -> temporalio.service.ServiceClient: config: ConnectConfig,
next: Callable[[ConnectConfig], Awaitable[ServiceClient]],
) -> ServiceClient:
"""Hook called when connecting to the Temporal service. """Hook called when connecting to the Temporal service.


This method is called during service client connection and allows plugins This method is called during service client connection and allows plugins
Expand All @@ -7586,16 +7576,3 @@ async def connect_service_client(
Returns: Returns:
The connected service client. The connected service client.
""" """


class _RootPlugin(Plugin):
def init_client_plugin(self, next: Plugin) -> None:
raise NotImplementedError()

def configure_client(self, config: ClientConfig) -> ClientConfig:
return config

async def connect_service_client(
self, config: temporalio.service.ConnectConfig
) -> temporalio.service.ServiceClient:
return await temporalio.service.ServiceClient.connect(config)
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from pydantic_core import to_json from pydantic_core import to_json
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict


from temporalio import activity from temporalio import activity, workflow
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater
from temporalio.exceptions import ApplicationError from temporalio.exceptions import ApplicationError


Expand Down
168 changes: 47 additions & 121 deletions temporalio/contrib/openai_agents/_temporal_openai_agents.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@
from agents.run import get_default_agent_runner, set_default_agent_runner from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tracing import get_trace_provider from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider from agents.tracing.provider import DefaultTraceProvider
from openai.types.responses import ResponsePromptParam


import temporalio.client
import temporalio.worker
from temporalio.client import ClientConfig
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._openai_runner import ( from temporalio.contrib.openai_agents._openai_runner import (
Expand All @@ -47,13 +43,8 @@
DataConverter, DataConverter,
DefaultPayloadConverter, DefaultPayloadConverter,
) )
from temporalio.worker import ( from temporalio.plugin import SimplePlugin
Replayer, from temporalio.worker import WorkflowRunner
ReplayerConfig,
Worker,
WorkerConfig,
WorkflowReplayResult,
)
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner


# Unsupported on python 3.9 # Unsupported on python 3.9
Expand Down Expand Up @@ -172,7 +163,21 @@ def __init__(self) -> None:
super().__init__(ToJsonOptions(exclude_unset=True)) super().__init__(ToJsonOptions(exclude_unset=True))




class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): def _data_converter(converter: Optional[DataConverter]) -> DataConverter:
if converter is None:
return DataConverter(payload_converter_class=OpenAIPayloadConverter)
elif converter.payload_converter_class is DefaultPayloadConverter:
return dataclasses.replace(
converter, payload_converter_class=OpenAIPayloadConverter
)
elif not isinstance(converter.payload_converter, OpenAIPayloadConverter):
raise ValueError(
"The payload converter must be of type OpenAIPayloadConverter."
)
return converter


class OpenAIAgentsPlugin(SimplePlugin):
"""Temporal plugin for integrating OpenAI agents with Temporal workflows. """Temporal plugin for integrating OpenAI agents with Temporal workflows.


.. warning:: .. warning::
Expand Down Expand Up @@ -278,127 +283,48 @@ def __init__(
"When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout"
) )


self._model_params = model_params # Delay activity construction until they are actually needed
self._model_provider = model_provider def add_activities(
self._mcp_server_providers = mcp_server_providers activities: Optional[Sequence[Callable]],
self._register_activities = register_activities ) -> Sequence[Callable]:

if not register_activities:
def init_client_plugin(self, next: temporalio.client.Plugin) -> None: return activities or []
"""Set the next client plugin"""
self.next_client_plugin = next

async def connect_service_client(
self, config: temporalio.service.ConnectConfig
) -> temporalio.service.ServiceClient:
"""No modifications to service client"""
return await self.next_client_plugin.connect_service_client(config)

def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None:
"""Set the next worker plugin"""
self.next_worker_plugin = next

@staticmethod
def _data_converter(converter: Optional[DataConverter]) -> DataConverter:
if converter is None:
return DataConverter(payload_converter_class=OpenAIPayloadConverter)
elif converter.payload_converter_class is DefaultPayloadConverter:
return dataclasses.replace(
converter, payload_converter_class=OpenAIPayloadConverter
)
elif not isinstance(converter.payload_converter, OpenAIPayloadConverter):
raise ValueError(
"The payload converter must be of type OpenAIPayloadConverter."
)
return converter

def configure_client(self, config: ClientConfig) -> ClientConfig:
"""Configure the Temporal client for OpenAI agents integration.

This method sets up the Pydantic data converter to enable proper
serialization of OpenAI agent objects and responses.

Args:
config: The client configuration to modify.

Returns:
The modified client configuration.
"""
config["data_converter"] = self._data_converter(config["data_converter"])
return self.next_client_plugin.configure_client(config)

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
"""Configure the Temporal worker for OpenAI agents integration.

This method adds the necessary interceptors and activities for OpenAI
agent execution:
- Adds tracing interceptors for OpenAI agent interactions
- Registers model execution activities

Args:
config: The worker configuration to modify.

Returns:
The modified worker configuration.
"""
config["interceptors"] = list(config.get("interceptors") or []) + [
OpenAIAgentsTracingInterceptor()
]


if self._register_activities: new_activities = [ModelActivity(model_provider).invoke_model_activity]
new_activities = [ModelActivity(self._model_provider).invoke_model_activity]


server_names = [server.name for server in self._mcp_server_providers] server_names = [server.name for server in mcp_server_providers]
if len(server_names) != len(set(server_names)): if len(server_names) != len(set(server_names)):
raise ValueError( raise ValueError(
f"More than one mcp server registered with the same name. Please provide unique names." f"More than one mcp server registered with the same name. Please provide unique names."
) )


for mcp_server in self._mcp_server_providers: for mcp_server in mcp_server_providers:
new_activities.extend(mcp_server._get_activities()) new_activities.extend(mcp_server._get_activities())
config["activities"] = list(config.get("activities") or []) + new_activities return list(activities or []) + new_activities

def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner:
if not runner:
raise ValueError("No WorkflowRunner provided to the OpenAI plugin.")


runner = config.get("workflow_runner") # If in sandbox, add additional passthrough
if isinstance(runner, SandboxedWorkflowRunner): if isinstance(runner, SandboxedWorkflowRunner):
config["workflow_runner"] = dataclasses.replace( return dataclasses.replace(
runner, runner,
restrictions=runner.restrictions.with_passthrough_modules("mcp"), restrictions=runner.restrictions.with_passthrough_modules("mcp"),
) )

return runner
config["workflow_failure_exception_types"] = list(
config.get("workflow_failure_exception_types") or []
) + [AgentsWorkflowError]
return self.next_worker_plugin.configure_worker(config)

async def run_worker(self, worker: Worker) -> None:
"""Run the worker with OpenAI agents temporal overrides.

This method sets up the necessary runtime overrides for OpenAI agents
to work within the Temporal worker context, including custom runners
and trace providers.

Args:
worker: The worker instance to run.
"""
with set_open_ai_agent_temporal_overrides(self._model_params):
await self.next_worker_plugin.run_worker(worker)

def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
"""Configure the replayer for OpenAI Agents."""
config["interceptors"] = list(config.get("interceptors") or []) + [
OpenAIAgentsTracingInterceptor()
]
config["data_converter"] = self._data_converter(config.get("data_converter"))
return self.next_worker_plugin.configure_replayer(config)


@asynccontextmanager @asynccontextmanager
async def run_replayer( async def run_context() -> AsyncIterator[None]:
self, with set_open_ai_agent_temporal_overrides(model_params):
replayer: Replayer, yield
histories: AsyncIterator[temporalio.client.WorkflowHistory],
) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: super().__init__(
"""Set the OpenAI Overrides during replay""" name="OpenAIAgentsPlugin",
with set_open_ai_agent_temporal_overrides(self._model_params): data_converter=_data_converter,
async with self.next_worker_plugin.run_replayer( worker_interceptors=[OpenAIAgentsTracingInterceptor()],
replayer, histories activities=add_activities,
) as results: workflow_runner=workflow_runner,
yield results workflow_failure_exception_types=[AgentsWorkflowError],
run_context=lambda: run_context(),
)
Loading
Loading