Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 120 additions & 27 deletions python/restate/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,25 @@
from dataclasses import dataclass
from datetime import timedelta
from inspect import Signature
from typing import Any, AsyncContextManager, Callable, Awaitable, Dict, Generic, List, Literal, Optional, TypeVar
from typing import (
Any,
AsyncContextManager,
Callable,
Awaitable,
Dict,
Generic,
List,
Literal,
Optional,
TypeVar,
)

from restate.retry_policy import InvocationRetryPolicy

from restate.context import HandlerType
from restate.exceptions import TerminalError
from restate.serde import DefaultSerde, PydanticJsonSerde, MsgspecJsonSerde, Serde, is_pydantic, Msgspec
from restate.types import extract_core_type

I = TypeVar("I")
O = TypeVar("O")
Expand Down Expand Up @@ -78,48 +90,129 @@ class HandlerIO(Generic[I, O]):
output_type: Optional[TypeHint[O]] = None


def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
def _json_schema_wrap_as_optional(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Augment handler_io with additional information about the input and output types.
modify the given JSON schema with its type wrapped as optional (nullable).
"""
t = schema.get("type")

if t is None:
# If type is unspecified, leave it open by only adding "null"
schema["type"] = ["null"]
return schema

if isinstance(t, list):
if "null" not in t:
t.append("null")
else:
if t != "null":
schema["type"] = [t, "null"]

return schema


def _make_json_schema_generator(
original: Callable[[], Dict[str, Any]], type: Literal["optional", "simple"]
) -> Callable[[], Dict[str, Any]]:
"""
Create a JSON schema generator that handles optional types.

If the type is optional, the generated schema will include "null" in the type.
"""
if type == "simple":
return original

def generator() -> Dict[str, Any]:
schema = original()
if type == "optional":
return _json_schema_wrap_as_optional(schema)

assert False, "unreachable"

return generator


def update_handler_io_with_input_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
"""
Augment handler_io with additional information about the input type.

This function has a special check for msgspec Structs and Pydantic models when these are provided.
This method will inspect the signature of an handler and will look for
the input and the return types of a function, and will:
the input type of a function, and will:
* capture any msgspec Structs or Pydantic models (to be used later at discovery)
* replace the default json serializer (is unchanged by a user) with the appropriate serde
"""
params = list(signature.parameters.values())
if len(params) == 1:
# if there is only one parameter, it is the context.
handler_io.input_type = TypeHint(is_void=True)
else:
annotation = params[-1].annotation
handler_io.input_type = TypeHint(annotation=annotation)
if Msgspec.is_struct(annotation):
handler_io.input_type.generate_json_schema = lambda: Msgspec.json_schema(annotation)
if isinstance(handler_io.input_serde, DefaultSerde):
handler_io.input_serde = MsgspecJsonSerde(annotation)
elif is_pydantic(annotation):
handler_io.input_type.generate_json_schema = lambda: annotation.model_json_schema(mode="serialization")
if isinstance(handler_io.input_serde, DefaultSerde):
handler_io.input_serde = PydanticJsonSerde(annotation)
return

annotation = params[-1].annotation
core_kind, core_type = extract_core_type(annotation)
handler_io.input_type = TypeHint(annotation=core_type)
if Msgspec.is_struct(core_type):
handler_io.input_type.generate_json_schema = _make_json_schema_generator(
lambda: Msgspec.json_schema(core_type), core_kind
)
if isinstance(handler_io.input_serde, DefaultSerde):
handler_io.input_serde = MsgspecJsonSerde(core_type)
return

if is_pydantic(core_type):
handler_io.input_type.generate_json_schema = _make_json_schema_generator(
lambda: core_type.model_json_schema(mode="serialization"), core_kind
)
if isinstance(handler_io.input_serde, DefaultSerde):
handler_io.input_serde = PydanticJsonSerde(core_type)


def update_handler_io_with_return_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
"""
Augment handler_io with additional information about the output type.

This function has a special check for msgspec Structs and Pydantic models when these are provided.
This method will inspect the signature of an handler and will look for
the return type of a function, and will:
* capture any msgspec Structs or Pydantic models (to be used later at discovery)
* replace the default json serializer (is unchanged by a user) with the appropriate serde
"""
return_annotation = signature.return_annotation
if return_annotation is None or return_annotation is Signature.empty:
# if there is no return annotation, we assume it is void
handler_io.output_type = TypeHint(is_void=True)
else:
handler_io.output_type = TypeHint(annotation=return_annotation)
if Msgspec.is_struct(return_annotation):
handler_io.output_type.generate_json_schema = lambda: Msgspec.json_schema(return_annotation)
if isinstance(handler_io.output_serde, DefaultSerde):
handler_io.output_serde = MsgspecJsonSerde(return_annotation)
elif is_pydantic(return_annotation):
handler_io.output_type.generate_json_schema = lambda: return_annotation.model_json_schema(
mode="serialization"
)
if isinstance(handler_io.output_serde, DefaultSerde):
handler_io.output_serde = PydanticJsonSerde(return_annotation)
return

core_kind, return_core_type = extract_core_type(return_annotation)
handler_io.output_type = TypeHint(annotation=return_core_type)
if Msgspec.is_struct(return_core_type):
handler_io.output_type.generate_json_schema = _make_json_schema_generator(
lambda: Msgspec.json_schema(return_core_type), core_kind
)
if isinstance(handler_io.output_serde, DefaultSerde):
handler_io.output_serde = MsgspecJsonSerde(return_core_type)
return

if is_pydantic(return_core_type):
handler_io.output_type.generate_json_schema = _make_json_schema_generator(
lambda: return_core_type.model_json_schema(mode="serialization"), core_kind
)
if isinstance(handler_io.output_serde, DefaultSerde):
handler_io.output_serde = PydanticJsonSerde(return_core_type)


def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
"""
Augment handler_io with additional information about the input and output types.

This function has a special check for msgspec Structs and Pydantic models when these are provided.
This method will inspect the signature of an handler and will look for
the input and the return types of a function, and will:
* capture any msgspec Structs or Pydantic models (to be used later at discovery)
* replace the default json serializer (is unchanged by a user) with the appropriate serde
"""
update_handler_io_with_input_type_hints(handler_io, signature)
update_handler_io_with_return_type_hints(handler_io, signature)


# pylint: disable=R0902
Expand Down
5 changes: 5 additions & 0 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from restate.handler import Handler, handler_from_callable, invoke_handler
from restate.serde import BytesSerde, DefaultSerde, Serde
from restate.server_types import ReceiveChannel, Send
from restate.types import extract_core_type
from restate.vm import Failure, Invocation, NotReady, VMWrapper, RunRetryConfig, Suspended # pylint: disable=line-too-long
from restate.vm import (
DoProgressAnyCompleted,
Expand Down Expand Up @@ -697,6 +698,10 @@ def run_typed(
if options.type_hint is None:
signature = inspect.signature(action, eval_str=True)
options.type_hint = signature.return_annotation
core_type_kind, core_type = extract_core_type(options.type_hint)
if core_type_kind == "simple" or core_type_kind == "optional":
# use core type as it is more specific. E.g. Optional[T] -> T
options.type_hint = core_type
options.serde = typing.cast(DefaultSerde, options.serde).with_maybe_type(options.type_hint)
handle = self.vm.sys_run(name)
update_restate_context_is_replaying(self.vm)
Expand Down
21 changes: 21 additions & 0 deletions python/restate/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""

from dataclasses import dataclass
from types import UnionType
from typing import Any, Tuple, Literal, Union, get_args, get_origin

from restate.client_types import RestateClient

Expand All @@ -32,3 +34,22 @@ class HarnessEnvironment:

client: RestateClient
"""The Restate client connected to the ingress URL"""


def extract_core_type(annotation: Any) -> Tuple[Literal["optional", "simple"], Any]:
"""
Extract the core type from a type annotation.

Currently only supports Optional[T] types.
"""
if annotation is None:
return "simple", annotation

origin = get_origin(annotation)
args = get_args(annotation)

if (origin is UnionType or Union) and len(args) == 2 and type(None) in args:
non_none_type = args[0] if args[1] is type(None) else args[1]
return "optional", non_none_type

return "simple", annotation
27 changes: 27 additions & 0 deletions tests/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,30 @@
def test_bytes_serde():
s = BytesSerde()
assert bytes(range(20)) == s.serialize(bytes(range(20)))


def extract_core_type_optional():
from restate.types import extract_core_type

from typing import Optional, Union

kind, tpe = extract_core_type(Optional[int])
assert kind == "optional"
assert tpe is int

kind, tpe = extract_core_type(Union[int, None])
assert kind == "optional"
assert tpe is int

kind, tpe = extract_core_type(str | None)
assert kind == "optional"
assert tpe is str

kind, tpe = extract_core_type(None | str)
assert kind == "optional"
assert tpe is str

for t in [int, str, bytes, dict, list, None]:
kind, tpe = extract_core_type(t)
assert kind == "simple"
assert tpe is t
31 changes: 28 additions & 3 deletions tests/servercontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,38 @@ async def test_promise_default_serde():
async def run(ctx: WorkflowContext) -> str:
promise = ctx.promise("test.promise", type_hint=str)

assert isinstance(promise.serde, DefaultSerde), \
f"Expected DefaultSerde but got {type(promise.serde).__name__}"
assert isinstance(promise.serde, DefaultSerde), f"Expected DefaultSerde but got {type(promise.serde).__name__}"

await promise.resolve("success")
return await promise.value()


async with simple_harness(workflow) as client:
result = await client.workflow_call(run, key="test-key", arg=None)
assert result == "success"


async def test_handler_with_union_none():
greeter = Service("greeter")

@greeter.handler()
async def greet(ctx: Context, name: str) -> str | None:
return "hi"

async with simple_harness(greeter) as client:
res = await client.service_call(greet, arg="bob")
assert res == "hi"


async def test_handler_with_ctx_none():
greeter = Service("greeter")

async def maybe_something() -> str | None:
return "hi"

@greeter.handler()
async def greet(ctx: Context, name: str) -> str | None:
return await ctx.run_typed("foo", maybe_something)

async with simple_harness(greeter) as client:
res = await client.service_call(greet, arg="bob")
assert res == "hi"