Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded pubsub #2762

Merged
merged 10 commits into from May 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions dev_requirements.txt
Expand Up @@ -15,4 +15,5 @@ pytest-cov>=4.0.0
vulture>=2.3.0
ujson>=4.2.0
wheel>=0.30.0
urllib3<2
uvloop
83 changes: 73 additions & 10 deletions redis/client.py
Expand Up @@ -833,6 +833,7 @@ class AbstractRedis:
"QUIT": bool_ok,
"STRALGO": parse_stralgo,
"PUBSUB NUMSUB": parse_pubsub_numsub,
"PUBSUB SHARDNUMSUB": parse_pubsub_numsub,
"RANDOMKEY": lambda r: r and r or None,
"RESET": str_if_bytes,
"SCAN": parse_scan,
Expand Down Expand Up @@ -1440,8 +1441,8 @@ class PubSub:
will be returned and it's safe to start listening again.
"""

PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage")
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
HEALTH_CHECK_MESSAGE = "redis-py-health-check"

def __init__(
Expand Down Expand Up @@ -1493,9 +1494,11 @@ def reset(self):
self.connection.clear_connect_callbacks()
self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
self.health_check_response_counter = 0
self.channels = {}
self.pending_unsubscribe_channels = set()
self.shard_channels = {}
dvora-h marked this conversation as resolved.
Show resolved Hide resolved
self.pending_unsubscribe_shard_channels = set()
self.patterns = {}
self.pending_unsubscribe_patterns = set()
self.subscribed_event.clear()
Expand All @@ -1510,16 +1513,23 @@ def on_connect(self, connection):
# before passing them to [p]subscribe.
self.pending_unsubscribe_channels.clear()
self.pending_unsubscribe_patterns.clear()
self.pending_unsubscribe_shard_channels.clear()
if self.channels:
channels = {}
for k, v in self.channels.items():
channels[self.encoder.decode(k, force=True)] = v
channels = {
self.encoder.decode(k, force=True): v for k, v in self.channels.items()
}
self.subscribe(**channels)
if self.patterns:
patterns = {}
for k, v in self.patterns.items():
patterns[self.encoder.decode(k, force=True)] = v
patterns = {
self.encoder.decode(k, force=True): v for k, v in self.patterns.items()
}
self.psubscribe(**patterns)
if self.shard_channels:
shard_channels = {
self.encoder.decode(k, force=True): v
for k, v in self.shard_channels.items()
}
self.ssubscribe(**shard_channels)

@property
def subscribed(self):
Expand Down Expand Up @@ -1728,6 +1738,45 @@ def unsubscribe(self, *args):
self.pending_unsubscribe_channels.update(channels)
return self.execute_command("UNSUBSCRIBE", *args)

def ssubscribe(self, *args, target_node=None, **kwargs):
"""
Subscribes the client to the specified shard channels.
Channels supplied as keyword arguments expect a channel name as the key
and a callable as the value. A channel's callable will be invoked automatically
when a message is received on that channel rather than producing a message via
``listen()`` or ``get_sharded_message()``.
"""
if args:
args = list_or_args(args[0], args[1:])
new_s_channels = dict.fromkeys(args)
new_s_channels.update(kwargs)
ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys())
# update the s_channels dict AFTER we send the command. we don't want to
# subscribe twice to these channels, once for the command and again
# for the reconnection.
new_s_channels = self._normalize_keys(new_s_channels)
self.shard_channels.update(new_s_channels)
if not self.subscribed:
# Set the subscribed_event flag to True
self.subscribed_event.set()
# Clear the health check counter
self.health_check_response_counter = 0
chayim marked this conversation as resolved.
Show resolved Hide resolved
self.pending_unsubscribe_shard_channels.difference_update(new_s_channels)
return ret_val

def sunsubscribe(self, *args, target_node=None):
"""
Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
all shard_channels
"""
if args:
args = list_or_args(args[0], args[1:])
dvora-h marked this conversation as resolved.
Show resolved Hide resolved
s_channels = self._normalize_keys(dict.fromkeys(args))
else:
s_channels = self.shard_channels
self.pending_unsubscribe_shard_channels.update(s_channels)
return self.execute_command("SUNSUBSCRIBE", *args)

def listen(self):
"Listen for messages on channels this client has been subscribed to"
while self.subscribed:
Expand Down Expand Up @@ -1762,6 +1811,8 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
return self.handle_message(response, ignore_subscribe_messages)
return None

get_sharded_message = get_message

def ping(self, message=None):
"""
Ping the Redis server
Expand Down Expand Up @@ -1809,12 +1860,17 @@ def handle_message(self, response, ignore_subscribe_messages=False):
if pattern in self.pending_unsubscribe_patterns:
self.pending_unsubscribe_patterns.remove(pattern)
self.patterns.pop(pattern, None)
elif message_type == "sunsubscribe":
s_channel = response[1]
if s_channel in self.pending_unsubscribe_shard_channels:
self.pending_unsubscribe_shard_channels.remove(s_channel)
self.shard_channels.pop(s_channel, None)
else:
channel = response[1]
if channel in self.pending_unsubscribe_channels:
self.pending_unsubscribe_channels.remove(channel)
self.channels.pop(channel, None)
if not self.channels and not self.patterns:
if not self.channels and not self.patterns and not self.shard_channels:
# There are no subscriptions anymore, set subscribed_event flag
# to false
self.subscribed_event.clear()
Expand All @@ -1823,6 +1879,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
# if there's a message handler, invoke it
if message_type == "pmessage":
handler = self.patterns.get(message["pattern"], None)
elif message_type == "smessage":
handler = self.shard_channels.get(message["channel"], None)
else:
handler = self.channels.get(message["channel"], None)
if handler:
Expand All @@ -1843,6 +1901,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
for pattern, handler in self.patterns.items():
if handler is None:
raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
for s_channel, handler in self.shard_channels.items():
if handler is None:
raise PubSubError(
f"Shard Channel: '{s_channel}' has no handler registered"
)

thread = PubSubWorkerThread(
self, sleep_time, daemon=daemon, exception_handler=exception_handler
Expand Down
105 changes: 101 additions & 4 deletions redis/cluster.py
Expand Up @@ -9,6 +9,7 @@
from redis.backoff import default_backoff
from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan
from redis.commands import READ_COMMANDS, RedisClusterCommands
from redis.commands.helpers import list_or_args
from redis.connection import ConnectionPool, DefaultParser, parse_url
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
from redis.exceptions import (
Expand Down Expand Up @@ -222,6 +223,8 @@ class AbstractRedisCluster:
"PUBSUB CHANNELS",
"PUBSUB NUMPAT",
"PUBSUB NUMSUB",
"PUBSUB SHARDCHANNELS",
"PUBSUB SHARDNUMSUB",
"PING",
"INFO",
"SHUTDOWN",
Expand Down Expand Up @@ -346,11 +349,13 @@ class AbstractRedisCluster:
}

RESULT_CALLBACKS = dict_merge(
list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub),
list_keys_to_dict(["PUBSUB NUMSUB", "PUBSUB SHARDNUMSUB"], parse_pubsub_numsub),
list_keys_to_dict(
["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values()))
),
list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result),
list_keys_to_dict(
["KEYS", "PUBSUB CHANNELS", "PUBSUB SHARDCHANNELS"], merge_result
),
list_keys_to_dict(
[
"PING",
Expand Down Expand Up @@ -1625,6 +1630,8 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
else redis_cluster.get_redis_connection(self.node).connection_pool
)
self.cluster = redis_cluster
self.node_pubsub_mapping = {}
self._pubsubs_generator = self._pubsubs_generator()
super().__init__(
**kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder
)
Expand Down Expand Up @@ -1678,9 +1685,9 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port):
f"Node {host}:{port} doesn't exist in the cluster"
)

def execute_command(self, *args, **kwargs):
def execute_command(self, *args):
"""
Execute a publish/subscribe command.
Execute a subscribe/unsubscribe command.

Taken code from redis-py and tweak to make it work within a cluster.
"""
Expand Down Expand Up @@ -1713,13 +1720,103 @@ def execute_command(self, *args, **kwargs):
connection = self.connection
self._execute(connection, connection.send_command, *args)

def _get_node_pubsub(self, node):
dvora-h marked this conversation as resolved.
Show resolved Hide resolved
try:
return self.node_pubsub_mapping[node.name]
except KeyError:
pubsub = node.redis_connection.pubsub()
self.node_pubsub_mapping[node.name] = pubsub
return pubsub

def _sharded_message_generator(self):
for _ in range(len(self.node_pubsub_mapping)):
pubsub = next(self._pubsubs_generator)
message = pubsub.get_message()
if message is not None:
return message
return None

def _pubsubs_generator(self):
while True:
for pubsub in self.node_pubsub_mapping.values():
yield pubsub

def get_sharded_message(
self, ignore_subscribe_messages=False, timeout=0.0, target_node=None
):
if target_node:
message = self.node_pubsub_mapping[target_node.name].get_message(
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
)
else:
message = self._sharded_message_generator()
if message is None:
return None
elif str_if_bytes(message["type"]) == "sunsubscribe":
if message["channel"] in self.pending_unsubscribe_shard_channels:
self.pending_unsubscribe_shard_channels.remove(message["channel"])
self.shard_channels.pop(message["channel"], None)
node = self.cluster.get_node_from_key(message["channel"])
if self.node_pubsub_mapping[node.name].subscribed is False:
self.node_pubsub_mapping.pop(node.name)
if not self.channels and not self.patterns and not self.shard_channels:
# There are no subscriptions anymore, set subscribed_event flag
# to false
self.subscribed_event.clear()
if self.ignore_subscribe_messages or ignore_subscribe_messages:
return None
return message

def ssubscribe(self, *args, **kwargs):
if args:
args = list_or_args(args[0], args[1:])
s_channels = dict.fromkeys(args)
s_channels.update(kwargs)
for s_channel, handler in s_channels.items():
node = self.cluster.get_node_from_key(s_channel)
pubsub = self._get_node_pubsub(node)
if handler:
pubsub.ssubscribe(**{s_channel: handler})
else:
pubsub.ssubscribe(s_channel)
self.shard_channels.update(pubsub.shard_channels)
self.pending_unsubscribe_shard_channels.difference_update(
self._normalize_keys({s_channel: None})
)
if pubsub.subscribed and not self.subscribed:
self.subscribed_event.set()
self.health_check_response_counter = 0

def sunsubscribe(self, *args):
if args:
args = list_or_args(args[0], args[1:])
else:
args = self.shard_channels

for s_channel in args:
node = self.cluster.get_node_from_key(s_channel)
p = self._get_node_pubsub(node)
p.sunsubscribe(s_channel)
self.pending_unsubscribe_shard_channels.update(
p.pending_unsubscribe_shard_channels
)

def get_redis_connection(self):
"""
Get the Redis connection of the pubsub connected node.
"""
if self.node is not None:
return self.node.redis_connection

def disconnect(self):
"""
Disconnect the pubsub connection.
"""
if self.connection:
self.connection.disconnect()
for pubsub in self.node_pubsub_mapping.values():
pubsub.connection.disconnect()


class ClusterPipeline(RedisCluster):
"""
Expand Down
26 changes: 26 additions & 0 deletions redis/commands/core.py
Expand Up @@ -5103,6 +5103,15 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT
"""
return self.execute_command("PUBLISH", channel, message, **kwargs)

def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT:
"""
Posts a message to the given shard channel.
Returns the number of clients that received the message

For more information see https://redis.io/commands/spublish
"""
return self.execute_command("SPUBLISH", shard_channel, message)

def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
"""
Return a list of channels that have at least one subscriber
Expand All @@ -5111,6 +5120,14 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
"""
return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs)

def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
"""
Return a list of shard_channels that have at least one subscriber

For more information see https://redis.io/commands/pubsub-shardchannels
"""
return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs)

def pubsub_numpat(self, **kwargs) -> ResponseT:
"""
Returns the number of subscriptions to patterns
Expand All @@ -5128,6 +5145,15 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT:
"""
return self.execute_command("PUBSUB NUMSUB", *args, **kwargs)

def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT:
"""
Return a list of (shard_channel, number of subscribers) tuples
for each channel given in ``*args``

For more information see https://redis.io/commands/pubsub-shardnumsub
"""
return self.execute_command("PUBSUB SHARDNUMSUB", *args, **kwargs)


AsyncPubSubCommands = PubSubCommands

Expand Down
4 changes: 2 additions & 2 deletions redis/parsers/commands.py
Expand Up @@ -155,13 +155,13 @@ def _get_pubsub_keys(self, *args):
# the second argument is a part of the command name, e.g.
# ['PUBSUB', 'NUMSUB', 'foo'].
pubsub_type = args[1].upper()
if pubsub_type in ["CHANNELS", "NUMSUB"]:
if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]:
keys = args[2:]
elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]:
# format example:
# SUBSCRIBE channel [channel ...]
keys = list(args[1:])
elif command == "PUBLISH":
elif command in ["PUBLISH", "SPUBLISH"]:
# format example:
# PUBLISH channel message
keys = [args[1]]
Expand Down