diff --git a/src/agentex/lib/core/temporal/workers/worker.py b/src/agentex/lib/core/temporal/workers/worker.py index 464b44a4..04babb54 100644 --- a/src/agentex/lib/core/temporal/workers/worker.py +++ b/src/agentex/lib/core/temporal/workers/worker.py @@ -4,7 +4,7 @@ import uuid from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, overload from aiohttp import web from temporalio.client import Client @@ -99,10 +99,28 @@ def __init__( self.healthy = False self.health_check_port = health_check_port + @overload async def run( self, activities: list[Callable], + *, workflow: type, + ) -> None: ... + + @overload + async def run( + self, + activities: list[Callable], + *, + workflows: list[type], + ) -> None: ... + + async def run( + self, + activities: list[Callable], + *, + workflow: type | None = None, + workflows: list[type] | None = None, ): await self.start_health_check_server() await self._register_agent() @@ -115,11 +133,14 @@ async def run( if debug_enabled: logger.info("🐛 [WORKER] Temporal debug mode enabled - deadlock detection disabled") + if workflow is None and workflows is None: + raise ValueError("Either workflow or workflows must be provided") + worker = Worker( client=temporal_client, task_queue=self.task_queue, activity_executor=ThreadPoolExecutor(max_workers=self.max_workers), - workflows=[workflow], + workflows=[workflow] if workflows is None else workflows, activities=activities, workflow_runner=UnsandboxedWorkflowRunner(), max_concurrent_activities=self.max_concurrent_activities,