Skip to content

Commit

Permalink
[fix] add event handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mosquito committed May 29, 2019
1 parent a01469e commit 4042e98
Showing 1 changed file with 67 additions and 42 deletions.
109 changes: 67 additions & 42 deletions wsrpc_aiohttp/websocket/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class WSRPCBase:
_CLEAN_LOCK_TIMEOUT = 2

__slots__ = ('_handlers', '_loop', '_pending_tasks', '_locks',
'_futures', '_serial', '_timeout')
'_futures', '_serial', '_timeout', '_event_listeners')

def __init__(self, loop: asyncio.AbstractEventLoop = None, timeout=None):
self._loop = loop or asyncio.get_event_loop()
Expand All @@ -71,6 +71,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop = None, timeout=None):
self._timeout = timeout
self._locks = defaultdict(partial(asyncio.Lock, loop=self._loop))
self._futures = defaultdict(self._loop.create_future)
self._event_listeners = set()

def _create_task(self, coro):
task = self._loop.create_task(coro) # type: asyncio.Task
Expand Down Expand Up @@ -156,9 +157,49 @@ def _prepare_args(args):

return arguments, kwargs

def prepare_args(self, args):
return self._prepare_args(args)

@staticmethod
def is_route(func):
return (
hasattr(func, '__self__') and
isinstance(func.__self__, WebSocketRoute)
)

async def handle_method(self, method, serial, *args, **kwargs):
callee = self.resolver(method)

if not self.is_route(callee):
a = [self]
a.extend(args)
args = a

result = await self._executor(partial(callee, *args, **kwargs))
await self._send(result=result, id=serial)

async def handle_result(self, serial, result):
cb = self._futures.pop(serial, None)
cb.set_result(result)

async def handle_error(self, serial, error):
self._reject(serial, error)
log.error('Client return error: \n\t%r', error)

def __clean_lock(self, serial):
if serial not in self._locks:
return
log.debug("Release and delete lock for %s serial %s", self, serial)
self._locks.pop(serial)

async def handle_event(self, event):
for listener in self._event_listeners:
self._loop.call_soon(listener, event)

async def on_message(self, message: aiohttp.WSMessage):
# deserialize message
data = message.json(loads=loads)
# noinspection PyNoneFunctionAssignment, PyTypeChecker
data = message.json(loads=loads) # type: dict

log.debug("Response: %r", data)
serial = data.get('id')
Expand All @@ -169,56 +210,31 @@ async def on_message(self, message: aiohttp.WSMessage):
log.debug("Acquiring lock for %s serial %s", self, serial)
async with self._locks[serial]:
try:
if 'method' in data:
args, kwargs = self._prepare_args(
if serial is None:
await self.handle_event(data)
elif 'method' in data:
args, kwargs = self.prepare_args(
data.get('params', None)
)

callee = self.resolver(method)
calee_is_route = (
hasattr(callee, '__self__') and
isinstance(callee.__self__, WebSocketRoute)
await self.handle_method(
method, serial, *args, **kwargs
)

if not calee_is_route:
a = [self]
a.extend(args)
args = a

result = await self._executor(
partial(callee, *args, **kwargs)
)

await self._send(result=result, id=serial)
elif 'result' in data:
cb = self._futures.pop(serial, None)
cb.set_result(result)

await self.handle_result(serial, result)
elif 'error' in data:
self._reject(serial, error)
log.error('Client return error: \n\t%r', error)
await self.handle_error(serial, error)

except Exception as e:
log.exception(e)

if not serial:
return

await self._send(error=self._format_error(e), id=serial)

if serial:
await self._send(error=self._format_error(e), id=serial)
finally:
def clean_lock():
log.debug(
"Release and delete lock for %s serial %s",
self, serial
)

if serial in self._locks:
self._locks.pop(serial)

self._call_later(self._CLEAN_LOCK_TIMEOUT, clean_lock)
self._call_later(
self._CLEAN_LOCK_TIMEOUT, self.__clean_lock, serial
)

@abc.abstractstaticmethod
@abc.abstractmethod
async def _send(self, **kwargs):
raise NotImplementedError

Expand Down Expand Up @@ -311,6 +327,9 @@ async def make_something(self, foo, bar):
await self._send(**payload)
return await asyncio.wait_for(future, self._timeout, loop=self._loop)

async def emit(self, event):
await self._send(**event)

@classmethod
def add_route(cls, route: str, handler: Union[WebSocketRoute, Callable]):
""" Expose local function through RPC
Expand All @@ -335,6 +354,12 @@ def add_route(cls, route: str, handler: Union[WebSocketRoute, Callable]):
assert callable(handler) or isinstance(handler, WebSocketRoute)
cls.get_routes()[route] = handler

def add_event_listener(self, func: Callable[[dict], Any]):
self._event_listeners.add(func)

def remove_event_listeners(self, func):
return self._event_listeners.remove(func)

@classmethod
def remove_route(cls, route: str, fail=True):
""" Removes route by name. If `fail=True` an exception
Expand All @@ -351,7 +376,7 @@ def __repr__(self):
else:
return "<RPCWebsocket: {0} (waiting)>".format(self.__hash__())

@abc.abstractstaticmethod
@abc.abstractmethod
async def _executor(self, func):
raise NotImplementedError

Expand Down

0 comments on commit 4042e98

Please sign in to comment.