diff --git a/temporalio/client.py b/temporalio/client.py index 46a30cfa7..6c26d41ef 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -70,11 +70,13 @@ WorkflowSerializationContext, ) from temporalio.service import ( + ConnectConfig, HttpConnectProxyConfig, KeepAliveConfig, RetryConfig, RPCError, RPCStatusCode, + ServiceClient, TLSConfig, ) @@ -198,12 +200,14 @@ async def connect( http_connect_proxy_config=http_connect_proxy_config, ) - root_plugin: Plugin = _RootPlugin() + def make_lambda(plugin, next): + return lambda config: plugin.connect_service_client(config, next) + + next_function = ServiceClient.connect for plugin in reversed(plugins): - plugin.init_client_plugin(root_plugin) - root_plugin = plugin + next_function = make_lambda(plugin, next_function) - service_client = await root_plugin.connect_service_client(connect_config) + service_client = await next_function(connect_config) return Client( service_client, @@ -243,12 +247,10 @@ def __init__( plugins=plugins, ) - root_plugin: Plugin = _RootPlugin() - for plugin in reversed(plugins): - plugin.init_client_plugin(root_plugin) - root_plugin = plugin + for plugin in plugins: + config = plugin.configure_client(config) - self._init_from_config(root_plugin.configure_client(config)) + self._init_from_config(config) def _init_from_config(self, config: ClientConfig): self._config = config @@ -7541,20 +7543,6 @@ def name(self) -> str: """ return type(self).__module__ + "." + type(self).__qualname__ - @abstractmethod - def init_client_plugin(self, next: Plugin) -> None: - """Initialize this plugin in the plugin chain. - - This method sets up the chain of responsibility pattern by providing a reference - to the next plugin in the chain. It is called during client creation to build - the plugin chain. Note, this may be called twice in the case of :py:meth:`connect`. - Implementations should store this reference and call the corresponding method - of the next plugin on method calls. - - Args: - next: The next plugin in the chain to delegate to. - """ - @abstractmethod def configure_client(self, config: ClientConfig) -> ClientConfig: """Hook called when creating a client to allow modification of configuration. @@ -7572,8 +7560,10 @@ def configure_client(self, config: ClientConfig) -> ClientConfig: @abstractmethod async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: + self, + config: ConnectConfig, + next: Callable[[ConnectConfig], Awaitable[ServiceClient]], + ) -> ServiceClient: """Hook called when connecting to the Temporal service. This method is called during service client connection and allows plugins @@ -7586,16 +7576,3 @@ async def connect_service_client( Returns: The connected service client. """ - - -class _RootPlugin(Plugin): - def init_client_plugin(self, next: Plugin) -> None: - raise NotImplementedError() - - def configure_client(self, config: ClientConfig) -> ClientConfig: - return config - - async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: - return await temporalio.service.ServiceClient.connect(config) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 6409f0bb1..ed2f16d3d 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -36,7 +36,7 @@ from pydantic_core import to_json from typing_extensions import Required, TypedDict -from temporalio import activity +from temporalio import activity, workflow from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater from temporalio.exceptions import ApplicationError diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 2334ec08c..9df481d00 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -22,11 +22,7 @@ from agents.run import get_default_agent_runner, set_default_agent_runner from agents.tracing import get_trace_provider from agents.tracing.provider import DefaultTraceProvider -from openai.types.responses import ResponsePromptParam -import temporalio.client -import temporalio.worker -from temporalio.client import ClientConfig from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import ( @@ -47,13 +43,8 @@ DataConverter, DefaultPayloadConverter, ) -from temporalio.worker import ( - Replayer, - ReplayerConfig, - Worker, - WorkerConfig, - WorkflowReplayResult, -) +from temporalio.plugin import SimplePlugin +from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner # Unsupported on python 3.9 @@ -172,7 +163,21 @@ def __init__(self) -> None: super().__init__(ToJsonOptions(exclude_unset=True)) -class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): +def _data_converter(converter: Optional[DataConverter]) -> DataConverter: + if converter is None: + return DataConverter(payload_converter_class=OpenAIPayloadConverter) + elif converter.payload_converter_class is DefaultPayloadConverter: + return dataclasses.replace( + converter, payload_converter_class=OpenAIPayloadConverter + ) + elif not isinstance(converter.payload_converter, OpenAIPayloadConverter): + raise ValueError( + "The payload converter must be of type OpenAIPayloadConverter." + ) + return converter + + +class OpenAIAgentsPlugin(SimplePlugin): """Temporal plugin for integrating OpenAI agents with Temporal workflows. .. warning:: @@ -278,127 +283,48 @@ def __init__( "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" ) - self._model_params = model_params - self._model_provider = model_provider - self._mcp_server_providers = mcp_server_providers - self._register_activities = register_activities - - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: - """Set the next client plugin""" - self.next_client_plugin = next - - async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: - """No modifications to service client""" - return await self.next_client_plugin.connect_service_client(config) - - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: - """Set the next worker plugin""" - self.next_worker_plugin = next - - @staticmethod - def _data_converter(converter: Optional[DataConverter]) -> DataConverter: - if converter is None: - return DataConverter(payload_converter_class=OpenAIPayloadConverter) - elif converter.payload_converter_class is DefaultPayloadConverter: - return dataclasses.replace( - converter, payload_converter_class=OpenAIPayloadConverter - ) - elif not isinstance(converter.payload_converter, OpenAIPayloadConverter): - raise ValueError( - "The payload converter must be of type OpenAIPayloadConverter." - ) - return converter - - def configure_client(self, config: ClientConfig) -> ClientConfig: - """Configure the Temporal client for OpenAI agents integration. - - This method sets up the Pydantic data converter to enable proper - serialization of OpenAI agent objects and responses. - - Args: - config: The client configuration to modify. - - Returns: - The modified client configuration. - """ - config["data_converter"] = self._data_converter(config["data_converter"]) - return self.next_client_plugin.configure_client(config) - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - """Configure the Temporal worker for OpenAI agents integration. - - This method adds the necessary interceptors and activities for OpenAI - agent execution: - - Adds tracing interceptors for OpenAI agent interactions - - Registers model execution activities - - Args: - config: The worker configuration to modify. - - Returns: - The modified worker configuration. - """ - config["interceptors"] = list(config.get("interceptors") or []) + [ - OpenAIAgentsTracingInterceptor() - ] + # Delay activity construction until they are actually needed + def add_activities( + activities: Optional[Sequence[Callable]], + ) -> Sequence[Callable]: + if not register_activities: + return activities or [] - if self._register_activities: - new_activities = [ModelActivity(self._model_provider).invoke_model_activity] + new_activities = [ModelActivity(model_provider).invoke_model_activity] - server_names = [server.name for server in self._mcp_server_providers] + server_names = [server.name for server in mcp_server_providers] if len(server_names) != len(set(server_names)): raise ValueError( f"More than one mcp server registered with the same name. Please provide unique names." ) - for mcp_server in self._mcp_server_providers: + for mcp_server in mcp_server_providers: new_activities.extend(mcp_server._get_activities()) - config["activities"] = list(config.get("activities") or []) + new_activities - - runner = config.get("workflow_runner") - if isinstance(runner, SandboxedWorkflowRunner): - config["workflow_runner"] = dataclasses.replace( - runner, - restrictions=runner.restrictions.with_passthrough_modules("mcp"), - ) - - config["workflow_failure_exception_types"] = list( - config.get("workflow_failure_exception_types") or [] - ) + [AgentsWorkflowError] - return self.next_worker_plugin.configure_worker(config) + return list(activities or []) + new_activities - async def run_worker(self, worker: Worker) -> None: - """Run the worker with OpenAI agents temporal overrides. + def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner: + if not runner: + raise ValueError("No WorkflowRunner provided to the OpenAI plugin.") - This method sets up the necessary runtime overrides for OpenAI agents - to work within the Temporal worker context, including custom runners - and trace providers. - - Args: - worker: The worker instance to run. - """ - with set_open_ai_agent_temporal_overrides(self._model_params): - await self.next_worker_plugin.run_worker(worker) - - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - """Configure the replayer for OpenAI Agents.""" - config["interceptors"] = list(config.get("interceptors") or []) + [ - OpenAIAgentsTracingInterceptor() - ] - config["data_converter"] = self._data_converter(config.get("data_converter")) - return self.next_worker_plugin.configure_replayer(config) - - @asynccontextmanager - async def run_replayer( - self, - replayer: Replayer, - histories: AsyncIterator[temporalio.client.WorkflowHistory], - ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: - """Set the OpenAI Overrides during replay""" - with set_open_ai_agent_temporal_overrides(self._model_params): - async with self.next_worker_plugin.run_replayer( - replayer, histories - ) as results: - yield results + # If in sandbox, add additional passthrough + if isinstance(runner, SandboxedWorkflowRunner): + return dataclasses.replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules("mcp"), + ) + return runner + + @asynccontextmanager + async def run_context() -> AsyncIterator[None]: + with set_open_ai_agent_temporal_overrides(model_params): + yield + + super().__init__( + name="OpenAIAgentsPlugin", + data_converter=_data_converter, + worker_interceptors=[OpenAIAgentsTracingInterceptor()], + activities=add_activities, + workflow_runner=workflow_runner, + workflow_failure_exception_types=[AgentsWorkflowError], + run_context=lambda: run_context(), + ) diff --git a/temporalio/plugin.py b/temporalio/plugin.py new file mode 100644 index 000000000..04435c7e8 --- /dev/null +++ b/temporalio/plugin.py @@ -0,0 +1,257 @@ +"""Plugin module for Temporal SDK. + +This module provides plugin functionality that allows customization of both client +and worker behavior in the Temporal SDK through configurable parameters. +""" + +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, +) + +import temporalio.client +import temporalio.converter +import temporalio.worker +from temporalio.client import ClientConfig, WorkflowHistory +from temporalio.service import ConnectConfig, ServiceClient +from temporalio.worker import ( + Replayer, + ReplayerConfig, + Worker, + WorkerConfig, + WorkflowReplayResult, + WorkflowRunner, +) + +T = TypeVar("T") + +PluginParameter = Union[None, T, Callable[[Optional[T]], T]] + + +class SimplePlugin(temporalio.client.Plugin, temporalio.worker.Plugin): + """A simple plugin definition which has a limited set of configurations but makes it easier to produce + a plugin which needs to configure them. + """ + + def __init__( + self, + name: str, + *, + data_converter: PluginParameter[temporalio.converter.DataConverter] = None, + client_interceptors: PluginParameter[ + Sequence[temporalio.client.Interceptor] + ] = None, + activities: PluginParameter[Sequence[Callable]] = None, + nexus_service_handlers: PluginParameter[Sequence[Any]] = None, + workflows: PluginParameter[Sequence[Type]] = None, + workflow_runner: PluginParameter[WorkflowRunner] = None, + worker_interceptors: PluginParameter[ + Sequence[temporalio.worker.Interceptor] + ] = None, + workflow_failure_exception_types: PluginParameter[ + Sequence[Type[BaseException]] + ] = None, + run_context: Optional[Callable[[], AbstractAsyncContextManager[None]]] = None, + ) -> None: + """Create a simple plugin with configurable parameters. Each of the parameters will be applied to any + component for which they are applicable. All arguments are optional, and all but run_context can also + be callables for more complex modification. See the type PluginParameter above. + For details on each argument, see below. + + Args: + name: The name of the plugin. + data_converter: Data converter for serialization, or callable to customize existing one. + Applied to the Client and Replayer. + client_interceptors: Client interceptors to append, or callable to customize existing ones. + Applied to the Client. Note, if the provided interceptor is also a worker.Interceptor, + it will be added to any worker which uses that client. + activities: Activity functions to append, or callable to customize existing ones. + Applied to the Worker. + nexus_service_handlers: Nexus service handlers to append, or callable to customize existing ones. + Applied to the Worker. + workflows: Workflow classes to append, or callable to customize existing ones. + Applied to the Worker and Replayer. + workflow_runner: Workflow runner, or callable to customize existing one. + Applied to the Worker and Replayer. + worker_interceptors: Worker interceptors to append, or callable to customize existing ones. + Applied to the Worker and Replayer. + workflow_failure_exception_types: Exception types for workflow failures to append, + or callable to customize existing ones. Applied to the Worker and Replayer. + run_context: A place to run custom code to wrap around the Worker (or Replayer) execution. + Specifically, it's an async context manager producer. Applied to the Worker and Replayer. + + Returns: + A configured Plugin instance. + """ + self._name = name + self.data_converter = data_converter + self.client_interceptors = client_interceptors + self.activities = activities + self.nexus_service_handlers = nexus_service_handlers + self.workflows = workflows + self.workflow_runner = workflow_runner + self.worker_interceptors = worker_interceptors + self.workflow_failure_exception_types = workflow_failure_exception_types + self.run_context = run_context + + def name(self) -> str: + """See base class.""" + return self._name + + def configure_client(self, config: ClientConfig) -> ClientConfig: + """See base class.""" + data_converter = _resolve_parameter( + config.get("data_converter"), self.data_converter + ) + if data_converter: + config["data_converter"] = data_converter + + interceptors = _resolve_append_parameter( + config.get("interceptors"), self.client_interceptors + ) + if interceptors is not None: + config["interceptors"] = interceptors + + return config + + async def connect_service_client( + self, + config: ConnectConfig, + next: Callable[[ConnectConfig], Awaitable[ServiceClient]], + ) -> temporalio.service.ServiceClient: + """See base class.""" + return await next(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + """See base class.""" + activities = _resolve_append_parameter( + config.get("activities"), self.activities + ) + if activities: + config["activities"] = activities + + nexus_service_handlers = _resolve_append_parameter( + config.get("nexus_service_handlers"), self.nexus_service_handlers + ) + if nexus_service_handlers is not None: + config["nexus_service_handlers"] = nexus_service_handlers + + workflows = _resolve_append_parameter(config.get("workflows"), self.workflows) + if workflows is not None: + config["workflows"] = workflows + + workflow_runner = _resolve_parameter( + config.get("workflow_runner"), self.workflow_runner + ) + if workflow_runner: + config["workflow_runner"] = workflow_runner + + interceptors = _resolve_append_parameter( + config.get("interceptors"), self.worker_interceptors + ) + if interceptors is not None: + config["interceptors"] = interceptors + + failure_exception_types = _resolve_append_parameter( + config.get("workflow_failure_exception_types"), + self.workflow_failure_exception_types, + ) + if failure_exception_types is not None: + config["workflow_failure_exception_types"] = failure_exception_types + + return config + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + """See base class.""" + data_converter = _resolve_parameter( + config.get("data_converter"), self.data_converter + ) + if data_converter: + config["data_converter"] = data_converter + + workflows = _resolve_append_parameter(config.get("workflows"), self.workflows) + if workflows is not None: + config["workflows"] = workflows + + workflow_runner = _resolve_parameter( + config.get("workflow_runner"), self.workflow_runner + ) + if workflow_runner: + config["workflow_runner"] = workflow_runner + + interceptors = _resolve_append_parameter( + config.get("interceptors"), self.worker_interceptors + ) + if interceptors is not None: + config["interceptors"] = interceptors + + failure_exception_types = _resolve_append_parameter( + config.get("workflow_failure_exception_types"), + self.workflow_failure_exception_types, + ) + if failure_exception_types is not None: + config["workflow_failure_exception_types"] = failure_exception_types + + return config + + async def run_worker( + self, worker: Worker, next: Callable[[Worker], Awaitable[None]] + ) -> None: + """See base class.""" + if self.run_context: + async with self.run_context(): + await next(worker) + else: + await next(worker) + + @asynccontextmanager + async def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[WorkflowHistory], + next: Callable[ + [Replayer, AsyncIterator[WorkflowHistory]], + AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], + ], + ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + """See base class.""" + if self.run_context: + async with self.run_context(): + async with next(replayer, histories) as results: + yield results + else: + async with next(replayer, histories) as results: + yield results + + +def _resolve_parameter( + existing: Optional[T], parameter: PluginParameter[T] +) -> Optional[T]: + if parameter is None: + return existing + elif callable(parameter): + return cast(Callable[[Optional[T]], Optional[T]], parameter)(existing) + else: + return parameter + + +def _resolve_append_parameter( + existing: Optional[Sequence[T]], parameter: PluginParameter[Sequence[T]] +) -> Optional[Sequence[T]]: + if parameter is None: + return existing + elif callable(parameter): + return cast( + Callable[[Optional[Sequence[T]]], Optional[Sequence[T]]], parameter + )(existing) + else: + return list(existing or []) + list(parameter) diff --git a/temporalio/worker/_plugin.py b/temporalio/worker/_plugin.py index 0e696a2dd..2f4c04a1e 100644 --- a/temporalio/worker/_plugin.py +++ b/temporalio/worker/_plugin.py @@ -2,7 +2,7 @@ import abc from contextlib import AbstractAsyncContextManager -from typing import TYPE_CHECKING, AsyncIterator +from typing import TYPE_CHECKING, AsyncIterator, Awaitable, Callable from temporalio.client import WorkflowHistory @@ -34,19 +34,6 @@ def name(self) -> str: """ return type(self).__module__ + "." + type(self).__qualname__ - @abc.abstractmethod - def init_worker_plugin(self, next: Plugin) -> None: - """Initialize this plugin in the plugin chain. - - This method sets up the chain of responsibility pattern by providing a reference - to the next plugin in the chain. It is called during worker creation to build - the plugin chain. Implementations should store this reference and call the corresponding method - of the next plugin on method calls. - - Args: - next: The next plugin in the chain to delegate to. - """ - @abc.abstractmethod def configure_worker(self, config: WorkerConfig) -> WorkerConfig: """Hook called when creating a worker to allow modification of configuration. @@ -64,7 +51,9 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: """ @abc.abstractmethod - async def run_worker(self, worker: Worker) -> None: + async def run_worker( + self, worker: Worker, next: Callable[[Worker], Awaitable[None]] + ) -> None: """Hook called when running a worker to allow interception of execution. This method is called when the worker is started and allows plugins to @@ -73,6 +62,7 @@ async def run_worker(self, worker: Worker) -> None: Args: worker: The worker instance to run. + next: Callable to continue the worker execution. """ @abc.abstractmethod @@ -94,26 +84,9 @@ def run_replayer( self, replayer: Replayer, histories: AsyncIterator[WorkflowHistory], + next: Callable[ + [Replayer, AsyncIterator[WorkflowHistory]], + AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], + ], ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: """Hook called when running a replayer to allow interception of execution.""" - - -class _RootPlugin(Plugin): - def init_worker_plugin(self, next: Plugin) -> None: - raise NotImplementedError() - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - return config - - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - return config - - async def run_worker(self, worker: Worker) -> None: - await worker._run() - - def run_replayer( - self, - replayer: Replayer, - histories: AsyncIterator[WorkflowHistory], - ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: - return replayer._workflow_replay_iterator(histories) diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 664602219..72ffa4b56 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -21,7 +21,6 @@ from ..common import HeaderCodecBehavior from ._interceptor import Interceptor -from ._plugin import _RootPlugin from ._worker import load_default_build_id from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner @@ -84,12 +83,9 @@ def __init__( ) # Apply plugin configuration - root_plugin: temporalio.worker.Plugin = _RootPlugin() - for plugin in reversed(plugins): - plugin.init_worker_plugin(root_plugin) - root_plugin = plugin - self._config = root_plugin.configure_replayer(self._config) - self._plugin = root_plugin + self.plugins = plugins + for plugin in plugins: + self._config = plugin.configure_replayer(self._config) # Validate workflows after plugin configuration if not self._config["workflows"]: @@ -176,7 +172,15 @@ def workflow_replay_iterator( An async iterator that returns replayed workflow results as they are replayed. """ - return self._plugin.run_replayer(self, histories) + + def make_lambda(plugin, next): + return lambda r, hs: plugin.run_replayer(r, hs, next) + + next_function = lambda r, hs: r._workflow_replay_iterator(hs) + for plugin in reversed(self.plugins): + next_function = make_lambda(plugin, next_function) + + return next_function(self, histories) @asynccontextmanager async def _workflow_replay_iterator( diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index dea516da8..ad238507c 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -38,7 +38,7 @@ from ._activity import SharedStateManager, _ActivityWorker from ._interceptor import Interceptor from ._nexus import _NexusWorker -from ._plugin import Plugin, _RootPlugin +from ._plugin import Plugin from ._tuning import WorkerTuner from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner @@ -377,12 +377,9 @@ def __init__( ) plugins = plugins_from_client + list(plugins) - root_plugin: Plugin = _RootPlugin() - for plugin in reversed(plugins): - plugin.init_worker_plugin(root_plugin) - root_plugin = plugin - config = root_plugin.configure_worker(config) - self._plugin = root_plugin + self.plugins = plugins + for plugin in plugins: + config = plugin.configure_worker(config) self._init_from_config(client, config) @@ -690,7 +687,15 @@ async def run(self) -> None: also cancel the shutdown process. Therefore users are encouraged to use explicit shutdown instead. """ - await self._plugin.run_worker(self) + + def make_lambda(plugin, next): + return lambda w: plugin.run_worker(w, next) + + next_function = lambda w: w._run() + for plugin in reversed(self.plugins): + next_function = make_lambda(plugin, next_function) + + await next_function(self) async def _run(self): # Eagerly validate which will do a namespace check in Core diff --git a/tests/contrib/openai_agents/test_openai_replay.py b/tests/contrib/openai_agents/test_openai_replay.py index d625343b8..2d76cf765 100644 --- a/tests/contrib/openai_agents/test_openai_replay.py +++ b/tests/contrib/openai_agents/test_openai_replay.py @@ -1,6 +1,8 @@ from pathlib import Path import pytest +from agents import OpenAIProvider +from openai import AsyncOpenAI from temporalio.client import WorkflowHistory from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin diff --git a/tests/test_plugins.py b/tests/test_plugins.py index eb08bba2d..5571841b4 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -2,15 +2,19 @@ import uuid import warnings from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import AsyncIterator, cast +from typing import AsyncIterator, Awaitable, Callable, Optional, cast import pytest import temporalio.client +import temporalio.converter import temporalio.worker from temporalio import workflow -from temporalio.client import Client, ClientConfig, OutboundInterceptor, Plugin +from temporalio.client import Client, ClientConfig, OutboundInterceptor, WorkflowHistory from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.converter import DataConverter +from temporalio.plugin import SimplePlugin +from temporalio.service import ConnectConfig, ServiceClient from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( Replayer, @@ -37,21 +41,20 @@ class MyClientPlugin(temporalio.client.Plugin): def __init__(self): self.interceptor = TestClientInterceptor() - def init_client_plugin(self, next: Plugin) -> None: - self.next_client_plugin = next - def configure_client(self, config: ClientConfig) -> ClientConfig: config["namespace"] = "replaced_namespace" config["interceptors"] = list(config.get("interceptors") or []) + [ self.interceptor ] - return self.next_client_plugin.configure_client(config) + return config async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: + self, + config: ConnectConfig, + next: Callable[[ConnectConfig], Awaitable[ServiceClient]], + ) -> ServiceClient: config.api_key = "replaced key" - return await self.next_client_plugin.connect_service_client(config) + return await next(config) async def test_client_plugin(client: Client, env: WorkflowEnvironment): @@ -73,42 +76,41 @@ async def test_client_plugin(client: Client, env: WorkflowEnvironment): class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: - self.next_worker_plugin = next - - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: - self.next_client_plugin = next - def configure_client(self, config: ClientConfig) -> ClientConfig: - return self.next_client_plugin.configure_client(config) + return config def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "combined" - return self.next_worker_plugin.configure_worker(config) + return config async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: - return await self.next_client_plugin.connect_service_client(config) + self, + config: ConnectConfig, + next: Callable[[ConnectConfig], Awaitable[ServiceClient]], + ) -> ServiceClient: + return await next(config) - async def run_worker(self, worker: Worker) -> None: - await self.next_worker_plugin.run_worker(worker) + async def run_worker( + self, worker: Worker, next: Callable[[Worker], Awaitable[None]] + ) -> None: + await next(worker) def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - return self.next_worker_plugin.configure_replayer(config) + return config def run_replayer( self, replayer: Replayer, histories: AsyncIterator[temporalio.client.WorkflowHistory], + next: Callable[ + [Replayer, AsyncIterator[WorkflowHistory]], + AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], + ], ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: - return self.next_worker_plugin.run_replayer(replayer, histories) + return next(replayer, histories) class MyWorkerPlugin(temporalio.worker.Plugin): - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: - self.next_worker_plugin = next - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "replaced_queue" runner = config.get("workflow_runner") @@ -117,20 +119,26 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: runner, restrictions=runner.restrictions.with_passthrough_modules("my_module"), ) - return self.next_worker_plugin.configure_worker(config) + return config - async def run_worker(self, worker: Worker) -> None: - await self.next_worker_plugin.run_worker(worker) + async def run_worker( + self, worker: Worker, next: Callable[[Worker], Awaitable[None]] + ) -> None: + await next(worker) def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - return self.next_worker_plugin.configure_replayer(config) + return config def run_replayer( self, replayer: Replayer, histories: AsyncIterator[temporalio.client.WorkflowHistory], + next: Callable[ + [Replayer, AsyncIterator[WorkflowHistory]], + AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], + ], ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: - return self.next_worker_plugin.run_replayer(replayer, histories) + return next(replayer, histories) async def test_worker_plugin_basic_config(client: Client) -> None: @@ -193,40 +201,42 @@ async def test_worker_sandbox_restrictions(client: Client) -> None: class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: - self.next_worker_plugin = next - - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: - self.next_client_plugin = next - def configure_client(self, config: ClientConfig) -> ClientConfig: config["data_converter"] = pydantic_data_converter - return self.next_client_plugin.configure_client(config) + return config def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] - return self.next_worker_plugin.configure_worker(config) + return config def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: config["data_converter"] = pydantic_data_converter config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] - return self.next_worker_plugin.configure_replayer(config) + return config - async def run_worker(self, worker: Worker) -> None: - await self.next_worker_plugin.run_worker(worker) + async def run_worker( + self, worker: Worker, next: Callable[[Worker], Awaitable[None]] + ) -> None: + await next(worker) async def connect_service_client( - self, config: temporalio.service.ConnectConfig + self, + config: temporalio.service.ConnectConfig, + next: Callable[[ConnectConfig], Awaitable[ServiceClient]], ) -> temporalio.service.ServiceClient: - return await self.next_client_plugin.connect_service_client(config) + return await next(config) @asynccontextmanager async def run_replayer( self, replayer: Replayer, histories: AsyncIterator[temporalio.client.WorkflowHistory], + next: Callable[ + [Replayer, AsyncIterator[WorkflowHistory]], + AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]], + ], ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: - async with self.next_worker_plugin.run_replayer(replayer, histories) as result: + async with next(replayer, histories) as result: yield result @@ -237,6 +247,13 @@ async def run(self, name: str) -> str: return f"Hello, {name}!" +@workflow.defn +class HelloWorkflow2: + @workflow.run + async def run(self, name: str) -> str: + return f"Hello, {name}!" + + async def test_replay(client: Client) -> None: plugin = ReplayCheckPlugin() new_config = client.config() @@ -256,3 +273,92 @@ async def test_replay(client: Client) -> None: assert replayer.config().get("data_converter") == pydantic_data_converter await replayer.replay_workflow(await handle.fetch_history()) + + +async def test_simple_plugins(client: Client) -> None: + plugin = SimplePlugin( + "MyPlugin", + data_converter=pydantic_data_converter, + workflows=[HelloWorkflow2], + ) + config = client.config() + config["plugins"] = [plugin] + new_client = Client(**config) + + assert new_client.data_converter == pydantic_data_converter + + # Test without plugin registered in client + worker = Worker( + client, + task_queue="queue", + activities=[never_run_activity], + workflows=[HelloWorkflow], + plugins=[plugin], + ) + # On a sequence, a value is appended + assert worker.config().get("workflows") == [HelloWorkflow, HelloWorkflow2] + + # Test with plugin registered in client + worker = Worker( + new_client, + task_queue="queue", + activities=[never_run_activity], + ) + assert worker.config().get("workflows") == [HelloWorkflow2] + + replayer = Replayer(workflows=[HelloWorkflow], plugins=[plugin]) + assert replayer.config().get("data_converter") == pydantic_data_converter + assert replayer.config().get("workflows") == [HelloWorkflow, HelloWorkflow2] + + +async def test_simple_plugins_callables(client: Client) -> None: + def converter(old: Optional[DataConverter]): + if old != temporalio.converter.default(): + raise ValueError("Can't override non-default converter") + return pydantic_data_converter + + plugin = SimplePlugin( + "MyPlugin", + data_converter=converter, + ) + config = client.config() + config["plugins"] = [plugin] + new_client = Client(**config) + + assert new_client.data_converter == pydantic_data_converter + + with pytest.raises(ValueError): + config["data_converter"] = pydantic_data_converter + Client(**config) + + # On a sequence, the lambda overrides the existing values + plugin = SimplePlugin( + "MyPlugin", + workflows=lambda workflows: [], + ) + worker = Worker( + client, + task_queue="queue", + workflows=[HelloWorkflow], + activities=[never_run_activity], + plugins=[plugin], + ) + assert worker.config().get("workflows") == [] + + +class MediumPlugin(SimplePlugin): + def __init__(self): + super().__init__("MediumPlugin", data_converter=pydantic_data_converter) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + config = super().configure_worker(config) + config["task_queue"] = "override" + return config + + +async def test_medium_plugin(client: Client) -> None: + plugin = MediumPlugin() + worker = Worker( + client, task_queue="queue", plugins=[plugin], workflows=[HelloWorkflow] + ) + assert worker.config().get("task_queue") == "override"