Skip to content
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ services:
API_PORT: "4000"
ENABLE_LOGGING: true
SIMULATE_CLUSTER: true
DEFAULT_INTERCEPTORS: "cluster,hitless,logger"

ports:
- "15379:15379"
- "15380:15380"
Expand Down
2 changes: 1 addition & 1 deletion redis/_parsers/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _read_from_socket(
sock.settimeout(timeout)
try:
while True:
data = self._sock.recv(socket_read_size)
data = sock.recv(socket_read_size)
# an empty string indicates the server shutdown the socket
if isinstance(data, bytes) and len(data) == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
Expand Down
4 changes: 2 additions & 2 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,10 +2027,10 @@ def initialize(
# Make sure cluster mode is enabled on this node
try:
cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS"))
# For some cases we might not want to disconnect current pool and
# lose in flight commands responses
if disconnect_startup_nodes_pools:
# Disconnect the connection pool to avoid keeping the connection open
# For some cases we might not want to disconnect current pool and
# lose in flight commands responses
r.connection_pool.disconnect()
except ResponseError:
raise RedisClusterException(
Expand Down
8 changes: 7 additions & 1 deletion redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,13 @@ def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None)
conn_socket = self._get_socket()
if conn_socket:
timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
conn_socket.settimeout(timeout)
# if the current timeout is 0 it means we are in the middle of a can_read call
# in this case we don't want to change the timeout because the operation
# is non-blocking and should return immediately
# Changing the state from non-blocking to blocking in the middle of a read operation
# will lead to a deadlock
if conn_socket.gettimeout() != 0:
conn_socket.settimeout(timeout)
self.update_parser_timeout(timeout)

def update_parser_timeout(self, timeout: Optional[float] = None):
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ def pytest_addoption(parser):
help="Name of the Redis endpoint the tests should be executed on",
)

parser.addoption(
"--cluster-endpoint-name",
action="store",
default=None,
help="Name of the Redis endpoint with OSS API the tests should be executed on",
)


def _get_info(redis_url):
client = redis.Redis.from_url(redis_url)
Expand Down
23 changes: 20 additions & 3 deletions tests/maint_notifications/proxy_server_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64
from dataclasses import dataclass
import logging
import re
from typing import Union

from redis.http.http_client import HttpClient, HttpError
Expand All @@ -10,6 +9,25 @@
class RespTranslator:
"""Helper class to translate between RESP and other encodings."""

@staticmethod
def re_cluster_maint_notification_to_resp(txt: str) -> str:
"""Convert query to RESP format."""
parts = txt.split()

match parts:
case ["MOVING", seq_id, time, new_host]:
return f">4\r\n+MOVING\r\n:{seq_id}\r\n:{time}\r\n+{new_host}\r\n"
case ["MIGRATING", seq_id, time, shards]:
return f">4\r\n+MIGRATING\r\n:{seq_id}\r\n:{time}\r\n+{shards}\r\n"
case ["MIGRATED", seq_id, shards]:
return f">3\r\n+MIGRATED\r\n:{seq_id}\r\n+{shards}\r\n"
case ["FAILING_OVER", seq_id, time, shards]:
return f">4\r\n+FAILING_OVER\r\n:{seq_id}\r\n:{time}\r\n+{shards}\r\n"
case ["FAILED_OVER", seq_id, shards]:
return f">3\r\n+FAILED_OVER\r\n:{seq_id}\r\n+{shards}\r\n"
case _:
raise NotImplementedError(f"Unknown notification: {txt}")

@staticmethod
def oss_maint_notification_to_resp(txt: str) -> str:
"""Convert query to RESP format."""
Expand Down Expand Up @@ -232,8 +250,7 @@ def send_notification(

if not conn_ids:
raise RuntimeError(
f"No connections found for node {node_port}. "
f"Available nodes: {list(set(c.get('node') for c in stats.get('connections', {}).values()))}"
f"No connections found for node {connected_to_port}. \nStats: {stats}"
)

# Send notification to each connection
Expand Down
4 changes: 2 additions & 2 deletions tests/test_asyncio/test_scenario/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from redis.event import AsyncEventListenerInterface, EventDispatcher
from redis.multidb.failure_detector import DEFAULT_MIN_NUM_FAILURES
from tests.test_scenario.conftest import get_endpoints_config, extract_cluster_fqdn
from tests.test_scenario.fault_injector_client import FaultInjectorClient
from tests.test_scenario.fault_injector_client import REFaultInjector


class CheckActiveDatabaseChangedListener(AsyncEventListenerInterface):
Expand All @@ -31,7 +31,7 @@ async def listen(self, event: AsyncActiveDatabaseChanged):
@pytest.fixture()
def fault_injector_client():
url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324")
return FaultInjectorClient(url)
return REFaultInjector(url)


@pytest_asyncio.fixture()
Expand Down
116 changes: 111 additions & 5 deletions tests/test_scenario/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from urllib.parse import urlparse

import pytest
from redis import RedisCluster

from redis.backoff import NoBackoff, ExponentialBackoff
from redis.event import EventDispatcher, EventListenerInterface
Expand All @@ -22,12 +23,16 @@
from redis.client import Redis
from redis.maint_notifications import EndpointType, MaintNotificationsConfig
from redis.retry import Retry
from tests.test_scenario.fault_injector_client import FaultInjectorClient
from tests.test_scenario.fault_injector_client import (
ProxyServerFaultInjector,
REFaultInjector,
)

RELAXED_TIMEOUT = 30
CLIENT_TIMEOUT = 5

DEFAULT_ENDPOINT_NAME = "m-standard"
DEFAULT_OSS_API_ENDPOINT_NAME = "oss-api"


class CheckActiveDatabaseChangedListener(EventListenerInterface):
Expand All @@ -38,13 +43,24 @@ def listen(self, event: ActiveDatabaseChanged):
self.is_changed_flag = True


def use_mock_proxy():
return os.getenv("REDIS_ENTERPRISE_TESTS", "true").lower() == "false"


@pytest.fixture()
def endpoint_name(request):
return request.config.getoption("--endpoint-name") or os.getenv(
"REDIS_ENDPOINT_NAME", DEFAULT_ENDPOINT_NAME
)


@pytest.fixture()
def cluster_endpoint_name(request):
return request.config.getoption("--cluster-endpoint-name") or os.getenv(
"REDIS_CLUSTER_ENDPOINT_NAME", DEFAULT_OSS_API_ENDPOINT_NAME
)


def get_endpoints_config(endpoint_name: str):
endpoints_config = os.getenv("REDIS_ENDPOINTS_CONFIG_PATH", None)

Expand All @@ -67,10 +83,27 @@ def endpoints_config(endpoint_name: str):
return get_endpoints_config(endpoint_name)


@pytest.fixture()
def cluster_endpoints_config(cluster_endpoint_name: str):
return get_endpoints_config(cluster_endpoint_name)


@pytest.fixture()
def fault_injector_client():
url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324")
return FaultInjectorClient(url)
if use_mock_proxy():
return ProxyServerFaultInjector(oss_cluster=False)
else:
url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324")
return REFaultInjector(url)


@pytest.fixture()
def fault_injector_client_oss_api():
if use_mock_proxy():
return ProxyServerFaultInjector(oss_cluster=True)
else:
url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324")
return REFaultInjector(url)


@pytest.fixture()
Expand Down Expand Up @@ -208,8 +241,6 @@ def _get_client_maint_notifications(
endpoint_type=endpoint_type,
)

# Create Redis client with maintenance notifications config
# This will automatically create the MaintNotificationsPoolHandler
if disable_retries:
retry = Retry(NoBackoff(), 0)
else:
Expand All @@ -218,6 +249,8 @@ def _get_client_maint_notifications(
tls_enabled = True if parsed.scheme == "rediss" else False
logging.info(f"TLS enabled: {tls_enabled}")

# Create Redis client with maintenance notifications config
# This will automatically create the MaintNotificationsPoolHandler
client = Redis(
host=host,
port=port,
Expand All @@ -235,3 +268,76 @@ def _get_client_maint_notifications(
logging.info(f"Client uses Protocol: {client.connection_pool.get_protocol()}")

return client


@pytest.fixture()
def cluster_client_maint_notifications(cluster_endpoints_config):
return _get_cluster_client_maint_notifications(cluster_endpoints_config)


def _get_cluster_client_maint_notifications(
endpoints_config,
protocol: int = 3,
enable_maintenance_notifications: bool = True,
endpoint_type: Optional[EndpointType] = None,
enable_relaxed_timeout: bool = True,
enable_proactive_reconnect: bool = True,
disable_retries: bool = False,
socket_timeout: Optional[float] = None,
host_config: Optional[str] = None,
):
"""Create Redis cluster client with maintenance notifications enabled."""
# Get credentials from the configuration
username = endpoints_config.get("username")
password = endpoints_config.get("password")

# Parse host and port from endpoints URL
endpoints = endpoints_config.get("endpoints", [])
if not endpoints:
raise ValueError("No endpoints found in configuration")

parsed = urlparse(endpoints[0])
host = parsed.hostname
port = parsed.port

if not host:
raise ValueError(f"Could not parse host from endpoint URL: {endpoints[0]}")

logging.info(f"Connecting to Redis Enterprise: {host}:{port} with user: {username}")

if disable_retries:
retry = Retry(NoBackoff(), 0)
else:
retry = Retry(backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3)

tls_enabled = True if parsed.scheme == "rediss" else False
logging.info(f"TLS enabled: {tls_enabled}")

# Configure maintenance notifications
maintenance_config = MaintNotificationsConfig(
enabled=enable_maintenance_notifications,
proactive_reconnect=enable_proactive_reconnect,
relaxed_timeout=RELAXED_TIMEOUT if enable_relaxed_timeout else -1,
endpoint_type=endpoint_type,
)

# Create Redis cluster client with maintenance notifications config
client = RedisCluster(
host=host,
port=port,
socket_timeout=CLIENT_TIMEOUT if socket_timeout is None else socket_timeout,
username=username,
password=password,
ssl=tls_enabled,
ssl_cert_reqs="none",
ssl_check_hostname=False,
protocol=protocol, # RESP3 required for push notifications
maint_notifications_config=maintenance_config,
retry=retry,
)
logging.info("Redis cluster client created with maintenance notifications enabled")
logging.info(
f"Cluster working with the following nodes: {[(node.name, node.server_type) for node in client.get_nodes()]}"
)

return client
Loading