From b26a03ef543f022b6dd8b0db8293006a549bff2c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 27 Aug 2025 16:42:32 +0300 Subject: [PATCH 01/20] Extract additional interfaces and abstract classes --- redis/multidb/circuit.py | 82 ++++++----- redis/multidb/client.py | 25 ++-- redis/multidb/command_executor.py | 152 ++++++++++---------- redis/multidb/config.py | 8 +- redis/multidb/database.py | 100 +++++++------ redis/multidb/event.py | 13 +- redis/multidb/failover.py | 11 +- redis/multidb/failure_detector.py | 1 - tests/test_multidb/conftest.py | 12 +- tests/test_multidb/test_circuit.py | 4 +- tests/test_multidb/test_client.py | 4 +- tests/test_multidb/test_config.py | 10 +- tests/test_multidb/test_failure_detector.py | 12 +- 13 files changed, 225 insertions(+), 209 deletions(-) diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 79c8a5f379..221dc556a3 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -45,8 +45,49 @@ def database(self, database): """Set database associated with this circuit.""" pass +class BaseCircuitBreaker(CircuitBreaker): + """ + Base implementation of Circuit Breaker interface. + """ + def __init__(self, cb: pybreaker.CircuitBreaker): + self._cb = cb + self._state_pb_mapper = { + State.CLOSED: self._cb.close, + State.OPEN: self._cb.open, + State.HALF_OPEN: self._cb.half_open, + } + self._database = None + + @property + def grace_period(self) -> float: + return self._cb.reset_timeout + + @grace_period.setter + def grace_period(self, grace_period: float): + self._cb.reset_timeout = grace_period + + @property + def state(self) -> State: + return State(value=self._cb.state.name) + + @state.setter + def state(self, state: State): + self._state_pb_mapper[state]() + + @property + def database(self): + return self._database + + @database.setter + def database(self, database): + self._database = database + +class SyncCircuitBreaker(CircuitBreaker): + """ + Synchronous implementation of Circuit Breaker interface. + """ @abstractmethod - def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" pass @@ -54,7 +95,7 @@ class PBListener(pybreaker.CircuitBreakerListener): """Wrapper for callback to be compatible with pybreaker implementation.""" def __init__( self, - cb: Callable[[CircuitBreaker, State, State], None], + cb: Callable[[SyncCircuitBreaker, State, State], None], database, ): """ @@ -75,8 +116,7 @@ def state_change(self, cb, old_state, new_state): new_state = State(value=new_state.name) self._cb(cb, old_state, new_state) - -class PBCircuitBreakerAdapter(CircuitBreaker): +class PBCircuitBreakerAdapter(SyncCircuitBreaker, BaseCircuitBreaker): def __init__(self, cb: pybreaker.CircuitBreaker): """ Initialize a PBCircuitBreakerAdapter instance. @@ -87,38 +127,8 @@ def __init__(self, cb: pybreaker.CircuitBreaker): Args: cb: A pybreaker CircuitBreaker instance to be adapted. """ - self._cb = cb - self._state_pb_mapper = { - State.CLOSED: self._cb.close, - State.OPEN: self._cb.open, - State.HALF_OPEN: self._cb.half_open, - } - self._database = None - - @property - def grace_period(self) -> float: - return self._cb.reset_timeout - - @grace_period.setter - def grace_period(self, grace_period: float): - self._cb.reset_timeout = grace_period - - @property - def state(self) -> State: - return State(value=self._cb.state.name) - - @state.setter - def state(self, state: State): - self._state_pb_mapper[state]() - - @property - def database(self): - return self._database - - @database.setter - def database(self, database): - self._database = database + super().__init__(cb) - def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): listener = PBListener(cb, self.database) self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 56342a7a53..8a0e006977 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,15 +1,12 @@ import threading -import socket from typing import List, Any, Callable, Optional from redis.background import BackgroundScheduler -from redis.client import PubSubWorkerThread -from redis.exceptions import ConnectionError, TimeoutError from redis.commands import RedisModuleCommands, CoreCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD -from redis.multidb.circuit import State as CBState, CircuitBreaker -from redis.multidb.database import Database, AbstractDatabase, Databases +from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck @@ -92,7 +89,7 @@ def get_databases(self) -> Databases: """ return self._databases - def set_active_database(self, database: AbstractDatabase) -> None: + def set_active_database(self, database: SyncDatabase) -> None: """ Promote one of the existing databases to become an active. """ @@ -115,7 +112,7 @@ def set_active_database(self, database: AbstractDatabase) -> None: raise NoValidDatabaseException('Cannot set active database, database is unhealthy') - def add_database(self, database: AbstractDatabase): + def add_database(self, database: SyncDatabase): """ Adds a new database to the database list. """ @@ -129,7 +126,7 @@ def add_database(self, database: AbstractDatabase): self._databases.add(database, database.weight) self._change_active_database(database, highest_weighted_db) - def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): + def _change_active_database(self, new_database: SyncDatabase, highest_weight_database: SyncDatabase): if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: self.command_executor.active_database = new_database @@ -143,7 +140,7 @@ def remove_database(self, database: Database): if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: self.command_executor.active_database = highest_weighted_db - def update_database_weight(self, database: AbstractDatabase, weight: float): + def update_database_weight(self, database: SyncDatabase, weight: float): """ Updates a database from the database list. """ @@ -210,7 +207,7 @@ def pubsub(self, **kwargs): return PubSub(self, **kwargs) - def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None: + def _check_db_health(self, database: SyncDatabase, on_error: Callable[[Exception], None] = None) -> None: """ Runs health checks on the given database until first failure. """ @@ -247,7 +244,7 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): for database, _ in self._databases: self._check_db_health(database, on_error) - def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: self._check_db_health(circuit.database) return @@ -255,7 +252,7 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: if old_state == CBState.CLOSED and new_state == CBState.OPEN: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) -def _half_open_circuit(circuit: CircuitBreaker): +def _half_open_circuit(circuit: SyncCircuitBreaker): circuit.state = CBState.HALF_OPEN @@ -450,8 +447,8 @@ def run_in_thread( exception_handler: Optional[Callable] = None, sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": - return self._client.command_executor.execute_pubsub_run_in_thread( - sleep_time=sleep_time, + return self._client.command_executor.execute_pubsub_run( + sleep_time, daemon=daemon, exception_handler=exception_handler, pubsub=self, diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 094230a31d..364c0a07ea 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Optional, Callable +from typing import List, Optional, Callable, Any from redis.client import Pipeline, PubSub, PubSubWorkerThread from redis.event import EventDispatcherInterface, OnCommandsFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL -from redis.multidb.database import Database, AbstractDatabase, Databases +from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged from redis.multidb.failover import FailoverStrategy @@ -17,15 +17,40 @@ class CommandExecutor(ABC): @property @abstractmethod - def failure_detectors(self) -> List[FailureDetector]: - """Returns a list of failure detectors.""" + def auto_fallback_interval(self) -> float: + """Returns auto-fallback interval.""" pass + @auto_fallback_interval.setter @abstractmethod - def add_failure_detector(self, failure_detector: FailureDetector) -> None: - """Adds new failure detector to the list of failure detectors.""" + def auto_fallback_interval(self, auto_fallback_interval: float) -> None: + """Sets auto-fallback interval.""" pass +class BaseCommandExecutor(CommandExecutor): + def __init__( + self, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + self._auto_fallback_interval = auto_fallback_interval + self._next_fallback_attempt: datetime + + @property + def auto_fallback_interval(self) -> float: + return self._auto_fallback_interval + + @auto_fallback_interval.setter + def auto_fallback_interval(self, auto_fallback_interval: int) -> None: + self._auto_fallback_interval = auto_fallback_interval + + def _schedule_next_fallback(self) -> None: + if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: + return + + self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) + +class SyncCommandExecutor(CommandExecutor): + @property @abstractmethod def databases(self) -> Databases: @@ -34,19 +59,25 @@ def databases(self) -> Databases: @property @abstractmethod - def active_database(self) -> Optional[Database]: - """Returns currently active database.""" + def failure_detectors(self) -> List[FailureDetector]: + """Returns a list of failure detectors.""" pass - @active_database.setter @abstractmethod - def active_database(self, database: AbstractDatabase) -> None: - """Sets currently active database.""" + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + """Adds a new failure detector to the list of failure detectors.""" pass + @property @abstractmethod - def pubsub(self, **kwargs): - """Initializes a PubSub object on a currently active database""" + def active_database(self) -> Optional[Database]: + """Returns currently active database.""" + pass + + @active_database.setter + @abstractmethod + def active_database(self, database: SyncDatabase) -> None: + """Sets the currently active database.""" pass @property @@ -69,30 +100,41 @@ def failover_strategy(self) -> FailoverStrategy: @property @abstractmethod - def auto_fallback_interval(self) -> float: - """Returns auto-fallback interval.""" + def command_retry(self) -> Retry: + """Returns command retry object.""" pass - @auto_fallback_interval.setter @abstractmethod - def auto_fallback_interval(self, auto_fallback_interval: float) -> None: - """Sets auto-fallback interval.""" + def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" pass - @property @abstractmethod - def command_retry(self) -> Retry: - """Returns command retry object.""" + def execute_command(self, *args, **options): + """Executes a command and returns the result.""" pass @abstractmethod - def execute_command(self, *args, **options): - """Executes a command and returns the result.""" + def execute_pipeline(self, command_stack: tuple): + """Executes a stack of commands in pipeline.""" pass + @abstractmethod + def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + """Executes a transaction block wrapped in callback.""" + pass -class DefaultCommandExecutor(CommandExecutor): + @abstractmethod + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """Executes a given method on active pub/sub.""" + pass + @abstractmethod + def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + """Executes pub/sub run in a thread.""" + pass + +class DefaultCommandExecutor(SyncCommandExecutor, BaseCommandExecutor): def __init__( self, failure_detectors: List[FailureDetector], @@ -113,22 +155,26 @@ def __init__( event_dispatcher: Interface for dispatching events auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database """ + super().__init__(auto_fallback_interval) + for fd in failure_detectors: fd.set_command_executor(command_executor=self) - self._failure_detectors = failure_detectors self._databases = databases + self._failure_detectors = failure_detectors self._command_retry = command_retry self._failover_strategy = failover_strategy self._event_dispatcher = event_dispatcher - self._auto_fallback_interval = auto_fallback_interval - self._next_fallback_attempt: datetime self._active_database: Optional[Database] = None self._active_pubsub: Optional[PubSub] = None self._active_pubsub_kwargs = {} self._setup_event_dispatcher() self._schedule_next_fallback() + @property + def databases(self) -> Databases: + return self._databases + @property def failure_detectors(self) -> List[FailureDetector]: return self._failure_detectors @@ -136,20 +182,16 @@ def failure_detectors(self) -> List[FailureDetector]: def add_failure_detector(self, failure_detector: FailureDetector) -> None: self._failure_detectors.append(failure_detector) - @property - def databases(self) -> Databases: - return self._databases - @property def command_retry(self) -> Retry: return self._command_retry @property - def active_database(self) -> Optional[AbstractDatabase]: + def active_database(self) -> Optional[SyncDatabase]: return self._active_database @active_database.setter - def active_database(self, database: AbstractDatabase) -> None: + def active_database(self, database: SyncDatabase) -> None: old_active = self._active_database self._active_database = database @@ -170,25 +212,13 @@ def active_pubsub(self, pubsub: PubSub) -> None: def failover_strategy(self) -> FailoverStrategy: return self._failover_strategy - @property - def auto_fallback_interval(self) -> float: - return self._auto_fallback_interval - - @auto_fallback_interval.setter - def auto_fallback_interval(self, auto_fallback_interval: int) -> None: - self._auto_fallback_interval = auto_fallback_interval - def execute_command(self, *args, **options): - """Executes a command and returns the result.""" def callback(): return self._active_database.client.execute_command(*args, **options) return self._execute_with_failure_detection(callback, args) def execute_pipeline(self, command_stack: tuple): - """ - Executes a stack of commands in pipeline. - """ def callback(): with self._active_database.client.pipeline() as pipe: for command, options in command_stack: @@ -199,18 +229,12 @@ def callback(): return self._execute_with_failure_detection(callback, command_stack) def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): - """ - Executes a transaction block wrapped in callback. - """ def callback(): return self._active_database.client.transaction(transaction, *watches, **options) return self._execute_with_failure_detection(callback) def pubsub(self, **kwargs): - """ - Initializes a PubSub object on a currently active database. - """ def callback(): if self._active_pubsub is None: self._active_pubsub = self._active_database.client.pubsub(**kwargs) @@ -220,31 +244,15 @@ def callback(): return self._execute_with_failure_detection(callback) def execute_pubsub_method(self, method_name: str, *args, **kwargs): - """ - Executes given method on active pub/sub. - """ def callback(): method = getattr(self.active_pubsub, method_name) return method(*args, **kwargs) return self._execute_with_failure_detection(callback, *args) - def execute_pubsub_run_in_thread( - self, - pubsub, - sleep_time: float = 0.0, - daemon: bool = False, - exception_handler: Optional[Callable] = None, - sharded_pubsub: bool = False, - ) -> "PubSubWorkerThread": + def execute_pubsub_run(self, sleep_time, **kwargs) -> "PubSubWorkerThread": def callback(): - return self._active_pubsub.run_in_thread( - sleep_time, - daemon=daemon, - exception_handler=exception_handler, - pubsub=pubsub, - sharded_pubsub=sharded_pubsub - ) + return self._active_pubsub.run_in_thread(sleep_time, **kwargs) return self._execute_with_failure_detection(callback) @@ -280,12 +288,6 @@ def _check_active_database(self): self.active_database = self._failover_strategy.database self._schedule_next_fallback() - def _schedule_next_fallback(self) -> None: - if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: - return - - self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) - def _setup_event_dispatcher(self): """ Registers necessary listeners. diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 5555baec44..a966ec329a 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -9,7 +9,7 @@ from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ @@ -44,7 +44,7 @@ class DatabaseConfig: client_kwargs (dict): Additional parameters for the database client connection. from_url (Optional[str]): Redis URL way of connecting to the database. from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. - circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. grace_period (float): Grace period after which we need to check if the circuit could be closed again. health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used on public Redis Enterprise endpoints. @@ -57,11 +57,11 @@ class DatabaseConfig: client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None - circuit: Optional[CircuitBreaker] = None + circuit: Optional[SyncCircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD health_check_url: Optional[str] = None - def default_circuit_breaker(self) -> CircuitBreaker: + def default_circuit_breaker(self) -> SyncCircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) return PBCircuitBreakerAdapter(circuit_breaker) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index b03e77bd70..75a662d904 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -5,65 +5,92 @@ from redis import RedisCluster from redis.data_structure import WeightedList -from redis.multidb.circuit import CircuitBreaker +from redis.multidb.circuit import SyncCircuitBreaker from redis.typing import Number class AbstractDatabase(ABC): @property @abstractmethod - def client(self) -> Union[redis.Redis, RedisCluster]: - """The underlying redis client.""" + def weight(self) -> float: + """The weight of this database in compare to others. Used to determine the database failover to.""" pass - @client.setter + @weight.setter @abstractmethod - def client(self, client: Union[redis.Redis, RedisCluster]): - """Set the underlying redis client.""" + def weight(self, weight: float): + """Set the weight of this database in compare to others.""" pass @property @abstractmethod - def weight(self) -> float: - """The weight of this database in compare to others. Used to determine the database failover to.""" + def health_check_url(self) -> Optional[str]: + """Health check URL associated with the current database.""" pass - @weight.setter + @health_check_url.setter @abstractmethod - def weight(self, weight: float): - """Set the weight of this database in compare to others.""" + def health_check_url(self, health_check_url: Optional[str]): + """Set the health check URL associated with the current database.""" pass +class BaseDatabase(AbstractDatabase): + def __init__( + self, + weight: float, + health_check_url: Optional[str] = None, + ): + self._weight = weight + self._health_check_url = health_check_url + + @property + def weight(self) -> float: + return self._weight + + @weight.setter + def weight(self, weight: float): + self._weight = weight + + @property + def health_check_url(self) -> Optional[str]: + return self._health_check_url + + @health_check_url.setter + def health_check_url(self, health_check_url: Optional[str]): + self._health_check_url = health_check_url + +class SyncDatabase(AbstractDatabase): + """Database with an underlying synchronous redis client.""" @property @abstractmethod - def circuit(self) -> CircuitBreaker: - """Circuit breaker for the current database.""" + def client(self) -> Union[redis.Redis, RedisCluster]: + """The underlying redis client.""" pass - @circuit.setter + @client.setter @abstractmethod - def circuit(self, circuit: CircuitBreaker): - """Set the circuit breaker for the current database.""" + def client(self, client: Union[redis.Redis, RedisCluster]): + """Set the underlying redis client.""" pass @property @abstractmethod - def health_check_url(self) -> Optional[str]: - """Health check URL associated with the current database.""" + def circuit(self) -> SyncCircuitBreaker: + """Circuit breaker for the current database.""" pass - @health_check_url.setter + @circuit.setter @abstractmethod - def health_check_url(self, health_check_url: Optional[str]): - """Set the health check URL associated with the current database.""" + def circuit(self, circuit: SyncCircuitBreaker): + """Set the circuit breaker for the current database.""" pass -Databases = WeightedList[tuple[AbstractDatabase, Number]] +Databases = WeightedList[tuple[SyncDatabase, Number]] -class Database(AbstractDatabase): +class Database(BaseDatabase, SyncDatabase): def __init__( self, client: Union[redis.Redis, RedisCluster], - circuit: CircuitBreaker, + circuit: SyncCircuitBreaker, weight: float, health_check_url: Optional[str] = None, ): @@ -79,8 +106,7 @@ def __init__( self._client = client self._cb = circuit self._cb.database = self - self._weight = weight - self._health_check_url = health_check_url + super().__init__(weight, health_check_url) @property def client(self) -> Union[redis.Redis, RedisCluster]: @@ -91,25 +117,9 @@ def client(self, client: Union[redis.Redis, RedisCluster]): self._client = client @property - def weight(self) -> float: - return self._weight - - @weight.setter - def weight(self, weight: float): - self._weight = weight - - @property - def circuit(self) -> CircuitBreaker: + def circuit(self) -> SyncCircuitBreaker: return self._cb @circuit.setter - def circuit(self, circuit: CircuitBreaker): - self._cb = circuit - - @property - def health_check_url(self) -> Optional[str]: - return self._health_check_url - - @health_check_url.setter - def health_check_url(self, health_check_url: Optional[str]): - self._health_check_url = health_check_url + def circuit(self, circuit: SyncCircuitBreaker): + self._cb = circuit \ No newline at end of file diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 2598bc4d06..bca9482347 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,8 +1,7 @@ from typing import List from redis.event import EventListenerInterface, OnCommandsFailEvent -from redis.multidb.config import Databases -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import SyncDatabase from redis.multidb.failure_detector import FailureDetector class ActiveDatabaseChanged: @@ -11,8 +10,8 @@ class ActiveDatabaseChanged: """ def __init__( self, - old_database: AbstractDatabase, - new_database: AbstractDatabase, + old_database: SyncDatabase, + new_database: SyncDatabase, command_executor, **kwargs ): @@ -22,11 +21,11 @@ def __init__( self._kwargs = kwargs @property - def old_database(self) -> AbstractDatabase: + def old_database(self) -> SyncDatabase: return self._old_database @property - def new_database(self) -> AbstractDatabase: + def new_database(self) -> SyncDatabase: return self._new_database @property @@ -39,7 +38,7 @@ def kwargs(self): class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): """ - Re-subscribe currently active pub/sub to a new active database. + Re-subscribe the currently active pub / sub to a new active database. """ def listen(self, event: ActiveDatabaseChanged): old_pubsub = event.command_executor.active_pubsub diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index d6cf198678..fd08b77ecd 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from redis.data_structure import WeightedList -from redis.multidb.database import Databases -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException from redis.retry import Retry @@ -13,13 +12,13 @@ class FailoverStrategy(ABC): @property @abstractmethod - def database(self) -> AbstractDatabase: + def database(self) -> SyncDatabase: """Select the database according to the strategy.""" pass @abstractmethod def set_databases(self, databases: Databases) -> None: - """Set the databases strategy operates on.""" + """Set the database strategy operates on.""" pass class WeightBasedFailoverStrategy(FailoverStrategy): @@ -35,7 +34,7 @@ def __init__( self._databases = WeightedList() @property - def database(self) -> AbstractDatabase: + def database(self) -> SyncDatabase: return self._retry.call_with_retry( lambda: self._get_active_database(), lambda _: dummy_fail() @@ -44,7 +43,7 @@ def database(self) -> AbstractDatabase: def set_databases(self, databases: Databases) -> None: self._databases = databases - def _get_active_database(self) -> AbstractDatabase: + def _get_active_database(self) -> SyncDatabase: for database, _ in self._databases: if database.circuit.state == CBState.CLOSED: return database diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 3280fa6c32..ef4bd35f69 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -24,7 +24,6 @@ class CommandFailureDetector(FailureDetector): """ Detects a failure based on a threshold of failed commands during a specific period of time. """ - def __init__( self, threshold: int, diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index a34ef01476..9503d79d9b 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -4,7 +4,7 @@ from redis import Redis from redis.data_structure import WeightedList -from redis.multidb.circuit import CircuitBreaker, State as CBState +from redis.multidb.circuit import State as CBState, SyncCircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases @@ -19,8 +19,8 @@ def mock_client() -> Redis: return Mock(spec=Redis) @pytest.fixture() -def mock_cb() -> CircuitBreaker: - return Mock(spec=CircuitBreaker) +def mock_cb() -> SyncCircuitBreaker: + return Mock(spec=SyncCircuitBreaker) @pytest.fixture() def mock_fd() -> FailureDetector: @@ -41,7 +41,7 @@ def mock_db(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -55,7 +55,7 @@ def mock_db1(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -69,7 +69,7 @@ def mock_db2(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index 7dc642373b..f5f39c3f6b 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -1,7 +1,7 @@ import pybreaker import pytest -from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker, SyncCircuitBreaker class TestPBCircuitBreaker: @@ -39,7 +39,7 @@ def test_cb_executes_callback_on_state_changed(self): adapter = PBCircuitBreakerAdapter(cb=pb_circuit) called_count = 0 - def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): + def callback(cb: SyncCircuitBreaker, old_state: CbState, new_state: CbState): nonlocal called_count assert old_state == CbState.CLOSED assert new_state == CbState.HALF_OPEN diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 193980d37c..c7c15fe684 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -8,7 +8,7 @@ from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import SyncDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy @@ -458,7 +458,7 @@ def test_set_active_database( assert client.set('key', 'value') == 'OK' with pytest.raises(ValueError, match='Given database is not a member of database list'): - client.set_active_database(Mock(spec=AbstractDatabase)) + client.set_active_database(Mock(spec=SyncDatabase)) mock_hc.check_health.return_value = False diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 87aae701a9..e428b3ce7a 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,6 +1,6 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database @@ -49,11 +49,11 @@ def test_overridden_config(self): mock_connection_pools[0].connection_kwargs = {} mock_connection_pools[1].connection_kwargs = {} mock_connection_pools[2].connection_kwargs = {} - mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1 = Mock(spec=SyncCircuitBreaker) mock_cb1.grace_period = grace_period - mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2 = Mock(spec=SyncCircuitBreaker) mock_cb2.grace_period = grace_period - mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3 = Mock(spec=SyncCircuitBreaker) mock_cb3.grace_period = grace_period mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] @@ -113,7 +113,7 @@ def test_default_config(self): def test_overridden_config(self): mock_connection_pool = Mock(spec=ConnectionPool) - mock_circuit = Mock(spec=CircuitBreaker) + mock_circuit = Mock(spec=SyncCircuitBreaker) config = DatabaseConfig( client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py index 86d6e1cd82..28687f2a11 100644 --- a/tests/test_multidb/test_failure_detector.py +++ b/tests/test_multidb/test_failure_detector.py @@ -3,7 +3,7 @@ import pytest -from redis.multidb.command_executor import CommandExecutor +from redis.multidb.command_executor import SyncCommandExecutor from redis.multidb.failure_detector import CommandFailureDetector from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -19,7 +19,7 @@ class TestCommandFailureDetector: ) def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): fd = CommandFailureDetector(5, 1) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -41,7 +41,7 @@ def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exce ) def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): fd = CommandFailureDetector(5, 1) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -62,7 +62,7 @@ def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interv ) def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): fd = CommandFailureDetector(5, 0.3) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -96,7 +96,7 @@ def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_e ) def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): fd = CommandFailureDetector(5, 0.3) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -128,7 +128,7 @@ def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): ) def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): fd = CommandFailureDetector(5, 1, error_types=[ConnectionError]) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED From bad9bcc32a69265cf1c5709b41b4867362e8007b Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 29 Aug 2025 11:37:36 +0300 Subject: [PATCH 02/20] Added base async components --- redis/asyncio/multidb/__init__.py | 0 redis/asyncio/multidb/circuit.py | 26 +++ redis/asyncio/multidb/command_executor.py | 95 +++++++++++ redis/asyncio/multidb/database.py | 67 ++++++++ redis/asyncio/multidb/event.py | 65 ++++++++ redis/asyncio/multidb/failover.py | 51 ++++++ redis/asyncio/multidb/failure_detector.py | 29 ++++ redis/asyncio/multidb/healthcheck.py | 75 +++++++++ redis/event.py | 3 + redis/multidb/circuit.py | 4 +- redis/utils.py | 6 + tests/test_asyncio/test_multidb/__init__.py | 0 tests/test_asyncio/test_multidb/conftest.py | 59 +++++++ .../test_asyncio/test_multidb/test_circuit.py | 58 +++++++ .../test_multidb/test_failover.py | 121 ++++++++++++++ .../test_multidb/test_failure_detector.py | 153 ++++++++++++++++++ .../test_multidb/test_healthcheck.py | 48 ++++++ 17 files changed, 858 insertions(+), 2 deletions(-) create mode 100644 redis/asyncio/multidb/__init__.py create mode 100644 redis/asyncio/multidb/circuit.py create mode 100644 redis/asyncio/multidb/command_executor.py create mode 100644 redis/asyncio/multidb/database.py create mode 100644 redis/asyncio/multidb/event.py create mode 100644 redis/asyncio/multidb/failover.py create mode 100644 redis/asyncio/multidb/failure_detector.py create mode 100644 redis/asyncio/multidb/healthcheck.py create mode 100644 tests/test_asyncio/test_multidb/__init__.py create mode 100644 tests/test_asyncio/test_multidb/conftest.py create mode 100644 tests/test_asyncio/test_multidb/test_circuit.py create mode 100644 tests/test_asyncio/test_multidb/test_failover.py create mode 100644 tests/test_asyncio/test_multidb/test_failure_detector.py create mode 100644 tests/test_asyncio/test_multidb/test_healthcheck.py diff --git a/redis/asyncio/multidb/__init__.py b/redis/asyncio/multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/asyncio/multidb/circuit.py b/redis/asyncio/multidb/circuit.py new file mode 100644 index 0000000000..97411e6e42 --- /dev/null +++ b/redis/asyncio/multidb/circuit.py @@ -0,0 +1,26 @@ +from abc import abstractmethod +from typing import Callable + +import pybreaker + +from redis.multidb.circuit import CircuitBreaker, State, BaseCircuitBreaker, PBCircuitBreakerAdapter + + +class AsyncCircuitBreaker(CircuitBreaker): + """Async implementation of Circuit Breaker interface.""" + + @abstractmethod + async def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + +class AsyncPBCircuitBreakerAdapter(BaseCircuitBreaker, AsyncCircuitBreaker): + """ + Async adapter for pybreaker's CircuitBreaker implementation. + """ + def __init__(self, cb: pybreaker.CircuitBreaker): + super().__init__(cb) + self._sync_cb = PBCircuitBreakerAdapter(cb) + + async def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + self._sync_cb.on_state_changed(cb) \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py new file mode 100644 index 0000000000..18117160ee --- /dev/null +++ b/redis/asyncio/multidb/command_executor.py @@ -0,0 +1,95 @@ +from abc import abstractmethod +from typing import List, Optional, Callable, Any + +from redis.asyncio.client import PubSub, Pipeline +from redis.asyncio.multidb.database import Databases, AsyncDatabase +from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.retry import Retry +from redis.multidb.command_executor import CommandExecutor + + +class AsyncCommandExecutor(CommandExecutor): + + @property + @abstractmethod + def databases(self) -> Databases: + """Returns a list of databases.""" + pass + + @property + @abstractmethod + def failure_detectors(self) -> List[AsyncFailureDetector]: + """Returns a list of failure detectors.""" + pass + + @abstractmethod + def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None: + """Adds a new failure detector to the list of failure detectors.""" + pass + + @property + @abstractmethod + def active_database(self) -> Optional[AsyncDatabase]: + """Returns currently active database.""" + pass + + @active_database.setter + @abstractmethod + def active_database(self, database: AsyncDatabase) -> None: + """Sets the currently active database.""" + pass + + @property + @abstractmethod + def active_pubsub(self) -> Optional[PubSub]: + """Returns currently active pubsub.""" + pass + + @active_pubsub.setter + @abstractmethod + def active_pubsub(self, pubsub: PubSub) -> None: + """Sets currently active pubsub.""" + pass + + @property + @abstractmethod + def failover_strategy(self) -> AsyncFailoverStrategy: + """Returns failover strategy.""" + pass + + @property + @abstractmethod + def command_retry(self) -> Retry: + """Returns command retry object.""" + pass + + @abstractmethod + async def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" + pass + + @abstractmethod + async def execute_command(self, *args, **options): + """Executes a command and returns the result.""" + pass + + @abstractmethod + async def execute_pipeline(self, command_stack: tuple): + """Executes a stack of commands in pipeline.""" + pass + + @abstractmethod + async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + """Executes a transaction block wrapped in callback.""" + pass + + @abstractmethod + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """Executes a given method on active pub/sub.""" + pass + + @abstractmethod + def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + """Executes pub/sub run in a thread.""" + pass \ No newline at end of file diff --git a/redis/asyncio/multidb/database.py b/redis/asyncio/multidb/database.py new file mode 100644 index 0000000000..85320f3aaa --- /dev/null +++ b/redis/asyncio/multidb/database.py @@ -0,0 +1,67 @@ +from abc import abstractmethod +from typing import Union, Optional + +from redis.asyncio import Redis, RedisCluster +from redis.asyncio.multidb.circuit import AsyncCircuitBreaker +from redis.data_structure import WeightedList +from redis.multidb.database import AbstractDatabase, BaseDatabase +from redis.typing import Number + + +class AsyncDatabase(AbstractDatabase): + """Database with an underlying asynchronous redis client.""" + @property + @abstractmethod + def client(self) -> Union[Redis, RedisCluster]: + """The underlying redis client.""" + pass + + @client.setter + @abstractmethod + def client(self, client: Union[Redis, RedisCluster]): + """Set the underlying redis client.""" + pass + + @property + @abstractmethod + def circuit(self) -> AsyncCircuitBreaker: + """Circuit breaker for the current database.""" + pass + + @circuit.setter + @abstractmethod + def circuit(self, circuit: AsyncCircuitBreaker): + """Set the circuit breaker for the current database.""" + pass + +Databases = WeightedList[tuple[AsyncDatabase, Number]] + +class Database(BaseDatabase, AsyncDatabase): + def __init__( + self, + client: Union[Redis, RedisCluster], + circuit: AsyncCircuitBreaker, + weight: float, + health_check_url: Optional[str] = None, + ): + self._client = client + self._cb = circuit + self._cb.database = self + super().__init__(weight, health_check_url) + + @property + def client(self) -> Union[Redis, RedisCluster]: + return self._client + + @client.setter + def client(self, client: Union[Redis, RedisCluster]): + self._client = client + + @property + def circuit(self) -> AsyncCircuitBreaker: + return self._cb + + @circuit.setter + def circuit(self, circuit: AsyncCircuitBreaker): + self._cb = circuit + diff --git a/redis/asyncio/multidb/event.py b/redis/asyncio/multidb/event.py new file mode 100644 index 0000000000..ea5534ce86 --- /dev/null +++ b/redis/asyncio/multidb/event.py @@ -0,0 +1,65 @@ +from typing import List + +from redis.asyncio.multidb.database import AsyncDatabase +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent + + +class AsyncActiveDatabaseChanged: + """ + Event fired when an async active database has been changed. + """ + def __init__( + self, + old_database: AsyncDatabase, + new_database: AsyncDatabase, + command_executor, + **kwargs + ): + self._old_database = old_database + self._new_database = new_database + self._command_executor = command_executor + self._kwargs = kwargs + + @property + def old_database(self) -> AsyncDatabase: + return self._old_database + + @property + def new_database(self) -> AsyncDatabase: + return self._new_database + + @property + def command_executor(self): + return self._command_executor + + @property + def kwargs(self): + return self._kwargs + +class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface): + """ + Re-subscribe the currently active pub / sub to a new active database. + """ + async def listen(self, event: AsyncActiveDatabaseChanged): + old_pubsub = event.command_executor.active_pubsub + + if old_pubsub is not None: + # Re-assign old channels and patterns so they will be automatically subscribed on connection. + new_pubsub = event.new_database.client.pubsub(**event.kwargs) + new_pubsub.channels = old_pubsub.channels + new_pubsub.patterns = old_pubsub.patterns + await new_pubsub.on_connect(None) + event.command_executor.active_pubsub = new_pubsub + await old_pubsub.close() + +class RegisterCommandFailure(AsyncEventListenerInterface): + """ + Event listener that registers command failures and passing it to the failure detectors. + """ + def __init__(self, failure_detectors: List[AsyncFailureDetector]): + self._failure_detectors = failure_detectors + + async def listen(self, event: AsyncOnCommandsFailEvent) -> None: + for failure_detector in self._failure_detectors: + await failure_detector.register_failure(event.exception, event.commands) \ No newline at end of file diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py new file mode 100644 index 0000000000..ad7f25ce41 --- /dev/null +++ b/redis/asyncio/multidb/failover.py @@ -0,0 +1,51 @@ +from abc import abstractmethod, ABC + +from redis.asyncio.multidb.database import AsyncDatabase, Databases +from redis.multidb.circuit import State as CBState +from redis.asyncio.retry import Retry +from redis.data_structure import WeightedList +from redis.multidb.exception import NoValidDatabaseException +from redis.utils import dummy_fail_async + + +class AsyncFailoverStrategy(ABC): + + @property + @abstractmethod + async def database(self) -> AsyncDatabase: + """Select the database according to the strategy.""" + pass + + @abstractmethod + def set_databases(self, databases: Databases) -> None: + """Set the database strategy operates on.""" + pass + +class WeightBasedFailoverStrategy(AsyncFailoverStrategy): + """ + Failover strategy based on database weights. + """ + def __init__( + self, + retry: Retry + ): + self._retry = retry + self._retry.update_supported_errors([NoValidDatabaseException]) + self._databases = WeightedList() + + @property + async def database(self) -> AsyncDatabase: + return await self._retry.call_with_retry( + lambda: self._get_active_database(), + lambda _: dummy_fail_async() + ) + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + + async def _get_active_database(self) -> AsyncDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException('No valid database available for communication') \ No newline at end of file diff --git a/redis/asyncio/multidb/failure_detector.py b/redis/asyncio/multidb/failure_detector.py new file mode 100644 index 0000000000..8aa4752924 --- /dev/null +++ b/redis/asyncio/multidb/failure_detector.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + +from redis.multidb.failure_detector import FailureDetector + + +class AsyncFailureDetector(ABC): + + @abstractmethod + async def register_failure(self, exception: Exception, cmd: tuple) -> None: + """Register a failure that occurred during command execution.""" + pass + + @abstractmethod + def set_command_executor(self, command_executor) -> None: + """Set the command executor for this failure.""" + pass + +class FailureDetectorAsyncWrapper(AsyncFailureDetector): + """ + Async wrapper for the failure detector. + """ + def __init__(self, failure_detector: FailureDetector) -> None: + self._failure_detector = failure_detector + + async def register_failure(self, exception: Exception, cmd: tuple) -> None: + self._failure_detector.register_failure(exception, cmd) + + def set_command_executor(self, command_executor) -> None: + self._failure_detector.set_command_executor(command_executor) \ No newline at end of file diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py new file mode 100644 index 0000000000..7ae7bf34de --- /dev/null +++ b/redis/asyncio/multidb/healthcheck.py @@ -0,0 +1,75 @@ +import logging +from abc import ABC, abstractmethod + +from redis.asyncio import Redis +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialWithJitterBackoff +from redis.utils import dummy_fail_async + +DEFAULT_HEALTH_CHECK_RETRIES = 3 +DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) + +logger = logging.getLogger(__name__) + +class HealthCheck(ABC): + + @property + @abstractmethod + def retry(self) -> Retry: + """The retry object to use for health checks.""" + pass + + @abstractmethod + async def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class AbstractHealthCheck(HealthCheck): + def __init__( + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) -> None: + self._retry = retry + self._retry.update_supported_errors([ConnectionRefusedError]) + + @property + def retry(self) -> Retry: + return self._retry + + @abstractmethod + async def check_health(self, database) -> bool: + pass + +class EchoHealthCheck(AbstractHealthCheck): + def __init__( + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) -> None: + """ + Check database healthiness by sending an echo request. + """ + super().__init__( + retry=retry, + ) + async def check_health(self, database) -> bool: + return await self._retry.call_with_retry( + lambda: self._returns_echoed_message(database), + lambda _: dummy_fail_async() + ) + + async def _returns_echoed_message(self, database) -> bool: + expected_message = ["healthcheck", b"healthcheck"] + + if isinstance(database.client, Redis): + actual_message = await database.client.execute_command("ECHO" ,"healthcheck") + return actual_message in expected_message + else: + # For a cluster checks if all nodes are healthy. + all_nodes = database.client.get_nodes() + for node in all_nodes: + actual_message = await node.redis_connection.execute_command("ECHO" ,"healthcheck") + + if actual_message not in expected_message: + return False + + return True \ No newline at end of file diff --git a/redis/event.py b/redis/event.py index 1fa66f0587..4d167442eb 100644 --- a/redis/event.py +++ b/redis/event.py @@ -271,6 +271,9 @@ def commands(self) -> tuple: def exception(self) -> Exception: return self._exception +class AsyncOnCommandsFailEvent(OnCommandsFailEvent): + pass + class ReAuthConnectionListener(EventListenerInterface): """ Listener that performs re-authentication of given connection. diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 221dc556a3..576ee27fab 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -87,7 +87,7 @@ class SyncCircuitBreaker(CircuitBreaker): Synchronous implementation of Circuit Breaker interface. """ @abstractmethod - def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" pass @@ -129,6 +129,6 @@ def __init__(self, cb: pybreaker.CircuitBreaker): """ super().__init__(cb) - def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): listener = PBListener(cb, self.database) self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/utils.py b/redis/utils.py index 94bfab61bb..1800582e46 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -314,3 +314,9 @@ def dummy_fail(): Fake function for a Retry object if you don't need to handle each failure. """ pass + +async def dummy_fail_async(): + """ + Async fake function for a Retry object if you don't need to handle each failure. + """ + pass \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/__init__.py b/tests/test_asyncio/test_multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py new file mode 100644 index 0000000000..1f67e3c63c --- /dev/null +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -0,0 +1,59 @@ +from unittest.mock import Mock + +import pytest + +from redis.multidb.circuit import State as CBState +from redis.asyncio import Redis +from redis.asyncio.multidb.circuit import AsyncCircuitBreaker +from redis.asyncio.multidb.database import Database + + +@pytest.fixture() +def mock_client() -> Redis: + return Mock(spec=Redis) + +@pytest.fixture() +def mock_cb() -> AsyncCircuitBreaker: + return Mock(spec=AsyncCircuitBreaker) + +@pytest.fixture() +def mock_db(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db1(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db2(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_circuit.py b/tests/test_asyncio/test_multidb/test_circuit.py new file mode 100644 index 0000000000..b1080cfc7d --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_circuit.py @@ -0,0 +1,58 @@ +import pybreaker +import pytest + +from redis.asyncio.multidb.circuit import ( + AsyncPBCircuitBreakerAdapter, + State as CbState, +) +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter + + +class TestAsyncPBCircuitBreaker: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CbState.CLOSED}}, + ], + indirect=True, + ) + async def test_cb_correctly_configured(self, mock_db): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = AsyncPBCircuitBreakerAdapter(cb=pb_circuit) + assert adapter.state == CbState.CLOSED + + adapter.state = CbState.OPEN + assert adapter.state == CbState.OPEN + + adapter.state = CbState.HALF_OPEN + assert adapter.state == CbState.HALF_OPEN + + adapter.state = CbState.CLOSED + assert adapter.state == CbState.CLOSED + + assert adapter.grace_period == 5 + adapter.grace_period = 10 + + assert adapter.grace_period == 10 + + adapter.database = mock_db + assert adapter.database == mock_db + + @pytest.mark.asyncio + async def test_cb_executes_callback_on_state_changed(self): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = AsyncPBCircuitBreakerAdapter(cb=pb_circuit) + called_count = 0 + + def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): + nonlocal called_count + assert old_state == CbState.CLOSED + assert new_state == CbState.HALF_OPEN + assert isinstance(cb, PBCircuitBreakerAdapter) + called_count += 1 + + await adapter.on_state_changed(callback) + adapter.state = CbState.HALF_OPEN + + assert called_count == 1 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py new file mode 100644 index 0000000000..d7bc4411b6 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -0,0 +1,121 @@ +from unittest.mock import PropertyMock + +import pytest + +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.retry import Retry + + +class TestAsyncWeightBasedFailoverStrategy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + ids=['all closed - highest weight', 'highest weight - open'], + indirect=True, + ) + async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + + strategy = WeightBasedFailoverStrategy(retry=retry) + strategy.set_databases(databases) + + assert await strategy.database == mock_db1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + assert await failover_strategy.database == mock_db + assert state_mock.call_count == 4 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert await failover_strategy.database + + assert state_mock.call_count == 4 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert await failover_strategy.database \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failure_detector.py b/tests/test_asyncio/test_multidb/test_failure_detector.py new file mode 100644 index 0000000000..3c1eb4fabd --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_failure_detector.py @@ -0,0 +1,153 @@ +import asyncio +from unittest.mock import Mock + +import pytest + +from redis.asyncio.multidb.command_executor import AsyncCommandExecutor +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper +from redis.multidb.circuit import State as CBState +from redis.multidb.failure_detector import CommandFailureDetector + + +class TestFailureDetectorAsyncWrapper: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 0.3)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + # 4 more failures as the last one already refreshed timer + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 0.3)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.4) + + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1, error_types=[ConnectionError])) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py new file mode 100644 index 0000000000..fd5c8ec3f0 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -0,0 +1,48 @@ +import pytest +from mock.mock import AsyncMock + +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.healthcheck import EchoHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError + + +class TestEchoHealthCheck: + + @pytest.mark.asyncio + async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 + + @pytest.mark.asyncio + async def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'wrong']) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == False + assert mock_client.execute_command.call_count == 3 + + @pytest.mark.asyncio + async def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) + mock_cb.state = CBState.HALF_OPEN + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 \ No newline at end of file From ae42bea09a097855e05fbf62ca757a85df53af5a Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 2 Sep 2025 11:45:57 +0300 Subject: [PATCH 03/20] Added command executor --- redis/asyncio/multidb/command_executor.py | 184 +++++++++++++++++- redis/asyncio/multidb/failover.py | 2 - redis/event.py | 7 +- tests/test_asyncio/test_multidb/conftest.py | 29 ++- .../test_multidb/test_command_executor.py | 165 ++++++++++++++++ .../test_multidb/test_failover.py | 8 +- 6 files changed, 378 insertions(+), 17 deletions(-) create mode 100644 tests/test_asyncio/test_multidb/test_command_executor.py diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 18117160ee..af10a00988 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -1,12 +1,18 @@ from abc import abstractmethod +from datetime import datetime from typing import List, Optional, Callable, Any from redis.asyncio.client import PubSub, Pipeline -from redis.asyncio.multidb.database import Databases, AsyncDatabase +from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database +from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ + ResubscribeOnActiveDatabaseChanged from redis.asyncio.multidb.failover import AsyncFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.multidb.circuit import State as CBState from redis.asyncio.retry import Retry -from redis.multidb.command_executor import CommandExecutor +from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent +from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor +from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL class AsyncCommandExecutor(CommandExecutor): @@ -34,9 +40,8 @@ def active_database(self) -> Optional[AsyncDatabase]: """Returns currently active database.""" pass - @active_database.setter @abstractmethod - def active_database(self, database: AsyncDatabase) -> None: + async def set_active_database(self, database: AsyncDatabase) -> None: """Sets the currently active database.""" pass @@ -85,11 +90,176 @@ async def execute_transaction(self, transaction: Callable[[Pipeline], None], *wa pass @abstractmethod - def execute_pubsub_method(self, method_name: str, *args, **kwargs): + async def execute_pubsub_method(self, method_name: str, *args, **kwargs): """Executes a given method on active pub/sub.""" pass @abstractmethod - def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: """Executes pub/sub run in a thread.""" - pass \ No newline at end of file + pass + + +class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor): + def __init__( + self, + failure_detectors: List[AsyncFailureDetector], + databases: Databases, + command_retry: Retry, + failover_strategy: AsyncFailoverStrategy, + event_dispatcher: EventDispatcherInterface, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + """ + Initialize the DefaultCommandExecutor instance. + + Args: + failure_detectors: List of failure detector instances to monitor database health + databases: Collection of available databases to execute commands on + command_retry: Retry policy for failed command execution + failover_strategy: Strategy for handling database failover + event_dispatcher: Interface for dispatching events + auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database + """ + super().__init__(auto_fallback_interval) + + for fd in failure_detectors: + fd.set_command_executor(command_executor=self) + + self._databases = databases + self._failure_detectors = failure_detectors + self._command_retry = command_retry + self._failover_strategy = failover_strategy + self._event_dispatcher = event_dispatcher + self._active_database: Optional[Database] = None + self._active_pubsub: Optional[PubSub] = None + self._active_pubsub_kwargs = {} + self._setup_event_dispatcher() + self._schedule_next_fallback() + + @property + def databases(self) -> Databases: + return self._databases + + @property + def failure_detectors(self) -> List[AsyncFailureDetector]: + return self._failure_detectors + + def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None: + self._failure_detectors.append(failure_detector) + + @property + def active_database(self) -> Optional[AsyncDatabase]: + return self._active_database + + async def set_active_database(self, database: AsyncDatabase) -> None: + old_active = self._active_database + self._active_database = database + + if old_active is not None and old_active is not database: + await self._event_dispatcher.dispatch_async( + AsyncActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs) + ) + + @property + def active_pubsub(self) -> Optional[PubSub]: + return self._active_pubsub + + @active_pubsub.setter + def active_pubsub(self, pubsub: PubSub) -> None: + self._active_pubsub = pubsub + + @property + def failover_strategy(self) -> AsyncFailoverStrategy: + return self._failover_strategy + + @property + def command_retry(self) -> Retry: + return self._command_retry + + async def pubsub(self, **kwargs): + async def callback(): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs + return None + + return await self._execute_with_failure_detection(callback) + + async def execute_command(self, *args, **options): + async def callback(): + return await self._active_database.client.execute_command(*args, **options) + + return await self._execute_with_failure_detection(callback, args) + + async def execute_pipeline(self, command_stack: tuple): + async def callback(): + with self._active_database.client.pipeline() as pipe: + for command, options in command_stack: + await pipe.execute_command(*command, **options) + + return await pipe.execute() + + return await self._execute_with_failure_detection(callback, command_stack) + + async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + async def callback(): + return await self._active_database.client.transaction(transaction, *watches, **options) + + return await self._execute_with_failure_detection(callback) + + async def execute_pubsub_method(self, method_name: str, *args, **kwargs): + async def callback(): + method = getattr(self.active_pubsub, method_name) + return await method(*args, **kwargs) + + return await self._execute_with_failure_detection(callback, *args) + + async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + async def callback(): + return await self._active_pubsub.run(poll_timeout=sleep_time, **kwargs) + + return await self._execute_with_failure_detection(callback) + + async def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): + """ + Execute a commands execution callback with failure detection. + """ + async def wrapper(): + # On each retry we need to check active database as it might change. + await self._check_active_database() + return await callback() + + return await self._command_retry.call_with_retry( + lambda: wrapper(), + lambda error: self._on_command_fail(error, *cmds), + ) + + async def _check_active_database(self): + """ + Checks if active a database needs to be updated. + """ + if ( + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) + ): + await self.set_active_database(await self._failover_strategy.database()) + self._schedule_next_fallback() + + async def _on_command_fail(self, error, *args): + await self._event_dispatcher.dispatch_async(AsyncOnCommandsFailEvent(args, error)) + + def _setup_event_dispatcher(self): + """ + Registers necessary listeners. + """ + failure_listener = RegisterCommandFailure(self._failure_detectors) + resubscribe_listener = ResubscribeOnActiveDatabaseChanged() + self._event_dispatcher.register_listeners({ + AsyncOnCommandsFailEvent: [failure_listener], + AsyncActiveDatabaseChanged: [resubscribe_listener], + }) \ No newline at end of file diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py index ad7f25ce41..a2ed427e05 100644 --- a/redis/asyncio/multidb/failover.py +++ b/redis/asyncio/multidb/failover.py @@ -10,7 +10,6 @@ class AsyncFailoverStrategy(ABC): - @property @abstractmethod async def database(self) -> AsyncDatabase: """Select the database according to the strategy.""" @@ -33,7 +32,6 @@ def __init__( self._retry.update_supported_errors([NoValidDatabaseException]) self._databases = WeightedList() - @property async def database(self) -> AsyncDatabase: return await self._retry.call_with_retry( lambda: self._get_active_database(), diff --git a/redis/event.py b/redis/event.py index 4d167442eb..8327ec5f76 100644 --- a/redis/event.py +++ b/redis/event.py @@ -43,7 +43,10 @@ async def dispatch_async(self, event: object): pass @abstractmethod - def register_listeners(self, mappings: Dict[Type[object], List[EventListenerInterface]]): + def register_listeners( + self, + mappings: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]] + ): """Register additional listeners.""" pass @@ -99,7 +102,7 @@ def dispatch(self, event: object): listener.listen(event) async def dispatch_async(self, event: object): - with self._async_lock: + async with self._async_lock: listeners = self._event_listeners_mapping.get(type(event), []) for listener in listeners: diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py index 1f67e3c63c..0c4e427264 100644 --- a/tests/test_asyncio/test_multidb/conftest.py +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -2,10 +2,14 @@ import pytest +from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState from redis.asyncio import Redis from redis.asyncio.multidb.circuit import AsyncCircuitBreaker -from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.database import Database, Databases @pytest.fixture() @@ -16,6 +20,18 @@ def mock_client() -> Redis: def mock_cb() -> AsyncCircuitBreaker: return Mock(spec=AsyncCircuitBreaker) +@pytest.fixture() +def mock_fd() -> AsyncFailureDetector: + return Mock(spec=AsyncFailureDetector) + +@pytest.fixture() +def mock_fs() -> AsyncFailoverStrategy: + return Mock(spec=AsyncFailoverStrategy) + +@pytest.fixture() +def mock_hc() -> HealthCheck: + return Mock(spec=HealthCheck) + @pytest.fixture() def mock_db(request) -> Database: db = Mock(spec=Database) @@ -56,4 +72,13 @@ def mock_db2(request) -> Database: mock_cb.state = cb.get("state", CBState.CLOSED) db.circuit = mock_cb - return db \ No newline at end of file + return db + + +def create_weighted_list(*databases) -> Databases: + dbs = WeightedList() + + for db in databases: + dbs.add(db, db.weight) + + return dbs \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_command_executor.py b/tests/test_asyncio/test_multidb/test_command_executor.py new file mode 100644 index 0000000000..3f64e6aa0b --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_command_executor.py @@ -0,0 +1,165 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper +from redis.event import EventDispatcher +from redis.exceptions import ConnectionError +from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.multidb.circuit import State as CBState +from redis.multidb.failure_detector import CommandFailureDetector +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +class TestDefaultCommandExecutor: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) + ) + + await executor.set_active_database(mock_db1) + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + + await executor.set_active_database(mock_db2) + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_automatically_select_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) + ) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.circuit.state = CBState.OPEN + + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 2 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_fallback_to_another_db_after_fallback_interval( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), 0) + ) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.weight = 0.1 + await asyncio.sleep(0.15) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + mock_db1.weight = 0.7 + await asyncio.sleep(0.15) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_fallback_to_another_db_after_failure_detection( + self, mock_db, mock_db1, mock_db2, mock_fs + ): + mock_db1.client.execute_command = AsyncMock(side_effect=['OK1', ConnectionError, ConnectionError, ConnectionError, 'OK1']) + mock_db2.client.execute_command = AsyncMock(side_effect=['OK2', ConnectionError, ConnectionError, ConnectionError]) + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + threshold = 3 + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(threshold, 1)) + ed = EventDispatcher() + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), threshold), + ) + fd.set_command_executor(command_executor=executor) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_selector.call_count == 3 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py index d7bc4411b6..f692c40643 100644 --- a/tests/test_asyncio/test_multidb/test_failover.py +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -39,7 +39,7 @@ async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): strategy = WeightBasedFailoverStrategy(retry=retry) strategy.set_databases(databases) - assert await strategy.database == mock_db1 + assert await strategy.database() == mock_db1 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -67,7 +67,7 @@ async def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2 failover_strategy = WeightBasedFailoverStrategy(retry=retry) failover_strategy.set_databases(databases) - assert await failover_strategy.database == mock_db + assert await failover_strategy.database() == mock_db assert state_mock.call_count == 4 @pytest.mark.asyncio @@ -97,7 +97,7 @@ async def test_get_valid_database_throws_exception_with_retries(self, mock_db, m failover_strategy.set_databases(databases) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert await failover_strategy.database + assert await failover_strategy.database() assert state_mock.call_count == 4 @@ -118,4 +118,4 @@ async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock failover_strategy = WeightBasedFailoverStrategy(retry=retry) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert await failover_strategy.database \ No newline at end of file + assert await failover_strategy.database() \ No newline at end of file From 8fc74b96c20e2cfcd4160217afbf05dced520375 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 2 Sep 2025 12:53:16 +0300 Subject: [PATCH 04/20] Added recurring background tasks with event loop only --- redis/asyncio/multidb/config.py | 169 ++++++++++++++++++++++++++++++++ redis/background.py | 52 +++++++++- tests/test_background.py | 33 +++++++ 3 files changed, 251 insertions(+), 3 deletions(-) create mode 100644 redis/asyncio/multidb/config.py diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py new file mode 100644 index 0000000000..9b3588aa06 --- /dev/null +++ b/redis/asyncio/multidb/config.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass, field +from typing import Optional, List, Type, Union + +import pybreaker + +from redis.asyncio import ConnectionPool, Redis, RedisCluster +from redis.asyncio.multidb.circuit import AsyncCircuitBreaker, AsyncPBCircuitBreakerAdapter +from redis.asyncio.multidb.database import Databases, Database +from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper +from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, \ + EchoHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff +from redis.data_structure import WeightedList +from redis.event import EventDispatcherInterface, EventDispatcher +from redis.multidb.failure_detector import CommandFailureDetector + +DEFAULT_GRACE_PERIOD = 5.0 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_FAILURES_THRESHOLD = 3 +DEFAULT_FAILURES_DURATION = 2 +DEFAULT_FAILOVER_RETRIES = 3 +DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) +DEFAULT_AUTO_FALLBACK_INTERVAL = -1 + +def default_event_dispatcher() -> EventDispatcherInterface: + return EventDispatcher() + +@dataclass +class DatabaseConfig: + """ + Dataclass representing the configuration for a database connection. + + This class is used to store configuration settings for a database connection, + including client options, connection sourcing details, circuit breaker settings, + and cluster-specific properties. It provides a structure for defining these + attributes and allows for the creation of customized configurations for various + database setups. + + Attributes: + weight (float): Weight of the database to define the active one. + client_kwargs (dict): Additional parameters for the database client connection. + from_url (Optional[str]): Redis URL way of connecting to the database. + from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. + circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. + grace_period (float): Grace period after which we need to check if the circuit could be closed again. + health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used + on public Redis Enterprise endpoints. + + Methods: + default_circuit_breaker: + Generates and returns a default CircuitBreaker instance adapted for use. + """ + weight: float = 1.0 + client_kwargs: dict = field(default_factory=dict) + from_url: Optional[str] = None + from_pool: Optional[ConnectionPool] = None + circuit: Optional[AsyncCircuitBreaker] = None + grace_period: float = DEFAULT_GRACE_PERIOD + health_check_url: Optional[str] = None + + def default_circuit_breaker(self) -> AsyncCircuitBreaker: + circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) + return AsyncPBCircuitBreakerAdapter(circuit_breaker) + +@dataclass +class MultiDbConfig: + """ + Configuration class for managing multiple database connections in a resilient and fail-safe manner. + + Attributes: + databases_config: A list of database configurations. + client_class: The client class used to manage database connections. + command_retry: Retry strategy for executing database commands. + failure_detectors: Optional list of additional failure detectors for monitoring database failures. + failure_threshold: Threshold for determining database failure. + failures_interval: Time interval for tracking database failures. + health_checks: Optional list of additional health checks performed on databases. + health_check_interval: Time interval for executing health checks. + health_check_retries: Number of retry attempts for performing health checks. + health_check_backoff: Backoff strategy for health check retries. + failover_strategy: Optional strategy for handling database failover scenarios. + failover_retries: Number of retries allowed for failover operations. + failover_backoff: Backoff strategy for failover retries. + auto_fallback_interval: Time interval to trigger automatic fallback. + event_dispatcher: Interface for dispatching events related to database operations. + + Methods: + databases: + Retrieves a collection of database clients managed by weighted configurations. + Initializes database clients based on the provided configuration and removes + redundant retry objects for lower-level clients to rely on global retry logic. + + default_failure_detectors: + Returns the default list of failure detectors used to monitor database failures. + + default_health_checks: + Returns the default list of health checks used to monitor database health + with specific retry and backoff strategies. + + default_failover_strategy: + Provides the default failover strategy used for handling failover scenarios + with defined retry and backoff configurations. + """ + databases_config: List[DatabaseConfig] + client_class: Type[Union[Redis, RedisCluster]] = Redis + command_retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ) + failure_detectors: Optional[List[AsyncFailureDetector]] = None + failure_threshold: int = DEFAULT_FAILURES_THRESHOLD + failures_interval: float = DEFAULT_FAILURES_DURATION + health_checks: Optional[List[HealthCheck]] = None + health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL + health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES + health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF + failover_strategy: Optional[AsyncFailoverStrategy] = None + failover_retries: int = DEFAULT_FAILOVER_RETRIES + failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL + event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) + + def databases(self) -> Databases: + databases = WeightedList() + + for database_config in self.databases_config: + # The retry object is not used in the lower level clients, so we can safely remove it. + # We rely on command_retry in terms of global retries. + database_config.client_kwargs.update({"retry": Retry(retries=0, backoff=NoBackoff())}) + + if database_config.from_url: + client = self.client_class.from_url(database_config.from_url, **database_config.client_kwargs) + elif database_config.from_pool: + database_config.from_pool.set_retry(Retry(retries=0, backoff=NoBackoff())) + client = self.client_class.from_pool(connection_pool=database_config.from_pool) + else: + client = self.client_class(**database_config.client_kwargs) + + circuit = database_config.default_circuit_breaker() \ + if database_config.circuit is None else database_config.circuit + databases.add( + Database( + client=client, + circuit=circuit, + weight=database_config.weight, + health_check_url=database_config.health_check_url + ), + database_config.weight + ) + + return databases + + def default_failure_detectors(self) -> List[AsyncFailureDetector]: + return [ + FailureDetectorAsyncWrapper( + CommandFailureDetector(threshold=self.failure_threshold, duration=self.failures_interval) + ), + ] + + def default_health_checks(self) -> List[HealthCheck]: + return [ + EchoHealthCheck(retry=Retry(retries=self.health_check_retries, backoff=self.health_check_backoff)), + ] + + def default_failover_strategy(self) -> AsyncFailoverStrategy: + return WeightBasedFailoverStrategy( + retry=Retry(retries=self.failover_retries, backoff=self.failover_backoff), + ) \ No newline at end of file diff --git a/redis/background.py b/redis/background.py index 6466649859..ce43cbfa7a 100644 --- a/redis/background.py +++ b/redis/background.py @@ -1,6 +1,7 @@ import asyncio import threading -from typing import Callable +from typing import Callable, Coroutine, Any + class BackgroundScheduler: """ @@ -45,7 +46,35 @@ def run_recurring( ) thread.start() - def _call_later(self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args): + async def run_recurring_async( + self, + interval: float, + coro: Callable[..., Coroutine[Any, Any, Any]], + *args + ): + """ + Runs recurring coroutine with given interval in seconds in the current event loop. + To be used only from an async context. No additional threads are created. + """ + loop = asyncio.get_running_loop() + wrapped = _async_to_sync_wrapper(loop, coro, *args) + + def tick(): + # Schedule the coroutine + wrapped() + # Schedule next tick + self._next_timer = loop.call_later(interval, tick) + + # Schedule first tick + self._next_timer = loop.call_later(interval, tick) + + def _call_later( + self, + loop: asyncio.AbstractEventLoop, + delay: float, + callback: Callable, + *args + ): self._next_timer = loop.call_later(delay, callback, *args) def _call_later_recurring( @@ -86,4 +115,21 @@ def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop, call_soon """ asyncio.set_event_loop(event_loop) event_loop.call_soon(call_soon_cb, event_loop, *args) - event_loop.run_forever() \ No newline at end of file + event_loop.run_forever() + +def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs): + """ + Wraps an asynchronous function so it can be used with loop.call_later. + + :param loop: The event loop in which the coroutine will be executed. + :param coro_func: The coroutine function to wrap. + :param args: Positional arguments to pass to the coroutine function. + :param kwargs: Keyword arguments to pass to the coroutine function. + :return: A regular function suitable for loop.call_later. + """ + + def wrapped(): + # Schedule the coroutine in the event loop + asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop) + + return wrapped \ No newline at end of file diff --git a/tests/test_background.py b/tests/test_background.py index 4b3a5377c1..ba62e5bdd9 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,3 +1,4 @@ +import asyncio from time import sleep import pytest @@ -57,4 +58,36 @@ def callback(arg1: str, arg2: int): sleep(timeout) + assert execute_counter == call_count + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "interval,timeout,call_count", + [ + (0.012, 0.04, 3), + (0.035, 0.04, 1), + (0.045, 0.04, 0), + ] + ) + async def test_run_recurring_async(self, interval, timeout, call_count): + execute_counter = 0 + one = 'arg1' + two = 9999 + + async def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + await scheduler.run_recurring_async(interval, callback, one, two) + assert execute_counter == 0 + + await asyncio.sleep(timeout) + assert execute_counter == call_count \ No newline at end of file From 97c3cde72dad63a46e00f6e6425e690206a2daaa Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 3 Sep 2025 12:21:00 +0300 Subject: [PATCH 05/20] Added MultiDBClient --- redis/asyncio/multidb/circuit.py | 26 - redis/asyncio/multidb/client.py | 237 +++++++++ redis/asyncio/multidb/command_executor.py | 1 + redis/asyncio/multidb/config.py | 10 +- redis/asyncio/multidb/database.py | 12 +- redis/multidb/circuit.py | 13 +- redis/multidb/client.py | 6 +- redis/multidb/config.py | 8 +- redis/multidb/database.py | 12 +- tests/test_asyncio/test_multidb/conftest.py | 38 +- .../test_asyncio/test_multidb/test_circuit.py | 58 --- .../test_asyncio/test_multidb/test_client.py | 471 ++++++++++++++++++ tests/test_multidb/conftest.py | 12 +- tests/test_multidb/test_circuit.py | 4 +- tests/test_multidb/test_client.py | 4 - tests/test_multidb/test_config.py | 10 +- 16 files changed, 784 insertions(+), 138 deletions(-) delete mode 100644 redis/asyncio/multidb/circuit.py create mode 100644 redis/asyncio/multidb/client.py delete mode 100644 tests/test_asyncio/test_multidb/test_circuit.py create mode 100644 tests/test_asyncio/test_multidb/test_client.py diff --git a/redis/asyncio/multidb/circuit.py b/redis/asyncio/multidb/circuit.py deleted file mode 100644 index 97411e6e42..0000000000 --- a/redis/asyncio/multidb/circuit.py +++ /dev/null @@ -1,26 +0,0 @@ -from abc import abstractmethod -from typing import Callable - -import pybreaker - -from redis.multidb.circuit import CircuitBreaker, State, BaseCircuitBreaker, PBCircuitBreakerAdapter - - -class AsyncCircuitBreaker(CircuitBreaker): - """Async implementation of Circuit Breaker interface.""" - - @abstractmethod - async def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): - """Callback called when the state of the circuit changes.""" - pass - -class AsyncPBCircuitBreakerAdapter(BaseCircuitBreaker, AsyncCircuitBreaker): - """ - Async adapter for pybreaker's CircuitBreaker implementation. - """ - def __init__(self, cb: pybreaker.CircuitBreaker): - super().__init__(cb) - self._sync_cb = PBCircuitBreakerAdapter(cb) - - async def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): - self._sync_cb.on_state_changed(cb) \ No newline at end of file diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py new file mode 100644 index 0000000000..dbf03a3ef4 --- /dev/null +++ b/redis/asyncio/multidb/client.py @@ -0,0 +1,237 @@ +import asyncio +from typing import Callable, Optional, Coroutine, Any + +from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.multidb.database import AsyncDatabase, Databases +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD +from redis.background import BackgroundScheduler +from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands +from redis.multidb.exception import NoValidDatabaseException + + +class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): + """ + Client that operates on multiple logical Redis databases. + Should be used in Active-Active database setups. + """ + def __init__(self, config: MultiDbConfig): + self._databases = config.databases() + self._health_checks = config.default_health_checks() + + if config.health_checks is not None: + self._health_checks.extend(config.health_checks) + + self._health_check_interval = config.health_check_interval + self._failure_detectors = config.default_failure_detectors() + + if config.failure_detectors is not None: + self._failure_detectors.extend(config.failure_detectors) + + self._failover_strategy = config.default_failover_strategy() \ + if config.failover_strategy is None else config.failover_strategy + self._failover_strategy.set_databases(self._databases) + self._auto_fallback_interval = config.auto_fallback_interval + self._event_dispatcher = config.event_dispatcher + self._command_retry = config.command_retry + self._command_retry.update_supported_errors([ConnectionRefusedError]) + self.command_executor = DefaultCommandExecutor( + failure_detectors=self._failure_detectors, + databases=self._databases, + command_retry=self._command_retry, + failover_strategy=self._failover_strategy, + event_dispatcher=self._event_dispatcher, + auto_fallback_interval=self._auto_fallback_interval, + ) + self.initialized = False + self._hc_lock = asyncio.Lock() + self._bg_scheduler = BackgroundScheduler() + self._config = config + + async def initialize(self): + """ + Perform initialization of databases to define their initial state. + """ + async def raise_exception_on_failed_hc(error): + raise error + + # Initial databases check to define initial state + await self._check_databases_health(on_error=raise_exception_on_failed_hc) + + # Starts recurring health checks on the background. + await self._bg_scheduler.run_recurring_async( + self._health_check_interval, + self._check_databases_health, + ) + + is_active_db_found = False + + for database, weight in self._databases: + # Set on state changed callback for each circuit. + database.circuit.on_state_changed(self._on_circuit_state_change_callback) + + # Set states according to a weights and circuit state + if database.circuit.state == CBState.CLOSED and not is_active_db_found: + await self.command_executor.set_active_database(database) + is_active_db_found = True + + if not is_active_db_found: + raise NoValidDatabaseException('Initial connection failed - no active database found') + + self.initialized = True + + def get_databases(self) -> Databases: + """ + Returns a sorted (by weight) list of all databases. + """ + return self._databases + + async def set_active_database(self, database: AsyncDatabase) -> None: + """ + Promote one of the existing databases to become an active. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + await self._check_db_health(database) + + if database.circuit.state == CBState.CLOSED: + highest_weighted_db, _ = self._databases.get_top_n(1)[0] + await self.command_executor.set_active_database(database) + return + + raise NoValidDatabaseException('Cannot set active database, database is unhealthy') + + async def add_database(self, database: AsyncDatabase): + """ + Adds a new database to the database list. + """ + for existing_db, _ in self._databases: + if existing_db == database: + raise ValueError('Given database already exists') + + await self._check_db_health(database) + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.add(database, database.weight) + await self._change_active_database(database, highest_weighted_db) + + async def _change_active_database(self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase): + if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: + await self.command_executor.set_active_database(new_database) + + async def remove_database(self, database: AsyncDatabase): + """ + Removes a database from the database list. + """ + weight = self._databases.remove(database) + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + + if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: + await self.command_executor.set_active_database(highest_weighted_db) + + async def update_database_weight(self, database: AsyncDatabase, weight: float): + """ + Updates a database from the database list. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.update_weight(database, weight) + database.weight = weight + await self._change_active_database(database, highest_weighted_db) + + def add_failure_detector(self, failure_detector: AsyncFailureDetector): + """ + Adds a new failure detector to the database. + """ + self._failure_detectors.append(failure_detector) + + async def add_health_check(self, healthcheck: HealthCheck): + """ + Adds a new health check to the database. + """ + async with self._hc_lock: + self._health_checks.append(healthcheck) + + async def execute_command(self, *args, **options): + """ + Executes a single command and return its result. + """ + if not self.initialized: + await self.initialize() + + return await self.command_executor.execute_command(*args, **options) + + async def _check_databases_health( + self, + on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, + ): + """ + Runs health checks as a recurring task. + Runs health checks against all databases. + """ + for database, _ in self._databases: + async with self._hc_lock: + await self._check_db_health(database, on_error) + + async def _check_db_health( + self, + database: AsyncDatabase, + on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, + ) -> None: + """ + Runs health checks on the given database until first failure. + """ + is_healthy = True + + # Health check will setup circuit state + for health_check in self._health_checks: + if not is_healthy: + # If one of the health checks failed, it's considered unhealthy + break + + try: + is_healthy = await health_check.check_health(database) + + if not is_healthy and database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + except Exception as e: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + is_healthy = False + + if on_error: + await on_error(e) + + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + loop = asyncio.get_running_loop() + + if new_state == CBState.HALF_OPEN: + asyncio.create_task(self._check_db_health(circuit.database)) + return + + if old_state == CBState.CLOSED and new_state == CBState.OPEN: + loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) + +def _half_open_circuit(circuit: CircuitBreaker): + circuit.state = CBState.HALF_OPEN \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index af10a00988..22aef83118 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -248,6 +248,7 @@ async def _check_active_database(self): ) ): await self.set_active_database(await self._failover_strategy.database()) + print("Active database now with weight {}", format(self._active_database.weight)) self._schedule_next_fallback() async def _on_command_fail(self, error, *args): diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py index 9b3588aa06..b5f4a0658d 100644 --- a/redis/asyncio/multidb/config.py +++ b/redis/asyncio/multidb/config.py @@ -4,7 +4,6 @@ import pybreaker from redis.asyncio import ConnectionPool, Redis, RedisCluster -from redis.asyncio.multidb.circuit import AsyncCircuitBreaker, AsyncPBCircuitBreakerAdapter from redis.asyncio.multidb.database import Databases, Database from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper @@ -14,6 +13,7 @@ from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcherInterface, EventDispatcher +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter from redis.multidb.failure_detector import CommandFailureDetector DEFAULT_GRACE_PERIOD = 5.0 @@ -43,7 +43,7 @@ class DatabaseConfig: client_kwargs (dict): Additional parameters for the database client connection. from_url (Optional[str]): Redis URL way of connecting to the database. from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. - circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. grace_period (float): Grace period after which we need to check if the circuit could be closed again. health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used on public Redis Enterprise endpoints. @@ -56,13 +56,13 @@ class DatabaseConfig: client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None - circuit: Optional[AsyncCircuitBreaker] = None + circuit: Optional[CircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD health_check_url: Optional[str] = None - def default_circuit_breaker(self) -> AsyncCircuitBreaker: + def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) - return AsyncPBCircuitBreakerAdapter(circuit_breaker) + return PBCircuitBreakerAdapter(circuit_breaker) @dataclass class MultiDbConfig: diff --git a/redis/asyncio/multidb/database.py b/redis/asyncio/multidb/database.py index 85320f3aaa..6afbbbf5ea 100644 --- a/redis/asyncio/multidb/database.py +++ b/redis/asyncio/multidb/database.py @@ -2,8 +2,8 @@ from typing import Union, Optional from redis.asyncio import Redis, RedisCluster -from redis.asyncio.multidb.circuit import AsyncCircuitBreaker from redis.data_structure import WeightedList +from redis.multidb.circuit import CircuitBreaker from redis.multidb.database import AbstractDatabase, BaseDatabase from redis.typing import Number @@ -24,13 +24,13 @@ def client(self, client: Union[Redis, RedisCluster]): @property @abstractmethod - def circuit(self) -> AsyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: """Circuit breaker for the current database.""" pass @circuit.setter @abstractmethod - def circuit(self, circuit: AsyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass @@ -40,7 +40,7 @@ class Database(BaseDatabase, AsyncDatabase): def __init__( self, client: Union[Redis, RedisCluster], - circuit: AsyncCircuitBreaker, + circuit: CircuitBreaker, weight: float, health_check_url: Optional[str] = None, ): @@ -58,10 +58,10 @@ def client(self, client: Union[Redis, RedisCluster]): self._client = client @property - def circuit(self) -> AsyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: return self._cb @circuit.setter - def circuit(self, circuit: AsyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): self._cb = circuit diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 576ee27fab..8f904c0e4b 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -45,6 +45,11 @@ def database(self, database): """Set database associated with this circuit.""" pass + @abstractmethod + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + class BaseCircuitBreaker(CircuitBreaker): """ Base implementation of Circuit Breaker interface. @@ -82,10 +87,6 @@ def database(self): def database(self, database): self._database = database -class SyncCircuitBreaker(CircuitBreaker): - """ - Synchronous implementation of Circuit Breaker interface. - """ @abstractmethod def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" @@ -95,7 +96,7 @@ class PBListener(pybreaker.CircuitBreakerListener): """Wrapper for callback to be compatible with pybreaker implementation.""" def __init__( self, - cb: Callable[[SyncCircuitBreaker, State, State], None], + cb: Callable[[CircuitBreaker, State, State], None], database, ): """ @@ -116,7 +117,7 @@ def state_change(self, cb, old_state, new_state): new_state = State(value=new_state.name) self._cb(cb, old_state, new_state) -class PBCircuitBreakerAdapter(SyncCircuitBreaker, BaseCircuitBreaker): +class PBCircuitBreakerAdapter(BaseCircuitBreaker): def __init__(self, cb: pybreaker.CircuitBreaker): """ Initialize a PBCircuitBreakerAdapter instance. diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 8a0e006977..71e079346a 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -5,7 +5,7 @@ from redis.commands import RedisModuleCommands, CoreCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD -from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector @@ -244,7 +244,7 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): for database, _ in self._databases: self._check_db_health(database, on_error) - def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_state: CBState, new_state: CBState): + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: self._check_db_health(circuit.database) return @@ -252,7 +252,7 @@ def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_sta if old_state == CBState.CLOSED and new_state == CBState.OPEN: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) -def _half_open_circuit(circuit: SyncCircuitBreaker): +def _half_open_circuit(circuit: CircuitBreaker): circuit.state = CBState.HALF_OPEN diff --git a/redis/multidb/config.py b/redis/multidb/config.py index a966ec329a..fc349ed04b 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -9,7 +9,7 @@ from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ @@ -44,7 +44,7 @@ class DatabaseConfig: client_kwargs (dict): Additional parameters for the database client connection. from_url (Optional[str]): Redis URL way of connecting to the database. from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. - circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. grace_period (float): Grace period after which we need to check if the circuit could be closed again. health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used on public Redis Enterprise endpoints. @@ -57,11 +57,11 @@ class DatabaseConfig: client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None - circuit: Optional[SyncCircuitBreaker] = None + circuit: Optional[CircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD health_check_url: Optional[str] = None - def default_circuit_breaker(self) -> SyncCircuitBreaker: + def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) return PBCircuitBreakerAdapter(circuit_breaker) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 75a662d904..9c2ffe3552 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -5,7 +5,7 @@ from redis import RedisCluster from redis.data_structure import WeightedList -from redis.multidb.circuit import SyncCircuitBreaker +from redis.multidb.circuit import CircuitBreaker from redis.typing import Number class AbstractDatabase(ABC): @@ -74,13 +74,13 @@ def client(self, client: Union[redis.Redis, RedisCluster]): @property @abstractmethod - def circuit(self) -> SyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: """Circuit breaker for the current database.""" pass @circuit.setter @abstractmethod - def circuit(self, circuit: SyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass @@ -90,7 +90,7 @@ class Database(BaseDatabase, SyncDatabase): def __init__( self, client: Union[redis.Redis, RedisCluster], - circuit: SyncCircuitBreaker, + circuit: CircuitBreaker, weight: float, health_check_url: Optional[str] = None, ): @@ -117,9 +117,9 @@ def client(self, client: Union[redis.Redis, RedisCluster]): self._client = client @property - def circuit(self) -> SyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: return self._cb @circuit.setter - def circuit(self, circuit: SyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): self._cb = circuit \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py index 0c4e427264..0ac231cf52 100644 --- a/tests/test_asyncio/test_multidb/conftest.py +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -2,13 +2,14 @@ import pytest +from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL, \ + DatabaseConfig from redis.asyncio.multidb.failover import AsyncFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector from redis.asyncio.multidb.healthcheck import HealthCheck from redis.data_structure import WeightedList -from redis.multidb.circuit import State as CBState +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.asyncio import Redis -from redis.asyncio.multidb.circuit import AsyncCircuitBreaker from redis.asyncio.multidb.database import Database, Databases @@ -17,8 +18,8 @@ def mock_client() -> Redis: return Mock(spec=Redis) @pytest.fixture() -def mock_cb() -> AsyncCircuitBreaker: - return Mock(spec=AsyncCircuitBreaker) +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) @pytest.fixture() def mock_fd() -> AsyncFailureDetector: @@ -39,7 +40,7 @@ def mock_db(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -53,7 +54,7 @@ def mock_db1(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -67,13 +68,36 @@ def mock_db2(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) db.circuit = mock_cb return db +@pytest.fixture() +def mock_multi_db_config( + request, mock_fd, mock_fs, mock_hc, mock_ed +) -> MultiDbConfig: + hc_interval = request.param.get('hc_interval', None) + if hc_interval is None: + hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL + + auto_fallback_interval = request.param.get('auto_fallback_interval', None) + if auto_fallback_interval is None: + auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_check_interval=hc_interval, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed + ) + + return config + def create_weighted_list(*databases) -> Databases: dbs = WeightedList() diff --git a/tests/test_asyncio/test_multidb/test_circuit.py b/tests/test_asyncio/test_multidb/test_circuit.py deleted file mode 100644 index b1080cfc7d..0000000000 --- a/tests/test_asyncio/test_multidb/test_circuit.py +++ /dev/null @@ -1,58 +0,0 @@ -import pybreaker -import pytest - -from redis.asyncio.multidb.circuit import ( - AsyncPBCircuitBreakerAdapter, - State as CbState, -) -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter - - -class TestAsyncPBCircuitBreaker: - @pytest.mark.asyncio - @pytest.mark.parametrize( - 'mock_db', - [ - {'weight': 0.7, 'circuit': {'state': CbState.CLOSED}}, - ], - indirect=True, - ) - async def test_cb_correctly_configured(self, mock_db): - pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) - adapter = AsyncPBCircuitBreakerAdapter(cb=pb_circuit) - assert adapter.state == CbState.CLOSED - - adapter.state = CbState.OPEN - assert adapter.state == CbState.OPEN - - adapter.state = CbState.HALF_OPEN - assert adapter.state == CbState.HALF_OPEN - - adapter.state = CbState.CLOSED - assert adapter.state == CbState.CLOSED - - assert adapter.grace_period == 5 - adapter.grace_period = 10 - - assert adapter.grace_period == 10 - - adapter.database = mock_db - assert adapter.database == mock_db - - @pytest.mark.asyncio - async def test_cb_executes_callback_on_state_changed(self): - pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) - adapter = AsyncPBCircuitBreakerAdapter(cb=pb_circuit) - called_count = 0 - - def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): - nonlocal called_count - assert old_state == CbState.CLOSED - assert new_state == CbState.HALF_OPEN - assert isinstance(cb, PBCircuitBreakerAdapter) - called_count += 1 - - await adapter.on_state_changed(callback) - adapter.state = CbState.HALF_OPEN - - assert called_count == 1 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py new file mode 100644 index 0000000000..c2fe914e9f --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -0,0 +1,471 @@ +import asyncio +from unittest.mock import patch, AsyncMock, Mock + +import pybreaker +import pytest + +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES, DEFAULT_FAILOVER_BACKOFF +from redis.asyncio.multidb.database import AsyncDatabase +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF, HealthCheck +from redis.asyncio.retry import Retry +from redis.event import EventDispatcher, AsyncOnCommandsFailEvent +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.exception import NoValidDatabaseException +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +class TestMultiDbClient: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_on_successful_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + + mock_hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert await client.set('key', 'value') == 'OK1' + await asyncio.sleep(0.15) + assert await client.set('key', 'value') == 'OK2' + await asyncio.sleep(0.1) + assert await client.set('key', 'value') == 'OK' + await asyncio.sleep(0.1) + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = 0.2 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert await client.set('key', 'value') == 'OK1' + await asyncio.sleep(0.15) + assert await client.set('key', 'value') == 'OK2' + await asyncio.sleep(0.22) + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_command_throws_exception_on_failed_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): + await client.set('key', 'value') + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_database_throws_exception_on_same_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(ValueError, match='Given database already exists'): + await client.add_database(mock_db) + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_database_makes_new_database_active( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK2' + assert mock_hc.check_health.call_count == 2 + + await client.add_database(mock_db1) + assert mock_hc.check_health.call_count == 3 + + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_remove_highest_weighted_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.remove_database(mock_db1) + assert await client.set('key', 'value') == 'OK2' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_update_database_weight_to_be_highest( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 + + assert await client.set('key', 'value') == 'OK2' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_new_failure_detector( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_multi_db_config.event_dispatcher = EventDispatcher() + mock_fd = mock_multi_db_config.failure_detectors[0] + + # Event fired if command against mock_db1 would fail + command_fail_event = AsyncOnCommandsFailEvent( + commands=('SET', 'key', 'value'), + exception=Exception(), + ) + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + # Simulate failing command events that lead to a failure detection + for i in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async(command_fail_event) + + assert mock_fd.register_failure.call_count == 5 + + another_fd = Mock(spec=AsyncFailureDetector) + client.add_failure_detector(another_fd) + + # Simulate failing command events that lead to a failure detection + for i in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async(command_fail_event) + + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_new_health_check( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True + + await client.add_health_check(another_hc) + await client._check_db_health(mock_db1) + + assert mock_hc.check_health.call_count == 4 + assert another_hc.check_health.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_set_active_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db.client.execute_command.return_value = 'OK' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.set_active_database(mock_db) + assert await client.set('key', 'value') == 'OK' + + with pytest.raises(ValueError, match='Given database is not a member of database list'): + await client.set_active_database(Mock(spec=AsyncDatabase)) + + mock_hc.check_health.return_value = False + + with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): + await client.set_active_database(mock_db1) \ No newline at end of file diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 9503d79d9b..0c082f0f17 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -4,7 +4,7 @@ from redis import Redis from redis.data_structure import WeightedList -from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases @@ -19,8 +19,8 @@ def mock_client() -> Redis: return Mock(spec=Redis) @pytest.fixture() -def mock_cb() -> SyncCircuitBreaker: - return Mock(spec=SyncCircuitBreaker) +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) @pytest.fixture() def mock_fd() -> FailureDetector: @@ -41,7 +41,7 @@ def mock_db(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -55,7 +55,7 @@ def mock_db1(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -69,7 +69,7 @@ def mock_db2(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index f5f39c3f6b..7dc642373b 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -1,7 +1,7 @@ import pybreaker import pytest -from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker class TestPBCircuitBreaker: @@ -39,7 +39,7 @@ def test_cb_executes_callback_on_state_changed(self): adapter = PBCircuitBreakerAdapter(cb=pb_circuit) called_count = 0 - def callback(cb: SyncCircuitBreaker, old_state: CbState, new_state: CbState): + def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): nonlocal called_count assert old_state == CbState.CLOSED assert new_state == CbState.HALF_OPEN diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index c7c15fe684..d352c1da92 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -166,13 +166,9 @@ def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - sleep(0.15) - assert client.set('key', 'value') == 'OK2' - sleep(0.22) - assert client.set('key', 'value') == 'OK1' @pytest.mark.parametrize( diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index e428b3ce7a..1ea63a0e14 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,6 +1,6 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database @@ -49,11 +49,11 @@ def test_overridden_config(self): mock_connection_pools[0].connection_kwargs = {} mock_connection_pools[1].connection_kwargs = {} mock_connection_pools[2].connection_kwargs = {} - mock_cb1 = Mock(spec=SyncCircuitBreaker) + mock_cb1 = Mock(spec=CircuitBreaker) mock_cb1.grace_period = grace_period - mock_cb2 = Mock(spec=SyncCircuitBreaker) + mock_cb2 = Mock(spec=CircuitBreaker) mock_cb2.grace_period = grace_period - mock_cb3 = Mock(spec=SyncCircuitBreaker) + mock_cb3 = Mock(spec=CircuitBreaker) mock_cb3.grace_period = grace_period mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] @@ -113,7 +113,7 @@ def test_default_config(self): def test_overridden_config(self): mock_connection_pool = Mock(spec=ConnectionPool) - mock_circuit = Mock(spec=SyncCircuitBreaker) + mock_circuit = Mock(spec=CircuitBreaker) config = DatabaseConfig( client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit From e376544c55f97d7b3ca336402b42920c59afc33a Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 4 Sep 2025 10:24:20 +0300 Subject: [PATCH 06/20] Added scenario and config tests --- redis/asyncio/multidb/client.py | 4 +- redis/asyncio/multidb/command_executor.py | 1 - redis/event.py | 5 +- .../test_asyncio/test_multidb/test_config.py | 125 ++++++++++++++++++ tests/test_asyncio/test_scenario/__init__.py | 0 tests/test_asyncio/test_scenario/conftest.py | 88 ++++++++++++ .../test_scenario/test_active_active.py | 59 +++++++++ 7 files changed, 278 insertions(+), 4 deletions(-) create mode 100644 tests/test_asyncio/test_multidb/test_config.py create mode 100644 tests/test_asyncio/test_scenario/__init__.py create mode 100644 tests/test_asyncio/test_scenario/conftest.py create mode 100644 tests/test_asyncio/test_scenario/test_active_active.py diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index dbf03a3ef4..73eafd9026 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -61,10 +61,10 @@ async def raise_exception_on_failed_hc(error): await self._check_databases_health(on_error=raise_exception_on_failed_hc) # Starts recurring health checks on the background. - await self._bg_scheduler.run_recurring_async( + asyncio.create_task(self._bg_scheduler.run_recurring_async( self._health_check_interval, self._check_databases_health, - ) + )) is_active_db_found = False diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 22aef83118..af10a00988 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -248,7 +248,6 @@ async def _check_active_database(self): ) ): await self.set_active_database(await self._failover_strategy.database()) - print("Active database now with weight {}", format(self._active_database.weight)) self._schedule_next_fallback() async def _on_command_fail(self, error, *args): diff --git a/redis/event.py b/redis/event.py index 8327ec5f76..de38e1a069 100644 --- a/redis/event.py +++ b/redis/event.py @@ -108,7 +108,10 @@ async def dispatch_async(self, event: object): for listener in listeners: await listener.listen(event) - def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]): + def register_listeners( + self, + event_listeners: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]] + ): with self._lock: for event_type in event_listeners: if event_type in self._event_listeners_mapping: diff --git a/tests/test_asyncio/test_multidb/test_config.py b/tests/test_asyncio/test_multidb/test_config.py new file mode 100644 index 0000000000..64760740a1 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_config.py @@ -0,0 +1,125 @@ +from unittest.mock import Mock + +from redis.asyncio import ConnectionPool +from redis.asyncio.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_GRACE_PERIOD, \ + DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy, AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper, AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.asyncio.retry import Retry +from redis.multidb.circuit import CircuitBreaker + + +class TestMultiDbConfig: + def test_default_config(self): + db_configs = [ + DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0), + DatabaseConfig(client_kwargs={'host': 'host2', 'port': 'port2'}, weight=0.9), + DatabaseConfig(client_kwargs={'host': 'host3', 'port': 'port3'}, weight=0.8), + ] + + config = MultiDbConfig( + databases_config=db_configs + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD + assert db.client.get_retry() is not config.command_retry + i+=1 + + assert len(config.default_failure_detectors()) == 1 + assert isinstance(config.default_failure_detectors()[0], FailureDetectorAsyncWrapper) + assert len(config.default_health_checks()) == 1 + assert isinstance(config.default_health_checks()[0], EchoHealthCheck) + assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL + assert isinstance(config.default_failover_strategy(), WeightBasedFailoverStrategy) + assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL + assert isinstance(config.command_retry, Retry) + + def test_overridden_config(self): + grace_period = 2 + mock_connection_pools = [Mock(spec=ConnectionPool), Mock(spec=ConnectionPool), Mock(spec=ConnectionPool)] + mock_connection_pools[0].connection_kwargs = {} + mock_connection_pools[1].connection_kwargs = {} + mock_connection_pools[2].connection_kwargs = {} + mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1.grace_period = grace_period + mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2.grace_period = grace_period + mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3.grace_period = grace_period + mock_failure_detectors = [Mock(spec=AsyncFailureDetector), Mock(spec=AsyncFailureDetector)] + mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] + health_check_interval = 10 + mock_failover_strategy = Mock(spec=AsyncFailoverStrategy) + auto_fallback_interval = 10 + db_configs = [ + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, weight=1.0, circuit=mock_cb1 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, weight=0.9, circuit=mock_cb2 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, weight=0.8, circuit=mock_cb3 + ), + ] + + config = MultiDbConfig( + databases_config=db_configs, + failure_detectors=mock_failure_detectors, + health_checks=mock_health_checks, + health_check_interval=health_check_interval, + failover_strategy=mock_failover_strategy, + auto_fallback_interval=auto_fallback_interval, + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.client.connection_pool == mock_connection_pools[i] + assert db.circuit.grace_period == grace_period + i+=1 + + assert len(config.failure_detectors) == 2 + assert config.failure_detectors[0] == mock_failure_detectors[0] + assert config.failure_detectors[1] == mock_failure_detectors[1] + assert len(config.health_checks) == 2 + assert config.health_checks[0] == mock_health_checks[0] + assert config.health_checks[1] == mock_health_checks[1] + assert config.health_check_interval == health_check_interval + assert config.failover_strategy == mock_failover_strategy + assert config.auto_fallback_interval == auto_fallback_interval + +class TestDatabaseConfig: + def test_default_config(self): + config = DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0) + + assert config.client_kwargs == {'host': 'host1', 'port': 'port1'} + assert config.weight == 1.0 + assert isinstance(config.default_circuit_breaker(), CircuitBreaker) + + def test_overridden_config(self): + mock_connection_pool = Mock(spec=ConnectionPool) + mock_circuit = Mock(spec=CircuitBreaker) + + config = DatabaseConfig( + client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit + ) + + assert config.client_kwargs == {'connection_pool': mock_connection_pool} + assert config.weight == 1.0 + assert config.circuit == mock_circuit \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/__init__.py b/tests/test_asyncio/test_scenario/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py new file mode 100644 index 0000000000..312712ba05 --- /dev/null +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -0,0 +1,88 @@ +import os + +import pytest + +from redis.asyncio import Redis +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILURES_THRESHOLD, DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, \ + MultiDbConfig +from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff +from redis.event import AsyncEventListenerInterface, EventDispatcher +from tests.test_scenario.conftest import get_endpoint_config, extract_cluster_fqdn +from tests.test_scenario.fault_injector_client import FaultInjectorClient + + +class CheckActiveDatabaseChangedListener(AsyncEventListenerInterface): + def __init__(self): + self.is_changed_flag = False + + async def listen(self, event: AsyncActiveDatabaseChanged): + self.is_changed_flag = True + +@pytest.fixture() +def fault_injector_client(): + url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") + return FaultInjectorClient(url) + +@pytest.fixture() +def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: + client_class = request.param.get('client_class', Redis) + + if client_class == Redis: + endpoint_config = get_endpoint_config('re-active-active') + else: + endpoint_config = get_endpoint_config('re-active-active-oss-cluster') + + username = endpoint_config.get('username', None) + password = endpoint_config.get('password', None) + failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=2, base=0.05), retries=10)) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners({ + AsyncActiveDatabaseChanged: [listener], + }) + db_configs = [] + + db_config = DatabaseConfig( + weight=1.0, + from_url=endpoint_config['endpoints'][0], + client_kwargs={ + 'username': username, + 'password': password, + 'decode_responses': True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][0]) + ) + db_configs.append(db_config) + + db_config1 = DatabaseConfig( + weight=0.9, + from_url=endpoint_config['endpoints'][1], + client_kwargs={ + 'username': username, + 'password': password, + 'decode_responses': True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][1]) + ) + db_configs.append(db_config1) + + config = MultiDbConfig( + client_class=client_class, + databases_config=db_configs, + command_retry=command_retry, + failure_threshold=failure_threshold, + health_check_retries=3, + health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, + health_check_backoff=ExponentialBackoff(cap=5, base=0.5), + ) + + return MultiDBClient(config), listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py new file mode 100644 index 0000000000..833bb0776f --- /dev/null +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -0,0 +1,59 @@ +import asyncio +import logging +from time import sleep + +import pytest + +from tests.test_scenario.fault_injector_client import ActionRequest, ActionType + +logger = logging.getLogger(__name__) + +async def trigger_network_failure_action(fault_injector_client, config, event: asyncio.Event = None): + action_request = ActionRequest( + action_type=ActionType.NETWORK_FAILURE, + parameters={"bdb_id": config['bdb_id'], "delay": 2, "cluster_index": 0} + ) + + result = fault_injector_client.trigger_action(action_request) + status_result = fault_injector_client.get_action_status(result['action_id']) + + while status_result['status'] != "success": + await asyncio.sleep(0.1) + status_result = fault_injector_client.get_action_status(result['action_id']) + logger.info(f"Waiting for action to complete. Status: {status_result['status']}") + + if event: + event.set() + + logger.info(f"Action completed. Status: {status_result['status']}") + +class TestActiveActive: + + def teardown_method(self, method): + # Timeout so the cluster could recover from network failure. + sleep(5) + + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + # Client initialized on the first command. + await r_multi_db.set('key', 'value') + + # Execute commands before network failure + while not event.is_set(): + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + + # Execute commands until database failover + while not listener.is_changed_flag: + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) \ No newline at end of file From 57f6d8bb82bbb50213e8f3f238264f30436cc6a0 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 4 Sep 2025 12:11:39 +0300 Subject: [PATCH 07/20] Added pipeline and transaction support for MultiDBClient --- redis/asyncio/multidb/client.py | 114 ++++++- redis/asyncio/multidb/command_executor.py | 24 +- .../test_multidb/test_pipeline.py | 321 ++++++++++++++++++ tests/test_asyncio/test_scenario/conftest.py | 8 +- .../test_scenario/test_active_active.py | 130 +++++++ 5 files changed, 585 insertions(+), 12 deletions(-) create mode 100644 tests/test_asyncio/test_multidb/test_pipeline.py diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 73eafd9026..1025c4b37b 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -1,5 +1,5 @@ import asyncio -from typing import Callable, Optional, Coroutine, Any +from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable from redis.asyncio.multidb.command_executor import DefaultCommandExecutor from redis.asyncio.multidb.database import AsyncDatabase, Databases @@ -10,6 +10,7 @@ from redis.background import BackgroundScheduler from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands from redis.multidb.exception import NoValidDatabaseException +from redis.typing import KeyT class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): @@ -49,6 +50,19 @@ def __init__(self, config: MultiDbConfig): self._hc_lock = asyncio.Lock() self._bg_scheduler = BackgroundScheduler() self._config = config + self._hc_task = None + self._half_open_state_task = None + + async def __aenter__(self: "MultiDBClient") -> "MultiDBClient": + if not self.initialized: + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if self._hc_task: + self._hc_task.cancel() + if self._half_open_state_task: + self._half_open_state_task.cancel() async def initialize(self): """ @@ -61,7 +75,7 @@ async def raise_exception_on_failed_hc(error): await self._check_databases_health(on_error=raise_exception_on_failed_hc) # Starts recurring health checks on the background. - asyncio.create_task(self._bg_scheduler.run_recurring_async( + self._hc_task = asyncio.create_task(self._bg_scheduler.run_recurring_async( self._health_check_interval, self._check_databases_health, )) @@ -180,6 +194,34 @@ async def execute_command(self, *args, **options): return await self.command_executor.execute_command(*args, **options) + def pipeline(self): + """ + Enters into pipeline mode of the client. + """ + return Pipeline(self) + + async def transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): + """ + Executes callable as transaction. + """ + if not self.initialized: + await self.initialize() + + return await self.command_executor.execute_transaction( + func, + *watches, + shard_hint=shard_hint, + value_from_callable=value_from_callable, + watch_delay=watch_delay, + ) + async def _check_databases_health( self, on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, @@ -227,11 +269,75 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: loop = asyncio.get_running_loop() if new_state == CBState.HALF_OPEN: - asyncio.create_task(self._check_db_health(circuit.database)) + self._half_open_state_task = asyncio.create_task(self._check_db_health(circuit.database)) return if old_state == CBState.CLOSED and new_state == CBState.OPEN: loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) def _half_open_circuit(circuit: CircuitBreaker): - circuit.state = CBState.HALF_OPEN \ No newline at end of file + circuit.state = CBState.HALF_OPEN + +class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands): + """ + Pipeline implementation for multiple logical Redis databases. + """ + def __init__(self, client: MultiDBClient): + self._command_stack = [] + self._client = client + + async def __aenter__(self: "Pipeline") -> "Pipeline": + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + await self._client.__aexit__(exc_type, exc_value, traceback) + + def __await__(self): + return self._async_self().__await__() + + async def _async_self(self): + return self + + def __len__(self) -> int: + return len(self._command_stack) + + def __bool__(self) -> bool: + """Pipeline instances should always evaluate to True""" + return True + + async def reset(self) -> None: + self._command_stack = [] + + async def aclose(self) -> None: + """Close the pipeline""" + await self.reset() + + def pipeline_execute_command(self, *args, **options) -> "Pipeline": + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self._command_stack.append((args, options)) + return self + + def execute_command(self, *args, **kwargs): + """Adds a command to the stack""" + return self.pipeline_execute_command(*args, **kwargs) + + async def execute(self) -> List[Any]: + """Execute all the commands in the current pipeline""" + if not self._client.initialized: + await self._client.initialize() + + try: + return await self._client.command_executor.execute_pipeline(tuple(self._command_stack)) + finally: + await self.reset() \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index af10a00988..4133dba394 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -1,6 +1,6 @@ from abc import abstractmethod from datetime import datetime -from typing import List, Optional, Callable, Any +from typing import List, Optional, Callable, Any, Union, Awaitable from redis.asyncio.client import PubSub, Pipeline from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database @@ -13,6 +13,7 @@ from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.typing import KeyT class AsyncCommandExecutor(CommandExecutor): @@ -194,17 +195,30 @@ async def callback(): async def execute_pipeline(self, command_stack: tuple): async def callback(): - with self._active_database.client.pipeline() as pipe: + async with self._active_database.client.pipeline() as pipe: for command, options in command_stack: - await pipe.execute_command(*command, **options) + pipe.execute_command(*command, **options) return await pipe.execute() return await self._execute_with_failure_detection(callback, command_stack) - async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + async def execute_transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): async def callback(): - return await self._active_database.client.transaction(transaction, *watches, **options) + return await self._active_database.client.transaction( + func, + *watches, + shard_hint=shard_hint, + value_from_callable=value_from_callable, + watch_delay=watch_delay + ) return await self._execute_with_failure_detection(callback) diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py new file mode 100644 index 0000000000..5af2e3e864 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -0,0 +1,321 @@ +import asyncio +from unittest.mock import Mock, AsyncMock, patch + +import pybreaker +import pytest + +from redis.asyncio.client import Pipeline +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF +from redis.asyncio.retry import Retry +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.config import DEFAULT_FAILOVER_BACKOFF +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +def mock_pipe() -> Pipeline: + mock_pipe = Mock(spec=Pipeline) + mock_pipe.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_pipe.__aexit__ = AsyncMock(return_value=None) + return mock_pipe + +class TestPipeline: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_executes_pipeline_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + pipe = mock_pipe() + pipe.execute.return_value = ['OK1', 'value1'] + mock_db1.client.pipeline.return_value = pipe + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + pipe = client.pipeline() + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_pipeline_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + pipe = mock_pipe() + pipe.execute.return_value = ['OK1', 'value1'] + mock_db1.client.pipeline.return_value = pipe + + mock_hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async with client.pipeline() as pipe: + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + + pipe = mock_pipe() + pipe.execute.return_value = ['OK', 'value'] + mock_db.client.pipeline.return_value = pipe + + pipe1 = mock_pipe() + pipe1.execute.return_value = ['OK1', 'value'] + mock_db1.client.pipeline.return_value = pipe1 + + pipe2 = mock_pipe() + pipe2.execute.return_value = ['OK2', 'value'] + mock_db2.client.pipeline.return_value = pipe2 + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value'] + + await asyncio.sleep(0.15) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK2', 'value'] + + await asyncio.sleep(0.1) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK', 'value'] + + await asyncio.sleep(0.1) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value'] + +class TestTransaction: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_executes_transaction_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.transaction.return_value = ['OK1', 'value1'] + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await client.transaction(callback) == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_transaction_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.transaction.return_value = ['OK1', 'value1'] + + mock_hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await client.transaction(callback) == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + + mock_db.client.transaction.return_value = ['OK', 'value'] + mock_db1.client.transaction.return_value = ['OK1', 'value'] + mock_db2.client.transaction.return_value = ['OK2', 'value'] + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + + async def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await client.transaction(callback) == ['OK1', 'value'] + await asyncio.sleep(0.15) + assert await client.transaction(callback) == ['OK2', 'value'] + await asyncio.sleep(0.1) + assert await client.transaction(callback) == ['OK', 'value'] + await asyncio.sleep(0.1) + assert await client.transaction(callback) == ['OK1', 'value'] \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 312712ba05..18bc8f1417 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -1,6 +1,7 @@ import os import pytest +import pytest_asyncio from redis.asyncio import Redis from redis.asyncio.multidb.client import MultiDBClient @@ -26,8 +27,8 @@ def fault_injector_client(): url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") return FaultInjectorClient(url) -@pytest.fixture() -def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: +@pytest_asyncio.fixture() +async def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: client_class = request.param.get('client_class', Redis) if client_class == Redis: @@ -85,4 +86,5 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen health_check_backoff=ExponentialBackoff(cap=5, base=0.5), ) - return MultiDBClient(config), listener, endpoint_config \ No newline at end of file + async with MultiDBClient(config) as client: + return client, listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 833bb0776f..76db322253 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -4,6 +4,7 @@ import pytest +from redis.asyncio.client import Pipeline from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -33,6 +34,7 @@ def teardown_method(self, method): # Timeout so the cluster could recover from network failure. sleep(5) + @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", [{"failure_threshold": 2}], @@ -56,4 +58,132 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in # Execute commands until database failover while not listener.is_changed_flag: assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + # Client initialized on first pipe execution. + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + # Execute pipeline before network failure + while not event.is_set(): + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + # Execute pipeline until database failover + for _ in range(5): + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + # Client initialized on first pipe execution. + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + # Execute pipeline before network failure + while not event.is_set(): + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + # Execute pipeline until database failover + for _ in range(5): + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + async def callback(pipe: Pipeline): + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + + # Client initialized on first transaction execution. + await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] + + # Execute transaction before network failure + while not event.is_set(): + await r_multi_db.transaction(callback) + await asyncio.sleep(0.5) + + # Execute transaction until database failover + while not listener.is_changed_flag: + await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] await asyncio.sleep(0.5) \ No newline at end of file From 25eebb96824fdbbb56edc0957cd910408315b11c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 4 Sep 2025 16:01:58 +0300 Subject: [PATCH 08/20] Added pub/sub support for MultiDBClient --- redis/asyncio/client.py | 12 +- redis/asyncio/multidb/client.py | 135 +++++++++++++++++- redis/asyncio/multidb/command_executor.py | 12 +- redis/multidb/client.py | 8 +- .../test_scenario/test_active_active.py | 42 +++++- 5 files changed, 191 insertions(+), 18 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index aac409073f..4c000bd2e7 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -1191,6 +1191,7 @@ async def run( *, exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, poll_timeout: float = 1.0, + pubsub = None ) -> None: """Process pub/sub messages using registered callbacks. @@ -1215,9 +1216,14 @@ async def run( await self.connect() while True: try: - await self.get_message( - ignore_subscribe_messages=True, timeout=poll_timeout - ) + if pubsub is None: + await self.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) + else: + await pubsub.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) except asyncio.CancelledError: raise except BaseException as e: diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 1025c4b37b..7c0bef4f6e 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -1,6 +1,7 @@ import asyncio from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable +from redis.asyncio.client import PubSubHandler from redis.asyncio.multidb.command_executor import DefaultCommandExecutor from redis.asyncio.multidb.database import AsyncDatabase, Databases from redis.asyncio.multidb.failure_detector import AsyncFailureDetector @@ -10,7 +11,7 @@ from redis.background import BackgroundScheduler from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands from redis.multidb.exception import NoValidDatabaseException -from redis.typing import KeyT +from redis.typing import KeyT, EncodableT, ChannelT class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): @@ -222,6 +223,17 @@ async def transaction( watch_delay=watch_delay, ) + async def pubsub(self, **kwargs): + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + if not self.initialized: + await self.initialize() + + return PubSub(self, **kwargs) + async def _check_databases_health( self, on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, @@ -340,4 +352,123 @@ async def execute(self) -> List[Any]: try: return await self._client.command_executor.execute_pipeline(tuple(self._command_stack)) finally: - await self.reset() \ No newline at end of file + await self.reset() + +class PubSub: + """ + PubSub object for multi database client. + """ + def __init__(self, client: MultiDBClient, **kwargs): + """Initialize the PubSub object for a multi-database client. + + Args: + client: MultiDBClient instance to use for pub/sub operations + **kwargs: Additional keyword arguments to pass to the underlying pubsub implementation + """ + + self._client = client + self._client.command_executor.pubsub(**kwargs) + + async def __aenter__(self) -> "PubSub": + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + await self.aclose() + + async def aclose(self): + return await self._client.command_executor.execute_pubsub_method('aclose') + + @property + def subscribed(self) -> bool: + return self._client.command_executor.active_pubsub.subscribed + + async def execute_command(self, *args: EncodableT): + return await self._client.command_executor.execute_pubsub_method('execute_command', *args) + + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + return await self._client.command_executor.execute_pubsub_method( + 'psubscribe', + *args, + **kwargs + ) + + async def punsubscribe(self, *args: ChannelT): + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + return await self._client.command_executor.execute_pubsub_method( + 'punsubscribe', + *args + ) + + async def subscribe(self, *args: ChannelT, **kwargs: Callable): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + return await self._client.command_executor.execute_pubsub_method( + 'subscribe', + *args, + **kwargs + ) + + async def unsubscribe(self, *args): + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + return await self._client.command_executor.execute_pubsub_method( + 'unsubscribe', + *args + ) + + async def get_message( + self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number or None to wait indefinitely. + """ + return await self._client.command_executor.execute_pubsub_method( + 'get_message', + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + + async def run( + self, + *, + exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, + poll_timeout: float = 1.0, + ) -> None: + """Process pub/sub messages using registered callbacks. + + This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in + redis-py, but it is a coroutine. To launch it as a separate task, use + ``asyncio.create_task``: + + >>> task = asyncio.create_task(pubsub.run()) + + To shut it down, use asyncio cancellation: + + >>> task.cancel() + >>> await task + """ + return await self._client.command_executor.execute_pubsub_run( + exception_handler=exception_handler, + sleep_time=poll_timeout, + pubsub=self + ) \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 4133dba394..f7ae0e717b 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -178,14 +178,10 @@ def failover_strategy(self) -> AsyncFailoverStrategy: def command_retry(self) -> Retry: return self._command_retry - async def pubsub(self, **kwargs): - async def callback(): - if self._active_pubsub is None: - self._active_pubsub = self._active_database.client.pubsub(**kwargs) - self._active_pubsub_kwargs = kwargs - return None - - return await self._execute_with_failure_detection(callback) + def pubsub(self, **kwargs): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs async def execute_command(self, *args, **options): async def callback(): diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 71e079346a..e6b815c76f 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -337,9 +337,6 @@ def __init__(self, client: MultiDBClient, **kwargs): def __enter__(self) -> "PubSub": return self - def __exit__(self, exc_type, exc_value, traceback) -> None: - self.reset() - def __del__(self) -> None: try: # if this object went out of scope prior to shutting down @@ -350,7 +347,7 @@ def __del__(self) -> None: pass def reset(self) -> None: - pass + return self._client.command_executor.execute_pubsub_method('reset') def close(self) -> None: self.reset() @@ -359,6 +356,9 @@ def close(self) -> None: def subscribed(self) -> bool: return self._client.command_executor.active_pubsub.subscribed + def execute_command(self, *args): + return self._client.command_executor.execute_pubsub_method('execute_command', *args) + def psubscribe(self, *args, **kwargs): """ Subscribe to channel patterns. Patterns supplied as keyword arguments diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 76db322253..93068f6756 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -1,4 +1,5 @@ import asyncio +import json import logging from time import sleep @@ -186,4 +187,43 @@ async def callback(pipe: Pipeline): # Execute transaction until database failover while not listener.is_changed_flag: await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) \ No newline at end of file + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + data = json.dumps({'message': 'test'}) + messages_count = 0 + + async def handler(message): + nonlocal messages_count + messages_count += 1 + + pubsub = await r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + await pubsub.subscribe(**{'test-channel': handler}) + task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) + + # Execute publish before network failure + while not event.is_set(): + await r_multi_db.publish('test-channel', data) + await asyncio.sleep(0.5) + + # Execute publish until database failover + while not listener.is_changed_flag: + await r_multi_db.publish('test-channel', data) + await asyncio.sleep(0.5) + + task.cancel() + assert messages_count > 1 \ No newline at end of file From a82d8e7d0bc6552099c4eb691d0ac00011c3fc4c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 5 Sep 2025 10:43:16 +0300 Subject: [PATCH 09/20] Added check for couroutines methods for pub/sub --- redis/asyncio/multidb/command_executor.py | 6 +++++- tests/test_asyncio/test_scenario/test_active_active.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index f7ae0e717b..7133955740 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from asyncio import iscoroutinefunction from datetime import datetime from typing import List, Optional, Callable, Any, Union, Awaitable @@ -221,7 +222,10 @@ async def callback(): async def execute_pubsub_method(self, method_name: str, *args, **kwargs): async def callback(): method = getattr(self.active_pubsub, method_name) - return await method(*args, **kwargs) + if iscoroutinefunction(method): + return await method(*args, **kwargs) + else: + return method(*args, **kwargs) return await self._execute_with_failure_detection(callback, *args) diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 93068f6756..4d61434d8a 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -226,4 +226,5 @@ async def handler(message): await asyncio.sleep(0.5) task.cancel() + await pubsub.unsubscribe('test-channel') is True assert messages_count > 1 \ No newline at end of file From d38fb0a2517c84d01dd328e5ba378eebbae846b7 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 10 Sep 2025 10:39:00 +0300 Subject: [PATCH 10/20] Added OSS Cluster API support for MultiDBCLient --- redis/asyncio/multidb/client.py | 4 + redis/asyncio/multidb/command_executor.py | 4 + redis/asyncio/multidb/healthcheck.py | 2 +- redis/multidb/client.py | 4 + tests/test_asyncio/test_scenario/conftest.py | 5 +- .../test_scenario/test_active_active.py | 229 +++++++++--------- 6 files changed, 131 insertions(+), 117 deletions(-) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 7c0bef4f6e..e098a4723b 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -1,4 +1,5 @@ import asyncio +import logging from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable from redis.asyncio.client import PubSubHandler @@ -13,6 +14,7 @@ from redis.multidb.exception import NoValidDatabaseException from redis.typing import KeyT, EncodableT, ChannelT +logger = logging.getLogger(__name__) class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): """ @@ -274,6 +276,8 @@ async def _check_db_health( database.circuit.state = CBState.OPEN is_healthy = False + logger.exception('Health check failed, due to exception', exc_info=e) + if on_error: await on_error(e) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 7133955740..d63b19269d 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import List, Optional, Callable, Any, Union, Awaitable +from redis.asyncio import RedisCluster from redis.asyncio.client import PubSub, Pipeline from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ @@ -181,6 +182,9 @@ def command_retry(self) -> Retry: def pubsub(self, **kwargs): if self._active_pubsub is None: + if isinstance(self._active_database.client, RedisCluster): + raise ValueError("PubSub is not supported for RedisCluster") + self._active_pubsub = self._active_database.client.pubsub(**kwargs) self._active_pubsub_kwargs = kwargs diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index 7ae7bf34de..b8037e261e 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -67,7 +67,7 @@ async def _returns_echoed_message(self, database) -> bool: # For a cluster checks if all nodes are healthy. all_nodes = database.client.get_nodes() for node in all_nodes: - actual_message = await node.redis_connection.execute_command("ECHO" ,"healthcheck") + actual_message = await node.execute_command("ECHO" ,"healthcheck") if actual_message not in expected_message: return False diff --git a/redis/multidb/client.py b/redis/multidb/client.py index e6b815c76f..7ed1935c1c 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,3 +1,4 @@ +import logging import threading from typing import List, Any, Callable, Optional @@ -11,6 +12,7 @@ from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck +logger = logging.getLogger(__name__) class MultiDBClient(RedisModuleCommands, CoreCommands): """ @@ -232,6 +234,8 @@ def _check_db_health(self, database: SyncDatabase, on_error: Callable[[Exception database.circuit.state = CBState.OPEN is_healthy = False + logger.exception('Health check failed, due to exception', exc_info=e) + if on_error: on_error(e) diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 18bc8f1417..d6d96d0660 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -28,7 +28,7 @@ def fault_injector_client(): return FaultInjectorClient(url) @pytest_asyncio.fixture() -async def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: +async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChangedListener, dict]: client_class = request.param.get('client_class', Redis) if client_class == Redis: @@ -86,5 +86,4 @@ async def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChanged health_check_backoff=ExponentialBackoff(cap=5, base=0.5), ) - async with MultiDBClient(config) as client: - return client, listener, endpoint_config \ No newline at end of file + return config, listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 4d61434d8a..afc110d69b 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -5,7 +5,9 @@ import pytest -from redis.asyncio.client import Pipeline +from redis.asyncio import RedisCluster +from redis.asyncio.client import Pipeline, Redis +from redis.asyncio.multidb.client import MultiDBClient from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -33,60 +35,99 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(5) + sleep(6) @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", - [{"failure_threshold": 2}], + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], indirect=True ) @pytest.mark.timeout(50) async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): - r_multi_db, listener, config = r_multi_db + client_config, listener, endpoint_config = r_multi_db event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) - - # Client initialized on the first command. - await r_multi_db.set('key', 'value') + asyncio.create_task(trigger_network_failure_action(fault_injector_client,endpoint_config,event)) - # Execute commands before network failure - while not event.is_set(): - assert await r_multi_db.get('key') == 'value' - await asyncio.sleep(0.5) + async with MultiDBClient(client_config) as r_multi_db: + # Execute commands before network failure + while not event.is_set(): + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) - # Execute commands until database failover - while not listener.is_changed_flag: - assert await r_multi_db.get('key') == 'value' - await asyncio.sleep(0.5) + # Execute commands until database failover + while not listener.is_changed_flag: + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", - [{"failure_threshold": 2}], + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], indirect=True ) @pytest.mark.timeout(50) async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): - r_multi_db, listener, config = r_multi_db + client_config, listener, endpoint_config = r_multi_db event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + asyncio.create_task(trigger_network_failure_action(fault_injector_client,endpoint_config,event)) + + async with MultiDBClient(client_config) as r_multi_db: + # Execute pipeline before network failure + while not event.is_set(): + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + # Execute pipeline until database failover + for _ in range(5): + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) - # Client initialized on first pipe execution. - async with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + client_config, listener, endpoint_config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,endpoint_config,event)) - # Execute pipeline before network failure - while not event.is_set(): - async with r_multi_db.pipeline() as pipe: + async with MultiDBClient(client_config) as r_multi_db: + # Execute pipeline before network failure + while not event.is_set(): + pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') pipe.set('{hash}key3', 'value3') @@ -94,11 +135,11 @@ async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, pipe.get('{hash}key2') pipe.get('{hash}key3') assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + await asyncio.sleep(0.5) - # Execute pipeline until database failover - for _ in range(5): - async with r_multi_db.pipeline() as pipe: + # Execute pipeline until database failover + for _ in range(5): + pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') pipe.set('{hash}key3', 'value3') @@ -111,62 +152,19 @@ async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", - [{"failure_threshold": 2}], - indirect=True - ) - @pytest.mark.timeout(50) - async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): - r_multi_db, listener, config = r_multi_db - - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) - - # Client initialized on first pipe execution. - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - - # Execute pipeline before network failure - while not event.is_set(): - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) - - # Execute pipeline until database failover - for _ in range(5): - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "r_multi_db", - [{"failure_threshold": 2}], + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], indirect=True ) @pytest.mark.timeout(50) async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): - r_multi_db, listener, config = r_multi_db + client_config, listener, endpoint_config = r_multi_db event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + asyncio.create_task(trigger_network_failure_action(fault_injector_client,endpoint_config,event)) async def callback(pipe: Pipeline): pipe.set('{hash}key1', 'value1') @@ -176,18 +174,16 @@ async def callback(pipe: Pipeline): pipe.get('{hash}key2') pipe.get('{hash}key3') - # Client initialized on first transaction execution. - await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] + async with MultiDBClient(client_config) as r_multi_db: + # Execute transaction before network failure + while not event.is_set(): + await r_multi_db.transaction(callback) + await asyncio.sleep(0.5) - # Execute transaction before network failure - while not event.is_set(): - await r_multi_db.transaction(callback) - await asyncio.sleep(0.5) - - # Execute transaction until database failover - while not listener.is_changed_flag: - await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + # Execute transaction until database failover + while not listener.is_changed_flag: + await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -197,10 +193,10 @@ async def callback(pipe: Pipeline): ) @pytest.mark.timeout(50) async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): - r_multi_db, listener, config = r_multi_db + client_config, listener, endpoint_config = r_multi_db event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + asyncio.create_task(trigger_network_failure_action(fault_injector_client,endpoint_config,event)) data = json.dumps({'message': 'test'}) messages_count = 0 @@ -209,22 +205,29 @@ async def handler(message): nonlocal messages_count messages_count += 1 - pubsub = await r_multi_db.pubsub() + async with MultiDBClient(client_config) as r_multi_db: + pubsub = await r_multi_db.pubsub() - # Assign a handler and run in a separate thread. - await pubsub.subscribe(**{'test-channel': handler}) - task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) + # Assign a handler and run in a separate thread. + await pubsub.subscribe(**{'test-channel': handler}) + task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) - # Execute publish before network failure - while not event.is_set(): - await r_multi_db.publish('test-channel', data) - await asyncio.sleep(0.5) + # Execute publish before network failure + while not event.is_set(): + await r_multi_db.publish('test-channel', data) + await asyncio.sleep(0.5) - # Execute publish until database failover - while not listener.is_changed_flag: - await r_multi_db.publish('test-channel', data) - await asyncio.sleep(0.5) + # Execute publish until database failover + while not listener.is_changed_flag: + await r_multi_db.publish('test-channel', data) + await asyncio.sleep(0.5) + + # After db changed still generates some traffic. + for _ in range(5): + await r_multi_db.publish('test-channel', data) - task.cancel() - await pubsub.unsubscribe('test-channel') is True - assert messages_count > 1 \ No newline at end of file + # A timeout to ensure that an async handler will handle all previous messages. + await asyncio.sleep(0.1) + task.cancel() + await pubsub.unsubscribe('test-channel') is True + assert messages_count >= 5 \ No newline at end of file From 54db16b062a527ebc314d9480cadb1f9ae963dd9 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 11 Sep 2025 09:49:25 +0300 Subject: [PATCH 11/20] Added support for Lag-Aware Healthcheck and OSS Cluster API --- redis/asyncio/cluster.py | 4 +- redis/asyncio/http/__init__.py | 0 redis/asyncio/http/http_client.py | 216 ++++++++++++++++++ redis/asyncio/multidb/healthcheck.py | 104 ++++++++- redis/http/http_client.py | 10 +- redis/multidb/healthcheck.py | 2 +- .../test_multidb/test_healthcheck.py | 143 +++++++++++- tests/test_asyncio/test_scenario/conftest.py | 2 + .../test_scenario/test_active_active.py | 68 ++++-- tests/test_multidb/test_healthcheck.py | 4 +- 10 files changed, 521 insertions(+), 32 deletions(-) create mode 100644 redis/asyncio/http/__init__.py create mode 100644 redis/asyncio/http/http_client.py diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 956262696a..f957baa319 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -404,6 +404,7 @@ def __init__( else: self._event_dispatcher = event_dispatcher + self.startup_nodes = startup_nodes self.nodes_manager = NodesManager( startup_nodes, require_full_coverage, @@ -2199,7 +2200,8 @@ async def _reinitialize_on_error(self, error): await self._pipe.cluster_client.nodes_manager.initialize() self.reinitialize_counter = 0 else: - self._pipe.cluster_client.nodes_manager.update_moved_exception(error) + if type(error) == MovedError: + self._pipe.cluster_client.nodes_manager.update_moved_exception(error) self._executing = False diff --git a/redis/asyncio/http/__init__.py b/redis/asyncio/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/asyncio/http/http_client.py b/redis/asyncio/http/http_client.py new file mode 100644 index 0000000000..8f746b0a8b --- /dev/null +++ b/redis/asyncio/http/http_client.py @@ -0,0 +1,216 @@ +import asyncio +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Mapping, Union, Any +from redis.http.http_client import HttpResponse, HttpClient + +DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)" +DEFAULT_TIMEOUT = 30.0 +RETRY_STATUS_CODES = {429, 500, 502, 503, 504} + +class AsyncHTTPClient(ABC): + @abstractmethod + async def get( + self, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP GET request.""" + pass + + @abstractmethod + async def delete( + self, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP DELETE request.""" + pass + + @abstractmethod + async def post( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP POST request.""" + pass + + @abstractmethod + async def put( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP PUT request.""" + pass + + @abstractmethod + async def patch( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP PATCH request.""" + pass + + @abstractmethod + async def request( + self, + method: str, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[bytes, str]] = None, + timeout: Optional[float] = None, + ) -> HttpResponse: + """ + Invoke HTTP request with given method.""" + pass + +class AsyncHTTPClientWrapper(AsyncHTTPClient): + """ + An async wrapper around sync HTTP client with thread pool execution. + """ + def __init__( + self, + client: HttpClient, + max_workers: int = 10 + ) -> None: + """ + Initialize a new HTTP client instance. + + Args: + client: Sync HTTP client instance. + max_workers: Maximum number of concurrent requests. + + The client supports both regular HTTPS with server verification and mutual TLS + authentication. For server verification, provide CA certificate information via + ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client + certificate and key via client_cert_file and client_key_file. + """ + self.client = client + self._executor = ThreadPoolExecutor(max_workers=max_workers) + + async def get( + self, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.get, + path, params, headers, timeout, expect_json + ) + + async def delete( + self, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.delete, + path, params, headers, timeout, expect_json + ) + + async def post( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.post, + path, json_body, data, params, headers, timeout, expect_json + ) + + async def put( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.put, + path, json_body, data, params, headers, timeout, expect_json + ) + + async def patch( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.patch, + path, json_body, data, params, headers, timeout, expect_json + ) + + async def request( + self, + method: str, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[bytes, str]] = None, + timeout: Optional[float] = None, + ) -> HttpResponse: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.request, + method, path, params, headers, body, timeout + ) diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index b8037e261e..974e39d226 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -1,9 +1,13 @@ import logging from abc import ABC, abstractmethod +from typing import Optional, Tuple, Union from redis.asyncio import Redis +from redis.asyncio.http.http_client import AsyncHTTPClientWrapper, DEFAULT_TIMEOUT from redis.asyncio.retry import Retry +from redis.retry import Retry as SyncRetry from redis.backoff import ExponentialWithJitterBackoff +from redis.http.http_client import HttpClient from redis.utils import dummy_fail_async DEFAULT_HEALTH_CHECK_RETRIES = 3 @@ -72,4 +76,102 @@ async def _returns_echoed_message(self, database) -> bool: if actual_message not in expected_message: return False - return True \ No newline at end of file + return True + +class LagAwareHealthCheck(AbstractHealthCheck): + """ + Health check available for Redis Enterprise deployments. + Verify via REST API that the database is healthy based on different lags. + """ + def __init__( + self, + retry: SyncRetry = SyncRetry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), + rest_api_port: int = 9443, + lag_aware_tolerance: int = 100, + timeout: float = DEFAULT_TIMEOUT, + auth_basic: Optional[Tuple[str, str]] = None, + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, + ): + """ + Initialize LagAwareHealthCheck with the specified parameters. + + Args: + retry: Retry configuration for health checks + rest_api_port: Port number for Redis Enterprise REST API (default: 9443) + lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100) + timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) + auth_basic: Tuple of (username, password) for basic authentication + verify_tls: Whether to verify TLS certificates (default: True) + ca_file: Path to CA certificate file for TLS verification + ca_path: Path to CA certificates directory for TLS verification + ca_data: CA certificate data as string or bytes + client_cert_file: Path to client certificate file for mutual TLS + client_key_file: Path to client private key file for mutual TLS + client_key_password: Password for encrypted client private key + """ + super().__init__( + retry=retry, + ) + self._http_client = AsyncHTTPClientWrapper( + HttpClient( + timeout=timeout, + auth_basic=auth_basic, + retry=self.retry, + verify_tls=verify_tls, + ca_file=ca_file, + ca_path=ca_path, + ca_data=ca_data, + client_cert_file=client_cert_file, + client_key_file=client_key_file, + client_key_password=client_key_password + ) + ) + self._rest_api_port = rest_api_port + self._lag_aware_tolerance = lag_aware_tolerance + + async def check_health(self, database) -> bool: + if database.health_check_url is None: + raise ValueError( + "Database health check url is not set. Please check DatabaseConfig for the current database." + ) + + if isinstance(database.client, Redis): + db_host = database.client.get_connection_kwargs()["host"] + else: + db_host = database.client.startup_nodes[0].host + + base_url = f"{database.health_check_url}:{self._rest_api_port}" + self._http_client.client.base_url = base_url + + # Find bdb matching to the current database host + matching_bdb = None + for bdb in await self._http_client.get("/v1/bdbs"): + for endpoint in bdb["endpoints"]: + if endpoint['dns_name'] == db_host: + matching_bdb = bdb + break + + # In case if the host was set as public IP + for addr in endpoint['addr']: + if addr == db_host: + matching_bdb = bdb + break + + if matching_bdb is None: + logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") + raise ValueError("Could not find a matching bdb") + + url = (f"/v1/bdbs/{matching_bdb['uid']}/availability" + f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}") + await self._http_client.get(url, expect_json=False) + + # Status checked in an http client, otherwise HttpError will be raised + return True \ No newline at end of file diff --git a/redis/http/http_client.py b/redis/http/http_client.py index 0a2de2e44c..986e773915 100644 --- a/redis/http/http_client.py +++ b/redis/http/http_client.py @@ -68,7 +68,6 @@ class HttpClient: def __init__( self, base_url: str = "", - *, headers: Optional[Mapping[str, str]] = None, timeout: float = DEFAULT_TIMEOUT, retry: Retry = Retry( @@ -131,7 +130,6 @@ def __init__( def get( self, path: str, - *, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, @@ -150,7 +148,6 @@ def get( def delete( self, path: str, - *, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, @@ -169,7 +166,6 @@ def delete( def post( self, path: str, - *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, @@ -190,7 +186,6 @@ def post( def put( self, path: str, - *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, @@ -211,7 +206,6 @@ def put( def patch( self, path: str, - *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, @@ -234,7 +228,6 @@ def request( self, method: str, path: str, - *, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, body: Optional[Union[bytes, str]] = None, @@ -319,7 +312,6 @@ def _json_call( self, method: str, path: str, - *, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, @@ -340,7 +332,7 @@ def _json_call( return resp.json() return resp - def _prepare_body(self, *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]: + def _prepare_body(self, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]: if json_body is not None and data is not None: raise ValueError("Provide either json_body or data, not both.") if json_body is not None: diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 9818d06e28..5a21918513 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -166,7 +166,7 @@ def check_health(self, database) -> bool: logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") raise ValueError("Could not find a matching bdb") - url = (f"/v1/local/bdbs/{matching_bdb['uid']}/endpoint/availability" + url = (f"/v1/bdbs/{matching_bdb['uid']}/availability" f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}") self._http_client.get(url, expect_json=False) diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py index fd5c8ec3f0..ba6e8c2b7c 100644 --- a/tests/test_asyncio/test_multidb/test_healthcheck.py +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -1,10 +1,11 @@ import pytest -from mock.mock import AsyncMock +from mock.mock import AsyncMock, MagicMock from redis.asyncio.multidb.database import Database -from redis.asyncio.multidb.healthcheck import EchoHealthCheck +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff +from redis.http.http_client import HttpError from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -45,4 +46,140 @@ async def test_database_close_circuit_on_successful_healthcheck(self, mock_clien db = Database(mock_client, mock_cb, 0.9) assert await hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 \ No newline at end of file + assert mock_client.execute_command.call_count == 3 + +class TestLagAwareHealthCheck: + @pytest.mark.asyncio + async def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, mock_cb): + """ + Ensures health check succeeds when /v1/bdbs contains an endpoint whose dns_name + matches database host, and availability endpoint returns success. + """ + host = "db1.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + # Mock HttpClient used inside LagAwareHealthCheck + mock_http = AsyncMock() + mock_http.get.side_effect = [ + # First call: list of bdbs + [ + { + "uid": "bdb-1", + "endpoints": [ + {"dns_name": host, "addr": ["10.0.0.1", "10.0.0.2"]}, + ], + } + ], + # Second call: availability check (no JSON expected) + None, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + rest_api_port=1234, lag_aware_tolerance=150 + ) + # Inject our mocked http client + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert await hc.check_health(db) is True + # Base URL must be set correctly + assert hc._http_client.client.base_url == f"https://healthcheck.example.com:1234" + # Calls: first to list bdbs, then to availability + assert mock_http.get.call_count == 2 + first_call = mock_http.get.call_args_list[0] + second_call = mock_http.get.call_args_list[1] + assert first_call.args[0] == "/v1/bdbs" + assert second_call.args[0] == "/v1/bdbs/bdb-1/availability?extend_check=lag&availability_lag_tolerance_ms=150" + assert second_call.kwargs.get("expect_json") is False + + @pytest.mark.asyncio + async def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): + """ + Ensures health check succeeds when endpoint addr list contains the database host. + """ + host_ip = "203.0.113.5" + mock_client.get_connection_kwargs.return_value = {"host": host_ip} + + mock_http = AsyncMock() + mock_http.get.side_effect = [ + [ + { + "uid": "bdb-42", + "endpoints": [ + {"dns_name": "not-matching.example.com", "addr": [host_ip]}, + ], + } + ], + None, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert await hc.check_health(db) is True + assert mock_http.get.call_count == 2 + assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=100" + + @pytest.mark.asyncio + async def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): + """ + Ensures health check raises ValueError when there's no bdb matching the database host. + """ + host = "db2.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = AsyncMock() + # Return bdbs that do not match host by dns_name nor addr + mock_http.get.return_value = [ + {"uid": "a", "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}]}, + {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(ValueError, match="Could not find a matching bdb"): + await hc.check_health(db) + + # Only the listing call should have happened + mock_http.get.assert_called_once_with("/v1/bdbs") + + @pytest.mark.asyncio + async def test_propagates_http_error_from_availability(self, mock_client, mock_cb): + """ + Ensures that any HTTP error raised by the availability endpoint is propagated. + """ + host = "db3.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = AsyncMock() + # First: list bdbs -> match by dns_name + mock_http.get.side_effect = [ + [{"uid": "bdb-err", "endpoints": [{"dns_name": host, "addr": []}]}], + # Second: availability -> raise HttpError + HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(HttpError, match="busy") as e: + await hc.check_health(db) + assert e.status == 503 + + # Ensure both calls were attempted + assert mock_http.get.call_count == 2 \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index d6d96d0660..735af7fed6 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -44,6 +44,7 @@ async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChanged # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + health_checks = request.param.get('health_checks', []) event_dispatcher = EventDispatcher() listener = CheckActiveDatabaseChangedListener() event_dispatcher.register_listeners({ @@ -80,6 +81,7 @@ async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChanged databases_config=db_configs, command_retry=command_retry, failure_threshold=failure_threshold, + health_checks=health_checks, health_check_retries=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index afc110d69b..b95b868789 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import os from time import sleep import pytest @@ -8,6 +9,7 @@ from redis.asyncio import RedisCluster from redis.asyncio.client import Pipeline, Redis from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.healthcheck import LagAwareHealthCheck from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -51,10 +53,12 @@ def teardown_method(self, method): async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,endpoint_config,event)) - async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + + await r_multi_db.set('key', 'value') + # Execute commands before network failure while not event.is_set(): assert await r_multi_db.get('key') == 'value' @@ -65,6 +69,40 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in assert await r_multi_db.get('key') == 'value' await asyncio.sleep(0.5) + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "failure_threshold": 2, "health_checks": + [LagAwareHealthCheck(verify_tls=False, auth_basic=(os.getenv('ENV0_USERNAME'),os.getenv('ENV0_PASSWORD')))] + }, + {"client_class": RedisCluster, "failure_threshold": 2, "health_checks": + [LagAwareHealthCheck(verify_tls=False, auth_basic=(os.getenv('ENV0_USERNAME'),os.getenv('ENV0_PASSWORD')))] + }, + ], + ids=["standalone", "cluster"], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): + client_config, listener, endpoint_config = r_multi_db + + async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + + await r_multi_db.set('key', 'value') + + # Execute commands before network failure + while not event.is_set(): + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + + # Execute commands after network failure + while not listener.is_changed_flag: + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", @@ -79,10 +117,10 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,endpoint_config,event)) - async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + # Execute pipeline before network failure while not event.is_set(): async with r_multi_db.pipeline() as pipe: @@ -121,10 +159,10 @@ async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,endpoint_config,event)) - async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + # Execute pipeline before network failure while not event.is_set(): pipe = r_multi_db.pipeline() @@ -163,9 +201,6 @@ async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_ async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,endpoint_config,event)) - async def callback(pipe: Pipeline): pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -175,6 +210,9 @@ async def callback(pipe: Pipeline): pipe.get('{hash}key3') async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + # Execute transaction before network failure while not event.is_set(): await r_multi_db.transaction(callback) @@ -195,9 +233,6 @@ async def callback(pipe: Pipeline): async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,endpoint_config,event)) - data = json.dumps({'message': 'test'}) messages_count = 0 @@ -206,6 +241,9 @@ async def handler(message): messages_count += 1 async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + pubsub = await r_multi_db.pubsub() # Assign a handler and run in a separate thread. diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 18bfe5f23b..77886832e7 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -88,7 +88,7 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc first_call = mock_http.get.call_args_list[0] second_call = mock_http.get.call_args_list[1] assert first_call.args[0] == "/v1/bdbs" - assert second_call.args[0] == "/v1/local/bdbs/bdb-1/endpoint/availability?extend_check=lag&availability_lag_tolerance_ms=150" + assert second_call.args[0] == "/v1/bdbs/bdb-1/availability?extend_check=lag&availability_lag_tolerance_ms=150" assert second_call.kwargs.get("expect_json") is False def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): @@ -120,7 +120,7 @@ def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb assert hc.check_health(db) is True assert mock_http.get.call_count == 2 - assert mock_http.get.call_args_list[1].args[0] == "/v1/local/bdbs/bdb-42/endpoint/availability?extend_check=lag&availability_lag_tolerance_ms=100" + assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=100" def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): """ From 443186165c9b7bcac3588c88562d63b56df1eb3d Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 11 Sep 2025 10:04:19 +0300 Subject: [PATCH 12/20] Increased timeouts between tests --- tests/test_asyncio/test_scenario/test_active_active.py | 2 +- tests/test_scenario/test_active_active.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 770cfab692..c054d17dc2 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -37,7 +37,7 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(6) + sleep(10) @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 44c57e6b99..c87ad903b1 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -36,7 +36,7 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(5) + sleep(10) @pytest.mark.parametrize( "r_multi_db", From 9ed190199b73bb572a214e46494baa91fa8fd2cb Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 12 Sep 2025 09:41:01 +0300 Subject: [PATCH 13/20] [Sync] Refactored healthcheck --- redis/multidb/client.py | 68 ++++---- redis/multidb/config.py | 18 +- redis/multidb/exception.py | 10 +- redis/multidb/healthcheck.py | 184 +++++++++++++++----- tests/test_multidb/conftest.py | 18 +- tests/test_multidb/test_client.py | 56 +++--- tests/test_multidb/test_healthcheck.py | 197 +++++++++++++++++++--- tests/test_multidb/test_pipeline.py | 43 +++-- tests/test_scenario/conftest.py | 7 +- tests/test_scenario/test_active_active.py | 20 +-- 10 files changed, 439 insertions(+), 182 deletions(-) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 7ed1935c1c..748cef4855 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,5 +1,7 @@ import logging import threading +from concurrent.futures import as_completed +from concurrent.futures.thread import ThreadPoolExecutor from typing import List, Any, Callable, Optional from redis.background import BackgroundScheduler @@ -8,9 +10,9 @@ from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.database import Database, Databases, SyncDatabase -from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck +from redis.multidb.healthcheck import HealthCheck, HealthCheckPolicy logger = logging.getLogger(__name__) @@ -27,6 +29,10 @@ def __init__(self, config: MultiDbConfig): self._health_checks.extend(config.health_checks) self._health_check_interval = config.health_check_interval + self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( + config.health_check_probes, + config.health_check_delay + ) self._failure_detectors = config.default_failure_detectors() if config.failure_detectors is not None: @@ -209,44 +215,48 @@ def pubsub(self, **kwargs): return PubSub(self, **kwargs) - def _check_db_health(self, database: SyncDatabase, on_error: Callable[[Exception], None] = None) -> None: + def _check_db_health(self, database: SyncDatabase) -> bool: """ Runs health checks on the given database until first failure. """ - is_healthy = True - - with self._hc_lock: - # Health check will setup circuit state - for health_check in self._health_checks: - if not is_healthy: - # If one of the health checks failed, it's considered unhealthy - break + # Health check will setup circuit state + is_healthy = self._health_check_policy.execute(self._health_checks, database) - try: - is_healthy = health_check.check_health(database) - - if not is_healthy and database.circuit.state != CBState.OPEN: - database.circuit.state = CBState.OPEN - elif is_healthy and database.circuit.state != CBState.CLOSED: - database.circuit.state = CBState.CLOSED - except Exception as e: - if database.circuit.state != CBState.OPEN: - database.circuit.state = CBState.OPEN - is_healthy = False - - logger.exception('Health check failed, due to exception', exc_info=e) - - if on_error: - on_error(e) + if not is_healthy: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + return is_healthy + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + return is_healthy def _check_databases_health(self, on_error: Callable[[Exception], None] = None): """ Runs health checks as a recurring task. Runs health checks against all databases. """ - for database, _ in self._databases: - self._check_db_health(database, on_error) + with ThreadPoolExecutor(max_workers=len(self._databases)) as executor: + # Submit all health checks + futures = { + executor.submit(self._check_db_health, database) + for database, _ in self._databases + } + + for future in as_completed(futures, timeout=self._health_check_interval): + try: + future.result() + except UnhealthyDatabaseException as e: + unhealthy_db = e.database + unhealthy_db.circuit.state = CBState.OPEN + + logger.exception( + 'Health check failed, due to exception', + exc_info=e.original_exception + ) + + if on_error: + on_error(e.original_exception) def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: diff --git a/redis/multidb/config.py b/redis/multidb/config.py index fc349ed04b..e44b3eaae7 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -12,14 +12,14 @@ from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ - DEFAULT_HEALTH_CHECK_BACKOFF +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_PROBES, \ + DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy from redis.retry import Retry DEFAULT_GRACE_PERIOD = 5.0 -DEFAULT_HEALTH_CHECK_INTERVAL = 5 DEFAULT_FAILURES_THRESHOLD = 3 +DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL DEFAULT_FAILURES_DURATION = 2 DEFAULT_FAILOVER_RETRIES = 3 DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) @@ -79,8 +79,9 @@ class MultiDbConfig: failures_interval: Time interval for tracking database failures. health_checks: Optional list of additional health checks performed on databases. health_check_interval: Time interval for executing health checks. - health_check_retries: Number of retry attempts for performing health checks. - health_check_backoff: Backoff strategy for health check retries. + health_check_probes: Number of attempts to evaluate the health of a database. + health_check_delay: Delay between health check attempts. + health_check_policy: Policy for determining database health based on health checks. failover_strategy: Optional strategy for handling database failover scenarios. failover_retries: Number of retries allowed for failover operations. failover_backoff: Backoff strategy for failover retries. @@ -114,8 +115,9 @@ class MultiDbConfig: failures_interval: float = DEFAULT_FAILURES_DURATION health_checks: Optional[List[HealthCheck]] = None health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL - health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES - health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF + health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES + health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY + health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY failover_strategy: Optional[FailoverStrategy] = None failover_retries: int = DEFAULT_FAILOVER_RETRIES failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF @@ -159,7 +161,7 @@ def default_failure_detectors(self) -> List[FailureDetector]: def default_health_checks(self) -> List[HealthCheck]: return [ - EchoHealthCheck(retry=Retry(retries=self.health_check_retries, backoff=self.health_check_backoff)), + EchoHealthCheck(), ] def default_failover_strategy(self) -> FailoverStrategy: diff --git a/redis/multidb/exception.py b/redis/multidb/exception.py index 80fdb9409a..b49896c34d 100644 --- a/redis/multidb/exception.py +++ b/redis/multidb/exception.py @@ -1,2 +1,10 @@ class NoValidDatabaseException(Exception): - pass \ No newline at end of file + pass + +class UnhealthyDatabaseException(Exception): + """Exception raised when a database is unhealthy due to an underlying exception.""" + + def __init__(self, message, database, original_exception): + super().__init__(message) + self.database = database + self.original_exception = original_exception \ No newline at end of file diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 5a21918513..eadfdf0e6f 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,66 +1,164 @@ import logging from abc import abstractmethod, ABC -from typing import Optional, Tuple, Union +from enum import Enum +from time import sleep +from typing import Optional, Tuple, Union, List from redis import Redis -from redis.backoff import ExponentialWithJitterBackoff +from redis.backoff import NoBackoff from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient +from redis.multidb.exception import UnhealthyDatabaseException from redis.retry import Retry -from redis.utils import dummy_fail - -DEFAULT_HEALTH_CHECK_RETRIES = 3 -DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) +DEFAULT_HEALTH_CHECK_PROBES = 3 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_HEALTH_CHECK_DELAY = 0.5 logger = logging.getLogger(__name__) class HealthCheck(ABC): + @abstractmethod + def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class HealthCheckPolicy(ABC): + """ + Health checks execution policy. + """ @property @abstractmethod - def retry(self) -> Retry: - """The retry object to use for health checks.""" + def health_check_probes(self) -> int: + """Number of probes to execute health checks.""" pass + @property @abstractmethod - def check_health(self, database) -> bool: - """Function to determine the health status.""" + def health_check_delay(self) -> float: + """Delay between health check probes.""" pass -class AbstractHealthCheck(HealthCheck): - def __init__( - self, - retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) -> None: - self._retry = retry - self._retry.update_supported_errors([ConnectionRefusedError]) + @abstractmethod + def execute(self, health_checks: List[HealthCheck], database) -> bool: + """Execute health checks and return database health status.""" + pass + +class AbstractHealthCheckPolicy(HealthCheckPolicy): + def __init__(self, health_check_probes: int, health_check_delay: float): + if health_check_probes < 1: + raise ValueError("health_check_probes must be greater than 0") + self._health_check_probes = health_check_probes + self._health_check_delay = health_check_delay @property - def retry(self) -> Retry: - return self._retry + def health_check_probes(self) -> int: + return self._health_check_probes - @abstractmethod - def check_health(self, database) -> bool: + @property + def health_check_delay(self) -> float: + return self._health_check_delay + + def execute(self, health_checks: List[HealthCheck], database) -> bool: pass +class HealthyAllPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if all health check probes are successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) -class EchoHealthCheck(AbstractHealthCheck): - def __init__( - self, - retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) -> None: - """ - Check database healthiness by sending an echo request. - """ - super().__init__( - retry=retry, - ) - def check_health(self, database) -> bool: - return self._retry.call_with_retry( - lambda: self._returns_echoed_message(database), - lambda _: dummy_fail() - ) + def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + for attempt in range(self.health_check_probes): + try: + if not health_check.check_health(database): + return False + except Exception as e: + raise UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + sleep(self._health_check_delay) + return True + +class HealthyMajorityPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if a majority of health check probes are successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + if self.health_check_probes % 2 == 0: + unsuccessful_probes = self.health_check_probes / 2 + else: + unsuccessful_probes = (self.health_check_probes + 1) / 2 + + for attempt in range(self.health_check_probes): + try: + if not health_check.check_health(database): + unsuccessful_probes -= 1 + if unsuccessful_probes <= 0: + return False + except Exception as e: + unsuccessful_probes -= 1 + if unsuccessful_probes <= 0: + raise UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + sleep(self._health_check_delay) + return True + +class HealthyAnyPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if at least one health check probe is successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) - def _returns_echoed_message(self, database) -> bool: + def execute(self, health_checks: List[HealthCheck], database) -> bool: + is_healthy = False + + for health_check in health_checks: + exception = None + + for attempt in range(self.health_check_probes): + try: + if health_check.check_health(database): + is_healthy = True + break + else: + is_healthy = False + except Exception as e: + exception = UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + sleep(self._health_check_delay) + + if not is_healthy and not exception: + return is_healthy + elif not is_healthy and exception: + raise exception + + return is_healthy + +class HealthCheckPolicies(Enum): + HEALTHY_ALL = HealthyAllPolicy + HEALTHY_MAJORITY = HealthyMajorityPolicy + HEALTHY_ANY = HealthyAnyPolicy + +class EchoHealthCheck(HealthCheck): + """ + Health check based on ECHO command. + """ + def check_health(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] if isinstance(database.client, Redis): @@ -77,14 +175,14 @@ def _returns_echoed_message(self, database) -> bool: return True -class LagAwareHealthCheck(AbstractHealthCheck): + +class LagAwareHealthCheck(HealthCheck): """ Health check available for Redis Enterprise deployments. Verify via REST API that the database is healthy based on different lags. """ def __init__( self, - retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), rest_api_port: int = 9443, lag_aware_tolerance: int = 100, timeout: float = DEFAULT_TIMEOUT, @@ -103,7 +201,6 @@ def __init__( Initialize LagAwareHealthCheck with the specified parameters. Args: - retry: Retry configuration for health checks rest_api_port: Port number for Redis Enterprise REST API (default: 9443) lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100) timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) @@ -116,13 +213,10 @@ def __init__( client_key_file: Path to client private key file for mutual TLS client_key_password: Password for encrypted client private key """ - super().__init__( - retry=retry, - ) self._http_client = HttpClient( timeout=timeout, auth_basic=auth_basic, - retry=self.retry, + retry=Retry(NoBackoff(), retries=0), verify_tls=verify_tls, ca_file=ca_file, ca_path=ca_path, diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 0c082f0f17..f47da3b174 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -6,11 +6,11 @@ from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_AUTO_FALLBACK_INTERVAL + DEFAULT_AUTO_FALLBACK_INTERVAL, DEFAULT_HEALTH_CHECK_POLICY from redis.multidb.database import Database, Databases from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck +from redis.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_PROBES from tests.conftest import mock_ed @@ -80,18 +80,18 @@ def mock_db2(request) -> Database: def mock_multi_db_config( request, mock_fd, mock_fs, mock_hc, mock_ed ) -> MultiDbConfig: - hc_interval = request.param.get('hc_interval', None) - if hc_interval is None: - hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL - - auto_fallback_interval = request.param.get('auto_fallback_interval', None) - if auto_fallback_interval is None: - auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + hc_interval = request.param.get('hc_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + auto_fallback_interval = request.param.get('auto_fallback_interval', DEFAULT_AUTO_FALLBACK_INTERVAL) + health_check_policy = request.param.get('health_check_policy', DEFAULT_HEALTH_CHECK_POLICY) + health_check_probes = request.param.get('health_check_probes', DEFAULT_HEALTH_CHECK_PROBES) config = MultiDbConfig( databases_config=[Mock(spec=DatabaseConfig)], failure_detectors=[mock_fd], health_check_interval=hc_interval, + health_check_delay=0.05, + health_check_policy=health_check_policy, + health_check_probes=health_check_probes, failover_strategy=mock_fs, auto_fallback_interval=auto_fallback_interval, event_dispatcher=mock_ed diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index d352c1da92..4cac5c51ec 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -13,8 +13,7 @@ from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ - DEFAULT_HEALTH_CHECK_BACKOFF +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_PROBES from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list @@ -46,7 +45,7 @@ def test_execute_command_against_correct_db_on_successful_initialization( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -73,12 +72,12 @@ def test_execute_command_against_correct_db_and_closed_circuit( patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -88,7 +87,7 @@ def test_execute_command_against_correct_db_and_closed_circuit( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -114,31 +113,30 @@ def test_execute_command_against_correct_db_on_background_health_check_determine databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] - mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_check_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - sleep(0.15) + sleep(0.3) assert client.set('key', 'value') == 'OK2' - sleep(0.1) + sleep(0.2) assert client.set('key', 'value') == 'OK' - sleep(0.1) + sleep(0.2) assert client.set('key', 'value') == 'OK1' @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -152,23 +150,21 @@ def test_execute_command_auto_fallback_to_highest_weight_db( databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] - mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.auto_fallback_interval = 0.2 + mock_multi_db_config.health_check_interval = 0.2 + mock_multi_db_config.auto_fallback_interval = 0.4 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - sleep(0.15) + sleep(0.30) assert client.set('key', 'value') == 'OK2' - sleep(0.22) + sleep(0.44) assert client.set('key', 'value') == 'OK1' @pytest.mark.parametrize( @@ -256,10 +252,10 @@ def test_add_database_makes_new_database_active( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK2' - assert mock_hc.check_health.call_count == 2 + assert mock_hc.check_health.call_count == 6 client.add_database(mock_db1) - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 assert client.set('key', 'value') == 'OK1' @@ -291,7 +287,7 @@ def test_remove_highest_weighted_database( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 client.remove_database(mock_db1) @@ -325,7 +321,7 @@ def test_update_database_weight_to_be_highest( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 client.update_database_weight(mock_db2, 0.8) assert mock_db2.weight == 0.8 @@ -366,7 +362,7 @@ def test_add_new_failure_detector( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 # Simulate failing command events that lead to a failure detection for i in range(5): @@ -410,7 +406,7 @@ def test_add_new_health_check( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 another_hc = Mock(spec=HealthCheck) another_hc.check_health.return_value = True @@ -418,8 +414,8 @@ def test_add_new_health_check( client.add_health_check(another_hc) client._check_db_health(mock_db1) - assert mock_hc.check_health.call_count == 4 - assert another_hc.check_health.call_count == 1 + assert mock_hc.check_health.call_count == 12 + assert another_hc.check_health.call_count == 3 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -448,7 +444,7 @@ def test_set_active_database( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 client.set_active_database(mock_db) assert client.set('key', 'value') == 'OK' diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 77886832e7..5f8be6add5 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -1,15 +1,171 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest -from redis.backoff import ExponentialBackoff from redis.multidb.database import Database from redis.http.http_client import HttpError -from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck +from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck, HealthCheck, HealthyAllPolicy, \ + UnhealthyDatabaseException, HealthyMajorityPolicy, HealthyAnyPolicy from redis.multidb.circuit import State as CBState -from redis.exceptions import ConnectionError -from redis.retry import Retry +class TestHealthyAllPolicy: + def test_policy_returns_true_for_all_successful_probes(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.return_value = True + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) == True + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 3 + + def test_policy_returns_false_on_first_failed_probe(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, False] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) == False + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + + def test_policy_raise_unhealthy_database_exception(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, ConnectionError] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + +class TestHealthyMajorityPolicy: + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + (3, [True, False, False], [True, True, True], 3, 0, False), + (3, [True, True, True], [True, False, False], 3, 3, False), + (3, [True, False, True], [True, True, True], 3, 3, True), + (3, [True, True, True], [True, False, True], 3, 3, True), + (3, [True, True, False], [True, False, True], 3, 3, True), + (4, [True, True, False, False], [True, True, True, True], 4, 0, False), + (4, [True, True, True, True], [True, True, False, False], 4, 4, False), + (4, [False, True, True, True], [True, True, True, True], 4, 4, True), + (4, [True, True, True, True], [True, False, True, True], 4, 4, True), + (4, [False, True, True, True], [True, True, False, True], 4, 4, True), + ], + ids=[ + 'HC1 - no majority - odd', 'HC2 - no majority - odd', 'HC1 - majority- odd', + 'HC2 - majority - odd', 'HC1 + HC2 - majority - odd', 'HC1 - no majority - even', + 'HC2 - no majority - even','HC1 - majority - even', 'HC2 - majority - even', + 'HC1 + HC2 - majority - even' + ] + ) + def test_policy_returns_true_for_majority_successful_probes( + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyMajorityPolicy(probes, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count", + [ + (3, [True, ConnectionError, ConnectionError], [True, True, True], 3, 0), + (3, [True, True, True], [True, ConnectionError, ConnectionError], 3, 3), + (4, [True, ConnectionError, ConnectionError, True], [True, True, True, True], 3, 0), + (4, [True, True, True, True], [True, ConnectionError, ConnectionError, False], 4, 3), + ], + ids=[ + 'HC1 - majority- odd', 'HC2 - majority - odd', + 'HC1 - majority - even', 'HC2 - majority - even', + ] + ) + def test_policy_raise_unhealthy_database_exception_on_majority_probes_exceptions( + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + +class TestHealthyAnyPolicy: + @pytest.mark.parametrize( + "hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + ([False, False, False], [True, True, True], 3, 0, False), + ([False, False, True], [False, False, False], 3, 3, False), + ([False, True, True], [False, False, True], 2, 3, True), + ([True, True, True], [False, True, False], 1, 2, True), + ], + ids=[ + 'HC1 - no successful', 'HC2 - no successful', + 'HC1 - successful', 'HC2 - successful', + ] + ) + def test_policy_returns_true_for_any_successful_probe( + self, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed_health_check(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [False, False, ConnectionError] + mock_hc2.check_health.side_effect = [True, True, True] + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 class TestEchoHealthCheck: def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): @@ -17,33 +173,33 @@ def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + mock_client.execute_command.return_value = 'healthcheck' + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): """ Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'wrong'] - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + mock_client.execute_command.return_value = 'wrong' + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == False - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): - mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] + mock_client.execute_command.return_value = 'healthcheck' mock_cb.state = CBState.HALF_OPEN - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 class TestLagAwareHealthCheck: @@ -72,7 +228,6 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc ] hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), rest_api_port=1234, lag_aware_tolerance=150 ) # Inject our mocked http client @@ -111,9 +266,7 @@ def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb None, ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") @@ -136,9 +289,7 @@ def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") @@ -164,9 +315,7 @@ def test_propagates_http_error_from_availability(self, mock_client, mock_cb): HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index 6e7c344d85..608cc3373d 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -10,7 +10,7 @@ from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF from redis.multidb.failover import WeightBasedFailoverStrategy -from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF +from redis.multidb.healthcheck import EchoHealthCheck from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list @@ -54,7 +54,7 @@ def test_executes_pipeline_against_correct_db( pipe.get('key1') assert pipe.execute() == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -79,7 +79,7 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.execute.return_value = ['OK1', 'value1'] mock_db1.client.pipeline.return_value = pipe - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -89,7 +89,7 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.get('key1') assert pipe.execute() == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -99,7 +99,7 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -125,9 +125,8 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -144,7 +143,7 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin pipe2.execute.return_value = ['OK2', 'value'] mock_db2.client.pipeline.return_value = pipe2 - mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_check_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) @@ -157,7 +156,7 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin assert pipe.execute() == ['OK1', 'value'] - sleep(0.15) + sleep(0.3) with client.pipeline() as pipe: pipe.set('key1', 'value') @@ -165,7 +164,7 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin assert pipe.execute() == ['OK2', 'value'] - sleep(0.1) + sleep(0.2) with client.pipeline() as pipe: pipe.set('key1', 'value') @@ -173,7 +172,7 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin assert pipe.execute() == ['OK', 'value'] - sleep(0.1) + sleep(0.2) with client.pipeline() as pipe: pipe.set('key1', 'value') @@ -214,7 +213,7 @@ def callback(pipe: Pipeline): pipe.get('key1') assert client.transaction(callback) == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -237,7 +236,7 @@ def test_execute_transaction_against_correct_db_and_closed_circuit( patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.transaction.return_value = ['OK1', 'value1'] - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -247,7 +246,7 @@ def callback(pipe: Pipeline): pipe.get('key1') assert client.transaction(callback) == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -257,7 +256,7 @@ def callback(pipe: Pipeline): 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -283,9 +282,7 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -294,7 +291,7 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter mock_db1.client.transaction.return_value = ['OK1', 'value'] mock_db2.client.transaction.return_value = ['OK2', 'value'] - mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_check_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) @@ -306,9 +303,9 @@ def callback(pipe: Pipeline): pipe.get('key1') assert client.transaction(callback) == ['OK1', 'value'] - sleep(0.15) + sleep(0.3) assert client.transaction(callback) == ['OK2', 'value'] - sleep(0.1) + sleep(0.2) assert client.transaction(callback) == ['OK', 'value'] - sleep(0.1) + sleep(0.2) assert client.transaction(callback) == ['OK1', 'value'] \ No newline at end of file diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index a0f19e1a87..54325e58a5 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -12,7 +12,7 @@ from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_FAILURES_THRESHOLD from redis.multidb.event import ActiveDatabaseChanged -from redis.multidb.healthcheck import EchoHealthCheck +from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_DELAY from redis.retry import Retry from tests.test_scenario.fault_injector_client import FaultInjectorClient @@ -61,6 +61,7 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + health_check_delay = request.param.get('health_check_delay', DEFAULT_HEALTH_CHECK_DELAY) event_dispatcher = EventDispatcher() listener = CheckActiveDatabaseChangedListener() event_dispatcher.register_listeners({ @@ -97,10 +98,10 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen databases_config=db_configs, command_retry=command_retry, failure_threshold=failure_threshold, - health_check_retries=3, + health_check_probes=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, - health_check_backoff=ExponentialBackoff(cap=5, base=0.5), + health_check_delay=health_check_delay, ) return MultiDBClient(config), listener, endpoint_config diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index c87ad903b1..23a75886da 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -36,7 +36,7 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(10) + sleep(15) @pytest.mark.parametrize( "r_multi_db", @@ -47,7 +47,7 @@ def teardown_method(self, method): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db @@ -75,13 +75,13 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "failure_threshold": 2, "health_check_interval": 10}, + {"client_class": RedisCluster, "failure_threshold": 2, "health_check_interval": 10}, ], ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db @@ -123,7 +123,7 @@ def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_inj ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db @@ -179,7 +179,7 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db @@ -235,7 +235,7 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db @@ -277,7 +277,7 @@ def callback(pipe: Pipeline): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db @@ -323,7 +323,7 @@ def handler(message): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) def test_sharded_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db From 9f08d388df82bf5ac38f44977bf56d3df5cb843d Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 12 Sep 2025 10:58:06 +0300 Subject: [PATCH 14/20] [Async] Refactored healthcheck --- redis/asyncio/multidb/client.py | 69 +++--- redis/asyncio/multidb/config.py | 17 +- redis/asyncio/multidb/healthcheck.py | 183 ++++++++++++---- redis/multidb/healthcheck.py | 1 + tests/test_asyncio/test_multidb/conftest.py | 19 +- .../test_asyncio/test_multidb/test_client.py | 43 ++-- .../test_multidb/test_healthcheck.py | 205 ++++++++++++++++-- .../test_multidb/test_pipeline.py | 27 +-- tests/test_asyncio/test_scenario/conftest.py | 4 +- .../test_scenario/test_active_active.py | 14 +- 10 files changed, 417 insertions(+), 165 deletions(-) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index e098a4723b..a7f591bda0 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -6,12 +6,12 @@ from redis.asyncio.multidb.command_executor import DefaultCommandExecutor from redis.asyncio.multidb.database import AsyncDatabase, Databases from redis.asyncio.multidb.failure_detector import AsyncFailureDetector -from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD from redis.background import BackgroundScheduler from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands -from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException from redis.typing import KeyT, EncodableT, ChannelT logger = logging.getLogger(__name__) @@ -29,6 +29,10 @@ def __init__(self, config: MultiDbConfig): self._health_checks.extend(config.health_checks) self._health_check_interval = config.health_check_interval + self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( + config.health_check_probes, + config.health_check_delay + ) self._failure_detectors = config.default_failure_detectors() if config.failure_detectors is not None: @@ -244,42 +248,45 @@ async def _check_databases_health( Runs health checks as a recurring task. Runs health checks against all databases. """ - for database, _ in self._databases: - async with self._hc_lock: - await self._check_db_health(database, on_error) + results = await asyncio.wait_for( + asyncio.gather( + *( + asyncio.create_task(self._check_db_health(database)) + for database, _ in self._databases + ), + return_exceptions=True, + ), + timeout=self._health_check_interval, + ) - async def _check_db_health( - self, - database: AsyncDatabase, - on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, - ) -> None: + for result in results: + if isinstance(result, UnhealthyDatabaseException): + unhealthy_db = result.database + unhealthy_db.circuit.state = CBState.OPEN + + logger.exception( + 'Health check failed, due to exception', + exc_info=result.original_exception + ) + + if on_error: + on_error(result.original_exception) + + async def _check_db_health(self, database: AsyncDatabase,) -> bool: """ Runs health checks on the given database until first failure. """ - is_healthy = True - # Health check will setup circuit state - for health_check in self._health_checks: - if not is_healthy: - # If one of the health checks failed, it's considered unhealthy - break + is_healthy = await self._health_check_policy.execute(self._health_checks, database) - try: - is_healthy = await health_check.check_health(database) + if not is_healthy: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + return is_healthy + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED - if not is_healthy and database.circuit.state != CBState.OPEN: - database.circuit.state = CBState.OPEN - elif is_healthy and database.circuit.state != CBState.CLOSED: - database.circuit.state = CBState.CLOSED - except Exception as e: - if database.circuit.state != CBState.OPEN: - database.circuit.state = CBState.OPEN - is_healthy = False - - logger.exception('Health check failed, due to exception', exc_info=e) - - if on_error: - await on_error(e) + return is_healthy def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): loop = asyncio.get_running_loop() diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py index b5f4a0658d..1ec3c3498c 100644 --- a/redis/asyncio/multidb/config.py +++ b/redis/asyncio/multidb/config.py @@ -7,8 +7,8 @@ from redis.asyncio.multidb.database import Databases, Database from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper -from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, \ - EchoHealthCheck +from redis.asyncio.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies from redis.asyncio.retry import Retry from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList @@ -17,9 +17,9 @@ from redis.multidb.failure_detector import CommandFailureDetector DEFAULT_GRACE_PERIOD = 5.0 -DEFAULT_HEALTH_CHECK_INTERVAL = 5 DEFAULT_FAILURES_THRESHOLD = 3 DEFAULT_FAILURES_DURATION = 2 +DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL DEFAULT_FAILOVER_RETRIES = 3 DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) DEFAULT_AUTO_FALLBACK_INTERVAL = -1 @@ -78,8 +78,8 @@ class MultiDbConfig: failures_interval: Time interval for tracking database failures. health_checks: Optional list of additional health checks performed on databases. health_check_interval: Time interval for executing health checks. - health_check_retries: Number of retry attempts for performing health checks. - health_check_backoff: Backoff strategy for health check retries. + health_check_probes: Number of attempts to evaluate the health of a database. + health_check_delay: Delay between health check attempts. failover_strategy: Optional strategy for handling database failover scenarios. failover_retries: Number of retries allowed for failover operations. failover_backoff: Backoff strategy for failover retries. @@ -113,8 +113,9 @@ class MultiDbConfig: failures_interval: float = DEFAULT_FAILURES_DURATION health_checks: Optional[List[HealthCheck]] = None health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL - health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES - health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF + health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES + health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY + health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY failover_strategy: Optional[AsyncFailoverStrategy] = None failover_retries: int = DEFAULT_FAILOVER_RETRIES failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF @@ -160,7 +161,7 @@ def default_failure_detectors(self) -> List[AsyncFailureDetector]: def default_health_checks(self) -> List[HealthCheck]: return [ - EchoHealthCheck(retry=Retry(retries=self.health_check_retries, backoff=self.health_check_backoff)), + EchoHealthCheck(), ] def default_failover_strategy(self) -> AsyncFailoverStrategy: diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index 51302cb7ba..b20c0d1500 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -1,67 +1,167 @@ +import asyncio import logging from abc import ABC, abstractmethod -from typing import Optional, Tuple, Union +from enum import Enum +from typing import Optional, Tuple, Union, List from redis.asyncio import Redis from redis.asyncio.http.http_client import AsyncHTTPClientWrapper, DEFAULT_TIMEOUT from redis.asyncio.retry import Retry -from redis.retry import Retry as SyncRetry -from redis.backoff import ExponentialWithJitterBackoff +from redis.backoff import NoBackoff from redis.http.http_client import HttpClient -from redis.utils import dummy_fail_async +from redis.multidb.exception import UnhealthyDatabaseException -DEFAULT_HEALTH_CHECK_RETRIES = 3 -DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) +DEFAULT_HEALTH_CHECK_PROBES = 3 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_HEALTH_CHECK_DELAY = 0.5 logger = logging.getLogger(__name__) class HealthCheck(ABC): + @abstractmethod + async def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class HealthCheckPolicy(ABC): + """ + Health checks execution policy. + """ @property @abstractmethod - def retry(self) -> Retry: - """The retry object to use for health checks.""" + def health_check_probes(self) -> int: + """Number of probes to execute health checks.""" pass + @property @abstractmethod - async def check_health(self, database) -> bool: - """Function to determine the health status.""" + def health_check_delay(self) -> float: + """Delay between health check probes.""" pass -class AbstractHealthCheck(HealthCheck): - def __init__( - self, - retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) -> None: - self._retry = retry - self._retry.update_supported_errors([ConnectionRefusedError]) + @abstractmethod + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + """Execute health checks and return database health status.""" + pass + +class AbstractHealthCheckPolicy(HealthCheckPolicy): + def __init__(self, health_check_probes: int, health_check_delay: float): + if health_check_probes < 1: + raise ValueError("health_check_probes must be greater than 0") + self._health_check_probes = health_check_probes + self._health_check_delay = health_check_delay + + @property + def health_check_probes(self) -> int: + return self._health_check_probes @property - def retry(self) -> Retry: - return self._retry + def health_check_delay(self) -> float: + return self._health_check_delay @abstractmethod - async def check_health(self, database) -> bool: + async def execute(self, health_checks: List[HealthCheck], database) -> bool: pass -class EchoHealthCheck(AbstractHealthCheck): - def __init__( - self, - retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) -> None: - """ - Check database healthiness by sending an echo request. - """ - super().__init__( - retry=retry, - ) - async def check_health(self, database) -> bool: - return await self._retry.call_with_retry( - lambda: self._returns_echoed_message(database), - lambda _: dummy_fail_async() - ) +class HealthyAllPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if all health check probes are successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + for attempt in range(self.health_check_probes): + try: + if not await health_check.check_health(database): + return False + except Exception as e: + raise UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + await asyncio.sleep(self._health_check_delay) + return True + +class HealthyMajorityPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if a majority of health check probes are successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + if self.health_check_probes % 2 == 0: + unsuccessful_probes = self.health_check_probes / 2 + else: + unsuccessful_probes = (self.health_check_probes + 1) / 2 + + for attempt in range(self.health_check_probes): + try: + if not await health_check.check_health(database): + unsuccessful_probes -= 1 + if unsuccessful_probes <= 0: + return False + except Exception as e: + unsuccessful_probes -= 1 + if unsuccessful_probes <= 0: + raise UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + await asyncio.sleep(self._health_check_delay) + return True + +class HealthyAnyPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if at least one health check probe is successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) - async def _returns_echoed_message(self, database) -> bool: + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + is_healthy = False + + for health_check in health_checks: + exception = None + + for attempt in range(self.health_check_probes): + try: + if await health_check.check_health(database): + is_healthy = True + break + else: + is_healthy = False + except Exception as e: + exception = UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + await asyncio.sleep(self._health_check_delay) + + if not is_healthy and not exception: + return is_healthy + elif not is_healthy and exception: + raise exception + + return is_healthy + +class HealthCheckPolicies(Enum): + HEALTHY_ALL = HealthyAllPolicy + HEALTHY_MAJORITY = HealthyMajorityPolicy + HEALTHY_ANY = HealthyAnyPolicy + +class EchoHealthCheck(HealthCheck): + """ + Health check based on ECHO command. + """ + async def check_health(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] if isinstance(database.client, Redis): @@ -78,14 +178,13 @@ async def _returns_echoed_message(self, database) -> bool: return True -class LagAwareHealthCheck(AbstractHealthCheck): +class LagAwareHealthCheck(HealthCheck): """ Health check available for Redis Enterprise deployments. Verify via REST API that the database is healthy based on different lags. """ def __init__( self, - retry: SyncRetry = SyncRetry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), rest_api_port: int = 9443, lag_aware_tolerance: int = 100, timeout: float = DEFAULT_TIMEOUT, @@ -104,7 +203,6 @@ def __init__( Initialize LagAwareHealthCheck with the specified parameters. Args: - retry: Retry configuration for health checks rest_api_port: Port number for Redis Enterprise REST API (default: 9443) lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100) timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) @@ -117,14 +215,11 @@ def __init__( client_key_file: Path to client private key file for mutual TLS client_key_password: Password for encrypted client private key """ - super().__init__( - retry=retry, - ) self._http_client = AsyncHTTPClientWrapper( HttpClient( timeout=timeout, auth_basic=auth_basic, - retry=self.retry, + retry=Retry(NoBackoff(), retries=0), verify_tls=verify_tls, ca_file=ca_file, ca_path=ca_path, diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index eadfdf0e6f..d96e39a1da 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -58,6 +58,7 @@ def health_check_probes(self) -> int: def health_check_delay(self) -> float: return self._health_check_delay + @abstractmethod def execute(self, health_checks: List[HealthCheck], database) -> bool: pass diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py index 0ac231cf52..a62c30b51a 100644 --- a/tests/test_asyncio/test_multidb/conftest.py +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -3,16 +3,15 @@ import pytest from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL, \ - DatabaseConfig + DatabaseConfig, DEFAULT_HEALTH_CHECK_POLICY from redis.asyncio.multidb.failover import AsyncFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector -from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_PROBES from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.asyncio import Redis from redis.asyncio.multidb.database import Database, Databases - @pytest.fixture() def mock_client() -> Redis: return Mock(spec=Redis) @@ -79,18 +78,18 @@ def mock_db2(request) -> Database: def mock_multi_db_config( request, mock_fd, mock_fs, mock_hc, mock_ed ) -> MultiDbConfig: - hc_interval = request.param.get('hc_interval', None) - if hc_interval is None: - hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL - - auto_fallback_interval = request.param.get('auto_fallback_interval', None) - if auto_fallback_interval is None: - auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + hc_interval = request.param.get('hc_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + auto_fallback_interval = request.param.get('auto_fallback_interval', DEFAULT_AUTO_FALLBACK_INTERVAL) + health_check_policy = request.param.get('health_check_policy', DEFAULT_HEALTH_CHECK_POLICY) + health_check_probes = request.param.get('health_check_probes', DEFAULT_HEALTH_CHECK_PROBES) config = MultiDbConfig( databases_config=[Mock(spec=DatabaseConfig)], failure_detectors=[mock_fd], health_check_interval=hc_interval, + health_check_delay=0.05, + health_check_policy=health_check_policy, + health_check_probes=health_check_probes, failover_strategy=mock_fs, auto_fallback_interval=auto_fallback_interval, event_dispatcher=mock_ed diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index c2fe914e9f..875e11d6e5 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -9,8 +9,7 @@ from redis.asyncio.multidb.database import AsyncDatabase from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector -from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ - DEFAULT_HEALTH_CHECK_BACKOFF, HealthCheck +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck from redis.asyncio.retry import Retry from redis.event import EventDispatcher, AsyncOnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter @@ -46,7 +45,7 @@ async def test_execute_command_against_correct_db_on_successful_initialization( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -74,12 +73,12 @@ async def test_execute_command_against_correct_db_and_closed_circuit( patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command = AsyncMock(return_value='OK1') - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -90,7 +89,7 @@ async def test_execute_command_against_correct_db_and_closed_circuit( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -116,9 +115,7 @@ async def test_execute_command_against_correct_db_on_background_health_check_det databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] @@ -141,7 +138,7 @@ async def test_execute_command_against_correct_db_on_background_health_check_det 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -155,9 +152,7 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] @@ -201,7 +196,7 @@ async def test_execute_command_throws_exception_on_failed_initialization( with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): await client.set('key', 'value') - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -230,7 +225,7 @@ async def test_add_database_throws_exception_on_same_database( with pytest.raises(ValueError, match='Given database already exists'): await client.add_database(mock_db) - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -261,10 +256,10 @@ async def test_add_database_makes_new_database_active( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK2' - assert mock_hc.check_health.call_count == 2 + assert mock_hc.check_health.call_count == 6 await client.add_database(mock_db1) - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 assert await client.set('key', 'value') == 'OK1' @@ -297,7 +292,7 @@ async def test_remove_highest_weighted_database( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 await client.remove_database(mock_db1) assert await client.set('key', 'value') == 'OK2' @@ -331,7 +326,7 @@ async def test_update_database_weight_to_be_highest( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 await client.update_database_weight(mock_db2, 0.8) assert mock_db2.weight == 0.8 @@ -373,7 +368,7 @@ async def test_add_new_failure_detector( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 # Simulate failing command events that lead to a failure detection for i in range(5): @@ -418,7 +413,7 @@ async def test_add_new_health_check( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 another_hc = Mock(spec=HealthCheck) another_hc.check_health.return_value = True @@ -426,8 +421,8 @@ async def test_add_new_health_check( await client.add_health_check(another_hc) await client._check_db_health(mock_db1) - assert mock_hc.check_health.call_count == 4 - assert another_hc.check_health.call_count == 1 + assert mock_hc.check_health.call_count == 12 + assert another_hc.check_health.call_count == 3 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -457,7 +452,7 @@ async def test_set_active_database( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 await client.set_active_database(mock_db) assert await client.set('key', 'value') == 'OK' diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py index ba6e8c2b7c..4924914769 100644 --- a/tests/test_asyncio/test_multidb/test_healthcheck.py +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -1,15 +1,181 @@ import pytest -from mock.mock import AsyncMock, MagicMock +from mock.mock import AsyncMock, Mock from redis.asyncio.multidb.database import Database -from redis.asyncio.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck -from redis.asyncio.retry import Retry -from redis.backoff import ExponentialBackoff +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck, HealthCheck, HealthyAllPolicy, \ + HealthyMajorityPolicy, HealthyAnyPolicy from redis.http.http_client import HttpError from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError +from redis.multidb.exception import UnhealthyDatabaseException +class TestHealthyAllPolicy: + @pytest.mark.asyncio + async def test_policy_returns_true_for_all_successful_probes(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.return_value = True + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) == True + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 3 + + @pytest.mark.asyncio + async def test_policy_returns_false_on_first_failed_probe(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, False] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) == False + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + + @pytest.mark.asyncio + async def test_policy_raise_unhealthy_database_exception(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, ConnectionError] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + +class TestHealthyMajorityPolicy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + (3, [True, False, False], [True, True, True], 3, 0, False), + (3, [True, True, True], [True, False, False], 3, 3, False), + (3, [True, False, True], [True, True, True], 3, 3, True), + (3, [True, True, True], [True, False, True], 3, 3, True), + (3, [True, True, False], [True, False, True], 3, 3, True), + (4, [True, True, False, False], [True, True, True, True], 4, 0, False), + (4, [True, True, True, True], [True, True, False, False], 4, 4, False), + (4, [False, True, True, True], [True, True, True, True], 4, 4, True), + (4, [True, True, True, True], [True, False, True, True], 4, 4, True), + (4, [False, True, True, True], [True, True, False, True], 4, 4, True), + ], + ids=[ + 'HC1 - no majority - odd', 'HC2 - no majority - odd', 'HC1 - majority- odd', + 'HC2 - majority - odd', 'HC1 + HC2 - majority - odd', 'HC1 - no majority - even', + 'HC2 - no majority - even','HC1 - majority - even', 'HC2 - majority - even', + 'HC1 + HC2 - majority - even' + ] + ) + async def test_policy_returns_true_for_majority_successful_probes( + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyMajorityPolicy(probes, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count", + [ + (3, [True, ConnectionError, ConnectionError], [True, True, True], 3, 0), + (3, [True, True, True], [True, ConnectionError, ConnectionError], 3, 3), + (4, [True, ConnectionError, ConnectionError, True], [True, True, True, True], 3, 0), + (4, [True, True, True, True], [True, ConnectionError, ConnectionError, False], 4, 3), + ], + ids=[ + 'HC1 - majority- odd', 'HC2 - majority - odd', + 'HC1 - majority - even', 'HC2 - majority - even', + ] + ) + async def test_policy_raise_unhealthy_database_exception_on_majority_probes_exceptions( + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + +class TestHealthyAnyPolicy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + ([False, False, False], [True, True, True], 3, 0, False), + ([False, False, True], [False, False, False], 3, 3, False), + ([False, True, True], [False, False, True], 2, 3, True), + ([True, True, True], [False, True, False], 1, 2, True), + ], + ids=[ + 'HC1 - no successful', 'HC2 - no successful', + 'HC1 - successful', 'HC2 - successful', + ] + ) + async def test_policy_returns_true_for_any_successful_probe( + self, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + @pytest.mark.asyncio + async def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed_health_check(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [False, False, ConnectionError] + mock_hc2.check_health.side_effect = [True, True, True] + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + class TestEchoHealthCheck: @pytest.mark.asyncio @@ -18,12 +184,12 @@ async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + mock_client.execute_command = AsyncMock(side_effect=['healthcheck']) + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert await hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 @pytest.mark.asyncio async def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): @@ -31,22 +197,22 @@ async def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_clien Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'wrong']) - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + mock_client.execute_command = AsyncMock(side_effect=['wrong']) + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert await hc.check_health(db) == False - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 @pytest.mark.asyncio async def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): - mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) + mock_client.execute_command = AsyncMock(side_effect=['healthcheck']) mock_cb.state = CBState.HALF_OPEN - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert await hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 class TestLagAwareHealthCheck: @pytest.mark.asyncio @@ -75,7 +241,6 @@ async def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_clien ] hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), rest_api_port=1234, lag_aware_tolerance=150 ) # Inject our mocked http client @@ -115,9 +280,7 @@ async def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, m None, ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") @@ -141,9 +304,7 @@ async def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_c {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") @@ -170,9 +331,7 @@ async def test_propagates_http_error_from_availability(self, mock_client, mock_c HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py index 5af2e3e864..a9e0b8083d 100644 --- a/tests/test_asyncio/test_multidb/test_pipeline.py +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -8,8 +8,7 @@ from redis.asyncio.multidb.client import MultiDBClient from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy -from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ - DEFAULT_HEALTH_CHECK_BACKOFF +from redis.asyncio.multidb.healthcheck import EchoHealthCheck from redis.asyncio.retry import Retry from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_FAILOVER_BACKOFF @@ -57,7 +56,7 @@ async def test_executes_pipeline_against_correct_db( pipe.get('key1') assert await pipe.execute() == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -83,7 +82,7 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.execute.return_value = ['OK1', 'value1'] mock_db1.client.pipeline.return_value = pipe - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -93,7 +92,7 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.get('key1') assert await pipe.execute() == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -104,7 +103,7 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -130,9 +129,7 @@ async def test_execute_pipeline_against_correct_db_on_background_health_check_de databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -219,7 +216,7 @@ async def callback(pipe: Pipeline): pipe.get('key1') assert await client.transaction(callback) == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -243,7 +240,7 @@ async def test_execute_transaction_against_correct_db_and_closed_circuit( patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.transaction.return_value = ['OK1', 'value1'] - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -253,7 +250,7 @@ async def callback(pipe: Pipeline): pipe.get('key1') assert await client.transaction(callback) == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -264,7 +261,7 @@ async def callback(pipe: Pipeline): 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -290,9 +287,7 @@ async def test_execute_transaction_against_correct_db_on_background_health_check databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 735af7fed6..7992983a17 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -82,10 +82,10 @@ async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChanged command_retry=command_retry, failure_threshold=failure_threshold, health_checks=health_checks, - health_check_retries=3, + health_check_probes=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, - health_check_backoff=ExponentialBackoff(cap=5, base=0.5), + health_check_delay=ExponentialBackoff(cap=5, base=0.5), ) return config, listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index c054d17dc2..dd3deba7b1 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -37,7 +37,7 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(10) + sleep(15) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -49,7 +49,7 @@ def teardown_method(self, method): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db @@ -83,7 +83,7 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db @@ -113,7 +113,7 @@ async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fau ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db @@ -155,7 +155,7 @@ async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db @@ -197,7 +197,7 @@ async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_ ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db @@ -229,7 +229,7 @@ async def callback(pipe: Pipeline): [{"failure_threshold": 2}], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db From 3c45f6e1c61d65121b2fa6fe0c7f3ac30685ac6b Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 15 Sep 2025 14:27:05 +0300 Subject: [PATCH 15/20] [Sync] Refactored Failover Strategy --- redis/multidb/circuit.py | 2 + redis/multidb/config.py | 25 ++- redis/multidb/exception.py | 6 +- redis/multidb/failover.py | 53 ++++-- redis/multidb/failure_detector.py | 6 +- redis/multidb/healthcheck.py | 2 + tests/test_multidb/conftest.py | 4 +- tests/test_multidb/test_client.py | 13 +- tests/test_multidb/test_config.py | 4 +- tests/test_multidb/test_failover.py | 99 ++++++++-- tests/test_multidb/test_pipeline.py | 11 +- tests/test_scenario/conftest.py | 4 +- tests/test_scenario/test_active_active.py | 214 +++++++++++++++------- 13 files changed, 309 insertions(+), 134 deletions(-) diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 8f904c0e4b..5796840e27 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -4,6 +4,8 @@ import pybreaker +DEFAULT_GRACE_PERIOD = 5.0 + class State(Enum): CLOSED = 'closed' OPEN = 'open' diff --git a/redis/multidb/config.py b/redis/multidb/config.py index e44b3eaae7..6e990db328 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -6,23 +6,19 @@ from redis import Redis, ConnectionPool from redis.asyncio import RedisCluster -from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff +from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database, Databases -from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector +from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector, DEFAULT_FAILURES_THRESHOLD, \ + DEFAULT_FAILURES_DURATION from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_PROBES, \ - DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies -from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy + DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies, DEFAULT_HEALTH_CHECK_POLICY +from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy, DEFAULT_FAILOVER_ATTEMPTS, \ + DEFAULT_FAILOVER_DELAY from redis.retry import Retry -DEFAULT_GRACE_PERIOD = 5.0 -DEFAULT_FAILURES_THRESHOLD = 3 -DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL -DEFAULT_FAILURES_DURATION = 2 -DEFAULT_FAILOVER_RETRIES = 3 -DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) DEFAULT_AUTO_FALLBACK_INTERVAL = -1 def default_event_dispatcher() -> EventDispatcherInterface: @@ -119,8 +115,8 @@ class MultiDbConfig: health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY failover_strategy: Optional[FailoverStrategy] = None - failover_retries: int = DEFAULT_FAILOVER_RETRIES - failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF + failover_retries: int = DEFAULT_FAILOVER_ATTEMPTS + failover_delay: float = DEFAULT_FAILOVER_DELAY auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) @@ -166,5 +162,6 @@ def default_health_checks(self) -> List[HealthCheck]: def default_failover_strategy(self) -> FailoverStrategy: return WeightBasedFailoverStrategy( - retry=Retry(retries=self.failover_retries, backoff=self.failover_backoff), + failover_delay=self.failover_delay, + failover_attempts=self.failover_retries, ) diff --git a/redis/multidb/exception.py b/redis/multidb/exception.py index b49896c34d..3d4e9bf0ba 100644 --- a/redis/multidb/exception.py +++ b/redis/multidb/exception.py @@ -7,4 +7,8 @@ class UnhealthyDatabaseException(Exception): def __init__(self, message, database, original_exception): super().__init__(message) self.database = database - self.original_exception = original_exception \ No newline at end of file + self.original_exception = original_exception + +class TemporaryUnavailableException(Exception): + """Exception raised when all databases in setup is temporary unavailable.""" + pass \ No newline at end of file diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index fd08b77ecd..6f7ac8fd17 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -1,12 +1,13 @@ +import time from abc import ABC, abstractmethod from redis.data_structure import WeightedList from redis.multidb.database import Databases, SyncDatabase from redis.multidb.circuit import State as CBState -from redis.multidb.exception import NoValidDatabaseException -from redis.retry import Retry -from redis.utils import dummy_fail +from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException +DEFAULT_FAILOVER_ATTEMPTS = 10 +DEFAULT_FAILOVER_DELAY = 12 class FailoverStrategy(ABC): @@ -27,25 +28,45 @@ class WeightBasedFailoverStrategy(FailoverStrategy): """ def __init__( self, - retry: Retry - ): - self._retry = retry - self._retry.update_supported_errors([NoValidDatabaseException]) + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, + ) -> None: self._databases = WeightedList() + self._failover_attempts = failover_attempts + self._failover_delay = failover_delay + self._next_attempt_ts: int = 0 + self._failover_counter: int = 0 @property def database(self) -> SyncDatabase: - return self._retry.call_with_retry( - lambda: self._get_active_database(), - lambda _: dummy_fail() - ) + try: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + self._reset() + return database + + raise NoValidDatabaseException('No valid database available for communication') + except NoValidDatabaseException as e: + if self._next_attempt_ts == 0: + self._next_attempt_ts = time.time() + self._failover_delay + self._failover_counter += 1 + elif time.time() >= self._next_attempt_ts: + self._next_attempt_ts += self._failover_delay + self._failover_counter += 1 + + if self._failover_counter > self._failover_attempts: + self._reset() + raise e + else: + raise TemporaryUnavailableException( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) def set_databases(self, databases: Databases) -> None: self._databases = databases - def _get_active_database(self) -> SyncDatabase: - for database, _ in self._databases: - if database.circuit.state == CBState.CLOSED: - return database + def _reset(self) -> None: + self._next_attempt_ts = 0 + self._failover_counter = 0 - raise NoValidDatabaseException('No valid database available for communication') diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index ef4bd35f69..09e9274e8a 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -7,6 +7,8 @@ from redis.multidb.circuit import State as CBState +DEFAULT_FAILURES_THRESHOLD = 3 +DEFAULT_FAILURES_DURATION = 2 class FailureDetector(ABC): @@ -26,8 +28,8 @@ class CommandFailureDetector(FailureDetector): """ def __init__( self, - threshold: int, - duration: float, + threshold: int = DEFAULT_FAILURES_THRESHOLD, + duration: float = DEFAULT_FAILURES_DURATION, error_types: Optional[List[Type[Exception]]] = None, ) -> None: """ diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index d96e39a1da..86d3983444 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -155,6 +155,8 @@ class HealthCheckPolicies(Enum): HEALTHY_MAJORITY = HealthyMajorityPolicy HEALTHY_ANY = HealthyAnyPolicy +DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL + class EchoHealthCheck(HealthCheck): """ Health check based on ECHO command. diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index f47da3b174..3b1f7f369b 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -6,11 +6,11 @@ from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_AUTO_FALLBACK_INTERVAL, DEFAULT_HEALTH_CHECK_POLICY + DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_PROBES +from redis.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_POLICY from tests.conftest import mock_ed diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 4cac5c51ec..a818b90eb9 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -6,15 +6,12 @@ from redis.event import EventDispatcher, OnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter -from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ - DEFAULT_FAILOVER_BACKOFF from redis.multidb.database import SyncDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException -from redis.multidb.failover import WeightBasedFailoverStrategy +from redis.multidb.failover import WeightBasedFailoverStrategy, DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_PROBES -from redis.retry import Retry +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck from tests.test_multidb.conftest import create_weighted_list @@ -120,7 +117,8 @@ def test_execute_command_against_correct_db_on_background_health_check_determine mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] mock_multi_db_config.health_check_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + failover_attempts=DEFAULT_FAILOVER_ATTEMPTS, + failover_delay=DEFAULT_FAILOVER_DELAY ) client = MultiDBClient(mock_multi_db_config) @@ -157,7 +155,8 @@ def test_execute_command_auto_fallback_to_highest_weight_db( mock_multi_db_config.health_check_interval = 0.2 mock_multi_db_config.auto_fallback_interval = 0.4 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + failover_attempts=DEFAULT_FAILOVER_ATTEMPTS, + failover_delay=DEFAULT_FAILOVER_DELAY ) client = MultiDBClient(mock_multi_db_config) diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 1ea63a0e14..abed8ec2fa 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,8 +1,8 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker, DEFAULT_GRACE_PERIOD from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD + DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig from redis.multidb.database import Database from redis.multidb.failure_detector import CommandFailureDetector, FailureDetector from redis.multidb.healthcheck import EchoHealthCheck, HealthCheck diff --git a/tests/test_multidb/test_failover.py b/tests/test_multidb/test_failover.py index 06390c4e2e..0759df88d6 100644 --- a/tests/test_multidb/test_failover.py +++ b/tests/test_multidb/test_failover.py @@ -1,14 +1,12 @@ +from time import sleep from unittest.mock import PropertyMock import pytest -from redis.backoff import NoBackoff, ExponentialBackoff from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState -from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException from redis.multidb.failover import WeightBasedFailoverStrategy -from redis.retry import Retry - class TestWeightBasedFailoverStrategy: @pytest.mark.parametrize( @@ -29,13 +27,12 @@ class TestWeightBasedFailoverStrategy: indirect=True, ) def test_get_valid_database(self, mock_db, mock_db1, mock_db2): - retry = Retry(NoBackoff(), 0) databases = WeightedList() databases.add(mock_db, mock_db.weight) databases.add(mock_db1, mock_db1.weight) databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy = WeightBasedFailoverStrategy() failover_strategy.set_databases(databases) assert failover_strategy.database == mock_db1 @@ -51,21 +48,35 @@ def test_get_valid_database(self, mock_db, mock_db1, mock_db2): ], indirect=True, ) - def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): + def test_get_valid_database_with_failover_attempts(self, mock_db, mock_db1, mock_db2): state_mock = PropertyMock( side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] ) type(mock_db.circuit).state = state_mock + failover_attempts = 3 - retry = Retry(ExponentialBackoff(cap=1), 3) databases = WeightedList() databases.add(mock_db, mock_db.weight) databases.add(mock_db1, mock_db1.weight) databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy = WeightBasedFailoverStrategy( + failover_attempts=failover_attempts, + failover_delay=0.1 + ) failover_strategy.set_databases(databases) - assert failover_strategy.database == mock_db + for i in range(failover_attempts + 1): + try: + database = failover_strategy.database + assert database == mock_db + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + sleep(0.11) + pass + assert state_mock.call_count == 4 @pytest.mark.parametrize( @@ -79,22 +90,79 @@ def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): ], indirect=True, ) - def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): + def test_get_valid_database_throws_exception_on_attempts_exceed(self, mock_db, mock_db1, mock_db2): state_mock = PropertyMock( side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] ) type(mock_db.circuit).state = state_mock + failover_attempts = 3 - retry = Retry(ExponentialBackoff(cap=1), 3) databases = WeightedList() databases.add(mock_db, mock_db.weight) databases.add(mock_db1, mock_db1.weight) databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy = WeightBasedFailoverStrategy( + failover_attempts=failover_attempts, + failover_delay=0.1 + ) failover_strategy.set_databases(databases) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert failover_strategy.database + for i in range(failover_attempts + 1): + try: + database = failover_strategy.database + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + sleep(0.11) + pass + + assert state_mock.call_count == 4 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_get_valid_database_throws_exception_on_attempts_does_not_exceed_delay(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + ) + type(mock_db.circuit).state = state_mock + failover_attempts = 3 + + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy( + failover_attempts=failover_attempts, + failover_delay=0.1 + ) + failover_strategy.set_databases(databases) + + with pytest.raises(TemporaryUnavailableException, match=( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + )): + for i in range(failover_attempts + 1): + try: + database = failover_strategy.database + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + if i == failover_attempts: + raise e assert state_mock.call_count == 4 @@ -110,8 +178,7 @@ def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db indirect=True, ) def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): - retry = Retry(NoBackoff(), 0) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy = WeightBasedFailoverStrategy(failover_attempts=0, failover_delay=0) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): assert failover_strategy.database \ No newline at end of file diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index 608cc3373d..0176581d20 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -7,11 +7,8 @@ from redis.client import Pipeline from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.client import MultiDBClient -from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ - DEFAULT_FAILOVER_BACKOFF -from redis.multidb.failover import WeightBasedFailoverStrategy +from redis.multidb.failover import WeightBasedFailoverStrategy, DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY from redis.multidb.healthcheck import EchoHealthCheck -from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list def mock_pipe() -> Pipeline: @@ -145,7 +142,8 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin mock_multi_db_config.health_check_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + failover_attempts=DEFAULT_FAILOVER_ATTEMPTS, + failover_delay=DEFAULT_FAILOVER_DELAY ) client = MultiDBClient(mock_multi_db_config) @@ -293,7 +291,8 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter mock_multi_db_config.health_check_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + failover_attempts=DEFAULT_FAILOVER_ATTEMPTS, + failover_delay=DEFAULT_FAILOVER_DELAY ) client = MultiDBClient(mock_multi_db_config) diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 54325e58a5..d9c1c17aff 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -9,9 +9,9 @@ from redis.backoff import NoBackoff, ExponentialBackoff from redis.event import EventDispatcher, EventListenerInterface from redis.multidb.client import MultiDBClient -from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_FAILURES_THRESHOLD +from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL from redis.multidb.event import ActiveDatabaseChanged +from redis.multidb.failure_detector import DEFAULT_FAILURES_THRESHOLD from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_DELAY from redis.retry import Retry from tests.test_scenario.fault_injector_client import FaultInjectorClient diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 23a75886da..a3056323a5 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -7,8 +7,13 @@ import pytest from redis import Redis, RedisCluster +from redis.backoff import ConstantBackoff from redis.client import Pipeline +from redis.multidb.exception import TemporaryUnavailableException +from redis.multidb.failover import DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY from redis.multidb.healthcheck import LagAwareHealthCheck +from redis.retry import Retry +from redis.utils import dummy_fail from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -47,10 +52,17 @@ def teardown_method(self, method): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + # Handle unavailable databases from previous test. + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, @@ -59,17 +71,26 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector ) # Client initialized on the first command. - r_multi_db.set('key', 'value') + retry.call_with_retry( + lambda : r_multi_db.set('key', 'value'), + lambda _ : dummy_fail() + ) thread.start() # Execute commands before network failure while not event.is_set(): - assert r_multi_db.get('key') == 'value' + assert retry.call_with_retry( + lambda : r_multi_db.get('key'), + lambda _ : dummy_fail() + ) == 'value' sleep(0.5) # Execute commands until database failover while not listener.is_changed_flag: - assert r_multi_db.get('key') == 'value' + assert retry.call_with_retry( + lambda : r_multi_db.get('key'), + lambda _ : dummy_fail() + ) == 'value' sleep(0.5) @pytest.mark.parametrize( @@ -81,9 +102,14 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -101,17 +127,26 @@ def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_inj ) # Client initialized on the first command. - r_multi_db.set('key', 'value') + retry.call_with_retry( + lambda : r_multi_db.set('key', 'value'), + lambda _ : dummy_fail() + ) thread.start() # Execute commands before network failure while not event.is_set(): - assert r_multi_db.get('key') == 'value' + assert retry.call_with_retry( + lambda : r_multi_db.get('key'), + lambda _ : dummy_fail() + ) == 'value' sleep(0.5) # Execute commands after network failure while not listener.is_changed_flag: - assert r_multi_db.get('key') == 'value' + assert retry.call_with_retry( + lambda : r_multi_db.get('key'), + lambda _ : dummy_fail() + ) == 'value' sleep(0.5) @pytest.mark.parametrize( @@ -123,9 +158,14 @@ def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_inj ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -134,20 +174,7 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault args=(fault_injector_client,config,event) ) - # Client initialized on first pipe execution. - with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - - thread.start() - - # Execute pipeline before network failure - while not event.is_set(): + def callback(): with r_multi_db.pipeline() as pipe: pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -156,18 +183,28 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + # Client initialized on first pipe execution. + retry.call_with_retry( + lambda : callback(), + lambda _ : dummy_fail() + ) + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail() + ) sleep(0.5) # Execute pipeline until database failover for _ in range(5): - with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail() + ) sleep(0.5) @pytest.mark.parametrize( @@ -179,9 +216,14 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -190,20 +232,7 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject args=(fault_injector_client,config,event) ) - # Client initialized on first pipe execution. - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - - thread.start() - - # Execute pipeline before network failure - while not event.is_set(): + def callback(): pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -212,18 +241,29 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + # Client initialized on first pipe execution. + retry.call_with_retry( + lambda : callback(), + lambda _ : dummy_fail() + ) + + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail() + ) sleep(0.5) # Execute pipeline until database failover for _ in range(5): - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail() + ) sleep(0.5) @pytest.mark.parametrize( @@ -235,9 +275,14 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -255,17 +300,26 @@ def callback(pipe: Pipeline): pipe.get('{hash}key3') # Client initialized on first transaction execution. - r_multi_db.transaction(callback) + retry.call_with_retry( + lambda : r_multi_db.transaction(callback), + lambda _ : dummy_fail() + ) thread.start() # Execute transaction before network failure while not event.is_set(): - r_multi_db.transaction(callback) + retry.call_with_retry( + lambda: r_multi_db.transaction(callback), + lambda _: dummy_fail() + ) sleep(0.5) # Execute transaction until database failover while not listener.is_changed_flag: - r_multi_db.transaction(callback) + retry.call_with_retry( + lambda: r_multi_db.transaction(callback), + lambda _: dummy_fail() + ) sleep(0.5) @pytest.mark.parametrize( @@ -277,9 +331,14 @@ def callback(pipe: Pipeline): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -297,18 +356,27 @@ def handler(message): pubsub = r_multi_db.pubsub() # Assign a handler and run in a separate thread. - pubsub.subscribe(**{'test-channel': handler}) + retry.call_with_retry( + lambda: pubsub.subscribe(**{'test-channel': handler}), + lambda _: dummy_fail() + ) pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) thread.start() # Execute publish before network failure while not event.is_set(): - r_multi_db.publish('test-channel', data) + retry.call_with_retry( + lambda: r_multi_db.publish('test-channel', data), + lambda _: dummy_fail() + ) sleep(0.5) # Execute publish until database failover while not listener.is_changed_flag: - r_multi_db.publish('test-channel', data) + retry.call_with_retry( + lambda: r_multi_db.publish('test-channel', data), + lambda _: dummy_fail() + ) sleep(0.5) pubsub_thread.stop() @@ -323,9 +391,14 @@ def handler(message): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) def test_sharded_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -343,7 +416,10 @@ def handler(message): pubsub = r_multi_db.pubsub() # Assign a handler and run in a separate thread. - pubsub.ssubscribe(**{'test-channel': handler}) + retry.call_with_retry( + lambda: pubsub.ssubscribe(**{'test-channel': handler}), + lambda _: dummy_fail() + ) pubsub_thread = pubsub.run_in_thread( sleep_time=0.1, daemon=True, @@ -353,12 +429,18 @@ def handler(message): # Execute publish before network failure while not event.is_set(): - r_multi_db.spublish('test-channel', data) + retry.call_with_retry( + lambda: r_multi_db.spublish('test-channel', data), + lambda _: dummy_fail() + ) sleep(0.5) # Execute publish until database failover while not listener.is_changed_flag: - r_multi_db.spublish('test-channel', data) + retry.call_with_retry( + lambda: r_multi_db.spublish('test-channel', data), + lambda _: dummy_fail() + ) sleep(0.5) pubsub_thread.stop() From bc598d01b69ef2a303548c8e568df05eb59f94df Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 15 Sep 2025 15:07:41 +0300 Subject: [PATCH 16/20] [Async] Refactored Failover Strategy --- redis/asyncio/multidb/config.py | 25 ++- redis/asyncio/multidb/failover.py | 52 +++-- redis/asyncio/multidb/failure_detector.py | 2 + redis/asyncio/multidb/healthcheck.py | 2 + tests/test_asyncio/test_multidb/conftest.py | 6 +- .../test_asyncio/test_multidb/test_client.py | 10 +- .../test_multidb/test_failover.py | 97 ++++++++-- .../test_multidb/test_pipeline.py | 10 +- .../test_scenario/test_active_active.py | 182 +++++++++++++----- 9 files changed, 272 insertions(+), 114 deletions(-) diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py index 1ec3c3498c..eff7c994e4 100644 --- a/redis/asyncio/multidb/config.py +++ b/redis/asyncio/multidb/config.py @@ -5,23 +5,19 @@ from redis.asyncio import ConnectionPool, Redis, RedisCluster from redis.asyncio.multidb.database import Databases, Database -from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy -from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper +from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy, DEFAULT_FAILOVER_DELAY, \ + DEFAULT_FAILOVER_ATTEMPTS +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper, \ + DEFAULT_FAILURES_THRESHOLD, DEFAULT_FAILURES_DURATION from redis.asyncio.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies + DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies, DEFAULT_HEALTH_CHECK_POLICY from redis.asyncio.retry import Retry -from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff +from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcherInterface, EventDispatcher -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter, DEFAULT_GRACE_PERIOD from redis.multidb.failure_detector import CommandFailureDetector -DEFAULT_GRACE_PERIOD = 5.0 -DEFAULT_FAILURES_THRESHOLD = 3 -DEFAULT_FAILURES_DURATION = 2 -DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL -DEFAULT_FAILOVER_RETRIES = 3 -DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) DEFAULT_AUTO_FALLBACK_INTERVAL = -1 def default_event_dispatcher() -> EventDispatcherInterface: @@ -117,8 +113,8 @@ class MultiDbConfig: health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY failover_strategy: Optional[AsyncFailoverStrategy] = None - failover_retries: int = DEFAULT_FAILOVER_RETRIES - failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF + failover_retries: int = DEFAULT_FAILOVER_ATTEMPTS + failover_delay: float = DEFAULT_FAILOVER_DELAY auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) @@ -166,5 +162,6 @@ def default_health_checks(self) -> List[HealthCheck]: def default_failover_strategy(self) -> AsyncFailoverStrategy: return WeightBasedFailoverStrategy( - retry=Retry(retries=self.failover_retries, backoff=self.failover_backoff), + failover_delay=self.failover_delay, + failover_attempts=self.failover_retries, ) \ No newline at end of file diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py index a2ed427e05..62e415d397 100644 --- a/redis/asyncio/multidb/failover.py +++ b/redis/asyncio/multidb/failover.py @@ -1,12 +1,13 @@ +import time from abc import abstractmethod, ABC from redis.asyncio.multidb.database import AsyncDatabase, Databases from redis.multidb.circuit import State as CBState -from redis.asyncio.retry import Retry from redis.data_structure import WeightedList -from redis.multidb.exception import NoValidDatabaseException -from redis.utils import dummy_fail_async +from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException +DEFAULT_FAILOVER_ATTEMPTS = 10 +DEFAULT_FAILOVER_DELAY = 12 class AsyncFailoverStrategy(ABC): @@ -26,24 +27,43 @@ class WeightBasedFailoverStrategy(AsyncFailoverStrategy): """ def __init__( self, - retry: Retry + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, ): - self._retry = retry - self._retry.update_supported_errors([NoValidDatabaseException]) self._databases = WeightedList() + self._failover_attempts = failover_attempts + self._failover_delay = failover_delay + self._next_attempt_ts: int = 0 + self._failover_counter: int = 0 async def database(self) -> AsyncDatabase: - return await self._retry.call_with_retry( - lambda: self._get_active_database(), - lambda _: dummy_fail_async() - ) + try: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + self._reset() + return database + + raise NoValidDatabaseException('No valid database available for communication') + except NoValidDatabaseException as e: + if self._next_attempt_ts == 0: + self._next_attempt_ts = time.time() + self._failover_delay + self._failover_counter += 1 + elif time.time() >= self._next_attempt_ts: + self._next_attempt_ts += self._failover_delay + self._failover_counter += 1 + + if self._failover_counter > self._failover_attempts: + self._reset() + raise e + else: + raise TemporaryUnavailableException( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) def set_databases(self, databases: Databases) -> None: self._databases = databases - async def _get_active_database(self) -> AsyncDatabase: - for database, _ in self._databases: - if database.circuit.state == CBState.CLOSED: - return database - - raise NoValidDatabaseException('No valid database available for communication') \ No newline at end of file + def _reset(self) -> None: + self._next_attempt_ts = 0 + self._failover_counter = 0 \ No newline at end of file diff --git a/redis/asyncio/multidb/failure_detector.py b/redis/asyncio/multidb/failure_detector.py index 8aa4752924..687e294c6d 100644 --- a/redis/asyncio/multidb/failure_detector.py +++ b/redis/asyncio/multidb/failure_detector.py @@ -2,6 +2,8 @@ from redis.multidb.failure_detector import FailureDetector +DEFAULT_FAILURES_THRESHOLD = 3 +DEFAULT_FAILURES_DURATION = 2 class AsyncFailureDetector(ABC): diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index b20c0d1500..0b3f57702d 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -157,6 +157,8 @@ class HealthCheckPolicies(Enum): HEALTHY_MAJORITY = HealthyMajorityPolicy HEALTHY_ANY = HealthyAnyPolicy +DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL + class EchoHealthCheck(HealthCheck): """ Health check based on ECHO command. diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py index a62c30b51a..f5ea12d9b0 100644 --- a/tests/test_asyncio/test_multidb/conftest.py +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -2,11 +2,11 @@ import pytest -from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL, \ - DatabaseConfig, DEFAULT_HEALTH_CHECK_POLICY +from redis.asyncio.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_AUTO_FALLBACK_INTERVAL from redis.asyncio.multidb.failover import AsyncFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector -from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_PROBES +from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_HEALTH_CHECK_POLICY from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.asyncio import Redis diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index 875e11d6e5..e2ebb89bca 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -5,12 +5,10 @@ import pytest from redis.asyncio.multidb.client import MultiDBClient -from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES, DEFAULT_FAILOVER_BACKOFF from redis.asyncio.multidb.database import AsyncDatabase from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck -from redis.asyncio.retry import Retry from redis.event import EventDispatcher, AsyncOnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.exception import NoValidDatabaseException @@ -120,9 +118,7 @@ async def test_execute_command_against_correct_db_on_background_health_check_det mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert await client.set('key', 'value') == 'OK1' @@ -158,9 +154,7 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.auto_fallback_interval = 0.2 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert await client.set('key', 'value') == 'OK1' diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py index f692c40643..7319ffc9bd 100644 --- a/tests/test_asyncio/test_multidb/test_failover.py +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import PropertyMock import pytest @@ -5,7 +6,7 @@ from redis.backoff import NoBackoff, ExponentialBackoff from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState -from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy from redis.asyncio.retry import Retry @@ -30,13 +31,12 @@ class TestAsyncWeightBasedFailoverStrategy: indirect=True, ) async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): - retry = Retry(NoBackoff(), 0) databases = WeightedList() databases.add(mock_db, mock_db.weight) databases.add(mock_db1, mock_db1.weight) databases.add(mock_db2, mock_db2.weight) - strategy = WeightBasedFailoverStrategy(retry=retry) + strategy = WeightBasedFailoverStrategy() strategy.set_databases(databases) assert await strategy.database() == mock_db1 @@ -53,21 +53,35 @@ async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): ], indirect=True, ) - async def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): + async def test_get_valid_database_with_failover_attempts(self, mock_db, mock_db1, mock_db2): state_mock = PropertyMock( side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] ) type(mock_db.circuit).state = state_mock + failover_attempts = 3 - retry = Retry(ExponentialBackoff(cap=1), 3) databases = WeightedList() databases.add(mock_db, mock_db.weight) databases.add(mock_db1, mock_db1.weight) databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy = WeightBasedFailoverStrategy( + failover_attempts=failover_attempts, + failover_delay=0.1 + ) failover_strategy.set_databases(databases) - assert await failover_strategy.database() == mock_db + for i in range(failover_attempts + 1): + try: + database = await failover_strategy.database() + assert database == mock_db + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + await asyncio.sleep(0.11) + pass + assert state_mock.call_count == 4 @pytest.mark.asyncio @@ -82,22 +96,80 @@ async def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2 ], indirect=True, ) - async def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): + async def test_get_valid_database_throws_exception_on_attempts_exceed(self, mock_db, mock_db1, mock_db2): state_mock = PropertyMock( side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] ) type(mock_db.circuit).state = state_mock + failover_attempts = 3 - retry = Retry(ExponentialBackoff(cap=1), 3) databases = WeightedList() databases.add(mock_db, mock_db.weight) databases.add(mock_db1, mock_db1.weight) databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy = WeightBasedFailoverStrategy( + failover_attempts=failover_attempts, + failover_delay=0.1 + ) failover_strategy.set_databases(databases) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert await failover_strategy.database() + for i in range(failover_attempts + 1): + try: + database = await failover_strategy.database() + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + await asyncio.sleep(0.11) + pass + + assert state_mock.call_count == 4 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_get_valid_database_throws_exception_on_attempts_does_not_exceed_delay(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + ) + type(mock_db.circuit).state = state_mock + failover_attempts = 3 + + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy( + failover_attempts=failover_attempts, + failover_delay=0.1 + ) + failover_strategy.set_databases(databases) + + with pytest.raises(TemporaryUnavailableException, match=( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + )): + for i in range(failover_attempts + 1): + try: + database = await failover_strategy.database() + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + if i == failover_attempts: + raise e assert state_mock.call_count == 4 @@ -114,8 +186,7 @@ async def test_get_valid_database_throws_exception_with_retries(self, mock_db, m indirect=True, ) async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): - retry = Retry(NoBackoff(), 0) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy = WeightBasedFailoverStrategy(failover_attempts=0, failover_delay=0) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): assert await failover_strategy.database() \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py index a9e0b8083d..492919cdac 100644 --- a/tests/test_asyncio/test_multidb/test_pipeline.py +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -6,12 +6,10 @@ from redis.asyncio.client import Pipeline from redis.asyncio.multidb.client import MultiDBClient -from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy from redis.asyncio.multidb.healthcheck import EchoHealthCheck from redis.asyncio.retry import Retry from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter -from redis.multidb.config import DEFAULT_FAILOVER_BACKOFF from tests.test_asyncio.test_multidb.conftest import create_weighted_list @@ -147,9 +145,7 @@ async def test_execute_pipeline_against_correct_db_on_background_health_check_de mock_db2.client.pipeline.return_value = pipe2 mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) @@ -297,9 +293,7 @@ async def test_execute_transaction_against_correct_db_on_background_health_check mock_db2.client.transaction.return_value = ['OK2', 'value'] mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index dd3deba7b1..2540c8a99d 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -9,7 +9,12 @@ from redis.asyncio import RedisCluster from redis.asyncio.client import Pipeline, Redis from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.failover import DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY from redis.asyncio.multidb.healthcheck import LagAwareHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ConstantBackoff +from redis.multidb.exception import TemporaryUnavailableException +from redis.utils import dummy_fail_async from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -49,24 +54,40 @@ def teardown_method(self, method): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + # Handle unavailable databases from previous test. + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) + async with MultiDBClient(client_config) as r_multi_db: event = asyncio.Event() asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) - await r_multi_db.set('key', 'value') + await retry.call_with_retry( + lambda : r_multi_db.set('key', 'value'), + lambda _: dummy_fail_async() + ) # Execute commands before network failure while not event.is_set(): - assert await r_multi_db.get('key') == 'value' + assert await retry.call_with_retry( + lambda: r_multi_db.get('key') , + lambda _: dummy_fail_async() + ) == 'value' await asyncio.sleep(0.5) # Execute commands until database failover while not listener.is_changed_flag: - assert await r_multi_db.get('key') == 'value' + assert await retry.call_with_retry( + lambda: r_multi_db.get('key'), + lambda _: dummy_fail_async() + ) == 'value' await asyncio.sleep(0.5) @pytest.mark.asyncio @@ -83,24 +104,38 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) async with MultiDBClient(client_config) as r_multi_db: event = asyncio.Event() asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) - await r_multi_db.set('key', 'value') + await retry.call_with_retry( + lambda: r_multi_db.set('key', 'value'), + lambda _: dummy_fail_async() + ) # Execute commands before network failure while not event.is_set(): - assert await r_multi_db.get('key') == 'value' + assert await retry.call_with_retry( + lambda: r_multi_db.get('key'), + lambda _: dummy_fail_async() + ) == 'value' await asyncio.sleep(0.5) # Execute commands after network failure while not listener.is_changed_flag: - assert await r_multi_db.get('key') == 'value' + assert await retry.call_with_retry( + lambda: r_multi_db.get('key'), + lambda _: dummy_fail_async() + ) == 'value' await asyncio.sleep(0.5) @pytest.mark.asyncio @@ -113,9 +148,24 @@ async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fau ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) + + async def callback(): + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] async with MultiDBClient(client_config) as r_multi_db: event = asyncio.Event() @@ -123,26 +173,18 @@ async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, # Execute pipeline before network failure while not event.is_set(): - async with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + await retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) # Execute commands until database failover while not listener.is_changed_flag: - async with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail_async() + ) await asyncio.sleep(0.5) @pytest.mark.asyncio @@ -155,9 +197,24 @@ async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) + + async def callback(): + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] async with MultiDBClient(client_config) as r_multi_db: event = asyncio.Event() @@ -165,27 +222,19 @@ async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_ # Execute pipeline before network failure while not event.is_set(): - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + await retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) # Execute pipeline until database failover while not listener.is_changed_flag: - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + await retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -197,10 +246,16 @@ async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_ ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(100) async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) + async def callback(pipe: Pipeline): pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -215,12 +270,18 @@ async def callback(pipe: Pipeline): # Execute transaction before network failure while not event.is_set(): - await r_multi_db.transaction(callback) + await retry.call_with_retry( + lambda: r_multi_db.transaction(callback), + lambda _: dummy_fail_async() + ) await asyncio.sleep(0.5) # Execute transaction until database failover while not listener.is_changed_flag: - await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] + assert await retry.call_with_retry( + lambda: r_multi_db.transaction(callback), + lambda _: dummy_fail_async() + ) == [True, True, True, 'value1', 'value2', 'value3'] await asyncio.sleep(0.5) @pytest.mark.asyncio @@ -232,6 +293,11 @@ async def callback(pipe: Pipeline): @pytest.mark.timeout(60) async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) data = json.dumps({'message': 'test'}) messages_count = 0 @@ -247,22 +313,34 @@ async def handler(message): pubsub = await r_multi_db.pubsub() # Assign a handler and run in a separate thread. - await pubsub.subscribe(**{'test-channel': handler}) + await retry.call_with_retry( + lambda: pubsub.subscribe(**{'test-channel': handler}), + lambda _: dummy_fail_async() + ) task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) # Execute publish before network failure while not event.is_set(): - await r_multi_db.publish('test-channel', data) + await retry.call_with_retry( + lambda: r_multi_db.publish('test-channel', data), + lambda _: dummy_fail_async() + ) await asyncio.sleep(0.5) # Execute publish until database failover while not listener.is_changed_flag: - await r_multi_db.publish('test-channel', data) + await retry.call_with_retry( + lambda: r_multi_db.publish('test-channel', data), + lambda _: dummy_fail_async() + ) await asyncio.sleep(0.5) # After db changed still generates some traffic. for _ in range(5): - await r_multi_db.publish('test-channel', data) + await retry.call_with_retry( + lambda: r_multi_db.publish('test-channel', data), + lambda _: dummy_fail_async() + ) # A timeout to ensure that an async handler will handle all previous messages. await asyncio.sleep(0.1) From d41440aadc88b6f23877ed5cf50432ff18417d2c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 16 Sep 2025 10:53:57 +0300 Subject: [PATCH 17/20] Changed default values according to a design doc --- redis/asyncio/multidb/config.py | 2 +- redis/asyncio/multidb/failure_detector.py | 2 +- redis/asyncio/multidb/healthcheck.py | 3 ++- redis/multidb/circuit.py | 2 +- redis/multidb/config.py | 2 +- redis/multidb/failure_detector.py | 2 +- redis/multidb/healthcheck.py | 4 +++- tests/test_asyncio/test_multidb/test_healthcheck.py | 2 +- tests/test_multidb/test_healthcheck.py | 2 +- 9 files changed, 12 insertions(+), 9 deletions(-) diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py index eff7c994e4..421255223e 100644 --- a/redis/asyncio/multidb/config.py +++ b/redis/asyncio/multidb/config.py @@ -18,7 +18,7 @@ from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter, DEFAULT_GRACE_PERIOD from redis.multidb.failure_detector import CommandFailureDetector -DEFAULT_AUTO_FALLBACK_INTERVAL = -1 +DEFAULT_AUTO_FALLBACK_INTERVAL = 120 def default_event_dispatcher() -> EventDispatcherInterface: return EventDispatcher() diff --git a/redis/asyncio/multidb/failure_detector.py b/redis/asyncio/multidb/failure_detector.py index 687e294c6d..cdfcc6ff1e 100644 --- a/redis/asyncio/multidb/failure_detector.py +++ b/redis/asyncio/multidb/failure_detector.py @@ -2,7 +2,7 @@ from redis.multidb.failure_detector import FailureDetector -DEFAULT_FAILURES_THRESHOLD = 3 +DEFAULT_FAILURES_THRESHOLD = 1000 DEFAULT_FAILURES_DURATION = 2 class AsyncFailureDetector(ABC): diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index 0b3f57702d..fe60d4a4bb 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -14,6 +14,7 @@ DEFAULT_HEALTH_CHECK_PROBES = 3 DEFAULT_HEALTH_CHECK_INTERVAL = 5 DEFAULT_HEALTH_CHECK_DELAY = 0.5 +DEFAULT_LAG_AWARE_TOLERANCE = 5000 logger = logging.getLogger(__name__) @@ -188,7 +189,7 @@ class LagAwareHealthCheck(HealthCheck): def __init__( self, rest_api_port: int = 9443, - lag_aware_tolerance: int = 100, + lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE, timeout: float = DEFAULT_TIMEOUT, auth_basic: Optional[Tuple[str, str]] = None, verify_tls: bool = True, diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 5796840e27..5757f3e6d9 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -4,7 +4,7 @@ import pybreaker -DEFAULT_GRACE_PERIOD = 5.0 +DEFAULT_GRACE_PERIOD = 60 class State(Enum): CLOSED = 'closed' diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 6e990db328..db07cb8748 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -19,7 +19,7 @@ DEFAULT_FAILOVER_DELAY from redis.retry import Retry -DEFAULT_AUTO_FALLBACK_INTERVAL = -1 +DEFAULT_AUTO_FALLBACK_INTERVAL = 120 def default_event_dispatcher() -> EventDispatcherInterface: return EventDispatcher() diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 09e9274e8a..6b918b152a 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -7,7 +7,7 @@ from redis.multidb.circuit import State as CBState -DEFAULT_FAILURES_THRESHOLD = 3 +DEFAULT_FAILURES_THRESHOLD = 1000 DEFAULT_FAILURES_DURATION = 2 class FailureDetector(ABC): diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 86d3983444..48f67b6746 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -13,6 +13,8 @@ DEFAULT_HEALTH_CHECK_PROBES = 3 DEFAULT_HEALTH_CHECK_INTERVAL = 5 DEFAULT_HEALTH_CHECK_DELAY = 0.5 +DEFAULT_LAG_AWARE_TOLERANCE = 5000 + logger = logging.getLogger(__name__) class HealthCheck(ABC): @@ -187,7 +189,7 @@ class LagAwareHealthCheck(HealthCheck): def __init__( self, rest_api_port: int = 9443, - lag_aware_tolerance: int = 100, + lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE, timeout: float = DEFAULT_TIMEOUT, auth_basic: Optional[Tuple[str, str]] = None, verify_tls: bool = True, diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py index 4924914769..72da0ef737 100644 --- a/tests/test_asyncio/test_multidb/test_healthcheck.py +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -287,7 +287,7 @@ async def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, m assert await hc.check_health(db) is True assert mock_http.get.call_count == 2 - assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=100" + assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=5000" @pytest.mark.asyncio async def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 5f8be6add5..43ad1ac888 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -273,7 +273,7 @@ def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb assert hc.check_health(db) is True assert mock_http.get.call_count == 2 - assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=100" + assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=5000" def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): """ From 9550299721e0ffdba31c9bbb107faafc73f6b9e3 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 16 Sep 2025 11:46:57 +0300 Subject: [PATCH 18/20] [Async] Added Strategy Executor --- redis/asyncio/multidb/client.py | 2 + redis/asyncio/multidb/command_executor.py | 23 ++- redis/asyncio/multidb/config.py | 11 +- redis/asyncio/multidb/failover.py | 71 +++++++-- .../test_multidb/test_failover.py | 137 +++++++----------- 5 files changed, 132 insertions(+), 112 deletions(-) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index a7f591bda0..b9925ea928 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -50,6 +50,8 @@ def __init__(self, config: MultiDbConfig): databases=self._databases, command_retry=self._command_retry, failover_strategy=self._failover_strategy, + failover_attempts=config.failover_attempts, + failover_delay=config.failover_delay, event_dispatcher=self._event_dispatcher, auto_fallback_interval=self._auto_fallback_interval, ) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index d63b19269d..c08d33a8f3 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -8,7 +8,8 @@ from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ ResubscribeOnActiveDatabaseChanged -from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failover import AsyncFailoverStrategy, StrategyExecutor, DefaultStrategyExecutor, \ + DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY from redis.asyncio.multidb.failure_detector import AsyncFailureDetector from redis.multidb.circuit import State as CBState from redis.asyncio.retry import Retry @@ -62,8 +63,8 @@ def active_pubsub(self, pubsub: PubSub) -> None: @property @abstractmethod - def failover_strategy(self) -> AsyncFailoverStrategy: - """Returns failover strategy.""" + def strategy_executor(self) -> StrategyExecutor: + """Returns failover strategy executor.""" pass @property @@ -111,6 +112,8 @@ def __init__( command_retry: Retry, failover_strategy: AsyncFailoverStrategy, event_dispatcher: EventDispatcherInterface, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, ): """ @@ -122,6 +125,8 @@ def __init__( command_retry: Retry policy for failed command execution failover_strategy: Strategy for handling database failover event_dispatcher: Interface for dispatching events + failover_attempts: Number of failover attempts + failover_delay: Delay between failover attempts auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database """ super().__init__(auto_fallback_interval) @@ -132,7 +137,11 @@ def __init__( self._databases = databases self._failure_detectors = failure_detectors self._command_retry = command_retry - self._failover_strategy = failover_strategy + self._strategy_executor = DefaultStrategyExecutor( + failover_strategy, + failover_attempts, + failover_delay + ) self._event_dispatcher = event_dispatcher self._active_database: Optional[Database] = None self._active_pubsub: Optional[PubSub] = None @@ -173,8 +182,8 @@ def active_pubsub(self, pubsub: PubSub) -> None: self._active_pubsub = pubsub @property - def failover_strategy(self) -> AsyncFailoverStrategy: - return self._failover_strategy + def strategy_executor(self) -> StrategyExecutor: + return self._strategy_executor @property def command_retry(self) -> Retry: @@ -265,7 +274,7 @@ async def _check_active_database(self): and self._next_fallback_attempt <= datetime.now() ) ): - await self.set_active_database(await self._failover_strategy.database()) + await self.set_active_database(await self._strategy_executor.execute()) self._schedule_next_fallback() async def _on_command_fail(self, error, *args): diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py index 421255223e..354bbcf5c7 100644 --- a/redis/asyncio/multidb/config.py +++ b/redis/asyncio/multidb/config.py @@ -77,8 +77,8 @@ class MultiDbConfig: health_check_probes: Number of attempts to evaluate the health of a database. health_check_delay: Delay between health check attempts. failover_strategy: Optional strategy for handling database failover scenarios. - failover_retries: Number of retries allowed for failover operations. - failover_backoff: Backoff strategy for failover retries. + failover_attempts: Number of retries allowed for failover operations. + failover_delay: Delay between failover attempts. auto_fallback_interval: Time interval to trigger automatic fallback. event_dispatcher: Interface for dispatching events related to database operations. @@ -113,7 +113,7 @@ class MultiDbConfig: health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY failover_strategy: Optional[AsyncFailoverStrategy] = None - failover_retries: int = DEFAULT_FAILOVER_ATTEMPTS + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS failover_delay: float = DEFAULT_FAILOVER_DELAY auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) @@ -161,7 +161,4 @@ def default_health_checks(self) -> List[HealthCheck]: ] def default_failover_strategy(self) -> AsyncFailoverStrategy: - return WeightBasedFailoverStrategy( - failover_delay=self.failover_delay, - failover_attempts=self.failover_retries, - ) \ No newline at end of file + return WeightBasedFailoverStrategy() \ No newline at end of file diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py index 62e415d397..a91f527057 100644 --- a/redis/asyncio/multidb/failover.py +++ b/redis/asyncio/multidb/failover.py @@ -21,29 +21,81 @@ def set_databases(self, databases: Databases) -> None: """Set the database strategy operates on.""" pass +class StrategyExecutor(ABC): + + @property + @abstractmethod + def failover_attempts(self) -> int: + """The number of failover attempts.""" + pass + + @property + @abstractmethod + def failover_delay(self) -> float: + """The delay between failover attempts.""" + pass + + @property + @abstractmethod + def strategy(self) -> AsyncFailoverStrategy: + """The strategy to execute.""" + pass + + @abstractmethod + async def execute(self) -> AsyncDatabase: + """Execute the failover strategy.""" + pass + class WeightBasedFailoverStrategy(AsyncFailoverStrategy): """ Failover strategy based on database weights. """ + def __init__(self): + self._databases = WeightedList() + + async def database(self) -> AsyncDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException('No valid database available for communication') + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + +class DefaultStrategyExecutor(StrategyExecutor): + """ + Executes given failover strategy. + """ def __init__( self, + strategy: AsyncFailoverStrategy, failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, failover_delay: float = DEFAULT_FAILOVER_DELAY, ): - self._databases = WeightedList() + self._strategy = strategy self._failover_attempts = failover_attempts self._failover_delay = failover_delay self._next_attempt_ts: int = 0 self._failover_counter: int = 0 - async def database(self) -> AsyncDatabase: - try: - for database, _ in self._databases: - if database.circuit.state == CBState.CLOSED: - self._reset() - return database + @property + def failover_attempts(self) -> int: + return self._failover_attempts + + @property + def failover_delay(self) -> float: + return self._failover_delay + + @property + def strategy(self) -> AsyncFailoverStrategy: + return self._strategy - raise NoValidDatabaseException('No valid database available for communication') + async def execute(self) -> AsyncDatabase: + try: + database = await self._strategy.database() + self._reset() + return database except NoValidDatabaseException as e: if self._next_attempt_ts == 0: self._next_attempt_ts = time.time() + self._failover_delay @@ -61,9 +113,6 @@ async def database(self) -> AsyncDatabase: "This is a temporary condition - please retry the operation." ) - def set_databases(self, databases: Databases) -> None: - self._databases = databases - def _reset(self) -> None: self._next_attempt_ts = 0 self._failover_counter = 0 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py index 7319ffc9bd..02aaf8e461 100644 --- a/tests/test_asyncio/test_multidb/test_failover.py +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -1,14 +1,11 @@ import asyncio -from unittest.mock import PropertyMock import pytest -from redis.backoff import NoBackoff, ExponentialBackoff from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException -from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy -from redis.asyncio.retry import Retry +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy, DefaultStrategyExecutor class TestAsyncWeightBasedFailoverStrategy: @@ -53,26 +50,38 @@ async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): ], indirect=True, ) - async def test_get_valid_database_with_failover_attempts(self, mock_db, mock_db1, mock_db2): - state_mock = PropertyMock( - side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] - ) - type(mock_db.circuit).state = state_mock - failover_attempts = 3 + async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + failover_strategy = WeightBasedFailoverStrategy() - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy( + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert await failover_strategy.database() + +class TestDefaultStrategyExecutor: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_execute_returns_valid_database_with_failover_attempts(self, mock_db, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + mock_db + ] + executor = DefaultStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) - failover_strategy.set_databases(databases) for i in range(failover_attempts + 1): try: - database = await failover_strategy.database() + database = await executor.execute() assert database == mock_db except TemporaryUnavailableException as e: assert e.args[0] == ( @@ -82,41 +91,27 @@ async def test_get_valid_database_with_failover_attempts(self, mock_db, mock_db1 await asyncio.sleep(0.11) pass - assert state_mock.call_count == 4 + assert mock_fs.database.call_count == 4 @pytest.mark.asyncio - @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', - [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, - ), - ], - indirect=True, - ) - async def test_get_valid_database_throws_exception_on_attempts_exceed(self, mock_db, mock_db1, mock_db2): - state_mock = PropertyMock( - side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] - ) - type(mock_db.circuit).state = state_mock + async def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): failover_attempts = 3 - - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy( + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException + ] + executor = DefaultStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) - failover_strategy.set_databases(databases) - with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + with pytest.raises(NoValidDatabaseException): for i in range(failover_attempts + 1): try: - database = await failover_strategy.database() + await executor.execute() except TemporaryUnavailableException as e: assert e.args[0] == ( "No database connections currently available. " @@ -125,36 +120,22 @@ async def test_get_valid_database_throws_exception_on_attempts_exceed(self, mock await asyncio.sleep(0.11) pass - assert state_mock.call_count == 4 + assert mock_fs.database.call_count == 4 @pytest.mark.asyncio - @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', - [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, - ), - ], - indirect=True, - ) - async def test_get_valid_database_throws_exception_on_attempts_does_not_exceed_delay(self, mock_db, mock_db1, mock_db2): - state_mock = PropertyMock( - side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] - ) - type(mock_db.circuit).state = state_mock + async def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, mock_fs): failover_attempts = 3 - - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy( + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException + ] + executor = DefaultStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) - failover_strategy.set_databases(databases) with pytest.raises(TemporaryUnavailableException, match=( "No database connections currently available. " @@ -162,7 +143,7 @@ async def test_get_valid_database_throws_exception_on_attempts_does_not_exceed_d )): for i in range(failover_attempts + 1): try: - database = await failover_strategy.database() + await executor.execute() except TemporaryUnavailableException as e: assert e.args[0] == ( "No database connections currently available. " @@ -171,22 +152,4 @@ async def test_get_valid_database_throws_exception_on_attempts_does_not_exceed_d if i == failover_attempts: raise e - assert state_mock.call_count == 4 - - @pytest.mark.asyncio - @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', - [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, - ), - ], - indirect=True, - ) - async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): - failover_strategy = WeightBasedFailoverStrategy(failover_attempts=0, failover_delay=0) - - with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert await failover_strategy.database() \ No newline at end of file + assert mock_fs.database.call_count == 4 \ No newline at end of file From 0b4b4f095d29f3085eca5ec2acdb5f2cd5246c57 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 16 Sep 2025 12:46:21 +0300 Subject: [PATCH 19/20] [Sync] Added Strategy Executor --- redis/multidb/client.py | 2 + redis/multidb/command_executor.py | 23 +++- redis/multidb/config.py | 13 +- redis/multidb/failover.py | 73 +++++++++-- tests/test_multidb/test_client.py | 12 +- tests/test_multidb/test_command_executor.py | 15 +-- tests/test_multidb/test_failover.py | 136 ++++++++------------ tests/test_multidb/test_pipeline.py | 26 ++-- 8 files changed, 153 insertions(+), 147 deletions(-) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 748cef4855..19f846bd29 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -50,6 +50,8 @@ def __init__(self, config: MultiDbConfig): databases=self._databases, command_retry=self._command_retry, failover_strategy=self._failover_strategy, + failover_attempts=config.failover_attempts, + failover_delay=config.failover_delay, event_dispatcher=self._event_dispatcher, auto_fallback_interval=self._auto_fallback_interval, ) diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 364c0a07ea..92f1a20fee 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -8,7 +8,8 @@ from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged -from redis.multidb.failover import FailoverStrategy +from redis.multidb.failover import FailoverStrategy, StrategyExecutor, DEFAULT_FAILOVER_ATTEMPTS, \ + DEFAULT_FAILOVER_DELAY, DefaultStrategyExecutor from redis.multidb.failure_detector import FailureDetector from redis.retry import Retry @@ -94,8 +95,8 @@ def active_pubsub(self, pubsub: PubSub) -> None: @property @abstractmethod - def failover_strategy(self) -> FailoverStrategy: - """Returns failover strategy.""" + def strategy_executor(self) -> StrategyExecutor: + """Returns failover strategy executor.""" pass @property @@ -142,6 +143,8 @@ def __init__( command_retry: Retry, failover_strategy: FailoverStrategy, event_dispatcher: EventDispatcherInterface, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, ): """ @@ -153,6 +156,8 @@ def __init__( command_retry: Retry policy for failed command execution failover_strategy: Strategy for handling database failover event_dispatcher: Interface for dispatching events + failover_attempts: Number of failover attempts + failover_delay: Delay between failover attempts auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database """ super().__init__(auto_fallback_interval) @@ -163,7 +168,11 @@ def __init__( self._databases = databases self._failure_detectors = failure_detectors self._command_retry = command_retry - self._failover_strategy = failover_strategy + self._strategy_executor = DefaultStrategyExecutor( + failover_strategy, + failover_attempts, + failover_delay + ) self._event_dispatcher = event_dispatcher self._active_database: Optional[Database] = None self._active_pubsub: Optional[PubSub] = None @@ -209,8 +218,8 @@ def active_pubsub(self, pubsub: PubSub) -> None: self._active_pubsub = pubsub @property - def failover_strategy(self) -> FailoverStrategy: - return self._failover_strategy + def strategy_executor(self) -> StrategyExecutor: + return self._strategy_executor def execute_command(self, *args, **options): def callback(): @@ -285,7 +294,7 @@ def _check_active_database(self): and self._next_fallback_attempt <= datetime.now() ) ): - self.active_database = self._failover_strategy.database + self.active_database = self._strategy_executor.execute() self._schedule_next_fallback() def _setup_event_dispatcher(self): diff --git a/redis/multidb/config.py b/redis/multidb/config.py index db07cb8748..ff9872ffd4 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -5,7 +5,7 @@ from typing_extensions import Optional from redis import Redis, ConnectionPool -from redis.asyncio import RedisCluster +from redis import RedisCluster from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface @@ -79,8 +79,8 @@ class MultiDbConfig: health_check_delay: Delay between health check attempts. health_check_policy: Policy for determining database health based on health checks. failover_strategy: Optional strategy for handling database failover scenarios. - failover_retries: Number of retries allowed for failover operations. - failover_backoff: Backoff strategy for failover retries. + failover_attempts: Number of retries allowed for failover operations. + failover_delay: Delay between failover attempts. auto_fallback_interval: Time interval to trigger automatic fallback. event_dispatcher: Interface for dispatching events related to database operations. @@ -115,7 +115,7 @@ class MultiDbConfig: health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY failover_strategy: Optional[FailoverStrategy] = None - failover_retries: int = DEFAULT_FAILOVER_ATTEMPTS + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS failover_delay: float = DEFAULT_FAILOVER_DELAY auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) @@ -161,7 +161,4 @@ def default_health_checks(self) -> List[HealthCheck]: ] def default_failover_strategy(self) -> FailoverStrategy: - return WeightBasedFailoverStrategy( - failover_delay=self.failover_delay, - failover_attempts=self.failover_retries, - ) + return WeightBasedFailoverStrategy() diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index 6f7ac8fd17..7faf6c8a5f 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -11,7 +11,6 @@ class FailoverStrategy(ABC): - @property @abstractmethod def database(self) -> SyncDatabase: """Select the database according to the strategy.""" @@ -22,30 +21,81 @@ def set_databases(self, databases: Databases) -> None: """Set the database strategy operates on.""" pass +class StrategyExecutor(ABC): + + @property + @abstractmethod + def failover_attempts(self) -> int: + """The number of failover attempts.""" + pass + + @property + @abstractmethod + def failover_delay(self) -> float: + """The delay between failover attempts.""" + pass + + @property + @abstractmethod + def strategy(self) -> FailoverStrategy: + """The strategy to execute.""" + pass + + @abstractmethod + def execute(self) -> SyncDatabase: + """Execute the failover strategy.""" + pass + class WeightBasedFailoverStrategy(FailoverStrategy): """ Failover strategy based on database weights. """ + def __init__(self) -> None: + self._databases = WeightedList() + + def database(self) -> SyncDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException('No valid database available for communication') + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + +class DefaultStrategyExecutor(StrategyExecutor): + """ + Executes given failover strategy. + """ def __init__( self, + strategy: FailoverStrategy, failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, failover_delay: float = DEFAULT_FAILOVER_DELAY, - ) -> None: - self._databases = WeightedList() + ): + self._strategy = strategy self._failover_attempts = failover_attempts self._failover_delay = failover_delay self._next_attempt_ts: int = 0 self._failover_counter: int = 0 @property - def database(self) -> SyncDatabase: - try: - for database, _ in self._databases: - if database.circuit.state == CBState.CLOSED: - self._reset() - return database + def failover_attempts(self) -> int: + return self._failover_attempts + + @property + def failover_delay(self) -> float: + return self._failover_delay + + @property + def strategy(self) -> FailoverStrategy: + return self._strategy - raise NoValidDatabaseException('No valid database available for communication') + def execute(self) -> SyncDatabase: + try: + database = self._strategy.database() + self._reset() + return database except NoValidDatabaseException as e: if self._next_attempt_ts == 0: self._next_attempt_ts = time.time() + self._failover_delay @@ -63,9 +113,6 @@ def database(self) -> SyncDatabase: "This is a temporary condition - please retry the operation." ) - def set_databases(self, databases: Databases) -> None: - self._databases = databases - def _reset(self) -> None: self._next_attempt_ts = 0 self._failover_counter = 0 diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index a818b90eb9..5e710f23c2 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -9,7 +9,7 @@ from redis.multidb.database import SyncDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException -from redis.multidb.failover import WeightBasedFailoverStrategy, DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY +from redis.multidb.failover import WeightBasedFailoverStrategy from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck from tests.test_multidb.conftest import create_weighted_list @@ -116,10 +116,7 @@ def test_execute_command_against_correct_db_on_background_health_check_determine mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] mock_multi_db_config.health_check_interval = 0.2 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - failover_attempts=DEFAULT_FAILOVER_ATTEMPTS, - failover_delay=DEFAULT_FAILOVER_DELAY - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' @@ -154,10 +151,7 @@ def test_execute_command_auto_fallback_to_highest_weight_db( mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] mock_multi_db_config.health_check_interval = 0.2 mock_multi_db_config.auto_fallback_interval = 0.4 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - failover_attempts=DEFAULT_FAILOVER_ATTEMPTS, - failover_delay=DEFAULT_FAILOVER_DELAY - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py index 675f9d442f..044fef0f8c 100644 --- a/tests/test_multidb/test_command_executor.py +++ b/tests/test_multidb/test_command_executor.py @@ -61,8 +61,7 @@ def test_execute_command_automatically_select_active_database( ): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' - mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2]) - type(mock_fs).database = mock_selector + mock_fs.database.side_effect = [mock_db1, mock_db2] databases = create_weighted_list(mock_db, mock_db1, mock_db2) executor = DefaultCommandExecutor( @@ -78,7 +77,7 @@ def test_execute_command_automatically_select_active_database( assert executor.execute_command('SET', 'key', 'value') == 'OK2' assert mock_ed.register_listeners.call_count == 1 - assert mock_selector.call_count == 2 + assert mock_fs.database.call_count == 2 @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -96,8 +95,7 @@ def test_execute_command_fallback_to_another_db_after_fallback_interval( ): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' - mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) - type(mock_fs).database = mock_selector + mock_fs.database.side_effect = [mock_db1, mock_db2, mock_db1] databases = create_weighted_list(mock_db, mock_db1, mock_db2) executor = DefaultCommandExecutor( @@ -119,7 +117,7 @@ def test_execute_command_fallback_to_another_db_after_fallback_interval( assert executor.execute_command('SET', 'key', 'value') == 'OK1' assert mock_ed.register_listeners.call_count == 1 - assert mock_selector.call_count == 3 + assert mock_fs.database.call_count == 3 @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -137,8 +135,7 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( ): mock_db1.client.execute_command.side_effect = ['OK1', ConnectionError, ConnectionError, ConnectionError, 'OK1'] mock_db2.client.execute_command.side_effect = ['OK2', ConnectionError, ConnectionError, ConnectionError] - mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) - type(mock_fs).database = mock_selector + mock_fs.database.side_effect = [mock_db1, mock_db2, mock_db1] threshold = 3 fd = CommandFailureDetector(threshold, 1) ed = EventDispatcher() @@ -157,4 +154,4 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( assert executor.execute_command('SET', 'key', 'value') == 'OK1' assert executor.execute_command('SET', 'key', 'value') == 'OK2' assert executor.execute_command('SET', 'key', 'value') == 'OK1' - assert mock_selector.call_count == 3 \ No newline at end of file + assert mock_fs.database.call_count == 3 \ No newline at end of file diff --git a/tests/test_multidb/test_failover.py b/tests/test_multidb/test_failover.py index 0759df88d6..88a19b6a2e 100644 --- a/tests/test_multidb/test_failover.py +++ b/tests/test_multidb/test_failover.py @@ -1,12 +1,12 @@ from time import sleep -from unittest.mock import PropertyMock import pytest from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException -from redis.multidb.failover import WeightBasedFailoverStrategy +from redis.multidb.failover import WeightBasedFailoverStrategy, DefaultStrategyExecutor + class TestWeightBasedFailoverStrategy: @pytest.mark.parametrize( @@ -35,7 +35,7 @@ def test_get_valid_database(self, mock_db, mock_db1, mock_db2): failover_strategy = WeightBasedFailoverStrategy() failover_strategy.set_databases(databases) - assert failover_strategy.database == mock_db1 + assert failover_strategy.database() == mock_db1 @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -48,26 +48,37 @@ def test_get_valid_database(self, mock_db, mock_db1, mock_db2): ], indirect=True, ) - def test_get_valid_database_with_failover_attempts(self, mock_db, mock_db1, mock_db2): - state_mock = PropertyMock( - side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] - ) - type(mock_db.circuit).state = state_mock - failover_attempts = 3 + def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + failover_strategy = WeightBasedFailoverStrategy() - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy( + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert failover_strategy.database() + +class TestDefaultStrategyExecutor: + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_execute_returns_valid_database_with_failover_attempts(self, mock_db, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + mock_db + ] + executor = DefaultStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) - failover_strategy.set_databases(databases) for i in range(failover_attempts + 1): try: - database = failover_strategy.database + database = executor.execute() assert database == mock_db except TemporaryUnavailableException as e: assert e.args[0] == ( @@ -77,40 +88,26 @@ def test_get_valid_database_with_failover_attempts(self, mock_db, mock_db1, mock sleep(0.11) pass - assert state_mock.call_count == 4 + assert mock_fs.database.call_count == 4 - @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', - [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, - ), - ], - indirect=True, - ) - def test_get_valid_database_throws_exception_on_attempts_exceed(self, mock_db, mock_db1, mock_db2): - state_mock = PropertyMock( - side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] - ) - type(mock_db.circuit).state = state_mock + def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): failover_attempts = 3 - - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy( + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException + ] + executor = DefaultStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) - failover_strategy.set_databases(databases) - with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + with pytest.raises(NoValidDatabaseException): for i in range(failover_attempts + 1): try: - database = failover_strategy.database + executor.execute() except TemporaryUnavailableException as e: assert e.args[0] == ( "No database connections currently available. " @@ -119,35 +116,21 @@ def test_get_valid_database_throws_exception_on_attempts_exceed(self, mock_db, m sleep(0.11) pass - assert state_mock.call_count == 4 + assert mock_fs.database.call_count == 4 - @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', - [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, - ), - ], - indirect=True, - ) - def test_get_valid_database_throws_exception_on_attempts_does_not_exceed_delay(self, mock_db, mock_db1, mock_db2): - state_mock = PropertyMock( - side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] - ) - type(mock_db.circuit).state = state_mock + def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, mock_fs): failover_attempts = 3 - - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy( + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException + ] + executor = DefaultStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) - failover_strategy.set_databases(databases) with pytest.raises(TemporaryUnavailableException, match=( "No database connections currently available. " @@ -155,7 +138,7 @@ def test_get_valid_database_throws_exception_on_attempts_does_not_exceed_delay(s )): for i in range(failover_attempts + 1): try: - database = failover_strategy.database + executor.execute() except TemporaryUnavailableException as e: assert e.args[0] == ( "No database connections currently available. " @@ -164,21 +147,4 @@ def test_get_valid_database_throws_exception_on_attempts_does_not_exceed_delay(s if i == failover_attempts: raise e - assert state_mock.call_count == 4 - - @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', - [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, - ), - ], - indirect=True, - ) - def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): - failover_strategy = WeightBasedFailoverStrategy(failover_attempts=0, failover_delay=0) - - with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert failover_strategy.database \ No newline at end of file + assert mock_fs.database.call_count == 4 \ No newline at end of file diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index 0176581d20..54f6a4df17 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -140,11 +140,8 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin pipe2.execute.return_value = ['OK2', 'value'] mock_db2.client.pipeline.return_value = pipe2 - mock_multi_db_config.health_check_interval = 0.2 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - failover_attempts=DEFAULT_FAILOVER_ATTEMPTS, - failover_delay=DEFAULT_FAILOVER_DELAY - ) + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) @@ -154,7 +151,7 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin assert pipe.execute() == ['OK1', 'value'] - sleep(0.3) + sleep(0.15) with client.pipeline() as pipe: pipe.set('key1', 'value') @@ -162,7 +159,7 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin assert pipe.execute() == ['OK2', 'value'] - sleep(0.2) + sleep(0.1) with client.pipeline() as pipe: pipe.set('key1', 'value') @@ -170,7 +167,7 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin assert pipe.execute() == ['OK', 'value'] - sleep(0.2) + sleep(0.1) with client.pipeline() as pipe: pipe.set('key1', 'value') @@ -289,11 +286,8 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter mock_db1.client.transaction.return_value = ['OK1', 'value'] mock_db2.client.transaction.return_value = ['OK2', 'value'] - mock_multi_db_config.health_check_interval = 0.2 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - failover_attempts=DEFAULT_FAILOVER_ATTEMPTS, - failover_delay=DEFAULT_FAILOVER_DELAY - ) + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) @@ -302,9 +296,9 @@ def callback(pipe: Pipeline): pipe.get('key1') assert client.transaction(callback) == ['OK1', 'value'] - sleep(0.3) + sleep(0.15) assert client.transaction(callback) == ['OK2', 'value'] - sleep(0.2) + sleep(0.1) assert client.transaction(callback) == ['OK', 'value'] - sleep(0.2) + sleep(0.1) assert client.transaction(callback) == ['OK1', 'value'] \ No newline at end of file From 4dd4f0ada4aa58de4801d97b6d2e4dbee728b803 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 18 Sep 2025 10:50:57 +0300 Subject: [PATCH 20/20] Apply comments --- redis/asyncio/multidb/command_executor.py | 12 ++++++------ redis/asyncio/multidb/failover.py | 4 ++-- redis/asyncio/multidb/healthcheck.py | 16 +++++++++------- redis/multidb/command_executor.py | 14 +++++++------- redis/multidb/exception.py | 2 +- redis/multidb/failover.py | 4 ++-- redis/multidb/healthcheck.py | 18 ++++++++++-------- .../test_asyncio/test_multidb/test_failover.py | 8 ++++---- tests/test_multidb/test_failover.py | 8 ++++---- 9 files changed, 45 insertions(+), 41 deletions(-) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index c08d33a8f3..7e622d6260 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -8,7 +8,7 @@ from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ ResubscribeOnActiveDatabaseChanged -from redis.asyncio.multidb.failover import AsyncFailoverStrategy, StrategyExecutor, DefaultStrategyExecutor, \ +from redis.asyncio.multidb.failover import AsyncFailoverStrategy, FailoverStrategyExecutor, DefaultFailoverStrategyExecutor, \ DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY from redis.asyncio.multidb.failure_detector import AsyncFailureDetector from redis.multidb.circuit import State as CBState @@ -63,7 +63,7 @@ def active_pubsub(self, pubsub: PubSub) -> None: @property @abstractmethod - def strategy_executor(self) -> StrategyExecutor: + def failover_strategy_executor(self) -> FailoverStrategyExecutor: """Returns failover strategy executor.""" pass @@ -137,7 +137,7 @@ def __init__( self._databases = databases self._failure_detectors = failure_detectors self._command_retry = command_retry - self._strategy_executor = DefaultStrategyExecutor( + self._failover_strategy_executor = DefaultFailoverStrategyExecutor( failover_strategy, failover_attempts, failover_delay @@ -182,8 +182,8 @@ def active_pubsub(self, pubsub: PubSub) -> None: self._active_pubsub = pubsub @property - def strategy_executor(self) -> StrategyExecutor: - return self._strategy_executor + def failover_strategy_executor(self) -> FailoverStrategyExecutor: + return self._failover_strategy_executor @property def command_retry(self) -> Retry: @@ -274,7 +274,7 @@ async def _check_active_database(self): and self._next_fallback_attempt <= datetime.now() ) ): - await self.set_active_database(await self._strategy_executor.execute()) + await self.set_active_database(await self._failover_strategy_executor.execute()) self._schedule_next_fallback() async def _on_command_fail(self, error, *args): diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py index a91f527057..997b7941c4 100644 --- a/redis/asyncio/multidb/failover.py +++ b/redis/asyncio/multidb/failover.py @@ -21,7 +21,7 @@ def set_databases(self, databases: Databases) -> None: """Set the database strategy operates on.""" pass -class StrategyExecutor(ABC): +class FailoverStrategyExecutor(ABC): @property @abstractmethod @@ -63,7 +63,7 @@ async def database(self) -> AsyncDatabase: def set_databases(self, databases: Databases) -> None: self._databases = databases -class DefaultStrategyExecutor(StrategyExecutor): +class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor): """ Executes given failover strategy. """ diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index fe60d4a4bb..b5bf695380 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -4,6 +4,8 @@ from enum import Enum from typing import Optional, Tuple, Union, List +from pygments.lexers.julia import allowed_variable + from redis.asyncio import Redis from redis.asyncio.http.http_client import AsyncHTTPClientWrapper, DEFAULT_TIMEOUT from redis.asyncio.retry import Retry @@ -97,19 +99,19 @@ def __init__(self, health_check_probes: int, health_check_delay: float): async def execute(self, health_checks: List[HealthCheck], database) -> bool: for health_check in health_checks: if self.health_check_probes % 2 == 0: - unsuccessful_probes = self.health_check_probes / 2 + allowed_unsuccessful_probes = self.health_check_probes / 2 else: - unsuccessful_probes = (self.health_check_probes + 1) / 2 + allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2 for attempt in range(self.health_check_probes): try: if not await health_check.check_health(database): - unsuccessful_probes -= 1 - if unsuccessful_probes <= 0: + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: return False except Exception as e: - unsuccessful_probes -= 1 - if unsuccessful_probes <= 0: + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: raise UnhealthyDatabaseException( f"Unhealthy database", database, e ) @@ -174,7 +176,7 @@ async def check_health(self, database) -> bool: # For a cluster checks if all nodes are healthy. all_nodes = database.client.get_nodes() for node in all_nodes: - actual_message = await node.execute_command("ECHO" ,"healthcheck") + actual_message = await node.execute_command("ECHO", "healthcheck") if actual_message not in expected_message: return False diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 92f1a20fee..7ca7d2ec52 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -8,8 +8,8 @@ from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged -from redis.multidb.failover import FailoverStrategy, StrategyExecutor, DEFAULT_FAILOVER_ATTEMPTS, \ - DEFAULT_FAILOVER_DELAY, DefaultStrategyExecutor +from redis.multidb.failover import FailoverStrategy, FailoverStrategyExecutor, DEFAULT_FAILOVER_ATTEMPTS, \ + DEFAULT_FAILOVER_DELAY, DefaultFailoverStrategyExecutor from redis.multidb.failure_detector import FailureDetector from redis.retry import Retry @@ -95,7 +95,7 @@ def active_pubsub(self, pubsub: PubSub) -> None: @property @abstractmethod - def strategy_executor(self) -> StrategyExecutor: + def failover_strategy_executor(self) -> FailoverStrategyExecutor: """Returns failover strategy executor.""" pass @@ -168,7 +168,7 @@ def __init__( self._databases = databases self._failure_detectors = failure_detectors self._command_retry = command_retry - self._strategy_executor = DefaultStrategyExecutor( + self._failover_strategy_executor = DefaultFailoverStrategyExecutor( failover_strategy, failover_attempts, failover_delay @@ -218,8 +218,8 @@ def active_pubsub(self, pubsub: PubSub) -> None: self._active_pubsub = pubsub @property - def strategy_executor(self) -> StrategyExecutor: - return self._strategy_executor + def failover_strategy_executor(self) -> FailoverStrategyExecutor: + return self._failover_strategy_executor def execute_command(self, *args, **options): def callback(): @@ -294,7 +294,7 @@ def _check_active_database(self): and self._next_fallback_attempt <= datetime.now() ) ): - self.active_database = self._strategy_executor.execute() + self.active_database = self._failover_strategy_executor.execute() self._schedule_next_fallback() def _setup_event_dispatcher(self): diff --git a/redis/multidb/exception.py b/redis/multidb/exception.py index 3d4e9bf0ba..f54632cae7 100644 --- a/redis/multidb/exception.py +++ b/redis/multidb/exception.py @@ -10,5 +10,5 @@ def __init__(self, message, database, original_exception): self.original_exception = original_exception class TemporaryUnavailableException(Exception): - """Exception raised when all databases in setup is temporary unavailable.""" + """Exception raised when all databases in setup are temporary unavailable.""" pass \ No newline at end of file diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index 7faf6c8a5f..fbbd254252 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -21,7 +21,7 @@ def set_databases(self, databases: Databases) -> None: """Set the database strategy operates on.""" pass -class StrategyExecutor(ABC): +class FailoverStrategyExecutor(ABC): @property @abstractmethod @@ -63,7 +63,7 @@ def database(self) -> SyncDatabase: def set_databases(self, databases: Databases) -> None: self._databases = databases -class DefaultStrategyExecutor(StrategyExecutor): +class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor): """ Executes given failover strategy. """ diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 48f67b6746..fcfd7e44a8 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -4,6 +4,8 @@ from time import sleep from typing import Optional, Tuple, Union, List +from pygments.lexers.julia import allowed_variable + from redis import Redis from redis.backoff import NoBackoff from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient @@ -96,19 +98,19 @@ def __init__(self, health_check_probes: int, health_check_delay: float): def execute(self, health_checks: List[HealthCheck], database) -> bool: for health_check in health_checks: if self.health_check_probes % 2 == 0: - unsuccessful_probes = self.health_check_probes / 2 + allowed_unsuccessful_probes = self.health_check_probes / 2 else: - unsuccessful_probes = (self.health_check_probes + 1) / 2 + allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2 for attempt in range(self.health_check_probes): try: if not health_check.check_health(database): - unsuccessful_probes -= 1 - if unsuccessful_probes <= 0: + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: return False except Exception as e: - unsuccessful_probes -= 1 - if unsuccessful_probes <= 0: + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: raise UnhealthyDatabaseException( f"Unhealthy database", database, e ) @@ -167,13 +169,13 @@ def check_health(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] if isinstance(database.client, Redis): - actual_message = database.client.execute_command("ECHO" ,"healthcheck") + actual_message = database.client.execute_command("ECHO", "healthcheck") return actual_message in expected_message else: # For a cluster checks if all nodes are healthy. all_nodes = database.client.get_nodes() for node in all_nodes: - actual_message = node.redis_connection.execute_command("ECHO" ,"healthcheck") + actual_message = node.redis_connection.execute_command("ECHO", "healthcheck") if actual_message not in expected_message: return False diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py index 02aaf8e461..0275969d03 100644 --- a/tests/test_asyncio/test_multidb/test_failover.py +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -5,7 +5,7 @@ from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException -from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy, DefaultStrategyExecutor +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy, DefaultFailoverStrategyExecutor class TestAsyncWeightBasedFailoverStrategy: @@ -73,7 +73,7 @@ async def test_execute_returns_valid_database_with_failover_attempts(self, mock_ NoValidDatabaseException, mock_db ] - executor = DefaultStrategyExecutor( + executor = DefaultFailoverStrategyExecutor( mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 @@ -102,7 +102,7 @@ async def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): NoValidDatabaseException, NoValidDatabaseException ] - executor = DefaultStrategyExecutor( + executor = DefaultFailoverStrategyExecutor( mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 @@ -131,7 +131,7 @@ async def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, NoValidDatabaseException, NoValidDatabaseException ] - executor = DefaultStrategyExecutor( + executor = DefaultFailoverStrategyExecutor( mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 diff --git a/tests/test_multidb/test_failover.py b/tests/test_multidb/test_failover.py index 88a19b6a2e..6ae6a9610c 100644 --- a/tests/test_multidb/test_failover.py +++ b/tests/test_multidb/test_failover.py @@ -5,7 +5,7 @@ from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException -from redis.multidb.failover import WeightBasedFailoverStrategy, DefaultStrategyExecutor +from redis.multidb.failover import WeightBasedFailoverStrategy, DefaultFailoverStrategyExecutor class TestWeightBasedFailoverStrategy: @@ -70,7 +70,7 @@ def test_execute_returns_valid_database_with_failover_attempts(self, mock_db, mo NoValidDatabaseException, mock_db ] - executor = DefaultStrategyExecutor( + executor = DefaultFailoverStrategyExecutor( mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 @@ -98,7 +98,7 @@ def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): NoValidDatabaseException, NoValidDatabaseException ] - executor = DefaultStrategyExecutor( + executor = DefaultFailoverStrategyExecutor( mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 @@ -126,7 +126,7 @@ def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, mock_f NoValidDatabaseException, NoValidDatabaseException ] - executor = DefaultStrategyExecutor( + executor = DefaultFailoverStrategyExecutor( mock_fs, failover_attempts=failover_attempts, failover_delay=0.1