Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support both Pydantic v1 and v2 #24

Merged
merged 17 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Websockets are ideal to create bi-directional realtime connections over the web.
- Server Endpoint:
- Based on [FAST-API](https://github.com/tiangolo/fastapi): enjoy all the benefits of a full ASGI platform, including Async-io and dependency injections (for example to authenticate connections)

- Based on [Pydnatic](https://pydantic-docs.helpmanual.io/): easily serialize structured data as part of RPC requests and responses (see 'tests/basic_rpc_test.py :: test_structured_response' for an example)
- Based on [Pydantic](https://pydantic-docs.helpmanual.io/): easily serialize structured data as part of RPC requests and responses (see 'tests/basic_rpc_test.py :: test_structured_response' for an example)

- Client :
- Based on [Tenacity](https://tenacity.readthedocs.io/en/latest/index.html): allowing configurable retries to keep to connection alive
Expand Down
96 changes: 71 additions & 25 deletions fastapi_websocket_rpc/rpc_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
"""
import asyncio
from inspect import _empty, getmembers, ismethod, signature
from typing import Any, Coroutine, Dict, List
from typing import Any, Dict, List

from pydantic import ValidationError

from .utils import gen_uid
from .logger import get_logger
from .rpc_methods import EXPOSED_BUILT_IN_METHODS, NoResponse, RpcMethodsBase
from .schemas import RpcMessage, RpcRequest, RpcResponse
from .utils import gen_uid, get_model_parser

from .logger import get_logger
logger = get_logger("RPC_CHANNEL")


Expand All @@ -31,6 +31,7 @@ class RpcChannelClosedException(Exception):
"""
Raised when the channel is closed mid-operation
"""

pass


Expand Down Expand Up @@ -92,11 +93,16 @@ class RpcCaller:

def __init__(self, channel, methods=None) -> None:
self._channel = channel
self._method_names = [method[0] for method in getmembers(
methods, lambda i: ismethod(i))] if methods is not None else None
self._method_names = (
[method[0] for method in getmembers(methods, lambda i: ismethod(i))]
if methods is not None
else None
)

def __getattribute__(self, name: str):
if (not name.startswith("_") or name in EXPOSED_BUILT_IN_METHODS) and (self._method_names is None or name in self._method_names):
if (not name.startswith("_") or name in EXPOSED_BUILT_IN_METHODS) and (
self._method_names is None or name in self._method_names
):
return RpcProxy(self._channel, name)
else:
return super().__getattribute__(name)
Expand Down Expand Up @@ -124,7 +130,15 @@ class RpcChannel:
e.g. answer = channel.other.add(a=1,b=1) will (For example) ask the other side to perform 1+1 and will return an RPC-response of 2
"""

def __init__(self, methods: RpcMethodsBase, socket, channel_id=None, default_response_timeout=None, sync_channel_id=False, **kwargs):
def __init__(
self,
methods: RpcMethodsBase,
socket,
channel_id=None,
default_response_timeout=None,
sync_channel_id=False,
**kwargs,
):
"""

Args:
Expand Down Expand Up @@ -177,12 +191,18 @@ async def get_other_channel_id(self) -> str:
The _channel_id_synced verify we have it
Timeout exception can be raised if the value isn't available
"""
await asyncio.wait_for(self._channel_id_synced.wait(), self.default_response_timeout)
await asyncio.wait_for(
self._channel_id_synced.wait(), self.default_response_timeout
)
return self._other_channel_id

def get_return_type(self, method):
method_signature = signature(method)
return method_signature.return_annotation if method_signature.return_annotation is not _empty else str
return (
method_signature.return_annotation
if method_signature.return_annotation is not _empty
else str
)

async def send(self, data):
"""
Expand Down Expand Up @@ -217,14 +237,14 @@ async def on_message(self, data):
This is the main function servers/clients using the channel need to call (upon reading a message on the wire)
"""
try:
message = RpcMessage.parse_obj(data)
parse_model = get_model_parser()
message = parse_model(RpcMessage, data)
if message.request is not None:
await self.on_request(message.request)
if message.response is not None:
await self.on_response(message.response)
except ValidationError as e:
logger.error(f"Failed to parse message", {
'message': data, 'error': e})
logger.error(f"Failed to parse message", {"message": data, "error": e})
await self.on_error(e)
except Exception as e:
await self.on_error(e)
Expand Down Expand Up @@ -267,7 +287,8 @@ async def on_connect(self):
"""
if self._sync_channel_id:
self._get_other_channel_id_task = asyncio.create_task(
self._get_other_channel_id())
self._get_other_channel_id()
)
await self.on_handler_event(self._connect_handlers, self)

async def _get_other_channel_id(self):
Expand All @@ -277,7 +298,11 @@ async def _get_other_channel_id(self):
"""
if self._other_channel_id is None:
other_channel_id = await self.other._get_channel_id_()
self._other_channel_id = other_channel_id.result if other_channel_id and other_channel_id.result else None
self._other_channel_id = (
other_channel_id.result
if other_channel_id and other_channel_id.result
else None
)
if self._other_channel_id is None:
raise RemoteValueError()
# update asyncio event that we received remote channel id
Expand All @@ -303,11 +328,14 @@ async def on_request(self, message: RpcRequest):
message (RpcRequest): the RPC request with the method to call
"""
# TODO add exception support (catch exceptions and pass to other side as response with errors)
logger.debug("Handling RPC request - %s",
{'request': message, 'channel': self.id})
logger.debug(
"Handling RPC request - %s", {"request": message, "channel": self.id}
)
method_name = message.method
# Ignore "_" prefixed methods (except the built in "_ping_")
if (isinstance(method_name, str) and (not method_name.startswith("_") or method_name in EXPOSED_BUILT_IN_METHODS)):
if isinstance(method_name, str) and (
not method_name.startswith("_") or method_name in EXPOSED_BUILT_IN_METHODS
):
method = getattr(self.methods, method_name)
if callable(method):
result = await method(**message.arguments)
Expand All @@ -317,8 +345,17 @@ async def on_request(self, message: RpcRequest):
# if no type given - try to convert to string
if result_type is str and type(result) is not str:
result = str(result)
response = RpcMessage(response=RpcResponse[result_type](
call_id=message.call_id, result=result, result_type=getattr(result_type, "__name__", getattr(result_type, "_name", "unknown-type"))))
response = RpcMessage(
response=RpcResponse[result_type](
call_id=message.call_id,
result=result,
result_type=getattr(
result_type,
"__name__",
getattr(result_type, "_name", "unknown-type"),
),
)
)
await self.send(response)

def get_saved_promise(self, call_id):
Expand All @@ -338,7 +375,7 @@ async def on_response(self, response: RpcResponse):
Args:
response (RpcResponse): the received response
"""
logger.debug("Handling RPC response - %s", {'response': response})
logger.debug("Handling RPC response - %s", {"response": response})
if response.call_id is not None and response.call_id in self.requests:
self.responses[response.call_id] = response
promise = self.requests[response.call_id]
Expand All @@ -360,15 +397,23 @@ async def wait_for_response(self, promise, timeout=DEFAULT_TIMEOUT) -> RpcRespon
if timeout is DEFAULT_TIMEOUT:
timeout = self.default_response_timeout
# wait for the promise or until the channel is terminated
_, pending = await asyncio.wait([asyncio.ensure_future(promise.wait()), asyncio.ensure_future(self._closed.wait())], timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
_, pending = await asyncio.wait(
[
asyncio.ensure_future(promise.wait()),
asyncio.ensure_future(self._closed.wait()),
],
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel all pending futures and then detect if close was the first done
for fut in pending:
fut.cancel()
response = self.responses.get(promise.call_id, NoResponse)
# if the channel was closed before we could finish
if response is NoResponse:
raise RpcChannelClosedException(
f"Channel Closed before RPC response for {promise.call_id} could be received")
f"Channel Closed before RPC response for {promise.call_id} could be received"
)
self.clear_saved_call(promise.call_id)
return response

Expand All @@ -382,9 +427,10 @@ async def async_call(self, name, args={}, call_id=None) -> RpcPromise:
call_id (string, optional): a UUID to use to track the call () - override only with true UUIDs
"""
call_id = call_id or gen_uid()
msg = RpcMessage(request=RpcRequest(
method=name, arguments=args, call_id=call_id))
logger.debug("Calling RPC method - %s", {'message': msg})
msg = RpcMessage(
request=RpcRequest(method=name, arguments=args, call_id=call_id)
)
logger.debug("Calling RPC method - %s", {"message": msg})
await self.send(msg)
promise = self.requests[msg.request.call_id] = RpcPromise(msg.request)
return promise
Expand Down
26 changes: 19 additions & 7 deletions fastapi_websocket_rpc/schemas.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Dict, Generic, List, Optional, TypeVar
from enum import Enum
from typing import Dict, Generic, Optional, TypeVar

from pydantic import BaseModel
from pydantic.generics import GenericModel

from .utils import is_pydantic_pre_v2

UUID = str

Expand All @@ -13,13 +14,24 @@ class RpcRequest(BaseModel):
call_id: Optional[UUID] = None


ResponseT = TypeVar('ResponseT')
ResponseT = TypeVar("ResponseT")


class RpcResponse(GenericModel, Generic[ResponseT]):
result: ResponseT
result_type: Optional[str]
call_id: Optional[UUID] = None
# Check pydantic version to handle deprecated GenericModel
if is_pydantic_pre_v2():
from pydantic.generics import GenericModel

class RpcResponse(GenericModel, Generic[ResponseT]):
result: ResponseT
result_type: Optional[str]
call_id: Optional[UUID] = None

else:

class RpcResponse(BaseModel, Generic[ResponseT]):
result: ResponseT
result_type: Optional[str]
call_id: Optional[UUID] = None


class RpcMessage(BaseModel):
Expand Down
9 changes: 6 additions & 3 deletions fastapi_websocket_rpc/simplewebsocket.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any
from abc import ABC, abstractmethod
import json
from abc import ABC, abstractmethod

from .utils import get_model_serializer


class SimpleWebSocket(ABC):
"""
Abstract base class for all websocket related wrappers.
"""

@abstractmethod
def send(self, msg):
pass
Expand All @@ -25,7 +27,8 @@ def __init__(self, websocket: SimpleWebSocket):
self._websocket = websocket

def _serialize(self, msg):
return msg.json()
serialize_model = get_model_serializer()
return serialize_model(msg)

def _deserialize(self, buffer):
return json.loads(buffer)
Expand Down
36 changes: 30 additions & 6 deletions fastapi_websocket_rpc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from datetime import timedelta
from random import SystemRandom, randrange

__author__ = 'OrW'
import pydantic
from packaging import version

__author__ = "OrW"


class RandomUtils(object):
@staticmethod
def gen_cookie_id():
return os.urandom(16).encode('hex')
return os.urandom(16).encode("hex")

@staticmethod
def gen_uid():
Expand All @@ -21,8 +24,10 @@ def gen_uid():
def gen_token(size=256):
if size % 2 != 0:
raise ValueError("Size in bits must be an even number.")
return uuid.UUID(int=SystemRandom().getrandbits(size/2)).hex + \
uuid.UUID(int=SystemRandom().getrandbits(size/2)).hex
return (
uuid.UUID(int=SystemRandom().getrandbits(size / 2)).hex
+ uuid.UUID(int=SystemRandom().getrandbits(size / 2)).hex
)

@staticmethod
def random_datetime(start=None, end=None):
Expand Down Expand Up @@ -52,9 +57,28 @@ def random_datetime(start=None, end=None):
class StringUtils(object):
@staticmethod
def convert_camelcase_to_underscore(name, lower=True):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
res = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1)
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
res = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1)
if lower:
return res.lower()
else:
return res.upper()


# Helper methods for supporting Pydantic v1 and v2
def is_pydantic_pre_v2():
return version.parse(pydantic.VERSION) < version.parse("2.0.0")


def get_model_serializer():
if is_pydantic_pre_v2():
return lambda model, **kwargs: model.json(**kwargs)
else:
return lambda model, **kwargs: model.model_dump_json(**kwargs)


def get_model_parser():
ff137 marked this conversation as resolved.
Show resolved Hide resolved
if is_pydantic_pre_v2():
return lambda model, data, **kwargs: model.parse_obj(data, **kwargs)
else:
return lambda model, data, **kwargs: model.model_validate(data, **kwargs)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
fastapi>=0.78.0,<1
pydantic>=1.9.1,<2
packaging>=20.4
pydantic>=1.9.1
uvicorn>=0.17.6,<1
websockets>=10.3,<11
tenacity>=8.0.1,<9
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_requirements(env=""):

setup(
name='fastapi_websocket_rpc',
version='0.1.24',
version='0.1.25',
ff137 marked this conversation as resolved.
Show resolved Hide resolved
author='Or Weis',
author_email="or@permit.io",
description="A fast and durable bidirectional JSON RPC channel over Websockets and FastApi.",
Expand Down
Loading
Loading