diff --git a/redis/asyncio/multidb/__init__.py b/redis/asyncio/multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py new file mode 100644 index 0000000000..73eafd9026 --- /dev/null +++ b/redis/asyncio/multidb/client.py @@ -0,0 +1,237 @@ +import asyncio +from typing import Callable, Optional, Coroutine, Any + +from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.multidb.database import AsyncDatabase, Databases +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD +from redis.background import BackgroundScheduler +from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands +from redis.multidb.exception import NoValidDatabaseException + + +class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): + """ + Client that operates on multiple logical Redis databases. + Should be used in Active-Active database setups. + """ + def __init__(self, config: MultiDbConfig): + self._databases = config.databases() + self._health_checks = config.default_health_checks() + + if config.health_checks is not None: + self._health_checks.extend(config.health_checks) + + self._health_check_interval = config.health_check_interval + self._failure_detectors = config.default_failure_detectors() + + if config.failure_detectors is not None: + self._failure_detectors.extend(config.failure_detectors) + + self._failover_strategy = config.default_failover_strategy() \ + if config.failover_strategy is None else config.failover_strategy + self._failover_strategy.set_databases(self._databases) + self._auto_fallback_interval = config.auto_fallback_interval + self._event_dispatcher = config.event_dispatcher + self._command_retry = config.command_retry + self._command_retry.update_supported_errors([ConnectionRefusedError]) + self.command_executor = DefaultCommandExecutor( + failure_detectors=self._failure_detectors, + databases=self._databases, + command_retry=self._command_retry, + failover_strategy=self._failover_strategy, + event_dispatcher=self._event_dispatcher, + auto_fallback_interval=self._auto_fallback_interval, + ) + self.initialized = False + self._hc_lock = asyncio.Lock() + self._bg_scheduler = BackgroundScheduler() + self._config = config + + async def initialize(self): + """ + Perform initialization of databases to define their initial state. + """ + async def raise_exception_on_failed_hc(error): + raise error + + # Initial databases check to define initial state + await self._check_databases_health(on_error=raise_exception_on_failed_hc) + + # Starts recurring health checks on the background. + asyncio.create_task(self._bg_scheduler.run_recurring_async( + self._health_check_interval, + self._check_databases_health, + )) + + is_active_db_found = False + + for database, weight in self._databases: + # Set on state changed callback for each circuit. + database.circuit.on_state_changed(self._on_circuit_state_change_callback) + + # Set states according to a weights and circuit state + if database.circuit.state == CBState.CLOSED and not is_active_db_found: + await self.command_executor.set_active_database(database) + is_active_db_found = True + + if not is_active_db_found: + raise NoValidDatabaseException('Initial connection failed - no active database found') + + self.initialized = True + + def get_databases(self) -> Databases: + """ + Returns a sorted (by weight) list of all databases. + """ + return self._databases + + async def set_active_database(self, database: AsyncDatabase) -> None: + """ + Promote one of the existing databases to become an active. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + await self._check_db_health(database) + + if database.circuit.state == CBState.CLOSED: + highest_weighted_db, _ = self._databases.get_top_n(1)[0] + await self.command_executor.set_active_database(database) + return + + raise NoValidDatabaseException('Cannot set active database, database is unhealthy') + + async def add_database(self, database: AsyncDatabase): + """ + Adds a new database to the database list. + """ + for existing_db, _ in self._databases: + if existing_db == database: + raise ValueError('Given database already exists') + + await self._check_db_health(database) + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.add(database, database.weight) + await self._change_active_database(database, highest_weighted_db) + + async def _change_active_database(self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase): + if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: + await self.command_executor.set_active_database(new_database) + + async def remove_database(self, database: AsyncDatabase): + """ + Removes a database from the database list. + """ + weight = self._databases.remove(database) + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + + if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: + await self.command_executor.set_active_database(highest_weighted_db) + + async def update_database_weight(self, database: AsyncDatabase, weight: float): + """ + Updates a database from the database list. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.update_weight(database, weight) + database.weight = weight + await self._change_active_database(database, highest_weighted_db) + + def add_failure_detector(self, failure_detector: AsyncFailureDetector): + """ + Adds a new failure detector to the database. + """ + self._failure_detectors.append(failure_detector) + + async def add_health_check(self, healthcheck: HealthCheck): + """ + Adds a new health check to the database. + """ + async with self._hc_lock: + self._health_checks.append(healthcheck) + + async def execute_command(self, *args, **options): + """ + Executes a single command and return its result. + """ + if not self.initialized: + await self.initialize() + + return await self.command_executor.execute_command(*args, **options) + + async def _check_databases_health( + self, + on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, + ): + """ + Runs health checks as a recurring task. + Runs health checks against all databases. + """ + for database, _ in self._databases: + async with self._hc_lock: + await self._check_db_health(database, on_error) + + async def _check_db_health( + self, + database: AsyncDatabase, + on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, + ) -> None: + """ + Runs health checks on the given database until first failure. + """ + is_healthy = True + + # Health check will setup circuit state + for health_check in self._health_checks: + if not is_healthy: + # If one of the health checks failed, it's considered unhealthy + break + + try: + is_healthy = await health_check.check_health(database) + + if not is_healthy and database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + except Exception as e: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + is_healthy = False + + if on_error: + await on_error(e) + + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + loop = asyncio.get_running_loop() + + if new_state == CBState.HALF_OPEN: + asyncio.create_task(self._check_db_health(circuit.database)) + return + + if old_state == CBState.CLOSED and new_state == CBState.OPEN: + loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) + +def _half_open_circuit(circuit: CircuitBreaker): + circuit.state = CBState.HALF_OPEN \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py new file mode 100644 index 0000000000..af10a00988 --- /dev/null +++ b/redis/asyncio/multidb/command_executor.py @@ -0,0 +1,265 @@ +from abc import abstractmethod +from datetime import datetime +from typing import List, Optional, Callable, Any + +from redis.asyncio.client import PubSub, Pipeline +from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database +from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ + ResubscribeOnActiveDatabaseChanged +from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.multidb.circuit import State as CBState +from redis.asyncio.retry import Retry +from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent +from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor +from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL + + +class AsyncCommandExecutor(CommandExecutor): + + @property + @abstractmethod + def databases(self) -> Databases: + """Returns a list of databases.""" + pass + + @property + @abstractmethod + def failure_detectors(self) -> List[AsyncFailureDetector]: + """Returns a list of failure detectors.""" + pass + + @abstractmethod + def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None: + """Adds a new failure detector to the list of failure detectors.""" + pass + + @property + @abstractmethod + def active_database(self) -> Optional[AsyncDatabase]: + """Returns currently active database.""" + pass + + @abstractmethod + async def set_active_database(self, database: AsyncDatabase) -> None: + """Sets the currently active database.""" + pass + + @property + @abstractmethod + def active_pubsub(self) -> Optional[PubSub]: + """Returns currently active pubsub.""" + pass + + @active_pubsub.setter + @abstractmethod + def active_pubsub(self, pubsub: PubSub) -> None: + """Sets currently active pubsub.""" + pass + + @property + @abstractmethod + def failover_strategy(self) -> AsyncFailoverStrategy: + """Returns failover strategy.""" + pass + + @property + @abstractmethod + def command_retry(self) -> Retry: + """Returns command retry object.""" + pass + + @abstractmethod + async def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" + pass + + @abstractmethod + async def execute_command(self, *args, **options): + """Executes a command and returns the result.""" + pass + + @abstractmethod + async def execute_pipeline(self, command_stack: tuple): + """Executes a stack of commands in pipeline.""" + pass + + @abstractmethod + async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + """Executes a transaction block wrapped in callback.""" + pass + + @abstractmethod + async def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """Executes a given method on active pub/sub.""" + pass + + @abstractmethod + async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + """Executes pub/sub run in a thread.""" + pass + + +class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor): + def __init__( + self, + failure_detectors: List[AsyncFailureDetector], + databases: Databases, + command_retry: Retry, + failover_strategy: AsyncFailoverStrategy, + event_dispatcher: EventDispatcherInterface, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + """ + Initialize the DefaultCommandExecutor instance. + + Args: + failure_detectors: List of failure detector instances to monitor database health + databases: Collection of available databases to execute commands on + command_retry: Retry policy for failed command execution + failover_strategy: Strategy for handling database failover + event_dispatcher: Interface for dispatching events + auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database + """ + super().__init__(auto_fallback_interval) + + for fd in failure_detectors: + fd.set_command_executor(command_executor=self) + + self._databases = databases + self._failure_detectors = failure_detectors + self._command_retry = command_retry + self._failover_strategy = failover_strategy + self._event_dispatcher = event_dispatcher + self._active_database: Optional[Database] = None + self._active_pubsub: Optional[PubSub] = None + self._active_pubsub_kwargs = {} + self._setup_event_dispatcher() + self._schedule_next_fallback() + + @property + def databases(self) -> Databases: + return self._databases + + @property + def failure_detectors(self) -> List[AsyncFailureDetector]: + return self._failure_detectors + + def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None: + self._failure_detectors.append(failure_detector) + + @property + def active_database(self) -> Optional[AsyncDatabase]: + return self._active_database + + async def set_active_database(self, database: AsyncDatabase) -> None: + old_active = self._active_database + self._active_database = database + + if old_active is not None and old_active is not database: + await self._event_dispatcher.dispatch_async( + AsyncActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs) + ) + + @property + def active_pubsub(self) -> Optional[PubSub]: + return self._active_pubsub + + @active_pubsub.setter + def active_pubsub(self, pubsub: PubSub) -> None: + self._active_pubsub = pubsub + + @property + def failover_strategy(self) -> AsyncFailoverStrategy: + return self._failover_strategy + + @property + def command_retry(self) -> Retry: + return self._command_retry + + async def pubsub(self, **kwargs): + async def callback(): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs + return None + + return await self._execute_with_failure_detection(callback) + + async def execute_command(self, *args, **options): + async def callback(): + return await self._active_database.client.execute_command(*args, **options) + + return await self._execute_with_failure_detection(callback, args) + + async def execute_pipeline(self, command_stack: tuple): + async def callback(): + with self._active_database.client.pipeline() as pipe: + for command, options in command_stack: + await pipe.execute_command(*command, **options) + + return await pipe.execute() + + return await self._execute_with_failure_detection(callback, command_stack) + + async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + async def callback(): + return await self._active_database.client.transaction(transaction, *watches, **options) + + return await self._execute_with_failure_detection(callback) + + async def execute_pubsub_method(self, method_name: str, *args, **kwargs): + async def callback(): + method = getattr(self.active_pubsub, method_name) + return await method(*args, **kwargs) + + return await self._execute_with_failure_detection(callback, *args) + + async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + async def callback(): + return await self._active_pubsub.run(poll_timeout=sleep_time, **kwargs) + + return await self._execute_with_failure_detection(callback) + + async def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): + """ + Execute a commands execution callback with failure detection. + """ + async def wrapper(): + # On each retry we need to check active database as it might change. + await self._check_active_database() + return await callback() + + return await self._command_retry.call_with_retry( + lambda: wrapper(), + lambda error: self._on_command_fail(error, *cmds), + ) + + async def _check_active_database(self): + """ + Checks if active a database needs to be updated. + """ + if ( + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) + ): + await self.set_active_database(await self._failover_strategy.database()) + self._schedule_next_fallback() + + async def _on_command_fail(self, error, *args): + await self._event_dispatcher.dispatch_async(AsyncOnCommandsFailEvent(args, error)) + + def _setup_event_dispatcher(self): + """ + Registers necessary listeners. + """ + failure_listener = RegisterCommandFailure(self._failure_detectors) + resubscribe_listener = ResubscribeOnActiveDatabaseChanged() + self._event_dispatcher.register_listeners({ + AsyncOnCommandsFailEvent: [failure_listener], + AsyncActiveDatabaseChanged: [resubscribe_listener], + }) \ No newline at end of file diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py new file mode 100644 index 0000000000..b5f4a0658d --- /dev/null +++ b/redis/asyncio/multidb/config.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass, field +from typing import Optional, List, Type, Union + +import pybreaker + +from redis.asyncio import ConnectionPool, Redis, RedisCluster +from redis.asyncio.multidb.database import Databases, Database +from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper +from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, \ + EchoHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff +from redis.data_structure import WeightedList +from redis.event import EventDispatcherInterface, EventDispatcher +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.failure_detector import CommandFailureDetector + +DEFAULT_GRACE_PERIOD = 5.0 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_FAILURES_THRESHOLD = 3 +DEFAULT_FAILURES_DURATION = 2 +DEFAULT_FAILOVER_RETRIES = 3 +DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) +DEFAULT_AUTO_FALLBACK_INTERVAL = -1 + +def default_event_dispatcher() -> EventDispatcherInterface: + return EventDispatcher() + +@dataclass +class DatabaseConfig: + """ + Dataclass representing the configuration for a database connection. + + This class is used to store configuration settings for a database connection, + including client options, connection sourcing details, circuit breaker settings, + and cluster-specific properties. It provides a structure for defining these + attributes and allows for the creation of customized configurations for various + database setups. + + Attributes: + weight (float): Weight of the database to define the active one. + client_kwargs (dict): Additional parameters for the database client connection. + from_url (Optional[str]): Redis URL way of connecting to the database. + from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + grace_period (float): Grace period after which we need to check if the circuit could be closed again. + health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used + on public Redis Enterprise endpoints. + + Methods: + default_circuit_breaker: + Generates and returns a default CircuitBreaker instance adapted for use. + """ + weight: float = 1.0 + client_kwargs: dict = field(default_factory=dict) + from_url: Optional[str] = None + from_pool: Optional[ConnectionPool] = None + circuit: Optional[CircuitBreaker] = None + grace_period: float = DEFAULT_GRACE_PERIOD + health_check_url: Optional[str] = None + + def default_circuit_breaker(self) -> CircuitBreaker: + circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) + return PBCircuitBreakerAdapter(circuit_breaker) + +@dataclass +class MultiDbConfig: + """ + Configuration class for managing multiple database connections in a resilient and fail-safe manner. + + Attributes: + databases_config: A list of database configurations. + client_class: The client class used to manage database connections. + command_retry: Retry strategy for executing database commands. + failure_detectors: Optional list of additional failure detectors for monitoring database failures. + failure_threshold: Threshold for determining database failure. + failures_interval: Time interval for tracking database failures. + health_checks: Optional list of additional health checks performed on databases. + health_check_interval: Time interval for executing health checks. + health_check_retries: Number of retry attempts for performing health checks. + health_check_backoff: Backoff strategy for health check retries. + failover_strategy: Optional strategy for handling database failover scenarios. + failover_retries: Number of retries allowed for failover operations. + failover_backoff: Backoff strategy for failover retries. + auto_fallback_interval: Time interval to trigger automatic fallback. + event_dispatcher: Interface for dispatching events related to database operations. + + Methods: + databases: + Retrieves a collection of database clients managed by weighted configurations. + Initializes database clients based on the provided configuration and removes + redundant retry objects for lower-level clients to rely on global retry logic. + + default_failure_detectors: + Returns the default list of failure detectors used to monitor database failures. + + default_health_checks: + Returns the default list of health checks used to monitor database health + with specific retry and backoff strategies. + + default_failover_strategy: + Provides the default failover strategy used for handling failover scenarios + with defined retry and backoff configurations. + """ + databases_config: List[DatabaseConfig] + client_class: Type[Union[Redis, RedisCluster]] = Redis + command_retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ) + failure_detectors: Optional[List[AsyncFailureDetector]] = None + failure_threshold: int = DEFAULT_FAILURES_THRESHOLD + failures_interval: float = DEFAULT_FAILURES_DURATION + health_checks: Optional[List[HealthCheck]] = None + health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL + health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES + health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF + failover_strategy: Optional[AsyncFailoverStrategy] = None + failover_retries: int = DEFAULT_FAILOVER_RETRIES + failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL + event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) + + def databases(self) -> Databases: + databases = WeightedList() + + for database_config in self.databases_config: + # The retry object is not used in the lower level clients, so we can safely remove it. + # We rely on command_retry in terms of global retries. + database_config.client_kwargs.update({"retry": Retry(retries=0, backoff=NoBackoff())}) + + if database_config.from_url: + client = self.client_class.from_url(database_config.from_url, **database_config.client_kwargs) + elif database_config.from_pool: + database_config.from_pool.set_retry(Retry(retries=0, backoff=NoBackoff())) + client = self.client_class.from_pool(connection_pool=database_config.from_pool) + else: + client = self.client_class(**database_config.client_kwargs) + + circuit = database_config.default_circuit_breaker() \ + if database_config.circuit is None else database_config.circuit + databases.add( + Database( + client=client, + circuit=circuit, + weight=database_config.weight, + health_check_url=database_config.health_check_url + ), + database_config.weight + ) + + return databases + + def default_failure_detectors(self) -> List[AsyncFailureDetector]: + return [ + FailureDetectorAsyncWrapper( + CommandFailureDetector(threshold=self.failure_threshold, duration=self.failures_interval) + ), + ] + + def default_health_checks(self) -> List[HealthCheck]: + return [ + EchoHealthCheck(retry=Retry(retries=self.health_check_retries, backoff=self.health_check_backoff)), + ] + + def default_failover_strategy(self) -> AsyncFailoverStrategy: + return WeightBasedFailoverStrategy( + retry=Retry(retries=self.failover_retries, backoff=self.failover_backoff), + ) \ No newline at end of file diff --git a/redis/asyncio/multidb/database.py b/redis/asyncio/multidb/database.py new file mode 100644 index 0000000000..6afbbbf5ea --- /dev/null +++ b/redis/asyncio/multidb/database.py @@ -0,0 +1,67 @@ +from abc import abstractmethod +from typing import Union, Optional + +from redis.asyncio import Redis, RedisCluster +from redis.data_structure import WeightedList +from redis.multidb.circuit import CircuitBreaker +from redis.multidb.database import AbstractDatabase, BaseDatabase +from redis.typing import Number + + +class AsyncDatabase(AbstractDatabase): + """Database with an underlying asynchronous redis client.""" + @property + @abstractmethod + def client(self) -> Union[Redis, RedisCluster]: + """The underlying redis client.""" + pass + + @client.setter + @abstractmethod + def client(self, client: Union[Redis, RedisCluster]): + """Set the underlying redis client.""" + pass + + @property + @abstractmethod + def circuit(self) -> CircuitBreaker: + """Circuit breaker for the current database.""" + pass + + @circuit.setter + @abstractmethod + def circuit(self, circuit: CircuitBreaker): + """Set the circuit breaker for the current database.""" + pass + +Databases = WeightedList[tuple[AsyncDatabase, Number]] + +class Database(BaseDatabase, AsyncDatabase): + def __init__( + self, + client: Union[Redis, RedisCluster], + circuit: CircuitBreaker, + weight: float, + health_check_url: Optional[str] = None, + ): + self._client = client + self._cb = circuit + self._cb.database = self + super().__init__(weight, health_check_url) + + @property + def client(self) -> Union[Redis, RedisCluster]: + return self._client + + @client.setter + def client(self, client: Union[Redis, RedisCluster]): + self._client = client + + @property + def circuit(self) -> CircuitBreaker: + return self._cb + + @circuit.setter + def circuit(self, circuit: CircuitBreaker): + self._cb = circuit + diff --git a/redis/asyncio/multidb/event.py b/redis/asyncio/multidb/event.py new file mode 100644 index 0000000000..ea5534ce86 --- /dev/null +++ b/redis/asyncio/multidb/event.py @@ -0,0 +1,65 @@ +from typing import List + +from redis.asyncio.multidb.database import AsyncDatabase +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent + + +class AsyncActiveDatabaseChanged: + """ + Event fired when an async active database has been changed. + """ + def __init__( + self, + old_database: AsyncDatabase, + new_database: AsyncDatabase, + command_executor, + **kwargs + ): + self._old_database = old_database + self._new_database = new_database + self._command_executor = command_executor + self._kwargs = kwargs + + @property + def old_database(self) -> AsyncDatabase: + return self._old_database + + @property + def new_database(self) -> AsyncDatabase: + return self._new_database + + @property + def command_executor(self): + return self._command_executor + + @property + def kwargs(self): + return self._kwargs + +class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface): + """ + Re-subscribe the currently active pub / sub to a new active database. + """ + async def listen(self, event: AsyncActiveDatabaseChanged): + old_pubsub = event.command_executor.active_pubsub + + if old_pubsub is not None: + # Re-assign old channels and patterns so they will be automatically subscribed on connection. + new_pubsub = event.new_database.client.pubsub(**event.kwargs) + new_pubsub.channels = old_pubsub.channels + new_pubsub.patterns = old_pubsub.patterns + await new_pubsub.on_connect(None) + event.command_executor.active_pubsub = new_pubsub + await old_pubsub.close() + +class RegisterCommandFailure(AsyncEventListenerInterface): + """ + Event listener that registers command failures and passing it to the failure detectors. + """ + def __init__(self, failure_detectors: List[AsyncFailureDetector]): + self._failure_detectors = failure_detectors + + async def listen(self, event: AsyncOnCommandsFailEvent) -> None: + for failure_detector in self._failure_detectors: + await failure_detector.register_failure(event.exception, event.commands) \ No newline at end of file diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py new file mode 100644 index 0000000000..a2ed427e05 --- /dev/null +++ b/redis/asyncio/multidb/failover.py @@ -0,0 +1,49 @@ +from abc import abstractmethod, ABC + +from redis.asyncio.multidb.database import AsyncDatabase, Databases +from redis.multidb.circuit import State as CBState +from redis.asyncio.retry import Retry +from redis.data_structure import WeightedList +from redis.multidb.exception import NoValidDatabaseException +from redis.utils import dummy_fail_async + + +class AsyncFailoverStrategy(ABC): + + @abstractmethod + async def database(self) -> AsyncDatabase: + """Select the database according to the strategy.""" + pass + + @abstractmethod + def set_databases(self, databases: Databases) -> None: + """Set the database strategy operates on.""" + pass + +class WeightBasedFailoverStrategy(AsyncFailoverStrategy): + """ + Failover strategy based on database weights. + """ + def __init__( + self, + retry: Retry + ): + self._retry = retry + self._retry.update_supported_errors([NoValidDatabaseException]) + self._databases = WeightedList() + + async def database(self) -> AsyncDatabase: + return await self._retry.call_with_retry( + lambda: self._get_active_database(), + lambda _: dummy_fail_async() + ) + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + + async def _get_active_database(self) -> AsyncDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException('No valid database available for communication') \ No newline at end of file diff --git a/redis/asyncio/multidb/failure_detector.py b/redis/asyncio/multidb/failure_detector.py new file mode 100644 index 0000000000..8aa4752924 --- /dev/null +++ b/redis/asyncio/multidb/failure_detector.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + +from redis.multidb.failure_detector import FailureDetector + + +class AsyncFailureDetector(ABC): + + @abstractmethod + async def register_failure(self, exception: Exception, cmd: tuple) -> None: + """Register a failure that occurred during command execution.""" + pass + + @abstractmethod + def set_command_executor(self, command_executor) -> None: + """Set the command executor for this failure.""" + pass + +class FailureDetectorAsyncWrapper(AsyncFailureDetector): + """ + Async wrapper for the failure detector. + """ + def __init__(self, failure_detector: FailureDetector) -> None: + self._failure_detector = failure_detector + + async def register_failure(self, exception: Exception, cmd: tuple) -> None: + self._failure_detector.register_failure(exception, cmd) + + def set_command_executor(self, command_executor) -> None: + self._failure_detector.set_command_executor(command_executor) \ No newline at end of file diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py new file mode 100644 index 0000000000..ccaf285ade --- /dev/null +++ b/redis/asyncio/multidb/healthcheck.py @@ -0,0 +1,75 @@ +import logging +from abc import ABC, abstractmethod + +from redis.asyncio import Redis +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialWithJitterBackoff +from redis.utils import dummy_fail_async + +DEFAULT_HEALTH_CHECK_RETRIES = 3 +DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) + +logger = logging.getLogger(__name__) + +class HealthCheck(ABC): + + @property + @abstractmethod + def retry(self) -> Retry: + """The retry object to use for health checks.""" + pass + + @abstractmethod + async def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class AbstractHealthCheck(HealthCheck): + def __init__( + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) -> None: + self._retry = retry + self._retry.update_supported_errors([ConnectionRefusedError]) + + @property + def retry(self) -> Retry: + return self._retry + + @abstractmethod + async def check_health(self, database) -> bool: + pass + +class EchoHealthCheck(AbstractHealthCheck): + def __init__( + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) -> None: + """ + Check database healthiness by sending an echo request. + """ + super().__init__( + retry=retry, + ) + async def check_health(self, database) -> bool: + return await self._retry.call_with_retry( + lambda: self._returns_echoed_message(database), + lambda _: dummy_fail_async() + ) + + async def _returns_echoed_message(self, database) -> bool: + expected_message = ["healthcheck", b"healthcheck"] + + if isinstance(database.client, Redis): + actual_message = await database.client.execute_command("ECHO", "healthcheck") + return actual_message in expected_message + else: + # For a cluster checks if all nodes are healthy. + all_nodes = database.client.get_nodes() + for node in all_nodes: + actual_message = await node.redis_connection.execute_command("ECHO", "healthcheck") + + if actual_message not in expected_message: + return False + + return True \ No newline at end of file diff --git a/redis/background.py b/redis/background.py index 6466649859..ce43cbfa7a 100644 --- a/redis/background.py +++ b/redis/background.py @@ -1,6 +1,7 @@ import asyncio import threading -from typing import Callable +from typing import Callable, Coroutine, Any + class BackgroundScheduler: """ @@ -45,7 +46,35 @@ def run_recurring( ) thread.start() - def _call_later(self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args): + async def run_recurring_async( + self, + interval: float, + coro: Callable[..., Coroutine[Any, Any, Any]], + *args + ): + """ + Runs recurring coroutine with given interval in seconds in the current event loop. + To be used only from an async context. No additional threads are created. + """ + loop = asyncio.get_running_loop() + wrapped = _async_to_sync_wrapper(loop, coro, *args) + + def tick(): + # Schedule the coroutine + wrapped() + # Schedule next tick + self._next_timer = loop.call_later(interval, tick) + + # Schedule first tick + self._next_timer = loop.call_later(interval, tick) + + def _call_later( + self, + loop: asyncio.AbstractEventLoop, + delay: float, + callback: Callable, + *args + ): self._next_timer = loop.call_later(delay, callback, *args) def _call_later_recurring( @@ -86,4 +115,21 @@ def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop, call_soon """ asyncio.set_event_loop(event_loop) event_loop.call_soon(call_soon_cb, event_loop, *args) - event_loop.run_forever() \ No newline at end of file + event_loop.run_forever() + +def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs): + """ + Wraps an asynchronous function so it can be used with loop.call_later. + + :param loop: The event loop in which the coroutine will be executed. + :param coro_func: The coroutine function to wrap. + :param args: Positional arguments to pass to the coroutine function. + :param kwargs: Keyword arguments to pass to the coroutine function. + :return: A regular function suitable for loop.call_later. + """ + + def wrapped(): + # Schedule the coroutine in the event loop + asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop) + + return wrapped \ No newline at end of file diff --git a/redis/event.py b/redis/event.py index 1fa66f0587..de38e1a069 100644 --- a/redis/event.py +++ b/redis/event.py @@ -43,7 +43,10 @@ async def dispatch_async(self, event: object): pass @abstractmethod - def register_listeners(self, mappings: Dict[Type[object], List[EventListenerInterface]]): + def register_listeners( + self, + mappings: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]] + ): """Register additional listeners.""" pass @@ -99,13 +102,16 @@ def dispatch(self, event: object): listener.listen(event) async def dispatch_async(self, event: object): - with self._async_lock: + async with self._async_lock: listeners = self._event_listeners_mapping.get(type(event), []) for listener in listeners: await listener.listen(event) - def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]): + def register_listeners( + self, + event_listeners: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]] + ): with self._lock: for event_type in event_listeners: if event_type in self._event_listeners_mapping: @@ -271,6 +277,9 @@ def commands(self) -> tuple: def exception(self) -> Exception: return self._exception +class AsyncOnCommandsFailEvent(OnCommandsFailEvent): + pass + class ReAuthConnectionListener(EventListenerInterface): """ Listener that performs re-authentication of given connection. diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 221dc556a3..8f904c0e4b 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -45,6 +45,11 @@ def database(self, database): """Set database associated with this circuit.""" pass + @abstractmethod + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + class BaseCircuitBreaker(CircuitBreaker): """ Base implementation of Circuit Breaker interface. @@ -82,12 +87,8 @@ def database(self): def database(self, database): self._database = database -class SyncCircuitBreaker(CircuitBreaker): - """ - Synchronous implementation of Circuit Breaker interface. - """ @abstractmethod - def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" pass @@ -95,7 +96,7 @@ class PBListener(pybreaker.CircuitBreakerListener): """Wrapper for callback to be compatible with pybreaker implementation.""" def __init__( self, - cb: Callable[[SyncCircuitBreaker, State, State], None], + cb: Callable[[CircuitBreaker, State, State], None], database, ): """ @@ -116,7 +117,7 @@ def state_change(self, cb, old_state, new_state): new_state = State(value=new_state.name) self._cb(cb, old_state, new_state) -class PBCircuitBreakerAdapter(SyncCircuitBreaker, BaseCircuitBreaker): +class PBCircuitBreakerAdapter(BaseCircuitBreaker): def __init__(self, cb: pybreaker.CircuitBreaker): """ Initialize a PBCircuitBreakerAdapter instance. @@ -129,6 +130,6 @@ def __init__(self, cb: pybreaker.CircuitBreaker): """ super().__init__(cb) - def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): listener = PBListener(cb, self.database) self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 8a0e006977..71e079346a 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -5,7 +5,7 @@ from redis.commands import RedisModuleCommands, CoreCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD -from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector @@ -244,7 +244,7 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): for database, _ in self._databases: self._check_db_health(database, on_error) - def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_state: CBState, new_state: CBState): + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: self._check_db_health(circuit.database) return @@ -252,7 +252,7 @@ def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_sta if old_state == CBState.CLOSED and new_state == CBState.OPEN: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) -def _half_open_circuit(circuit: SyncCircuitBreaker): +def _half_open_circuit(circuit: CircuitBreaker): circuit.state = CBState.HALF_OPEN diff --git a/redis/multidb/config.py b/redis/multidb/config.py index a966ec329a..fc349ed04b 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -9,7 +9,7 @@ from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ @@ -44,7 +44,7 @@ class DatabaseConfig: client_kwargs (dict): Additional parameters for the database client connection. from_url (Optional[str]): Redis URL way of connecting to the database. from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. - circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. grace_period (float): Grace period after which we need to check if the circuit could be closed again. health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used on public Redis Enterprise endpoints. @@ -57,11 +57,11 @@ class DatabaseConfig: client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None - circuit: Optional[SyncCircuitBreaker] = None + circuit: Optional[CircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD health_check_url: Optional[str] = None - def default_circuit_breaker(self) -> SyncCircuitBreaker: + def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) return PBCircuitBreakerAdapter(circuit_breaker) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 75a662d904..9c2ffe3552 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -5,7 +5,7 @@ from redis import RedisCluster from redis.data_structure import WeightedList -from redis.multidb.circuit import SyncCircuitBreaker +from redis.multidb.circuit import CircuitBreaker from redis.typing import Number class AbstractDatabase(ABC): @@ -74,13 +74,13 @@ def client(self, client: Union[redis.Redis, RedisCluster]): @property @abstractmethod - def circuit(self) -> SyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: """Circuit breaker for the current database.""" pass @circuit.setter @abstractmethod - def circuit(self, circuit: SyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass @@ -90,7 +90,7 @@ class Database(BaseDatabase, SyncDatabase): def __init__( self, client: Union[redis.Redis, RedisCluster], - circuit: SyncCircuitBreaker, + circuit: CircuitBreaker, weight: float, health_check_url: Optional[str] = None, ): @@ -117,9 +117,9 @@ def client(self, client: Union[redis.Redis, RedisCluster]): self._client = client @property - def circuit(self) -> SyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: return self._cb @circuit.setter - def circuit(self, circuit: SyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): self._cb = circuit \ No newline at end of file diff --git a/redis/utils.py b/redis/utils.py index 94bfab61bb..1800582e46 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -314,3 +314,9 @@ def dummy_fail(): Fake function for a Retry object if you don't need to handle each failure. """ pass + +async def dummy_fail_async(): + """ + Async fake function for a Retry object if you don't need to handle each failure. + """ + pass \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/__init__.py b/tests/test_asyncio/test_multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py new file mode 100644 index 0000000000..0ac231cf52 --- /dev/null +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -0,0 +1,108 @@ +from unittest.mock import Mock + +import pytest + +from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL, \ + DatabaseConfig +from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.asyncio import Redis +from redis.asyncio.multidb.database import Database, Databases + + +@pytest.fixture() +def mock_client() -> Redis: + return Mock(spec=Redis) + +@pytest.fixture() +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) + +@pytest.fixture() +def mock_fd() -> AsyncFailureDetector: + return Mock(spec=AsyncFailureDetector) + +@pytest.fixture() +def mock_fs() -> AsyncFailoverStrategy: + return Mock(spec=AsyncFailoverStrategy) + +@pytest.fixture() +def mock_hc() -> HealthCheck: + return Mock(spec=HealthCheck) + +@pytest.fixture() +def mock_db(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db1(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db2(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_multi_db_config( + request, mock_fd, mock_fs, mock_hc, mock_ed +) -> MultiDbConfig: + hc_interval = request.param.get('hc_interval', None) + if hc_interval is None: + hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL + + auto_fallback_interval = request.param.get('auto_fallback_interval', None) + if auto_fallback_interval is None: + auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_check_interval=hc_interval, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed + ) + + return config + + +def create_weighted_list(*databases) -> Databases: + dbs = WeightedList() + + for db in databases: + dbs.add(db, db.weight) + + return dbs \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py new file mode 100644 index 0000000000..c2fe914e9f --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -0,0 +1,471 @@ +import asyncio +from unittest.mock import patch, AsyncMock, Mock + +import pybreaker +import pytest + +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES, DEFAULT_FAILOVER_BACKOFF +from redis.asyncio.multidb.database import AsyncDatabase +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF, HealthCheck +from redis.asyncio.retry import Retry +from redis.event import EventDispatcher, AsyncOnCommandsFailEvent +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.exception import NoValidDatabaseException +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +class TestMultiDbClient: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_on_successful_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + + mock_hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert await client.set('key', 'value') == 'OK1' + await asyncio.sleep(0.15) + assert await client.set('key', 'value') == 'OK2' + await asyncio.sleep(0.1) + assert await client.set('key', 'value') == 'OK' + await asyncio.sleep(0.1) + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = 0.2 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert await client.set('key', 'value') == 'OK1' + await asyncio.sleep(0.15) + assert await client.set('key', 'value') == 'OK2' + await asyncio.sleep(0.22) + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_command_throws_exception_on_failed_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): + await client.set('key', 'value') + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_database_throws_exception_on_same_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(ValueError, match='Given database already exists'): + await client.add_database(mock_db) + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_database_makes_new_database_active( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK2' + assert mock_hc.check_health.call_count == 2 + + await client.add_database(mock_db1) + assert mock_hc.check_health.call_count == 3 + + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_remove_highest_weighted_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.remove_database(mock_db1) + assert await client.set('key', 'value') == 'OK2' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_update_database_weight_to_be_highest( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 + + assert await client.set('key', 'value') == 'OK2' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_new_failure_detector( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_multi_db_config.event_dispatcher = EventDispatcher() + mock_fd = mock_multi_db_config.failure_detectors[0] + + # Event fired if command against mock_db1 would fail + command_fail_event = AsyncOnCommandsFailEvent( + commands=('SET', 'key', 'value'), + exception=Exception(), + ) + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + # Simulate failing command events that lead to a failure detection + for i in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async(command_fail_event) + + assert mock_fd.register_failure.call_count == 5 + + another_fd = Mock(spec=AsyncFailureDetector) + client.add_failure_detector(another_fd) + + # Simulate failing command events that lead to a failure detection + for i in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async(command_fail_event) + + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_new_health_check( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True + + await client.add_health_check(another_hc) + await client._check_db_health(mock_db1) + + assert mock_hc.check_health.call_count == 4 + assert another_hc.check_health.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_set_active_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db.client.execute_command.return_value = 'OK' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.set_active_database(mock_db) + assert await client.set('key', 'value') == 'OK' + + with pytest.raises(ValueError, match='Given database is not a member of database list'): + await client.set_active_database(Mock(spec=AsyncDatabase)) + + mock_hc.check_health.return_value = False + + with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): + await client.set_active_database(mock_db1) \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_command_executor.py b/tests/test_asyncio/test_multidb/test_command_executor.py new file mode 100644 index 0000000000..3f64e6aa0b --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_command_executor.py @@ -0,0 +1,165 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper +from redis.event import EventDispatcher +from redis.exceptions import ConnectionError +from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.multidb.circuit import State as CBState +from redis.multidb.failure_detector import CommandFailureDetector +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +class TestDefaultCommandExecutor: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) + ) + + await executor.set_active_database(mock_db1) + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + + await executor.set_active_database(mock_db2) + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_automatically_select_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) + ) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.circuit.state = CBState.OPEN + + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 2 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_fallback_to_another_db_after_fallback_interval( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), 0) + ) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.weight = 0.1 + await asyncio.sleep(0.15) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + mock_db1.weight = 0.7 + await asyncio.sleep(0.15) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_fallback_to_another_db_after_failure_detection( + self, mock_db, mock_db1, mock_db2, mock_fs + ): + mock_db1.client.execute_command = AsyncMock(side_effect=['OK1', ConnectionError, ConnectionError, ConnectionError, 'OK1']) + mock_db2.client.execute_command = AsyncMock(side_effect=['OK2', ConnectionError, ConnectionError, ConnectionError]) + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + threshold = 3 + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(threshold, 1)) + ed = EventDispatcher() + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), threshold), + ) + fd.set_command_executor(command_executor=executor) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_selector.call_count == 3 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_config.py b/tests/test_asyncio/test_multidb/test_config.py new file mode 100644 index 0000000000..64760740a1 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_config.py @@ -0,0 +1,125 @@ +from unittest.mock import Mock + +from redis.asyncio import ConnectionPool +from redis.asyncio.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_GRACE_PERIOD, \ + DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy, AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper, AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.asyncio.retry import Retry +from redis.multidb.circuit import CircuitBreaker + + +class TestMultiDbConfig: + def test_default_config(self): + db_configs = [ + DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0), + DatabaseConfig(client_kwargs={'host': 'host2', 'port': 'port2'}, weight=0.9), + DatabaseConfig(client_kwargs={'host': 'host3', 'port': 'port3'}, weight=0.8), + ] + + config = MultiDbConfig( + databases_config=db_configs + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD + assert db.client.get_retry() is not config.command_retry + i+=1 + + assert len(config.default_failure_detectors()) == 1 + assert isinstance(config.default_failure_detectors()[0], FailureDetectorAsyncWrapper) + assert len(config.default_health_checks()) == 1 + assert isinstance(config.default_health_checks()[0], EchoHealthCheck) + assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL + assert isinstance(config.default_failover_strategy(), WeightBasedFailoverStrategy) + assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL + assert isinstance(config.command_retry, Retry) + + def test_overridden_config(self): + grace_period = 2 + mock_connection_pools = [Mock(spec=ConnectionPool), Mock(spec=ConnectionPool), Mock(spec=ConnectionPool)] + mock_connection_pools[0].connection_kwargs = {} + mock_connection_pools[1].connection_kwargs = {} + mock_connection_pools[2].connection_kwargs = {} + mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1.grace_period = grace_period + mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2.grace_period = grace_period + mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3.grace_period = grace_period + mock_failure_detectors = [Mock(spec=AsyncFailureDetector), Mock(spec=AsyncFailureDetector)] + mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] + health_check_interval = 10 + mock_failover_strategy = Mock(spec=AsyncFailoverStrategy) + auto_fallback_interval = 10 + db_configs = [ + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, weight=1.0, circuit=mock_cb1 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, weight=0.9, circuit=mock_cb2 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, weight=0.8, circuit=mock_cb3 + ), + ] + + config = MultiDbConfig( + databases_config=db_configs, + failure_detectors=mock_failure_detectors, + health_checks=mock_health_checks, + health_check_interval=health_check_interval, + failover_strategy=mock_failover_strategy, + auto_fallback_interval=auto_fallback_interval, + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.client.connection_pool == mock_connection_pools[i] + assert db.circuit.grace_period == grace_period + i+=1 + + assert len(config.failure_detectors) == 2 + assert config.failure_detectors[0] == mock_failure_detectors[0] + assert config.failure_detectors[1] == mock_failure_detectors[1] + assert len(config.health_checks) == 2 + assert config.health_checks[0] == mock_health_checks[0] + assert config.health_checks[1] == mock_health_checks[1] + assert config.health_check_interval == health_check_interval + assert config.failover_strategy == mock_failover_strategy + assert config.auto_fallback_interval == auto_fallback_interval + +class TestDatabaseConfig: + def test_default_config(self): + config = DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0) + + assert config.client_kwargs == {'host': 'host1', 'port': 'port1'} + assert config.weight == 1.0 + assert isinstance(config.default_circuit_breaker(), CircuitBreaker) + + def test_overridden_config(self): + mock_connection_pool = Mock(spec=ConnectionPool) + mock_circuit = Mock(spec=CircuitBreaker) + + config = DatabaseConfig( + client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit + ) + + assert config.client_kwargs == {'connection_pool': mock_connection_pool} + assert config.weight == 1.0 + assert config.circuit == mock_circuit \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py new file mode 100644 index 0000000000..f692c40643 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -0,0 +1,121 @@ +from unittest.mock import PropertyMock + +import pytest + +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.retry import Retry + + +class TestAsyncWeightBasedFailoverStrategy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + ids=['all closed - highest weight', 'highest weight - open'], + indirect=True, + ) + async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + + strategy = WeightBasedFailoverStrategy(retry=retry) + strategy.set_databases(databases) + + assert await strategy.database() == mock_db1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + assert await failover_strategy.database() == mock_db + assert state_mock.call_count == 4 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert await failover_strategy.database() + + assert state_mock.call_count == 4 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert await failover_strategy.database() \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failure_detector.py b/tests/test_asyncio/test_multidb/test_failure_detector.py new file mode 100644 index 0000000000..3c1eb4fabd --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_failure_detector.py @@ -0,0 +1,153 @@ +import asyncio +from unittest.mock import Mock + +import pytest + +from redis.asyncio.multidb.command_executor import AsyncCommandExecutor +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper +from redis.multidb.circuit import State as CBState +from redis.multidb.failure_detector import CommandFailureDetector + + +class TestFailureDetectorAsyncWrapper: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 0.3)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + # 4 more failures as the last one already refreshed timer + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 0.3)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.4) + + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1, error_types=[ConnectionError])) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py new file mode 100644 index 0000000000..fd5c8ec3f0 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -0,0 +1,48 @@ +import pytest +from mock.mock import AsyncMock + +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.healthcheck import EchoHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError + + +class TestEchoHealthCheck: + + @pytest.mark.asyncio + async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 + + @pytest.mark.asyncio + async def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'wrong']) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == False + assert mock_client.execute_command.call_count == 3 + + @pytest.mark.asyncio + async def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) + mock_cb.state = CBState.HALF_OPEN + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/__init__.py b/tests/test_asyncio/test_scenario/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py new file mode 100644 index 0000000000..312712ba05 --- /dev/null +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -0,0 +1,88 @@ +import os + +import pytest + +from redis.asyncio import Redis +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILURES_THRESHOLD, DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, \ + MultiDbConfig +from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff +from redis.event import AsyncEventListenerInterface, EventDispatcher +from tests.test_scenario.conftest import get_endpoint_config, extract_cluster_fqdn +from tests.test_scenario.fault_injector_client import FaultInjectorClient + + +class CheckActiveDatabaseChangedListener(AsyncEventListenerInterface): + def __init__(self): + self.is_changed_flag = False + + async def listen(self, event: AsyncActiveDatabaseChanged): + self.is_changed_flag = True + +@pytest.fixture() +def fault_injector_client(): + url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") + return FaultInjectorClient(url) + +@pytest.fixture() +def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: + client_class = request.param.get('client_class', Redis) + + if client_class == Redis: + endpoint_config = get_endpoint_config('re-active-active') + else: + endpoint_config = get_endpoint_config('re-active-active-oss-cluster') + + username = endpoint_config.get('username', None) + password = endpoint_config.get('password', None) + failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=2, base=0.05), retries=10)) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners({ + AsyncActiveDatabaseChanged: [listener], + }) + db_configs = [] + + db_config = DatabaseConfig( + weight=1.0, + from_url=endpoint_config['endpoints'][0], + client_kwargs={ + 'username': username, + 'password': password, + 'decode_responses': True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][0]) + ) + db_configs.append(db_config) + + db_config1 = DatabaseConfig( + weight=0.9, + from_url=endpoint_config['endpoints'][1], + client_kwargs={ + 'username': username, + 'password': password, + 'decode_responses': True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][1]) + ) + db_configs.append(db_config1) + + config = MultiDbConfig( + client_class=client_class, + databases_config=db_configs, + command_retry=command_retry, + failure_threshold=failure_threshold, + health_check_retries=3, + health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, + health_check_backoff=ExponentialBackoff(cap=5, base=0.5), + ) + + return MultiDBClient(config), listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py new file mode 100644 index 0000000000..518e9561d9 --- /dev/null +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -0,0 +1,59 @@ +import asyncio +import logging +from time import sleep + +import pytest + +from tests.test_scenario.fault_injector_client import ActionRequest, ActionType + +logger = logging.getLogger(__name__) + +async def trigger_network_failure_action(fault_injector_client, config, event: asyncio.Event = None): + action_request = ActionRequest( + action_type=ActionType.NETWORK_FAILURE, + parameters={"bdb_id": config['bdb_id'], "delay": 2, "cluster_index": 0} + ) + + result = fault_injector_client.trigger_action(action_request) + status_result = fault_injector_client.get_action_status(result['action_id']) + + while status_result['status'] != "success": + await asyncio.sleep(0.1) + status_result = fault_injector_client.get_action_status(result['action_id']) + logger.info(f"Waiting for action to complete. Status: {status_result['status']}") + + if event: + event.set() + + logger.info(f"Action completed. Status: {status_result['status']}") + +class TestActiveActive: + + def teardown_method(self, method): + # Timeout so the cluster could recover from network failure. + sleep(5) + + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, config, event)) + + # Client initialized on the first command. + await r_multi_db.set('key', 'value') + + # Execute commands before network failure + while not event.is_set(): + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + + # Execute commands until database failover + while not listener.is_changed_flag: + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) \ No newline at end of file diff --git a/tests/test_background.py b/tests/test_background.py index 4b3a5377c1..ba62e5bdd9 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,3 +1,4 @@ +import asyncio from time import sleep import pytest @@ -57,4 +58,36 @@ def callback(arg1: str, arg2: int): sleep(timeout) + assert execute_counter == call_count + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "interval,timeout,call_count", + [ + (0.012, 0.04, 3), + (0.035, 0.04, 1), + (0.045, 0.04, 0), + ] + ) + async def test_run_recurring_async(self, interval, timeout, call_count): + execute_counter = 0 + one = 'arg1' + two = 9999 + + async def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + await scheduler.run_recurring_async(interval, callback, one, two) + assert execute_counter == 0 + + await asyncio.sleep(timeout) + assert execute_counter == call_count \ No newline at end of file diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 9503d79d9b..0c082f0f17 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -4,7 +4,7 @@ from redis import Redis from redis.data_structure import WeightedList -from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases @@ -19,8 +19,8 @@ def mock_client() -> Redis: return Mock(spec=Redis) @pytest.fixture() -def mock_cb() -> SyncCircuitBreaker: - return Mock(spec=SyncCircuitBreaker) +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) @pytest.fixture() def mock_fd() -> FailureDetector: @@ -41,7 +41,7 @@ def mock_db(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -55,7 +55,7 @@ def mock_db1(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -69,7 +69,7 @@ def mock_db2(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index f5f39c3f6b..7dc642373b 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -1,7 +1,7 @@ import pybreaker import pytest -from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker class TestPBCircuitBreaker: @@ -39,7 +39,7 @@ def test_cb_executes_callback_on_state_changed(self): adapter = PBCircuitBreakerAdapter(cb=pb_circuit) called_count = 0 - def callback(cb: SyncCircuitBreaker, old_state: CbState, new_state: CbState): + def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): nonlocal called_count assert old_state == CbState.CLOSED assert new_state == CbState.HALF_OPEN diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index c7c15fe684..d352c1da92 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -166,13 +166,9 @@ def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - sleep(0.15) - assert client.set('key', 'value') == 'OK2' - sleep(0.22) - assert client.set('key', 'value') == 'OK1' @pytest.mark.parametrize( diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index e428b3ce7a..1ea63a0e14 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,6 +1,6 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database @@ -49,11 +49,11 @@ def test_overridden_config(self): mock_connection_pools[0].connection_kwargs = {} mock_connection_pools[1].connection_kwargs = {} mock_connection_pools[2].connection_kwargs = {} - mock_cb1 = Mock(spec=SyncCircuitBreaker) + mock_cb1 = Mock(spec=CircuitBreaker) mock_cb1.grace_period = grace_period - mock_cb2 = Mock(spec=SyncCircuitBreaker) + mock_cb2 = Mock(spec=CircuitBreaker) mock_cb2.grace_period = grace_period - mock_cb3 = Mock(spec=SyncCircuitBreaker) + mock_cb3 = Mock(spec=CircuitBreaker) mock_cb3.grace_period = grace_period mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] @@ -113,7 +113,7 @@ def test_default_config(self): def test_overridden_config(self): mock_connection_pool = Mock(spec=ConnectionPool) - mock_circuit = Mock(spec=SyncCircuitBreaker) + mock_circuit = Mock(spec=CircuitBreaker) config = DatabaseConfig( client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit