Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
HandlerType = Union[Callable[[Any, I], Awaitable[O]], Callable[[Any], Awaitable[O]]]
RunAction = Union[Callable[..., Coroutine[Any, Any, T]], Callable[..., T]]

# pylint: disable=R0902
@dataclass
class RunOptions(typing.Generic[T]):
"""
Expand All @@ -40,15 +41,32 @@ class RunOptions(typing.Generic[T]):
serde: Serde[T] = DefaultSerde()
"""The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
See also 'type_hint'."""
max_attempts: Optional[int] = None
"""The maximum number of retry attempts, including the initial attempt, to complete the action.
If None, the action will be retried indefinitely, until it succeeds.
Otherwise, the action will be retried until the maximum number of attempts is reached and then it will raise a TerminalError."""
max_retry_duration: Optional[timedelta] = None
"""The maximum duration for retrying. If None, the action will be retried indefinitely, until it succeeds.
Otherwise, the action will be retried until the maximum duration is reached and then it will raise a TerminalError."""
type_hint: Optional[typing.Type[T]] = None
"""The type hint of the return value of the action. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer."""
max_attempts: Optional[int] = None
"""Max number of attempts (including the initial), before giving up.

When giving up, `ctx.run` will throw a `TerminalError` wrapping the original error message."""
max_duration: Optional[timedelta] = None
"""Max duration of retries, before giving up.

When giving up, `ctx.run` will throw a `TerminalError` wrapping the original error message."""
initial_retry_interval: Optional[timedelta] = None
"""Initial interval for the first retry attempt.
Retry interval will grow by a factor specified in `retry_interval_factor`.

If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy."""
max_retry_interval: Optional[timedelta] = None
"""Max interval between retries.
Retry interval will grow by a factor specified in `retry_interval_factor`.

The default is 10 seconds."""
retry_interval_factor: Optional[float] = None
"""Exponentiation factor to use when computing the next retry delay.

If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy."""
max_retry_duration: Optional[timedelta] = None
"""Deprecated: Use max_duration instead."""

# pylint: disable=R0903
class RestateDurableFuture(typing.Generic[T], Awaitable[T]):
Expand Down
57 changes: 36 additions & 21 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,18 +408,14 @@ async def must_take_notification(self, handle):

async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> None:
"""Create a coroutine to poll the handle."""
await self.take_and_send_output()
while True:
await self.take_and_send_output()
do_progress_response = self.vm.do_progress(handles)
if isinstance(do_progress_response, Exception):
# We might need to write out something at this point.
await self.take_and_send_output()
if isinstance(do_progress_response, BaseException):
# Print this exception, might be relevant for the user
traceback.print_exception(do_progress_response)
await cancel_current_task()
if isinstance(do_progress_response, Suspended):
# We might need to write out something at this point.
await self.take_and_send_output()
await cancel_current_task()
if isinstance(do_progress_response, DoProgressAnyCompleted):
# One of the handles completed
Expand All @@ -432,9 +428,10 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
assert fn is not None

async def wrapper(f):
await f()
await self.take_and_send_output()
await self.receive.enqueue_restate_event({ 'type' : 'restate.run_completed', 'data': None})
try:
await f()
finally:
await self.receive.enqueue_restate_event({ 'type' : 'restate.run_completed', 'data': None})

task = asyncio.create_task(wrapper(fn))
self.tasks.add(task)
Expand Down Expand Up @@ -542,9 +539,13 @@ async def create_run_coroutine(self,
action: RunAction[T],
serde: Serde[T],
max_attempts: Optional[int] = None,
max_retry_duration: Optional[timedelta] = None,
max_duration: Optional[timedelta] = None,
initial_retry_interval: Optional[timedelta] = None,
max_retry_interval: Optional[timedelta] = None,
retry_interval_factor: Optional[float] = None,
):
"""Create a coroutine to poll the handle."""
start = time.time()
try:
if inspect.iscoroutinefunction(action):
action_result: T = await action() # type: ignore
Expand All @@ -565,15 +566,20 @@ async def create_run_coroutine(self,
raise e from None
# pylint: disable=W0718
except Exception as e:
if max_attempts is None and max_retry_duration is None:
# no retry policy
# todo: log the error
self.vm.notify_error(repr(e), traceback.format_exc())
else:
failure = Failure(code=500, message=str(e))
max_duration_ms = None if max_retry_duration is None else int(max_retry_duration.total_seconds() * 1000)
config = RunRetryConfig(max_attempts=max_attempts, max_duration=max_duration_ms)
self.vm.propose_run_completion_transient(handle, failure=failure, attempt_duration_ms=1, config=config)
end = time.time()
attempt_duration = int((end - start) * 1000)
failure = Failure(code=500, message=str(e))
max_duration_ms = None if max_duration is None else int(max_duration.total_seconds() * 1000)
initial_retry_interval_ms = None if initial_retry_interval is None else int(initial_retry_interval.total_seconds() * 1000)
max_retry_interval_ms = None if max_retry_interval is None else int(max_retry_interval.total_seconds() * 1000)
config = RunRetryConfig(
max_attempts=max_attempts,
max_duration=max_duration_ms,
initial_interval=initial_retry_interval_ms,
max_interval=max_retry_interval_ms,
interval_factor=retry_interval_factor
)
self.vm.propose_run_completion_transient(handle, failure=failure, attempt_duration_ms=attempt_duration, config=config)
# pylint: disable=W0236
# pylint: disable=R0914
def run(self,
Expand All @@ -600,7 +606,7 @@ def run(self,
else:
# todo: we can also verify by looking at the signature that there are no missing parameters
noargs_action = action # type: ignore
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, noargs_action, serde, max_attempts, max_retry_duration)
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, noargs_action, serde, max_attempts, max_retry_duration, None, None, None)
return self.create_future(handle, serde) # type: ignore

def run_typed(
Expand All @@ -623,7 +629,16 @@ def run_typed(
update_restate_context_is_replaying(self.vm)

func = functools.partial(action, *args, **kwargs)
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, func, options.serde, options.max_attempts, options.max_retry_duration)
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(
handle,
func,
options.serde,
options.max_attempts,
options.max_duration,
options.initial_retry_interval,
options.max_retry_interval,
options.retry_interval_factor
)
return self.create_future(handle, options.serde)

def sleep(self, delta: timedelta) -> RestateDurableSleepFuture:
Expand Down
26 changes: 14 additions & 12 deletions python/restate/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ class Invocation:
@dataclass
class RunRetryConfig:
"""
Expo Retry Configuration
Exponential Retry Configuration

All duration/interval values are in milliseconds.
"""
initial_interval: typing.Optional[int] = None
max_attempts: typing.Optional[int] = None
max_duration: typing.Optional[int] = None
max_interval: typing.Optional[int] = None
interval_factor: typing.Optional[float] = None

@dataclass
class Failure:
Expand Down Expand Up @@ -394,22 +398,20 @@ def propose_run_completion_failure(self, handle: int, output: Failure) -> int:
return self.vm.propose_run_completion_failure(handle, res)

# pylint: disable=line-too-long
def propose_run_completion_transient(self, handle: int, failure: Failure, attempt_duration_ms: int, config: RunRetryConfig) -> int | None:
def propose_run_completion_transient(self, handle: int, failure: Failure, attempt_duration_ms: int, config: RunRetryConfig):
"""
Exit a side effect with a transient Error.
This requires a retry policy to be provided.
"""
py_failure = PyFailure(failure.code, failure.message)
py_config = PyExponentialRetryConfig(config.initial_interval, config.max_attempts, config.max_duration)
try:
handle = self.vm.propose_run_completion_failure_transient(handle, py_failure, attempt_duration_ms, py_config)
# The VM decided not to retry, therefore we get back an handle that will be resolved
# with a terminal failure.
return handle
# pylint: disable=bare-except
except:
# The VM decided to retry, therefore we tear down the current execution
return None
py_config = PyExponentialRetryConfig(
config.initial_interval,
config.max_attempts,
config.max_duration,
config.max_interval,
config.interval_factor
)
self.vm.propose_run_completion_failure_transient(handle, py_failure, attempt_duration_ms, py_config)

def sys_end(self):
"""
Expand Down
40 changes: 31 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::fmt;
use pyo3::create_exception;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyNone, PyString};
use restate_sdk_shared_core::fmt::{set_error_formatter, ErrorFormatter};
use restate_sdk_shared_core::{
CallHandle, CoreVM, DoProgressResponse, Error, Header, IdentityVerifier, Input, NonEmptyValue,
NotificationHandle, ResponseHead, RetryPolicy, RunExitResult, TakeOutputResult, Target,
TerminalFailure, VMOptions, Value, CANCEL_NOTIFICATION_HANDLE, VM,
};
use std::fmt;
use std::time::{Duration, SystemTime};
use restate_sdk_shared_core::fmt::{set_error_formatter, ErrorFormatter};

// Current crate version
const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION");
Expand Down Expand Up @@ -115,33 +115,55 @@ struct PyExponentialRetryConfig {
max_attempts: Option<u32>,
#[pyo3(get, set)]
max_duration: Option<u64>,
#[pyo3(get, set)]
max_interval: Option<u64>,
#[pyo3(get, set)]
factor: Option<f64>,
}

#[pymethods]
impl PyExponentialRetryConfig {
#[pyo3(signature = (initial_interval=None, max_attempts=None, max_duration=None))]
#[pyo3(signature = (initial_interval=None, max_attempts=None, max_duration=None, max_interval=None, factor=None))]
#[new]
fn new(
initial_interval: Option<u64>,
max_attempts: Option<u32>,
max_duration: Option<u64>,
max_interval: Option<u64>,
factor: Option<f64>,
) -> Self {
Self {
initial_interval,
max_attempts,
max_duration,
max_interval,
factor,
}
}
}

impl From<PyExponentialRetryConfig> for RetryPolicy {
fn from(value: PyExponentialRetryConfig) -> Self {
RetryPolicy::Exponential {
initial_interval: Duration::from_millis(value.initial_interval.unwrap_or(10)),
max_attempts: value.max_attempts,
max_duration: value.max_duration.map(Duration::from_millis),
factor: 2.0,
max_interval: None,
if value.initial_interval.is_some()
|| value.max_attempts.is_some()
|| value.max_duration.is_some()
|| value.max_interval.is_some()
|| value.factor.is_some()
{
// If any of the values are set, then let's create the exponential retry policy
RetryPolicy::Exponential {
initial_interval: Duration::from_millis(value.initial_interval.unwrap_or(50)),
max_attempts: value.max_attempts,
max_duration: value.max_duration.map(Duration::from_millis),
factor: value.factor.unwrap_or(2.0) as f32,
max_interval: value
.max_interval
.map(Duration::from_millis)
.or_else(|| Some(Duration::from_secs(10))),
}
} else {
// Let's use retry policy infinite here, which will give back control to the invocation retry policy
RetryPolicy::Infinite
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions test-services/services/failing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""example.py"""
from datetime import timedelta

# pylint: disable=C0116
# pylint: disable=W0613
# pylint: disable=W0622
Expand Down Expand Up @@ -61,7 +63,7 @@ def side_effect():
return eventual_success_side_effects
raise ValueError(f"Failed at attempt: {eventual_success_side_effects}")

options: RunOptions[int] = RunOptions(max_attempts=minimum_attempts + 1)
options: RunOptions[int] = RunOptions(max_attempts=minimum_attempts + 1, initial_retry_interval=timedelta(milliseconds=1), retry_interval_factor=1.0)
return await ctx.run_typed("sideEffect", side_effect, options)

eventual_failure_side_effects = 0
Expand All @@ -75,7 +77,7 @@ def side_effect():
raise ValueError(f"Failed at attempt: {eventual_failure_side_effects}")

try:
options: RunOptions[int] = RunOptions(max_attempts=retry_policy_max_retry_count)
options: RunOptions[int] = RunOptions(max_attempts=retry_policy_max_retry_count, initial_retry_interval=timedelta(milliseconds=1), retry_interval_factor=1.0)
await ctx.run_typed("sideEffect", side_effect, options)
raise ValueError("Side effect did not fail.")
except TerminalError as t:
Expand Down