Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/restate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@

from .logging import getLogger, RestateLoggingFilter


try:
from .harness import create_test_harness, test_harness # type: ignore
except ImportError:
Expand Down
15 changes: 15 additions & 0 deletions python/restate/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""This module contains internal extensions apis"""

from .server_context import current_context

__all__ = ["current_context"]
12 changes: 12 additions & 0 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import time

from restate.context import (
Context,
DurablePromise,
AttemptFinishedEvent,
HandlerType,
Expand Down Expand Up @@ -302,6 +303,14 @@ def update_restate_context_is_replaying(vm: VMWrapper):
restate_context_is_replaying.set(vm.is_replaying())


_restate_context_var = contextvars.ContextVar[Context]("restate_context")


def current_context() -> Context | None:
"""Get the current context."""
return _restate_context_var.get()


# pylint: disable=R0902
class ServerInvocationContext(ObjectContext):
"""This class implements the context for the restate framework based on the server."""
Expand Down Expand Up @@ -330,6 +339,7 @@ def __init__(
async def enter(self):
"""Invoke the user code."""
update_restate_context_is_replaying(self.vm)
token = _restate_context_var.set(self)
try:
in_buffer = self.invocation.input_buffer
out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer)
Expand All @@ -356,6 +366,8 @@ async def enter(self):
stacktrace = "\n".join(traceback.format_exception(e))
self.vm.notify_error(repr(e), stacktrace)
raise e
finally:
_restate_context_var.reset(token)

async def leave(self):
"""Leave the context."""
Expand Down
64 changes: 64 additions & 0 deletions tests/ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#

import restate
from restate import (
Context,
Service,
HarnessEnvironment,
)
import pytest

# ----- Asyncio fixtures


@pytest.fixture(scope="session")
def anyio_backend():
return "asyncio"


pytestmark = [
pytest.mark.anyio,
]

# -------- Restate services and restate fixture

greeter = Service("greeter")


def magic_function():
from restate.extensions import current_context

ctx = current_context()
assert ctx is not None
return ctx.request().id


@greeter.handler()
async def greet(ctx: Context, name: str) -> str:
id = magic_function()
return f"Hello {id}!"


@pytest.fixture(scope="session")
async def restate_test_harness():
async with restate.create_test_harness(
restate.app([greeter]), restate_image="ghcr.io/restatedev/restate:latest"
) as harness:
yield harness


# ----- Tests


async def test_greeter(restate_test_harness: HarnessEnvironment):
greeting = await restate_test_harness.client.service_call(greet, arg="bob")
assert greeting.startswith("Hello ")