From 2a9cec3b64a172d641b787482ae5cd05f0285e30 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 2 Oct 2025 10:36:34 -0700 Subject: [PATCH 01/13] Add static plugin constructor --- temporalio/plugin.py | 199 ++++++++++++++++++++++++++++++++++++++++++ tests/test_plugins.py | 34 ++++++++ 2 files changed, 233 insertions(+) create mode 100644 temporalio/plugin.py diff --git a/temporalio/plugin.py b/temporalio/plugin.py new file mode 100644 index 000000000..1cbc78782 --- /dev/null +++ b/temporalio/plugin.py @@ -0,0 +1,199 @@ +import abc +import dataclasses +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, AsyncIterator, Callable, Optional, Sequence, Set, Type + +import temporalio.client +import temporalio.converter +import temporalio.worker +from temporalio.client import ClientConfig, WorkflowHistory +from temporalio.worker import ( + Replayer, + ReplayerConfig, + Worker, + WorkerConfig, + WorkflowReplayResult, +) +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + + +class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin, abc.ABC): + pass + + +def create_plugin( + *, + data_converter: Optional[temporalio.converter.DataConverter] = None, + client_interceptors: Optional[Sequence[temporalio.client.Interceptor]] = None, + activities: Optional[Sequence[Callable]] = None, + nexus_service_handlers: Optional[Sequence[Any]] = None, + workflows: Optional[Sequence[Type]] = None, + passthrough_modules: Optional[Set[str]] = None, + worker_interceptors: Optional[Sequence[temporalio.worker.Interceptor]] = None, + workflow_failure_exception_types: Optional[Sequence[Type[BaseException]]] = None, + run_context: Optional[AbstractAsyncContextManager[None]] = None, +) -> Plugin: + return _StaticPlugin( + data_converter=data_converter, + client_interceptors=client_interceptors, + activities=activities, + nexus_service_handlers=nexus_service_handlers, + workflows=workflows, + passthrough_modules=passthrough_modules, + worker_interceptors=worker_interceptors, + workflow_failure_exception_types=workflow_failure_exception_types, + run_context=run_context, + ) + + +class _StaticPlugin(Plugin): + def __init__( + self, + *, + data_converter: Optional[temporalio.converter.DataConverter] = None, + client_interceptors: Optional[Sequence[temporalio.client.Interceptor]] = None, + activities: Optional[Sequence[Callable]] = None, + nexus_service_handlers: Optional[Sequence[Any]] = None, + workflows: Optional[Sequence[Type]] = None, + passthrough_modules: Optional[Set[str]] = None, + worker_interceptors: Optional[Sequence[temporalio.worker.Interceptor]] = None, + workflow_failure_exception_types: Optional[ + Sequence[Type[BaseException]] + ] = None, + run_context: Optional[AbstractAsyncContextManager[None]] = None, + ) -> None: + self.data_converter = data_converter + self.client_interceptors = client_interceptors + self.activities = activities + self.nexus_service_handlers = nexus_service_handlers + self.workflows = workflows + self.passthrough_modules = passthrough_modules + self.worker_interceptors = worker_interceptors + self.workflow_failure_exception_types = workflow_failure_exception_types + self.run_context = run_context + + def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: + self.next_worker_plugin = next + + def init_client_plugin(self, next: temporalio.client.Plugin) -> None: + self.next_client_plugin = next + + def configure_client(self, config: ClientConfig) -> ClientConfig: + if self.data_converter: + if not config["data_converter"] == temporalio.converter.default(): + raise ValueError( + "Static Plugin was configured with a data converter, but the client was as well." + ) + else: + config["data_converter"] = self.data_converter + + if self.client_interceptors: + config["interceptors"] = list(config.get("interceptors", [])) + list( + self.client_interceptors + ) + + return self.next_client_plugin.configure_client(config) + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + return await self.next_client_plugin.connect_service_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + if self.activities: + config["activities"] = list(config.get("activities", [])) + list( + self.activities + ) + + if self.nexus_service_handlers: + config["nexus_service_handlers"] = list( + config.get("nexus_service_handlers", []) + ) + list(self.nexus_service_handlers) + + if self.workflows: + config["workflows"] = list(config.get("workflows", [])) + list( + self.workflows + ) + + if self.passthrough_modules: + runner = config.get("workflow_runner") + if runner and isinstance(runner, SandboxedWorkflowRunner): + config["workflow_runner"] = dataclasses.replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + *self.passthrough_modules + ), + ) + + if self.worker_interceptors: + config["interceptors"] = list(config.get("interceptors", [])) + list( + self.worker_interceptors + ) + + if self.workflow_failure_exception_types: + config["workflow_failure_exception_types"] = list( + config.get("workflow_failure_exception_types", []) + ) + list(self.workflow_failure_exception_types) + + return config + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + if self.data_converter: + if not config["data_converter"] == temporalio.converter.default(): + raise ValueError( + "Static Plugin was configured with a data converter, but the client was as well." + ) + else: + config["data_converter"] = self.data_converter + + if self.workflows: + config["workflows"] = list(config.get("workflows", [])) + list( + self.workflows + ) + + if self.passthrough_modules: + runner = config.get("workflow_runner") + if runner and isinstance(runner, SandboxedWorkflowRunner): + config["workflow_runner"] = dataclasses.replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + *self.passthrough_modules + ), + ) + + if self.worker_interceptors: + config["interceptors"] = list(config.get("interceptors", [])) + list( + self.worker_interceptors + ) + + if self.workflow_failure_exception_types: + config["workflow_failure_exception_types"] = list( + config.get("workflow_failure_exception_types", []) + ) + list(self.workflow_failure_exception_types) + + return config + + async def run_worker(self, worker: Worker) -> None: + if self.run_context: + async with self.run_context: + await self.next_worker_plugin.run_worker(worker) + else: + await self.next_worker_plugin.run_worker(worker) + + @asynccontextmanager + async def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[WorkflowHistory], + ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + if self.run_context: + async with self.run_context: + async with self.next_worker_plugin.run_replayer( + replayer, histories + ) as results: + yield results + else: + async with self.next_worker_plugin.run_replayer( + replayer, histories + ) as results: + yield results diff --git a/tests/test_plugins.py b/tests/test_plugins.py index eb08bba2d..98ecde3db 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -11,6 +11,7 @@ from temporalio import workflow from temporalio.client import Client, ClientConfig, OutboundInterceptor, Plugin from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.plugin import create_plugin from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( Replayer, @@ -256,3 +257,36 @@ async def test_replay(client: Client) -> None: assert replayer.config().get("data_converter") == pydantic_data_converter await replayer.replay_workflow(await handle.fetch_history()) + +async def test_static_plugins(client: Client) -> None: + plugin = create_plugin( + data_converter=pydantic_data_converter, + workflows=[HelloWorkflow], + ) + config = client.config() + config["plugins"] = [plugin] + new_client = Client(**config) + + assert new_client.data_converter == pydantic_data_converter + + # Test without plugin registered in client + worker = Worker( + client, + task_queue="queue", + activities=[never_run_activity], + plugins=[plugin], + ) + assert worker.config().get("workflows") == [HelloWorkflow] + + # Test with plugin registered in client + worker = Worker( + new_client, + task_queue="queue", + activities=[never_run_activity], + plugins=[plugin], + ) + assert worker.config().get("workflows") == [HelloWorkflow] + + replayer = Replayer(workflows=[], plugins=[plugin]) + assert replayer.config().get("data_converter") == pydantic_data_converter + assert replayer.config().get("workflows") == [HelloWorkflow] From 60eb145089f40ca9c1dc7db124eea002ef4ce694 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 2 Oct 2025 11:38:12 -0700 Subject: [PATCH 02/13] Allow callables as parameters --- temporalio/plugin.py | 261 ++++++++++++++++++++++++++---------------- tests/test_plugins.py | 59 ++++++++-- 2 files changed, 212 insertions(+), 108 deletions(-) diff --git a/temporalio/plugin.py b/temporalio/plugin.py index 1cbc78782..44c7827ac 100644 --- a/temporalio/plugin.py +++ b/temporalio/plugin.py @@ -1,7 +1,17 @@ import abc -import dataclasses from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import Any, AsyncIterator, Callable, Optional, Sequence, Set, Type +from typing import ( + Any, + AsyncIterator, + Callable, + Optional, + Sequence, + Set, + Type, + TypeVar, + Union, + cast, +) import temporalio.client import temporalio.converter @@ -13,24 +23,35 @@ Worker, WorkerConfig, WorkflowReplayResult, + WorkflowRunner, ) -from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin, abc.ABC): pass +T = TypeVar("T") + +PluginParameter = Union[None, T, Callable[[Optional[T]], T]] + + def create_plugin( *, - data_converter: Optional[temporalio.converter.DataConverter] = None, - client_interceptors: Optional[Sequence[temporalio.client.Interceptor]] = None, - activities: Optional[Sequence[Callable]] = None, - nexus_service_handlers: Optional[Sequence[Any]] = None, - workflows: Optional[Sequence[Type]] = None, - passthrough_modules: Optional[Set[str]] = None, - worker_interceptors: Optional[Sequence[temporalio.worker.Interceptor]] = None, - workflow_failure_exception_types: Optional[Sequence[Type[BaseException]]] = None, + data_converter: PluginParameter[temporalio.converter.DataConverter] = None, + client_interceptors: PluginParameter[ + Sequence[temporalio.client.Interceptor] + ] = None, + activities: PluginParameter[Sequence[Callable]] = None, + nexus_service_handlers: PluginParameter[Sequence[Any]] = None, + workflows: PluginParameter[Sequence[Type]] = None, + workflow_runner: PluginParameter[WorkflowRunner] = None, + worker_interceptors: PluginParameter[ + Sequence[temporalio.worker.Interceptor] + ] = None, + workflow_failure_exception_types: PluginParameter[ + Sequence[Type[BaseException]] + ] = None, run_context: Optional[AbstractAsyncContextManager[None]] = None, ) -> Plugin: return _StaticPlugin( @@ -39,7 +60,7 @@ def create_plugin( activities=activities, nexus_service_handlers=nexus_service_handlers, workflows=workflows, - passthrough_modules=passthrough_modules, + workflow_runner=workflow_runner, worker_interceptors=worker_interceptors, workflow_failure_exception_types=workflow_failure_exception_types, run_context=run_context, @@ -50,14 +71,18 @@ class _StaticPlugin(Plugin): def __init__( self, *, - data_converter: Optional[temporalio.converter.DataConverter] = None, - client_interceptors: Optional[Sequence[temporalio.client.Interceptor]] = None, - activities: Optional[Sequence[Callable]] = None, - nexus_service_handlers: Optional[Sequence[Any]] = None, - workflows: Optional[Sequence[Type]] = None, - passthrough_modules: Optional[Set[str]] = None, - worker_interceptors: Optional[Sequence[temporalio.worker.Interceptor]] = None, - workflow_failure_exception_types: Optional[ + data_converter: PluginParameter[temporalio.converter.DataConverter] = None, + client_interceptors: PluginParameter[ + Sequence[temporalio.client.Interceptor] + ] = None, + activities: PluginParameter[Sequence[Callable]] = None, + nexus_service_handlers: PluginParameter[Sequence[Any]] = None, + workflows: PluginParameter[Sequence[Type]] = None, + workflow_runner: PluginParameter[WorkflowRunner] = None, + worker_interceptors: PluginParameter[ + Sequence[temporalio.worker.Interceptor] + ] = None, + workflow_failure_exception_types: PluginParameter[ Sequence[Type[BaseException]] ] = None, run_context: Optional[AbstractAsyncContextManager[None]] = None, @@ -67,7 +92,7 @@ def __init__( self.activities = activities self.nexus_service_handlers = nexus_service_handlers self.workflows = workflows - self.passthrough_modules = passthrough_modules + self.workflow_runner = workflow_runner self.worker_interceptors = worker_interceptors self.workflow_failure_exception_types = workflow_failure_exception_types self.run_context = run_context @@ -79,18 +104,18 @@ def init_client_plugin(self, next: temporalio.client.Plugin) -> None: self.next_client_plugin = next def configure_client(self, config: ClientConfig) -> ClientConfig: - if self.data_converter: - if not config["data_converter"] == temporalio.converter.default(): - raise ValueError( - "Static Plugin was configured with a data converter, but the client was as well." - ) - else: - config["data_converter"] = self.data_converter - - if self.client_interceptors: - config["interceptors"] = list(config.get("interceptors", [])) + list( - self.client_interceptors - ) + self._set_dict( + config, # type: ignore + "data_converter", + self._resolve_parameter(config.get("data_converter"), self.data_converter), + ) + self._set_dict( + config, # type: ignore + "interceptors", + self._resolve_append_parameter( + config.get("interceptors"), self.client_interceptors + ), + ) return self.next_client_plugin.configure_client(config) @@ -100,77 +125,82 @@ async def connect_service_client( return await self.next_client_plugin.connect_service_client(config) def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - if self.activities: - config["activities"] = list(config.get("activities", [])) + list( - self.activities - ) - - if self.nexus_service_handlers: - config["nexus_service_handlers"] = list( - config.get("nexus_service_handlers", []) - ) + list(self.nexus_service_handlers) - - if self.workflows: - config["workflows"] = list(config.get("workflows", [])) + list( - self.workflows - ) - - if self.passthrough_modules: - runner = config.get("workflow_runner") - if runner and isinstance(runner, SandboxedWorkflowRunner): - config["workflow_runner"] = dataclasses.replace( - runner, - restrictions=runner.restrictions.with_passthrough_modules( - *self.passthrough_modules - ), - ) - - if self.worker_interceptors: - config["interceptors"] = list(config.get("interceptors", [])) + list( - self.worker_interceptors - ) - - if self.workflow_failure_exception_types: - config["workflow_failure_exception_types"] = list( - config.get("workflow_failure_exception_types", []) - ) + list(self.workflow_failure_exception_types) + self._set_dict( + config, # type: ignore + "activities", + self._resolve_append_parameter(config.get("activities"), self.activities), + ) + self._set_dict( + config, # type: ignore + "nexus_service_handlers", + self._resolve_append_parameter( + config.get("nexus_service_handlers"), self.nexus_service_handlers + ), + ) + self._set_dict( + config, # type: ignore + "workflows", + self._resolve_append_parameter(config.get("workflows"), self.workflows), + ) + + self._set_dict( + config, # type: ignore + "workflow_runner", + self._resolve_parameter( + config.get("workflow_runner"), self.workflow_runner + ), + ) + self._set_dict( + config, # type: ignore + "interceptors", + self._resolve_append_parameter( + config.get("interceptors"), self.worker_interceptors + ), + ) + self._set_dict( + config, # type: ignore + "workflow_failure_exception_types", + self._resolve_append_parameter( + config.get("workflow_failure_exception_types"), + self.workflow_failure_exception_types, + ), + ) return config def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - if self.data_converter: - if not config["data_converter"] == temporalio.converter.default(): - raise ValueError( - "Static Plugin was configured with a data converter, but the client was as well." - ) - else: - config["data_converter"] = self.data_converter - - if self.workflows: - config["workflows"] = list(config.get("workflows", [])) + list( - self.workflows - ) - - if self.passthrough_modules: - runner = config.get("workflow_runner") - if runner and isinstance(runner, SandboxedWorkflowRunner): - config["workflow_runner"] = dataclasses.replace( - runner, - restrictions=runner.restrictions.with_passthrough_modules( - *self.passthrough_modules - ), - ) - - if self.worker_interceptors: - config["interceptors"] = list(config.get("interceptors", [])) + list( - self.worker_interceptors - ) - - if self.workflow_failure_exception_types: - config["workflow_failure_exception_types"] = list( - config.get("workflow_failure_exception_types", []) - ) + list(self.workflow_failure_exception_types) - + self._set_dict( + config, # type: ignore + "data_converter", + self._resolve_parameter(config.get("data_converter"), self.data_converter), + ) + self._set_dict( + config, # type: ignore + "workflows", + self._resolve_append_parameter(config.get("workflows"), self.workflows), + ) + self._set_dict( + config, # type: ignore + "workflow_runner", + self._resolve_parameter( + config.get("workflow_runner"), self.workflow_runner + ), + ) + self._set_dict( + config, # type: ignore + "interceptors", + self._resolve_append_parameter( + config.get("interceptors"), self.worker_interceptors + ), + ) + self._set_dict( + config, # type: ignore + "workflow_failure_exception_types", + self._resolve_append_parameter( + config.get("workflow_failure_exception_types"), + self.workflow_failure_exception_types, + ), + ) return config async def run_worker(self, worker: Worker) -> None: @@ -197,3 +227,34 @@ async def run_replayer( replayer, histories ) as results: yield results + + @staticmethod + def _resolve_parameter( + existing: Optional[T], parameter: PluginParameter[T] + ) -> Optional[T]: + if parameter is None: + return existing + elif callable(parameter): + return cast(Callable[[Optional[T]], Optional[T]], parameter)(existing) + else: + return parameter + + @staticmethod + def _resolve_append_parameter( + existing: Optional[Sequence[T]], parameter: PluginParameter[Sequence[T]] + ) -> Optional[Sequence[T]]: + if parameter is None: + return existing + elif callable(parameter): + return cast( + Callable[[Optional[Sequence[T]]], Optional[Sequence[T]]], parameter + )(existing) + else: + return list(existing or []) + list(parameter) + + @staticmethod + def _set_dict(config: dict[str, Any], key: str, value: Optional[Any]) -> None: + if value is not None: + config[key] = value + else: + del config[key] diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 98ecde3db..dd912ad0d 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -2,15 +2,17 @@ import uuid import warnings from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import AsyncIterator, cast +from typing import AsyncIterator, Optional, cast import pytest import temporalio.client +import temporalio.converter import temporalio.worker from temporalio import workflow from temporalio.client import Client, ClientConfig, OutboundInterceptor, Plugin from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.converter import DataConverter from temporalio.plugin import create_plugin from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( @@ -237,6 +239,11 @@ class HelloWorkflow: async def run(self, name: str) -> str: return f"Hello, {name}!" +@workflow.defn +class HelloWorkflow2: + @workflow.run + async def run(self, name: str) -> str: + return f"Hello, {name}!" async def test_replay(client: Client) -> None: plugin = ReplayCheckPlugin() @@ -258,10 +265,11 @@ async def test_replay(client: Client) -> None: await replayer.replay_workflow(await handle.fetch_history()) + async def test_static_plugins(client: Client) -> None: plugin = create_plugin( data_converter=pydantic_data_converter, - workflows=[HelloWorkflow], + workflows=[HelloWorkflow2], ) config = client.config() config["plugins"] = [plugin] @@ -270,23 +278,58 @@ async def test_static_plugins(client: Client) -> None: assert new_client.data_converter == pydantic_data_converter # Test without plugin registered in client - worker = Worker( + worker = Worker( client, task_queue="queue", activities=[never_run_activity], + workflows=[HelloWorkflow], plugins=[plugin], ) - assert worker.config().get("workflows") == [HelloWorkflow] + # On a sequence, a value is appended + assert worker.config().get("workflows") == [HelloWorkflow, HelloWorkflow2] # Test with plugin registered in client - worker = Worker( + worker = Worker( new_client, task_queue="queue", activities=[never_run_activity], plugins=[plugin], ) - assert worker.config().get("workflows") == [HelloWorkflow] + assert worker.config().get("workflows") == [HelloWorkflow2] - replayer = Replayer(workflows=[], plugins=[plugin]) + replayer = Replayer(workflows=[HelloWorkflow], plugins=[plugin]) assert replayer.config().get("data_converter") == pydantic_data_converter - assert replayer.config().get("workflows") == [HelloWorkflow] + assert replayer.config().get("workflows") == [HelloWorkflow, HelloWorkflow2] + + +async def test_static_plugins_callables(client: Client) -> None: + def converter(old: Optional[DataConverter]): + if old != temporalio.converter.default(): + raise ValueError("Can't override non-default converter") + return pydantic_data_converter + + plugin = create_plugin( + data_converter=converter, + ) + config = client.config() + config["plugins"] = [plugin] + new_client = Client(**config) + + assert new_client.data_converter == pydantic_data_converter + + with pytest.raises(ValueError): + config["data_converter"] = pydantic_data_converter + Client(**config) + + # On a sequence, the lambda overrides the existing values + plugin = create_plugin( + workflows=lambda workflows: [], + ) + worker = Worker( + client, + task_queue="queue", + workflows=[HelloWorkflow], + activities=[never_run_activity], + plugins=[plugin], + ) + assert worker.config().get("workflows") == [] From aeac7e2641acdebba86d31d7d2d5682c19bf88db Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 2 Oct 2025 12:17:48 -0700 Subject: [PATCH 03/13] Convert OpenAI plugin as example --- .../openai_agents/_temporal_openai_agents.py | 268 +++++------------- temporalio/plugin.py | 58 ++-- tests/test_plugins.py | 2 + 3 files changed, 109 insertions(+), 219 deletions(-) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index d893affde..c86908020 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -47,12 +47,14 @@ DataConverter, DefaultPayloadConverter, ) +from temporalio.plugin import Plugin, create_plugin from temporalio.worker import ( Replayer, ReplayerConfig, Worker, WorkerConfig, WorkflowReplayResult, + WorkflowRunner, ) from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner @@ -172,24 +174,28 @@ def __init__(self) -> None: super().__init__(ToJsonOptions(exclude_unset=True)) -class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): - """Temporal plugin for integrating OpenAI agents with Temporal workflows. - - .. warning:: - This class is experimental and may change in future versions. - Use with caution in production environments. - - This plugin provides seamless integration between the OpenAI Agents SDK and - Temporal workflows. It automatically configures the necessary interceptors, - activities, and data converters to enable OpenAI agents to run within - Temporal workflows with proper tracing and model execution. - - The plugin: - 1. Configures the Pydantic data converter for type-safe serialization - 2. Sets up tracing interceptors for OpenAI agent interactions - 3. Registers model execution activities - 4. Automatically registers MCP server activities and manages their lifecycles - 5. Manages the OpenAI agent runtime overrides during worker execution +def _data_converter(converter: Optional[DataConverter]) -> DataConverter: + if converter is None: + return DataConverter(payload_converter_class=OpenAIPayloadConverter) + elif converter.payload_converter_class is DefaultPayloadConverter: + return dataclasses.replace( + converter, payload_converter_class=OpenAIPayloadConverter + ) + elif not isinstance(converter.payload_converter, OpenAIPayloadConverter): + raise ValueError( + "The payload converter must be of type OpenAIPayloadConverter." + ) + return converter + + +def OpenAIAgentsPlugin( + model_params: Optional[ModelActivityParameters] = None, + model_provider: Optional[ModelProvider] = None, + mcp_server_providers: Sequence[ + Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] + ] = (), +) -> Plugin: + """Create an OpenAI agents plugin. Args: model_params: Configuration parameters for Temporal activity execution @@ -197,201 +203,59 @@ class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): model_provider: Optional model provider for custom model implementations. Useful for testing or custom model integrations. mcp_server_providers: Sequence of MCP servers to automatically register with the worker. - The plugin will wrap each server in a TemporalMCPServer if needed and - manage their connection lifecycles tied to the worker lifetime. This is - the recommended way to use MCP servers with Temporal workflows. - - Example: - >>> from temporalio.client import Client - >>> from temporalio.worker import Worker - >>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters, StatelessMCPServerProvider - >>> from agents.mcp import MCPServerStdio - >>> from datetime import timedelta - >>> - >>> # Configure model parameters - >>> model_params = ModelActivityParameters( - ... start_to_close_timeout=timedelta(seconds=30), - ... retry_policy=RetryPolicy(maximum_attempts=3) - ... ) - >>> - >>> # Create MCP servers - >>> filesystem_server = StatelessMCPServerProvider(MCPServerStdio( - ... name="Filesystem Server", - ... params={"command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "."]} - ... )) - >>> - >>> # Create plugin with MCP servers - >>> plugin = OpenAIAgentsPlugin( - ... model_params=model_params, - ... mcp_server_providers=[filesystem_server] - ... ) - >>> - >>> # Use with client and worker - >>> client = await Client.connect( - ... "localhost:7233", - ... plugins=[plugin] - ... ) - >>> worker = Worker( - ... client, - ... task_queue="my-task-queue", - ... workflows=[MyWorkflow], - ... ) + Each server will be wrapped in a TemporalMCPServer if not already wrapped, + and their activities will be automatically registered with the worker. + The plugin manages the connection lifecycle of these servers. """ - - def __init__( - self, - model_params: Optional[ModelActivityParameters] = None, - model_provider: Optional[ModelProvider] = None, - mcp_server_providers: Sequence[ - Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] - ] = (), - ) -> None: - """Initialize the OpenAI agents plugin. - - Args: - model_params: Configuration parameters for Temporal activity execution - of model calls. If None, default parameters will be used. - model_provider: Optional model provider for custom model implementations. - Useful for testing or custom model integrations. - mcp_server_providers: Sequence of MCP servers to automatically register with the worker. - Each server will be wrapped in a TemporalMCPServer if not already wrapped, - and their activities will be automatically registered with the worker. - The plugin manages the connection lifecycle of these servers. - """ - if model_params is None: - model_params = ModelActivityParameters() - - # For the default provider, we provide a default start_to_close_timeout of 60 seconds. - # Other providers will need to define their own. - if ( - model_params.start_to_close_timeout is None - and model_params.schedule_to_close_timeout is None - ): - if model_provider is None: - model_params.start_to_close_timeout = timedelta(seconds=60) - else: - raise ValueError( - "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" - ) - - self._model_params = model_params - self._model_provider = model_provider - self._mcp_server_providers = mcp_server_providers - - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: - """Set the next client plugin""" - self.next_client_plugin = next - - async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: - """No modifications to service client""" - return await self.next_client_plugin.connect_service_client(config) - - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: - """Set the next worker plugin""" - self.next_worker_plugin = next - - @staticmethod - def _data_converter(converter: Optional[DataConverter]) -> DataConverter: - if converter is None: - return DataConverter(payload_converter_class=OpenAIPayloadConverter) - elif converter.payload_converter_class is DefaultPayloadConverter: - return dataclasses.replace( - converter, payload_converter_class=OpenAIPayloadConverter - ) - elif not isinstance(converter.payload_converter, OpenAIPayloadConverter): + if model_params is None: + model_params = ModelActivityParameters() + + # For the default provider, we provide a default start_to_close_timeout of 60 seconds. + # Other providers will need to define their own. + if ( + model_params.start_to_close_timeout is None + and model_params.schedule_to_close_timeout is None + ): + if model_provider is None: + model_params.start_to_close_timeout = timedelta(seconds=60) + else: raise ValueError( - "The payload converter must be of type OpenAIPayloadConverter." + "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" ) - return converter - - def configure_client(self, config: ClientConfig) -> ClientConfig: - """Configure the Temporal client for OpenAI agents integration. - - This method sets up the Pydantic data converter to enable proper - serialization of OpenAI agent objects and responses. - - Args: - config: The client configuration to modify. - - Returns: - The modified client configuration. - """ - config["data_converter"] = self._data_converter(config["data_converter"]) - return self.next_client_plugin.configure_client(config) - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - """Configure the Temporal worker for OpenAI agents integration. + new_activities = [ModelActivity(model_provider).invoke_model_activity] - This method adds the necessary interceptors and activities for OpenAI - agent execution: - - Adds tracing interceptors for OpenAI agent interactions - - Registers model execution activities + server_names = [server.name for server in mcp_server_providers] + if len(server_names) != len(set(server_names)): + raise ValueError( + f"More than one mcp server registered with the same name. Please provide unique names." + ) - Args: - config: The worker configuration to modify. + for mcp_server in mcp_server_providers: + new_activities.extend(mcp_server._get_activities()) - Returns: - The modified worker configuration. - """ - config["interceptors"] = list(config.get("interceptors") or []) + [ - OpenAIAgentsTracingInterceptor() - ] - new_activities = [ModelActivity(self._model_provider).invoke_model_activity] + def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner: + if not runner: + raise ValueError("No WorkflowRunner provided to the OpenAI plugin.") - server_names = [server.name for server in self._mcp_server_providers] - if len(server_names) != len(set(server_names)): - raise ValueError( - f"More than one mcp server registered with the same name. Please provide unique names." - ) - - for mcp_server in self._mcp_server_providers: - new_activities.extend(mcp_server._get_activities()) - config["activities"] = list(config.get("activities") or []) + new_activities - - runner = config.get("workflow_runner") + # If in sandbox, add additional passthrough if isinstance(runner, SandboxedWorkflowRunner): - config["workflow_runner"] = dataclasses.replace( + return dataclasses.replace( runner, restrictions=runner.restrictions.with_passthrough_modules("mcp"), ) - - config["workflow_failure_exception_types"] = list( - config.get("workflow_failure_exception_types") or [] - ) + [AgentsWorkflowError] - return self.next_worker_plugin.configure_worker(config) - - async def run_worker(self, worker: Worker) -> None: - """Run the worker with OpenAI agents temporal overrides. - - This method sets up the necessary runtime overrides for OpenAI agents - to work within the Temporal worker context, including custom runners - and trace providers. - - Args: - worker: The worker instance to run. - """ - with set_open_ai_agent_temporal_overrides(self._model_params): - await self.next_worker_plugin.run_worker(worker) - - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - """Configure the replayer for OpenAI Agents.""" - config["interceptors"] = list(config.get("interceptors") or []) + [ - OpenAIAgentsTracingInterceptor() - ] - config["data_converter"] = self._data_converter(config.get("data_converter")) - return self.next_worker_plugin.configure_replayer(config) + return runner @asynccontextmanager - async def run_replayer( - self, - replayer: Replayer, - histories: AsyncIterator[temporalio.client.WorkflowHistory], - ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: - """Set the OpenAI Overrides during replay""" - with set_open_ai_agent_temporal_overrides(self._model_params): - async with self.next_worker_plugin.run_replayer( - replayer, histories - ) as results: - yield results + async def run_context() -> AsyncIterator[None]: + with set_open_ai_agent_temporal_overrides(model_params): + yield + + return create_plugin( + data_converter=_data_converter, + worker_interceptors=[OpenAIAgentsTracingInterceptor()], + activities=new_activities, + workflow_runner=workflow_runner, + workflow_failure_exception_types=[AgentsWorkflowError], + run_context=lambda: run_context(), + ) diff --git a/temporalio/plugin.py b/temporalio/plugin.py index 44c7827ac..e2c0aec6e 100644 --- a/temporalio/plugin.py +++ b/temporalio/plugin.py @@ -1,3 +1,9 @@ +"""Plugin module for Temporal SDK. + +This module provides plugin functionality that allows customization of both client +and worker behavior in the Temporal SDK through configurable parameters. +""" + import abc from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import ( @@ -6,7 +12,6 @@ Callable, Optional, Sequence, - Set, Type, TypeVar, Union, @@ -28,6 +33,12 @@ class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin, abc.ABC): + """Abstract base class for Temporal plugins. + + This class combines both client and worker plugin capabilities, + just used where multiple inheritance is not possible. + """ + pass @@ -52,8 +63,25 @@ def create_plugin( workflow_failure_exception_types: PluginParameter[ Sequence[Type[BaseException]] ] = None, - run_context: Optional[AbstractAsyncContextManager[None]] = None, + run_context: Optional[Callable[[], AbstractAsyncContextManager[None]]] = None, ) -> Plugin: + """Create a static plugin with configurable parameters. + + Args: + data_converter: Data converter for serialization, or callable to customize existing one. + client_interceptors: Client interceptors to append, or callable to customize existing ones. + activities: Activity functions to append, or callable to customize existing ones. + nexus_service_handlers: Nexus service handlers to append, or callable to customize existing ones. + workflows: Workflow classes to append, or callable to customize existing ones. + workflow_runner: Workflow runner, or callable to customize existing one. + worker_interceptors: Worker interceptors to append, or callable to customize existing ones. + workflow_failure_exception_types: Exception types for workflow failures to append, + or callable to customize existing ones. + run_context: Optional async context manager producer to wrap worker/replayer execution. + + Returns: + A configured Plugin instance. + """ return _StaticPlugin( data_converter=data_converter, client_interceptors=client_interceptors, @@ -71,21 +99,17 @@ class _StaticPlugin(Plugin): def __init__( self, *, - data_converter: PluginParameter[temporalio.converter.DataConverter] = None, - client_interceptors: PluginParameter[ - Sequence[temporalio.client.Interceptor] - ] = None, - activities: PluginParameter[Sequence[Callable]] = None, - nexus_service_handlers: PluginParameter[Sequence[Any]] = None, - workflows: PluginParameter[Sequence[Type]] = None, - workflow_runner: PluginParameter[WorkflowRunner] = None, - worker_interceptors: PluginParameter[ - Sequence[temporalio.worker.Interceptor] - ] = None, + data_converter: PluginParameter[temporalio.converter.DataConverter], + client_interceptors: PluginParameter[Sequence[temporalio.client.Interceptor]], + activities: PluginParameter[Sequence[Callable]], + nexus_service_handlers: PluginParameter[Sequence[Any]], + workflows: PluginParameter[Sequence[Type]], + workflow_runner: PluginParameter[WorkflowRunner], + worker_interceptors: PluginParameter[Sequence[temporalio.worker.Interceptor]], workflow_failure_exception_types: PluginParameter[ Sequence[Type[BaseException]] - ] = None, - run_context: Optional[AbstractAsyncContextManager[None]] = None, + ], + run_context: Optional[Callable[[], AbstractAsyncContextManager[None]]], ) -> None: self.data_converter = data_converter self.client_interceptors = client_interceptors @@ -205,7 +229,7 @@ def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: async def run_worker(self, worker: Worker) -> None: if self.run_context: - async with self.run_context: + async with self.run_context(): await self.next_worker_plugin.run_worker(worker) else: await self.next_worker_plugin.run_worker(worker) @@ -217,7 +241,7 @@ async def run_replayer( histories: AsyncIterator[WorkflowHistory], ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: if self.run_context: - async with self.run_context: + async with self.run_context(): async with self.next_worker_plugin.run_replayer( replayer, histories ) as results: diff --git a/tests/test_plugins.py b/tests/test_plugins.py index dd912ad0d..71eaa3bce 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -239,12 +239,14 @@ class HelloWorkflow: async def run(self, name: str) -> str: return f"Hello, {name}!" + @workflow.defn class HelloWorkflow2: @workflow.run async def run(self, name: str) -> str: return f"Hello, {name}!" + async def test_replay(client: Client) -> None: plugin = ReplayCheckPlugin() new_config = client.config() From 984a50d28bd34b6d46681891f5a11c4e6e093ee3 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 2 Oct 2025 12:58:12 -0700 Subject: [PATCH 04/13] Delay openai client creation --- .../contrib/openai_agents/_invoke_model_activity.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 6409f0bb1..aec84350d 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -155,14 +155,16 @@ class ModelActivity: def __init__(self, model_provider: Optional[ModelProvider] = None): """Initialize the activity with a model provider.""" - self._model_provider = model_provider or OpenAIProvider( - openai_client=AsyncOpenAI(max_retries=0) - ) + self._model_provider = model_provider @activity.defn @_auto_heartbeater async def invoke_model_activity(self, input: ActivityModelInput) -> ModelResponse: """Activity that invokes a model with the given input.""" + if not self._model_provider: + self._model_provider = OpenAIProvider( + openai_client=AsyncOpenAI(max_retries=0) + ) model = self._model_provider.get_model(input.get("model_name")) async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: From 56a92de675b287841642b179ec42bb1ffa2d7003 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 3 Oct 2025 13:39:34 -0700 Subject: [PATCH 05/13] Change plugin structure to remove initializers --- temporalio/client.py | 53 +++------ .../openai_agents/_temporal_openai_agents.py | 1 + temporalio/plugin.py | 42 ++++--- temporalio/worker/_plugin.py | 45 ++------ temporalio/worker/_replayer.py | 20 ++-- temporalio/worker/_worker.py | 21 ++-- tests/test_plugins.py | 103 ++++++++++-------- 7 files changed, 131 insertions(+), 154 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 20a9b3c6d..5208c99a0 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -70,11 +70,13 @@ WorkflowSerializationContext, ) from temporalio.service import ( + ConnectConfig, HttpConnectProxyConfig, KeepAliveConfig, RetryConfig, RPCError, RPCStatusCode, + ServiceClient, TLSConfig, ) @@ -198,12 +200,14 @@ async def connect( http_connect_proxy_config=http_connect_proxy_config, ) - root_plugin: Plugin = _RootPlugin() + def make_lambda(plugin, next): + return lambda config: plugin.connect_service_client(config, next) + + next_function = lambda config: ServiceClient.connect(config) for plugin in reversed(plugins): - plugin.init_client_plugin(root_plugin) - root_plugin = plugin + next_function = make_lambda(plugin, next_function) - service_client = await root_plugin.connect_service_client(connect_config) + service_client = await next_function(connect_config) return Client( service_client, @@ -243,12 +247,10 @@ def __init__( plugins=plugins, ) - root_plugin: Plugin = _RootPlugin() - for plugin in reversed(plugins): - plugin.init_client_plugin(root_plugin) - root_plugin = plugin + for plugin in plugins: + config = plugin.configure_client(config) - self._init_from_config(root_plugin.configure_client(config)) + self._init_from_config(config) def _init_from_config(self, config: ClientConfig): self._config = config @@ -7541,20 +7543,6 @@ def name(self) -> str: """ return type(self).__module__ + "." + type(self).__qualname__ - @abstractmethod - def init_client_plugin(self, next: Plugin) -> None: - """Initialize this plugin in the plugin chain. - - This method sets up the chain of responsibility pattern by providing a reference - to the next plugin in the chain. It is called during client creation to build - the plugin chain. Note, this may be called twice in the case of :py:meth:`connect`. - Implementations should store this reference and call the corresponding method - of the next plugin on method calls. - - Args: - next: The next plugin in the chain to delegate to. - """ - @abstractmethod def configure_client(self, config: ClientConfig) -> ClientConfig: """Hook called when creating a client to allow modification of configuration. @@ -7572,8 +7560,10 @@ def configure_client(self, config: ClientConfig) -> ClientConfig: @abstractmethod async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: + self, + config: ConnectConfig, + next: Callable[[ConnectConfig], Awaitable[ServiceClient]], + ) -> ServiceClient: """Hook called when connecting to the Temporal service. This method is called during service client connection and allows plugins @@ -7586,16 +7576,3 @@ async def connect_service_client( Returns: The connected service client. """ - - -class _RootPlugin(Plugin): - def init_client_plugin(self, next: Plugin) -> None: - raise NotImplementedError() - - def configure_client(self, config: ClientConfig) -> ClientConfig: - return config - - async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: - return await temporalio.service.ServiceClient.connect(config) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index c86908020..0b6ed9ea3 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -252,6 +252,7 @@ async def run_context() -> AsyncIterator[None]: yield return create_plugin( + name="OpenAIAgentsPlugin", data_converter=_data_converter, worker_interceptors=[OpenAIAgentsTracingInterceptor()], activities=new_activities, diff --git a/temporalio/plugin.py b/temporalio/plugin.py index e2c0aec6e..f103564d6 100644 --- a/temporalio/plugin.py +++ b/temporalio/plugin.py @@ -9,6 +9,7 @@ from typing import ( Any, AsyncIterator, + Awaitable, Callable, Optional, Sequence, @@ -22,6 +23,7 @@ import temporalio.converter import temporalio.worker from temporalio.client import ClientConfig, WorkflowHistory +from temporalio.service import ConnectConfig, ServiceClient from temporalio.worker import ( Replayer, ReplayerConfig, @@ -48,6 +50,7 @@ class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin, abc.ABC): def create_plugin( + name: str, *, data_converter: PluginParameter[temporalio.converter.DataConverter] = None, client_interceptors: PluginParameter[ @@ -68,6 +71,7 @@ def create_plugin( """Create a static plugin with configurable parameters. Args: + name: The name of the plugin. data_converter: Data converter for serialization, or callable to customize existing one. client_interceptors: Client interceptors to append, or callable to customize existing ones. activities: Activity functions to append, or callable to customize existing ones. @@ -83,6 +87,7 @@ def create_plugin( A configured Plugin instance. """ return _StaticPlugin( + name=name, data_converter=data_converter, client_interceptors=client_interceptors, activities=activities, @@ -98,6 +103,7 @@ def create_plugin( class _StaticPlugin(Plugin): def __init__( self, + name: str, *, data_converter: PluginParameter[temporalio.converter.DataConverter], client_interceptors: PluginParameter[Sequence[temporalio.client.Interceptor]], @@ -111,6 +117,7 @@ def __init__( ], run_context: Optional[Callable[[], AbstractAsyncContextManager[None]]], ) -> None: + self._name = name self.data_converter = data_converter self.client_interceptors = client_interceptors self.activities = activities @@ -121,11 +128,8 @@ def __init__( self.workflow_failure_exception_types = workflow_failure_exception_types self.run_context = run_context - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: - self.next_worker_plugin = next - - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: - self.next_client_plugin = next + def name(self) -> str: + return self._name def configure_client(self, config: ClientConfig) -> ClientConfig: self._set_dict( @@ -141,12 +145,14 @@ def configure_client(self, config: ClientConfig) -> ClientConfig: ), ) - return self.next_client_plugin.configure_client(config) + return config async def connect_service_client( - self, config: temporalio.service.ConnectConfig + self, + config: ConnectConfig, + next: Callable[[ConnectConfig], Awaitable[ServiceClient]], ) -> temporalio.service.ServiceClient: - return await self.next_client_plugin.connect_service_client(config) + return await next(config) def configure_worker(self, config: WorkerConfig) -> WorkerConfig: self._set_dict( @@ -227,29 +233,31 @@ def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: ) return config - async def run_worker(self, worker: Worker) -> None: + async def run_worker( + self, worker: Worker, next: Callable[[Worker], Awaitable[None]] + ) -> None: if self.run_context: async with self.run_context(): - await self.next_worker_plugin.run_worker(worker) + await next(worker) else: - await self.next_worker_plugin.run_worker(worker) + await next(worker) @asynccontextmanager async def run_replayer( self, replayer: Replayer, histories: AsyncIterator[WorkflowHistory], + next: Callable[ + [Replayer, AsyncIterator[WorkflowHistory]], + AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], + ], ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: if self.run_context: async with self.run_context(): - async with self.next_worker_plugin.run_replayer( - replayer, histories - ) as results: + async with next(replayer, histories) as results: yield results else: - async with self.next_worker_plugin.run_replayer( - replayer, histories - ) as results: + async with next(replayer, histories) as results: yield results @staticmethod diff --git a/temporalio/worker/_plugin.py b/temporalio/worker/_plugin.py index 0e696a2dd..2f4c04a1e 100644 --- a/temporalio/worker/_plugin.py +++ b/temporalio/worker/_plugin.py @@ -2,7 +2,7 @@ import abc from contextlib import AbstractAsyncContextManager -from typing import TYPE_CHECKING, AsyncIterator +from typing import TYPE_CHECKING, AsyncIterator, Awaitable, Callable from temporalio.client import WorkflowHistory @@ -34,19 +34,6 @@ def name(self) -> str: """ return type(self).__module__ + "." + type(self).__qualname__ - @abc.abstractmethod - def init_worker_plugin(self, next: Plugin) -> None: - """Initialize this plugin in the plugin chain. - - This method sets up the chain of responsibility pattern by providing a reference - to the next plugin in the chain. It is called during worker creation to build - the plugin chain. Implementations should store this reference and call the corresponding method - of the next plugin on method calls. - - Args: - next: The next plugin in the chain to delegate to. - """ - @abc.abstractmethod def configure_worker(self, config: WorkerConfig) -> WorkerConfig: """Hook called when creating a worker to allow modification of configuration. @@ -64,7 +51,9 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: """ @abc.abstractmethod - async def run_worker(self, worker: Worker) -> None: + async def run_worker( + self, worker: Worker, next: Callable[[Worker], Awaitable[None]] + ) -> None: """Hook called when running a worker to allow interception of execution. This method is called when the worker is started and allows plugins to @@ -73,6 +62,7 @@ async def run_worker(self, worker: Worker) -> None: Args: worker: The worker instance to run. + next: Callable to continue the worker execution. """ @abc.abstractmethod @@ -94,26 +84,9 @@ def run_replayer( self, replayer: Replayer, histories: AsyncIterator[WorkflowHistory], + next: Callable[ + [Replayer, AsyncIterator[WorkflowHistory]], + AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], + ], ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: """Hook called when running a replayer to allow interception of execution.""" - - -class _RootPlugin(Plugin): - def init_worker_plugin(self, next: Plugin) -> None: - raise NotImplementedError() - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - return config - - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - return config - - async def run_worker(self, worker: Worker) -> None: - await worker._run() - - def run_replayer( - self, - replayer: Replayer, - histories: AsyncIterator[WorkflowHistory], - ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: - return replayer._workflow_replay_iterator(histories) diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 664602219..72ffa4b56 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -21,7 +21,6 @@ from ..common import HeaderCodecBehavior from ._interceptor import Interceptor -from ._plugin import _RootPlugin from ._worker import load_default_build_id from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner @@ -84,12 +83,9 @@ def __init__( ) # Apply plugin configuration - root_plugin: temporalio.worker.Plugin = _RootPlugin() - for plugin in reversed(plugins): - plugin.init_worker_plugin(root_plugin) - root_plugin = plugin - self._config = root_plugin.configure_replayer(self._config) - self._plugin = root_plugin + self.plugins = plugins + for plugin in plugins: + self._config = plugin.configure_replayer(self._config) # Validate workflows after plugin configuration if not self._config["workflows"]: @@ -176,7 +172,15 @@ def workflow_replay_iterator( An async iterator that returns replayed workflow results as they are replayed. """ - return self._plugin.run_replayer(self, histories) + + def make_lambda(plugin, next): + return lambda r, hs: plugin.run_replayer(r, hs, next) + + next_function = lambda r, hs: r._workflow_replay_iterator(hs) + for plugin in reversed(self.plugins): + next_function = make_lambda(plugin, next_function) + + return next_function(self, histories) @asynccontextmanager async def _workflow_replay_iterator( diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index dea516da8..ad238507c 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -38,7 +38,7 @@ from ._activity import SharedStateManager, _ActivityWorker from ._interceptor import Interceptor from ._nexus import _NexusWorker -from ._plugin import Plugin, _RootPlugin +from ._plugin import Plugin from ._tuning import WorkerTuner from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner @@ -377,12 +377,9 @@ def __init__( ) plugins = plugins_from_client + list(plugins) - root_plugin: Plugin = _RootPlugin() - for plugin in reversed(plugins): - plugin.init_worker_plugin(root_plugin) - root_plugin = plugin - config = root_plugin.configure_worker(config) - self._plugin = root_plugin + self.plugins = plugins + for plugin in plugins: + config = plugin.configure_worker(config) self._init_from_config(client, config) @@ -690,7 +687,15 @@ async def run(self) -> None: also cancel the shutdown process. Therefore users are encouraged to use explicit shutdown instead. """ - await self._plugin.run_worker(self) + + def make_lambda(plugin, next): + return lambda w: plugin.run_worker(w, next) + + next_function = lambda w: w._run() + for plugin in reversed(self.plugins): + next_function = make_lambda(plugin, next_function) + + await next_function(self) async def _run(self): # Eagerly validate which will do a namespace check in Core diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 71eaa3bce..363bebb59 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -2,7 +2,7 @@ import uuid import warnings from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import AsyncIterator, Optional, cast +from typing import AsyncIterator, Awaitable, Callable, Optional, cast import pytest @@ -10,10 +10,11 @@ import temporalio.converter import temporalio.worker from temporalio import workflow -from temporalio.client import Client, ClientConfig, OutboundInterceptor, Plugin +from temporalio.client import Client, ClientConfig, OutboundInterceptor, WorkflowHistory from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.converter import DataConverter from temporalio.plugin import create_plugin +from temporalio.service import ConnectConfig, ServiceClient from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( Replayer, @@ -40,21 +41,20 @@ class MyClientPlugin(temporalio.client.Plugin): def __init__(self): self.interceptor = TestClientInterceptor() - def init_client_plugin(self, next: Plugin) -> None: - self.next_client_plugin = next - def configure_client(self, config: ClientConfig) -> ClientConfig: config["namespace"] = "replaced_namespace" config["interceptors"] = list(config.get("interceptors") or []) + [ self.interceptor ] - return self.next_client_plugin.configure_client(config) + return config async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: + self, + config: ConnectConfig, + next: Callable[[ConnectConfig], Awaitable[ServiceClient]], + ) -> ServiceClient: config.api_key = "replaced key" - return await self.next_client_plugin.connect_service_client(config) + return await next(config) async def test_client_plugin(client: Client, env: WorkflowEnvironment): @@ -76,42 +76,41 @@ async def test_client_plugin(client: Client, env: WorkflowEnvironment): class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: - self.next_worker_plugin = next - - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: - self.next_client_plugin = next - def configure_client(self, config: ClientConfig) -> ClientConfig: - return self.next_client_plugin.configure_client(config) + return config def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "combined" - return self.next_worker_plugin.configure_worker(config) + return config async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: - return await self.next_client_plugin.connect_service_client(config) + self, + config: ConnectConfig, + next: Callable[[ConnectConfig], Awaitable[ServiceClient]], + ) -> ServiceClient: + return await next(config) - async def run_worker(self, worker: Worker) -> None: - await self.next_worker_plugin.run_worker(worker) + async def run_worker( + self, worker: Worker, next: Callable[[Worker], Awaitable[None]] + ) -> None: + await next(worker) def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - return self.next_worker_plugin.configure_replayer(config) + return config def run_replayer( self, replayer: Replayer, histories: AsyncIterator[temporalio.client.WorkflowHistory], + next: Callable[ + [Replayer, AsyncIterator[WorkflowHistory]], + AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], + ], ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: - return self.next_worker_plugin.run_replayer(replayer, histories) + return next(replayer, histories) class MyWorkerPlugin(temporalio.worker.Plugin): - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: - self.next_worker_plugin = next - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "replaced_queue" runner = config.get("workflow_runner") @@ -120,20 +119,26 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: runner, restrictions=runner.restrictions.with_passthrough_modules("my_module"), ) - return self.next_worker_plugin.configure_worker(config) + return config - async def run_worker(self, worker: Worker) -> None: - await self.next_worker_plugin.run_worker(worker) + async def run_worker( + self, worker: Worker, next: Callable[[Worker], Awaitable[None]] + ) -> None: + await next(worker) def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - return self.next_worker_plugin.configure_replayer(config) + return config def run_replayer( self, replayer: Replayer, histories: AsyncIterator[temporalio.client.WorkflowHistory], + next: Callable[ + [Replayer, AsyncIterator[WorkflowHistory]], + AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], + ], ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: - return self.next_worker_plugin.run_replayer(replayer, histories) + return next(replayer, histories) async def test_worker_plugin_basic_config(client: Client) -> None: @@ -196,40 +201,42 @@ async def test_worker_sandbox_restrictions(client: Client) -> None: class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: - self.next_worker_plugin = next - - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: - self.next_client_plugin = next - def configure_client(self, config: ClientConfig) -> ClientConfig: config["data_converter"] = pydantic_data_converter - return self.next_client_plugin.configure_client(config) + return config def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] - return self.next_worker_plugin.configure_worker(config) + return config def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: config["data_converter"] = pydantic_data_converter config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] - return self.next_worker_plugin.configure_replayer(config) + return config - async def run_worker(self, worker: Worker) -> None: - await self.next_worker_plugin.run_worker(worker) + async def run_worker( + self, worker: Worker, next: Callable[[Worker], Awaitable[None]] + ) -> None: + await next(worker) async def connect_service_client( - self, config: temporalio.service.ConnectConfig + self, + config: temporalio.service.ConnectConfig, + next: Callable[[ConnectConfig], Awaitable[ServiceClient]], ) -> temporalio.service.ServiceClient: - return await self.next_client_plugin.connect_service_client(config) + return await next(config) @asynccontextmanager async def run_replayer( self, replayer: Replayer, histories: AsyncIterator[temporalio.client.WorkflowHistory], + next: Callable[ + [Replayer, AsyncIterator[WorkflowHistory]], + AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], + ], ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: - async with self.next_worker_plugin.run_replayer(replayer, histories) as result: + async with next(replayer, histories) as result: yield result @@ -270,6 +277,7 @@ async def test_replay(client: Client) -> None: async def test_static_plugins(client: Client) -> None: plugin = create_plugin( + "MyPlugin", data_converter=pydantic_data_converter, workflows=[HelloWorkflow2], ) @@ -295,7 +303,6 @@ async def test_static_plugins(client: Client) -> None: new_client, task_queue="queue", activities=[never_run_activity], - plugins=[plugin], ) assert worker.config().get("workflows") == [HelloWorkflow2] @@ -311,6 +318,7 @@ def converter(old: Optional[DataConverter]): return pydantic_data_converter plugin = create_plugin( + "MyPlugin", data_converter=converter, ) config = client.config() @@ -325,6 +333,7 @@ def converter(old: Optional[DataConverter]): # On a sequence, the lambda overrides the existing values plugin = create_plugin( + "MyPlugin", workflows=lambda workflows: [], ) worker = Worker( From 8b73fde1ce25240401e2e6bdd5b6edb2835af240 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Mon, 6 Oct 2025 08:31:28 -0700 Subject: [PATCH 06/13] PR feedback - exposing SimplePlugin type --- temporalio/client.py | 2 +- .../openai_agents/_temporal_openai_agents.py | 200 +++++++---- temporalio/plugin.py | 310 ++++++++---------- tests/test_plugins.py | 8 +- 4 files changed, 270 insertions(+), 250 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 898d9db18..6c26d41ef 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -203,7 +203,7 @@ async def connect( def make_lambda(plugin, next): return lambda config: plugin.connect_service_client(config, next) - next_function = lambda config: ServiceClient.connect(config) + next_function = ServiceClient.connect for plugin in reversed(plugins): next_function = make_lambda(plugin, next_function) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 0b6ed9ea3..dc1c5b704 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -22,11 +22,7 @@ from agents.run import get_default_agent_runner, set_default_agent_runner from agents.tracing import get_trace_provider from agents.tracing.provider import DefaultTraceProvider -from openai.types.responses import ResponsePromptParam -import temporalio.client -import temporalio.worker -from temporalio.client import ClientConfig from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import ( @@ -47,15 +43,8 @@ DataConverter, DefaultPayloadConverter, ) -from temporalio.plugin import Plugin, create_plugin -from temporalio.worker import ( - Replayer, - ReplayerConfig, - Worker, - WorkerConfig, - WorkflowReplayResult, - WorkflowRunner, -) +from temporalio.plugin import SimplePlugin +from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner # Unsupported on python 3.9 @@ -188,14 +177,24 @@ def _data_converter(converter: Optional[DataConverter]) -> DataConverter: return converter -def OpenAIAgentsPlugin( - model_params: Optional[ModelActivityParameters] = None, - model_provider: Optional[ModelProvider] = None, - mcp_server_providers: Sequence[ - Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] - ] = (), -) -> Plugin: - """Create an OpenAI agents plugin. +class OpenAIAgentsPlugin(SimplePlugin): + """Temporal plugin for integrating OpenAI agents with Temporal workflows. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + + This plugin provides seamless integration between the OpenAI Agents SDK and + Temporal workflows. It automatically configures the necessary interceptors, + activities, and data converters to enable OpenAI agents to run within + Temporal workflows with proper tracing and model execution. + + The plugin: + 1. Configures the Pydantic data converter for type-safe serialization + 2. Sets up tracing interceptors for OpenAI agent interactions + 3. Registers model execution activities + 4. Automatically registers MCP server activities and manages their lifecycles + 5. Manages the OpenAI agent runtime overrides during worker execution Args: model_params: Configuration parameters for Temporal activity execution @@ -203,60 +202,117 @@ def OpenAIAgentsPlugin( model_provider: Optional model provider for custom model implementations. Useful for testing or custom model integrations. mcp_server_providers: Sequence of MCP servers to automatically register with the worker. - Each server will be wrapped in a TemporalMCPServer if not already wrapped, - and their activities will be automatically registered with the worker. - The plugin manages the connection lifecycle of these servers. + The plugin will wrap each server in a TemporalMCPServer if needed and + manage their connection lifecycles tied to the worker lifetime. This is + the recommended way to use MCP servers with Temporal workflows. + + Example: + >>> from temporalio.client import Client + >>> from temporalio.worker import Worker + >>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters, StatelessMCPServerProvider + >>> from agents.mcp import MCPServerStdio + >>> from datetime import timedelta + >>> + >>> # Configure model parameters + >>> model_params = ModelActivityParameters( + ... start_to_close_timeout=timedelta(seconds=30), + ... retry_policy=RetryPolicy(maximum_attempts=3) + ... ) + >>> + >>> # Create MCP servers + >>> filesystem_server = StatelessMCPServerProvider(MCPServerStdio( + ... name="Filesystem Server", + ... params={"command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "."]} + ... )) + >>> + >>> # Create plugin with MCP servers + >>> plugin = OpenAIAgentsPlugin( + ... model_params=model_params, + ... mcp_server_providers=[filesystem_server] + ... ) + >>> + >>> # Use with client and worker + >>> client = await Client.connect( + ... "localhost:7233", + ... plugins=[plugin] + ... ) + >>> worker = Worker( + ... client, + ... task_queue="my-task-queue", + ... workflows=[MyWorkflow], + ... ) """ - if model_params is None: - model_params = ModelActivityParameters() - - # For the default provider, we provide a default start_to_close_timeout of 60 seconds. - # Other providers will need to define their own. - if ( - model_params.start_to_close_timeout is None - and model_params.schedule_to_close_timeout is None + + def __init__( + self, + model_params: Optional[ModelActivityParameters] = None, + model_provider: Optional[ModelProvider] = None, + mcp_server_providers: Sequence[ + Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] + ] = (), ): - if model_provider is None: - model_params.start_to_close_timeout = timedelta(seconds=60) - else: + """Create an OpenAI agents plugin. + + Args: + model_params: Configuration parameters for Temporal activity execution + of model calls. If None, default parameters will be used. + model_provider: Optional model provider for custom model implementations. + Useful for testing or custom model integrations. + mcp_server_providers: Sequence of MCP servers to automatically register with the worker. + Each server will be wrapped in a TemporalMCPServer if not already wrapped, + and their activities will be automatically registered with the worker. + The plugin manages the connection lifecycle of these servers. + """ + if model_params is None: + model_params = ModelActivityParameters() + + # For the default provider, we provide a default start_to_close_timeout of 60 seconds. + # Other providers will need to define their own. + if ( + model_params.start_to_close_timeout is None + and model_params.schedule_to_close_timeout is None + ): + if model_provider is None: + model_params.start_to_close_timeout = timedelta(seconds=60) + else: + raise ValueError( + "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" + ) + + new_activities = [ModelActivity(model_provider).invoke_model_activity] + + server_names = [server.name for server in mcp_server_providers] + if len(server_names) != len(set(server_names)): raise ValueError( - "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" + f"More than one mcp server registered with the same name. Please provide unique names." ) - new_activities = [ModelActivity(model_provider).invoke_model_activity] - - server_names = [server.name for server in mcp_server_providers] - if len(server_names) != len(set(server_names)): - raise ValueError( - f"More than one mcp server registered with the same name. Please provide unique names." + for mcp_server in mcp_server_providers: + new_activities.extend(mcp_server._get_activities()) + + def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner: + if not runner: + raise ValueError("No WorkflowRunner provided to the OpenAI plugin.") + + # If in sandbox, add additional passthrough + if isinstance(runner, SandboxedWorkflowRunner): + return dataclasses.replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules("mcp"), + ) + return runner + + @asynccontextmanager + async def run_context() -> AsyncIterator[None]: + with set_open_ai_agent_temporal_overrides(model_params): + yield + + super().__init__( + name="OpenAIAgentsPlugin", + data_converter=_data_converter, + worker_interceptors=[OpenAIAgentsTracingInterceptor()], + activities=new_activities, + workflow_runner=workflow_runner, + workflow_failure_exception_types=[AgentsWorkflowError], + run_context=lambda: run_context(), ) - - for mcp_server in mcp_server_providers: - new_activities.extend(mcp_server._get_activities()) - - def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner: - if not runner: - raise ValueError("No WorkflowRunner provided to the OpenAI plugin.") - - # If in sandbox, add additional passthrough - if isinstance(runner, SandboxedWorkflowRunner): - return dataclasses.replace( - runner, - restrictions=runner.restrictions.with_passthrough_modules("mcp"), - ) - return runner - - @asynccontextmanager - async def run_context() -> AsyncIterator[None]: - with set_open_ai_agent_temporal_overrides(model_params): - yield - - return create_plugin( - name="OpenAIAgentsPlugin", - data_converter=_data_converter, - worker_interceptors=[OpenAIAgentsTracingInterceptor()], - activities=new_activities, - workflow_runner=workflow_runner, - workflow_failure_exception_types=[AgentsWorkflowError], - run_context=lambda: run_context(), - ) diff --git a/temporalio/plugin.py b/temporalio/plugin.py index f103564d6..032b0e8b0 100644 --- a/temporalio/plugin.py +++ b/temporalio/plugin.py @@ -4,7 +4,6 @@ and worker behavior in the Temporal SDK through configurable parameters. """ -import abc from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import ( Any, @@ -33,90 +32,65 @@ WorkflowRunner, ) - -class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin, abc.ABC): - """Abstract base class for Temporal plugins. - - This class combines both client and worker plugin capabilities, - just used where multiple inheritance is not possible. - """ - - pass - - T = TypeVar("T") PluginParameter = Union[None, T, Callable[[Optional[T]], T]] -def create_plugin( - name: str, - *, - data_converter: PluginParameter[temporalio.converter.DataConverter] = None, - client_interceptors: PluginParameter[ - Sequence[temporalio.client.Interceptor] - ] = None, - activities: PluginParameter[Sequence[Callable]] = None, - nexus_service_handlers: PluginParameter[Sequence[Any]] = None, - workflows: PluginParameter[Sequence[Type]] = None, - workflow_runner: PluginParameter[WorkflowRunner] = None, - worker_interceptors: PluginParameter[ - Sequence[temporalio.worker.Interceptor] - ] = None, - workflow_failure_exception_types: PluginParameter[ - Sequence[Type[BaseException]] - ] = None, - run_context: Optional[Callable[[], AbstractAsyncContextManager[None]]] = None, -) -> Plugin: - """Create a static plugin with configurable parameters. - - Args: - name: The name of the plugin. - data_converter: Data converter for serialization, or callable to customize existing one. - client_interceptors: Client interceptors to append, or callable to customize existing ones. - activities: Activity functions to append, or callable to customize existing ones. - nexus_service_handlers: Nexus service handlers to append, or callable to customize existing ones. - workflows: Workflow classes to append, or callable to customize existing ones. - workflow_runner: Workflow runner, or callable to customize existing one. - worker_interceptors: Worker interceptors to append, or callable to customize existing ones. - workflow_failure_exception_types: Exception types for workflow failures to append, - or callable to customize existing ones. - run_context: Optional async context manager producer to wrap worker/replayer execution. - - Returns: - A configured Plugin instance. +class SimplePlugin(temporalio.client.Plugin, temporalio.worker.Plugin): + """A simple plugin definition which limits has a limited set of configurations but makes it easier to produce + a simple plugin which needs to configure them. """ - return _StaticPlugin( - name=name, - data_converter=data_converter, - client_interceptors=client_interceptors, - activities=activities, - nexus_service_handlers=nexus_service_handlers, - workflows=workflows, - workflow_runner=workflow_runner, - worker_interceptors=worker_interceptors, - workflow_failure_exception_types=workflow_failure_exception_types, - run_context=run_context, - ) - - -class _StaticPlugin(Plugin): + def __init__( self, name: str, *, - data_converter: PluginParameter[temporalio.converter.DataConverter], - client_interceptors: PluginParameter[Sequence[temporalio.client.Interceptor]], - activities: PluginParameter[Sequence[Callable]], - nexus_service_handlers: PluginParameter[Sequence[Any]], - workflows: PluginParameter[Sequence[Type]], - workflow_runner: PluginParameter[WorkflowRunner], - worker_interceptors: PluginParameter[Sequence[temporalio.worker.Interceptor]], + data_converter: PluginParameter[temporalio.converter.DataConverter] = None, + client_interceptors: PluginParameter[ + Sequence[temporalio.client.Interceptor] + ] = None, + activities: PluginParameter[Sequence[Callable]] = None, + nexus_service_handlers: PluginParameter[Sequence[Any]] = None, + workflows: PluginParameter[Sequence[Type]] = None, + workflow_runner: PluginParameter[WorkflowRunner] = None, + worker_interceptors: PluginParameter[ + Sequence[temporalio.worker.Interceptor] + ] = None, workflow_failure_exception_types: PluginParameter[ Sequence[Type[BaseException]] - ], - run_context: Optional[Callable[[], AbstractAsyncContextManager[None]]], + ] = None, + run_context: Optional[Callable[[], AbstractAsyncContextManager[None]]] = None, ) -> None: + """Create a simple plugin with configurable parameters. Each of the parameters will be applied to any + component for which they are applicable. All arguments are optional, and all but run_context can also + be callables for more complex modification. See the type PluginParameter above. + For details on each argument, see below. + + Args: + name: The name of the plugin. + data_converter: Data converter for serialization, or callable to customize existing one. + Applied to the Client and Replayer. + client_interceptors: Client interceptors to append, or callable to customize existing ones. + Applied to the Client. + activities: Activity functions to append, or callable to customize existing ones. + Applied to the Worker. + nexus_service_handlers: Nexus service handlers to append, or callable to customize existing ones. + Applied to the Worker. + workflows: Workflow classes to append, or callable to customize existing ones. + Applied to the Worker and Replayer. + workflow_runner: Workflow runner, or callable to customize existing one. + Applied to the Worker and Replayer. + worker_interceptors: Worker interceptors to append, or callable to customize existing ones. + Applied to the Worker and Replayer. + workflow_failure_exception_types: Exception types for workflow failures to append, + or callable to customize existing ones. Applied to the Worker and Replayer. + run_context: Optional async context manager producer to wrap worker/replayer execution. + Applied to the Worker and Replayer. + + Returns: + A configured Plugin instance. + """ self._name = name self.data_converter = data_converter self.client_interceptors = client_interceptors @@ -129,21 +103,22 @@ def __init__( self.run_context = run_context def name(self) -> str: + """See base class.""" return self._name def configure_client(self, config: ClientConfig) -> ClientConfig: - self._set_dict( - config, # type: ignore - "data_converter", - self._resolve_parameter(config.get("data_converter"), self.data_converter), + """See base class.""" + data_converter = _resolve_parameter( + config.get("data_converter"), self.data_converter ) - self._set_dict( - config, # type: ignore - "interceptors", - self._resolve_append_parameter( - config.get("interceptors"), self.client_interceptors - ), + if data_converter: + config["data_converter"] = data_converter + + interceptors = _resolve_append_parameter( + config.get("interceptors"), self.client_interceptors ) + if interceptors is not None: + config["interceptors"] = interceptors return config @@ -152,90 +127,85 @@ async def connect_service_client( config: ConnectConfig, next: Callable[[ConnectConfig], Awaitable[ServiceClient]], ) -> temporalio.service.ServiceClient: + """See base class.""" return await next(config) def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - self._set_dict( - config, # type: ignore - "activities", - self._resolve_append_parameter(config.get("activities"), self.activities), + """See base class.""" + activities = _resolve_append_parameter( + config.get("activities"), self.activities ) - self._set_dict( - config, # type: ignore - "nexus_service_handlers", - self._resolve_append_parameter( - config.get("nexus_service_handlers"), self.nexus_service_handlers - ), - ) - self._set_dict( - config, # type: ignore - "workflows", - self._resolve_append_parameter(config.get("workflows"), self.workflows), + if activities: + config["activities"] = activities + + nexus_service_handlers = _resolve_append_parameter( + config.get("nexus_service_handlers"), self.nexus_service_handlers ) + if nexus_service_handlers is not None: + config["nexus_service_handlers"] = nexus_service_handlers + + workflows = _resolve_append_parameter(config.get("workflows"), self.workflows) + if workflows is not None: + config["workflows"] = workflows - self._set_dict( - config, # type: ignore - "workflow_runner", - self._resolve_parameter( - config.get("workflow_runner"), self.workflow_runner - ), + workflow_runner = _resolve_parameter( + config.get("workflow_runner"), self.workflow_runner ) - self._set_dict( - config, # type: ignore - "interceptors", - self._resolve_append_parameter( - config.get("interceptors"), self.worker_interceptors - ), + if workflow_runner: + config["workflow_runner"] = workflow_runner + + interceptors = _resolve_append_parameter( + config.get("interceptors"), self.worker_interceptors ) - self._set_dict( - config, # type: ignore - "workflow_failure_exception_types", - self._resolve_append_parameter( - config.get("workflow_failure_exception_types"), - self.workflow_failure_exception_types, - ), + if interceptors is not None: + config["interceptors"] = interceptors + + failure_exception_types = _resolve_append_parameter( + config.get("workflow_failure_exception_types"), + self.workflow_failure_exception_types, ) + if failure_exception_types is not None: + config["workflow_failure_exception_types"] = failure_exception_types return config def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - self._set_dict( - config, # type: ignore - "data_converter", - self._resolve_parameter(config.get("data_converter"), self.data_converter), + """See base class.""" + data_converter = _resolve_parameter( + config.get("data_converter"), self.data_converter ) - self._set_dict( - config, # type: ignore - "workflows", - self._resolve_append_parameter(config.get("workflows"), self.workflows), - ) - self._set_dict( - config, # type: ignore - "workflow_runner", - self._resolve_parameter( - config.get("workflow_runner"), self.workflow_runner - ), + if data_converter: + config["data_converter"] = data_converter + + workflows = _resolve_append_parameter(config.get("workflows"), self.workflows) + if workflows is not None: + config["workflows"] = workflows + + workflow_runner = _resolve_parameter( + config.get("workflow_runner"), self.workflow_runner ) - self._set_dict( - config, # type: ignore - "interceptors", - self._resolve_append_parameter( - config.get("interceptors"), self.worker_interceptors - ), + if workflow_runner: + config["workflow_runner"] = workflow_runner + + interceptors = _resolve_append_parameter( + config.get("interceptors"), self.worker_interceptors ) - self._set_dict( - config, # type: ignore - "workflow_failure_exception_types", - self._resolve_append_parameter( - config.get("workflow_failure_exception_types"), - self.workflow_failure_exception_types, - ), + if interceptors is not None: + config["interceptors"] = interceptors + + failure_exception_types = _resolve_append_parameter( + config.get("workflow_failure_exception_types"), + self.workflow_failure_exception_types, ) + if failure_exception_types is not None: + config["workflow_failure_exception_types"] = failure_exception_types + return config async def run_worker( self, worker: Worker, next: Callable[[Worker], Awaitable[None]] ) -> None: + """See base class.""" if self.run_context: async with self.run_context(): await next(worker) @@ -252,6 +222,7 @@ async def run_replayer( AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], ], ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + """See base class.""" if self.run_context: async with self.run_context(): async with next(replayer, histories) as results: @@ -260,33 +231,26 @@ async def run_replayer( async with next(replayer, histories) as results: yield results - @staticmethod - def _resolve_parameter( - existing: Optional[T], parameter: PluginParameter[T] - ) -> Optional[T]: - if parameter is None: - return existing - elif callable(parameter): - return cast(Callable[[Optional[T]], Optional[T]], parameter)(existing) - else: - return parameter - - @staticmethod - def _resolve_append_parameter( - existing: Optional[Sequence[T]], parameter: PluginParameter[Sequence[T]] - ) -> Optional[Sequence[T]]: - if parameter is None: - return existing - elif callable(parameter): - return cast( - Callable[[Optional[Sequence[T]]], Optional[Sequence[T]]], parameter - )(existing) - else: - return list(existing or []) + list(parameter) - @staticmethod - def _set_dict(config: dict[str, Any], key: str, value: Optional[Any]) -> None: - if value is not None: - config[key] = value - else: - del config[key] +def _resolve_parameter( + existing: Optional[T], parameter: PluginParameter[T] +) -> Optional[T]: + if parameter is None: + return existing + elif callable(parameter): + return cast(Callable[[Optional[T]], Optional[T]], parameter)(existing) + else: + return parameter + + +def _resolve_append_parameter( + existing: Optional[Sequence[T]], parameter: PluginParameter[Sequence[T]] +) -> Optional[Sequence[T]]: + if parameter is None: + return existing + elif callable(parameter): + return cast( + Callable[[Optional[Sequence[T]]], Optional[Sequence[T]]], parameter + )(existing) + else: + return list(existing or []) + list(parameter) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 363bebb59..bfd2c6fb8 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -13,7 +13,7 @@ from temporalio.client import Client, ClientConfig, OutboundInterceptor, WorkflowHistory from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.converter import DataConverter -from temporalio.plugin import create_plugin +from temporalio.plugin import SimplePlugin from temporalio.service import ConnectConfig, ServiceClient from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( @@ -276,7 +276,7 @@ async def test_replay(client: Client) -> None: async def test_static_plugins(client: Client) -> None: - plugin = create_plugin( + plugin = SimplePlugin( "MyPlugin", data_converter=pydantic_data_converter, workflows=[HelloWorkflow2], @@ -317,7 +317,7 @@ def converter(old: Optional[DataConverter]): raise ValueError("Can't override non-default converter") return pydantic_data_converter - plugin = create_plugin( + plugin = SimplePlugin( "MyPlugin", data_converter=converter, ) @@ -332,7 +332,7 @@ def converter(old: Optional[DataConverter]): Client(**config) # On a sequence, the lambda overrides the existing values - plugin = create_plugin( + plugin = SimplePlugin( "MyPlugin", workflows=lambda workflows: [], ) From dcd3fa033645a69fe8e81db04e218dd3731f9e5c Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Mon, 6 Oct 2025 16:40:55 -0700 Subject: [PATCH 07/13] Early check OpenAI connectivity if outside a workflow --- temporalio/contrib/openai_agents/_invoke_model_activity.py | 6 +++++- temporalio/plugin.py | 7 ++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index aec84350d..dee97181b 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -36,7 +36,7 @@ from pydantic_core import to_json from typing_extensions import Required, TypedDict -from temporalio import activity +from temporalio import activity, workflow from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater from temporalio.exceptions import ApplicationError @@ -156,6 +156,10 @@ class ModelActivity: def __init__(self, model_provider: Optional[ModelProvider] = None): """Initialize the activity with a model provider.""" self._model_provider = model_provider + if model_provider is None and not workflow.in_workflow(): + self._model_provider = OpenAIProvider( + openai_client=AsyncOpenAI(max_retries=0) + ) @activity.defn @_auto_heartbeater diff --git a/temporalio/plugin.py b/temporalio/plugin.py index 032b0e8b0..01d7399f7 100644 --- a/temporalio/plugin.py +++ b/temporalio/plugin.py @@ -38,8 +38,8 @@ class SimplePlugin(temporalio.client.Plugin, temporalio.worker.Plugin): - """A simple plugin definition which limits has a limited set of configurations but makes it easier to produce - a simple plugin which needs to configure them. + """A simple plugin definition which has a limited set of configurations but makes it easier to produce + a plugin which needs to configure them. """ def __init__( @@ -72,7 +72,8 @@ def __init__( data_converter: Data converter for serialization, or callable to customize existing one. Applied to the Client and Replayer. client_interceptors: Client interceptors to append, or callable to customize existing ones. - Applied to the Client. + Applied to the Client. Note, if the provided interceptor is also a worker.Interceptor, + it will be added to any worker which uses that client. activities: Activity functions to append, or callable to customize existing ones. Applied to the Worker. nexus_service_handlers: Nexus service handlers to append, or callable to customize existing ones. From 3549fa7912e16e834bab0c19ad587b68e6b9ac3f Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 7 Oct 2025 07:41:03 -0700 Subject: [PATCH 08/13] PR Feedback --- .../openai_agents/_invoke_model_activity.py | 9 +++----- temporalio/plugin.py | 4 ++-- .../openai_agents/test_openai_replay.py | 10 ++++++++- tests/test_plugins.py | 22 +++++++++++++++++-- 4 files changed, 34 insertions(+), 11 deletions(-) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index dee97181b..2b75e2bd5 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -155,8 +155,9 @@ class ModelActivity: def __init__(self, model_provider: Optional[ModelProvider] = None): """Initialize the activity with a model provider.""" - self._model_provider = model_provider - if model_provider is None and not workflow.in_workflow(): + if model_provider: + self._model_provider = model_provider + else: self._model_provider = OpenAIProvider( openai_client=AsyncOpenAI(max_retries=0) ) @@ -165,10 +166,6 @@ def __init__(self, model_provider: Optional[ModelProvider] = None): @_auto_heartbeater async def invoke_model_activity(self, input: ActivityModelInput) -> ModelResponse: """Activity that invokes a model with the given input.""" - if not self._model_provider: - self._model_provider = OpenAIProvider( - openai_client=AsyncOpenAI(max_retries=0) - ) model = self._model_provider.get_model(input.get("model_name")) async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: diff --git a/temporalio/plugin.py b/temporalio/plugin.py index 01d7399f7..04435c7e8 100644 --- a/temporalio/plugin.py +++ b/temporalio/plugin.py @@ -86,8 +86,8 @@ def __init__( Applied to the Worker and Replayer. workflow_failure_exception_types: Exception types for workflow failures to append, or callable to customize existing ones. Applied to the Worker and Replayer. - run_context: Optional async context manager producer to wrap worker/replayer execution. - Applied to the Worker and Replayer. + run_context: A place to run custom code to wrap around the Worker (or Replayer) execution. + Specifically, it's an async context manager producer. Applied to the Worker and Replayer. Returns: A configured Plugin instance. diff --git a/tests/contrib/openai_agents/test_openai_replay.py b/tests/contrib/openai_agents/test_openai_replay.py index d625343b8..b807af383 100644 --- a/tests/contrib/openai_agents/test_openai_replay.py +++ b/tests/contrib/openai_agents/test_openai_replay.py @@ -1,6 +1,8 @@ from pathlib import Path import pytest +from agents import OpenAIProvider +from openai import AsyncOpenAI from temporalio.client import WorkflowHistory from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin @@ -42,5 +44,11 @@ async def test_replay(file_name: str) -> None: InputGuardrailWorkflow, OutputGuardrailWorkflow, ], - plugins=[OpenAIAgentsPlugin()], + plugins=[ + OpenAIAgentsPlugin( + model_provider=OpenAIProvider( + openai_client=AsyncOpenAI(max_retries=0, api_key="PLACEHOLDER") + ) + ) + ], ).replay_workflow(WorkflowHistory.from_json("fake", history_json)) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index bfd2c6fb8..5571841b4 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -275,7 +275,7 @@ async def test_replay(client: Client) -> None: await replayer.replay_workflow(await handle.fetch_history()) -async def test_static_plugins(client: Client) -> None: +async def test_simple_plugins(client: Client) -> None: plugin = SimplePlugin( "MyPlugin", data_converter=pydantic_data_converter, @@ -311,7 +311,7 @@ async def test_static_plugins(client: Client) -> None: assert replayer.config().get("workflows") == [HelloWorkflow, HelloWorkflow2] -async def test_static_plugins_callables(client: Client) -> None: +async def test_simple_plugins_callables(client: Client) -> None: def converter(old: Optional[DataConverter]): if old != temporalio.converter.default(): raise ValueError("Can't override non-default converter") @@ -344,3 +344,21 @@ def converter(old: Optional[DataConverter]): plugins=[plugin], ) assert worker.config().get("workflows") == [] + + +class MediumPlugin(SimplePlugin): + def __init__(self): + super().__init__("MediumPlugin", data_converter=pydantic_data_converter) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + config = super().configure_worker(config) + config["task_queue"] = "override" + return config + + +async def test_medium_plugin(client: Client) -> None: + plugin = MediumPlugin() + worker = Worker( + client, task_queue="queue", plugins=[plugin], workflows=[HelloWorkflow] + ) + assert worker.config().get("task_queue") == "override" From c90ee811cfbbec7435af93f6836219562c432ab3 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 7 Oct 2025 07:42:13 -0700 Subject: [PATCH 09/13] Don't register activities in replayer --- tests/contrib/openai_agents/test_openai_replay.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/contrib/openai_agents/test_openai_replay.py b/tests/contrib/openai_agents/test_openai_replay.py index b807af383..8f10fcf82 100644 --- a/tests/contrib/openai_agents/test_openai_replay.py +++ b/tests/contrib/openai_agents/test_openai_replay.py @@ -46,9 +46,8 @@ async def test_replay(file_name: str) -> None: ], plugins=[ OpenAIAgentsPlugin( - model_provider=OpenAIProvider( - openai_client=AsyncOpenAI(max_retries=0, api_key="PLACEHOLDER") - ) + # Activities won't be used by replayer + register_activities=False, ) ], ).replay_workflow(WorkflowHistory.from_json("fake", history_json)) From 23124328421d75b269f4861518e197cd3adf4aaa Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 7 Oct 2025 09:27:38 -0700 Subject: [PATCH 10/13] Delay activity construction until needed --- .../contrib/openai_agents/_temporal_openai_agents.py | 11 +++++++---- tests/contrib/openai_agents/test_openai_replay.py | 5 +---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 43aaddf39..a553367f9 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -283,7 +283,11 @@ def __init__( "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" ) - if register_activities: + # Delay activity construction until they are actually needed + def add_activities(activities: Optional[Sequence[Callable]]) -> Sequence[Callable]: + if not register_activities: + return activities or [] + new_activities = [ModelActivity(model_provider).invoke_model_activity] server_names = [server.name for server in mcp_server_providers] @@ -294,8 +298,7 @@ def __init__( for mcp_server in mcp_server_providers: new_activities.extend(mcp_server._get_activities()) - else: - new_activities = None + return (list(activities) or []) + new_activities def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner: if not runner: @@ -318,7 +321,7 @@ async def run_context() -> AsyncIterator[None]: name="OpenAIAgentsPlugin", data_converter=_data_converter, worker_interceptors=[OpenAIAgentsTracingInterceptor()], - activities=new_activities, + activities=add_activities, workflow_runner=workflow_runner, workflow_failure_exception_types=[AgentsWorkflowError], run_context=lambda: run_context(), diff --git a/tests/contrib/openai_agents/test_openai_replay.py b/tests/contrib/openai_agents/test_openai_replay.py index 8f10fcf82..c8ca1f47f 100644 --- a/tests/contrib/openai_agents/test_openai_replay.py +++ b/tests/contrib/openai_agents/test_openai_replay.py @@ -45,9 +45,6 @@ async def test_replay(file_name: str) -> None: OutputGuardrailWorkflow, ], plugins=[ - OpenAIAgentsPlugin( - # Activities won't be used by replayer - register_activities=False, - ) + OpenAIAgentsPlugin() ], ).replay_workflow(WorkflowHistory.from_json("fake", history_json)) From 259701c06fa8ec215e36c922b0eee7bcc645ff87 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 7 Oct 2025 09:28:43 -0700 Subject: [PATCH 11/13] Linting --- temporalio/contrib/openai_agents/_temporal_openai_agents.py | 4 +++- tests/contrib/openai_agents/test_openai_replay.py | 4 +--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index a553367f9..aadc2e9e5 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -284,7 +284,9 @@ def __init__( ) # Delay activity construction until they are actually needed - def add_activities(activities: Optional[Sequence[Callable]]) -> Sequence[Callable]: + def add_activities( + activities: Optional[Sequence[Callable]], + ) -> Sequence[Callable]: if not register_activities: return activities or [] diff --git a/tests/contrib/openai_agents/test_openai_replay.py b/tests/contrib/openai_agents/test_openai_replay.py index c8ca1f47f..2d76cf765 100644 --- a/tests/contrib/openai_agents/test_openai_replay.py +++ b/tests/contrib/openai_agents/test_openai_replay.py @@ -44,7 +44,5 @@ async def test_replay(file_name: str) -> None: InputGuardrailWorkflow, OutputGuardrailWorkflow, ], - plugins=[ - OpenAIAgentsPlugin() - ], + plugins=[OpenAIAgentsPlugin()], ).replay_workflow(WorkflowHistory.from_json("fake", history_json)) From 7c9c4747ab7854433fbeb616b4d0dd3f8593fa41 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 7 Oct 2025 09:30:50 -0700 Subject: [PATCH 12/13] Linting --- temporalio/contrib/openai_agents/_temporal_openai_agents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index aadc2e9e5..9df481d00 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -300,7 +300,7 @@ def add_activities( for mcp_server in mcp_server_providers: new_activities.extend(mcp_server._get_activities()) - return (list(activities) or []) + new_activities + return list(activities or []) + new_activities def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner: if not runner: From ee99a6365f544239991442e6ed25c0449cd613c6 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 7 Oct 2025 09:31:53 -0700 Subject: [PATCH 13/13] Simplify statement --- .../contrib/openai_agents/_invoke_model_activity.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 2b75e2bd5..ed2f16d3d 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -155,12 +155,9 @@ class ModelActivity: def __init__(self, model_provider: Optional[ModelProvider] = None): """Initialize the activity with a model provider.""" - if model_provider: - self._model_provider = model_provider - else: - self._model_provider = OpenAIProvider( - openai_client=AsyncOpenAI(max_retries=0) - ) + self._model_provider = model_provider or OpenAIProvider( + openai_client=AsyncOpenAI(max_retries=0) + ) @activity.defn @_auto_heartbeater