diff --git a/docker-compose.yml b/docker-compose.yml index 46c70ba5a9..ef7fd813b3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,7 @@ x-client-libs-stack-image: &client-libs-stack-image image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.2}" x-client-libs-image: &client-libs-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.2}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.4-RC1-pre.2}" services: @@ -58,7 +58,7 @@ services: - TLS_ENABLED=yes - PORT=16379 - TLS_PORT=27379 - command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --tls-auth-clients optional --save ""} + command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --tls-auth-clients optional --save "" --tls-cluster yes} ports: - "16379-16384:16379-16384" - "27379-27384:27379-27384" diff --git a/redis/_parsers/commands.py b/redis/_parsers/commands.py index b5109252ae..9ec50a240f 100644 --- a/redis/_parsers/commands.py +++ b/redis/_parsers/commands.py @@ -1,12 +1,50 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from enum import Enum +from typing import TYPE_CHECKING, Any, Awaitable, Dict, Optional, Tuple, Union -from redis.exceptions import RedisError, ResponseError +from redis.exceptions import IncorrectPolicyType, RedisError, ResponseError from redis.utils import str_if_bytes if TYPE_CHECKING: from redis.asyncio.cluster import ClusterNode +class RequestPolicy(Enum): + ALL_NODES = "all_nodes" + ALL_SHARDS = "all_shards" + ALL_REPLICAS = "all_replicas" + MULTI_SHARD = "multi_shard" + SPECIAL = "special" + DEFAULT_KEYLESS = "default_keyless" + DEFAULT_KEYED = "default_keyed" + DEFAULT_NODE = "default_node" + + +class ResponsePolicy(Enum): + ONE_SUCCEEDED = "one_succeeded" + ALL_SUCCEEDED = "all_succeeded" + AGG_LOGICAL_AND = "agg_logical_and" + AGG_LOGICAL_OR = "agg_logical_or" + AGG_MIN = "agg_min" + AGG_MAX = "agg_max" + AGG_SUM = "agg_sum" + SPECIAL = "special" + DEFAULT_KEYLESS = "default_keyless" + DEFAULT_KEYED = "default_keyed" + + +class CommandPolicies: + def __init__( + self, + request_policy: RequestPolicy = RequestPolicy.DEFAULT_KEYLESS, + response_policy: ResponsePolicy = ResponsePolicy.DEFAULT_KEYLESS, + ): + self.request_policy = request_policy + self.response_policy = response_policy + + +PolicyRecords = dict[str, dict[str, CommandPolicies]] + + class AbstractCommandsParser: def _get_pubsub_keys(self, *args): """ @@ -64,7 +102,8 @@ class CommandsParser(AbstractCommandsParser): def __init__(self, redis_connection): self.commands = {} - self.initialize(redis_connection) + self.redis_connection = redis_connection + self.initialize(self.redis_connection) def initialize(self, r): commands = r.command() @@ -129,7 +168,9 @@ def get_keys(self, redis_conn, *args): for subcmd in command["subcommands"]: if str_if_bytes(subcmd[0]) == subcmd_name: command = self.parse_subcommand(subcmd) - is_subcmd = True + + if command["first_key_pos"] > 0: + is_subcmd = True # The command doesn't have keys in it if not is_subcmd: @@ -169,6 +210,190 @@ def _get_moveable_keys(self, redis_conn, *args): raise e return keys + def _is_keyless_command( + self, command_name: str, subcommand_name: Optional[str] = None + ) -> bool: + """ + Determines whether a given command or subcommand is considered "keyless". + + A keyless command does not operate on specific keys, which is determined based + on the first key position in the command or subcommand details. If the command + or subcommand's first key position is zero or negative, it is treated as keyless. + + Parameters: + command_name: str + The name of the command to check. + subcommand_name: Optional[str], default=None + The name of the subcommand to check, if applicable. If not provided, + the check is performed only on the command. + + Returns: + bool + True if the specified command or subcommand is considered keyless, + False otherwise. + + Raises: + ValueError + If the specified subcommand is not found within the command or the + specified command does not exist in the available commands. + """ + if subcommand_name: + for subcommand in self.commands.get(command_name)["subcommands"]: + if str_if_bytes(subcommand[0]) == subcommand_name: + parsed_subcmd = self.parse_subcommand(subcommand) + return parsed_subcmd["first_key_pos"] <= 0 + raise ValueError( + f"Subcommand {subcommand_name} not found in command {command_name}" + ) + else: + command_details = self.commands.get(command_name, None) + if command_details is not None: + return command_details["first_key_pos"] <= 0 + + raise ValueError(f"Command {command_name} not found in commands") + + def get_command_policies(self) -> PolicyRecords: + """ + Retrieve and process the command policies for all commands and subcommands. + + This method traverses through commands and subcommands, extracting policy details + from associated data structures and constructing a dictionary of commands with their + associated policies. It supports nested data structures and handles both main commands + and their subcommands. + + Returns: + PolicyRecords: A collection of commands and subcommands associated with their + respective policies. + + Raises: + IncorrectPolicyType: If an invalid policy type is encountered during policy extraction. + """ + command_with_policies = {} + + def extract_policies(data, module_name, command_name): + """ + Recursively extract policies from nested data structures. + + Args: + data: The data structure to search (can be list, dict, str, bytes, etc.) + command_name: The command name to associate with found policies + """ + if isinstance(data, (str, bytes)): + # Decode bytes to string if needed + policy = str_if_bytes(data.decode()) + + # Check if this is a policy string + if policy.startswith("request_policy") or policy.startswith( + "response_policy" + ): + if policy.startswith("request_policy"): + policy_type = policy.split(":")[1] + + try: + command_with_policies[module_name][ + command_name + ].request_policy = RequestPolicy(policy_type) + except ValueError: + raise IncorrectPolicyType( + f"Incorrect request policy type: {policy_type}" + ) + + if policy.startswith("response_policy"): + policy_type = policy.split(":")[1] + + try: + command_with_policies[module_name][ + command_name + ].response_policy = ResponsePolicy(policy_type) + except ValueError: + raise IncorrectPolicyType( + f"Incorrect response policy type: {policy_type}" + ) + + elif isinstance(data, list): + # For lists, recursively process each element + for item in data: + extract_policies(item, module_name, command_name) + + elif isinstance(data, dict): + # For dictionaries, recursively process each value + for value in data.values(): + extract_policies(value, module_name, command_name) + + for command, details in self.commands.items(): + # Check whether the command has keys + is_keyless = self._is_keyless_command(command) + + if is_keyless: + default_request_policy = RequestPolicy.DEFAULT_KEYLESS + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS + else: + default_request_policy = RequestPolicy.DEFAULT_KEYED + default_response_policy = ResponsePolicy.DEFAULT_KEYED + + # Check if it's a core or module command + split_name = command.split(".") + + if len(split_name) > 1: + module_name = split_name[0] + command_name = split_name[1] + else: + module_name = "core" + command_name = split_name[0] + + # Create a CommandPolicies object with default policies on the new command. + if command_with_policies.get(module_name, None) is None: + command_with_policies[module_name] = { + command_name: CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy, + ) + } + else: + command_with_policies[module_name][command_name] = CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy, + ) + + tips = details.get("tips") + subcommands = details.get("subcommands") + + # Process tips for the main command + if tips: + extract_policies(tips, module_name, command_name) + + # Process subcommands + if subcommands: + for subcommand_details in subcommands: + # Get the subcommand name (first element) + subcmd_name = subcommand_details[0] + if isinstance(subcmd_name, bytes): + subcmd_name = subcmd_name.decode() + + # Check whether the subcommand has keys + is_keyless = self._is_keyless_command(command, subcmd_name) + + if is_keyless: + default_request_policy = RequestPolicy.DEFAULT_KEYLESS + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS + else: + default_request_policy = RequestPolicy.DEFAULT_KEYED + default_response_policy = ResponsePolicy.DEFAULT_KEYED + + subcmd_name = subcmd_name.replace("|", " ") + + # Create a CommandPolicies object with default policies on the new command. + command_with_policies[module_name][subcmd_name] = CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy, + ) + + # Recursively extract policies from the rest of the subcommand details + for subcommand_detail in subcommand_details[1:]: + extract_policies(subcommand_detail, module_name, subcmd_name) + + return command_with_policies + class AsyncCommandsParser(AbstractCommandsParser): """ @@ -251,7 +476,9 @@ async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: for subcmd in command["subcommands"]: if str_if_bytes(subcmd[0]) == subcmd_name: command = self.parse_subcommand(subcmd) - is_subcmd = True + + if command["first_key_pos"] > 0: + is_subcmd = True # The command doesn't have keys in it if not is_subcmd: @@ -279,3 +506,187 @@ async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: else: raise e return keys + + async def _is_keyless_command( + self, command_name: str, subcommand_name: Optional[str] = None + ) -> bool: + """ + Determines whether a given command or subcommand is considered "keyless". + + A keyless command does not operate on specific keys, which is determined based + on the first key position in the command or subcommand details. If the command + or subcommand's first key position is zero or negative, it is treated as keyless. + + Parameters: + command_name: str + The name of the command to check. + subcommand_name: Optional[str], default=None + The name of the subcommand to check, if applicable. If not provided, + the check is performed only on the command. + + Returns: + bool + True if the specified command or subcommand is considered keyless, + False otherwise. + + Raises: + ValueError + If the specified subcommand is not found within the command or the + specified command does not exist in the available commands. + """ + if subcommand_name: + for subcommand in self.commands.get(command_name)["subcommands"]: + if str_if_bytes(subcommand[0]) == subcommand_name: + parsed_subcmd = self.parse_subcommand(subcommand) + return parsed_subcmd["first_key_pos"] <= 0 + raise ValueError( + f"Subcommand {subcommand_name} not found in command {command_name}" + ) + else: + command_details = self.commands.get(command_name, None) + if command_details is not None: + return command_details["first_key_pos"] <= 0 + + raise ValueError(f"Command {command_name} not found in commands") + + async def get_command_policies(self) -> Awaitable[PolicyRecords]: + """ + Retrieve and process the command policies for all commands and subcommands. + + This method traverses through commands and subcommands, extracting policy details + from associated data structures and constructing a dictionary of commands with their + associated policies. It supports nested data structures and handles both main commands + and their subcommands. + + Returns: + PolicyRecords: A collection of commands and subcommands associated with their + respective policies. + + Raises: + IncorrectPolicyType: If an invalid policy type is encountered during policy extraction. + """ + command_with_policies = {} + + def extract_policies(data, module_name, command_name): + """ + Recursively extract policies from nested data structures. + + Args: + data: The data structure to search (can be list, dict, str, bytes, etc.) + command_name: The command name to associate with found policies + """ + if isinstance(data, (str, bytes)): + # Decode bytes to string if needed + policy = str_if_bytes(data.decode()) + + # Check if this is a policy string + if policy.startswith("request_policy") or policy.startswith( + "response_policy" + ): + if policy.startswith("request_policy"): + policy_type = policy.split(":")[1] + + try: + command_with_policies[module_name][ + command_name + ].request_policy = RequestPolicy(policy_type) + except ValueError: + raise IncorrectPolicyType( + f"Incorrect request policy type: {policy_type}" + ) + + if policy.startswith("response_policy"): + policy_type = policy.split(":")[1] + + try: + command_with_policies[module_name][ + command_name + ].response_policy = ResponsePolicy(policy_type) + except ValueError: + raise IncorrectPolicyType( + f"Incorrect response policy type: {policy_type}" + ) + + elif isinstance(data, list): + # For lists, recursively process each element + for item in data: + extract_policies(item, module_name, command_name) + + elif isinstance(data, dict): + # For dictionaries, recursively process each value + for value in data.values(): + extract_policies(value, module_name, command_name) + + for command, details in self.commands.items(): + # Check whether the command has keys + is_keyless = await self._is_keyless_command(command) + + if is_keyless: + default_request_policy = RequestPolicy.DEFAULT_KEYLESS + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS + else: + default_request_policy = RequestPolicy.DEFAULT_KEYED + default_response_policy = ResponsePolicy.DEFAULT_KEYED + + # Check if it's a core or module command + split_name = command.split(".") + + if len(split_name) > 1: + module_name = split_name[0] + command_name = split_name[1] + else: + module_name = "core" + command_name = split_name[0] + + # Create a CommandPolicies object with default policies on the new command. + if command_with_policies.get(module_name, None) is None: + command_with_policies[module_name] = { + command_name: CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy, + ) + } + else: + command_with_policies[module_name][command_name] = CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy, + ) + + tips = details.get("tips") + subcommands = details.get("subcommands") + + # Process tips for the main command + if tips: + extract_policies(tips, module_name, command_name) + + # Process subcommands + if subcommands: + for subcommand_details in subcommands: + # Get the subcommand name (first element) + subcmd_name = subcommand_details[0] + if isinstance(subcmd_name, bytes): + subcmd_name = subcmd_name.decode() + + # Check whether the subcommand has keys + is_keyless = await self._is_keyless_command(command, subcmd_name) + + if is_keyless: + default_request_policy = RequestPolicy.DEFAULT_KEYLESS + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS + else: + default_request_policy = RequestPolicy.DEFAULT_KEYED + default_response_policy = ResponsePolicy.DEFAULT_KEYED + + subcmd_name = subcmd_name.replace("|", " ") + + # Create a CommandPolicies object with default policies on the new command. + command_with_policies[module_name][subcmd_name] = CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy, + ) + + # Recursively extract policies from the rest of the subcommand details + for subcommand_detail in subcommand_details[1:]: + extract_policies(subcommand_detail, module_name, subcmd_name) + + return command_with_policies diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 225fd3b79f..d70569bb95 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -26,6 +26,7 @@ ) from redis._parsers import AsyncCommandsParser, Encoder +from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, @@ -51,6 +52,7 @@ parse_cluster_slots, ) from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands +from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.credentials import CredentialProvider from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher @@ -310,6 +312,7 @@ def __init__( protocol: Optional[int] = 2, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, event_dispatcher: Optional[EventDispatcher] = None, + policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(), ) -> None: if db: raise RedisClusterException( @@ -423,7 +426,36 @@ def __init__( self.load_balancing_strategy = load_balancing_strategy self.reinitialize_steps = reinitialize_steps self.reinitialize_counter = 0 + + # For backward compatibility, mapping from existing policies to new one + self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { + self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, + self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, + self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, + self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, + self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, + SLOT_ID: RequestPolicy.DEFAULT_KEYED, + } + + self._policies_callback_mapping: dict[ + Union[RequestPolicy, ResponsePolicy], Callable + ] = { + RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ + self.get_random_primary_or_all_nodes(command_name) + ], + RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot, + RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], + RequestPolicy.ALL_SHARDS: self.get_primaries, + RequestPolicy.ALL_NODES: self.get_nodes, + RequestPolicy.ALL_REPLICAS: self.get_replicas, + RequestPolicy.SPECIAL: self.get_special_nodes, + ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, + ResponsePolicy.DEFAULT_KEYED: lambda res: res, + } + + self._policy_resolver = policy_resolver self.commands_parser = AsyncCommandsParser() + self._aggregate_nodes = None self.node_flags = self.__class__.NODE_FLAGS.copy() self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.response_callbacks = kwargs["response_callbacks"] @@ -619,6 +651,45 @@ def get_node_from_key( return slot_cache[node_idx] + def get_random_primary_or_all_nodes(self, command_name): + """ + Returns random primary or all nodes depends on READONLY mode. + """ + if self.read_from_replicas and command_name in READ_COMMANDS: + return self.get_random_node() + + return self.get_random_primary_node() + + def get_random_primary_node(self) -> "ClusterNode": + """ + Returns a random primary node + """ + return random.choice(self.get_primaries()) + + async def get_nodes_from_slot(self, command: str, *args): + """ + Returns a list of nodes that hold the specified keys' slots. + """ + # get the node that holds the key's slot + return [ + self.nodes_manager.get_node_from_slot( + await self._determine_slot(command, *args), + self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy if command in READ_COMMANDS else None, + ) + ] + + def get_special_nodes(self) -> Optional[list["ClusterNode"]]: + """ + Returns a list of nodes for commands with a special policy. + """ + if not self._aggregate_nodes: + raise RedisClusterException( + "Cannot execute FT.CURSOR commands without FT.AGGREGATE" + ) + + return self._aggregate_nodes + def keyslot(self, key: EncodableT) -> int: """ Find the keyslot for a given key. @@ -643,7 +714,11 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No self.response_callbacks[command] = callback async def _determine_nodes( - self, command: str, *args: Any, node_flag: Optional[str] = None + self, + command: str, + *args: Any, + request_policy: RequestPolicy, + node_flag: Optional[str] = None, ) -> List["ClusterNode"]: # Determine which nodes should be executed the command on. # Returns a list of target nodes. @@ -651,31 +726,22 @@ async def _determine_nodes( # get the nodes group for this command if it was predefined node_flag = self.command_flags.get(command) - if node_flag in self.node_flags: - if node_flag == self.__class__.DEFAULT_NODE: - # return the cluster's default node - return [self.nodes_manager.default_node] - if node_flag == self.__class__.PRIMARIES: - # return all primaries - return self.nodes_manager.get_nodes_by_server_type(PRIMARY) - if node_flag == self.__class__.REPLICAS: - # return all replicas - return self.nodes_manager.get_nodes_by_server_type(REPLICA) - if node_flag == self.__class__.ALL_NODES: - # return all nodes - return list(self.nodes_manager.nodes_cache.values()) - if node_flag == self.__class__.RANDOM: - # return a random node - return [random.choice(list(self.nodes_manager.nodes_cache.values()))] + if node_flag in self._command_flags_mapping: + request_policy = self._command_flags_mapping[node_flag] - # get the node that holds the key's slot - return [ - self.nodes_manager.get_node_from_slot( - await self._determine_slot(command, *args), - self.read_from_replicas and command in READ_COMMANDS, - self.load_balancing_strategy if command in READ_COMMANDS else None, - ) - ] + policy_callback = self._policies_callback_mapping[request_policy] + + if request_policy == RequestPolicy.DEFAULT_KEYED: + nodes = await policy_callback(command, *args) + elif request_policy == RequestPolicy.DEFAULT_KEYLESS: + nodes = policy_callback(command) + else: + nodes = policy_callback() + + if command.lower() == "ft.aggregate": + self._aggregate_nodes = nodes + + return nodes async def _determine_slot(self, command: str, *args: Any) -> int: if self.command_flags.get(command) == SLOT_ID: @@ -780,6 +846,33 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: target_nodes_specified = True retry_attempts = 0 + command_policies = await self._policy_resolver.resolve(args[0].lower()) + + if not command_policies and not target_nodes_specified: + command_flag = self.command_flags.get(command) + if not command_flag: + # Fallback to default policy + if not self.get_default_node(): + slot = None + else: + slot = await self._determine_slot(*args) + if not slot: + command_policies = CommandPolicies() + else: + command_policies = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + else: + if command_flag in self._command_flags_mapping: + command_policies = CommandPolicies( + request_policy=self._command_flags_mapping[command_flag] + ) + else: + command_policies = CommandPolicies() + elif not command_policies and target_nodes_specified: + command_policies = CommandPolicies() + # Add one for the first execution execute_attempts = 1 + retry_attempts for _ in range(execute_attempts): @@ -795,7 +888,9 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: if not target_nodes_specified: # Determine the nodes to execute the command on target_nodes = await self._determine_nodes( - *args, node_flag=passed_targets + *args, + request_policy=command_policies.request_policy, + node_flag=passed_targets, ) if not target_nodes: raise RedisClusterException( @@ -806,10 +901,12 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: # Return the processed result ret = await self._execute_command(target_nodes[0], *args, **kwargs) if command in self.result_callbacks: - return self.result_callbacks[command]( + ret = self.result_callbacks[command]( command, {target_nodes[0].name: ret}, **kwargs ) - return ret + return self._policies_callback_mapping[ + command_policies.response_policy + ](ret) else: keys = [node.name for node in target_nodes] values = await asyncio.gather( @@ -824,7 +921,9 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: return self.result_callbacks[command]( command, dict(zip(keys, values)), **kwargs ) - return dict(zip(keys, values)) + return self._policies_callback_mapping[ + command_policies.response_policy + ](dict(zip(keys, values))) except Exception as e: if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were should be reinitialized. @@ -1740,6 +1839,7 @@ def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: self.kwargs = kwargs self.position = position self.result: Union[Any, Exception] = None + self.command_policies: Optional[CommandPolicies] = None def __repr__(self) -> str: return f"[{self.position}] {self.args} ({self.kwargs})" @@ -1980,16 +2080,51 @@ async def _execute( nodes = {} for cmd in todo: passed_targets = cmd.kwargs.pop("target_nodes", None) + command_policies = await client._policy_resolver.resolve( + cmd.args[0].lower() + ) + if passed_targets and not client._is_node_flag(passed_targets): target_nodes = client._parse_target_nodes(passed_targets) + + if not command_policies: + command_policies = CommandPolicies() else: + if not command_policies: + command_flag = client.command_flags.get(cmd.args[0]) + if not command_flag: + # Fallback to default policy + if not client.get_default_node(): + slot = None + else: + slot = await client._determine_slot(*cmd.args) + if not slot: + command_policies = CommandPolicies() + else: + command_policies = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + else: + if command_flag in client._command_flags_mapping: + command_policies = CommandPolicies( + request_policy=client._command_flags_mapping[ + command_flag + ] + ) + else: + command_policies = CommandPolicies() + target_nodes = await client._determine_nodes( - *cmd.args, node_flag=passed_targets + *cmd.args, + request_policy=command_policies.request_policy, + node_flag=passed_targets, ) if not target_nodes: raise RedisClusterException( f"No targets were found to execute {cmd.args} command on" ) + cmd.command_policies = command_policies if len(target_nodes) > 1: raise RedisClusterException(f"Too many targets for command {cmd.args}") node = target_nodes[0] @@ -2010,9 +2145,9 @@ async def _execute( for cmd in todo: if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): try: - cmd.result = await client.execute_command( - *cmd.args, **cmd.kwargs - ) + cmd.result = client._policies_callback_mapping[ + cmd.command_policies.response_policy + ](await client.execute_command(*cmd.args, **cmd.kwargs)) except Exception as e: cmd.result = e diff --git a/redis/cluster.py b/redis/cluster.py index b34f3ea9da..33b54b1bed 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -11,12 +11,14 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from redis._parsers import CommandsParser, Encoder +from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy from redis._parsers.helpers import parse_scan from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.cache import CacheConfig, CacheFactory, CacheFactoryInterface, CacheInterface from redis.client import EMPTY_RESPONSE, CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args +from redis.commands.policies import PolicyResolver, StaticPolicyResolver from redis.connection import ( Connection, ConnectionPool, @@ -532,6 +534,7 @@ def __init__( cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, event_dispatcher: Optional[EventDispatcher] = None, + policy_resolver: PolicyResolver = StaticPolicyResolver(), **kwargs, ): """ @@ -714,7 +717,40 @@ def __init__( ) self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) + # For backward compatibility, mapping from existing policies to new one + self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { + self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, + self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, + self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, + self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, + self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, + SLOT_ID: RequestPolicy.DEFAULT_KEYED, + } + + self._policies_callback_mapping: dict[ + Union[RequestPolicy, ResponsePolicy], Callable + ] = { + RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ + self.get_random_primary_or_all_nodes(command_name) + ], + RequestPolicy.DEFAULT_KEYED: lambda command, + *args: self.get_nodes_from_slot(command, *args), + RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], + RequestPolicy.ALL_SHARDS: self.get_primaries, + RequestPolicy.ALL_NODES: self.get_nodes, + RequestPolicy.ALL_REPLICAS: self.get_replicas, + RequestPolicy.MULTI_SHARD: lambda *args, + **kwargs: self._split_multi_shard_command(*args, **kwargs), + RequestPolicy.SPECIAL: self.get_special_nodes, + ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, + ResponsePolicy.DEFAULT_KEYED: lambda res: res, + } + + self._policy_resolver = policy_resolver self.commands_parser = CommandsParser(self) + + # Node where FT.AGGREGATE command is executed. + self._aggregate_nodes = None self._lock = threading.RLock() def __enter__(self): @@ -777,6 +813,15 @@ def get_replicas(self): def get_random_node(self): return random.choice(list(self.nodes_manager.nodes_cache.values())) + def get_random_primary_or_all_nodes(self, command_name): + """ + Returns random primary or all nodes depends on READONLY mode. + """ + if self.read_from_replicas and command_name in READ_COMMANDS: + return self.get_random_node() + + return self.get_random_primary_node() + def get_nodes(self): return list(self.nodes_manager.nodes_cache.values()) @@ -806,6 +851,77 @@ def get_default_node(self): """ return self.nodes_manager.default_node + def get_nodes_from_slot(self, command: str, *args): + """ + Returns a list of nodes that hold the specified keys' slots. + """ + # get the node that holds the key's slot + slot = self.determine_slot(*args) + node = self.nodes_manager.get_node_from_slot( + slot, + self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy if command in READ_COMMANDS else None, + ) + return [node] + + def _split_multi_shard_command(self, *args, **kwargs) -> list[dict]: + """ + Splits the command with Multi-Shard policy, to the multiple commands + """ + keys = self._get_command_keys(*args) + commands = [] + + for key in keys: + commands.append( + { + "args": (args[0], key), + "kwargs": kwargs, + } + ) + + return commands + + def get_special_nodes(self) -> Optional[list["ClusterNode"]]: + """ + Returns a list of nodes for commands with a special policy. + """ + if not self._aggregate_nodes: + raise RedisClusterException( + "Cannot execute FT.CURSOR commands without FT.AGGREGATE" + ) + + return self._aggregate_nodes + + def get_random_primary_node(self) -> "ClusterNode": + """ + Returns a random primary node + """ + return random.choice(self.get_primaries()) + + def _evaluate_all_succeeded(self, res): + """ + Evaluate the result of a command with ResponsePolicy.ALL_SUCCEEDED + """ + first_successful_response = None + + if isinstance(res, dict): + for key, value in res.items(): + if value: + if first_successful_response is None: + first_successful_response = {key: value} + else: + return {key: False} + else: + for response in res: + if response: + if first_successful_response is None: + # Dynamically resolve type + first_successful_response = type(response)(response) + else: + return type(response)(False) + + return first_successful_response + def set_default_node(self, node): """ Set the default node of the cluster. @@ -955,9 +1071,12 @@ def set_response_callback(self, command, callback): """Set a custom Response Callback""" self.cluster_response_callbacks[command] = callback - def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: - # Determine which nodes should be executed the command on. - # Returns a list of target nodes. + def _determine_nodes( + self, *args, request_policy: RequestPolicy, **kwargs + ) -> List["ClusterNode"]: + """ + Determines a nodes the command should be executed on. + """ command = args[0].upper() if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: command = f"{args[0]} {args[1]}".upper() @@ -969,32 +1088,25 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: else: # get the nodes group for this command if it was predefined command_flag = self.command_flags.get(command) - if command_flag == self.__class__.RANDOM: - # return a random node - return [self.get_random_node()] - elif command_flag == self.__class__.PRIMARIES: - # return all primaries - return self.get_primaries() - elif command_flag == self.__class__.REPLICAS: - # return all replicas - return self.get_replicas() - elif command_flag == self.__class__.ALL_NODES: - # return all nodes - return self.get_nodes() - elif command_flag == self.__class__.DEFAULT_NODE: - # return the cluster's default node - return [self.nodes_manager.default_node] - elif command in self.__class__.SEARCH_COMMANDS[0]: - return [self.nodes_manager.default_node] + + if command_flag in self._command_flags_mapping: + request_policy = self._command_flags_mapping[command_flag] + + policy_callback = self._policies_callback_mapping[request_policy] + + if request_policy == RequestPolicy.DEFAULT_KEYED: + nodes = policy_callback(command, *args) + elif request_policy == RequestPolicy.MULTI_SHARD: + nodes = policy_callback(*args, **kwargs) + elif request_policy == RequestPolicy.DEFAULT_KEYLESS: + nodes = policy_callback(args[0]) else: - # get the node that holds the key's slot - slot = self.determine_slot(*args) - node = self.nodes_manager.get_node_from_slot( - slot, - self.read_from_replicas and command in READ_COMMANDS, - self.load_balancing_strategy if command in READ_COMMANDS else None, - ) - return [node] + nodes = policy_callback() + + if args[0].lower() == "ft.aggregate": + self._aggregate_nodes = nodes + + return nodes def _should_reinitialized(self): # To reinitialize the cluster on every MOVED error, @@ -1144,9 +1256,43 @@ def _internal_execute_command(self, *args, **kwargs): is_default_node = False target_nodes = None passed_targets = kwargs.pop("target_nodes", None) + command_policies = self._policy_resolver.resolve(args[0].lower()) + if passed_targets is not None and not self._is_nodes_flag(passed_targets): target_nodes = self._parse_target_nodes(passed_targets) target_nodes_specified = True + + if not command_policies and not target_nodes_specified: + command = args[0].upper() + if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: + command = f"{args[0]} {args[1]}".upper() + + # We only could resolve key properties if command is not + # in a list of pre-defined request policies + command_flag = self.command_flags.get(command) + if not command_flag: + # Fallback to default policy + if not self.get_default_node(): + slot = None + else: + slot = self.determine_slot(*args) + if not slot: + command_policies = CommandPolicies() + else: + command_policies = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + else: + if command_flag in self._command_flags_mapping: + command_policies = CommandPolicies( + request_policy=self._command_flags_mapping[command_flag] + ) + else: + command_policies = CommandPolicies() + elif not command_policies and target_nodes_specified: + command_policies = CommandPolicies() + # If an error that allows retrying was thrown, the nodes and slots # cache were reinitialized. We will retry executing the command with # the updated cluster setup only when the target nodes can be @@ -1164,7 +1310,9 @@ def _internal_execute_command(self, *args, **kwargs): if not target_nodes_specified: # Determine the nodes to execute the command on target_nodes = self._determine_nodes( - *args, **kwargs, nodes_flag=passed_targets + *args, + request_policy=command_policies.request_policy, + nodes_flag=passed_targets, ) if not target_nodes: raise RedisClusterException( @@ -1177,8 +1325,17 @@ def _internal_execute_command(self, *args, **kwargs): is_default_node = True for node in target_nodes: res[node.name] = self._execute_command(node, *args, **kwargs) + + if command_policies.response_policy == ResponsePolicy.ONE_SUCCEEDED: + break + # Return the processed result - return self._process_result(args[0], res, **kwargs) + return self._process_result( + args[0], + res, + response_policy=command_policies.response_policy, + **kwargs, + ) except Exception as e: if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: if is_default_node: @@ -1318,7 +1475,7 @@ def close(self) -> None: # RedisCluster's __init__ can fail before nodes_manager is set pass - def _process_result(self, command, res, **kwargs): + def _process_result(self, command, res, response_policy: ResponsePolicy, **kwargs): """ Process the result of the executed command. The function would return a dict or a single value. @@ -1330,13 +1487,13 @@ def _process_result(self, command, res, **kwargs): Dict """ if command in self.result_callbacks: - return self.result_callbacks[command](command, res, **kwargs) + res = self.result_callbacks[command](command, res, **kwargs) elif len(res) == 1: # When we execute the command on a single node, we can # remove the dictionary and return a single response - return list(res.values())[0] - else: - return res + res = list(res.values())[0] + + return self._policies_callback_mapping[response_policy](res) def load_external_module(self, funcname, func): """ @@ -2162,6 +2319,7 @@ def __init__( retry: Optional[Retry] = None, lock=None, transaction=False, + policy_resolver: PolicyResolver = StaticPolicyResolver(), **kwargs, ): """ """ @@ -2200,6 +2358,37 @@ def __init__( PipelineStrategy(self) if not transaction else TransactionStrategy(self) ) + # For backward compatibility, mapping from existing policies to new one + self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { + self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, + self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, + self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, + self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, + self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, + SLOT_ID: RequestPolicy.DEFAULT_KEYED, + } + + self._policies_callback_mapping: dict[ + Union[RequestPolicy, ResponsePolicy], Callable + ] = { + RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ + self.get_random_primary_or_all_nodes(command_name) + ], + RequestPolicy.DEFAULT_KEYED: lambda command, + *args: self.get_nodes_from_slot(command, *args), + RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], + RequestPolicy.ALL_SHARDS: self.get_primaries, + RequestPolicy.ALL_NODES: self.get_nodes, + RequestPolicy.ALL_REPLICAS: self.get_replicas, + RequestPolicy.MULTI_SHARD: lambda *args, + **kwargs: self._split_multi_shard_command(*args, **kwargs), + RequestPolicy.SPECIAL: self.get_special_nodes, + ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, + ResponsePolicy.DEFAULT_KEYED: lambda res: res, + } + + self._policy_resolver = policy_resolver + def __repr__(self): """ """ return f"{type(self).__name__}" @@ -2421,6 +2610,7 @@ def __init__(self, args, options=None, position=None): self.result = None self.node = None self.asking = False + self.command_policies: Optional[CommandPolicies] = None class NodeCommands: @@ -2779,6 +2969,8 @@ def _send_cluster_commands( # we figure out the slot number that command maps to, then from # the slot determine the node. for c in attempt: + command_policies = self._pipe._policy_resolver.resolve(c.args[0].lower()) + while True: # refer to our internal node -> slot table that # tells us where a given command should route to. @@ -2787,14 +2979,55 @@ def _send_cluster_commands( passed_targets = c.options.pop("target_nodes", None) if passed_targets and not self._is_nodes_flag(passed_targets): target_nodes = self._parse_target_nodes(passed_targets) + + if not command_policies: + command_policies = CommandPolicies() else: + if not command_policies: + command = c.args[0].upper() + if ( + len(c.args) >= 2 + and f"{c.args[0]} {c.args[1]}".upper() + in self._pipe.command_flags + ): + command = f"{c.args[0]} {c.args[1]}".upper() + + # We only could resolve key properties if command is not + # in a list of pre-defined request policies + command_flag = self.command_flags.get(command) + if not command_flag: + # Fallback to default policy + if not self._pipe.get_default_node(): + keys = None + else: + keys = self._pipe._get_command_keys(*c.args) + if not keys or len(keys) == 0: + command_policies = CommandPolicies() + else: + command_policies = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + else: + if command_flag in self._pipe._command_flags_mapping: + command_policies = CommandPolicies( + request_policy=self._pipe._command_flags_mapping[ + command_flag + ] + ) + else: + command_policies = CommandPolicies() + target_nodes = self._determine_nodes( - *c.args, node_flag=passed_targets + *c.args, + request_policy=command_policies.request_policy, + node_flag=passed_targets, ) if not target_nodes: raise RedisClusterException( f"No targets were found to execute {c.args} command on" ) + c.command_policies = command_policies if len(target_nodes) > 1: raise RedisClusterException( f"Too many targets for command {c.args}" @@ -2919,8 +3152,12 @@ def _send_cluster_commands( if c.args[0] in self._pipe.cluster_response_callbacks: # Remove keys entry, it needs only for cache. c.options.pop("keys", None) - c.result = self._pipe.cluster_response_callbacks[c.args[0]]( - c.result, **c.options + c.result = self._pipe._policies_callback_mapping[ + c.command_policies.response_policy + ]( + self._pipe.cluster_response_callbacks[c.args[0]]( + c.result, **c.options + ) ) response.append(c.result) @@ -2952,7 +3189,9 @@ def _parse_target_nodes(self, target_nodes): ) return nodes - def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: + def _determine_nodes( + self, *args, request_policy: RequestPolicy, **kwargs + ) -> List["ClusterNode"]: # Determine which nodes should be executed the command on. # Returns a list of target nodes. command = args[0].upper() @@ -2969,34 +3208,25 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: else: # get the nodes group for this command if it was predefined command_flag = self._pipe.command_flags.get(command) - if command_flag == self._pipe.RANDOM: - # return a random node - return [self._pipe.get_random_node()] - elif command_flag == self._pipe.PRIMARIES: - # return all primaries - return self._pipe.get_primaries() - elif command_flag == self._pipe.REPLICAS: - # return all replicas - return self._pipe.get_replicas() - elif command_flag == self._pipe.ALL_NODES: - # return all nodes - return self._pipe.get_nodes() - elif command_flag == self._pipe.DEFAULT_NODE: - # return the cluster's default node - return [self._nodes_manager.default_node] - elif command in self._pipe.SEARCH_COMMANDS[0]: - return [self._nodes_manager.default_node] + + if command_flag in self._pipe._command_flags_mapping: + request_policy = self._pipe._command_flags_mapping[command_flag] + + policy_callback = self._pipe._policies_callback_mapping[request_policy] + + if request_policy == RequestPolicy.DEFAULT_KEYED: + nodes = policy_callback(command, *args) + elif request_policy == RequestPolicy.MULTI_SHARD: + nodes = policy_callback(*args, **kwargs) + elif request_policy == RequestPolicy.DEFAULT_KEYLESS: + nodes = policy_callback(args[0]) else: - # get the node that holds the key's slot - slot = self._pipe.determine_slot(*args) - node = self._nodes_manager.get_node_from_slot( - slot, - self._pipe.read_from_replicas and command in READ_COMMANDS, - self._pipe.load_balancing_strategy - if command in READ_COMMANDS - else None, - ) - return [node] + nodes = policy_callback() + + if args[0].lower() == "ft.aggregate": + self._aggregate_nodes = nodes + + return nodes def multi(self): raise RedisClusterException( diff --git a/redis/commands/policies.py b/redis/commands/policies.py new file mode 100644 index 0000000000..c0c98d37f1 --- /dev/null +++ b/redis/commands/policies.py @@ -0,0 +1,330 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from redis._parsers.commands import ( + CommandPolicies, + CommandsParser, + PolicyRecords, + RequestPolicy, + ResponsePolicy, +) + +STATIC_POLICIES: PolicyRecords = { + "ft": { + "explaincli": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "suglen": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ), + "profile": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "dropindex": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "aliasupdate": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "alter": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "aggregate": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "syndump": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "create": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "explain": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "sugget": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ), + "dictdel": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "aliasadd": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "dictadd": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "synupdate": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "drop": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "info": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "sugadd": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ), + "dictdump": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "cursor": CommandPolicies( + request_policy=RequestPolicy.SPECIAL, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "search": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "tagvals": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "aliasdel": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "sugdel": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ), + "spellcheck": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + }, + "core": { + "command": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + }, +} + + +class PolicyResolver(ABC): + @abstractmethod + def resolve(self, command_name: str) -> Optional[CommandPolicies]: + """ + Resolves the command name and determines the associated command policies. + + Args: + command_name: The name of the command to resolve. + + Returns: + CommandPolicies: The policies associated with the specified command. + """ + pass + + @abstractmethod + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + """ + Factory method to instantiate a policy resolver with a fallback resolver. + + Args: + fallback: Fallback resolver + + Returns: + PolicyResolver: Returns a new policy resolver with the specified fallback resolver. + """ + pass + + +class AsyncPolicyResolver(ABC): + @abstractmethod + async def resolve(self, command_name: str) -> Optional[CommandPolicies]: + """ + Resolves the command name and determines the associated command policies. + + Args: + command_name: The name of the command to resolve. + + Returns: + CommandPolicies: The policies associated with the specified command. + """ + pass + + @abstractmethod + def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver": + """ + Factory method to instantiate an async policy resolver with a fallback resolver. + + Args: + fallback: Fallback resolver + + Returns: + AsyncPolicyResolver: Returns a new policy resolver with the specified fallback resolver. + """ + pass + + +class BasePolicyResolver(PolicyResolver): + """ + Base class for policy resolvers. + """ + + def __init__( + self, policies: PolicyRecords, fallback: Optional[PolicyResolver] = None + ) -> None: + self._policies = policies + self._fallback = fallback + + def resolve(self, command_name: str) -> Optional[CommandPolicies]: + parts = command_name.split(".") + + if len(parts) > 2: + raise ValueError(f"Wrong command or module name: {command_name}") + + module, command = parts if len(parts) == 2 else ("core", parts[0]) + + if self._policies.get(module, None) is None: + if self._fallback is not None: + return self._fallback.resolve(command_name) + else: + return None + + if self._policies.get(module).get(command, None) is None: + if self._fallback is not None: + return self._fallback.resolve(command_name) + else: + return None + + return self._policies.get(module).get(command) + + @abstractmethod + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + pass + + +class AsyncBasePolicyResolver(AsyncPolicyResolver): + """ + Async base class for policy resolvers. + """ + + def __init__( + self, policies: PolicyRecords, fallback: Optional[AsyncPolicyResolver] = None + ) -> None: + self._policies = policies + self._fallback = fallback + + async def resolve(self, command_name: str) -> Optional[CommandPolicies]: + parts = command_name.split(".") + + if len(parts) > 2: + raise ValueError(f"Wrong command or module name: {command_name}") + + module, command = parts if len(parts) == 2 else ("core", parts[0]) + + if self._policies.get(module, None) is None: + if self._fallback is not None: + return await self._fallback.resolve(command_name) + else: + return None + + if self._policies.get(module).get(command, None) is None: + if self._fallback is not None: + return await self._fallback.resolve(command_name) + else: + return None + + return self._policies.get(module).get(command) + + @abstractmethod + def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver": + pass + + +class DynamicPolicyResolver(BasePolicyResolver): + """ + Resolves policy dynamically based on the COMMAND output. + """ + + def __init__( + self, commands_parser: CommandsParser, fallback: Optional[PolicyResolver] = None + ) -> None: + """ + Parameters: + commands_parser (CommandsParser): COMMAND output parser. + fallback (Optional[PolicyResolver]): An optional resolver to be used when the + primary policies cannot handle a specific request. + """ + self._commands_parser = commands_parser + super().__init__(commands_parser.get_command_policies(), fallback) + + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + return DynamicPolicyResolver(self._commands_parser, fallback) + + +class StaticPolicyResolver(BasePolicyResolver): + """ + Resolves policy from a static list of policy records. + """ + + def __init__(self, fallback: Optional[PolicyResolver] = None) -> None: + """ + Parameters: + fallback (Optional[PolicyResolver]): An optional fallback policy resolver + used for resolving policies if static policies are inadequate. + """ + super().__init__(STATIC_POLICIES, fallback) + + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + return StaticPolicyResolver(fallback) + + +class AsyncDynamicPolicyResolver(AsyncBasePolicyResolver): + """ + Async version of DynamicPolicyResolver. + """ + + def __init__( + self, + policy_records: PolicyRecords, + fallback: Optional[AsyncPolicyResolver] = None, + ) -> None: + """ + Parameters: + policy_records (PolicyRecords): Policy records. + fallback (Optional[AsyncPolicyResolver]): An optional resolver to be used when the + primary policies cannot handle a specific request. + """ + super().__init__(policy_records, fallback) + + def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver": + return AsyncDynamicPolicyResolver(self._policies, fallback) + + +class AsyncStaticPolicyResolver(AsyncBasePolicyResolver): + """ + Async version of StaticPolicyResolver. + """ + + def __init__(self, fallback: Optional[AsyncPolicyResolver] = None) -> None: + """ + Parameters: + fallback (Optional[AsyncPolicyResolver]): An optional fallback policy resolver + used for resolving policies if static policies are inadequate. + """ + super().__init__(STATIC_POLICIES, fallback) + + def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver": + return AsyncStaticPolicyResolver(fallback) diff --git a/redis/exceptions.py b/redis/exceptions.py index 1e21265524..dab17c5c1f 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -253,3 +253,11 @@ class ExternalAuthProviderError(ConnectionError): """ pass + + +class IncorrectPolicyType(Exception): + """ + Raised when a policy type isn't matching to any known policy types. + """ + + pass diff --git a/tests/test_asyncio/test_command_parser.py b/tests/test_asyncio/test_command_parser.py new file mode 100644 index 0000000000..430a72f885 --- /dev/null +++ b/tests/test_asyncio/test_command_parser.py @@ -0,0 +1,161 @@ +import pytest + +from redis._parsers import AsyncCommandsParser +from redis._parsers.commands import RequestPolicy, ResponsePolicy +from tests.conftest import skip_if_server_version_lt + + +@pytest.mark.onlycluster +@skip_if_server_version_lt("8.0.0") +class TestAsyncCommandParser: + @pytest.mark.asyncio + async def test_get_command_policies(self, r): + commands_parser = AsyncCommandsParser() + await commands_parser.initialize(node=r.get_default_node()) + expected_command_policies = { + "core": { + "keys": [ + "keys", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.DEFAULT_KEYLESS, + ], + "acl setuser": [ + "acl setuser", + RequestPolicy.ALL_NODES, + ResponsePolicy.ALL_SUCCEEDED, + ], + "exists": ["exists", RequestPolicy.MULTI_SHARD, ResponsePolicy.AGG_SUM], + "config resetstat": [ + "config resetstat", + RequestPolicy.ALL_NODES, + ResponsePolicy.ALL_SUCCEEDED, + ], + "slowlog len": [ + "slowlog len", + RequestPolicy.ALL_NODES, + ResponsePolicy.AGG_SUM, + ], + "scan": ["scan", RequestPolicy.SPECIAL, ResponsePolicy.SPECIAL], + "latency history": [ + "latency history", + RequestPolicy.ALL_NODES, + ResponsePolicy.SPECIAL, + ], + "memory doctor": [ + "memory doctor", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.SPECIAL, + ], + "randomkey": [ + "randomkey", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.SPECIAL, + ], + "mget": [ + "mget", + RequestPolicy.MULTI_SHARD, + ResponsePolicy.DEFAULT_KEYED, + ], + "function restore": [ + "function restore", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.ALL_SUCCEEDED, + ], + }, + "json": { + "debug": [ + "debug", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "get": [ + "get", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "ft": { + "search": [ + "search", + RequestPolicy.DEFAULT_KEYLESS, + ResponsePolicy.DEFAULT_KEYLESS, + ], + "create": [ + "create", + RequestPolicy.DEFAULT_KEYLESS, + ResponsePolicy.DEFAULT_KEYLESS, + ], + }, + "bf": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "madd": [ + "madd", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "cf": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "mexists": [ + "mexists", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "tdigest": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "min": [ + "min", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "ts": { + "create": [ + "create", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "info": [ + "info", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "topk": { + "list": [ + "list", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "query": [ + "query", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + } + + actual_policies = await commands_parser.get_command_policies() + assert len(actual_policies) > 0 + + for module_name, commands in expected_command_policies.items(): + for command, command_policies in commands.items(): + assert command in actual_policies[module_name] + assert command_policies == [ + command, + actual_policies[module_name][command].request_policy, + actual_policies[module_name][command].response_policy, + ] diff --git a/tests/test_asyncio/test_command_policies.py b/tests/test_asyncio/test_command_policies.py new file mode 100644 index 0000000000..7a52f256c9 --- /dev/null +++ b/tests/test_asyncio/test_command_policies.py @@ -0,0 +1,184 @@ +import random + +import pytest +from mock import patch + +from redis import ResponseError +from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy +from redis.asyncio import RedisCluster +from redis.commands.policies import ( + AsyncDynamicPolicyResolver, + AsyncStaticPolicyResolver, +) +from redis.commands.search.aggregation import AggregateRequest, Cursor +from redis.commands.search.field import NumericField, TextField +from tests.conftest import skip_if_server_version_lt, is_resp2_connection + + +@pytest.mark.asyncio +@pytest.mark.onlycluster +class TestBasePolicyResolver: + async def test_resolve(self): + zcount_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + rpoplpush_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + + dynamic_resolver = AsyncDynamicPolicyResolver( + { + "core": { + "zcount": zcount_policy, + "rpoplpush": rpoplpush_policy, + } + } + ) + assert await dynamic_resolver.resolve("zcount") == zcount_policy + assert await dynamic_resolver.resolve("rpoplpush") == rpoplpush_policy + + with pytest.raises( + ValueError, match="Wrong command or module name: foo.bar.baz" + ): + await dynamic_resolver.resolve("foo.bar.baz") + + assert await dynamic_resolver.resolve("foo.bar") is None + assert await dynamic_resolver.resolve("core.foo") is None + + # Test that policy fallback correctly + static_resolver = AsyncStaticPolicyResolver() + with_fallback_dynamic_resolver = dynamic_resolver.with_fallback(static_resolver) + resolved_policies = await with_fallback_dynamic_resolver.resolve("ft.aggregate") + + assert resolved_policies.request_policy == RequestPolicy.DEFAULT_KEYLESS + assert resolved_policies.response_policy == ResponsePolicy.DEFAULT_KEYLESS + + # Extended chain with one more resolver + foo_bar_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ) + + another_dynamic_resolver = AsyncDynamicPolicyResolver( + { + "foo": { + "bar": foo_bar_policy, + } + } + ) + with_fallback_static_resolver = static_resolver.with_fallback( + another_dynamic_resolver + ) + with_double_fallback_dynamic_resolver = dynamic_resolver.with_fallback( + with_fallback_static_resolver + ) + + assert ( + await with_double_fallback_dynamic_resolver.resolve("foo.bar") + == foo_bar_policy + ) + + +@pytest.mark.onlycluster +@pytest.mark.asyncio +@skip_if_server_version_lt("8.0.0") +class TestClusterWithPolicies: + async def test_resolves_correctly_policies(self, r: RedisCluster, monkeypatch): + # original nodes selection method + determine_nodes = r._determine_nodes + determined_nodes = [] + primary_nodes = r.get_primaries() + calls = iter(list(range(len(primary_nodes)))) + + async def wrapper(*args, request_policy: RequestPolicy, **kwargs): + nonlocal determined_nodes + determined_nodes = await determine_nodes( + *args, request_policy=request_policy, **kwargs + ) + return determined_nodes + + # Mock random.choice to always return a pre-defined sequence of nodes + monkeypatch.setattr(random, "choice", lambda seq: seq[next(calls)]) + + with patch.object(r, "_determine_nodes", side_effect=wrapper, autospec=True): + # Routed to a random primary node + await r.ft().create_index( + [ + NumericField("random_num"), + TextField("title"), + TextField("body"), + TextField("parent"), + ] + ) + assert determined_nodes[0] == primary_nodes[0] + + # Routed to another random primary node + info = await r.ft().info() + + if is_resp2_connection(r): + assert info["index_name"] == "idx" + else: + assert info[b"index_name"] == b"idx" + + assert determined_nodes[0] == primary_nodes[1] + + expected_node = await r.get_nodes_from_slot("FT.SUGLEN", *["foo"]) + await r.ft().suglen("foo") + assert determined_nodes[0] == expected_node[0] + + # Indexing a document + await r.hset( + "search", + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + "parent": "redis", + "random_num": 10, + }, + ) + await r.hset( + "ai", + mapping={ + "title": "RedisAI", + "body": "RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa + "parent": "redis", + "random_num": 3, + }, + ) + await r.hset( + "json", + mapping={ + "title": "RedisJson", + "body": "RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa + "parent": "redis", + "random_num": 8, + }, + ) + + req = AggregateRequest("redis").group_by("@parent").cursor(1) + res = await r.ft().aggregate(req) + + if is_resp2_connection(r): + cursor = res.cursor + else: + cursor = Cursor(res[1]) + + # Ensure that aggregate node was cached. + assert determined_nodes[0] == r._aggregate_nodes[0] + + await r.ft().aggregate(cursor) + + # Verify that FT.CURSOR dispatched to the same node. + assert determined_nodes[0] == r._aggregate_nodes[0] + + # Error propagates to a user + with pytest.raises(ResponseError, match="Cursor not found, id:"): + await r.ft().aggregate(cursor) + + assert determined_nodes[0] == primary_nodes[2] + + # Core commands also randomly distributed across masters + await r.randomkey() + assert determined_nodes[0] == primary_nodes[0] diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 2936bb0024..759c93ffc6 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -207,7 +207,28 @@ def cmd_init_mock(self, r): "first_key_pos": 1, "last_key_pos": 1, "step_count": 1, - } + }, + "cluster delslots": { + "name": "cluster delslots", + "flags": ["readonly", "fast"], + "first_key_pos": 0, + "last_key_pos": 0, + "step_count": 0, + }, + "cluster delslotsrange": { + "name": "cluster delslotsrange", + "flags": ["readonly", "fast"], + "first_key_pos": 0, + "last_key_pos": 0, + "step_count": 0, + }, + "cluster addslots": { + "name": "cluster delslotsrange", + "flags": ["readonly", "fast"], + "first_key_pos": 0, + "last_key_pos": 0, + "step_count": 0, + }, } cmd_parser_initialize.side_effect = cmd_init_mock diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index e3b44a147f..26b0edc238 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -1,5 +1,6 @@ import pytest from redis._parsers import CommandsParser +from redis._parsers.commands import RequestPolicy, ResponsePolicy from .conftest import ( assert_resp_response, @@ -106,3 +107,155 @@ def test_get_pubsub_keys(self, r): assert commands_parser.get_keys(r, *args2) == ["foo1", "foo2", "foo3"] assert commands_parser.get_keys(r, *args3) == ["*"] assert commands_parser.get_keys(r, *args4) == ["foo1", "foo2", "foo3"] + + @skip_if_server_version_lt("8.0.0") + @pytest.mark.onlycluster + def test_get_command_policies(self, r): + commands_parser = CommandsParser(r) + expected_command_policies = { + "core": { + "keys": [ + "keys", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.DEFAULT_KEYLESS, + ], + "acl setuser": [ + "acl setuser", + RequestPolicy.ALL_NODES, + ResponsePolicy.ALL_SUCCEEDED, + ], + "exists": ["exists", RequestPolicy.MULTI_SHARD, ResponsePolicy.AGG_SUM], + "config resetstat": [ + "config resetstat", + RequestPolicy.ALL_NODES, + ResponsePolicy.ALL_SUCCEEDED, + ], + "slowlog len": [ + "slowlog len", + RequestPolicy.ALL_NODES, + ResponsePolicy.AGG_SUM, + ], + "scan": ["scan", RequestPolicy.SPECIAL, ResponsePolicy.SPECIAL], + "latency history": [ + "latency history", + RequestPolicy.ALL_NODES, + ResponsePolicy.SPECIAL, + ], + "memory doctor": [ + "memory doctor", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.SPECIAL, + ], + "randomkey": [ + "randomkey", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.SPECIAL, + ], + "mget": [ + "mget", + RequestPolicy.MULTI_SHARD, + ResponsePolicy.DEFAULT_KEYED, + ], + "function restore": [ + "function restore", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.ALL_SUCCEEDED, + ], + }, + "json": { + "debug": [ + "debug", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "get": [ + "get", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "ft": { + "search": [ + "search", + RequestPolicy.DEFAULT_KEYLESS, + ResponsePolicy.DEFAULT_KEYLESS, + ], + "create": [ + "create", + RequestPolicy.DEFAULT_KEYLESS, + ResponsePolicy.DEFAULT_KEYLESS, + ], + }, + "bf": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "madd": [ + "madd", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "cf": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "mexists": [ + "mexists", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "tdigest": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "min": [ + "min", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "ts": { + "create": [ + "create", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "info": [ + "info", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "topk": { + "list": [ + "list", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "query": [ + "query", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + } + + actual_policies = commands_parser.get_command_policies() + assert len(actual_policies) > 0 + + for module_name, commands in expected_command_policies.items(): + for command, command_policies in commands.items(): + assert command in actual_policies[module_name] + assert command_policies == [ + command, + actual_policies[module_name][command].request_policy, + actual_policies[module_name][command].response_policy, + ] diff --git a/tests/test_command_policies.py b/tests/test_command_policies.py new file mode 100644 index 0000000000..633134a1b6 --- /dev/null +++ b/tests/test_command_policies.py @@ -0,0 +1,183 @@ +import random +from unittest.mock import Mock, patch + +import pytest + +from redis import ResponseError + +from redis._parsers import CommandsParser +from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy +from redis.commands.policies import DynamicPolicyResolver, StaticPolicyResolver +from redis.commands.search.aggregation import AggregateRequest, Cursor +from redis.commands.search.field import TextField, NumericField +from tests.conftest import skip_if_server_version_lt, is_resp2_connection + + +@pytest.mark.onlycluster +class TestBasePolicyResolver: + def test_resolve(self): + mock_command_parser = Mock(spec=CommandsParser) + zcount_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + rpoplpush_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + + mock_command_parser.get_command_policies.return_value = { + "core": { + "zcount": zcount_policy, + "rpoplpush": rpoplpush_policy, + } + } + + dynamic_resolver = DynamicPolicyResolver(mock_command_parser) + assert dynamic_resolver.resolve("zcount") == zcount_policy + assert dynamic_resolver.resolve("rpoplpush") == rpoplpush_policy + + with pytest.raises( + ValueError, match="Wrong command or module name: foo.bar.baz" + ): + dynamic_resolver.resolve("foo.bar.baz") + + assert dynamic_resolver.resolve("foo.bar") is None + assert dynamic_resolver.resolve("core.foo") is None + + # Test that policy fallback correctly + static_resolver = StaticPolicyResolver() + with_fallback_dynamic_resolver = dynamic_resolver.with_fallback(static_resolver) + + assert ( + with_fallback_dynamic_resolver.resolve("ft.aggregate").request_policy + == RequestPolicy.DEFAULT_KEYLESS + ) + assert ( + with_fallback_dynamic_resolver.resolve("ft.aggregate").response_policy + == ResponsePolicy.DEFAULT_KEYLESS + ) + + # Extended chain with one more resolver + mock_command_parser = Mock(spec=CommandsParser) + foo_bar_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ) + + mock_command_parser.get_command_policies.return_value = { + "foo": { + "bar": foo_bar_policy, + } + } + another_dynamic_resolver = DynamicPolicyResolver(mock_command_parser) + with_fallback_static_resolver = static_resolver.with_fallback( + another_dynamic_resolver + ) + with_double_fallback_dynamic_resolver = dynamic_resolver.with_fallback( + with_fallback_static_resolver + ) + + assert ( + with_double_fallback_dynamic_resolver.resolve("foo.bar") == foo_bar_policy + ) + + +@pytest.mark.onlycluster +@skip_if_server_version_lt("8.0.0") +class TestClusterWithPolicies: + def test_resolves_correctly_policies(self, r, monkeypatch): + # original nodes selection method + determine_nodes = r._determine_nodes + determined_nodes = [] + primary_nodes = r.get_primaries() + calls = iter(list(range(len(primary_nodes)))) + + def wrapper(*args, request_policy: RequestPolicy, **kwargs): + nonlocal determined_nodes + determined_nodes = determine_nodes( + *args, request_policy=request_policy, **kwargs + ) + return determined_nodes + + # Mock random.choice to always return a pre-defined sequence of nodes + monkeypatch.setattr(random, "choice", lambda seq: seq[next(calls)]) + + with patch.object(r, "_determine_nodes", side_effect=wrapper, autospec=True): + # Routed to a random primary node + r.ft().create_index( + ( + NumericField("random_num"), + TextField("title"), + TextField("body"), + TextField("parent"), + ) + ) + assert determined_nodes[0] == primary_nodes[0] + + # Routed to another random primary node + info = r.ft().info() + if is_resp2_connection(r): + assert info["index_name"] == "idx" + else: + assert info[b"index_name"] == b"idx" + + assert determined_nodes[0] == primary_nodes[1] + + expected_node = r.get_nodes_from_slot("ft.suglen", *["FT.SUGLEN", "foo"]) + r.ft().suglen("foo") + assert determined_nodes[0] == expected_node[0] + + # Indexing a document + r.hset( + "search", + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + "parent": "redis", + "random_num": 10, + }, + ) + r.hset( + "ai", + mapping={ + "title": "RedisAI", + "body": "RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa + "parent": "redis", + "random_num": 3, + }, + ) + r.hset( + "json", + mapping={ + "title": "RedisJson", + "body": "RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa + "parent": "redis", + "random_num": 8, + }, + ) + + req = AggregateRequest("redis").group_by("@parent").cursor(1) + + if is_resp2_connection(r): + cursor = r.ft().aggregate(req).cursor + else: + cursor = Cursor(r.ft().aggregate(req)[1]) + + # Ensure that aggregate node was cached. + assert determined_nodes[0] == r._aggregate_nodes[0] + + r.ft().aggregate(cursor) + + # Verify that FT.CURSOR dispatched to the same node. + assert determined_nodes[0] == r._aggregate_nodes[0] + + # Error propagates to a user + with pytest.raises(ResponseError, match="Cursor not found, id:"): + r.ft().aggregate(cursor) + + assert determined_nodes[0] == primary_nodes[2] + + # Core commands also randomly distributed across masters + r.randomkey() + assert determined_nodes[0] == primary_nodes[0]