diff --git a/docker-compose.yml b/docker-compose.yml index 8e93cc132a..625fdec28a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -134,6 +134,8 @@ services: API_PORT: "4000" ENABLE_LOGGING: true SIMULATE_CLUSTER: true + DEFAULT_INTERCEPTORS: "cluster,hitless,logger" + ports: - "15379:15379" - "15380:15380" diff --git a/redis/_parsers/socket.py b/redis/_parsers/socket.py index 8147243bba..738d65a29c 100644 --- a/redis/_parsers/socket.py +++ b/redis/_parsers/socket.py @@ -62,7 +62,7 @@ def _read_from_socket( sock.settimeout(timeout) try: while True: - data = self._sock.recv(socket_read_size) + data = sock.recv(socket_read_size) # an empty string indicates the server shutdown the socket if isinstance(data, bytes) and len(data) == 0: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) diff --git a/redis/cluster.py b/redis/cluster.py index fb147faf5b..403116a5d4 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2027,10 +2027,10 @@ def initialize( # Make sure cluster mode is enabled on this node try: cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) + # For some cases we might not want to disconnect current pool and + # lose in flight commands responses if disconnect_startup_nodes_pools: # Disconnect the connection pool to avoid keeping the connection open - # For some cases we might not want to disconnect current pool and - # lose in flight commands responses r.connection_pool.disconnect() except ResponseError: raise RedisClusterException( diff --git a/redis/connection.py b/redis/connection.py index e8dc39a0d6..81f2a22b7d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -698,7 +698,13 @@ def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None) conn_socket = self._get_socket() if conn_socket: timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout - conn_socket.settimeout(timeout) + # if the current timeout is 0 it means we are in the middle of a can_read call + # in this case we don't want to change the timeout because the operation + # is non-blocking and should return immediately + # Changing the state from non-blocking to blocking in the middle of a read operation + # will lead to a deadlock + if conn_socket.gettimeout() != 0: + conn_socket.settimeout(timeout) self.update_parser_timeout(timeout) def update_parser_timeout(self, timeout: Optional[float] = None): diff --git a/tests/conftest.py b/tests/conftest.py index 9d2f51795a..1d6cea5ae4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -163,6 +163,13 @@ def pytest_addoption(parser): help="Name of the Redis endpoint the tests should be executed on", ) + parser.addoption( + "--cluster-endpoint-name", + action="store", + default=None, + help="Name of the Redis endpoint with OSS API the tests should be executed on", + ) + def _get_info(redis_url): client = redis.Redis.from_url(redis_url) diff --git a/tests/maint_notifications/proxy_server_helpers.py b/tests/maint_notifications/proxy_server_helpers.py index 1b219f2aaf..7c0100d7d4 100644 --- a/tests/maint_notifications/proxy_server_helpers.py +++ b/tests/maint_notifications/proxy_server_helpers.py @@ -1,7 +1,6 @@ import base64 from dataclasses import dataclass import logging -import re from typing import Union from redis.http.http_client import HttpClient, HttpError @@ -10,6 +9,25 @@ class RespTranslator: """Helper class to translate between RESP and other encodings.""" + @staticmethod + def re_cluster_maint_notification_to_resp(txt: str) -> str: + """Convert query to RESP format.""" + parts = txt.split() + + match parts: + case ["MOVING", seq_id, time, new_host]: + return f">4\r\n+MOVING\r\n:{seq_id}\r\n:{time}\r\n+{new_host}\r\n" + case ["MIGRATING", seq_id, time, shards]: + return f">4\r\n+MIGRATING\r\n:{seq_id}\r\n:{time}\r\n+{shards}\r\n" + case ["MIGRATED", seq_id, shards]: + return f">3\r\n+MIGRATED\r\n:{seq_id}\r\n+{shards}\r\n" + case ["FAILING_OVER", seq_id, time, shards]: + return f">4\r\n+FAILING_OVER\r\n:{seq_id}\r\n:{time}\r\n+{shards}\r\n" + case ["FAILED_OVER", seq_id, shards]: + return f">3\r\n+FAILED_OVER\r\n:{seq_id}\r\n+{shards}\r\n" + case _: + raise NotImplementedError(f"Unknown notification: {txt}") + @staticmethod def oss_maint_notification_to_resp(txt: str) -> str: """Convert query to RESP format.""" @@ -232,8 +250,7 @@ def send_notification( if not conn_ids: raise RuntimeError( - f"No connections found for node {node_port}. " - f"Available nodes: {list(set(c.get('node') for c in stats.get('connections', {}).values()))}" + f"No connections found for node {connected_to_port}. \nStats: {stats}" ) # Send notification to each connection diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 803445f508..7ec1ebb16c 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -17,7 +17,7 @@ from redis.event import AsyncEventListenerInterface, EventDispatcher from redis.multidb.failure_detector import DEFAULT_MIN_NUM_FAILURES from tests.test_scenario.conftest import get_endpoints_config, extract_cluster_fqdn -from tests.test_scenario.fault_injector_client import FaultInjectorClient +from tests.test_scenario.fault_injector_client import REFaultInjector class CheckActiveDatabaseChangedListener(AsyncEventListenerInterface): @@ -31,7 +31,7 @@ async def listen(self, event: AsyncActiveDatabaseChanged): @pytest.fixture() def fault_injector_client(): url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") - return FaultInjectorClient(url) + return REFaultInjector(url) @pytest_asyncio.fixture() diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index a7bdb61b07..41eb74762e 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -6,6 +6,7 @@ from urllib.parse import urlparse import pytest +from redis import RedisCluster from redis.backoff import NoBackoff, ExponentialBackoff from redis.event import EventDispatcher, EventListenerInterface @@ -22,12 +23,16 @@ from redis.client import Redis from redis.maint_notifications import EndpointType, MaintNotificationsConfig from redis.retry import Retry -from tests.test_scenario.fault_injector_client import FaultInjectorClient +from tests.test_scenario.fault_injector_client import ( + ProxyServerFaultInjector, + REFaultInjector, +) RELAXED_TIMEOUT = 30 CLIENT_TIMEOUT = 5 DEFAULT_ENDPOINT_NAME = "m-standard" +DEFAULT_OSS_API_ENDPOINT_NAME = "oss-api" class CheckActiveDatabaseChangedListener(EventListenerInterface): @@ -38,6 +43,10 @@ def listen(self, event: ActiveDatabaseChanged): self.is_changed_flag = True +def use_mock_proxy(): + return os.getenv("REDIS_ENTERPRISE_TESTS", "true").lower() == "false" + + @pytest.fixture() def endpoint_name(request): return request.config.getoption("--endpoint-name") or os.getenv( @@ -45,6 +54,13 @@ def endpoint_name(request): ) +@pytest.fixture() +def cluster_endpoint_name(request): + return request.config.getoption("--cluster-endpoint-name") or os.getenv( + "REDIS_CLUSTER_ENDPOINT_NAME", DEFAULT_OSS_API_ENDPOINT_NAME + ) + + def get_endpoints_config(endpoint_name: str): endpoints_config = os.getenv("REDIS_ENDPOINTS_CONFIG_PATH", None) @@ -67,10 +83,27 @@ def endpoints_config(endpoint_name: str): return get_endpoints_config(endpoint_name) +@pytest.fixture() +def cluster_endpoints_config(cluster_endpoint_name: str): + return get_endpoints_config(cluster_endpoint_name) + + @pytest.fixture() def fault_injector_client(): - url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") - return FaultInjectorClient(url) + if use_mock_proxy(): + return ProxyServerFaultInjector(oss_cluster=False) + else: + url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") + return REFaultInjector(url) + + +@pytest.fixture() +def fault_injector_client_oss_api(): + if use_mock_proxy(): + return ProxyServerFaultInjector(oss_cluster=True) + else: + url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") + return REFaultInjector(url) @pytest.fixture() @@ -208,8 +241,6 @@ def _get_client_maint_notifications( endpoint_type=endpoint_type, ) - # Create Redis client with maintenance notifications config - # This will automatically create the MaintNotificationsPoolHandler if disable_retries: retry = Retry(NoBackoff(), 0) else: @@ -218,6 +249,8 @@ def _get_client_maint_notifications( tls_enabled = True if parsed.scheme == "rediss" else False logging.info(f"TLS enabled: {tls_enabled}") + # Create Redis client with maintenance notifications config + # This will automatically create the MaintNotificationsPoolHandler client = Redis( host=host, port=port, @@ -235,3 +268,76 @@ def _get_client_maint_notifications( logging.info(f"Client uses Protocol: {client.connection_pool.get_protocol()}") return client + + +@pytest.fixture() +def cluster_client_maint_notifications(cluster_endpoints_config): + return _get_cluster_client_maint_notifications(cluster_endpoints_config) + + +def _get_cluster_client_maint_notifications( + endpoints_config, + protocol: int = 3, + enable_maintenance_notifications: bool = True, + endpoint_type: Optional[EndpointType] = None, + enable_relaxed_timeout: bool = True, + enable_proactive_reconnect: bool = True, + disable_retries: bool = False, + socket_timeout: Optional[float] = None, + host_config: Optional[str] = None, +): + """Create Redis cluster client with maintenance notifications enabled.""" + # Get credentials from the configuration + username = endpoints_config.get("username") + password = endpoints_config.get("password") + + # Parse host and port from endpoints URL + endpoints = endpoints_config.get("endpoints", []) + if not endpoints: + raise ValueError("No endpoints found in configuration") + + parsed = urlparse(endpoints[0]) + host = parsed.hostname + port = parsed.port + + if not host: + raise ValueError(f"Could not parse host from endpoint URL: {endpoints[0]}") + + logging.info(f"Connecting to Redis Enterprise: {host}:{port} with user: {username}") + + if disable_retries: + retry = Retry(NoBackoff(), 0) + else: + retry = Retry(backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3) + + tls_enabled = True if parsed.scheme == "rediss" else False + logging.info(f"TLS enabled: {tls_enabled}") + + # Configure maintenance notifications + maintenance_config = MaintNotificationsConfig( + enabled=enable_maintenance_notifications, + proactive_reconnect=enable_proactive_reconnect, + relaxed_timeout=RELAXED_TIMEOUT if enable_relaxed_timeout else -1, + endpoint_type=endpoint_type, + ) + + # Create Redis cluster client with maintenance notifications config + client = RedisCluster( + host=host, + port=port, + socket_timeout=CLIENT_TIMEOUT if socket_timeout is None else socket_timeout, + username=username, + password=password, + ssl=tls_enabled, + ssl_cert_reqs="none", + ssl_check_hostname=False, + protocol=protocol, # RESP3 required for push notifications + maint_notifications_config=maintenance_config, + retry=retry, + ) + logging.info("Redis cluster client created with maintenance notifications enabled") + logging.info( + f"Cluster working with the following nodes: {[(node.name, node.server_type) for node in client.get_nodes()]}" + ) + + return client diff --git a/tests/test_scenario/fault_injector_client.py b/tests/test_scenario/fault_injector_client.py index 8bce3a19e7..52ed7a599e 100644 --- a/tests/test_scenario/fault_injector_client.py +++ b/tests/test_scenario/fault_injector_client.py @@ -1,13 +1,22 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass import json import logging import time import urllib.request import urllib.error -from typing import Dict, Any, Optional, Union +from typing import Dict, Any, Optional, Tuple, Union from enum import Enum import pytest +from redis.cluster import ClusterNode +from tests.maint_notifications.proxy_server_helpers import ( + ProxyInterceptorHelper, + RespTranslator, + SlotsRange, +) + class TaskStatuses: """Class to hold completed statuses constants.""" @@ -56,7 +65,75 @@ def to_dict(self) -> Dict[str, Any]: } -class FaultInjectorClient: +@dataclass +class NodeInfo: + node_id: str + role: str + internal_address: str + external_address: str + hostname: str + port: int + + +class FaultInjectorClient(ABC): + @abstractmethod + def get_operation_result( + self, + action_id: str, + timeout: int = 60, + ) -> Dict[str, Any]: + pass + + @abstractmethod + def find_target_node_and_empty_node( + self, + endpoint_config: Dict[str, Any], + ) -> Tuple[NodeInfo, NodeInfo]: + pass + + @abstractmethod + def find_endpoint_for_bind( + self, + endpoint_config: Dict[str, Any], + endpoint_name: str, + ) -> str: + pass + + @abstractmethod + def execute_failover( + self, + endpoint_config: Dict[str, Any], + timeout: int = 60, + ) -> Dict[str, Any]: + pass + + @abstractmethod + def execute_migrate( + self, + endpoint_config: Dict[str, Any], + target_node: str, + empty_node: str, + ) -> str: + pass + + @abstractmethod + def execute_rebind( + self, + endpoint_config: Dict[str, Any], + endpoint_id: str, + ) -> str: + pass + + @abstractmethod + def get_moving_ttl(self) -> int: + pass + + +class REFaultInjector(FaultInjectorClient): + """Fault injector client for Redis Enterprise cluster setup.""" + + MOVING_TTL = 15 + def __init__(self, base_url: str): self.base_url = base_url.rstrip("/") @@ -148,3 +225,510 @@ def get_operation_result( time.sleep(check_interval) else: pytest.fail(f"Timeout waiting for operation {action_id}") + + def get_cluster_nodes_info( + self, + endpoint_config: Dict[str, Any], + timeout: int = 60, + ) -> Dict[str, Any]: + """Get cluster nodes information from Redis Enterprise.""" + try: + # Use rladmin status to get node information + bdb_id = endpoint_config.get("bdb_id") + get_status_action = ActionRequest( + action_type=ActionType.EXECUTE_RLADMIN_COMMAND, + parameters={ + "rladmin_command": "status", + "bdb_id": bdb_id, + }, + ) + trigger_action_result = self.trigger_action(get_status_action) + action_id = trigger_action_result.get("action_id") + if not action_id: + raise ValueError( + f"Failed to trigger get cluster status action for bdb_id {bdb_id}: {trigger_action_result}" + ) + + action_status_check_response = self.get_operation_result( + action_id, timeout=timeout + ) + logging.info( + f"Completed cluster nodes info reading: {action_status_check_response}" + ) + return action_status_check_response + + except Exception as e: + pytest.fail(f"Failed to get cluster nodes info: {e}") + + def find_target_node_and_empty_node( + self, + endpoint_config: Dict[str, Any], + ) -> Tuple[NodeInfo, NodeInfo]: + """Find the node with master shards and the node with no shards. + + Returns: + tuple: (target_node, empty_node) where target_node has master shards + and empty_node has no shards + """ + db_port = int(endpoint_config.get("port", 0)) + cluster_info = self.get_cluster_nodes_info(endpoint_config) + output = cluster_info.get("output", {}).get("output", "") + + if not output: + raise ValueError("No cluster status output found") + + # Parse the sections to find nodes with master shards and nodes with no shards + lines = output.split("\n") + shards_section_started = False + nodes_section_started = False + + # Get all node IDs from CLUSTER NODES section + all_nodes = set() + all_nodes_details = {} + nodes_with_any_shards = set() # Nodes with shards from ANY database + nodes_with_target_db_shards = set() # Nodes with shards from target database + master_nodes = set() # Master nodes for target database only + + for line in lines: + line = line.strip() + + # Start of CLUSTER NODES section + if line.startswith("CLUSTER NODES:"): + nodes_section_started = True + continue + elif line.startswith("DATABASES:"): + nodes_section_started = False + continue + elif nodes_section_started and line and not line.startswith("NODE:ID"): + # Parse node line: node:1 master 10.0.101.206 ... (ignore the role) + parts = line.split() + if len(parts) >= 1: + node_id = parts[0].replace("*", "") # Remove * prefix if present + node_role = parts[1] + node_internal_address = parts[2] + node_external_address = parts[3] + node_hostname = parts[4] + + node = NodeInfo( + node_id.split(":")[1], + node_role, + node_internal_address, + node_external_address, + node_hostname, + db_port, + ) + all_nodes.add(node_id) + all_nodes_details[node_id.split(":")[1]] = node + + # Start of SHARDS section - only care about shard roles here + if line.startswith("SHARDS:"): + shards_section_started = True + continue + elif shards_section_started and line.startswith("DB:ID"): + continue + elif shards_section_started and line and not line.startswith("ENDPOINTS:"): + # Parse shard line: db:1 m-standard redis:1 node:2 master 0-8191 1.4MB OK + parts = line.split() + if len(parts) >= 5: + db_id = parts[0] # db:1, db:2, etc. + node_id = parts[3] # node:2 + shard_role = parts[4] # master/slave - this is what matters + + # Track ALL nodes with shards (for finding truly empty nodes) + nodes_with_any_shards.add(node_id) + + # Only track master nodes for the specific database we're testing + bdb_id = endpoint_config.get("bdb_id") + if db_id == f"db:{bdb_id}": + nodes_with_target_db_shards.add(node_id) + if shard_role == "master": + master_nodes.add(node_id) + elif line.startswith("ENDPOINTS:") or not line: + shards_section_started = False + + # Find empty node (node with no shards from ANY database) + nodes_with_no_shards_target_bdb = all_nodes - nodes_with_target_db_shards + + logging.debug(f"All nodes: {all_nodes}") + logging.debug(f"Nodes with shards from any database: {nodes_with_any_shards}") + logging.debug( + f"Nodes with target database shards: {nodes_with_target_db_shards}" + ) + logging.debug(f"Master nodes (target database only): {master_nodes}") + logging.debug( + f"Nodes with no shards from target database: {nodes_with_no_shards_target_bdb}" + ) + + if not nodes_with_no_shards_target_bdb: + raise ValueError("All nodes have shards from target database") + + if not master_nodes: + raise ValueError("No nodes with master shards from target database found") + + # Return the first available empty node and master node (numeric part only) + empty_node = next(iter(nodes_with_no_shards_target_bdb)).split(":")[ + 1 + ] # node:1 -> 1 + target_node = next(iter(master_nodes)).split(":")[1] # node:2 -> 2 + + return all_nodes_details[target_node], all_nodes_details[empty_node] + + def find_endpoint_for_bind( + self, + endpoint_config: Dict[str, Any], + endpoint_name: str, + timeout: int = 60, + ) -> str: + """Find the endpoint ID from cluster status. + + Returns: + str: The endpoint ID (e.g., "1:1") + """ + cluster_info = self.get_cluster_nodes_info(endpoint_config, timeout) + output = cluster_info.get("output", {}).get("output", "") + + if not output: + raise ValueError("No cluster status output found") + + # Parse the ENDPOINTS section to find endpoint ID + lines = output.split("\n") + endpoints_section_started = False + + for line in lines: + line = line.strip() + + # Start of ENDPOINTS section + if line.startswith("ENDPOINTS:"): + endpoints_section_started = True + continue + elif line.startswith("SHARDS:"): + break + elif endpoints_section_started and line and not line.startswith("DB:ID"): + # Parse endpoint line: db:1 m-standard endpoint:1:1 node:2 single No + parts = line.split() + if len(parts) >= 3 and parts[1] == endpoint_name: + endpoint_full = parts[2] # endpoint:1:1 + if endpoint_full.startswith("endpoint:"): + endpoint_id = endpoint_full.replace("endpoint:", "") # 1:1 + return endpoint_id + + raise ValueError(f"No endpoint ID for {endpoint_name} found in cluster status") + + def execute_failover( + self, + endpoint_config: Dict[str, Any], + timeout: int = 60, + ) -> Dict[str, Any]: + """Execute failover command and wait for completion.""" + + try: + bdb_id = endpoint_config.get("bdb_id") + failover_action = ActionRequest( + action_type=ActionType.FAILOVER, + parameters={ + "bdb_id": bdb_id, + }, + ) + trigger_action_result = self.trigger_action(failover_action) + action_id = trigger_action_result.get("action_id") + if not action_id: + raise ValueError( + f"Failed to trigger fail over action for bdb_id {bdb_id}: {trigger_action_result}" + ) + + action_status_check_response = self.get_operation_result( + action_id, timeout=timeout + ) + logging.info( + f"Completed cluster nodes info reading: {action_status_check_response}" + ) + return action_status_check_response + + except Exception as e: + pytest.fail(f"Failed to get cluster nodes info: {e}") + + def execute_migrate( + self, + endpoint_config: Dict[str, Any], + target_node: str, + empty_node: str, + ) -> str: + """Execute rladmin migrate command and wait for completion.""" + command = f"migrate node {target_node} all_shards target_node {empty_node}" + + # Get bdb_id from endpoint configuration + bdb_id = endpoint_config.get("bdb_id") + + try: + # Correct parameter format for fault injector + parameters = { + "bdb_id": bdb_id, + "rladmin_command": command, # Just the command without "rladmin" prefix + } + + logging.debug(f"Executing rladmin_command with parameter: {parameters}") + + action = ActionRequest( + action_type=ActionType.EXECUTE_RLADMIN_COMMAND, parameters=parameters + ) + result = self.trigger_action(action) + + logging.debug(f"Migrate command action result: {result}") + + action_id = result.get("action_id") + + if not action_id: + raise Exception(f"Failed to trigger migrate action: {result}") + return action_id + except Exception as e: + raise Exception(f"Failed to execute rladmin migrate: {e}") + + def execute_rebind( + self, + endpoint_config: Dict[str, Any], + endpoint_id: str, + ) -> str: + """Execute rladmin bind endpoint command and wait for completion.""" + command = f"bind endpoint {endpoint_id} policy single" + + bdb_id = endpoint_config.get("bdb_id") + + try: + parameters = { + "rladmin_command": command, # Just the command without "rladmin" prefix + "bdb_id": bdb_id, + } + + logging.info(f"Executing rladmin_command with parameter: {parameters}") + action = ActionRequest( + action_type=ActionType.EXECUTE_RLADMIN_COMMAND, parameters=parameters + ) + result = self.trigger_action(action) + logging.info( + f"Migrate command {command} with parameters {parameters} trigger result: {result}" + ) + + action_id = result.get("action_id") + + if not action_id: + raise Exception(f"Failed to trigger bind endpoint action: {result}") + return action_id + except Exception as e: + raise Exception(f"Failed to execute rladmin bind endpoint: {e}") + + def get_moving_ttl(self) -> int: + return self.MOVING_TTL + + +class ProxyServerFaultInjector(FaultInjectorClient): + """Fault injector client for proxy server setup.""" + + NODE_PORT_1 = 15379 + NODE_PORT_2 = 15380 + NODE_PORT_3 = 15381 + + # Initial cluster node configuration for proxy-based tests + PROXY_CLUSTER_NODES = [ + ClusterNode("127.0.0.1", NODE_PORT_1), + ClusterNode("127.0.0.1", NODE_PORT_2), + ] + + DEFAULT_CLUSTER_SLOTS = [ + SlotsRange("127.0.0.1", NODE_PORT_1, 0, 8191), + SlotsRange("127.0.0.1", NODE_PORT_2, 8192, 16383), + ] + + CLUSTER_SLOTS_INTERCEPTOR_NAME = "test_topology" + + SLEEP_TIME_BETWEEN_START_END_NOTIFICATIONS = 2 + MOVING_TTL = 4 + + def __init__(self, oss_cluster: bool = False): + self.oss_cluster = oss_cluster + self.proxy_helper = ProxyInterceptorHelper() + + # set the initial state of the proxy server + logging.info( + f"Setting up initial cluster slots -> {self.DEFAULT_CLUSTER_SLOTS}" + ) + self.proxy_helper.set_cluster_slots( + self.CLUSTER_SLOTS_INTERCEPTOR_NAME, self.DEFAULT_CLUSTER_SLOTS + ) + logging.info("Sleeping for 2 seconds to allow proxy to apply the changes...") + time.sleep(2) + + self.seq_id = 0 + + def _get_seq_id(self): + self.seq_id += 1 + return self.seq_id + + def find_target_node_and_empty_node( + self, + endpoint_config: Dict[str, Any], + ) -> Tuple[NodeInfo, NodeInfo]: + target_node = NodeInfo( + "1", "master", "0.0.0.0", "127.0.0.1", "localhost", self.NODE_PORT_1 + ) + empty_node = NodeInfo( + "3", "master", "0.0.0.0", "127.0.0.1", "localhost", self.NODE_PORT_3 + ) + return target_node, empty_node + + def find_endpoint_for_bind( + self, + endpoint_config: Dict[str, Any], + endpoint_name: str, + ) -> str: + return "1:1" + + def execute_failover( + self, endpoint_config: Dict[str, Any], timeout: int = 60 + ) -> Dict[str, Any]: + """ + Simulates a failover operation and waits for completion. + This method does not create or manage threads; if asynchronous execution is required, + it should be called from a separate thread by the caller. + This will always run for the same nodes - node 1 to node 3! + Assumes that the initial state is the DEFAULT_CLUSTER_SLOTS - shard 1 on node 1 and shard 2 on node 2. + In a real RE cluster, a replica would exist on another node, which is simulated here with node 3. + """ + + # send smigrating + if self.oss_cluster: + start_maint_notif = RespTranslator.oss_maint_notification_to_resp( + f"SMIGRATING {self._get_seq_id()} 0-8191" + ) + else: + # send failing over + start_maint_notif = RespTranslator.re_cluster_maint_notification_to_resp( + f"FAILING_OVER {self._get_seq_id()} 2 [1]" + ) + + self.proxy_helper.send_notification(self.NODE_PORT_1, start_maint_notif) + + # sleep to allow the client to receive the notification + time.sleep(self.SLEEP_TIME_BETWEEN_START_END_NOTIFICATIONS) + + if self.oss_cluster: + # intercept cluster slots + self.proxy_helper.set_cluster_slots( + self.CLUSTER_SLOTS_INTERCEPTOR_NAME, + [ + SlotsRange("127.0.0.1", self.NODE_PORT_3, 0, 8191), + SlotsRange("127.0.0.1", self.NODE_PORT_2, 8192, 16383), + ], + ) + # send smigrated + end_maint_notif = RespTranslator.oss_maint_notification_to_resp( + f"SMIGRATED {self._get_seq_id()} 127.0.0.1:{self.NODE_PORT_3} 0-8191" + ) + else: + # send failed over + end_maint_notif = RespTranslator.re_cluster_maint_notification_to_resp( + f"FAILED_OVER {self._get_seq_id()} [1]" + ) + self.proxy_helper.send_notification(self.NODE_PORT_1, end_maint_notif) + + return {"status": "done"} + + def execute_migrate( + self, endpoint_config: Dict[str, Any], target_node: str, empty_node: str + ) -> str: + """ + Simulate migrate command execution. + This method does not create or manage threads; it simulates the migration process synchronously. + If asynchronous execution is desired, the caller should run this method in a separate thread. + This will run always for the same nodes - node 1 to node 2! + Assuming that the initial state is the DEFAULT_CLUSTER_SLOTS - shard 1 on node 1 and shard 2 on node 2. + + """ + + if self.oss_cluster: + # send smigrating + start_maint_notif = RespTranslator.oss_maint_notification_to_resp( + f"SMIGRATING {self._get_seq_id()} 0-200" + ) + else: + # send migrating + start_maint_notif = RespTranslator.re_cluster_maint_notification_to_resp( + f"MIGRATING {self._get_seq_id()} 2 [1]" + ) + + self.proxy_helper.send_notification(self.NODE_PORT_1, start_maint_notif) + + # sleep to allow the client to receive the notification + time.sleep(self.SLEEP_TIME_BETWEEN_START_END_NOTIFICATIONS) + + if self.oss_cluster: + # intercept cluster slots + self.proxy_helper.set_cluster_slots( + self.CLUSTER_SLOTS_INTERCEPTOR_NAME, + [ + SlotsRange("127.0.0.1", self.NODE_PORT_2, 0, 200), + SlotsRange("127.0.0.1", self.NODE_PORT_1, 201, 8191), + SlotsRange("127.0.0.1", self.NODE_PORT_2, 8192, 16383), + ], + ) + # send smigrated + end_maint_notif = RespTranslator.oss_maint_notification_to_resp( + f"SMIGRATED {self._get_seq_id()} 127.0.0.1:{self.NODE_PORT_2} 0-200" + ) + else: + # send migrated + end_maint_notif = RespTranslator.re_cluster_maint_notification_to_resp( + f"MIGRATED {self._get_seq_id()} [1]" + ) + self.proxy_helper.send_notification(self.NODE_PORT_1, end_maint_notif) + + return "done" + + def execute_rebind(self, endpoint_config: Dict[str, Any], endpoint_id: str) -> str: + """ + Execute rladmin bind endpoint command and wait for completion. + This method simulates the actual bind process. It does not create or manage threads; + if you wish to run it in a separate thread, you must do so from the caller. + This will run always for the same nodes - node 1 to node 3! + Assuming that the initial state is the DEFAULT_CLUSTER_SLOTS - shard 1 on node 1 + and shard 2 on node 2. + + """ + sleep_time = self.SLEEP_TIME_BETWEEN_START_END_NOTIFICATIONS + if self.oss_cluster: + # send smigrating + maint_start_notif = RespTranslator.oss_maint_notification_to_resp( + f"SMIGRATING {self._get_seq_id()} 0-8191" + ) + else: + # send moving + sleep_time = self.MOVING_TTL + maint_start_notif = RespTranslator.re_cluster_maint_notification_to_resp( + f"MOVING {self._get_seq_id()} {sleep_time} 127.0.0.1:{self.NODE_PORT_3}" + ) + self.proxy_helper.send_notification(self.NODE_PORT_1, maint_start_notif) + + # sleep to allow the client to receive the notification + time.sleep(sleep_time) + + if self.oss_cluster: + # intercept cluster slots + self.proxy_helper.set_cluster_slots( + self.CLUSTER_SLOTS_INTERCEPTOR_NAME, + [ + SlotsRange("127.0.0.1", self.NODE_PORT_3, 0, 8191), + SlotsRange("127.0.0.1", self.NODE_PORT_2, 8192, 16383), + ], + ) + # send smigrated + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + f"SMIGRATED {self._get_seq_id()} 127.0.0.1:{self.NODE_PORT_3} 0-8191" + ) + self.proxy_helper.send_notification(self.NODE_PORT_1, smigrated_node_1) + else: + # TODO drop connections to node 1 + pass + + return "done" + + def get_moving_ttl(self) -> int: + return self.MOVING_TTL diff --git a/tests/test_scenario/maint_notifications_helpers.py b/tests/test_scenario/maint_notifications_helpers.py index f7fb640274..505dfb8631 100644 --- a/tests/test_scenario/maint_notifications_helpers.py +++ b/tests/test_scenario/maint_notifications_helpers.py @@ -1,21 +1,48 @@ import logging import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import pytest +from redis import RedisCluster from redis.client import Redis from redis.connection import Connection from tests.test_scenario.fault_injector_client import ( - ActionRequest, - ActionType, FaultInjectorClient, + NodeInfo, ) class ClientValidations: + @staticmethod + def get_default_connection(redis_client: Union[Redis, RedisCluster]) -> Connection: + """Get a random connection from the pool.""" + if isinstance(redis_client, RedisCluster): + return redis_client.get_default_node().redis_connection.connection_pool.get_connection() + if isinstance(redis_client, Redis): + return redis_client.connection_pool.get_connection() + raise ValueError(f"Unsupported redis client type: {type(redis_client)}") + + @staticmethod + def release_connection( + redis_client: Union[Redis, RedisCluster], connection: Connection + ): + """Release a connection back to the pool.""" + if isinstance(redis_client, RedisCluster): + node_address = connection.host + ":" + str(connection.port) + node = redis_client.get_node(node_address) + if node is None: + raise ValueError( + f"Node not found in cluster for address: {node_address}" + ) + node.redis_connection.connection_pool.release(connection) + elif isinstance(redis_client, Redis): + redis_client.connection_pool.release(connection) + else: + raise ValueError(f"Unsupported redis client type: {type(redis_client)}") + @staticmethod def wait_push_notification( - redis_client: Redis, + redis_client: Union[Redis, RedisCluster], timeout: int = 120, fail_on_timeout: bool = True, connection: Optional[Connection] = None, @@ -24,8 +51,11 @@ def wait_push_notification( start_time = time.time() check_interval = 0.2 # Check more frequently during operations test_conn = ( - connection if connection else redis_client.connection_pool.get_connection() + connection + if connection + else ClientValidations.get_default_connection(redis_client) ) + logging.info(f"Waiting for push notification on connection: {test_conn}") try: while time.time() - start_time < timeout: @@ -49,146 +79,24 @@ def wait_push_notification( # Release the connection back to the pool try: if not connection: - redis_client.connection_pool.release(test_conn) + ClientValidations.release_connection(redis_client, test_conn) except Exception as e: logging.error(f"Error releasing connection: {e}") class ClusterOperations: - @staticmethod - def get_cluster_nodes_info( - fault_injector: FaultInjectorClient, - endpoint_config: Dict[str, Any], - timeout: int = 60, - ) -> Dict[str, Any]: - """Get cluster nodes information from Redis Enterprise.""" - try: - # Use rladmin status to get node information - bdb_id = endpoint_config.get("bdb_id") - get_status_action = ActionRequest( - action_type=ActionType.EXECUTE_RLADMIN_COMMAND, - parameters={ - "rladmin_command": "status", - "bdb_id": bdb_id, - }, - ) - trigger_action_result = fault_injector.trigger_action(get_status_action) - action_id = trigger_action_result.get("action_id") - if not action_id: - raise ValueError( - f"Failed to trigger get cluster status action for bdb_id {bdb_id}: {trigger_action_result}" - ) - - action_status_check_response = fault_injector.get_operation_result( - action_id, timeout=timeout - ) - logging.info( - f"Completed cluster nodes info reading: {action_status_check_response}" - ) - return action_status_check_response - - except Exception as e: - pytest.fail(f"Failed to get cluster nodes info: {e}") - @staticmethod def find_target_node_and_empty_node( fault_injector: FaultInjectorClient, endpoint_config: Dict[str, Any], - ) -> Tuple[str, str]: + ) -> Tuple[NodeInfo, NodeInfo]: """Find the node with master shards and the node with no shards. Returns: tuple: (target_node, empty_node) where target_node has master shards and empty_node has no shards """ - cluster_info = ClusterOperations.get_cluster_nodes_info( - fault_injector, endpoint_config - ) - output = cluster_info.get("output", {}).get("output", "") - - if not output: - raise ValueError("No cluster status output found") - - # Parse the sections to find nodes with master shards and nodes with no shards - lines = output.split("\n") - shards_section_started = False - nodes_section_started = False - - # Get all node IDs from CLUSTER NODES section - all_nodes = set() - nodes_with_any_shards = set() # Nodes with shards from ANY database - nodes_with_target_db_shards = set() # Nodes with shards from target database - master_nodes = set() # Master nodes for target database only - - for line in lines: - line = line.strip() - - # Start of CLUSTER NODES section - if line.startswith("CLUSTER NODES:"): - nodes_section_started = True - continue - elif line.startswith("DATABASES:"): - nodes_section_started = False - continue - elif nodes_section_started and line and not line.startswith("NODE:ID"): - # Parse node line: node:1 master 10.0.101.206 ... (ignore the role) - parts = line.split() - if len(parts) >= 1: - node_id = parts[0].replace("*", "") # Remove * prefix if present - all_nodes.add(node_id) - - # Start of SHARDS section - only care about shard roles here - if line.startswith("SHARDS:"): - shards_section_started = True - continue - elif shards_section_started and line.startswith("DB:ID"): - continue - elif shards_section_started and line and not line.startswith("ENDPOINTS:"): - # Parse shard line: db:1 m-standard redis:1 node:2 master 0-8191 1.4MB OK - parts = line.split() - if len(parts) >= 5: - db_id = parts[0] # db:1, db:2, etc. - node_id = parts[3] # node:2 - shard_role = parts[4] # master/slave - this is what matters - - # Track ALL nodes with shards (for finding truly empty nodes) - nodes_with_any_shards.add(node_id) - - # Only track master nodes for the specific database we're testing - bdb_id = endpoint_config.get("bdb_id") - if db_id == f"db:{bdb_id}": - nodes_with_target_db_shards.add(node_id) - if shard_role == "master": - master_nodes.add(node_id) - elif line.startswith("ENDPOINTS:") or not line: - shards_section_started = False - - # Find empty node (node with no shards from ANY database) - nodes_with_no_shards_target_bdb = all_nodes - nodes_with_target_db_shards - - logging.debug(f"All nodes: {all_nodes}") - logging.debug(f"Nodes with shards from any database: {nodes_with_any_shards}") - logging.debug( - f"Nodes with target database shards: {nodes_with_target_db_shards}" - ) - logging.debug(f"Master nodes (target database only): {master_nodes}") - logging.debug( - f"Nodes with no shards from target database: {nodes_with_no_shards_target_bdb}" - ) - - if not nodes_with_no_shards_target_bdb: - raise ValueError("All nodes have shards from target database") - - if not master_nodes: - raise ValueError("No nodes with master shards from target database found") - - # Return the first available empty node and master node (numeric part only) - empty_node = next(iter(nodes_with_no_shards_target_bdb)).split(":")[ - 1 - ] # node:1 -> 1 - target_node = next(iter(master_nodes)).split(":")[1] # node:2 -> 2 - - return target_node, empty_node + return fault_injector.find_target_node_and_empty_node(endpoint_config) @staticmethod def find_endpoint_for_bind( @@ -202,38 +110,7 @@ def find_endpoint_for_bind( Returns: str: The endpoint ID (e.g., "1:1") """ - cluster_info = ClusterOperations.get_cluster_nodes_info( - fault_injector, endpoint_config, timeout - ) - output = cluster_info.get("output", {}).get("output", "") - - if not output: - raise ValueError("No cluster status output found") - - # Parse the ENDPOINTS section to find endpoint ID - lines = output.split("\n") - endpoints_section_started = False - - for line in lines: - line = line.strip() - - # Start of ENDPOINTS section - if line.startswith("ENDPOINTS:"): - endpoints_section_started = True - continue - elif line.startswith("SHARDS:"): - endpoints_section_started = False - break - elif endpoints_section_started and line and not line.startswith("DB:ID"): - # Parse endpoint line: db:1 m-standard endpoint:1:1 node:2 single No - parts = line.split() - if len(parts) >= 3 and parts[1] == endpoint_name: - endpoint_full = parts[2] # endpoint:1:1 - if endpoint_full.startswith("endpoint:"): - endpoint_id = endpoint_full.replace("endpoint:", "") # 1:1 - return endpoint_id - - raise ValueError(f"No endpoint ID for {endpoint_name} found in cluster status") + return fault_injector.find_endpoint_for_bind(endpoint_config, endpoint_name) @staticmethod def execute_failover( @@ -242,100 +119,23 @@ def execute_failover( timeout: int = 60, ) -> Dict[str, Any]: """Execute failover command and wait for completion.""" - - try: - bdb_id = endpoint_config.get("bdb_id") - failover_action = ActionRequest( - action_type=ActionType.FAILOVER, - parameters={ - "bdb_id": bdb_id, - }, - ) - trigger_action_result = fault_injector.trigger_action(failover_action) - action_id = trigger_action_result.get("action_id") - if not action_id: - raise ValueError( - f"Failed to trigger fail over action for bdb_id {bdb_id}: {trigger_action_result}" - ) - - action_status_check_response = fault_injector.get_operation_result( - action_id, timeout=timeout - ) - logging.info( - f"Completed cluster nodes info reading: {action_status_check_response}" - ) - return action_status_check_response - - except Exception as e: - pytest.fail(f"Failed to get cluster nodes info: {e}") + return fault_injector.execute_failover(endpoint_config, timeout) @staticmethod - def execute_rladmin_migrate( + def execute_migrate( fault_injector: FaultInjectorClient, endpoint_config: Dict[str, Any], target_node: str, empty_node: str, ) -> str: """Execute rladmin migrate command and wait for completion.""" - command = f"migrate node {target_node} all_shards target_node {empty_node}" - - # Get bdb_id from endpoint configuration - bdb_id = endpoint_config.get("bdb_id") - - try: - # Correct parameter format for fault injector - parameters = { - "bdb_id": bdb_id, - "rladmin_command": command, # Just the command without "rladmin" prefix - } - - logging.debug(f"Executing rladmin_command with parameter: {parameters}") - - action = ActionRequest( - action_type=ActionType.EXECUTE_RLADMIN_COMMAND, parameters=parameters - ) - result = fault_injector.trigger_action(action) - - logging.debug(f"Migrate command action result: {result}") - - action_id = result.get("action_id") - - if not action_id: - raise Exception(f"Failed to trigger migrate action: {result}") - return action_id - except Exception as e: - raise Exception(f"Failed to execute rladmin migrate: {e}") + return fault_injector.execute_migrate(endpoint_config, target_node, empty_node) @staticmethod - def execute_rladmin_bind_endpoint( + def execute_rebind( fault_injector: FaultInjectorClient, endpoint_config: Dict[str, Any], endpoint_id: str, ) -> str: """Execute rladmin bind endpoint command and wait for completion.""" - command = f"bind endpoint {endpoint_id} policy single" - - bdb_id = endpoint_config.get("bdb_id") - - try: - parameters = { - "rladmin_command": command, # Just the command without "rladmin" prefix - "bdb_id": bdb_id, - } - - logging.info(f"Executing rladmin_command with parameter: {parameters}") - action = ActionRequest( - action_type=ActionType.EXECUTE_RLADMIN_COMMAND, parameters=parameters - ) - result = fault_injector.trigger_action(action) - logging.info( - f"Migrate command {command} with parameters {parameters} trigger result: {result}" - ) - - action_id = result.get("action_id") - - if not action_id: - raise Exception(f"Failed to trigger bind endpoint action: {result}") - return action_id - except Exception as e: - raise Exception(f"Failed to execute rladmin bind endpoint: {e}") + return fault_injector.execute_rebind(endpoint_config, endpoint_id) diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 59524ab5c1..7fc7a14d99 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -3,6 +3,7 @@ import os import threading from time import sleep +from typing import Optional import pytest @@ -20,7 +21,7 @@ def trigger_network_failure_action( - fault_injector_client, config, event: threading.Event = None + fault_injector_client, config, event: Optional[threading.Event] = None ): action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, diff --git a/tests/test_scenario/test_maint_notifications.py b/tests/test_scenario/test_maint_notifications.py index 7d99bfe8ae..6ac3d5f394 100644 --- a/tests/test_scenario/test_maint_notifications.py +++ b/tests/test_scenario/test_maint_notifications.py @@ -10,7 +10,7 @@ import pytest -from redis import Redis +from redis import Redis, RedisCluster from redis.connection import ConnectionInterface from redis.maint_notifications import ( EndpointType, @@ -21,9 +21,13 @@ CLIENT_TIMEOUT, RELAXED_TIMEOUT, _get_client_maint_notifications, + _get_cluster_client_maint_notifications, + use_mock_proxy, ) from tests.test_scenario.fault_injector_client import ( FaultInjectorClient, + NodeInfo, + ProxyServerFaultInjector, ) from tests.test_scenario.maint_notifications_helpers import ( ClientValidations, @@ -39,78 +43,18 @@ BIND_TIMEOUT = 60 MIGRATE_TIMEOUT = 60 FAILOVER_TIMEOUT = 15 +SMIGRATING_TIMEOUT = 15 +SMIGRATED_TIMEOUT = 15 DEFAULT_BIND_TTL = 15 -class TestPushNotifications: +class TestPushNotificationsBase: """ Test Redis Enterprise maintenance push notifications with real cluster operations. """ - @pytest.fixture(autouse=True) - def setup_and_cleanup( - self, - client_maint_notifications: Redis, - fault_injector_client: FaultInjectorClient, - endpoints_config: Dict[str, Any], - endpoint_name: str, - ): - # Initialize cleanup flags first to ensure they exist even if setup fails - self._failover_executed = False - self._migration_executed = False - self._bind_executed = False - self.target_node = None - self.empty_node = None - self.endpoint_id = None - - try: - self.target_node, self.empty_node = ( - ClusterOperations.find_target_node_and_empty_node( - fault_injector_client, endpoints_config - ) - ) - logging.info( - f"Using target_node: {self.target_node}, empty_node: {self.empty_node}" - ) - except Exception as e: - pytest.fail(f"Failed to find target and empty nodes: {e}") - - try: - self.endpoint_id = ClusterOperations.find_endpoint_for_bind( - fault_injector_client, endpoints_config, endpoint_name - ) - logging.info(f"Using endpoint: {self.endpoint_id}") - except Exception as e: - pytest.fail(f"Failed to find endpoint for bind operation: {e}") - - # Ensure setup completed successfully - if not self.target_node or not self.empty_node: - pytest.fail("Setup failed: target_node or empty_node not available") - if not self.endpoint_id: - pytest.fail("Setup failed: endpoint_id not available") - - # Yield control to the test - yield - - # Cleanup code - this will run even if the test fails - logging.info("Starting cleanup...") - try: - client_maint_notifications.close() - except Exception as e: - logging.error(f"Failed to close client: {e}") - - # Only attempt cleanup if we have the necessary attributes and they were executed - if self._failover_executed: - try: - self._execute_failover(fault_injector_client, endpoints_config) - logging.info("Failover cleanup completed") - except Exception as e: - logging.error(f"Failed to revert failover: {e}") - - logging.info("Cleanup finished") - def _execute_failover( self, fault_injector_client: FaultInjectorClient, @@ -130,7 +74,7 @@ def _execute_migration( target_node: str, empty_node: str, ): - migrate_action_id = ClusterOperations.execute_rladmin_migrate( + migrate_action_id = ClusterOperations.execute_migrate( fault_injector=fault_injector_client, endpoint_config=endpoints_config, target_node=target_node, @@ -150,7 +94,7 @@ def _execute_bind( endpoints_config: Dict[str, Any], endpoint_id: str, ): - bind_action_id = ClusterOperations.execute_rladmin_bind_endpoint( + bind_action_id = ClusterOperations.execute_rebind( fault_injector_client, endpoints_config, endpoint_id ) @@ -217,6 +161,7 @@ def _validate_moving_state( configured_endpoint_type: EndpointType, expected_matching_connected_conns_count: int, expected_matching_disconnected_conns_count: int, + fault_injector_client: FaultInjectorClient, ): """Validate the client connections are in the expected state after migration.""" matching_connected_conns_count = 0 @@ -236,7 +181,11 @@ def _validate_moving_state( == MaintNotificationsConfig().get_endpoint_type(conn.host, conn) ) ) + or isinstance( + fault_injector_client, ProxyServerFaultInjector + ) # we should not validate the endpoint type when using proxy server ) + if ( conn._sock is not None and conn._sock.gettimeout() == RELAXED_TIMEOUT @@ -260,7 +209,10 @@ def _validate_moving_state( ) def _validate_default_state( - self, client: Redis, expected_matching_conns_count: int + self, + client: Redis, + expected_matching_conns_count: int, + configured_timeout: float = CLIENT_TIMEOUT, ): """Validate the client connections are in the expected state after migration.""" matching_conns_count = 0 @@ -270,12 +222,12 @@ def _validate_default_state( if conn._sock is None: if ( conn.maintenance_state == MaintenanceState.NONE - and conn.socket_timeout == CLIENT_TIMEOUT + and conn.socket_timeout == configured_timeout and conn.host == conn.orig_host_address ): matching_conns_count += 1 elif ( - conn._sock.gettimeout() == CLIENT_TIMEOUT + conn._sock.gettimeout() == configured_timeout and conn.maintenance_state == MaintenanceState.NONE and conn.host == conn.orig_host_address ): @@ -305,6 +257,70 @@ def _validate_default_notif_disabled_state( matching_conns_count += 1 assert matching_conns_count == expected_matching_conns_count + +class TestStandaloneClientPushNotifications(TestPushNotificationsBase): + @pytest.fixture(autouse=True) + def setup_and_cleanup( + self, + client_maint_notifications: Redis, + fault_injector_client: FaultInjectorClient, + endpoints_config: Dict[str, Any], + endpoint_name: str, + ): + # Initialize cleanup flags first to ensure they exist even if setup fails + self._failover_executed = False + self._migration_executed = False + self._bind_executed = False + self.endpoint_id = None + + try: + target_node, empty_node = ClusterOperations.find_target_node_and_empty_node( + fault_injector_client, endpoints_config + ) + logging.info(f"Using target_node: {target_node}, empty_node: {empty_node}") + except Exception as e: + pytest.fail(f"Failed to find target and empty nodes: {e}") + + try: + self.endpoint_id = ClusterOperations.find_endpoint_for_bind( + fault_injector_client, endpoints_config, endpoint_name + ) + logging.info(f"Using endpoint: {self.endpoint_id}") + except Exception as e: + pytest.fail(f"Failed to find endpoint for bind operation: {e}") + + # Ensure setup completed successfully + if not target_node or not empty_node: + pytest.fail("Setup failed: target_node or empty_node not available") + if not self.endpoint_id: + pytest.fail("Setup failed: endpoint_id not available") + + self.target_node: NodeInfo = target_node + self.empty_node: NodeInfo = empty_node + + # Yield control to the test + yield + + # Cleanup code - this will run even if the test fails + logging.info("Starting cleanup...") + try: + client_maint_notifications.close() + except Exception as e: + logging.error(f"Failed to close client: {e}") + + # Only attempt cleanup if we have the necessary attributes and they were executed + if ( + not isinstance(fault_injector_client, ProxyServerFaultInjector) + and self._failover_executed + ): + try: + self._execute_failover(fault_injector_client, endpoints_config) + logging.info("Failover cleanup completed") + except Exception as e: + logging.error(f"Failed to revert failover: {e}") + + logging.info("Cleanup finished") + @pytest.mark.timeout(300) # 5 minutes timeout for this test def test_receive_failing_over_and_failed_over_push_notification( self, @@ -361,6 +377,9 @@ def test_receive_migrating_and_moving_push_notification( Test the push notifications are received when executing cluster operations. """ + # create one connection and release it back to the pool + conn = client_maint_notifications.connection_pool.get_connection() + client_maint_notifications.connection_pool.release(conn) logging.info("Executing rladmin migrate command...") migrate_thread = Thread( @@ -369,8 +388,8 @@ def test_receive_migrating_and_moving_push_notification( args=( fault_injector_client, endpoints_config, - self.target_node, - self.empty_node, + self.target_node.node_id, + self.empty_node.node_id, ), ) migrate_thread.start() @@ -468,8 +487,8 @@ def test_timeout_handling_during_migrating_and_moving( args=( fault_injector_client, endpoints_config, - self.target_node, - self.empty_node, + self.target_node.node_id, + self.empty_node.node_id, ), ) migrate_thread.start() @@ -510,7 +529,7 @@ def test_timeout_handling_during_migrating_and_moving( logging.info( "Waiting for moving ttl/2 to expire to validate proactive reconnection" ) - time.sleep(8) + time.sleep(fault_injector_client.get_moving_ttl() / 2) logging.info("Validating connections states...") self._validate_moving_state( @@ -518,6 +537,7 @@ def test_timeout_handling_during_migrating_and_moving( endpoint_type, expected_matching_connected_conns_count=0, expected_matching_disconnected_conns_count=3, + fault_injector_client=fault_injector_client, ) # during get_connection() the connection will be reconnected # either to the address provided in the moving notification or to the original address @@ -529,6 +549,7 @@ def test_timeout_handling_during_migrating_and_moving( endpoint_type, expected_matching_connected_conns_count=1, expected_matching_disconnected_conns_count=2, + fault_injector_client=fault_injector_client, ) client.connection_pool.release(conn) @@ -569,8 +590,8 @@ def test_connection_handling_during_moving( args=( fault_injector_client, endpoints_config, - self.target_node, - self.empty_node, + self.target_node.node_id, + self.empty_node.node_id, ), ) migrate_thread.start() @@ -613,13 +634,14 @@ def test_connection_handling_during_moving( logging.info( "Waiting for moving ttl/2 to expire to validate proactive reconnection" ) - time.sleep(8) + time.sleep(fault_injector_client.get_moving_ttl() / 2) # validate that new connections will also receive the moving notification connections = [] for _ in range(3): connections.append(client.connection_pool.get_connection()) for conn in connections: + logging.debug(f"Releasing connection {conn}. {conn.maintenance_state}") client.connection_pool.release(conn) logging.info("Validating connections states during MOVING ...") @@ -633,10 +655,11 @@ def test_connection_handling_during_moving( endpoint_type, expected_matching_connected_conns_count=3, expected_matching_disconnected_conns_count=0, + fault_injector_client=fault_injector_client, ) logging.info("Waiting for moving ttl to expire") - time.sleep(BIND_TIMEOUT) + time.sleep(fault_injector_client.get_moving_ttl()) logging.info("Validating connection states after MOVING has expired ...") self._validate_default_state(client, expected_matching_conns_count=3) @@ -657,6 +680,10 @@ def test_old_connection_shutdown_during_moving( endpoints_config=endpoints_config, endpoint_type=endpoint_type ) + # create one connection and release it back to the pool + conn = client.connection_pool.get_connection() + client.connection_pool.release(conn) + logging.info("Starting migration ...") migrate_thread = Thread( target=self._execute_migration, @@ -664,8 +691,8 @@ def test_old_connection_shutdown_during_moving( args=( fault_injector_client, endpoints_config, - self.target_node, - self.empty_node, + self.target_node.node_id, + self.empty_node.node_id, ), ) migrate_thread.start() @@ -691,6 +718,15 @@ def execute_commands(moving_event: threading.Event, errors: Queue): f"Command failed in thread {threading.current_thread().name}: {e}" ) + # get the connection here because in case of proxy server + # new connections will not receive the notification and there is a chance + # that the existing connections in the pool that are used in the multiple + # threads might have already consumed the notification + # even with re clusters we might end up with an existing connection that has been + # freed up in the pool that will not receive the notification while we are waiting + # for it because it has already received and processed it + conn_to_check_moving = client.connection_pool.get_connection() + logging.info("Starting rebind...") bind_thread = Thread( target=self._execute_bind, @@ -716,14 +752,23 @@ def execute_commands(moving_event: threading.Event, errors: Queue): logging.info("Waiting for MOVING push notification ...") # this will consume the notification in one of the connections # and will handle the states of the rest - ClientValidations.wait_push_notification(client, timeout=BIND_TIMEOUT) + ClientValidations.wait_push_notification( + client, timeout=BIND_TIMEOUT, connection=conn_to_check_moving + ) # set the event to stop the command execution threads + logging.info("Setting moving event...") moving_event.set() + # release the connection back to the pool so that it can be disconnected + # as part of the flow + client.connection_pool.release(conn_to_check_moving) # Wait for all workers to finish and propagate any exceptions for f in futures: f.result() + logging.info( + "All command execution threads finished. Validating connections states..." + ) # validate that all connections are either disconnected # or connected to the new address connections = self._get_all_connections_in_pool(client) @@ -732,22 +777,27 @@ def execute_commands(moving_event: threading.Event, errors: Queue): assert conn.get_resolved_ip() == conn.host assert conn.maintenance_state == MaintenanceState.MOVING assert conn._sock.gettimeout() == RELAXED_TIMEOUT - assert conn.host != conn.orig_host_address + if not isinstance(fault_injector_client, ProxyServerFaultInjector): + assert conn.host != conn.orig_host_address assert not conn.should_reconnect() else: assert conn.maintenance_state == MaintenanceState.MOVING assert conn.socket_timeout == RELAXED_TIMEOUT - assert conn.host != conn.orig_host_address + if not isinstance(fault_injector_client, ProxyServerFaultInjector): + assert conn.host != conn.orig_host_address assert not conn.should_reconnect() # validate no errors were raised in the command execution threads assert errors.empty(), f"Errors occurred in threads: {errors.queue}" logging.info("Waiting for moving ttl to expire") - time.sleep(DEFAULT_BIND_TTL) bind_thread.join() @pytest.mark.timeout(300) # 5 minutes timeout + @pytest.mark.skipif( + use_mock_proxy(), + reason="Mock proxy doesn't support sending notifications to new connections.", + ) def test_new_connections_receive_moving( self, client_maint_notifications: Redis, @@ -764,8 +814,8 @@ def test_new_connections_receive_moving( args=( fault_injector_client, endpoints_config, - self.target_node, - self.empty_node, + self.target_node.node_id, + self.empty_node.node_id, ), ) migrate_thread.start() @@ -829,6 +879,7 @@ def test_new_connections_receive_moving( endpoint_type, expected_matching_connected_conns_count=1, expected_matching_disconnected_conns_count=0, + fault_injector_client=fault_injector_client, ) logging.info("Waiting for moving thread to be completed ...") @@ -840,6 +891,10 @@ def test_new_connections_receive_moving( client_maint_notifications.connection_pool.release(first_conn) @pytest.mark.timeout(300) # 5 minutes timeout + @pytest.mark.skipif( + use_mock_proxy(), + reason="Mock proxy doesn't support sending notifications to new connections.", + ) def test_new_connections_receive_migrating( self, client_maint_notifications: Redis, @@ -856,8 +911,8 @@ def test_new_connections_receive_migrating( args=( fault_injector_client, endpoints_config, - self.target_node, - self.empty_node, + self.target_node.node_id, + self.empty_node.node_id, ), ) migrate_thread.start() @@ -927,8 +982,8 @@ def test_disabled_handling_during_migrating_and_moving( args=( fault_injector_client, endpoints_config, - self.target_node, - self.empty_node, + self.target_node.node_id, + self.empty_node.node_id, ), ) migrate_thread.start() @@ -1040,7 +1095,11 @@ def test_command_execution_during_migrating_and_moving( 3. Commands are executed successfully """ errors = Queue() - execution_duration = 180 + if isinstance(fault_injector_client, ProxyServerFaultInjector): + execution_duration = 20 + else: + execution_duration = 180 + socket_timeout = 0.5 client = _get_client_maint_notifications( @@ -1051,35 +1110,274 @@ def test_command_execution_during_migrating_and_moving( enable_maintenance_notifications=True, ) + def execute_commands(duration: int, errors: Queue): + start = time.time() + while time.time() - start < duration: + try: + client.set("key", "value") + client.get("key") + except Exception as e: + logging.error( + f"Error in thread {threading.current_thread().name}: {e}" + ) + errors.put( + f"Command failed in thread {threading.current_thread().name}: {e}" + ) + logging.debug(f"{threading.current_thread().name}: Thread ended") + + threads = [] + for i in range(10): + thread = Thread( + target=execute_commands, + name=f"command_execution_thread_{i}", + args=( + execution_duration, + errors, + ), + ) + thread.start() + threads.append(thread) + migrate_and_bind_thread = Thread( target=self._execute_migrate_bind_flow, name="migrate_and_bind_thread", args=( fault_injector_client, endpoints_config, - self.target_node, - self.empty_node, + self.target_node.node_id, + self.empty_node.node_id, self.endpoint_id, ), ) migrate_and_bind_thread.start() + for thread in threads: + thread.join() + + migrate_and_bind_thread.join() + + # validate connections settings + self._validate_default_state( + client, expected_matching_conns_count=10, configured_timeout=socket_timeout + ) + + assert errors.empty(), f"Errors occurred in threads: {errors.queue}" + + +class TestClusterClientPushNotifications(TestPushNotificationsBase): + @pytest.fixture(autouse=True) + def setup_and_cleanup( + self, + fault_injector_client_oss_api: FaultInjectorClient, + cluster_client_maint_notifications: RedisCluster, + cluster_endpoints_config: Dict[str, Any], + cluster_endpoint_name: str, + ): + # Initialize cleanup flags first to ensure they exist even if setup fails + self._failover_executed = False + self._migration_executed = False + self._bind_executed = False + self.endpoint_id = None + self.target_node = None + self.empty_node = None + + try: + target_node, empty_node = ClusterOperations.find_target_node_and_empty_node( + fault_injector_client_oss_api, cluster_endpoints_config + ) + logging.info(f"Using target_node: {target_node}, empty_node: {empty_node}") + except Exception as e: + pytest.fail(f"Failed to find target and empty nodes: {e}") + + try: + self.endpoint_id = ClusterOperations.find_endpoint_for_bind( + fault_injector_client_oss_api, + cluster_endpoints_config, + cluster_endpoint_name, + ) + logging.info(f"Using endpoint: {self.endpoint_id}") + except Exception as e: + pytest.fail(f"Failed to find endpoint for bind operation: {e}") + + # Ensure setup completed successfully + if not target_node or not empty_node: + pytest.fail("Setup failed: target_node or empty_node not available") + if not self.endpoint_id: + pytest.fail("Setup failed: endpoint_id not available") + + self.target_node = target_node + self.empty_node = empty_node + + # get the cluster topology for the test + cluster_client_maint_notifications.nodes_manager.initialize() + + # Yield control to the test + yield + + # Cleanup code - this will run even if the test fails + logging.info("Starting cleanup...") + try: + cluster_client_maint_notifications.close() + except Exception as e: + logging.error(f"Failed to close client: {e}") + + # Only attempt cleanup if we have the necessary attributes and they were executed + if ( + not isinstance(fault_injector_client_oss_api, ProxyServerFaultInjector) + and self._failover_executed + ): + try: + self._execute_failover( + fault_injector_client_oss_api, cluster_endpoints_config + ) + logging.info("Failover cleanup completed") + except Exception as e: + logging.error(f"Failed to revert failover: {e}") + + logging.info("Cleanup finished") + + @pytest.mark.timeout(300) # 5 minutes timeout for this test + def test_notification_handling_during_node_fail_over( + self, + cluster_client_maint_notifications: RedisCluster, + fault_injector_client_oss_api: FaultInjectorClient, + cluster_endpoints_config: Dict[str, Any], + ): + """ + Test the push notifications are received when executing re cluster operations. + + """ + logging.info("Creating one connection in the pool.") + # get the node covering first shard - it is the node we will failover + target_node = ( + cluster_client_maint_notifications.nodes_manager.get_node_from_slot(0) + ) + logging.info(f"Target node for slot 0: {target_node.name}") + conn = target_node.redis_connection.connection_pool.get_connection() + cluster_nodes = ( + cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() + ) + + logging.info("Executing failover command...") + failover_thread = Thread( + target=self._execute_failover, + name="failover_thread", + args=(fault_injector_client_oss_api, cluster_endpoints_config), + ) + failover_thread.start() + + logging.info("Waiting for SMIGRATING push notifications...") + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=SMIGRATING_TIMEOUT, + connection=conn, + ) + + logging.info("Validating connection maintenance state...") + assert conn.maintenance_state == MaintenanceState.MAINTENANCE + assert conn._sock.gettimeout() == RELAXED_TIMEOUT + assert conn.should_reconnect() is False + + assert len(cluster_nodes) == len( + cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + for node_key in cluster_nodes.keys(): + assert ( + node_key in cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + + logging.info("Waiting for SMIGRATED push notifications...") + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=SMIGRATED_TIMEOUT, + connection=conn, + ) + + logging.info("Validating connection state after SMIGRATED ...") + # connection will be dropped, but it is marked + # to be disconnected before released to the pool + # we don't waste time to update the timeouts and state + # so it is pointless to check those configs + assert conn.should_reconnect() is True + + # validate that the node was removed from the cluster + # for re clusters we don't receive the replica nodes, + # so after failover the node is removed from the cluster + # and the previous replica that is promoted to primary is added as a new node + + # the overall number of nodes should be the same - one removed and one added + assert len(cluster_nodes) == len( + cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + assert ( + target_node.name + not in cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + + logging.info("Releasing connection back to the pool...") + target_node.redis_connection.connection_pool.release(conn) + + failover_thread.join() + + @pytest.mark.timeout(300) # 5 minutes timeout for this test + def test_command_execution_during_node_fail_over( + self, + fault_injector_client_oss_api: FaultInjectorClient, + cluster_endpoints_config: Dict[str, Any], + ): + """ + Test the push notifications are received when executing re cluster operations. + + """ + + errors = Queue() + if isinstance(fault_injector_client_oss_api, ProxyServerFaultInjector): + execution_duration = 20 + else: + execution_duration = 180 + + socket_timeout = 0.5 + + cluster_client_maint_notifications = _get_cluster_client_maint_notifications( + endpoints_config=cluster_endpoints_config, + disable_retries=True, + socket_timeout=socket_timeout, + enable_maintenance_notifications=True, + ) + def execute_commands(duration: int, errors: Queue): start = time.time() while time.time() - start < duration: try: - client.set("key", "value") - client.get("key") + # the slot is covered by the first shard - this one will failover + cluster_client_maint_notifications.set("key:{3}", "value") + cluster_client_maint_notifications.get("key:{3}") + # execute also commands that will run on the second shard + cluster_client_maint_notifications.set("key:{0}", "value") + cluster_client_maint_notifications.get("key:{0}") except Exception as e: + logging.error( + f"Error in thread {threading.current_thread().name}: {e}" + ) errors.put( f"Command failed in thread {threading.current_thread().name}: {e}" ) + logging.debug(f"{threading.current_thread().name}: Thread ended") + + logging.info("Creating one connection in the pool.") + # get the node covering first shard - it is the node we will failover + target_node = ( + cluster_client_maint_notifications.nodes_manager.get_node_from_slot(0) + ) + cluster_nodes = ( + cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() + ) threads = [] - for _ in range(10): + for i in range(10): thread = Thread( target=execute_commands, - name="command_execution_thread", + name=f"command_execution_thread_{i}", args=( execution_duration, errors, @@ -1088,9 +1386,222 @@ def execute_commands(duration: int, errors: Queue): thread.start() threads.append(thread) + logging.info("Executing failover command...") + failover_thread = Thread( + target=self._execute_failover, + name="failover_thread", + args=(fault_injector_client_oss_api, cluster_endpoints_config), + ) + failover_thread.start() + for thread in threads: thread.join() - migrate_and_bind_thread.join() + failover_thread.join() + # validate that the failed_over primary node was removed from the cluster + # for re clusters we don't receive the replica nodes, + # so after failover the node is removed from the cluster + # and the previous replica that is promoted to primary is added as a new node + + # the overall number of nodes should be the same - one removed and one added + assert len(cluster_nodes) == len( + cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + assert ( + target_node.name + not in cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + + for ( + node + ) in cluster_client_maint_notifications.nodes_manager.nodes_cache.values(): + # validate connections settings + self._validate_default_state( + node.redis_connection, + expected_matching_conns_count=10, + configured_timeout=socket_timeout, + ) + + # validate no errors were raised in the command execution threads + assert errors.empty(), f"Errors occurred in threads: {errors.queue}" + + @pytest.mark.timeout(300) # 5 minutes timeout for this test + def test_notification_handling_during_migration_without_node_replacement( + self, + cluster_client_maint_notifications: RedisCluster, + fault_injector_client_oss_api: FaultInjectorClient, + cluster_endpoints_config: Dict[str, Any], + ): + """ + Test the push notifications are received when executing re cluster operations. + + """ + logging.info("Creating one connection in the pool.") + # get the node covering first shard - it is the node we will have migrated slots + target_node = ( + cluster_client_maint_notifications.nodes_manager.get_node_from_slot(0) + ) + conn = target_node.redis_connection.connection_pool.get_connection() + cluster_nodes = ( + cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() + ) + + logging.info("Executing failover command...") + migration_thread = Thread( + target=self._execute_migration, + name="migration_thread", + args=( + fault_injector_client_oss_api, + cluster_endpoints_config, + self.target_node.node_id, + self.empty_node.node_id, + ), + ) + migration_thread.start() + + logging.info("Waiting for SMIGRATING push notifications...") + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=SMIGRATING_TIMEOUT, + connection=conn, + ) + + logging.info("Validating connection maintenance state...") + assert conn.maintenance_state == MaintenanceState.MAINTENANCE + assert conn._sock.gettimeout() == RELAXED_TIMEOUT + assert conn.should_reconnect() is False + + assert len(cluster_nodes) == len( + cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + for node_key in cluster_nodes.keys(): + assert ( + node_key in cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + + logging.info("Waiting for SMIGRATED push notifications...") + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=SMIGRATED_TIMEOUT, + connection=conn, + ) + + logging.info("Validating connection state after SMIGRATED ...") + + assert conn.should_reconnect() is True + + # the overall number of nodes should be the same - one removed and one added + assert len(cluster_nodes) == len( + cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + for node_key in cluster_nodes.keys(): + assert ( + node_key in cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + + logging.info("Releasing connection back to the pool...") + target_node.redis_connection.connection_pool.release(conn) + + migration_thread.join() + + @pytest.mark.timeout(300) # 5 minutes timeout for this test + def test_command_execution_during_migration_without_node_replacement( + self, + fault_injector_client_oss_api: FaultInjectorClient, + cluster_endpoints_config: Dict[str, Any], + ): + """ + Test the push notifications are received when executing re cluster operations. + """ + + errors = Queue() + if isinstance(fault_injector_client_oss_api, ProxyServerFaultInjector): + execution_duration = 20 + else: + execution_duration = 180 + + socket_timeout = 0.5 + + cluster_client_maint_notifications = _get_cluster_client_maint_notifications( + endpoints_config=cluster_endpoints_config, + disable_retries=True, + socket_timeout=socket_timeout, + enable_maintenance_notifications=True, + ) + + def execute_commands(duration: int, errors: Queue): + start = time.time() + while time.time() - start < duration: + try: + # the slot is covered by the first shard - this one will have slots migrated + cluster_client_maint_notifications.set("key:{3}", "value") + cluster_client_maint_notifications.get("key:{3}") + # execute also commands that will run on the second shard + cluster_client_maint_notifications.set("key:{0}", "value") + cluster_client_maint_notifications.get("key:{0}") + except Exception as e: + logging.error( + f"Error in thread {threading.current_thread().name}: {e}" + ) + errors.put( + f"Command failed in thread {threading.current_thread().name}: {e}" + ) + logging.debug(f"{threading.current_thread().name}: Thread ended") + + cluster_nodes = ( + cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() + ) + + threads = [] + for i in range(10): + thread = Thread( + target=execute_commands, + name=f"command_execution_thread_{i}", + args=( + execution_duration, + errors, + ), + ) + thread.start() + threads.append(thread) + + logging.info("Executing failover command...") + migration_thread = Thread( + target=self._execute_migration, + name="migration_thread", + args=( + fault_injector_client_oss_api, + cluster_endpoints_config, + self.target_node.node_id, + self.empty_node.node_id, + ), + ) + migration_thread.start() + + for thread in threads: + thread.join() + + migration_thread.join() + + # validate cluster nodes + assert len(cluster_nodes) == len( + cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + for node_key in cluster_nodes.keys(): + assert ( + node_key in cluster_client_maint_notifications.nodes_manager.nodes_cache + ) + + for ( + node + ) in cluster_client_maint_notifications.nodes_manager.nodes_cache.values(): + # validate connections settings + self._validate_default_state( + node.redis_connection, + expected_matching_conns_count=10, + configured_timeout=socket_timeout, + ) + + # validate no errors were raised in the command execution threads assert errors.empty(), f"Errors occurred in threads: {errors.queue}"