Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/msgspec_greeter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""msgspec_greeter.py - Example using msgspec.Struct with Restate"""
# pylint: disable=C0116
# pylint: disable=W0613
# pylint: disable=C0115
# pylint: disable=R0903

import msgspec
from restate import Service, Context


# models
class GreetingRequest(msgspec.Struct):
name: str


class Greeting(msgspec.Struct):
message: str


# service

msgspec_greeter = Service("msgspec_greeter")


@msgspec_greeter.handler()
async def greet(ctx: Context, req: GreetingRequest) -> Greeting:
return Greeting(message=f"Hello {req.name}!")

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Source = "https://github.com/restatedev/sdk-python"
test = ["pytest", "hypercorn", "anyio"]
lint = ["mypy>=1.11.2", "pyright>=1.1.390", "ruff>=0.6.9"]
harness = ["testcontainers", "hypercorn", "httpx"]
serde = ["dacite", "pydantic"]
serde = ["dacite", "pydantic", "msgspec"]
client = ["httpx[http2]"]

[build-system]
Expand Down
4 changes: 4 additions & 0 deletions python/restate/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def json_schema_from_type_hint(type_hint: Optional[TypeHint[Any]]) -> Any:
return None
if not type_hint.annotation:
return None
if type_hint.is_msgspec:
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid this inner import. I'd rather to have all the conditional imports captured elsewhere.


return msgspec.json.schema(type_hint.annotation)
if type_hint.is_pydantic:
return type_hint.annotation.model_json_schema(mode="serialization")
return type_hint_to_json_schema(type_hint.annotation)
Expand Down
25 changes: 17 additions & 8 deletions python/restate/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from restate.context import HandlerType
from restate.exceptions import TerminalError
from restate.serde import DefaultSerde, PydanticJsonSerde, Serde, is_pydantic
from restate.serde import DefaultSerde, PydanticJsonSerde, MsgspecJsonSerde, Serde, is_pydantic, is_msgspec

I = TypeVar("I")
O = TypeVar("O")
Expand Down Expand Up @@ -54,6 +54,7 @@ class TypeHint(Generic[T]):

annotation: Optional[T] = None
is_pydantic: bool = False
is_msgspec: bool = False
is_void: bool = False


Expand All @@ -79,20 +80,24 @@ def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Si
"""
Augment handler_io with additional information about the input and output types.

This function has a special check for Pydantic models when these are provided.
This function has a special check for msgspec Structs and Pydantic models when these are provided.
This method will inspect the signature of an handler and will look for
the input and the return types of a function, and will:
* capture any Pydantic models (to be used later at discovery)
* replace the default json serializer (is unchanged by a user) with a Pydantic serde
* capture any msgspec Structs or Pydantic models (to be used later at discovery)
* replace the default json serializer (is unchanged by a user) with the appropriate serde
"""
params = list(signature.parameters.values())
if len(params) == 1:
# if there is only one parameter, it is the context.
handler_io.input_type = TypeHint(is_void=True)
else:
annotation = params[-1].annotation
handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False)
if is_pydantic(annotation):
handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False, is_msgspec=False)
if is_msgspec(annotation):
handler_io.input_type.is_msgspec = True
if isinstance(handler_io.input_serde, DefaultSerde):
handler_io.input_serde = MsgspecJsonSerde(annotation)
elif is_pydantic(annotation):
handler_io.input_type.is_pydantic = True
if isinstance(handler_io.input_serde, DefaultSerde):
handler_io.input_serde = PydanticJsonSerde(annotation)
Expand All @@ -102,8 +107,12 @@ def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Si
# if there is no return annotation, we assume it is void
handler_io.output_type = TypeHint(is_void=True)
else:
handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False)
if is_pydantic(annotation):
handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False, is_msgspec=False)
if is_msgspec(annotation):
handler_io.output_type.is_msgspec = True
if isinstance(handler_io.output_serde, DefaultSerde):
handler_io.output_serde = MsgspecJsonSerde(annotation)
elif is_pydantic(annotation):
handler_io.output_type.is_pydantic = True
if isinstance(handler_io.output_serde, DefaultSerde):
handler_io.output_serde = PydanticJsonSerde(annotation)
Expand Down
79 changes: 78 additions & 1 deletion python/restate/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,24 @@ def _from_dict(data_class: typing.Any, data: typing.Any) -> typing.Any: # pylin
return _to_dict, _from_dict


def try_import_msgspec_struct():
"""
Try to import Struct from msgspec.
"""
try:
from msgspec import Struct # type: ignore # pylint: disable=import-outside-toplevel

return Struct
except ImportError:

class Dummy: # pylint: disable=too-few-public-methods
"""a dummy class to use when msgspec is not available"""

return Dummy


PydanticBaseModel = try_import_pydantic_base_model()
MsgspecStruct = try_import_msgspec_struct()
# pylint: disable=C0103
DaciteToDict, DaciteFromDict = try_import_from_dacite()

Expand All @@ -97,6 +114,17 @@ def is_pydantic(annotation) -> bool:
return False


def is_msgspec(annotation) -> bool:
"""
Check if an object is a msgspec Struct.
"""
try:
return issubclass(annotation, MsgspecStruct)
except TypeError:
# annotation is not a class or a type
return False


class Serde(typing.Generic[T], abc.ABC):
"""serializer/deserializer interface."""

Expand Down Expand Up @@ -227,6 +255,10 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]:
"""
if not buf:
return None
if is_msgspec(self.type_hint):
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel

return msgspec.json.decode(buf, type=self.type_hint)
if is_pydantic(self.type_hint):
return self.type_hint.model_validate_json(buf) # type: ignore
if is_dataclass(self.type_hint):
Expand All @@ -237,7 +269,7 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]:
def serialize(self, obj: typing.Optional[I]) -> bytes:
"""
Serializes a Python object into a byte array.
If the object is a Pydantic BaseModel, uses its model_dump_json method.
If the object is a msgspec Struct or Pydantic BaseModel, uses their respective methods.

Args:
obj (Optional[I]): The Python object to serialize.
Expand All @@ -247,6 +279,10 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
"""
if obj is None:
return bytes()
if is_msgspec(self.type_hint):
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel

return msgspec.json.encode(obj)
if is_pydantic(self.type_hint):
return obj.model_dump_json().encode("utf-8") # type: ignore[attr-defined]
if is_dataclass(obj):
Expand Down Expand Up @@ -291,3 +327,44 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
return bytes()
json_str = obj.model_dump_json() # type: ignore[attr-defined]
return json_str.encode("utf-8")


class MsgspecJsonSerde(Serde[I]):
"""
Serde for msgspec Structs to/from JSON
"""

def __init__(self, model):
self.model = model

def deserialize(self, buf: bytes) -> typing.Optional[I]:
"""
Deserializes a bytearray to a msgspec Struct.

Args:
buf (bytearray): The bytearray to deserialize.

Returns:
typing.Optional[I]: The deserialized msgspec Struct.
"""
if not buf:
return None
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid these inline imports.


return msgspec.json.decode(buf, type=self.model)

def serialize(self, obj: typing.Optional[I]) -> bytes:
"""
Serializes a msgspec Struct to a bytearray.

Args:
obj (I): The msgspec Struct to serialize.

Returns:
bytearray: The serialized bytearray.
"""
if obj is None:
return bytes()
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel

return msgspec.json.encode(obj)