Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from restate.handler import Handler, handler_from_callable, invoke_handler
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
from restate.server_types import ReceiveChannel, Send
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
from restate.vm import Failure, Invocation, NotReady, VMWrapper, RunRetryConfig, Suspended # pylint: disable=line-too-long
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun
import typing_extensions

Expand Down Expand Up @@ -160,7 +160,7 @@ def __init__(self, context: "ServerInvocationContext", handle: int) -> None:
async def coro() -> str:
if not context.vm.is_completed(handle):
await context.create_poll_or_cancel_coroutine([handle])
invocation_id = context.must_take_notification(handle)
invocation_id = await context.must_take_notification(handle)
return typing.cast(str, invocation_id)

self.future = LazyFuture(coro)
Expand Down Expand Up @@ -200,7 +200,7 @@ def resolve(self, value: Any) -> Awaitable[None]:
async def await_point():
if not self.server_context.vm.is_completed(handle):
await self.server_context.create_poll_or_cancel_coroutine([handle])
self.server_context.must_take_notification(handle)
await self.server_context.must_take_notification(handle)

return ServerDurableFuture(self.server_context, handle, await_point)

Expand All @@ -213,7 +213,7 @@ def reject(self, message: str, code: int = 500) -> Awaitable[None]:
async def await_point():
if not self.server_context.vm.is_completed(handle):
await self.server_context.create_poll_or_cancel_coroutine([handle])
self.server_context.must_take_notification(handle)
await self.server_context.must_take_notification(handle)

return ServerDurableFuture(self.server_context, handle, await_point)

Expand Down Expand Up @@ -273,6 +273,19 @@ def update_restate_context_is_replaying(vm: VMWrapper):
"""Update the context var 'restate_context_is_replaying'. This should be called after each vm.sys_*"""
restate_context_is_replaying.set(vm.is_replaying())

async def cancel_current_task():
"""Cancel the current task"""
current_task = asyncio.current_task()
if current_task is not None:
# Cancel through asyncio API
current_task.cancel(
"Cancelled by Restate SDK, you should not call any Context method after this exception is thrown."
)
# Sleep 0 will pop up the cancellation
await asyncio.sleep(0)
else:
raise asyncio.CancelledError("Cancelled by Restate SDK, you should not call any Context method after this exception is thrown.")

# pylint: disable=R0902
class ServerInvocationContext(ObjectContext):
"""This class implements the context for the restate framework based on the server."""
Expand Down Expand Up @@ -312,7 +325,7 @@ async def enter(self):
self.vm.sys_write_output_failure(failure)
self.vm.sys_end()
# pylint: disable=W0718
except SuspendedException:
except asyncio.CancelledError:
pass
except DisconnectedException:
raise
Expand Down Expand Up @@ -372,9 +385,19 @@ async def take_and_send_output(self):
'more_body': True,
})

def must_take_notification(self, handle):
async def must_take_notification(self, handle):
"""Take notification, which must be present"""
res = self.vm.take_notification(handle)
if isinstance(res, Exception):
# We might need to write out something at this point.
await self.take_and_send_output()
# Print this exception, might be relevant for the user
traceback.print_exception(res)
await cancel_current_task()
if isinstance(res, Suspended):
# We might need to write out something at this point.
await self.take_and_send_output()
await cancel_current_task()
if isinstance(res, NotReady):
raise ValueError(f"Unexpected value error: {handle}")
if res is None:
Expand All @@ -383,12 +406,21 @@ def must_take_notification(self, handle):
raise TerminalError(res.message, res.code)
return res


async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> None:
"""Create a coroutine to poll the handle."""
await self.take_and_send_output()
while True:
do_progress_response = self.vm.do_progress(handles)
if isinstance(do_progress_response, Exception):
# We might need to write out something at this point.
await self.take_and_send_output()
# Print this exception, might be relevant for the user
traceback.print_exception(do_progress_response)
await cancel_current_task()
if isinstance(do_progress_response, Suspended):
# We might need to write out something at this point.
await self.take_and_send_output()
await cancel_current_task()
if isinstance(do_progress_response, DoProgressAnyCompleted):
# One of the handles completed
return
Expand Down Expand Up @@ -425,7 +457,7 @@ def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = N
async def fetch_result():
if not self.vm.is_completed(handle):
await self.create_poll_or_cancel_coroutine([handle])
res = self.must_take_notification(handle)
res = await self.must_take_notification(handle)
if res is None or serde is None:
return res
if isinstance(res, bytes):
Expand All @@ -443,15 +475,15 @@ def create_sleep_future(self, handle: int) -> ServerDurableSleepFuture:
async def transform():
if not self.vm.is_completed(handle):
await self.create_poll_or_cancel_coroutine([handle])
self.must_take_notification(handle)
await self.must_take_notification(handle)
return ServerDurableSleepFuture(self, handle, transform)

def create_call_future(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]:
"""Create a durable future."""
async def inv_id_factory():
if not self.vm.is_completed(invocation_id_handle):
await self.create_poll_or_cancel_coroutine([invocation_id_handle])
return self.must_take_notification(invocation_id_handle)
return await self.must_take_notification(invocation_id_handle)

return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory)

Expand Down
45 changes: 26 additions & 19 deletions python/restate/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from dataclasses import dataclass
import typing
from restate._internal import PyVM, PyHeader, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoWaitForPendingRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long
from restate._internal import PyVM, PyHeader, PyFailure, VMException, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoWaitForPendingRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long

@dataclass
class Invocation:
Expand Down Expand Up @@ -53,19 +53,18 @@ class NotReady:
NotReady
"""

class SuspendedException(Exception):
"""
Suspended Exception
"""
def __init__(self, *args: object) -> None:
super().__init__(*args)

NOT_READY = NotReady()
SUSPENDED = SuspendedException()
CANCEL_HANDLE = CANCEL_NOTIFICATION_HANDLE

NotificationType = typing.Optional[typing.Union[bytes, Failure, NotReady, list[str], str]]

class Suspended:
"""
Represents a suspended error
"""

SUSPENDED = Suspended()

class DoProgressAnyCompleted:
"""
Represents a notification that any of the handles has completed.
Expand Down Expand Up @@ -151,11 +150,16 @@ def is_completed(self, handle: int) -> bool:
"""Returns true when the notification handle is completed and hasn't been taken yet."""
return self.vm.is_completed(handle)

def do_progress(self, handles: list[int]) -> DoProgressResult:
# pylint: disable=R0911
def do_progress(self, handles: list[int]) \
-> typing.Union[DoProgressResult, Exception, Suspended]:
"""Do progress with notifications."""
result = self.vm.do_progress(handles)
try:
result = self.vm.do_progress(handles)
except VMException as e:
return e
if isinstance(result, PySuspended):
raise SUSPENDED
return SUSPENDED
if isinstance(result, PyDoProgressAnyCompleted):
return DO_PROGRESS_ANY_COMPLETED
if isinstance(result, PyDoProgressReadFromInput):
Expand All @@ -166,11 +170,17 @@ def do_progress(self, handles: list[int]) -> DoProgressResult:
return DO_PROGRESS_CANCEL_SIGNAL_RECEIVED
if isinstance(result, PyDoWaitForPendingRun):
return DO_WAIT_PENDING_RUN
raise ValueError(f"Unknown progress type: {result}")
return ValueError(f"Unknown progress type: {result}")

def take_notification(self, handle: int) -> NotificationType:
def take_notification(self, handle: int) \
-> typing.Union[NotificationType, Exception, Suspended]:
"""Take the result of an asynchronous operation."""
result = self.vm.take_notification(handle)
try:
result = self.vm.take_notification(handle)
except VMException as e:
return e
if isinstance(result, PySuspended):
return SUSPENDED
if result is None:
return NOT_READY
if isinstance(result, PyVoid):
Expand All @@ -190,10 +200,7 @@ def take_notification(self, handle: int) -> NotificationType:
code = result.code
message = result.message
return Failure(code, message)
if isinstance(result, PySuspended):
# the state machine had suspended
raise SUSPENDED
raise ValueError(f"Unknown result type: {result}")
return ValueError(f"Unknown result type: {result}")

def sys_input(self) -> Invocation:
"""
Expand Down
19 changes: 12 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ create_exception!(
restate_sdk_python_core,
VMException,
pyo3::exceptions::PyException,
"Restate VM exception."
"Protocol state machine exception."
);

impl From<PyVMError> for PyErr {
Expand Down Expand Up @@ -761,21 +761,26 @@ impl ErrorFormatter for PythonErrorFormatter {
fn display_closed_error(&self, f: &mut fmt::Formatter<'_>, event: &str) -> fmt::Result {
write!(f, "Execution is suspended, but the handler is still attempting to make progress (calling '{event}'). This can happen:

* If the SuspendedException is caught. Make sure you NEVER catch the SuspendedException, e.g. avoid:
* If you don't need to handle task cancellation, just avoid catch all statements. Don't do:
try:
# Code
except:
# This catches all exceptions, including the SuspendedException!
# This catches all exceptions, including the asyncio.CancelledError!
# '{event}' <- This operation prints this exception

And use instead:
Do instead:
try:
# Code
except TerminalException:
# In Restate handlers you typically want to catch TerminalException
# In Restate handlers you typically want to catch TerminalException only

Check https://docs.restate.dev/develop/python/durable-steps#run for more details on run error handling.
* To catch ctx.run/ctx.run_typed errors, check https://docs.restate.dev/develop/python/durable-steps#run for more details.

* If you use the context after the handler completed, e.g. moving the context to another thread. Check https://docs.restate.dev/develop/python/concurrent-tasks for more details on how to create durable concurrent tasks in Python.")
* If the asyncio.CancelledError is caught, you must not run any Context operation in the except arm.
Check https://docs.python.org/3/library/asyncio-task.html#task-cancellation for more details on task cancellation.

* If you use the context after the handler completed, e.g. moving the context to another thread.
Check https://docs.restate.dev/develop/python/concurrent-tasks for more details on how to create durable concurrent tasks in Python.")
}
}

Expand Down