From 5f107133dd743c78bde6606d0a90003e74d88d5b Mon Sep 17 00:00:00 2001 From: igalshilman Date: Mon, 8 Dec 2025 11:09:34 +0100 Subject: [PATCH] Add few tests --- python/restate/handler.py | 147 +++++++++++++++++++++++++------ python/restate/server_context.py | 5 ++ python/restate/types.py | 21 +++++ tests/serde.py | 27 ++++++ tests/servercontext.py | 31 ++++++- 5 files changed, 201 insertions(+), 30 deletions(-) diff --git a/python/restate/handler.py b/python/restate/handler.py index 5edbae9..b3371b2 100644 --- a/python/restate/handler.py +++ b/python/restate/handler.py @@ -18,13 +18,25 @@ from dataclasses import dataclass from datetime import timedelta from inspect import Signature -from typing import Any, AsyncContextManager, Callable, Awaitable, Dict, Generic, List, Literal, Optional, TypeVar +from typing import ( + Any, + AsyncContextManager, + Callable, + Awaitable, + Dict, + Generic, + List, + Literal, + Optional, + TypeVar, +) from restate.retry_policy import InvocationRetryPolicy from restate.context import HandlerType from restate.exceptions import TerminalError from restate.serde import DefaultSerde, PydanticJsonSerde, MsgspecJsonSerde, Serde, is_pydantic, Msgspec +from restate.types import extract_core_type I = TypeVar("I") O = TypeVar("O") @@ -78,13 +90,55 @@ class HandlerIO(Generic[I, O]): output_type: Optional[TypeHint[O]] = None -def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Signature): +def _json_schema_wrap_as_optional(schema: Dict[str, Any]) -> Dict[str, Any]: """ - Augment handler_io with additional information about the input and output types. + modify the given JSON schema with its type wrapped as optional (nullable). + """ + t = schema.get("type") + + if t is None: + # If type is unspecified, leave it open by only adding "null" + schema["type"] = ["null"] + return schema + + if isinstance(t, list): + if "null" not in t: + t.append("null") + else: + if t != "null": + schema["type"] = [t, "null"] + + return schema + + +def _make_json_schema_generator( + original: Callable[[], Dict[str, Any]], type: Literal["optional", "simple"] +) -> Callable[[], Dict[str, Any]]: + """ + Create a JSON schema generator that handles optional types. + + If the type is optional, the generated schema will include "null" in the type. + """ + if type == "simple": + return original + + def generator() -> Dict[str, Any]: + schema = original() + if type == "optional": + return _json_schema_wrap_as_optional(schema) + + assert False, "unreachable" + + return generator + + +def update_handler_io_with_input_type_hints(handler_io: HandlerIO[I, O], signature: Signature): + """ + Augment handler_io with additional information about the input type. This function has a special check for msgspec Structs and Pydantic models when these are provided. This method will inspect the signature of an handler and will look for - the input and the return types of a function, and will: + the input type of a function, and will: * capture any msgspec Structs or Pydantic models (to be used later at discovery) * replace the default json serializer (is unchanged by a user) with the appropriate serde """ @@ -92,34 +146,73 @@ def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Si if len(params) == 1: # if there is only one parameter, it is the context. handler_io.input_type = TypeHint(is_void=True) - else: - annotation = params[-1].annotation - handler_io.input_type = TypeHint(annotation=annotation) - if Msgspec.is_struct(annotation): - handler_io.input_type.generate_json_schema = lambda: Msgspec.json_schema(annotation) - if isinstance(handler_io.input_serde, DefaultSerde): - handler_io.input_serde = MsgspecJsonSerde(annotation) - elif is_pydantic(annotation): - handler_io.input_type.generate_json_schema = lambda: annotation.model_json_schema(mode="serialization") - if isinstance(handler_io.input_serde, DefaultSerde): - handler_io.input_serde = PydanticJsonSerde(annotation) + return + + annotation = params[-1].annotation + core_kind, core_type = extract_core_type(annotation) + handler_io.input_type = TypeHint(annotation=core_type) + if Msgspec.is_struct(core_type): + handler_io.input_type.generate_json_schema = _make_json_schema_generator( + lambda: Msgspec.json_schema(core_type), core_kind + ) + if isinstance(handler_io.input_serde, DefaultSerde): + handler_io.input_serde = MsgspecJsonSerde(core_type) + return + + if is_pydantic(core_type): + handler_io.input_type.generate_json_schema = _make_json_schema_generator( + lambda: core_type.model_json_schema(mode="serialization"), core_kind + ) + if isinstance(handler_io.input_serde, DefaultSerde): + handler_io.input_serde = PydanticJsonSerde(core_type) + + +def update_handler_io_with_return_type_hints(handler_io: HandlerIO[I, O], signature: Signature): + """ + Augment handler_io with additional information about the output type. + This function has a special check for msgspec Structs and Pydantic models when these are provided. + This method will inspect the signature of an handler and will look for + the return type of a function, and will: + * capture any msgspec Structs or Pydantic models (to be used later at discovery) + * replace the default json serializer (is unchanged by a user) with the appropriate serde + """ return_annotation = signature.return_annotation if return_annotation is None or return_annotation is Signature.empty: # if there is no return annotation, we assume it is void handler_io.output_type = TypeHint(is_void=True) - else: - handler_io.output_type = TypeHint(annotation=return_annotation) - if Msgspec.is_struct(return_annotation): - handler_io.output_type.generate_json_schema = lambda: Msgspec.json_schema(return_annotation) - if isinstance(handler_io.output_serde, DefaultSerde): - handler_io.output_serde = MsgspecJsonSerde(return_annotation) - elif is_pydantic(return_annotation): - handler_io.output_type.generate_json_schema = lambda: return_annotation.model_json_schema( - mode="serialization" - ) - if isinstance(handler_io.output_serde, DefaultSerde): - handler_io.output_serde = PydanticJsonSerde(return_annotation) + return + + core_kind, return_core_type = extract_core_type(return_annotation) + handler_io.output_type = TypeHint(annotation=return_core_type) + if Msgspec.is_struct(return_core_type): + handler_io.output_type.generate_json_schema = _make_json_schema_generator( + lambda: Msgspec.json_schema(return_core_type), core_kind + ) + if isinstance(handler_io.output_serde, DefaultSerde): + handler_io.output_serde = MsgspecJsonSerde(return_core_type) + return + + if is_pydantic(return_core_type): + handler_io.output_type.generate_json_schema = _make_json_schema_generator( + lambda: return_core_type.model_json_schema(mode="serialization"), core_kind + ) + if isinstance(handler_io.output_serde, DefaultSerde): + handler_io.output_serde = PydanticJsonSerde(return_core_type) + + +def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Signature): + """ + Augment handler_io with additional information about the input and output types. + + This function has a special check for msgspec Structs and Pydantic models when these are provided. + This method will inspect the signature of an handler and will look for + the input and the return types of a function, and will: + * capture any msgspec Structs or Pydantic models (to be used later at discovery) + * replace the default json serializer (is unchanged by a user) with the appropriate serde + """ + update_handler_io_with_input_type_hints(handler_io, signature) + update_handler_io_with_return_type_hints(handler_io, signature) # pylint: disable=R0902 diff --git a/python/restate/server_context.py b/python/restate/server_context.py index a8847d7..efa4e69 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -49,6 +49,7 @@ from restate.handler import Handler, handler_from_callable, invoke_handler from restate.serde import BytesSerde, DefaultSerde, Serde from restate.server_types import ReceiveChannel, Send +from restate.types import extract_core_type from restate.vm import Failure, Invocation, NotReady, VMWrapper, RunRetryConfig, Suspended # pylint: disable=line-too-long from restate.vm import ( DoProgressAnyCompleted, @@ -697,6 +698,10 @@ def run_typed( if options.type_hint is None: signature = inspect.signature(action, eval_str=True) options.type_hint = signature.return_annotation + core_type_kind, core_type = extract_core_type(options.type_hint) + if core_type_kind == "simple" or core_type_kind == "optional": + # use core type as it is more specific. E.g. Optional[T] -> T + options.type_hint = core_type options.serde = typing.cast(DefaultSerde, options.serde).with_maybe_type(options.type_hint) handle = self.vm.sys_run(name) update_restate_context_is_replaying(self.vm) diff --git a/python/restate/types.py b/python/restate/types.py index 285add9..8dc798a 100644 --- a/python/restate/types.py +++ b/python/restate/types.py @@ -16,6 +16,8 @@ """ from dataclasses import dataclass +from types import UnionType +from typing import Any, Tuple, Literal, Union, get_args, get_origin from restate.client_types import RestateClient @@ -32,3 +34,22 @@ class HarnessEnvironment: client: RestateClient """The Restate client connected to the ingress URL""" + + +def extract_core_type(annotation: Any) -> Tuple[Literal["optional", "simple"], Any]: + """ + Extract the core type from a type annotation. + + Currently only supports Optional[T] types. + """ + if annotation is None: + return "simple", annotation + + origin = get_origin(annotation) + args = get_args(annotation) + + if (origin is UnionType or Union) and len(args) == 2 and type(None) in args: + non_none_type = args[0] if args[1] is type(None) else args[1] + return "optional", non_none_type + + return "simple", annotation diff --git a/tests/serde.py b/tests/serde.py index 506874e..d15c5f2 100644 --- a/tests/serde.py +++ b/tests/serde.py @@ -4,3 +4,30 @@ def test_bytes_serde(): s = BytesSerde() assert bytes(range(20)) == s.serialize(bytes(range(20))) + + +def extract_core_type_optional(): + from restate.types import extract_core_type + + from typing import Optional, Union + + kind, tpe = extract_core_type(Optional[int]) + assert kind == "optional" + assert tpe is int + + kind, tpe = extract_core_type(Union[int, None]) + assert kind == "optional" + assert tpe is int + + kind, tpe = extract_core_type(str | None) + assert kind == "optional" + assert tpe is str + + kind, tpe = extract_core_type(None | str) + assert kind == "optional" + assert tpe is str + + for t in [int, str, bytes, dict, list, None]: + kind, tpe = extract_core_type(t) + assert kind == "simple" + assert tpe is t diff --git a/tests/servercontext.py b/tests/servercontext.py index 3b27949..308adf0 100644 --- a/tests/servercontext.py +++ b/tests/servercontext.py @@ -85,13 +85,38 @@ async def test_promise_default_serde(): async def run(ctx: WorkflowContext) -> str: promise = ctx.promise("test.promise", type_hint=str) - assert isinstance(promise.serde, DefaultSerde), \ - f"Expected DefaultSerde but got {type(promise.serde).__name__}" + assert isinstance(promise.serde, DefaultSerde), f"Expected DefaultSerde but got {type(promise.serde).__name__}" await promise.resolve("success") return await promise.value() - async with simple_harness(workflow) as client: result = await client.workflow_call(run, key="test-key", arg=None) assert result == "success" + + +async def test_handler_with_union_none(): + greeter = Service("greeter") + + @greeter.handler() + async def greet(ctx: Context, name: str) -> str | None: + return "hi" + + async with simple_harness(greeter) as client: + res = await client.service_call(greet, arg="bob") + assert res == "hi" + + +async def test_handler_with_ctx_none(): + greeter = Service("greeter") + + async def maybe_something() -> str | None: + return "hi" + + @greeter.handler() + async def greet(ctx: Context, name: str) -> str | None: + return await ctx.run_typed("foo", maybe_something) + + async with simple_harness(greeter) as client: + res = await client.service_call(greet, arg="bob") + assert res == "hi"