diff --git a/examples/example.py b/examples/example.py index ad29396..e19c769 100644 --- a/examples/example.py +++ b/examples/example.py @@ -12,6 +12,7 @@ # pylint: disable=C0116 # pylint: disable=W0613 +import logging import restate from greeter import greeter @@ -21,6 +22,8 @@ from pydantic_greeter import pydantic_greeter from concurrent_greeter import concurrent_greeter +logging.basicConfig(level=logging.INFO) + app = restate.app(services=[greeter, random_greeter, counter, diff --git a/examples/greeter.py b/examples/greeter.py index 502d0b9..2d1af12 100644 --- a/examples/greeter.py +++ b/examples/greeter.py @@ -13,9 +13,15 @@ # pylint: disable=W0613 from restate import Service, Context +import restate + +# Use restate.getLogger to create a logger that hides logs on replay +# To configure logging, just use the usual std logging configuration (see example.py for an example) +logger = restate.getLogger() greeter = Service("greeter") @greeter.handler() async def greet(ctx: Context, name: str) -> str: + logger.info("Received greeting request: %s", name) return f"Hello {name}!" diff --git a/python/restate/__init__.py b/python/restate/__init__.py index 252f8ee..321d703 100644 --- a/python/restate/__init__.py +++ b/python/restate/__init__.py @@ -26,6 +26,8 @@ from .endpoint import app +from .logging import getLogger, RestateLoggingFilter + try: from .harness import test_harness # type: ignore except ImportError: @@ -57,5 +59,7 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore "gather", "as_completed", "wait_completed", - "select" + "select", + "logging", + "RestateLoggingFilter" ] diff --git a/python/restate/logging.py b/python/restate/logging.py new file mode 100644 index 0000000..6038594 --- /dev/null +++ b/python/restate/logging.py @@ -0,0 +1,46 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +""" +This module contains the logging utilities for restate handlers. +""" +import logging + +from .server_context import restate_context_is_replaying + +# pylint: disable=C0103 +def getLogger(name=None): + """ + Wrapper for logging.getLogger returning a logger configured with RestateLoggingFilter + + :param name: the logger name, if any + :return: the logger as returned by logging.getLogger, configured with the RestateLoggingFilter + """ + logger = logging.getLogger(name) + logger.addFilter(RestateLoggingFilter()) + return logger + +# pylint: disable=R0903 +class RestateLoggingFilter(logging.Filter): + """ + Restate logging filter. This filter will filter out logs on replay + """ + + def filter(self, record): + # First, apply the filter base logic + if not super().filter(record): + return False + + # Read the context variable, check if we're replaying + if restate_context_is_replaying.get(): + return False + + # We're not replaying, all good pass the event + return True diff --git a/python/restate/server_context.py b/python/restate/server_context.py index ae2582f..ff53ba9 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -187,6 +187,7 @@ def __init__(self, server_context: "ServerInvocationContext", name, serde) -> No def value(self) -> RestateDurableFuture[Any]: handle = self.server_context.vm.sys_get_promise(self.name) + update_restate_context_is_replaying(self.server_context.vm) return self.server_context.create_future(handle, self.serde) def resolve(self, value: Any) -> Awaitable[None]: @@ -194,6 +195,7 @@ def resolve(self, value: Any) -> Awaitable[None]: assert self.serde is not None value_buffer = self.serde.serialize(value) handle = vm.sys_complete_promise_success(self.name, value_buffer) + update_restate_context_is_replaying(self.server_context.vm) async def await_point(): if not self.server_context.vm.is_completed(handle): @@ -206,6 +208,7 @@ def reject(self, message: str, code: int = 500) -> Awaitable[None]: vm: VMWrapper = self.server_context.vm py_failure = Failure(code=code, message=message) handle = vm.sys_complete_promise_failure(self.name, py_failure) + update_restate_context_is_replaying(self.server_context.vm) async def await_point(): if not self.server_context.vm.is_completed(handle): @@ -217,6 +220,7 @@ async def await_point(): def peek(self) -> Awaitable[Any | None]: vm: VMWrapper = self.server_context.vm handle = vm.sys_peek_promise(self.name) + update_restate_context_is_replaying(self.server_context.vm) serde = self.serde assert serde is not None return self.server_context.create_future(handle, serde) @@ -263,6 +267,12 @@ def cancel(self): for task in to_cancel: task.cancel() +restate_context_is_replaying = contextvars.ContextVar('restate_context_is_replaying', default=False) + +def update_restate_context_is_replaying(vm: VMWrapper): + """Update the context var 'restate_context_is_replaying'. This should be called after each vm.sys_*""" + restate_context_is_replaying.set(vm.is_replaying()) + # pylint: disable=R0902 class ServerInvocationContext(ObjectContext): """This class implements the context for the restate framework based on the server.""" @@ -289,13 +299,16 @@ def __init__(self, async def enter(self): """Invoke the user code.""" + update_restate_context_is_replaying(self.vm) try: in_buffer = self.invocation.input_buffer out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer) + restate_context_is_replaying.set(False) self.vm.sys_write_output_success(bytes(out_buffer)) self.vm.sys_end() except TerminalError as t: failure = Failure(code=t.status_code, message=t.message) + restate_context_is_replaying.set(False) self.vm.sys_write_output_failure(failure) self.vm.sys_end() # pylint: disable=W0718 @@ -341,6 +354,7 @@ async def leave(self): def on_attempt_finished(self): """Notify the attempt finished event.""" + restate_context_is_replaying.set(False) self.request_finished_event.set() try: self.tasks.cancel() @@ -446,12 +460,15 @@ def get(self, name: str, type_hint: Optional[typing.Type[T]] = None ) -> Awaitable[Optional[T]]: handle = self.vm.sys_get_state(name) + update_restate_context_is_replaying(self.vm) if isinstance(serde, DefaultSerde): serde = serde.with_maybe_type(type_hint) return self.create_future(handle, serde) # type: ignore def state_keys(self) -> Awaitable[List[str]]: - return self.create_future(self.vm.sys_get_state_keys()) + handle = self.vm.sys_get_state_keys() + update_restate_context_is_replaying(self.vm) + return self.create_future(handle) def set(self, name: str, value: T, serde: Serde[T] = DefaultSerde()) -> None: """Set the value associated with the given name.""" @@ -459,12 +476,15 @@ def set(self, name: str, value: T, serde: Serde[T] = DefaultSerde()) -> None: serde = serde.with_maybe_type(type(value)) buffer = serde.serialize(value) self.vm.sys_set_state(name, bytes(buffer)) + update_restate_context_is_replaying(self.vm) def clear(self, name: str) -> None: self.vm.sys_clear_state(name) + update_restate_context_is_replaying(self.vm) def clear_all(self) -> None: self.vm.sys_clear_all_state() + update_restate_context_is_replaying(self.vm) def request(self) -> Request: return Request( @@ -542,6 +562,7 @@ def run(self, serde = serde.with_maybe_type(type_hint) handle = self.vm.sys_run(name) + update_restate_context_is_replaying(self.vm) if args is not None: noargs_action = functools.partial(action, *args) @@ -566,6 +587,7 @@ def run_typed( options.type_hint = signature.return_annotation options.serde = options.serde.with_maybe_type(options.type_hint) handle = self.vm.sys_run(name) + update_restate_context_is_replaying(self.vm) func = functools.partial(action, *args, **kwargs) self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, func, options.serde, options.max_attempts, options.max_retry_duration) @@ -574,7 +596,9 @@ def run_typed( def sleep(self, delta: timedelta) -> RestateDurableSleepFuture: # convert timedelta to milliseconds millis = int(delta.total_seconds() * 1000) - return self.create_sleep_future(self.vm.sys_sleep(millis)) # type: ignore + handle = self.vm.sys_sleep(millis) + update_restate_context_is_replaying(self.vm) + return self.create_sleep_future(handle) # type: ignore def do_call(self, tpe: HandlerType[I, O], @@ -615,9 +639,11 @@ def do_raw_call(self, if send_delay: ms = int(send_delay.total_seconds() * 1000) send_handle = self.vm.sys_send(service, handler, parameter, key, delay=ms, idempotency_key=idempotency_key, headers=headers_kvs) + update_restate_context_is_replaying(self.vm) return ServerSendHandle(self, send_handle) if send: send_handle = self.vm.sys_send(service, handler, parameter, key, idempotency_key=idempotency_key, headers=headers_kvs) + update_restate_context_is_replaying(self.vm) return ServerSendHandle(self, send_handle) handle = self.vm.sys_call(service=service, @@ -626,6 +652,7 @@ def do_raw_call(self, key=key, idempotency_key=idempotency_key, headers=headers_kvs) + update_restate_context_is_replaying(self.vm) return self.create_call_future(handle=handle.result_handle, invocation_id_handle=handle.invocation_id_handle, @@ -712,6 +739,7 @@ def awakeable(self, if isinstance(serde, DefaultSerde): serde = serde.with_maybe_type(type_hint) name, handle = self.vm.sys_awakeable() + update_restate_context_is_replaying(self.vm) return name, self.create_future(handle, serde) def resolve_awakeable(self, @@ -722,9 +750,11 @@ def resolve_awakeable(self, serde = serde.with_maybe_type(type(value)) buf = serde.serialize(value) self.vm.sys_resolve_awakeable(name, buf) + update_restate_context_is_replaying(self.vm) def reject_awakeable(self, name: str, failure_message: str, failure_code: int = 500) -> None: - return self.vm.sys_reject_awakeable(name, Failure(code=failure_code, message=failure_message)) + self.vm.sys_reject_awakeable(name, Failure(code=failure_code, message=failure_message)) + update_restate_context_is_replaying(self.vm) def promise(self, name: str, serde: typing.Optional[Serde[T]] = JsonSerde(), type_hint: Optional[typing.Type[T]] = None) -> DurablePromise[T]: """Create a durable promise.""" @@ -740,6 +770,7 @@ def cancel_invocation(self, invocation_id: str): if invocation_id is None: raise ValueError("invocation_id cannot be None") self.vm.sys_cancel(invocation_id) + update_restate_context_is_replaying(self.vm) def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None @@ -749,4 +780,5 @@ def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde() if isinstance(serde, DefaultSerde): serde = serde.with_maybe_type(type_hint) handle = self.vm.attach_invocation(invocation_id) + update_restate_context_is_replaying(self.vm) return self.create_future(handle, serde) diff --git a/python/restate/vm.py b/python/restate/vm.py index 1ce98ab..b8a5fd3 100644 --- a/python/restate/vm.py +++ b/python/restate/vm.py @@ -423,3 +423,7 @@ def attach_invocation(self, invocation_id: str) -> int: Attach to an invocation """ return self.vm.attach_invocation(invocation_id) + + def is_replaying(self) -> bool: + """Returns true if the state machine is replaying.""" + return self.vm.is_replaying() diff --git a/src/lib.rs b/src/lib.rs index 4c58141..ba7f555 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -703,6 +703,10 @@ impl PyVM { fn sys_end(mut self_: PyRefMut<'_, Self>) -> Result<(), PyVMError> { self_.vm.sys_end().map_err(Into::into) } + + fn is_replaying(self_: PyRef<'_, Self>) -> bool { + self_.vm.is_replaying() + } } #[pyclass]