Skip to content

Commit

Permalink
[fix] memory leaks
Browse files Browse the repository at this point in the history
  • Loading branch information
mosquito committed Oct 12, 2017
1 parent 80d8e41 commit ef2eb54
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 59 deletions.
2 changes: 1 addition & 1 deletion wsrpc_aiohttp/version.py
Expand Up @@ -9,7 +9,7 @@

team_email = 'me@mosquito.su'

version_info = (0, 6, 3)
version_info = (0, 6, 4)


__author__ = ", ".join("{} <{}>".format(*info) for info in author_info)
Expand Down
17 changes: 11 additions & 6 deletions wsrpc_aiohttp/websocket/common.py
Expand Up @@ -82,11 +82,9 @@ def handler():
self._pending_tasks.add(self._loop.call_later(timer, handler))

async def close(self):
for task in tuple(self._pending_tasks):
task.cancel()

async def task_waiter(task):
if not (hasattr(task, '__iter__') or hasattr(task, '__aiter__')):
continue
return

try:
await task
Expand All @@ -95,6 +93,12 @@ async def close(self):
except Exception:
log.exception("Unhandled exception when closing client connection")

for task in tuple(self._pending_tasks):
task.cancel()

if not isinstance(task, asyncio.TimerHandle) and not task.cancelled():
self._loop.create_task(task_waiter(task))

def _log_call(self, start: float, *args):
end = self._loop.time()
log.info(end - start)
Expand All @@ -105,6 +109,7 @@ async def _handle_message(self, msg: aiohttp.WSMessage):
elif msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED):
self._create_task(self.close())
elif msg.type == aiohttp.WSMsgType.ERROR:
self._create_task(self.close())
raise aiohttp.WebSocketError
else:
log.warning("Unhandled message %r %r", msg.type, msg.data)
Expand Down Expand Up @@ -206,8 +211,8 @@ def _reject(self, serial, error):

future.set_exception(ClientException(error))

def _unresolvable(self, *args, **kwargs):
raise NotImplementedError('Callback function not implemented')
def _unresolvable(self, func_name, *args, **kwargs):
raise NotImplementedError('Callback function "%r" not implemented' % func_name)

def resolver(self, func_name):
class_name, method = func_name.split('.') if '.' in func_name else (func_name, 'init')
Expand Down
109 changes: 57 additions & 52 deletions wsrpc_aiohttp/websocket/handler.py
Expand Up @@ -23,6 +23,7 @@ class WebSocketBase(WSRPCBase, AbstractView):

_KEEPALIVE_PING_TIMEOUT = 30
_CLIENT_TIMEOUT = int(_KEEPALIVE_PING_TIMEOUT / 3)
_MAX_CONCURRENT_REQUESTS = 25

def __init__(self, request):
AbstractView.__init__(self, request)
Expand All @@ -32,12 +33,15 @@ def __init__(self, request):
self.id = uuid.uuid4()
self.protocol_version = None
self.serial = 0
self.socket = None # type: web.WebSocketResponse
self.semaphore = asyncio.Semaphore(self._MAX_CONCURRENT_REQUESTS, loop=self._loop)

@classmethod
def configure(cls, keepalive_timeout=_KEEPALIVE_PING_TIMEOUT, client_timeout=_CLIENT_TIMEOUT):
def configure(cls, keepalive_timeout=_KEEPALIVE_PING_TIMEOUT, client_timeout=_CLIENT_TIMEOUT,
max_concurrent_requests=_MAX_CONCURRENT_REQUESTS):

cls._KEEPALIVE_PING_TIMEOUT = keepalive_timeout
cls._CLIENT_TIMEOUT = client_timeout
cls._MAX_CONCURRENT_REQUESTS = max_concurrent_requests

@asyncio.coroutine
def __iter__(self):
Expand Down Expand Up @@ -82,75 +86,76 @@ def broadcast(cls, func, callback=WebSocketRoute.placebo, **kwargs):
loop.create_task(client.call, func, callback, **kwargs)

async def on_message(self, message: WSMessage):
log.debug('Client %s send message: "%s"', self.id, message)
async with self.semaphore:
log.debug('Client %s send message: "%s"', self.id, message)

start = self._loop.time()
start = self._loop.time()

# deserialize message
data = message.json(loads=json.loads)
serial = data.get('serial', -1)
msg_type = data.get('type', 'call')
# deserialize message
data = message.json(loads=json.loads)
serial = data.get('serial', -1)
msg_type = data.get('type', 'call')

message_repr = ''
message_repr = ''

assert serial >= 0
assert serial >= 0

log.debug("Acquiring lock for %s serial %s", self, serial)
async with self._locks[serial]:
try:
if msg_type == 'call':
args, kwargs = self._prepare_args(data.get('arguments', None))
callback = data.get('call', None)
log.debug("Acquiring lock for %s serial %s", self, serial)
async with self._locks[serial]:
try:
if msg_type == 'call':
args, kwargs = self._prepare_args(data.get('arguments', None))
callback = data.get('call', None)

message_repr = "call[%s]" % callback
message_repr = "call[%s]" % callback

if callback is None:
raise ValueError('Require argument "call" does\'t exist.')
if callback is None:
raise ValueError('Require argument "call" does\'t exist.')

callee = self.resolver(callback)
callee_is_route = hasattr(callee, '__self__') and isinstance(callee.__self__, WebSocketRoute)
if not callee_is_route:
a = [self]
a.extend(args)
args = a
callee = self.resolver(callback)
callee_is_route = hasattr(callee, '__self__') and isinstance(callee.__self__, WebSocketRoute)
if not callee_is_route:
a = [self]
a.extend(args)
args = a

result = await self._executor(partial(callee, *args, **kwargs))
self._send(data=result, serial=serial, type='callback')
result = await self._executor(partial(callee, *args, **kwargs))
self._send(data=result, serial=serial, type='callback')

elif msg_type == 'callback':
cb = self._futures.pop(serial, None)
payload = data.get('data', None)
cb.set_result(payload)
elif msg_type == 'callback':
cb = self._futures.pop(serial, None)
payload = data.get('data', None)
cb.set_result(payload)

message_repr = "callback[%r]" % payload
message_repr = "callback[%r]" % payload

elif msg_type == 'error':
self._reject(data.get('serial', -1), data.get('data', None))
log.error('Client return error: \n\t{0}'.format(data.get('data', None)))
elif msg_type == 'error':
self._reject(data.get('serial', -1), data.get('data', None))
log.error('Client return error: \n\t{0}'.format(data.get('data', None)))

message_repr = "error[%r]" % data
message_repr = "error[%r]" % data

except Exception as e:
log.exception(e)
self._send(data=self._format_error(e), serial=serial, type='error')
except Exception as e:
log.exception(e)
self._send(data=self._format_error(e), serial=serial, type='error')

finally:
def clean_lock():
log.debug("Release and delete lock for %s serial %s", self, serial)
if serial in self._locks:
self._locks.pop(serial)
finally:
def clean_lock():
log.debug("Release and delete lock for %s serial %s", self, serial)
if serial in self._locks:
self._locks.pop(serial)

self._call_later(self._CLIENT_TIMEOUT, clean_lock)
self._call_later(self._CLIENT_TIMEOUT, clean_lock)

response_time = self._loop.time() - start
response_time = self._loop.time() - start

if response_time > 0.001:
response_time = "%.3f" % response_time
else:
# loop.time() resolution is 1 ms.
response_time = "less then 1ms"
if response_time > 0.001:
response_time = "%.3f" % response_time
else:
# loop.time() resolution is 1 ms.
response_time = "less then 1ms"

log.info("Response for client \"%s\" #%d finished %s", message_repr, serial, response_time)
log.info("Response for client \"%s\" #%d finished %s", message_repr, serial, response_time)

def _send(self, **kwargs):
try:
Expand Down

0 comments on commit ef2eb54

Please sign in to comment.