diff --git a/python/restate/context.py b/python/restate/context.py index 37ab74d..fd51626 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -31,6 +31,7 @@ HandlerType = Union[Callable[[Any, I], Awaitable[O]], Callable[[Any], Awaitable[O]]] RunAction = Union[Callable[..., Coroutine[Any, Any, T]], Callable[..., T]] +# pylint: disable=R0902 @dataclass class RunOptions(typing.Generic[T]): """ @@ -40,15 +41,32 @@ class RunOptions(typing.Generic[T]): serde: Serde[T] = DefaultSerde() """The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type. See also 'type_hint'.""" - max_attempts: Optional[int] = None - """The maximum number of retry attempts, including the initial attempt, to complete the action. - If None, the action will be retried indefinitely, until it succeeds. - Otherwise, the action will be retried until the maximum number of attempts is reached and then it will raise a TerminalError.""" - max_retry_duration: Optional[timedelta] = None - """The maximum duration for retrying. If None, the action will be retried indefinitely, until it succeeds. - Otherwise, the action will be retried until the maximum duration is reached and then it will raise a TerminalError.""" type_hint: Optional[typing.Type[T]] = None """The type hint of the return value of the action. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.""" + max_attempts: Optional[int] = None + """Max number of attempts (including the initial), before giving up. + + When giving up, `ctx.run` will throw a `TerminalError` wrapping the original error message.""" + max_duration: Optional[timedelta] = None + """Max duration of retries, before giving up. + + When giving up, `ctx.run` will throw a `TerminalError` wrapping the original error message.""" + initial_retry_interval: Optional[timedelta] = None + """Initial interval for the first retry attempt. + Retry interval will grow by a factor specified in `retry_interval_factor`. + + If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy.""" + max_retry_interval: Optional[timedelta] = None + """Max interval between retries. + Retry interval will grow by a factor specified in `retry_interval_factor`. + + The default is 10 seconds.""" + retry_interval_factor: Optional[float] = None + """Exponentiation factor to use when computing the next retry delay. + + If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy.""" + max_retry_duration: Optional[timedelta] = None + """Deprecated: Use max_duration instead.""" # pylint: disable=R0903 class RestateDurableFuture(typing.Generic[T], Awaitable[T]): diff --git a/python/restate/server_context.py b/python/restate/server_context.py index d6dabee..8469891 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -408,18 +408,14 @@ async def must_take_notification(self, handle): async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> None: """Create a coroutine to poll the handle.""" - await self.take_and_send_output() while True: + await self.take_and_send_output() do_progress_response = self.vm.do_progress(handles) - if isinstance(do_progress_response, Exception): - # We might need to write out something at this point. - await self.take_and_send_output() + if isinstance(do_progress_response, BaseException): # Print this exception, might be relevant for the user traceback.print_exception(do_progress_response) await cancel_current_task() if isinstance(do_progress_response, Suspended): - # We might need to write out something at this point. - await self.take_and_send_output() await cancel_current_task() if isinstance(do_progress_response, DoProgressAnyCompleted): # One of the handles completed @@ -432,9 +428,10 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No assert fn is not None async def wrapper(f): - await f() - await self.take_and_send_output() - await self.receive.enqueue_restate_event({ 'type' : 'restate.run_completed', 'data': None}) + try: + await f() + finally: + await self.receive.enqueue_restate_event({ 'type' : 'restate.run_completed', 'data': None}) task = asyncio.create_task(wrapper(fn)) self.tasks.add(task) @@ -542,9 +539,13 @@ async def create_run_coroutine(self, action: RunAction[T], serde: Serde[T], max_attempts: Optional[int] = None, - max_retry_duration: Optional[timedelta] = None, + max_duration: Optional[timedelta] = None, + initial_retry_interval: Optional[timedelta] = None, + max_retry_interval: Optional[timedelta] = None, + retry_interval_factor: Optional[float] = None, ): """Create a coroutine to poll the handle.""" + start = time.time() try: if inspect.iscoroutinefunction(action): action_result: T = await action() # type: ignore @@ -565,15 +566,20 @@ async def create_run_coroutine(self, raise e from None # pylint: disable=W0718 except Exception as e: - if max_attempts is None and max_retry_duration is None: - # no retry policy - # todo: log the error - self.vm.notify_error(repr(e), traceback.format_exc()) - else: - failure = Failure(code=500, message=str(e)) - max_duration_ms = None if max_retry_duration is None else int(max_retry_duration.total_seconds() * 1000) - config = RunRetryConfig(max_attempts=max_attempts, max_duration=max_duration_ms) - self.vm.propose_run_completion_transient(handle, failure=failure, attempt_duration_ms=1, config=config) + end = time.time() + attempt_duration = int((end - start) * 1000) + failure = Failure(code=500, message=str(e)) + max_duration_ms = None if max_duration is None else int(max_duration.total_seconds() * 1000) + initial_retry_interval_ms = None if initial_retry_interval is None else int(initial_retry_interval.total_seconds() * 1000) + max_retry_interval_ms = None if max_retry_interval is None else int(max_retry_interval.total_seconds() * 1000) + config = RunRetryConfig( + max_attempts=max_attempts, + max_duration=max_duration_ms, + initial_interval=initial_retry_interval_ms, + max_interval=max_retry_interval_ms, + interval_factor=retry_interval_factor + ) + self.vm.propose_run_completion_transient(handle, failure=failure, attempt_duration_ms=attempt_duration, config=config) # pylint: disable=W0236 # pylint: disable=R0914 def run(self, @@ -600,7 +606,7 @@ def run(self, else: # todo: we can also verify by looking at the signature that there are no missing parameters noargs_action = action # type: ignore - self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, noargs_action, serde, max_attempts, max_retry_duration) + self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, noargs_action, serde, max_attempts, max_retry_duration, None, None, None) return self.create_future(handle, serde) # type: ignore def run_typed( @@ -623,7 +629,16 @@ def run_typed( update_restate_context_is_replaying(self.vm) func = functools.partial(action, *args, **kwargs) - self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, func, options.serde, options.max_attempts, options.max_retry_duration) + self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine( + handle, + func, + options.serde, + options.max_attempts, + options.max_duration, + options.initial_retry_interval, + options.max_retry_interval, + options.retry_interval_factor + ) return self.create_future(handle, options.serde) def sleep(self, delta: timedelta) -> RestateDurableSleepFuture: diff --git a/python/restate/vm.py b/python/restate/vm.py index a001ceb..ef8532f 100644 --- a/python/restate/vm.py +++ b/python/restate/vm.py @@ -33,11 +33,15 @@ class Invocation: @dataclass class RunRetryConfig: """ - Expo Retry Configuration + Exponential Retry Configuration + + All duration/interval values are in milliseconds. """ initial_interval: typing.Optional[int] = None max_attempts: typing.Optional[int] = None max_duration: typing.Optional[int] = None + max_interval: typing.Optional[int] = None + interval_factor: typing.Optional[float] = None @dataclass class Failure: @@ -394,22 +398,20 @@ def propose_run_completion_failure(self, handle: int, output: Failure) -> int: return self.vm.propose_run_completion_failure(handle, res) # pylint: disable=line-too-long - def propose_run_completion_transient(self, handle: int, failure: Failure, attempt_duration_ms: int, config: RunRetryConfig) -> int | None: + def propose_run_completion_transient(self, handle: int, failure: Failure, attempt_duration_ms: int, config: RunRetryConfig): """ Exit a side effect with a transient Error. This requires a retry policy to be provided. """ py_failure = PyFailure(failure.code, failure.message) - py_config = PyExponentialRetryConfig(config.initial_interval, config.max_attempts, config.max_duration) - try: - handle = self.vm.propose_run_completion_failure_transient(handle, py_failure, attempt_duration_ms, py_config) - # The VM decided not to retry, therefore we get back an handle that will be resolved - # with a terminal failure. - return handle - # pylint: disable=bare-except - except: - # The VM decided to retry, therefore we tear down the current execution - return None + py_config = PyExponentialRetryConfig( + config.initial_interval, + config.max_attempts, + config.max_duration, + config.max_interval, + config.interval_factor + ) + self.vm.propose_run_completion_failure_transient(handle, py_failure, attempt_duration_ms, py_config) def sys_end(self): """ diff --git a/src/lib.rs b/src/lib.rs index 992753d..a5da5f8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,14 @@ -use std::fmt; use pyo3::create_exception; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyNone, PyString}; +use restate_sdk_shared_core::fmt::{set_error_formatter, ErrorFormatter}; use restate_sdk_shared_core::{ CallHandle, CoreVM, DoProgressResponse, Error, Header, IdentityVerifier, Input, NonEmptyValue, NotificationHandle, ResponseHead, RetryPolicy, RunExitResult, TakeOutputResult, Target, TerminalFailure, VMOptions, Value, CANCEL_NOTIFICATION_HANDLE, VM, }; +use std::fmt; use std::time::{Duration, SystemTime}; -use restate_sdk_shared_core::fmt::{set_error_formatter, ErrorFormatter}; // Current crate version const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -115,33 +115,55 @@ struct PyExponentialRetryConfig { max_attempts: Option, #[pyo3(get, set)] max_duration: Option, + #[pyo3(get, set)] + max_interval: Option, + #[pyo3(get, set)] + factor: Option, } #[pymethods] impl PyExponentialRetryConfig { - #[pyo3(signature = (initial_interval=None, max_attempts=None, max_duration=None))] + #[pyo3(signature = (initial_interval=None, max_attempts=None, max_duration=None, max_interval=None, factor=None))] #[new] fn new( initial_interval: Option, max_attempts: Option, max_duration: Option, + max_interval: Option, + factor: Option, ) -> Self { Self { initial_interval, max_attempts, max_duration, + max_interval, + factor, } } } impl From for RetryPolicy { fn from(value: PyExponentialRetryConfig) -> Self { - RetryPolicy::Exponential { - initial_interval: Duration::from_millis(value.initial_interval.unwrap_or(10)), - max_attempts: value.max_attempts, - max_duration: value.max_duration.map(Duration::from_millis), - factor: 2.0, - max_interval: None, + if value.initial_interval.is_some() + || value.max_attempts.is_some() + || value.max_duration.is_some() + || value.max_interval.is_some() + || value.factor.is_some() + { + // If any of the values are set, then let's create the exponential retry policy + RetryPolicy::Exponential { + initial_interval: Duration::from_millis(value.initial_interval.unwrap_or(50)), + max_attempts: value.max_attempts, + max_duration: value.max_duration.map(Duration::from_millis), + factor: value.factor.unwrap_or(2.0) as f32, + max_interval: value + .max_interval + .map(Duration::from_millis) + .or_else(|| Some(Duration::from_secs(10))), + } + } else { + // Let's use retry policy infinite here, which will give back control to the invocation retry policy + RetryPolicy::Infinite } } } diff --git a/test-services/services/failing.py b/test-services/services/failing.py index 37711e8..9856d82 100644 --- a/test-services/services/failing.py +++ b/test-services/services/failing.py @@ -9,6 +9,8 @@ # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE # """example.py""" +from datetime import timedelta + # pylint: disable=C0116 # pylint: disable=W0613 # pylint: disable=W0622 @@ -61,7 +63,7 @@ def side_effect(): return eventual_success_side_effects raise ValueError(f"Failed at attempt: {eventual_success_side_effects}") - options: RunOptions[int] = RunOptions(max_attempts=minimum_attempts + 1) + options: RunOptions[int] = RunOptions(max_attempts=minimum_attempts + 1, initial_retry_interval=timedelta(milliseconds=1), retry_interval_factor=1.0) return await ctx.run_typed("sideEffect", side_effect, options) eventual_failure_side_effects = 0 @@ -75,7 +77,7 @@ def side_effect(): raise ValueError(f"Failed at attempt: {eventual_failure_side_effects}") try: - options: RunOptions[int] = RunOptions(max_attempts=retry_policy_max_retry_count) + options: RunOptions[int] = RunOptions(max_attempts=retry_policy_max_retry_count, initial_retry_interval=timedelta(milliseconds=1), retry_interval_factor=1.0) await ctx.run_typed("sideEffect", side_effect, options) raise ValueError("Side effect did not fail.") except TerminalError as t: