Skip to content

Commit

Permalink
[fix] add event handling (#20)
Browse files Browse the repository at this point in the history
* [fix] add event handling

* [fix] no create locks for events

* [fix] tests for emitter

* [fix] rename
  • Loading branch information
mosquito committed May 30, 2019
1 parent a01469e commit 18b8bdb
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 41 deletions.
23 changes: 23 additions & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from wsrpc_aiohttp.testing import BaseTestCase, async_timeout
from wsrpc_aiohttp import WSRPCBase


async def emitter(socket: WSRPCBase):
await socket.emit({"Hello": "world"})


class TestServerEvents(BaseTestCase):

@async_timeout
async def test_emitter(self):
self.WebSocketHandler.add_route('emitter', emitter)
client = await self.get_ws_client()

future = self.loop.create_future()

client.add_event_listener(future.set_result)

await client.proxy.emitter()
result = await future

self.assertDictEqual(result, {"Hello": "world"})
109 changes: 68 additions & 41 deletions wsrpc_aiohttp/websocket/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class WSRPCBase:
_CLEAN_LOCK_TIMEOUT = 2

__slots__ = ('_handlers', '_loop', '_pending_tasks', '_locks',
'_futures', '_serial', '_timeout')
'_futures', '_serial', '_timeout', '_event_listeners')

def __init__(self, loop: asyncio.AbstractEventLoop = None, timeout=None):
self._loop = loop or asyncio.get_event_loop()
Expand All @@ -71,6 +71,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop = None, timeout=None):
self._timeout = timeout
self._locks = defaultdict(partial(asyncio.Lock, loop=self._loop))
self._futures = defaultdict(self._loop.create_future)
self._event_listeners = set()

def _create_task(self, coro):
task = self._loop.create_task(coro) # type: asyncio.Task
Expand Down Expand Up @@ -156,12 +157,56 @@ def _prepare_args(args):

return arguments, kwargs

def prepare_args(self, args):
return self._prepare_args(args)

@staticmethod
def is_route(func):
return (
hasattr(func, '__self__') and
isinstance(func.__self__, WebSocketRoute)
)

async def handle_method(self, method, serial, *args, **kwargs):
callee = self.resolver(method)

if not self.is_route(callee):
a = [self]
a.extend(args)
args = a

result = await self._executor(partial(callee, *args, **kwargs))
await self._send(result=result, id=serial)

async def handle_result(self, serial, result):
cb = self._futures.pop(serial, None)
cb.set_result(result)

async def handle_error(self, serial, error):
self._reject(serial, error)
log.error('Client return error: \n\t%r', error)

def __clean_lock(self, serial):
if serial not in self._locks:
return
log.debug("Release and delete lock for %s serial %s", self, serial)
self._locks.pop(serial)

async def handle_event(self, event):
for listener in self._event_listeners:
self._loop.call_soon(listener, event)

async def on_message(self, message: aiohttp.WSMessage):
# deserialize message
data = message.json(loads=loads)
# noinspection PyNoneFunctionAssignment, PyTypeChecker
data = message.json(loads=loads) # type: dict

log.debug("Response: %r", data)
serial = data.get('id')

if serial is None:
return await self.handle_event(data)

method = data.get('method')
result = data.get('result')
error = data.get('error')
Expand All @@ -170,55 +215,28 @@ async def on_message(self, message: aiohttp.WSMessage):
async with self._locks[serial]:
try:
if 'method' in data:
args, kwargs = self._prepare_args(
args, kwargs = self.prepare_args(
data.get('params', None)
)

callee = self.resolver(method)
calee_is_route = (
hasattr(callee, '__self__') and
isinstance(callee.__self__, WebSocketRoute)
return await self.handle_method(
method, serial, *args, **kwargs
)

if not calee_is_route:
a = [self]
a.extend(args)
args = a

result = await self._executor(
partial(callee, *args, **kwargs)
)

await self._send(result=result, id=serial)
elif 'result' in data:
cb = self._futures.pop(serial, None)
cb.set_result(result)

return await self.handle_result(serial, result)
elif 'error' in data:
self._reject(serial, error)
log.error('Client return error: \n\t%r', error)
return await self.handle_error(serial, error)

except Exception as e:
log.exception(e)

if not serial:
return

await self._send(error=self._format_error(e), id=serial)

if serial:
await self._send(error=self._format_error(e), id=serial)
finally:
def clean_lock():
log.debug(
"Release and delete lock for %s serial %s",
self, serial
)

if serial in self._locks:
self._locks.pop(serial)

self._call_later(self._CLEAN_LOCK_TIMEOUT, clean_lock)
self._call_later(
self._CLEAN_LOCK_TIMEOUT, self.__clean_lock, serial
)

@abc.abstractstaticmethod
@abc.abstractmethod
async def _send(self, **kwargs):
raise NotImplementedError

Expand Down Expand Up @@ -311,6 +329,9 @@ async def make_something(self, foo, bar):
await self._send(**payload)
return await asyncio.wait_for(future, self._timeout, loop=self._loop)

async def emit(self, event):
await self._send(**event)

@classmethod
def add_route(cls, route: str, handler: Union[WebSocketRoute, Callable]):
""" Expose local function through RPC
Expand All @@ -335,6 +356,12 @@ def add_route(cls, route: str, handler: Union[WebSocketRoute, Callable]):
assert callable(handler) or isinstance(handler, WebSocketRoute)
cls.get_routes()[route] = handler

def add_event_listener(self, func: Callable[[dict], Any]):
self._event_listeners.add(func)

def remove_event_listeners(self, func):
return self._event_listeners.remove(func)

@classmethod
def remove_route(cls, route: str, fail=True):
""" Removes route by name. If `fail=True` an exception
Expand All @@ -351,7 +378,7 @@ def __repr__(self):
else:
return "<RPCWebsocket: {0} (waiting)>".format(self.__hash__())

@abc.abstractstaticmethod
@abc.abstractmethod
async def _executor(self, func):
raise NotImplementedError

Expand Down

0 comments on commit 18b8bdb

Please sign in to comment.