Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,7 +2208,8 @@ def _sharded_message_generator(self):

def _pubsubs_generator(self):
while True:
yield from self.node_pubsub_mapping.values()
current_nodes = list(self.node_pubsub_mapping.values())
yield from current_nodes

def get_sharded_message(
self, ignore_subscribe_messages=False, timeout=0.0, target_node=None
Expand Down
114 changes: 114 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,120 @@ def test_pubsub_shardnumsub(self, r):
channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)]
assert r.pubsub_shardnumsub("foo", "bar", "baz", target_nodes="all") == channels

@pytest.mark.onlycluster
@skip_if_server_version_lt("7.0.0")
def test_ssubscribe_multiple_channels_different_nodes(self, r):
"""
Test subscribing to multiple sharded channels on different nodes.
Validates that the generator properly handles multiple node_pubsub_mapping entries.
"""
pubsub = r.pubsub()
channel1 = "test-channel:{0}"
channel2 = "test-channel:{6}"

# Subscribe to first channel
pubsub.ssubscribe(channel1)
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
assert msg is not None
assert msg["type"] == "ssubscribe"

# Subscribe to second channel (likely different node)
pubsub.ssubscribe(channel2)
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
assert msg is not None
assert msg["type"] == "ssubscribe"

# Verify both channels are in shard_channels
assert channel1.encode() in pubsub.shard_channels
assert channel2.encode() in pubsub.shard_channels

pubsub.close()

@pytest.mark.onlycluster
@skip_if_server_version_lt("7.0.0")
def test_ssubscribe_multiple_channels_publish_and_read(self, r):
"""
Test publishing to multiple sharded channels and reading messages.
Validates that _sharded_message_generator properly cycles through
multiple node_pubsub_mapping entries.
"""
pubsub = r.pubsub()
channel1 = "test-channel:{0}"
channel2 = "test-channel:{6}"
msg1_data = "message-1"
msg2_data = "message-2"

# Subscribe to both channels
pubsub.ssubscribe(channel1, channel2)

# Read subscription confirmations
for _ in range(2):
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
assert msg is not None
assert msg["type"] == "ssubscribe"

# Publish messages to both channels
r.spublish(channel1, msg1_data)
r.spublish(channel2, msg2_data)

# Read messages - should get both messages
messages = []
for _ in range(2):
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
assert msg is not None
assert msg["type"] == "smessage"
messages.append(msg)

# Verify we got messages from both channels
channels_received = {msg["channel"] for msg in messages}
assert channel1.encode() in channels_received
assert channel2.encode() in channels_received

pubsub.close()

@pytest.mark.onlycluster
@skip_if_server_version_lt("7.0.0")
def test_generator_handles_concurrent_mapping_changes(self, r):
"""
Test that the generator properly handles mapping changes during iteration.
This validates the fix for the RuntimeError: dictionary changed size during iteration.
"""
pubsub = r.pubsub()
channel1 = "test-channel:{0}"
channel2 = "test-channel:{6}"

# Subscribe to first channel
pubsub.ssubscribe(channel1)
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
assert msg is not None
assert msg["type"] == "ssubscribe"

# Get initial mapping size (cluster pubsub only)
assert hasattr(pubsub, "node_pubsub_mapping"), "Test requires ClusterPubSub"
initial_size = len(pubsub.node_pubsub_mapping)

# Subscribe to second channel (modifies mapping during potential iteration)
pubsub.ssubscribe(channel2)
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
assert msg is not None
assert msg["type"] == "ssubscribe"

# Verify mapping was updated
assert len(pubsub.node_pubsub_mapping) >= initial_size

# Publish and read messages - should not raise RuntimeError
r.spublish(channel1, "msg1")
r.spublish(channel2, "msg2")

messages_received = 0
for _ in range(2):
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
if msg and msg["type"] == "smessage":
messages_received += 1

assert messages_received == 2
pubsub.close()


class TestPubSubPings:
@skip_if_server_version_lt("3.0.0")
Expand Down