diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 33c72d53..aad53e7f 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -130,9 +130,8 @@ async def run(self) -> None: raise RuntimeError(f"Unrecognized activity task: {task}") except temporalio.bridge.worker.PollShutdownError: return - except Exception: - # Should never happen - logger.exception(f"Activity runner failed") + except Exception as err: + raise RuntimeError("Activity worker failed") from err async def shutdown(self, after_graceful_timeout: timedelta) -> None: # Set event that we're shutting down (updates all activity tasks) diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 459ef883..d6085f87 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -8,7 +8,7 @@ import logging import sys from datetime import timedelta -from typing import Any, Callable, List, Optional, Sequence, Type, cast +from typing import Any, Awaitable, Callable, List, Optional, Sequence, Type, cast from typing_extensions import TypedDict @@ -38,7 +38,9 @@ class Worker: """Worker to process workflows and/or activities. Once created, workers can be run and shutdown explicitly via :py:meth:`run` - and :py:meth:`shutdown`, or they can be used in an ``async with`` clause. + and :py:meth:`shutdown`. Alternatively workers can be used in an + ``async with`` clause. See :py:meth:`__aenter__` and :py:meth:`__aexit__` + for important details about fatal errors. """ def __init__( @@ -71,6 +73,7 @@ def __init__( graceful_shutdown_timeout: timedelta = timedelta(), shared_state_manager: Optional[SharedStateManager] = None, debug_mode: bool = False, + on_fatal_error: Optional[Callable[[BaseException], Awaitable[None]]] = None, ) -> None: """Create a worker to process workflows and/or activities. @@ -163,6 +166,9 @@ def __init__( sandboxing in order to make using a debugger easier. If false but the environment variable ``TEMPORAL_DEBUG`` is truthy, this will be set to true. + on_fatal_error: An async function that can handle a failure before + the worker shutdown commences. This cannot stop the shutdown and + any exception raised is logged and ignored. """ if not activities and not workflows: raise ValueError("At least one activity or workflow must be specified") @@ -222,8 +228,14 @@ def __init__( graceful_shutdown_timeout=graceful_shutdown_timeout, shared_state_manager=shared_state_manager, debug_mode=debug_mode, + on_fatal_error=on_fatal_error, ) - self._task: Optional[asyncio.Task] = None + self._started = False + self._shutdown_event = asyncio.Event() + self._shutdown_complete_event = asyncio.Event() + self._async_context_inner_task: Optional[asyncio.Task] = None + self._async_context_run_task: Optional[asyncio.Task] = None + self._async_context_run_exception: Optional[BaseException] = None # Create activity and workflow worker self._activity_worker: Optional[_ActivityWorker] = None @@ -314,59 +326,174 @@ def task_queue(self) -> str: """Task queue this worker is on.""" return self._config["task_queue"] - async def __aenter__(self) -> Worker: - """Start the worker and return self for use by ``async with``. + @property + def is_running(self) -> bool: + """Whether the worker is running. - Returns: - Self. + This is only ``True`` if the worker has been started and not yet + shut down. """ - self._start() - return self + return self._started and not self.is_shutdown - async def __aexit__(self, *args) -> None: - """Same as :py:meth:`shutdown` for use by ``async with``.""" - await self.shutdown() + @property + def is_shutdown(self) -> bool: + """Whether the worker has run and shut down. - async def run(self) -> None: - """Run the worker and wait on it to be shutdown.""" - await self._start() + This is only ``True`` if the worker was once started and then shutdown. + This is not necessarily ``True`` after :py:meth:`shutdown` is first + called because the shutdown process can take a bit. + """ + return self._shutdown_complete_event.is_set() - def _start(self) -> asyncio.Task: - if self._task: - raise RuntimeError("Already started") - worker_tasks: List[asyncio.Task] = [] - if self._activity_worker: - worker_tasks.append(asyncio.create_task(self._activity_worker.run())) - if self._workflow_worker: - worker_tasks.append(asyncio.create_task(self._workflow_worker.run())) - self._task = asyncio.create_task(asyncio.wait(worker_tasks)) - return self._task + async def run(self) -> None: + """Run the worker and wait on it to be shut down. - async def shutdown(self) -> None: - """Shutdown the worker and wait until all activities have completed. + This will not return until shutdown is complete. This means that + activities have all completed after being told to cancel after the + graceful timeout period. - This will initiate a shutdown and optionally wait for a grace period - before sending cancels to all activities. + This method will raise if there is a worker fatal error. While + :py:meth:`shutdown` does not need to be invoked in this case, it is + harmless to do so. Otherwise, to shut down this worker, invoke + :py:meth:`shutdown`. - This worker should not be used in any way once this is called. + Technically this worker can be shutdown by issuing a cancel to this + async function assuming that it is currently running. A cancel could + also cancel the shutdown process. Therefore users are encouraged to use + explicit shutdown instead. """ - if not self._task: - raise RuntimeError("Never started") + if self._started: + raise RuntimeError("Already started") + self._started = True + + # Create a task that raises when a shutdown is requested + async def raise_on_shutdown(): + try: + await self._shutdown_event.wait() + raise _ShutdownRequested() + except asyncio.CancelledError: + pass + + tasks: List[asyncio.Task] = [asyncio.create_task(raise_on_shutdown())] + # Create tasks for workers + if self._activity_worker: + tasks.append(asyncio.create_task(self._activity_worker.run())) + if self._workflow_worker: + tasks.append(asyncio.create_task(self._workflow_worker.run())) + + # Wait for either worker or shutdown requested + wait_task = asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) + try: + await asyncio.shield(wait_task) + + # If any of the last two tasks failed, we want to re-raise that as + # the exception + exception = next((t.exception() for t in tasks[1:] if t.done()), None) + if exception: + logger.error("Worker failed, shutting down", exc_info=exception) + if self._config["on_fatal_error"]: + try: + await self._config["on_fatal_error"](exception) + except: + logger.warning("Fatal error handler failed") + + except asyncio.CancelledError as user_cancel_err: + # Represents user literally calling cancel + logger.info("Worker cancelled, shutting down") + exception = user_cancel_err + + # Cancel the shutdown task (safe if already done) + tasks[0].cancel() graceful_timeout = self._config["graceful_shutdown_timeout"] logger.info( - f"Beginning worker shutdown, will wait {graceful_timeout} before cancelling workflows/activities" + f"Beginning worker shutdown, will wait {graceful_timeout} before cancelling activities" ) # Start shutdown of the bridge bridge_shutdown_task = asyncio.create_task(self._bridge_worker.shutdown()) - # Wait for the poller loops to stop - await self._task + + # Wait for all tasks to complete (i.e. for poller loops to stop) + await asyncio.wait(tasks) + # Sometimes both workers throw an exception and since we only take the + # first, Python may complain with "Task exception was never retrieved" + # if we don't get the others. Therefore we call cancel on each task + # which suppresses this. + for task in tasks: + task.cancel() + # Shutdown the activity worker (there is no workflow worker shutdown) if self._activity_worker: await self._activity_worker.shutdown(graceful_timeout) # Wait for the bridge to report everything is completed await bridge_shutdown_task # Do final shutdown - await self._bridge_worker.finalize_shutdown() + try: + await self._bridge_worker.finalize_shutdown() + except: + # Ignore errors here that can arise in some tests where the bridge + # worker still has a reference + pass + + # Mark as shutdown complete and re-raise exception if present + self._shutdown_complete_event.set() + if exception: + raise exception + + async def shutdown(self) -> None: + """Initiate a worker shutdown and wait until complete. + + This can be called before the worker has even started and is safe for + repeated invocations. It simply sets a marker informing the worker to + shut down as it runs. + + This will not return until the worker has completed shutting down. + """ + self._shutdown_event.set() + await self._shutdown_complete_event.wait() + + async def __aenter__(self) -> Worker: + """Start the worker and return self for use by ``async with``. + + This is a wrapper around :py:meth:`run`. Please review that method. + + This takes a similar approach to :py:func:`asyncio.timeout` in that it + will cancel the current task if there is a fatal worker error and raise + that error out of the context manager. However, if the inner async code + swallows/wraps the :py:class:`asyncio.CancelledError`, the exiting + portion of the context manager will not raise the fatal worker error. + """ + if self._async_context_inner_task: + raise RuntimeError("Already started") + self._async_context_inner_task = asyncio.current_task() + if not self._async_context_inner_task: + raise RuntimeError("Can only use async with inside a task") + + # Start a task that runs and if there's an error, cancels the current + # task and re-raises the error + async def run(): + try: + await self.run() + except BaseException as err: + self._async_context_run_exception = err + self._async_context_inner_task.cancel() + + self._async_context_run_task = asyncio.create_task(run()) + return self + + async def __aexit__(self, exc_type: Optional[Type[BaseException]], *args) -> None: + """Same as :py:meth:`shutdown` for use by ``async with``. + + Note, this will raise the worker fatal error if one occurred and the + inner task cancellation was not inadvertently swallowed/wrapped. + """ + # Wait for shutdown then run complete + if not self._async_context_run_task: + raise RuntimeError("Never started") + await self.shutdown() + # Cancel our run task + self._async_context_run_task.cancel() + # Only re-raise our exception if present and exc_type is cancel + if exc_type is asyncio.CancelledError and self._async_context_run_exception: + raise self._async_context_run_exception class WorkerConfig(TypedDict, total=False): @@ -399,6 +526,7 @@ class WorkerConfig(TypedDict, total=False): graceful_shutdown_timeout: timedelta shared_state_manager: Optional[SharedStateManager] debug_mode: bool + on_fatal_error: Optional[Callable[[BaseException], Awaitable[None]]] _default_build_id: Optional[str] = None @@ -478,3 +606,7 @@ def _get_module_code(mod_name: str) -> Optional[bytes]: except Exception: pass return None + + +class _ShutdownRequested(RuntimeError): + pass diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 972a4f6f..5f99aaad 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -129,9 +129,8 @@ async def run(self) -> None: setattr(task, "__temporal_task_tag", task_tag) except temporalio.bridge.worker.PollShutdownError: pass - except Exception: - # Should never happen - logger.exception(f"Workflow runner failed") + except Exception as err: + raise RuntimeError("Workflow worker failed") from err finally: # Collect all tasks and wait for them to complete our_tasks = [ diff --git a/tests/helpers/worker.py b/tests/helpers/worker.py index a288d58c..1a53e17d 100644 --- a/tests/helpers/worker.py +++ b/tests/helpers/worker.py @@ -230,7 +230,7 @@ def __init__(self, env: WorkflowEnvironment) -> None: self.worker = Worker( env.client, task_queue=str(uuid.uuid4()), workflows=[KitchenSinkWorkflow] ) - self.worker._start() + self.run_task = asyncio.create_task(self.worker.run()) @property def task_queue(self) -> str: @@ -238,3 +238,4 @@ def task_queue(self) -> str: async def close(self): await self.worker.shutdown() + await self.run_task diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index c49d9048..a2a7f52f 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -1,4 +1,15 @@ +from __future__ import annotations + +import asyncio +import uuid +from typing import Any, Awaitable, Callable, Optional + +import pytest + import temporalio.worker._worker +from temporalio import activity, workflow +from temporalio.client import Client +from temporalio.worker import Worker def test_load_default_worker_binary_id(): @@ -6,3 +17,176 @@ def test_load_default_worker_binary_id(): val1 = temporalio.worker._worker.load_default_build_id(memoize=False) val2 = temporalio.worker._worker.load_default_build_id(memoize=False) assert val1 == val2 + + +@activity.defn +async def never_run_activity() -> None: + raise NotImplementedError + + +@workflow.defn +class NeverRunWorkflow: + @workflow.run + async def run(self) -> None: + raise NotImplementedError + + +async def test_worker_fatal_error_run(client: Client): + # Run worker with injected workflow poll error + worker = create_worker(client) + with pytest.raises(RuntimeError) as err: + with WorkerFailureInjector(worker) as inj: + inj.workflow.poll_fail_queue.put_nowait(RuntimeError("OH NO")) + await worker.run() + assert str(err.value) == "Workflow worker failed" + assert err.value.__cause__ and str(err.value.__cause__) == "OH NO" + + # Run worker with injected activity poll error + worker = create_worker(client) + with pytest.raises(RuntimeError) as err: + with WorkerFailureInjector(worker) as inj: + inj.activity.poll_fail_queue.put_nowait(RuntimeError("OH NO")) + await worker.run() + assert str(err.value) == "Activity worker failed" + assert err.value.__cause__ and str(err.value.__cause__) == "OH NO" + + # Run worker with them both injected (was causing warning for not retrieving + # the second error) + worker = create_worker(client) + with pytest.raises(RuntimeError) as err: + with WorkerFailureInjector(worker) as inj: + inj.workflow.poll_fail_queue.put_nowait(RuntimeError("OH NO")) + inj.activity.poll_fail_queue.put_nowait(RuntimeError("OH NO")) + await worker.run() + assert str(err.value).endswith("worker failed") + assert err.value.__cause__ and str(err.value.__cause__) == "OH NO" + + +async def test_worker_fatal_error_with(client: Client): + # Start the worker, wait a short bit, fail it, wait for long time (will be + # cancelled) + worker = create_worker(client) + with pytest.raises(RuntimeError) as err: + with WorkerFailureInjector(worker) as inj: + async with worker: + await asyncio.sleep(0.1) + inj.workflow.poll_fail_queue.put_nowait(RuntimeError("OH NO")) + await asyncio.sleep(1000) + assert str(err.value) == "Workflow worker failed" + assert err.value.__cause__ and str(err.value.__cause__) == "OH NO" + + # Raise inside the async with and confirm it works + worker = create_worker(client) + with pytest.raises(RuntimeError) as err: + with WorkerFailureInjector(worker) as inj: + async with worker: + raise RuntimeError("IN WITH") + assert str(err.value) == "IN WITH" + + # Demonstrate that inner re-thrown failure swallows worker failure + worker = create_worker(client) + with pytest.raises(RuntimeError) as err: + with WorkerFailureInjector(worker) as inj: + async with worker: + inj.workflow.poll_fail_queue.put_nowait(RuntimeError("OH NO")) + try: + await asyncio.sleep(1000) + except BaseException as inner_err: + raise RuntimeError("Caught cancel") from inner_err + assert str(err.value) == "Caught cancel" + assert err.value.__cause__ and type(err.value.__cause__) is asyncio.CancelledError + + +async def test_worker_fatal_error_callback(client: Client): + callback_err: Optional[BaseException] = None + + async def on_fatal_error(exc: BaseException) -> None: + nonlocal callback_err + callback_err = exc + + worker = create_worker(client, on_fatal_error) + with pytest.raises(RuntimeError) as err: + with WorkerFailureInjector(worker) as inj: + async with worker: + await asyncio.sleep(0.1) + inj.workflow.poll_fail_queue.put_nowait(RuntimeError("OH NO")) + await asyncio.sleep(1000) + assert err.value is callback_err + + +async def test_worker_cancel_run(client: Client): + worker = create_worker(client) + assert not worker.is_running and not worker.is_shutdown + run_task = asyncio.create_task(worker.run()) + await asyncio.sleep(0.3) + assert worker.is_running and not worker.is_shutdown + run_task.cancel() + with pytest.raises(asyncio.CancelledError): + await run_task + assert not worker.is_running and worker.is_shutdown + + +def create_worker( + client: Client, + on_fatal_error: Optional[Callable[[BaseException], Awaitable[None]]] = None, +) -> Worker: + return Worker( + client, + task_queue=f"task-queue-{uuid.uuid4()}", + activities=[never_run_activity], + workflows=[NeverRunWorkflow], + on_fatal_error=on_fatal_error, + ) + + +class WorkerFailureInjector: + def __init__(self, worker: Worker) -> None: + self.workflow = PollFailureInjector(worker, "poll_workflow_activation") + self.activity = PollFailureInjector(worker, "poll_activity_task") + + def __enter__(self) -> WorkerFailureInjector: + return self + + def __exit__(self, *args, **kwargs) -> None: + self.workflow.shutdown() + self.activity.shutdown() + + +class PollFailureInjector: + def __init__(self, worker: Worker, attr: str) -> None: + self.worker = worker + self.attr = attr + self.poll_fail_queue: asyncio.Queue[Exception] = asyncio.Queue() + self.orig_poll_call = getattr(worker._bridge_worker, attr) + setattr(worker._bridge_worker, attr, self.patched_poll_call) + self.next_poll_task: Optional[asyncio.Task] = None + self.next_exception_task: Optional[asyncio.Task] = None + + async def patched_poll_call(self) -> Any: + if not self.next_poll_task: + self.next_poll_task = asyncio.create_task(self.orig_poll_call()) + if not self.next_exception_task: + self.next_exception_task = asyncio.create_task(self.poll_fail_queue.get()) + + await asyncio.wait( + [self.next_poll_task, self.next_exception_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # If activation came, return that and leave queue for next poll + if self.next_poll_task.done(): + ret = self.next_poll_task.result() + self.next_poll_task = None + return ret + + # Raise the error + exc = self.next_exception_task.result() + self.next_exception_task = None + raise exc + + def shutdown(self) -> None: + if self.next_poll_task: + self.next_poll_task.cancel() + if self.next_exception_task: + self.next_exception_task.cancel() + setattr(self.worker._bridge_worker, self.attr, self.orig_poll_call)