Skip to content

Commit

Permalink
Add support for multi-shard commands.
Browse files Browse the repository at this point in the history
This allows the async cluster client to split keyed commands so that
keys will be grouped by the slot they belong in, and sent only to the
relevant shard.
  • Loading branch information
nihohit committed Jul 23, 2023
1 parent aed3204 commit 302f2cc
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 21 deletions.
49 changes: 34 additions & 15 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,24 +539,35 @@ where
Ok(())
}

async fn execute_on_multiple_nodes(
async fn execute_on_multiple_nodes<'a>(
func: fn(C, Arc<Cmd>) -> RedisFuture<'static, Response>,
cmd: &Arc<Cmd>,
routing: &MultipleNodeRoutingInfo,
cmd: &'a Arc<Cmd>,
routing: &'a MultipleNodeRoutingInfo,
core: Core<C>,
response_policy: Option<ResponsePolicy>,
) -> (OperationTarget, RedisResult<Response>) {
let read_guard = core.conn_lock.read().await;
let connections: Vec<(String, ConnectionFuture<C>)> = read_guard
let connections: Vec<_> = read_guard
.1
.addresses_for_multi_routing(routing)
.into_iter()
.filter_map(|addr| {
read_guard
.0
.get(addr)
.cloned()
.map(|conn| (addr.to_string(), conn))
.enumerate()
.filter_map(|(index, addr)| {
read_guard.0.get(addr).cloned().map(|conn| {
let cmd = match routing {
MultipleNodeRoutingInfo::MultiSlot(vec) => {
let mut new_cmd = Cmd::new();
new_cmd.arg(cmd.arg_idx(0));
let (_, indices) = vec.get(index).unwrap();
for index in indices {
new_cmd.arg(cmd.arg_idx(*index));
}
Arc::new(new_cmd)
}
_ => cmd.clone(),
};
(addr.to_string(), conn, cmd)
})
})
.collect();
drop(read_guard);
Expand All @@ -566,10 +577,10 @@ where
Response::Multiple(_) => unreachable!(),
};

let run_func = |(_, conn)| {
let run_func = |(_, conn, cmd)| {
Box::pin(async move {
let conn = conn.await;
Ok(extract_result(func(conn, cmd.clone()).await?))
Ok(extract_result(func(conn, cmd).await?))
})
};

Expand Down Expand Up @@ -611,17 +622,25 @@ where
Some(ResponsePolicy::CombineArrays) => {
future::try_join_all(connections.into_iter().map(run_func))
.await
.and_then(crate::cluster_routing::combine_array_results)
.and_then(|results| match routing {
MultipleNodeRoutingInfo::MultiSlot(vec) => {
crate::cluster_routing::combine_and_sort_array_results(
results,
vec.iter().map(|(_, indices)| indices),
)
}
_ => crate::cluster_routing::combine_array_results(results),
})
}
Some(ResponsePolicy::Special) | None => {
// This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user.
// TODO - once RESP3 is merged, return a map value here.
// TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes.
future::try_join_all(connections.into_iter().map(|(addr, conn)| async move {
future::try_join_all(connections.into_iter().map(|(addr, conn, cmd)| async move {
let conn = conn.await;
Ok(Value::Bulk(vec![
Value::Data(addr.into_bytes()),
extract_result(func(conn, cmd.clone()).await?),
extract_result(func(conn, cmd).await?),
]))
}))
.await
Expand Down
128 changes: 123 additions & 5 deletions redis/src/cluster_routing.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::cmp::min;
use std::collections::{BTreeMap, HashSet};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::iter::Iterator;

use rand::seq::SliceRandom;
Expand Down Expand Up @@ -62,6 +62,8 @@ pub(crate) enum SingleNodeRoutingInfo {
pub(crate) enum MultipleNodeRoutingInfo {
AllNodes,
AllMasters,
// Instructions on how to split a multi-slot command (e.g. MGET, MSET) into sub-commands. Each tuple is the route for each subcommand and the indices of the arguments from the original command that should be copied to the subcommand.
MultiSlot(Vec<(Route, Vec<usize>)>),
}

pub(crate) fn aggregate(values: Vec<Value>, op: AggregateOp) -> RedisResult<Value> {
Expand Down Expand Up @@ -155,6 +157,31 @@ pub(crate) fn combine_array_results(values: Vec<Value>) -> RedisResult<Value> {
Ok(Value::Bulk(results))
}

pub(crate) fn combine_and_sort_array_results<'a>(
values: Vec<Value>,
sorting_order: impl IntoIterator<Item = &'a Vec<usize>> + ExactSizeIterator,
) -> RedisResult<Value> {
let mut results = Vec::new();
results.resize(values.len(), Value::Nil);
assert_eq!(values.len(), sorting_order.len());

for (key_indices, value) in sorting_order.into_iter().zip(values) {
match value {
Value::Bulk(values) => {
assert_eq!(values.len(), key_indices.len());
for (index, value) in key_indices.iter().zip(values) {
results[*index - 1] = value;
}
}
_ => {
return Err((ErrorKind::TypeError, "expected array of values as response").into());
}
}
}

Ok(Value::Bulk(results))
}

fn get_slot(key: &[u8]) -> u16 {
let key = match get_hashtag(key) {
Some(tag) => tag,
Expand All @@ -173,6 +200,40 @@ fn get_route(is_readonly: bool, key: &[u8]) -> Route {
}
}

fn multi_shard<R>(
r: &R,
cmd: &[u8],
first_key_index: usize,
has_values: bool,
) -> Option<RoutingInfo>
where
R: Routable + ?Sized,
{
let is_readonly = is_readonly_cmd(cmd);
let mut routes = HashMap::new();
let mut index = first_key_index;
while let Some(key) = r.arg_idx(index) {
let route = get_route(is_readonly, key);
let entry = routes.entry(route);
let keys = entry.or_insert(Vec::new());
keys.push(index);

if has_values {
index += 1;
r.arg_idx(index)?; // check that there's a value for the key
keys.push(index);
}
index += 1
}

let mut routes: Vec<(Route, Vec<usize>)> = routes.into_iter().collect();
Some(if routes.len() == 1 {
RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(routes.pop().unwrap().0))
} else {
RoutingInfo::MultiNode(MultipleNodeRoutingInfo::MultiSlot(routes))
})
}

impl RoutingInfo {
pub(crate) fn response_policy<R>(r: &R) -> Option<ResponsePolicy>
where
Expand All @@ -187,7 +248,7 @@ impl RoutingInfo {
Some(Aggregate(AggregateOp::Sum))
}

b"MSETNX" | b"WAIT" => Some(Aggregate(AggregateOp::Min)),
b"WAIT" => Some(Aggregate(AggregateOp::Min)),

b"CONFIG SET" | b"FLUSHALL" | b"FLUSHDB" | b"FUNCTION DELETE" | b"FUNCTION FLUSH"
| b"FUNCTION LOAD" | b"FUNCTION RESTORE" | b"LATENCY RESET" | b"MEMORY PURGE"
Expand Down Expand Up @@ -248,7 +309,8 @@ impl RoutingInfo {
Some(RoutingInfo::MultiNode(MultipleNodeRoutingInfo::AllNodes))
}

// TODO - multi shard handling - b"MGET" |b"MSETNX" |b"DEL" |b"EXISTS" |b"UNLINK" |b"TOUCH" |b"MSET"
b"MGET" | b"DEL" | b"EXISTS" | b"UNLINK" | b"TOUCH" => multi_shard(r, cmd, 1, false),
b"MSET" => multi_shard(r, cmd, 1, true),
// TODO - special handling - b"SCAN"
b"SCAN" | b"CLIENT SETNAME" | b"SHUTDOWN" | b"SLAVEOF" | b"REPLICAOF" | b"MOVE"
| b"BITOP" => None,
Expand Down Expand Up @@ -394,7 +456,7 @@ impl Slot {
}
}

#[derive(Eq, PartialEq, Clone, Copy, Debug)]
#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
pub(crate) enum SlotAddr {
Master,
Replica,
Expand Down Expand Up @@ -504,13 +566,17 @@ impl SlotMap {
MultipleNodeRoutingInfo::AllMasters => {
self.all_unique_addresses(true).into_iter().collect()
}
MultipleNodeRoutingInfo::MultiSlot(routes) => routes
.iter()
.flat_map(|(route, _)| self.slot_addr_for_route(route))
.collect(),
}
}
}

/// Defines the slot and the [`SlotAddr`] to which
/// a command should be sent
#[derive(Eq, PartialEq, Clone, Copy, Debug)]
#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
pub(crate) struct Route(u16, SlotAddr);

impl Route {
Expand Down Expand Up @@ -776,6 +842,58 @@ mod tests {
]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 5210));
}

#[test]
fn test_multi_shard() {
let mut cmd = cmd("DEL");
cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz");
let routing = RoutingInfo::for_routable(&cmd);
let mut expected = std::collections::HashMap::new();
expected.insert(Route(4813, SlotAddr::Master), vec![3]);
expected.insert(Route(5061, SlotAddr::Master), vec![2, 4]);
expected.insert(Route(12182, SlotAddr::Master), vec![1]);

assert!(
matches!(routing.clone(), Some(RoutingInfo::MultiNode(MultipleNodeRoutingInfo::MultiSlot(vec))) if {
let routes = vec.clone().into_iter().collect();
expected == routes
}),
"{routing:?}"
);

let mut cmd = crate::cmd("MGET");
cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz");
let routing = RoutingInfo::for_routable(&cmd);
let mut expected = std::collections::HashMap::new();
expected.insert(Route(4813, SlotAddr::Replica), vec![3]);
expected.insert(Route(5061, SlotAddr::Replica), vec![2, 4]);
expected.insert(Route(12182, SlotAddr::Replica), vec![1]);

assert!(
matches!(routing.clone(), Some(RoutingInfo::MultiNode(MultipleNodeRoutingInfo::MultiSlot(vec))) if {
let routes = vec.clone().into_iter().collect();
expected ==routes
}),
"{routing:?}"
);
}

#[test]
fn test_combine_multi_shard_to_single_node_when_all_keys_are_in_same_slot() {
let mut cmd = cmd("DEL");
cmd.arg("foo").arg("{foo}bar").arg("{foo}baz");
let routing = RoutingInfo::for_routable(&cmd);

assert!(
matches!(
routing,
Some(RoutingInfo::SingleNode(
SingleNodeRoutingInfo::SpecificNode(Route(12182, SlotAddr::Master))
))
),
"{routing:?}"
);
}

#[test]
fn test_slot_map() {
let slot_map = SlotMap::from_slots(
Expand Down
40 changes: 39 additions & 1 deletion redis/tests/test_cluster_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1082,11 +1082,49 @@ fn test_cluster_fan_out_and_combine_arrays_of_values() {
result.sort();
assert_eq!(
result,
vec![format!("key:6379"), format!("key:6381"),],
vec!["key:6379".to_string(), "key:6381".to_string(),],
"{result:?}"
);
}

#[test]
fn test_cluster_split_multi_shard_command_and_combine_arrays_of_values() {
let name = "test_cluster_split_multi_shard_command_and_combine_arrays_of_values";
let mut cmd = cmd("MGET");
cmd.arg("foo").arg("bar").arg("baz");
let MockEnv {
runtime,
async_connection: mut connection,
handler: _handler,
..
} = MockEnv::with_client_builder(
ClusterClient::builder(vec![&*format!("redis://{name}")])
.retries(0)
.read_from_replicas(),
name,
move |received_cmd: &[u8], port| {
respond_startup_with_replica_using_config(name, received_cmd, None)?;
let cmd_str = std::str::from_utf8(received_cmd).unwrap();
let results = ["foo", "bar", "baz"]
.iter()
.filter_map(|expected_key| {
if cmd_str.contains(expected_key) {
Some(Value::Data(format!("{expected_key}-{port}").into_bytes()))
} else {
None
}
})
.collect();
Err(Ok(Value::Bulk(results)))
},
);

let result = runtime
.block_on(cmd.query_async::<_, Vec<String>>(&mut connection))
.unwrap();
assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6380"]);
}

#[test]
fn test_async_cluster_with_username_and_password() {
let cluster = TestClusterContext::new_with_cluster_client_builder(3, 0, |builder| {
Expand Down

0 comments on commit 302f2cc

Please sign in to comment.