|
| 1 | +import base64 |
| 2 | +import json |
| 3 | +import logging |
| 4 | +from typing import Union |
| 5 | +from urllib.request import Request, urlopen |
| 6 | +from urllib.error import URLError |
| 7 | + |
| 8 | + |
| 9 | +class RespTranslator: |
| 10 | + """Helper class to translate between RESP and other encodings.""" |
| 11 | + |
| 12 | + @staticmethod |
| 13 | + def cluster_slots_to_resp(resp: str) -> str: |
| 14 | + """Convert query to RESP format.""" |
| 15 | + return ( |
| 16 | + f"*{len(resp.split())}\r\n" |
| 17 | + + "\r\n".join(f"${len(x)}\r\n{x}" for x in resp.split()) |
| 18 | + + "\r\n" |
| 19 | + ) |
| 20 | + |
| 21 | + @staticmethod |
| 22 | + def smigrating_to_resp(resp: str) -> str: |
| 23 | + """Convert query to RESP format.""" |
| 24 | + return ( |
| 25 | + f">{len(resp.split())}\r\n" |
| 26 | + + "\r\n".join(f"${len(x)}\r\n{x}" for x in resp.split()) |
| 27 | + + "\r\n" |
| 28 | + ) |
| 29 | + |
| 30 | + |
| 31 | +class ProxyInterceptorHelper: |
| 32 | + """Helper class for intercepting socket calls and managing interceptor server.""" |
| 33 | + |
| 34 | + def __init__(self, server_url: str = "http://localhost:4000"): |
| 35 | + self.server_url = server_url |
| 36 | + self._resp_translator = RespTranslator() |
| 37 | + |
| 38 | + def cleanup_interceptors(self, *names: str): |
| 39 | + """ |
| 40 | + Resets all the interceptors by providing empty pattern and returned response. |
| 41 | +
|
| 42 | + Args: |
| 43 | + names: Names of the interceptors to reset |
| 44 | + """ |
| 45 | + for name in names: |
| 46 | + self._reset_interceptor(name) |
| 47 | + |
| 48 | + def set_cluster_nodes(self, name: str, nodes: list[tuple[str, int]]) -> str: |
| 49 | + """ |
| 50 | + Set cluster nodes by intercepting CLUSTER SLOTS command. |
| 51 | +
|
| 52 | + This method creates an interceptor that intercepts CLUSTER SLOTS commands |
| 53 | + and returns a modified topology with the provided nodes. |
| 54 | +
|
| 55 | + Args: |
| 56 | + name: Name of the interceptor |
| 57 | + nodes: List of (host, port) tuples representing the cluster nodes |
| 58 | +
|
| 59 | + Returns: |
| 60 | + The interceptor name that was created |
| 61 | +
|
| 62 | + Example: |
| 63 | + interceptor = InterceptorHelper(None, "http://localhost:4000") |
| 64 | + interceptor_name = interceptor.set_cluster_nodes( |
| 65 | + "test_topology", |
| 66 | + [("127.0.0.1", 6379), ("127.0.0.1", 6380), ("127.0.0.1", 6381)] |
| 67 | + ) |
| 68 | + """ |
| 69 | + # Build RESP response for CLUSTER SLOTS |
| 70 | + # Format: *<num_slots_ranges> for each range: *3 :start :end *3 $<host_len> <host> :<port> $<id_len> <id> |
| 71 | + resp_parts = [f"*{len(nodes)}"] |
| 72 | + |
| 73 | + # For simplicity, distribute slots evenly across nodes |
| 74 | + total_slots = 16384 |
| 75 | + slots_per_node = total_slots // len(nodes) |
| 76 | + |
| 77 | + for i, (host, port) in enumerate(nodes): |
| 78 | + start_slot = i * slots_per_node |
| 79 | + end_slot = ( |
| 80 | + (i + 1) * slots_per_node - 1 if i < len(nodes) - 1 else total_slots - 1 |
| 81 | + ) |
| 82 | + |
| 83 | + # Node info: *3 for (host, port, id) |
| 84 | + resp_parts.append("*3") |
| 85 | + resp_parts.append(f":{start_slot}") |
| 86 | + resp_parts.append(f":{end_slot}") |
| 87 | + |
| 88 | + # Node details: *3 for (host, port, id) |
| 89 | + resp_parts.append("*3") |
| 90 | + resp_parts.append(f"${len(host)}") |
| 91 | + resp_parts.append(host) |
| 92 | + resp_parts.append(f":{port}") |
| 93 | + resp_parts.append("$13") |
| 94 | + resp_parts.append(f"proxy-id-{port}") |
| 95 | + |
| 96 | + response = "\r\n".join(resp_parts) + "\r\n" |
| 97 | + |
| 98 | + # Add the interceptor |
| 99 | + self._add_interceptor( |
| 100 | + name=name, |
| 101 | + match="*2\r\n$7\r\ncluster\r\n$5\r\nslots\r\n", |
| 102 | + response=response, |
| 103 | + encoding="raw", |
| 104 | + ) |
| 105 | + |
| 106 | + return name |
| 107 | + |
| 108 | + def get_stats(self) -> dict: |
| 109 | + """ |
| 110 | + Get statistics from the interceptor server. |
| 111 | +
|
| 112 | + Returns: |
| 113 | + Statistics dictionary containing connection information |
| 114 | + """ |
| 115 | + url = f"{self.server_url}/stats" |
| 116 | + request = Request(url, method="GET") |
| 117 | + |
| 118 | + try: |
| 119 | + with urlopen(request) as response: |
| 120 | + return json.loads(response.read().decode("utf-8")) |
| 121 | + except URLError as e: |
| 122 | + raise RuntimeError(f"Failed to get stats from interceptor server: {e}") |
| 123 | + |
| 124 | + def get_connections(self) -> dict: |
| 125 | + """ |
| 126 | + Get all active connections from the server. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + Response from the server as a dictionary |
| 130 | + """ |
| 131 | + url = f"{self.server_url}/connections" |
| 132 | + request = Request(url, method="GET") |
| 133 | + |
| 134 | + try: |
| 135 | + with urlopen(request) as response: |
| 136 | + return json.loads(response.read().decode("utf-8")) |
| 137 | + except URLError as e: |
| 138 | + raise RuntimeError(f"Failed to get connections: {e}") |
| 139 | + |
| 140 | + def send_notification( |
| 141 | + self, |
| 142 | + connected_to_port: Union[int, str], |
| 143 | + notification: str, |
| 144 | + ) -> dict: |
| 145 | + """ |
| 146 | + Send a notification to all connections connected to a specific node. |
| 147 | +
|
| 148 | + This method: |
| 149 | + 1. Fetches stats from the interceptor server |
| 150 | + 2. Finds all connection IDs connected to the specified node |
| 151 | + 3. Sends the notification to each connection |
| 152 | +
|
| 153 | + Args: |
| 154 | + node_address: Node address in format "host:port" (e.g., "127.0.0.1:6379") |
| 155 | + notification: The notification message to send (RESP format) |
| 156 | + encoding: Encoding type - "base64" or "raw" |
| 157 | +
|
| 158 | + Returns: |
| 159 | + Response from the server as a dictionary |
| 160 | +
|
| 161 | + Example: |
| 162 | + interceptor = InterceptorHelper(None, "http://localhost:4000") |
| 163 | + result = interceptor.send_notification( |
| 164 | + "6379", |
| 165 | + "KjENCiQ0DQpQSU5HDQo=", # PING command in base64 |
| 166 | + encoding="base64" |
| 167 | + ) |
| 168 | + """ |
| 169 | + # Get stats to find connection IDs for the node |
| 170 | + stats = self.get_stats() |
| 171 | + |
| 172 | + # Extract connection IDs for the specified node |
| 173 | + conn_ids = [] |
| 174 | + for node_key, node_info in stats.items(): |
| 175 | + node_port = node_key.split("@")[1] |
| 176 | + if int(node_port) == int(connected_to_port): |
| 177 | + for conn in node_info.get("connections", []): |
| 178 | + conn_ids.append(conn["id"]) |
| 179 | + |
| 180 | + if not conn_ids: |
| 181 | + raise RuntimeError( |
| 182 | + f"No connections found for node {node_port}. " |
| 183 | + f"Available nodes: {list(set(c.get('node') for c in stats.get('connections', {}).values()))}" |
| 184 | + ) |
| 185 | + |
| 186 | + # Send notification to each connection |
| 187 | + results = {} |
| 188 | + logging.info(f"Sending notification to {len(conn_ids)} connections: {conn_ids}") |
| 189 | + connections_query = f"connectionIds={','.join(conn_ids)}" |
| 190 | + url = f"{self.server_url}/send-to-clients?{connections_query}&encoding=base64" |
| 191 | + # Encode notification to base64 |
| 192 | + data = base64.b64encode(notification.encode("utf-8")) |
| 193 | + |
| 194 | + request = Request(url, data=data, method="POST") |
| 195 | + |
| 196 | + try: |
| 197 | + with urlopen(request) as response: |
| 198 | + results = json.loads(response.read().decode("utf-8")) |
| 199 | + except URLError as e: |
| 200 | + results = {"error": str(e)} |
| 201 | + |
| 202 | + return { |
| 203 | + "node_address": node_port, |
| 204 | + "connection_ids": conn_ids, |
| 205 | + "results": results, |
| 206 | + } |
| 207 | + |
| 208 | + def _add_interceptor( |
| 209 | + self, |
| 210 | + name: str, |
| 211 | + match: str, |
| 212 | + response: str, |
| 213 | + encoding: str = "raw", |
| 214 | + ) -> dict: |
| 215 | + """ |
| 216 | + Add an interceptor to the server. |
| 217 | +
|
| 218 | + Args: |
| 219 | + name: Name of the interceptor |
| 220 | + match: Pattern to match (RESP format) |
| 221 | + response: Response to return when matched (RESP format) |
| 222 | + encoding: Encoding type - "base64" or "raw" |
| 223 | +
|
| 224 | + Returns: |
| 225 | + Response from the server as a dictionary |
| 226 | + """ |
| 227 | + url = f"{self.server_url}/interceptors" |
| 228 | + payload = { |
| 229 | + "name": name, |
| 230 | + "match": match, |
| 231 | + "response": response, |
| 232 | + "encoding": encoding, |
| 233 | + } |
| 234 | + data = json.dumps(payload).encode("utf-8") |
| 235 | + request = Request( |
| 236 | + url, data=data, method="POST", headers={"Content-Type": "application/json"} |
| 237 | + ) |
| 238 | + |
| 239 | + try: |
| 240 | + with urlopen(request) as response: |
| 241 | + return json.loads(response.read().decode("utf-8")) |
| 242 | + except URLError as e: |
| 243 | + raise RuntimeError(f"Failed to add interceptor: {e}") |
| 244 | + |
| 245 | + def _reset_interceptor(self, name: str): |
| 246 | + """ |
| 247 | + Reset an interceptor by providing empty pattern and returned response. |
| 248 | +
|
| 249 | + Args: |
| 250 | + name: Name of the interceptor to reset |
| 251 | + """ |
| 252 | + self._add_interceptor(name, "", "") |
0 commit comments