From c5ef2f9e06437fa3803aad2b9d9c79490f937712 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 22 Sep 2025 17:00:38 +0200 Subject: [PATCH] Use task cancellation instead of custom suspension exception --- python/restate/server_context.py | 52 ++++++++++++++++++++++++++------ python/restate/vm.py | 45 +++++++++++++++------------ src/lib.rs | 19 +++++++----- 3 files changed, 80 insertions(+), 36 deletions(-) diff --git a/python/restate/server_context.py b/python/restate/server_context.py index ff53ba9..8342817 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -33,7 +33,7 @@ from restate.handler import Handler, handler_from_callable, invoke_handler from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde from restate.server_types import ReceiveChannel, Send -from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long +from restate.vm import Failure, Invocation, NotReady, VMWrapper, RunRetryConfig, Suspended # pylint: disable=line-too-long from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun import typing_extensions @@ -160,7 +160,7 @@ def __init__(self, context: "ServerInvocationContext", handle: int) -> None: async def coro() -> str: if not context.vm.is_completed(handle): await context.create_poll_or_cancel_coroutine([handle]) - invocation_id = context.must_take_notification(handle) + invocation_id = await context.must_take_notification(handle) return typing.cast(str, invocation_id) self.future = LazyFuture(coro) @@ -200,7 +200,7 @@ def resolve(self, value: Any) -> Awaitable[None]: async def await_point(): if not self.server_context.vm.is_completed(handle): await self.server_context.create_poll_or_cancel_coroutine([handle]) - self.server_context.must_take_notification(handle) + await self.server_context.must_take_notification(handle) return ServerDurableFuture(self.server_context, handle, await_point) @@ -213,7 +213,7 @@ def reject(self, message: str, code: int = 500) -> Awaitable[None]: async def await_point(): if not self.server_context.vm.is_completed(handle): await self.server_context.create_poll_or_cancel_coroutine([handle]) - self.server_context.must_take_notification(handle) + await self.server_context.must_take_notification(handle) return ServerDurableFuture(self.server_context, handle, await_point) @@ -273,6 +273,19 @@ def update_restate_context_is_replaying(vm: VMWrapper): """Update the context var 'restate_context_is_replaying'. This should be called after each vm.sys_*""" restate_context_is_replaying.set(vm.is_replaying()) +async def cancel_current_task(): + """Cancel the current task""" + current_task = asyncio.current_task() + if current_task is not None: + # Cancel through asyncio API + current_task.cancel( + "Cancelled by Restate SDK, you should not call any Context method after this exception is thrown." + ) + # Sleep 0 will pop up the cancellation + await asyncio.sleep(0) + else: + raise asyncio.CancelledError("Cancelled by Restate SDK, you should not call any Context method after this exception is thrown.") + # pylint: disable=R0902 class ServerInvocationContext(ObjectContext): """This class implements the context for the restate framework based on the server.""" @@ -312,7 +325,7 @@ async def enter(self): self.vm.sys_write_output_failure(failure) self.vm.sys_end() # pylint: disable=W0718 - except SuspendedException: + except asyncio.CancelledError: pass except DisconnectedException: raise @@ -372,9 +385,19 @@ async def take_and_send_output(self): 'more_body': True, }) - def must_take_notification(self, handle): + async def must_take_notification(self, handle): """Take notification, which must be present""" res = self.vm.take_notification(handle) + if isinstance(res, Exception): + # We might need to write out something at this point. + await self.take_and_send_output() + # Print this exception, might be relevant for the user + traceback.print_exception(res) + await cancel_current_task() + if isinstance(res, Suspended): + # We might need to write out something at this point. + await self.take_and_send_output() + await cancel_current_task() if isinstance(res, NotReady): raise ValueError(f"Unexpected value error: {handle}") if res is None: @@ -383,12 +406,21 @@ def must_take_notification(self, handle): raise TerminalError(res.message, res.code) return res - async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> None: """Create a coroutine to poll the handle.""" await self.take_and_send_output() while True: do_progress_response = self.vm.do_progress(handles) + if isinstance(do_progress_response, Exception): + # We might need to write out something at this point. + await self.take_and_send_output() + # Print this exception, might be relevant for the user + traceback.print_exception(do_progress_response) + await cancel_current_task() + if isinstance(do_progress_response, Suspended): + # We might need to write out something at this point. + await self.take_and_send_output() + await cancel_current_task() if isinstance(do_progress_response, DoProgressAnyCompleted): # One of the handles completed return @@ -425,7 +457,7 @@ def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = N async def fetch_result(): if not self.vm.is_completed(handle): await self.create_poll_or_cancel_coroutine([handle]) - res = self.must_take_notification(handle) + res = await self.must_take_notification(handle) if res is None or serde is None: return res if isinstance(res, bytes): @@ -443,7 +475,7 @@ def create_sleep_future(self, handle: int) -> ServerDurableSleepFuture: async def transform(): if not self.vm.is_completed(handle): await self.create_poll_or_cancel_coroutine([handle]) - self.must_take_notification(handle) + await self.must_take_notification(handle) return ServerDurableSleepFuture(self, handle, transform) def create_call_future(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]: @@ -451,7 +483,7 @@ def create_call_future(self, handle: int, invocation_id_handle: int, serde: Serd async def inv_id_factory(): if not self.vm.is_completed(invocation_id_handle): await self.create_poll_or_cancel_coroutine([invocation_id_handle]) - return self.must_take_notification(invocation_id_handle) + return await self.must_take_notification(invocation_id_handle) return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory) diff --git a/python/restate/vm.py b/python/restate/vm.py index b8a5fd3..a001ceb 100644 --- a/python/restate/vm.py +++ b/python/restate/vm.py @@ -17,7 +17,7 @@ from dataclasses import dataclass import typing -from restate._internal import PyVM, PyHeader, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoWaitForPendingRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long +from restate._internal import PyVM, PyHeader, PyFailure, VMException, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoWaitForPendingRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long @dataclass class Invocation: @@ -53,19 +53,18 @@ class NotReady: NotReady """ -class SuspendedException(Exception): - """ - Suspended Exception - """ - def __init__(self, *args: object) -> None: - super().__init__(*args) - NOT_READY = NotReady() -SUSPENDED = SuspendedException() CANCEL_HANDLE = CANCEL_NOTIFICATION_HANDLE NotificationType = typing.Optional[typing.Union[bytes, Failure, NotReady, list[str], str]] +class Suspended: + """ + Represents a suspended error + """ + +SUSPENDED = Suspended() + class DoProgressAnyCompleted: """ Represents a notification that any of the handles has completed. @@ -151,11 +150,16 @@ def is_completed(self, handle: int) -> bool: """Returns true when the notification handle is completed and hasn't been taken yet.""" return self.vm.is_completed(handle) - def do_progress(self, handles: list[int]) -> DoProgressResult: + # pylint: disable=R0911 + def do_progress(self, handles: list[int]) \ + -> typing.Union[DoProgressResult, Exception, Suspended]: """Do progress with notifications.""" - result = self.vm.do_progress(handles) + try: + result = self.vm.do_progress(handles) + except VMException as e: + return e if isinstance(result, PySuspended): - raise SUSPENDED + return SUSPENDED if isinstance(result, PyDoProgressAnyCompleted): return DO_PROGRESS_ANY_COMPLETED if isinstance(result, PyDoProgressReadFromInput): @@ -166,11 +170,17 @@ def do_progress(self, handles: list[int]) -> DoProgressResult: return DO_PROGRESS_CANCEL_SIGNAL_RECEIVED if isinstance(result, PyDoWaitForPendingRun): return DO_WAIT_PENDING_RUN - raise ValueError(f"Unknown progress type: {result}") + return ValueError(f"Unknown progress type: {result}") - def take_notification(self, handle: int) -> NotificationType: + def take_notification(self, handle: int) \ + -> typing.Union[NotificationType, Exception, Suspended]: """Take the result of an asynchronous operation.""" - result = self.vm.take_notification(handle) + try: + result = self.vm.take_notification(handle) + except VMException as e: + return e + if isinstance(result, PySuspended): + return SUSPENDED if result is None: return NOT_READY if isinstance(result, PyVoid): @@ -190,10 +200,7 @@ def take_notification(self, handle: int) -> NotificationType: code = result.code message = result.message return Failure(code, message) - if isinstance(result, PySuspended): - # the state machine had suspended - raise SUSPENDED - raise ValueError(f"Unknown result type: {result}") + return ValueError(f"Unknown result type: {result}") def sys_input(self) -> Invocation: """ diff --git a/src/lib.rs b/src/lib.rs index 45eef66..992753d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -249,7 +249,7 @@ create_exception!( restate_sdk_python_core, VMException, pyo3::exceptions::PyException, - "Restate VM exception." + "Protocol state machine exception." ); impl From for PyErr { @@ -761,21 +761,26 @@ impl ErrorFormatter for PythonErrorFormatter { fn display_closed_error(&self, f: &mut fmt::Formatter<'_>, event: &str) -> fmt::Result { write!(f, "Execution is suspended, but the handler is still attempting to make progress (calling '{event}'). This can happen: -* If the SuspendedException is caught. Make sure you NEVER catch the SuspendedException, e.g. avoid: +* If you don't need to handle task cancellation, just avoid catch all statements. Don't do: try: # Code except: - # This catches all exceptions, including the SuspendedException! + # This catches all exceptions, including the asyncio.CancelledError! + # '{event}' <- This operation prints this exception -And use instead: +Do instead: try: # Code except TerminalException: - # In Restate handlers you typically want to catch TerminalException + # In Restate handlers you typically want to catch TerminalException only -Check https://docs.restate.dev/develop/python/durable-steps#run for more details on run error handling. +* To catch ctx.run/ctx.run_typed errors, check https://docs.restate.dev/develop/python/durable-steps#run for more details. -* If you use the context after the handler completed, e.g. moving the context to another thread. Check https://docs.restate.dev/develop/python/concurrent-tasks for more details on how to create durable concurrent tasks in Python.") +* If the asyncio.CancelledError is caught, you must not run any Context operation in the except arm. + Check https://docs.python.org/3/library/asyncio-task.html#task-cancellation for more details on task cancellation. + +* If you use the context after the handler completed, e.g. moving the context to another thread. + Check https://docs.restate.dev/develop/python/concurrent-tasks for more details on how to create durable concurrent tasks in Python.") } }