diff --git a/examples/msgspec_greeter.py b/examples/msgspec_greeter.py new file mode 100644 index 0000000..40e1077 --- /dev/null +++ b/examples/msgspec_greeter.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""msgspec_greeter.py - Example using msgspec.Struct with Restate""" +# pylint: disable=C0116 +# pylint: disable=W0613 +# pylint: disable=C0115 +# pylint: disable=R0903 + +import msgspec +from restate import Service, Context + + +# models +class GreetingRequest(msgspec.Struct): + name: str + + +class Greeting(msgspec.Struct): + message: str + + +# service + +msgspec_greeter = Service("msgspec_greeter") + + +@msgspec_greeter.handler() +async def greet(ctx: Context, req: GreetingRequest) -> Greeting: + return Greeting(message=f"Hello {req.name}!") + diff --git a/pyproject.toml b/pyproject.toml index 187280c..aeac742 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ Source = "https://github.com/restatedev/sdk-python" test = ["pytest", "hypercorn", "anyio"] lint = ["mypy>=1.11.2", "pyright>=1.1.390", "ruff>=0.6.9"] harness = ["testcontainers", "hypercorn", "httpx"] -serde = ["dacite", "pydantic"] +serde = ["dacite", "pydantic", "msgspec"] client = ["httpx[http2]"] [build-system] diff --git a/python/restate/discovery.py b/python/restate/discovery.py index 7fef373..6d5d750 100644 --- a/python/restate/discovery.py +++ b/python/restate/discovery.py @@ -218,6 +218,10 @@ def json_schema_from_type_hint(type_hint: Optional[TypeHint[Any]]) -> Any: return None if not type_hint.annotation: return None + if type_hint.is_msgspec: + import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel + + return msgspec.json.schema(type_hint.annotation) if type_hint.is_pydantic: return type_hint.annotation.model_json_schema(mode="serialization") return type_hint_to_json_schema(type_hint.annotation) diff --git a/python/restate/handler.py b/python/restate/handler.py index a78b472..af00ce0 100644 --- a/python/restate/handler.py +++ b/python/restate/handler.py @@ -24,7 +24,7 @@ from restate.context import HandlerType from restate.exceptions import TerminalError -from restate.serde import DefaultSerde, PydanticJsonSerde, Serde, is_pydantic +from restate.serde import DefaultSerde, PydanticJsonSerde, MsgspecJsonSerde, Serde, is_pydantic, is_msgspec I = TypeVar("I") O = TypeVar("O") @@ -54,6 +54,7 @@ class TypeHint(Generic[T]): annotation: Optional[T] = None is_pydantic: bool = False + is_msgspec: bool = False is_void: bool = False @@ -79,11 +80,11 @@ def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Si """ Augment handler_io with additional information about the input and output types. - This function has a special check for Pydantic models when these are provided. + This function has a special check for msgspec Structs and Pydantic models when these are provided. This method will inspect the signature of an handler and will look for the input and the return types of a function, and will: - * capture any Pydantic models (to be used later at discovery) - * replace the default json serializer (is unchanged by a user) with a Pydantic serde + * capture any msgspec Structs or Pydantic models (to be used later at discovery) + * replace the default json serializer (is unchanged by a user) with the appropriate serde """ params = list(signature.parameters.values()) if len(params) == 1: @@ -91,8 +92,12 @@ def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Si handler_io.input_type = TypeHint(is_void=True) else: annotation = params[-1].annotation - handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False) - if is_pydantic(annotation): + handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False, is_msgspec=False) + if is_msgspec(annotation): + handler_io.input_type.is_msgspec = True + if isinstance(handler_io.input_serde, DefaultSerde): + handler_io.input_serde = MsgspecJsonSerde(annotation) + elif is_pydantic(annotation): handler_io.input_type.is_pydantic = True if isinstance(handler_io.input_serde, DefaultSerde): handler_io.input_serde = PydanticJsonSerde(annotation) @@ -102,8 +107,12 @@ def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Si # if there is no return annotation, we assume it is void handler_io.output_type = TypeHint(is_void=True) else: - handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False) - if is_pydantic(annotation): + handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False, is_msgspec=False) + if is_msgspec(annotation): + handler_io.output_type.is_msgspec = True + if isinstance(handler_io.output_serde, DefaultSerde): + handler_io.output_serde = MsgspecJsonSerde(annotation) + elif is_pydantic(annotation): handler_io.output_type.is_pydantic = True if isinstance(handler_io.output_serde, DefaultSerde): handler_io.output_serde = PydanticJsonSerde(annotation) diff --git a/python/restate/serde.py b/python/restate/serde.py index c415981..6e43736 100644 --- a/python/restate/serde.py +++ b/python/restate/serde.py @@ -74,7 +74,24 @@ def _from_dict(data_class: typing.Any, data: typing.Any) -> typing.Any: # pylin return _to_dict, _from_dict +def try_import_msgspec_struct(): + """ + Try to import Struct from msgspec. + """ + try: + from msgspec import Struct # type: ignore # pylint: disable=import-outside-toplevel + + return Struct + except ImportError: + + class Dummy: # pylint: disable=too-few-public-methods + """a dummy class to use when msgspec is not available""" + + return Dummy + + PydanticBaseModel = try_import_pydantic_base_model() +MsgspecStruct = try_import_msgspec_struct() # pylint: disable=C0103 DaciteToDict, DaciteFromDict = try_import_from_dacite() @@ -97,6 +114,17 @@ def is_pydantic(annotation) -> bool: return False +def is_msgspec(annotation) -> bool: + """ + Check if an object is a msgspec Struct. + """ + try: + return issubclass(annotation, MsgspecStruct) + except TypeError: + # annotation is not a class or a type + return False + + class Serde(typing.Generic[T], abc.ABC): """serializer/deserializer interface.""" @@ -227,6 +255,10 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]: """ if not buf: return None + if is_msgspec(self.type_hint): + import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel + + return msgspec.json.decode(buf, type=self.type_hint) if is_pydantic(self.type_hint): return self.type_hint.model_validate_json(buf) # type: ignore if is_dataclass(self.type_hint): @@ -237,7 +269,7 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]: def serialize(self, obj: typing.Optional[I]) -> bytes: """ Serializes a Python object into a byte array. - If the object is a Pydantic BaseModel, uses its model_dump_json method. + If the object is a msgspec Struct or Pydantic BaseModel, uses their respective methods. Args: obj (Optional[I]): The Python object to serialize. @@ -247,6 +279,10 @@ def serialize(self, obj: typing.Optional[I]) -> bytes: """ if obj is None: return bytes() + if is_msgspec(self.type_hint): + import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel + + return msgspec.json.encode(obj) if is_pydantic(self.type_hint): return obj.model_dump_json().encode("utf-8") # type: ignore[attr-defined] if is_dataclass(obj): @@ -291,3 +327,44 @@ def serialize(self, obj: typing.Optional[I]) -> bytes: return bytes() json_str = obj.model_dump_json() # type: ignore[attr-defined] return json_str.encode("utf-8") + + +class MsgspecJsonSerde(Serde[I]): + """ + Serde for msgspec Structs to/from JSON + """ + + def __init__(self, model): + self.model = model + + def deserialize(self, buf: bytes) -> typing.Optional[I]: + """ + Deserializes a bytearray to a msgspec Struct. + + Args: + buf (bytearray): The bytearray to deserialize. + + Returns: + typing.Optional[I]: The deserialized msgspec Struct. + """ + if not buf: + return None + import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel + + return msgspec.json.decode(buf, type=self.model) + + def serialize(self, obj: typing.Optional[I]) -> bytes: + """ + Serializes a msgspec Struct to a bytearray. + + Args: + obj (I): The msgspec Struct to serialize. + + Returns: + bytearray: The serialized bytearray. + """ + if obj is None: + return bytes() + import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel + + return msgspec.json.encode(obj)