Skip to content

Commit

Permalink
Assure pools are closed on loop close in core (django#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
sevdog committed Feb 15, 2023
1 parent a7094c5 commit 2fca31c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 43 deletions.
63 changes: 37 additions & 26 deletions channels_redis/core.py
Expand Up @@ -15,7 +15,7 @@
from channels.exceptions import ChannelFull
from channels.layers import BaseChannelLayer

from .utils import _consistent_hash
from .utils import _consistent_hash, _wrap_close

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,6 +69,27 @@ def put_nowait(self, item):
return super(BoundedQueue, self).put_nowait(item)


class RedisLoopLayer:

def __init__(self, channel_layer):
self._lock = asyncio.Lock()
self.channel_layer = channel_layer
self._connections = {}

def get_connection(self, index):
if index not in self._connections:
pool = self.channel_layer.create_pool(index)
self._connections[index] = aioredis.Redis(connection_pool=pool)

return self._connections[index]

async def flush(self):
async with self._lock:
for index in list(self._connections):
connection = self._connections.pop(index)
await connection.close(close_connection_pool=True)


class RedisChannelLayer(BaseChannelLayer):
"""
Redis channel layer.
Expand Down Expand Up @@ -101,8 +122,7 @@ def __init__(
self.hosts = self.decode_hosts(hosts)
self.ring_size = len(self.hosts)
# Cached redis connection pools and the event loop they are from
self.pools = {}
self.pools_loop = None
self._layers = {}
# Normal channels choose a host index by cycling through the available hosts
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
Expand Down Expand Up @@ -331,7 +351,7 @@ async def receive(self, channel):

raise

message, token, exception = None, None, None
message = token = exception = None
for task in done:
try:
result = task.result()
Expand Down Expand Up @@ -367,7 +387,7 @@ async def receive(self, channel):
message_channel, message = await self.receive_single(
real_channel
)
if type(message_channel) is list:
if isinstance(message_channel, list):
for chan in message_channel:
self.receive_buffer[chan].put_nowait(message)
else:
Expand Down Expand Up @@ -459,11 +479,7 @@ async def new_channel(self, prefix="specific"):
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
return "%s.%s!%s" % (
prefix,
self.client_prefix,
uuid.uuid4().hex,
)
return f"{prefix}.{self.client_prefix}!{uuid.uuid4().hex}"

### Flush extension ###

Expand Down Expand Up @@ -496,9 +512,8 @@ async def close_pools(self):
# Flush all cleaners, in case somebody just wanted to close the
# pools without flushing first.
await self.wait_received()

for index in self.pools:
await self.pools[index].disconnect()
for layer in self._layers.values():
await layer.flush()

async def wait_received(self):
"""
Expand Down Expand Up @@ -667,7 +682,7 @@ def _group_key(self, group):
"""
Common function to make the storage key for the group.
"""
return ("%s:group:%s" % (self.prefix, group)).encode("utf8")
return f"{self.prefix}:group:{group}".encode("utf8")

### Serialization ###

Expand Down Expand Up @@ -711,7 +726,7 @@ def make_fernet(self, key):
return Fernet(formatted_key)

def __str__(self):
return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)
return f"{self.__class__.__name__}(hosts={self.hosts})"

### Connection handling ###

Expand All @@ -723,18 +738,14 @@ def connection(self, index):
# Catch bad indexes
if not 0 <= index < self.ring_size:
raise ValueError(
"There are only %s hosts - you asked for %s!" % (self.ring_size, index)
f"There are only {self.ring_size} hosts - you asked for {index}!"
)

loop = asyncio.get_running_loop()
try:
loop = asyncio.get_running_loop()
if self.pools_loop != loop:
self.pools = {}
self.pools_loop = loop
except RuntimeError:
pass

if index not in self.pools:
self.pools[index] = self.create_pool(index)
layer = self._layers[loop]
except KeyError:
_wrap_close(self, loop)
layer = self._layers[loop] = RedisLoopLayer(self)

return aioredis.Redis(connection_pool=self.pools[index])
return layer.get_connection(index)
18 changes: 1 addition & 17 deletions channels_redis/pubsub.py
@@ -1,32 +1,16 @@
import asyncio
import functools
import logging
import types
import uuid

import msgpack
from redis import asyncio as aioredis

from .utils import _consistent_hash
from .utils import _consistent_hash, _wrap_close

logger = logging.getLogger(__name__)


def _wrap_close(proxy, loop):
original_impl = loop.close

def _wrapper(self, *args, **kwargs):
if loop in proxy._layers:
layer = proxy._layers[loop]
del proxy._layers[loop]
loop.run_until_complete(layer.flush())

self.close = original_impl
return self.close(*args, **kwargs)

loop.close = types.MethodType(_wrapper, loop)


async def _async_proxy(obj, name, *args, **kwargs):
# Must be defined as a function and not a method due to
# https://bugs.python.org/issue38364
Expand Down
16 changes: 16 additions & 0 deletions channels_redis/utils.py
@@ -1,4 +1,5 @@
import binascii
import types


def _consistent_hash(value, ring_size):
Expand All @@ -15,3 +16,18 @@ def _consistent_hash(value, ring_size):
bigval = binascii.crc32(value) & 0xFFF
ring_divisor = 4096 / float(ring_size)
return int(bigval / ring_divisor)


def _wrap_close(proxy, loop):
original_impl = loop.close

def _wrapper(self, *args, **kwargs):
if loop in proxy._layers:
layer = proxy._layers[loop]
del proxy._layers[loop]
loop.run_until_complete(layer.flush())

self.close = original_impl
return self.close(*args, **kwargs)

loop.close = types.MethodType(_wrapper, loop)

0 comments on commit 2fca31c

Please sign in to comment.