diff --git a/temporalio/contrib/openai_agents/_model_parameters.py b/temporalio/contrib/openai_agents/_model_parameters.py index 12d83331c..b4c4c8245 100644 --- a/temporalio/contrib/openai_agents/_model_parameters.py +++ b/temporalio/contrib/openai_agents/_model_parameters.py @@ -70,3 +70,6 @@ class ModelActivityParameters: priority: Priority = Priority.default """Priority for the activity execution.""" + + use_local_activity: bool = False + """Whether to use a local activity. If changed during a workflow execution, that would break determinism.""" diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index 3815904e2..bd69af414 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -154,20 +154,32 @@ def make_tool_info(tool: Tool) -> ToolInput: else: summary = None - return await workflow.execute_activity_method( - ModelActivity.invoke_model_activity, - activity_input, - summary=summary, - task_queue=self.model_params.task_queue, - schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, - schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, - start_to_close_timeout=self.model_params.start_to_close_timeout, - heartbeat_timeout=self.model_params.heartbeat_timeout, - retry_policy=self.model_params.retry_policy, - cancellation_type=self.model_params.cancellation_type, - versioning_intent=self.model_params.versioning_intent, - priority=self.model_params.priority, - ) + if self.model_params.use_local_activity: + return await workflow.execute_local_activity_method( + ModelActivity.invoke_model_activity, + activity_input, + summary=summary, + schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, + schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, + start_to_close_timeout=self.model_params.start_to_close_timeout, + retry_policy=self.model_params.retry_policy, + cancellation_type=self.model_params.cancellation_type, + ) + else: + return await workflow.execute_activity_method( + ModelActivity.invoke_model_activity, + activity_input, + summary=summary, + task_queue=self.model_params.task_queue, + schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, + schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, + start_to_close_timeout=self.model_params.start_to_close_timeout, + heartbeat_timeout=self.model_params.heartbeat_timeout, + retry_policy=self.model_params.retry_policy, + cancellation_type=self.model_params.cancellation_type, + versioning_intent=self.model_params.versioning_intent, + priority=self.model_params.priority, + ) def stream_response( self, diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 33e150cbf..5767156f2 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -2670,3 +2670,34 @@ async def test_model_conversion_loops(): triage_agent = seat_booking_agent.handoffs[0] assert isinstance(triage_agent, Agent) assert isinstance(triage_agent.model, _TemporalModelStub) + + +async def test_local_hello_world_agent(client: Client): + new_config = client.config() + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + use_local_activity=True, + ), + model_provider=TestModelProvider(TestHelloModel()), + ) + ] + client = Client(**new_config) + + async with new_worker(client, HelloWorldAgent) as worker: + handle = await client.start_workflow( + HelloWorldAgent.run, + "Tell me about recursion in programming.", + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=5), + ) + result = await handle.result() + assert result == "test" + + local_activity_found = False + async for e in handle.fetch_history_events(): + if e.HasField("marker_recorded_event_attributes"): + local_activity_found = True + assert local_activity_found