Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ x-client-libs-stack-image: &client-libs-stack-image
image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.2}"

x-client-libs-image: &client-libs-image
image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.2}"
image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.4-RC1-pre.2}"

services:

Expand Down Expand Up @@ -58,7 +58,7 @@ services:
- TLS_ENABLED=yes
- PORT=16379
- TLS_PORT=27379
command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --tls-auth-clients optional --save ""}
command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --tls-auth-clients optional --save "" --tls-cluster yes}
ports:
- "16379-16384:16379-16384"
- "27379-27384:27379-27384"
Expand Down
421 changes: 416 additions & 5 deletions redis/_parsers/commands.py

Large diffs are not rendered by default.

201 changes: 168 additions & 33 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

from redis._parsers import AsyncCommandsParser, Encoder
from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy
from redis._parsers.helpers import (
_RedisCallbacks,
_RedisCallbacksRESP2,
Expand All @@ -51,6 +52,7 @@
parse_cluster_slots,
)
from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
from redis.credentials import CredentialProvider
from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
Expand Down Expand Up @@ -310,6 +312,7 @@ def __init__(
protocol: Optional[int] = 2,
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
event_dispatcher: Optional[EventDispatcher] = None,
policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(),
) -> None:
if db:
raise RedisClusterException(
Expand Down Expand Up @@ -423,7 +426,36 @@ def __init__(
self.load_balancing_strategy = load_balancing_strategy
self.reinitialize_steps = reinitialize_steps
self.reinitialize_counter = 0

# For backward compatibility, mapping from existing policies to new one
self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = {
self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS,
self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS,
self.__class__.ALL_NODES: RequestPolicy.ALL_NODES,
self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS,
self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE,
SLOT_ID: RequestPolicy.DEFAULT_KEYED,
}

self._policies_callback_mapping: dict[
Union[RequestPolicy, ResponsePolicy], Callable
] = {
RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [
self.get_random_primary_or_all_nodes(command_name)
],
RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot,
RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()],
RequestPolicy.ALL_SHARDS: self.get_primaries,
RequestPolicy.ALL_NODES: self.get_nodes,
RequestPolicy.ALL_REPLICAS: self.get_replicas,
RequestPolicy.SPECIAL: self.get_special_nodes,
ResponsePolicy.DEFAULT_KEYLESS: lambda res: res,
ResponsePolicy.DEFAULT_KEYED: lambda res: res,
}

self._policy_resolver = policy_resolver
self.commands_parser = AsyncCommandsParser()
self._aggregate_nodes = None
self.node_flags = self.__class__.NODE_FLAGS.copy()
self.command_flags = self.__class__.COMMAND_FLAGS.copy()
self.response_callbacks = kwargs["response_callbacks"]
Expand Down Expand Up @@ -619,6 +651,45 @@ def get_node_from_key(

return slot_cache[node_idx]

def get_random_primary_or_all_nodes(self, command_name):
"""
Returns random primary or all nodes depends on READONLY mode.
"""
if self.read_from_replicas and command_name in READ_COMMANDS:
return self.get_random_node()

return self.get_random_primary_node()

def get_random_primary_node(self) -> "ClusterNode":
"""
Returns a random primary node
"""
return random.choice(self.get_primaries())

async def get_nodes_from_slot(self, command: str, *args):
"""
Returns a list of nodes that hold the specified keys' slots.
"""
# get the node that holds the key's slot
return [
self.nodes_manager.get_node_from_slot(
await self._determine_slot(command, *args),
self.read_from_replicas and command in READ_COMMANDS,
self.load_balancing_strategy if command in READ_COMMANDS else None,
)
]

def get_special_nodes(self) -> Optional[list["ClusterNode"]]:
"""
Returns a list of nodes for commands with a special policy.
"""
if not self._aggregate_nodes:
raise RedisClusterException(
"Cannot execute FT.CURSOR commands without FT.AGGREGATE"
)

return self._aggregate_nodes

def keyslot(self, key: EncodableT) -> int:
"""
Find the keyslot for a given key.
Expand All @@ -643,39 +714,34 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No
self.response_callbacks[command] = callback

async def _determine_nodes(
self, command: str, *args: Any, node_flag: Optional[str] = None
self,
command: str,
*args: Any,
request_policy: RequestPolicy,
node_flag: Optional[str] = None,
) -> List["ClusterNode"]:
# Determine which nodes should be executed the command on.
# Returns a list of target nodes.
if not node_flag:
# get the nodes group for this command if it was predefined
node_flag = self.command_flags.get(command)

if node_flag in self.node_flags:
if node_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
if node_flag == self.__class__.PRIMARIES:
# return all primaries
return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
if node_flag == self.__class__.REPLICAS:
# return all replicas
return self.nodes_manager.get_nodes_by_server_type(REPLICA)
if node_flag == self.__class__.ALL_NODES:
# return all nodes
return list(self.nodes_manager.nodes_cache.values())
if node_flag == self.__class__.RANDOM:
# return a random node
return [random.choice(list(self.nodes_manager.nodes_cache.values()))]
if node_flag in self._command_flags_mapping:
request_policy = self._command_flags_mapping[node_flag]

# get the node that holds the key's slot
return [
self.nodes_manager.get_node_from_slot(
await self._determine_slot(command, *args),
self.read_from_replicas and command in READ_COMMANDS,
self.load_balancing_strategy if command in READ_COMMANDS else None,
)
]
policy_callback = self._policies_callback_mapping[request_policy]

if request_policy == RequestPolicy.DEFAULT_KEYED:
nodes = await policy_callback(command, *args)
elif request_policy == RequestPolicy.DEFAULT_KEYLESS:
nodes = policy_callback(command)
else:
nodes = policy_callback()

if command.lower() == "ft.aggregate":
self._aggregate_nodes = nodes

return nodes

async def _determine_slot(self, command: str, *args: Any) -> int:
if self.command_flags.get(command) == SLOT_ID:
Expand Down Expand Up @@ -780,6 +846,33 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
target_nodes_specified = True
retry_attempts = 0

command_policies = await self._policy_resolver.resolve(args[0].lower())

if not command_policies and not target_nodes_specified:
command_flag = self.command_flags.get(command)
if not command_flag:
# Fallback to default policy
if not self.get_default_node():
slot = None
else:
slot = await self._determine_slot(*args)
if not slot:
command_policies = CommandPolicies()
else:
command_policies = CommandPolicies(
request_policy=RequestPolicy.DEFAULT_KEYED,
response_policy=ResponsePolicy.DEFAULT_KEYED,
)
else:
if command_flag in self._command_flags_mapping:
command_policies = CommandPolicies(
request_policy=self._command_flags_mapping[command_flag]
)
else:
command_policies = CommandPolicies()
elif not command_policies and target_nodes_specified:
command_policies = CommandPolicies()

# Add one for the first execution
execute_attempts = 1 + retry_attempts
for _ in range(execute_attempts):
Expand All @@ -795,7 +888,9 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
if not target_nodes_specified:
# Determine the nodes to execute the command on
target_nodes = await self._determine_nodes(
*args, node_flag=passed_targets
*args,
request_policy=command_policies.request_policy,
node_flag=passed_targets,
)
if not target_nodes:
raise RedisClusterException(
Expand All @@ -806,10 +901,12 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
# Return the processed result
ret = await self._execute_command(target_nodes[0], *args, **kwargs)
if command in self.result_callbacks:
return self.result_callbacks[command](
ret = self.result_callbacks[command](
command, {target_nodes[0].name: ret}, **kwargs
)
return ret
return self._policies_callback_mapping[
command_policies.response_policy
](ret)
else:
keys = [node.name for node in target_nodes]
values = await asyncio.gather(
Expand All @@ -824,7 +921,9 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
return self.result_callbacks[command](
command, dict(zip(keys, values)), **kwargs
)
return dict(zip(keys, values))
return self._policies_callback_mapping[
command_policies.response_policy
](dict(zip(keys, values)))
except Exception as e:
if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
# The nodes and slots cache were should be reinitialized.
Expand Down Expand Up @@ -1740,6 +1839,7 @@ def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
self.kwargs = kwargs
self.position = position
self.result: Union[Any, Exception] = None
self.command_policies: Optional[CommandPolicies] = None

def __repr__(self) -> str:
return f"[{self.position}] {self.args} ({self.kwargs})"
Expand Down Expand Up @@ -1980,16 +2080,51 @@ async def _execute(
nodes = {}
for cmd in todo:
passed_targets = cmd.kwargs.pop("target_nodes", None)
command_policies = await client._policy_resolver.resolve(
cmd.args[0].lower()
)

if passed_targets and not client._is_node_flag(passed_targets):
target_nodes = client._parse_target_nodes(passed_targets)

if not command_policies:
command_policies = CommandPolicies()
else:
if not command_policies:
command_flag = client.command_flags.get(cmd.args[0])
if not command_flag:
# Fallback to default policy
if not client.get_default_node():
slot = None
else:
slot = await client._determine_slot(*cmd.args)
if not slot:
command_policies = CommandPolicies()
else:
command_policies = CommandPolicies(
request_policy=RequestPolicy.DEFAULT_KEYED,
response_policy=ResponsePolicy.DEFAULT_KEYED,
)
else:
if command_flag in client._command_flags_mapping:
command_policies = CommandPolicies(
request_policy=client._command_flags_mapping[
command_flag
]
)
else:
command_policies = CommandPolicies()

target_nodes = await client._determine_nodes(
*cmd.args, node_flag=passed_targets
*cmd.args,
request_policy=command_policies.request_policy,
node_flag=passed_targets,
)
if not target_nodes:
raise RedisClusterException(
f"No targets were found to execute {cmd.args} command on"
)
cmd.command_policies = command_policies
if len(target_nodes) > 1:
raise RedisClusterException(f"Too many targets for command {cmd.args}")
node = target_nodes[0]
Expand All @@ -2010,9 +2145,9 @@ async def _execute(
for cmd in todo:
if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
try:
cmd.result = await client.execute_command(
*cmd.args, **cmd.kwargs
)
cmd.result = client._policies_callback_mapping[
cmd.command_policies.response_policy
](await client.execute_command(*cmd.args, **cmd.kwargs))
except Exception as e:
cmd.result = e

Expand Down
Loading