Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# pylint: disable=C0116
# pylint: disable=W0613

import logging
import restate

from greeter import greeter
Expand All @@ -21,6 +22,8 @@
from pydantic_greeter import pydantic_greeter
from concurrent_greeter import concurrent_greeter

logging.basicConfig(level=logging.INFO)

app = restate.app(services=[greeter,
random_greeter,
counter,
Expand Down
6 changes: 6 additions & 0 deletions examples/greeter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
# pylint: disable=W0613

from restate import Service, Context
import restate

# Use restate.getLogger to create a logger that hides logs on replay
# To configure logging, just use the usual std logging configuration (see example.py for an example)
logger = restate.getLogger()

greeter = Service("greeter")

@greeter.handler()
async def greet(ctx: Context, name: str) -> str:
logger.info("Received greeting request: %s", name)
return f"Hello {name}!"
6 changes: 5 additions & 1 deletion python/restate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

from .endpoint import app

from .logging import getLogger, RestateLoggingFilter

try:
from .harness import test_harness # type: ignore
except ImportError:
Expand Down Expand Up @@ -57,5 +59,7 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore
"gather",
"as_completed",
"wait_completed",
"select"
"select",
"logging",
"RestateLoggingFilter"
]
46 changes: 46 additions & 0 deletions python/restate/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""
This module contains the logging utilities for restate handlers.
"""
import logging

from .server_context import restate_context_is_replaying

# pylint: disable=C0103
def getLogger(name=None):
"""
Wrapper for logging.getLogger returning a logger configured with RestateLoggingFilter

:param name: the logger name, if any
:return: the logger as returned by logging.getLogger, configured with the RestateLoggingFilter
"""
logger = logging.getLogger(name)
logger.addFilter(RestateLoggingFilter())
return logger

# pylint: disable=R0903
class RestateLoggingFilter(logging.Filter):
"""
Restate logging filter. This filter will filter out logs on replay
"""

def filter(self, record):
# First, apply the filter base logic
if not super().filter(record):
return False

# Read the context variable, check if we're replaying
if restate_context_is_replaying.get():
return False

# We're not replaying, all good pass the event
return True
38 changes: 35 additions & 3 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,15 @@ def __init__(self, server_context: "ServerInvocationContext", name, serde) -> No

def value(self) -> RestateDurableFuture[Any]:
handle = self.server_context.vm.sys_get_promise(self.name)
update_restate_context_is_replaying(self.server_context.vm)
return self.server_context.create_future(handle, self.serde)

def resolve(self, value: Any) -> Awaitable[None]:
vm: VMWrapper = self.server_context.vm
assert self.serde is not None
value_buffer = self.serde.serialize(value)
handle = vm.sys_complete_promise_success(self.name, value_buffer)
update_restate_context_is_replaying(self.server_context.vm)

async def await_point():
if not self.server_context.vm.is_completed(handle):
Expand All @@ -206,6 +208,7 @@ def reject(self, message: str, code: int = 500) -> Awaitable[None]:
vm: VMWrapper = self.server_context.vm
py_failure = Failure(code=code, message=message)
handle = vm.sys_complete_promise_failure(self.name, py_failure)
update_restate_context_is_replaying(self.server_context.vm)

async def await_point():
if not self.server_context.vm.is_completed(handle):
Expand All @@ -217,6 +220,7 @@ async def await_point():
def peek(self) -> Awaitable[Any | None]:
vm: VMWrapper = self.server_context.vm
handle = vm.sys_peek_promise(self.name)
update_restate_context_is_replaying(self.server_context.vm)
serde = self.serde
assert serde is not None
return self.server_context.create_future(handle, serde)
Expand Down Expand Up @@ -263,6 +267,12 @@ def cancel(self):
for task in to_cancel:
task.cancel()

restate_context_is_replaying = contextvars.ContextVar('restate_context_is_replaying', default=False)

def update_restate_context_is_replaying(vm: VMWrapper):
"""Update the context var 'restate_context_is_replaying'. This should be called after each vm.sys_*"""
restate_context_is_replaying.set(vm.is_replaying())

# pylint: disable=R0902
class ServerInvocationContext(ObjectContext):
"""This class implements the context for the restate framework based on the server."""
Expand All @@ -289,13 +299,16 @@ def __init__(self,

async def enter(self):
"""Invoke the user code."""
update_restate_context_is_replaying(self.vm)
try:
in_buffer = self.invocation.input_buffer
out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer)
restate_context_is_replaying.set(False)
self.vm.sys_write_output_success(bytes(out_buffer))
self.vm.sys_end()
except TerminalError as t:
failure = Failure(code=t.status_code, message=t.message)
restate_context_is_replaying.set(False)
self.vm.sys_write_output_failure(failure)
self.vm.sys_end()
# pylint: disable=W0718
Expand Down Expand Up @@ -341,6 +354,7 @@ async def leave(self):

def on_attempt_finished(self):
"""Notify the attempt finished event."""
restate_context_is_replaying.set(False)
self.request_finished_event.set()
try:
self.tasks.cancel()
Expand Down Expand Up @@ -446,25 +460,31 @@ def get(self, name: str,
type_hint: Optional[typing.Type[T]] = None
) -> Awaitable[Optional[T]]:
handle = self.vm.sys_get_state(name)
update_restate_context_is_replaying(self.vm)
if isinstance(serde, DefaultSerde):
serde = serde.with_maybe_type(type_hint)
return self.create_future(handle, serde) # type: ignore

def state_keys(self) -> Awaitable[List[str]]:
return self.create_future(self.vm.sys_get_state_keys())
handle = self.vm.sys_get_state_keys()
update_restate_context_is_replaying(self.vm)
return self.create_future(handle)

def set(self, name: str, value: T, serde: Serde[T] = DefaultSerde()) -> None:
"""Set the value associated with the given name."""
if isinstance(serde, DefaultSerde):
serde = serde.with_maybe_type(type(value))
buffer = serde.serialize(value)
self.vm.sys_set_state(name, bytes(buffer))
update_restate_context_is_replaying(self.vm)

def clear(self, name: str) -> None:
self.vm.sys_clear_state(name)
update_restate_context_is_replaying(self.vm)

def clear_all(self) -> None:
self.vm.sys_clear_all_state()
update_restate_context_is_replaying(self.vm)

def request(self) -> Request:
return Request(
Expand Down Expand Up @@ -542,6 +562,7 @@ def run(self,
serde = serde.with_maybe_type(type_hint)

handle = self.vm.sys_run(name)
update_restate_context_is_replaying(self.vm)

if args is not None:
noargs_action = functools.partial(action, *args)
Expand All @@ -566,6 +587,7 @@ def run_typed(
options.type_hint = signature.return_annotation
options.serde = options.serde.with_maybe_type(options.type_hint)
handle = self.vm.sys_run(name)
update_restate_context_is_replaying(self.vm)

func = functools.partial(action, *args, **kwargs)
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, func, options.serde, options.max_attempts, options.max_retry_duration)
Expand All @@ -574,7 +596,9 @@ def run_typed(
def sleep(self, delta: timedelta) -> RestateDurableSleepFuture:
# convert timedelta to milliseconds
millis = int(delta.total_seconds() * 1000)
return self.create_sleep_future(self.vm.sys_sleep(millis)) # type: ignore
handle = self.vm.sys_sleep(millis)
update_restate_context_is_replaying(self.vm)
return self.create_sleep_future(handle) # type: ignore

def do_call(self,
tpe: HandlerType[I, O],
Expand Down Expand Up @@ -615,9 +639,11 @@ def do_raw_call(self,
if send_delay:
ms = int(send_delay.total_seconds() * 1000)
send_handle = self.vm.sys_send(service, handler, parameter, key, delay=ms, idempotency_key=idempotency_key, headers=headers_kvs)
update_restate_context_is_replaying(self.vm)
return ServerSendHandle(self, send_handle)
if send:
send_handle = self.vm.sys_send(service, handler, parameter, key, idempotency_key=idempotency_key, headers=headers_kvs)
update_restate_context_is_replaying(self.vm)
return ServerSendHandle(self, send_handle)

handle = self.vm.sys_call(service=service,
Expand All @@ -626,6 +652,7 @@ def do_raw_call(self,
key=key,
idempotency_key=idempotency_key,
headers=headers_kvs)
update_restate_context_is_replaying(self.vm)

return self.create_call_future(handle=handle.result_handle,
invocation_id_handle=handle.invocation_id_handle,
Expand Down Expand Up @@ -712,6 +739,7 @@ def awakeable(self,
if isinstance(serde, DefaultSerde):
serde = serde.with_maybe_type(type_hint)
name, handle = self.vm.sys_awakeable()
update_restate_context_is_replaying(self.vm)
return name, self.create_future(handle, serde)

def resolve_awakeable(self,
Expand All @@ -722,9 +750,11 @@ def resolve_awakeable(self,
serde = serde.with_maybe_type(type(value))
buf = serde.serialize(value)
self.vm.sys_resolve_awakeable(name, buf)
update_restate_context_is_replaying(self.vm)

def reject_awakeable(self, name: str, failure_message: str, failure_code: int = 500) -> None:
return self.vm.sys_reject_awakeable(name, Failure(code=failure_code, message=failure_message))
self.vm.sys_reject_awakeable(name, Failure(code=failure_code, message=failure_message))
update_restate_context_is_replaying(self.vm)

def promise(self, name: str, serde: typing.Optional[Serde[T]] = JsonSerde(), type_hint: Optional[typing.Type[T]] = None) -> DurablePromise[T]:
"""Create a durable promise."""
Expand All @@ -740,6 +770,7 @@ def cancel_invocation(self, invocation_id: str):
if invocation_id is None:
raise ValueError("invocation_id cannot be None")
self.vm.sys_cancel(invocation_id)
update_restate_context_is_replaying(self.vm)

def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(),
type_hint: Optional[typing.Type[T]] = None
Expand All @@ -749,4 +780,5 @@ def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde()
if isinstance(serde, DefaultSerde):
serde = serde.with_maybe_type(type_hint)
handle = self.vm.attach_invocation(invocation_id)
update_restate_context_is_replaying(self.vm)
return self.create_future(handle, serde)
4 changes: 4 additions & 0 deletions python/restate/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,7 @@ def attach_invocation(self, invocation_id: str) -> int:
Attach to an invocation
"""
return self.vm.attach_invocation(invocation_id)

def is_replaying(self) -> bool:
"""Returns true if the state machine is replaying."""
return self.vm.is_replaying()
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,10 @@ impl PyVM {
fn sys_end(mut self_: PyRefMut<'_, Self>) -> Result<(), PyVMError> {
self_.vm.sys_end().map_err(Into::into)
}

fn is_replaying(self_: PyRef<'_, Self>) -> bool {
self_.vm.is_replaying()
}
}

#[pyclass]
Expand Down