Skip to content

Commit ce738e3

Browse files
committed
Adding Redis Proxy integration
1 parent af87ddf commit ce738e3

File tree

5 files changed

+342
-14
lines changed

5 files changed

+342
-14
lines changed

docker-compose.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
---
22
# image tag 8.0-RC2-pre is the one matching the 8.0 GA release
33
x-client-libs-stack-image: &client-libs-stack-image
4-
image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.4-RC1-pre.2}"
4+
image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.4-GA-pre.2}"
55

66
x-client-libs-image: &client-libs-image
7-
image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.4-RC1-pre.2}"
7+
image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.4-GA-pre.2}"
88

99
networks:
1010
redis-net:
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import base64
2+
import json
3+
import logging
4+
from typing import Union
5+
from urllib.request import Request, urlopen
6+
from urllib.error import URLError
7+
8+
9+
class RespTranslator:
10+
"""Helper class to translate between RESP and other encodings."""
11+
12+
@staticmethod
13+
def cluster_slots_to_resp(resp: str) -> str:
14+
"""Convert query to RESP format."""
15+
return (
16+
f"*{len(resp.split())}\r\n"
17+
+ "\r\n".join(f"${len(x)}\r\n{x}" for x in resp.split())
18+
+ "\r\n"
19+
)
20+
21+
@staticmethod
22+
def smigrating_to_resp(resp: str) -> str:
23+
"""Convert query to RESP format."""
24+
return (
25+
f">{len(resp.split())}\r\n"
26+
+ "\r\n".join(f"${len(x)}\r\n{x}" for x in resp.split())
27+
+ "\r\n"
28+
)
29+
30+
31+
class ProxyInterceptorHelper:
32+
"""Helper class for intercepting socket calls and managing interceptor server."""
33+
34+
def __init__(self, server_url: str = "http://localhost:4000"):
35+
self.server_url = server_url
36+
self._resp_translator = RespTranslator()
37+
38+
def cleanup_interceptors(self, *names: str):
39+
"""
40+
Resets all the interceptors by providing empty pattern and returned response.
41+
42+
Args:
43+
names: Names of the interceptors to reset
44+
"""
45+
for name in names:
46+
self._reset_interceptor(name)
47+
48+
def set_cluster_nodes(self, name: str, nodes: list[tuple[str, int]]) -> str:
49+
"""
50+
Set cluster nodes by intercepting CLUSTER SLOTS command.
51+
52+
This method creates an interceptor that intercepts CLUSTER SLOTS commands
53+
and returns a modified topology with the provided nodes.
54+
55+
Args:
56+
name: Name of the interceptor
57+
nodes: List of (host, port) tuples representing the cluster nodes
58+
59+
Returns:
60+
The interceptor name that was created
61+
62+
Example:
63+
interceptor = InterceptorHelper(None, "http://localhost:4000")
64+
interceptor_name = interceptor.set_cluster_nodes(
65+
"test_topology",
66+
[("127.0.0.1", 6379), ("127.0.0.1", 6380), ("127.0.0.1", 6381)]
67+
)
68+
"""
69+
# Build RESP response for CLUSTER SLOTS
70+
# Format: *<num_slots_ranges> for each range: *3 :start :end *3 $<host_len> <host> :<port> $<id_len> <id>
71+
resp_parts = [f"*{len(nodes)}"]
72+
73+
# For simplicity, distribute slots evenly across nodes
74+
total_slots = 16384
75+
slots_per_node = total_slots // len(nodes)
76+
77+
for i, (host, port) in enumerate(nodes):
78+
start_slot = i * slots_per_node
79+
end_slot = (
80+
(i + 1) * slots_per_node - 1 if i < len(nodes) - 1 else total_slots - 1
81+
)
82+
83+
# Node info: *3 for (host, port, id)
84+
resp_parts.append("*3")
85+
resp_parts.append(f":{start_slot}")
86+
resp_parts.append(f":{end_slot}")
87+
88+
# Node details: *3 for (host, port, id)
89+
resp_parts.append("*3")
90+
resp_parts.append(f"${len(host)}")
91+
resp_parts.append(host)
92+
resp_parts.append(f":{port}")
93+
resp_parts.append("$13")
94+
resp_parts.append(f"proxy-id-{port}")
95+
96+
response = "\r\n".join(resp_parts) + "\r\n"
97+
98+
# Add the interceptor
99+
self._add_interceptor(
100+
name=name,
101+
match="*2\r\n$7\r\ncluster\r\n$5\r\nslots\r\n",
102+
response=response,
103+
encoding="raw",
104+
)
105+
106+
return name
107+
108+
def get_stats(self) -> dict:
109+
"""
110+
Get statistics from the interceptor server.
111+
112+
Returns:
113+
Statistics dictionary containing connection information
114+
"""
115+
url = f"{self.server_url}/stats"
116+
request = Request(url, method="GET")
117+
118+
try:
119+
with urlopen(request) as response:
120+
return json.loads(response.read().decode("utf-8"))
121+
except URLError as e:
122+
raise RuntimeError(f"Failed to get stats from interceptor server: {e}")
123+
124+
def get_connections(self) -> dict:
125+
"""
126+
Get all active connections from the server.
127+
128+
Returns:
129+
Response from the server as a dictionary
130+
"""
131+
url = f"{self.server_url}/connections"
132+
request = Request(url, method="GET")
133+
134+
try:
135+
with urlopen(request) as response:
136+
return json.loads(response.read().decode("utf-8"))
137+
except URLError as e:
138+
raise RuntimeError(f"Failed to get connections: {e}")
139+
140+
def send_notification(
141+
self,
142+
connected_to_port: Union[int, str],
143+
notification: str,
144+
) -> dict:
145+
"""
146+
Send a notification to all connections connected to a specific node.
147+
148+
This method:
149+
1. Fetches stats from the interceptor server
150+
2. Finds all connection IDs connected to the specified node
151+
3. Sends the notification to each connection
152+
153+
Args:
154+
node_address: Node address in format "host:port" (e.g., "127.0.0.1:6379")
155+
notification: The notification message to send (RESP format)
156+
encoding: Encoding type - "base64" or "raw"
157+
158+
Returns:
159+
Response from the server as a dictionary
160+
161+
Example:
162+
interceptor = InterceptorHelper(None, "http://localhost:4000")
163+
result = interceptor.send_notification(
164+
"6379",
165+
"KjENCiQ0DQpQSU5HDQo=", # PING command in base64
166+
encoding="base64"
167+
)
168+
"""
169+
# Get stats to find connection IDs for the node
170+
stats = self.get_stats()
171+
172+
# Extract connection IDs for the specified node
173+
conn_ids = []
174+
for node_key, node_info in stats.items():
175+
node_port = node_key.split("@")[1]
176+
if int(node_port) == int(connected_to_port):
177+
for conn in node_info.get("connections", []):
178+
conn_ids.append(conn["id"])
179+
180+
if not conn_ids:
181+
raise RuntimeError(
182+
f"No connections found for node {node_port}. "
183+
f"Available nodes: {list(set(c.get('node') for c in stats.get('connections', {}).values()))}"
184+
)
185+
186+
# Send notification to each connection
187+
results = {}
188+
logging.info(f"Sending notification to {len(conn_ids)} connections: {conn_ids}")
189+
connections_query = f"connectionIds={','.join(conn_ids)}"
190+
url = f"{self.server_url}/send-to-clients?{connections_query}&encoding=base64"
191+
# Encode notification to base64
192+
data = base64.b64encode(notification.encode("utf-8"))
193+
194+
request = Request(url, data=data, method="POST")
195+
196+
try:
197+
with urlopen(request) as response:
198+
results = json.loads(response.read().decode("utf-8"))
199+
except URLError as e:
200+
results = {"error": str(e)}
201+
202+
return {
203+
"node_address": node_port,
204+
"connection_ids": conn_ids,
205+
"results": results,
206+
}
207+
208+
def _add_interceptor(
209+
self,
210+
name: str,
211+
match: str,
212+
response: str,
213+
encoding: str = "raw",
214+
) -> dict:
215+
"""
216+
Add an interceptor to the server.
217+
218+
Args:
219+
name: Name of the interceptor
220+
match: Pattern to match (RESP format)
221+
response: Response to return when matched (RESP format)
222+
encoding: Encoding type - "base64" or "raw"
223+
224+
Returns:
225+
Response from the server as a dictionary
226+
"""
227+
url = f"{self.server_url}/interceptors"
228+
payload = {
229+
"name": name,
230+
"match": match,
231+
"response": response,
232+
"encoding": encoding,
233+
}
234+
data = json.dumps(payload).encode("utf-8")
235+
request = Request(
236+
url, data=data, method="POST", headers={"Content-Type": "application/json"}
237+
)
238+
239+
try:
240+
with urlopen(request) as response:
241+
return json.loads(response.read().decode("utf-8"))
242+
except URLError as e:
243+
raise RuntimeError(f"Failed to add interceptor: {e}")
244+
245+
def _reset_interceptor(self, name: str):
246+
"""
247+
Reset an interceptor by providing empty pattern and returned response.
248+
249+
Args:
250+
name: Name of the interceptor to reset
251+
"""
252+
self._add_interceptor(name, "", "")

tests/test_cluster_maint_notifications_handling.py renamed to tests/maint_notifications/test_cluster_maint_notifications_handling.py

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
from typing import cast
22

3-
from redis import RedisCluster
3+
from redis import ConnectionPool, RedisCluster
44
from redis.cluster import ClusterNode
5-
from redis.connection import BlockingConnectionPool
5+
from redis.connection import (
6+
BlockingConnectionPool,
7+
)
68
from redis.maint_notifications import MaintNotificationsConfig
7-
from redis.cache import CacheConfig # noqa: F401
9+
from redis.cache import CacheConfig
10+
from tests.maint_notifications.proxy_server_helpers import (
11+
ProxyInterceptorHelper,
12+
RespTranslator,
13+
)
14+
15+
# Initial cluster node configuration for proxy-based tests
16+
PROXY_CLUSTER_NODES = [
17+
ClusterNode("127.0.0.1", 15379),
18+
ClusterNode("127.0.0.1", 15380),
19+
ClusterNode("127.0.0.1", 15381),
20+
]
821

922

1023
class TestClusterMaintNotificationsConfig:
1124
"""Test the maint_notifications_config parameter of RedisCluster."""
1225

13-
# Real cluster node configuration
14-
CLUSTER_NODES = [
15-
ClusterNode("127.0.0.1", 15379),
16-
ClusterNode("127.0.0.1", 15380),
17-
ClusterNode("127.0.0.1", 15381),
18-
]
19-
2026
# Helper methods
2127
def _create_cluster_client(
2228
self,
@@ -27,7 +33,7 @@ def _create_cluster_client(
2733
):
2834
"""Create a RedisCluster instance with real cluster nodes."""
2935
kwargs = {
30-
"startup_nodes": self.CLUSTER_NODES,
36+
"startup_nodes": PROXY_CLUSTER_NODES,
3137
"protocol": 3,
3238
"skip_full_coverage_check": skip_full_coverage_check,
3339
}
@@ -266,3 +272,73 @@ def test_config_with_pipeline_operations(self):
266272
assert results[3] == b"value2" # GET returns value
267273
finally:
268274
cluster.close()
275+
276+
277+
class TestClusterMaintNotificationsHandlingBase:
278+
"""Base class for maintenance notifications handling tests."""
279+
280+
def setup_method(self):
281+
"""Set up test fixtures with mocked sockets."""
282+
self.proxy_helper = ProxyInterceptorHelper()
283+
284+
# Create maintenance notifications config
285+
self.config = MaintNotificationsConfig(
286+
enabled="auto", proactive_reconnect=True, relaxed_timeout=30
287+
)
288+
self.cluster = self._create_cluster_client(maint_config=self.config)
289+
290+
def _create_cluster_client(
291+
self,
292+
pool_class=ConnectionPool,
293+
enable_cache=False,
294+
max_connections=10,
295+
maint_config=None,
296+
) -> RedisCluster:
297+
"""Create a RedisCluster instance with mocked sockets."""
298+
config = maint_config if maint_config is not None else self.config
299+
kwargs = {}
300+
if enable_cache:
301+
kwargs = {"cache_config": CacheConfig()}
302+
303+
test_redis_client = RedisCluster(
304+
protocol=3,
305+
startup_nodes=PROXY_CLUSTER_NODES,
306+
maint_notifications_config=config,
307+
connection_pool_class=pool_class,
308+
max_connections=max_connections,
309+
**kwargs,
310+
)
311+
312+
return test_redis_client
313+
314+
def teardown_method(self):
315+
"""Clean up test fixtures."""
316+
self.cluster.close()
317+
self.proxy_helper.cleanup_interceptors()
318+
319+
320+
class TestClusterMaintNotificationsHandling(TestClusterMaintNotificationsHandlingBase):
321+
"""Test maintenance notifications handling with RedisCluster."""
322+
323+
def test_receive_maint_notification(self):
324+
"""Test receiving a maintenance notification."""
325+
self.cluster.set("test", "VAL")
326+
pubsub = self.cluster.pubsub()
327+
pubsub.subscribe("test")
328+
test_msg = pubsub.get_message(ignore_subscribe_messages=True, timeout=10)
329+
print(test_msg)
330+
331+
# Try to send a push notification to the clients of given server node
332+
# Server node is defined by its port with the local test environment
333+
# The message should be in the format:
334+
# >3\r\n$7\r\nmessage\r\n$3\r\nfoo\r\n$4\r\neeee\r
335+
notification = RespTranslator.smigrating_to_resp(
336+
"TEST_NOTIFICATION 12182 127.0.0.1:15380"
337+
)
338+
self.proxy_helper.send_notification(pubsub.connection.port, notification)
339+
res = self.proxy_helper.get_connections()
340+
print(res)
341+
342+
test_msg = pubsub.get_message(timeout=1)
343+
print(test_msg)
344+
pass

tests/test_maint_notifications_handling.py renamed to tests/maint_notifications/test_maint_notifications_handling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def _get_client(
441441
setup_pool_handler: Whether to set up pool handler for moving notifications (default: False)
442442
443443
Returns:
444-
tuple: (test_pool, test_redis_client)
444+
test_redis_client
445445
"""
446446
config = (
447447
maint_notifications_config

0 commit comments

Comments
 (0)