Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Source = "https://github.com/restatedev/sdk-python"
"Bug Tracker" = "https://github.com/restatedev/sdk-python/issues"

[project.optional-dependencies]
test = ["pytest", "hypercorn"]
test = ["pytest", "hypercorn", "anyio"]
lint = ["mypy>=1.11.2", "pyright>=1.1.390", "ruff>=0.6.9"]
harness = ["testcontainers", "hypercorn", "httpx"]
serde = ["dacite", "pydantic"]
Expand Down Expand Up @@ -53,3 +53,4 @@ ignore = ["E741"]
filterwarnings = [
"ignore:The @wait_container_is_ready decorator is deprecated:DeprecationWarning",
]
anyio_mode = "auto"
9 changes: 5 additions & 4 deletions python/restate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import typing

from restate.server_types import RestateAppT
from restate.types import TestHarnessEnvironment
from restate.types import HarnessEnvironment

from .service import Service
from .object import VirtualObject
Expand Down Expand Up @@ -55,18 +55,18 @@
def create_test_harness(
app: RestateAppT,
follow_logs: bool = False,
restate_image: str = "restatedev/restate:latest",
restate_image: str = "docker.io/restatedev/restate:latest",
always_replay: bool = False,
disable_retries: bool = False,
) -> typing.AsyncGenerator[TestHarnessEnvironment, None]:
) -> typing.AsyncGenerator[HarnessEnvironment, None]:
"""a dummy harness constructor to raise ImportError. Install restate-sdk[harness] to use this feature"""
raise ImportError("Install restate-sdk[harness] to use this feature")

@typing.no_type_check
def test_harness(
app: RestateAppT,
follow_logs: bool = False,
restate_image: str = "restatedev/restate:latest",
restate_image: str = "docker.io/restatedev/restate:latest",
always_replay: bool = False,
disable_retries: bool = False,
):
Expand Down Expand Up @@ -107,6 +107,7 @@ async def create_client(
"app",
"create_test_harness",
"test_harness",
"HarnessEnvironment",
"gather",
"as_completed",
"wait_completed",
Expand Down
10 changes: 5 additions & 5 deletions python/restate/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from hypercorn.asyncio import serve
from restate.client import create_client
from restate.server_types import RestateAppT
from restate.types import TestHarnessEnvironment
from restate.types import HarnessEnvironment
from testcontainers.core.container import DockerContainer # type: ignore
from testcontainers.core.wait_strategies import CompositeWaitStrategy, HttpWaitStrategy

Expand Down Expand Up @@ -314,7 +314,7 @@ def create_restate_container(
def test_harness(
app: RestateAppT,
follow_logs: bool = False,
restate_image: str = "restatedev/restate:latest",
restate_image: str = "docker.io/restatedev/restate:latest",
always_replay: bool = False,
disable_retries: bool = False,
) -> RestateTestHarness:
Expand All @@ -334,10 +334,10 @@ def test_harness(
async def create_test_harness(
app: RestateAppT,
follow_logs: bool = False,
restate_image: str = "restatedev/restate:latest",
restate_image: str = "docker.io/restatedev/restate:latest",
always_replay: bool = False,
disable_retries: bool = False,
) -> typing.AsyncGenerator[TestHarnessEnvironment, None]:
) -> typing.AsyncGenerator[HarnessEnvironment, None]:
"""
Creates a test harness for running Restate services together with restate-server.

Expand Down Expand Up @@ -377,6 +377,6 @@ async def create_test_harness(
raise AssertionError(msg)

async with create_client(runtime.ingress_url()) as client:
yield TestHarnessEnvironment(
yield HarnessEnvironment(
ingress_url=runtime.ingress_url(), admin_api_url=runtime.admin_url(), client=client
)
2 changes: 1 addition & 1 deletion python/restate/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


@dataclass
class TestHarnessEnvironment:
class HarnessEnvironment:
"""Information about the test environment"""

ingress_url: str
Expand Down
142 changes: 142 additions & 0 deletions tests/harness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#

import uuid
import restate
from restate import (
Context,
Service,
HarnessEnvironment,
VirtualObject,
ObjectContext,
ObjectSharedContext,
Workflow,
WorkflowContext,
getLogger,
WorkflowSharedContext,
)
import pytest
import asyncio

# ----- Asyncio fixtures


@pytest.fixture(scope="session")
def anyio_backend():
return "asyncio"


pytestmark = [
pytest.mark.anyio,
]

# -------- Restate services and restate fixture

greeter = Service("greeter")


@greeter.handler()
async def greet(ctx: Context, name: str) -> str:
return f"Hello {name}!"


counter = VirtualObject("counter")


@counter.handler()
async def increment(ctx: ObjectContext, value: int) -> int:
n = await ctx.get("counter", type_hint=int) or 0
n += value
ctx.set("counter", n)
return n


@counter.handler(kind="shared")
async def count(ctx: ObjectSharedContext) -> int:
return await ctx.get("counter") or 0


payment = Workflow("payment")
payment_logger = getLogger("payment")


@payment.main()
async def pay(ctx: WorkflowContext):
ctx.set("status", "verifying payment")

def payment_gateway():
payment_logger.info("Doing payment work")

await ctx.run_typed("payment", payment_gateway)

ctx.set("status", "waiting for the payment provider to approve")

# Wait for the payment to be verified
result = await ctx.promise("verify.payment", type_hint=str).value()
return f"Verified {result}!"


@payment.handler()
async def payment_verified(ctx: WorkflowSharedContext, result: str):
promise = ctx.promise("verify.payment", type_hint=str)
await promise.resolve(result)


@pytest.fixture(scope="session")
async def restate_test_harness():
async with restate.create_test_harness(
restate.app([greeter, counter, payment]), restate_image="ghcr.io/restatedev/restate:latest"
) as harness:
yield harness


# ----- Tests


async def test_greeter(restate_test_harness: HarnessEnvironment):
greeting = await restate_test_harness.client.service_call(greet, arg="Pippo")

assert greeting == "Hello Pippo!"


async def test_counter(restate_test_harness: HarnessEnvironment):
random_key = str(uuid.uuid4())
initial_count = await restate_test_harness.client.object_call(count, key=random_key, arg=None)
await restate_test_harness.client.object_call(increment, key=random_key, arg=1)
new_count = await restate_test_harness.client.object_call(count, key=random_key, arg=None)

assert new_count == initial_count + 1


async def test_idempotency_key(restate_test_harness: HarnessEnvironment):
random_key = str(uuid.uuid4())
initial_count = await restate_test_harness.client.object_call(count, key=random_key, arg=None)
await restate_test_harness.client.object_call(increment, key=random_key, arg=1, idempotency_key=random_key)
await restate_test_harness.client.object_call(increment, key=random_key, arg=1, idempotency_key=random_key)
new_count = await restate_test_harness.client.object_call(count, key=random_key, arg=None)

assert new_count == initial_count + 1


async def test_workflow(restate_test_harness: HarnessEnvironment):
random_key = str(uuid.uuid4())
call_task = asyncio.create_task(restate_test_harness.client.workflow_call(pay, key=random_key, arg=None))

await restate_test_harness.client.workflow_call(payment_verified, key=random_key, arg="Done")

assert await call_task == "Verified Done!"


async def test_send(restate_test_harness: HarnessEnvironment):
invocation_handle = await restate_test_harness.client.service_send(greet, arg="Pippo")

assert invocation_handle.status_code == 200
assert len(invocation_handle.invocation_id) > 0
Loading