Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/msgspec_greeter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,3 @@ class Greeting(msgspec.Struct):
@msgspec_greeter.handler()
async def greet(ctx: Context, req: GreetingRequest) -> Greeting:
return Greeting(message=f"Hello {req.name}!")

8 changes: 2 additions & 6 deletions python/restate/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,8 @@ def json_schema_from_type_hint(type_hint: Optional[TypeHint[Any]]) -> Any:
return None
if not type_hint.annotation:
return None
if type_hint.is_msgspec:
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel

return msgspec.json.schema(type_hint.annotation)
if type_hint.is_pydantic:
return type_hint.annotation.model_json_schema(mode="serialization")
if type_hint.generate_json_schema is not None:
return type_hint.generate_json_schema()
return type_hint_to_json_schema(type_hint.annotation)


Expand Down
36 changes: 20 additions & 16 deletions python/restate/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from restate.context import HandlerType
from restate.exceptions import TerminalError
from restate.serde import DefaultSerde, PydanticJsonSerde, MsgspecJsonSerde, Serde, is_pydantic, is_msgspec
from restate.serde import DefaultSerde, PydanticJsonSerde, MsgspecJsonSerde, Serde, is_pydantic, Msgspec

I = TypeVar("I")
O = TypeVar("O")
Expand Down Expand Up @@ -53,9 +53,11 @@ class TypeHint(Generic[T]):
"""

annotation: Optional[T] = None
is_pydantic: bool = False
is_msgspec: bool = False
"""The type annotation."""
is_void: bool = False
"""Whether the type is void (i.e., None)."""
generate_json_schema: Callable[[], Dict[str, Any]] | None = None
"""A callable that generates the JSON schema for the type."""


@dataclass
Expand Down Expand Up @@ -92,30 +94,32 @@ def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Si
handler_io.input_type = TypeHint(is_void=True)
else:
annotation = params[-1].annotation
handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False, is_msgspec=False)
if is_msgspec(annotation):
handler_io.input_type.is_msgspec = True
handler_io.input_type = TypeHint(annotation=annotation)
if Msgspec.is_struct(annotation):
handler_io.input_type.generate_json_schema = lambda: Msgspec.json_schema(annotation)
if isinstance(handler_io.input_serde, DefaultSerde):
handler_io.input_serde = MsgspecJsonSerde(annotation)
elif is_pydantic(annotation):
handler_io.input_type.is_pydantic = True
handler_io.input_type.generate_json_schema = lambda: annotation.model_json_schema(mode="serialization")
if isinstance(handler_io.input_serde, DefaultSerde):
handler_io.input_serde = PydanticJsonSerde(annotation)

annotation = signature.return_annotation
if annotation is None or annotation is Signature.empty:
return_annotation = signature.return_annotation
if return_annotation is None or return_annotation is Signature.empty:
# if there is no return annotation, we assume it is void
handler_io.output_type = TypeHint(is_void=True)
else:
handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False, is_msgspec=False)
if is_msgspec(annotation):
handler_io.output_type.is_msgspec = True
handler_io.output_type = TypeHint(annotation=return_annotation)
if Msgspec.is_struct(return_annotation):
handler_io.output_type.generate_json_schema = lambda: Msgspec.json_schema(return_annotation)
if isinstance(handler_io.output_serde, DefaultSerde):
handler_io.output_serde = MsgspecJsonSerde(annotation)
elif is_pydantic(annotation):
handler_io.output_type.is_pydantic = True
handler_io.output_serde = MsgspecJsonSerde(return_annotation)
elif is_pydantic(return_annotation):
handler_io.output_type.generate_json_schema = lambda: return_annotation.model_json_schema(
mode="serialization"
)
if isinstance(handler_io.output_serde, DefaultSerde):
handler_io.output_serde = PydanticJsonSerde(annotation)
handler_io.output_serde = PydanticJsonSerde(return_annotation)


# pylint: disable=R0902
Expand Down
105 changes: 64 additions & 41 deletions python/restate/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

from dataclasses import asdict, is_dataclass

T = typing.TypeVar("T")
I = typing.TypeVar("I")
O = typing.TypeVar("O")


def try_import_pydantic_base_model():
"""
Expand Down Expand Up @@ -74,31 +78,56 @@ def _from_dict(data_class: typing.Any, data: typing.Any) -> typing.Any: # pylin
return _to_dict, _from_dict


def try_import_msgspec_struct():
class MsgspecJsonAPI:
def is_struct(self, annotation: typing.Any) -> bool:
return False

def decode(self, buf: bytes, type: typing.Type[T]) -> T:
raise NotImplementedError("Please use msgspec as a conditional dependency to use msgspec features.")

def encode(self, obj: typing.Any) -> bytes:
raise NotImplementedError("Please use msgspec as a conditional dependency to use msgspec features.")

def json_schema(self, type: typing.Type[T]) -> dict[str, typing.Any]:
raise NotImplementedError("Please use msgspec as a conditional dependency to use msgspec features.")


def try_import_msgspec_api():
"""
Try to import Struct from msgspec.
Try to import msgspec API.
"""
try:
from msgspec import Struct # type: ignore # pylint: disable=import-outside-toplevel
import msgspec

return Struct
except ImportError:
class MsgspecImpl(MsgspecJsonAPI):
def is_struct(self, annotation: typing.Any) -> bool:
try:
return issubclass(annotation, Struct)
except TypeError:
# annotation is not a class or a type
return False

class Dummy: # pylint: disable=too-few-public-methods
"""a dummy class to use when msgspec is not available"""
def decode(self, buf: bytes, type: typing.Type[T]) -> T:
return msgspec.json.decode(buf, type=type)

return Dummy
def encode(self, obj: typing.Any) -> bytes:
return msgspec.json.encode(obj)

def json_schema(self, type: typing.Type[T]) -> dict[str, typing.Any]:
return msgspec.json.schema(type)

return MsgspecImpl()

except ImportError:
return MsgspecJsonAPI()


PydanticBaseModel = try_import_pydantic_base_model()
MsgspecStruct = try_import_msgspec_struct()
Msgspec = try_import_msgspec_api()
# pylint: disable=C0103
DaciteToDict, DaciteFromDict = try_import_from_dacite()

T = typing.TypeVar("T")
I = typing.TypeVar("I")
O = typing.TypeVar("O")

# disable to few parameters
# pylint: disable=R0903

Expand All @@ -114,17 +143,6 @@ def is_pydantic(annotation) -> bool:
return False


def is_msgspec(annotation) -> bool:
"""
Check if an object is a msgspec Struct.
"""
try:
return issubclass(annotation, MsgspecStruct)
except TypeError:
# annotation is not a class or a type
return False


class Serde(typing.Generic[T], abc.ABC):
"""serializer/deserializer interface."""

Expand Down Expand Up @@ -255,15 +273,19 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]:
"""
if not buf:
return None
if is_msgspec(self.type_hint):
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel

return msgspec.json.decode(buf, type=self.type_hint)
if is_pydantic(self.type_hint):
return self.type_hint.model_validate_json(buf) # type: ignore
if is_dataclass(self.type_hint):
hint = self.type_hint
if not hint:
return json.loads(buf)
if Msgspec.is_struct(hint):
return Msgspec.decode(buf, type=hint)
if is_pydantic(hint):
return hint.model_validate_json(buf) # type: ignore
if is_dataclass(hint):
data = json.loads(buf)
return DaciteFromDict(self.type_hint, data)
return DaciteFromDict(hint, data)
# although we have a type hint, we fall back to json.loads because we were not able to
# identify a specific deserialization method, perhaps the user specified a default type
# for another reason than serialization/deserialization.
return json.loads(buf)

def serialize(self, obj: typing.Optional[I]) -> bytes:
Expand All @@ -279,15 +301,19 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
"""
if obj is None:
return bytes()
if is_msgspec(self.type_hint):
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel

return msgspec.json.encode(obj)
if is_pydantic(self.type_hint):
hint = self.type_hint
if not hint:
return json.dumps(obj).encode("utf-8")
if Msgspec.is_struct(hint):
return Msgspec.encode(obj)
if is_pydantic(hint):
return obj.model_dump_json().encode("utf-8") # type: ignore[attr-defined]
if is_dataclass(obj):
data = DaciteToDict(obj) # type: ignore
return json.dumps(data).encode("utf-8")
# although we have a type hint, we fall back to json.dumps because we were not able to
# identify a specific serialization method, perhaps the user specified a default type
# for another reason than serialization/deserialization.
return json.dumps(obj).encode("utf-8")


Expand Down Expand Up @@ -349,9 +375,8 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]:
"""
if not buf:
return None
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel

return msgspec.json.decode(buf, type=self.model)
return Msgspec.decode(buf, type=self.model)

def serialize(self, obj: typing.Optional[I]) -> bytes:
"""
Expand All @@ -365,6 +390,4 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
"""
if obj is None:
return bytes()
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel

return msgspec.json.encode(obj)
return Msgspec.encode(obj)
Loading