Skip to content

Commit

Permalink
Cluster Refactorings
Browse files Browse the repository at this point in the history
* Add tls support to `ClusterClientBuilder`
* Simplify cluster connection map; key with `host:port` string
  rather than as a potentially incomplete uri.
  • Loading branch information
0xWOF committed Feb 7, 2023
1 parent 7a98021 commit 0d0133c
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 70 deletions.
129 changes: 60 additions & 69 deletions redis/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::iter::Iterator;
use std::str::FromStr;
use std::thread;
use std::time::Duration;

Expand All @@ -54,7 +55,7 @@ use crate::cluster_pipeline::UNROUTABLE_ERROR;
use crate::cluster_routing::{Routable, RoutingInfo, Slot, SLOT_SIZE};
use crate::cmd::{cmd, Cmd};
use crate::connection::{
connect, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, IntoConnectionInfo,
connect, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, RedisConnectionInfo,
};
use crate::parser::parse_redis_value;
use crate::types::{ErrorKind, HashMap, HashSet, RedisError, RedisResult, Value};
Expand Down Expand Up @@ -90,25 +91,7 @@ impl ClusterConnection {
password: cluster_params.password,
read_timeout: RefCell::new(None),
write_timeout: RefCell::new(None),
#[cfg(feature = "tls")]
tls: {
if initial_nodes.is_empty() {
None
} else {
// TODO: Maybe should run through whole list and make sure they're all matching?
match &initial_nodes.get(0).unwrap().addr {
ConnectionAddr::Tcp(_, _) => None,
ConnectionAddr::TcpTls {
host: _,
port: _,
insecure,
} => Some(TlsMode::from_insecure_flag(*insecure)),
_ => None,
}
}
},
#[cfg(not(feature = "tls"))]
tls: None,
tls: cluster_params.tls,
initial_nodes: initial_nodes.to_vec(),
};
connection.create_initial_connections()?;
Expand Down Expand Up @@ -190,20 +173,9 @@ impl ClusterConnection {
let mut connections = HashMap::with_capacity(self.initial_nodes.len());

for info in self.initial_nodes.iter() {
let addr = match info.addr {
ConnectionAddr::Tcp(ref host, port) => format!("redis://{host}:{port}"),
ConnectionAddr::TcpTls {
ref host,
port,
insecure,
} => {
let tls_mode = TlsMode::from_insecure_flag(insecure);
build_connection_string(host, Some(port), Some(tls_mode))
}
_ => panic!("No reach."),
};
let addr = info.addr.to_string();

if let Ok(mut conn) = self.connect(info.clone()) {
if let Ok(mut conn) = self.connect(&addr) {
if conn.check_connection() {
connections.insert(addr, conn);
break;
Expand Down Expand Up @@ -255,7 +227,7 @@ impl ClusterConnection {
}
}

if let Ok(mut conn) = self.connect(addr.as_ref()) {
if let Ok(mut conn) = self.connect(addr) {
if conn.check_connection() {
conn.set_read_timeout(*self.read_timeout.borrow()).unwrap();
conn.set_write_timeout(*self.write_timeout.borrow())
Expand Down Expand Up @@ -328,12 +300,16 @@ impl ClusterConnection {
}
}

fn connect<T: IntoConnectionInfo>(&self, info: T) -> RedisResult<Connection> {
let mut connection_info = info.into_connection_info()?;
connection_info.redis.username = self.username.clone();
connection_info.redis.password = self.password.clone();
fn connect(&self, node: &str) -> RedisResult<Connection> {
let params = ClusterParams {
password: self.password.clone(),
username: self.username.clone(),
tls: self.tls,
..Default::default()
};
let info = get_connection_info(node, params)?;

let mut conn = connect(&connection_info, None)?;
let mut conn = connect(&info, None)?;
if self.read_from_replicas {
// If READONLY is sent to primary nodes, it will have no effect
cmd("READONLY").query(&mut conn)?;
Expand Down Expand Up @@ -487,19 +463,15 @@ impl ClusterConnection {
let kind = err.kind();

if kind == ErrorKind::Ask {
redirected = err
.redirect_node()
.map(|(node, _slot)| build_connection_string(node, None, self.tls));
redirected = err.redirect_node().map(|(node, _slot)| node.to_string());
is_asking = true;
} else if kind == ErrorKind::Moved {
// Refresh slots.
self.refresh_slots()?;
excludes.clear();

// Request again.
redirected = err
.redirect_node()
.map(|(node, _slot)| build_connection_string(node, None, self.tls));
redirected = err.redirect_node().map(|(node, _slot)| node.to_string());
is_asking = false;
continue;
} else if kind == ErrorKind::TryAgain || kind == ErrorKind::ClusterDown {
Expand Down Expand Up @@ -692,22 +664,16 @@ impl NodeCmd {
}
}

/// TlsMode indicates use or do not use verification of certification.
/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more.
#[derive(Clone, Copy)]
enum TlsMode {
pub enum TlsMode {
/// Secure verify certification.
Secure,
/// Insecure do not verify certification.
Insecure,
}

impl TlsMode {
fn from_insecure_flag(insecure: bool) -> TlsMode {
if insecure {
TlsMode::Insecure
} else {
TlsMode::Secure
}
}
}

fn get_random_connection<'a>(
connections: &'a mut HashMap<String, Connection>,
excludes: Option<&'a HashSet<String>>,
Expand All @@ -728,7 +694,7 @@ fn get_random_connection<'a>(
}

// Get slot data from connection.
fn get_slots(connection: &mut Connection, tls_mode: Option<TlsMode>) -> RedisResult<Vec<Slot>> {
fn get_slots(connection: &mut Connection, tls: Option<TlsMode>) -> RedisResult<Vec<Slot>> {
let mut cmd = Cmd::new();
cmd.arg("CLUSTER").arg("SLOTS");
let value = connection.req_command(&cmd)?;
Expand Down Expand Up @@ -778,7 +744,7 @@ fn get_slots(connection: &mut Connection, tls_mode: Option<TlsMode>) -> RedisRes
} else {
return None;
};
Some(build_connection_string(&ip, Some(port), tls_mode))
Some(get_connection_addr(ip.into_owned(), port, tls).to_string())
} else {
None
}
Expand All @@ -797,16 +763,41 @@ fn get_slots(connection: &mut Connection, tls_mode: Option<TlsMode>) -> RedisRes
Ok(result)
}

fn build_connection_string(host: &str, port: Option<u16>, tls_mode: Option<TlsMode>) -> String {
let host_port = match port {
Some(port) => format!("{host}:{port}"),
None => host.to_string(),
};
match tls_mode {
None => format!("redis://{host_port}"),
Some(TlsMode::Insecure) => {
format!("rediss://{host_port}/#insecure")
}
Some(TlsMode::Secure) => format!("rediss://{host_port}"),
// The node string passed to this function will always be in the format host:port as it is either:
// - Created by calling ConnectionAddr::to_string (unix connections are not supported in cluster mode)
// - Returned from redis via the ASK/MOVED response
fn get_connection_info(node: &str, cluster_params: ClusterParams) -> RedisResult<ConnectionInfo> {
let mut split = node.split(':');
let invalid_error = || (ErrorKind::InvalidClientConfig, "Invalid node string");

let host = split.next().ok_or_else(invalid_error)?;
let port = split
.next()
.and_then(|string| u16::from_str(string).ok())
.ok_or_else(invalid_error)?;

Ok(ConnectionInfo {
addr: get_connection_addr(host.to_string(), port, cluster_params.tls),
redis: RedisConnectionInfo {
password: cluster_params.password,
username: cluster_params.username,
..Default::default()
},
})
}

fn get_connection_addr(host: String, port: u16, tls: Option<TlsMode>) -> ConnectionAddr {
match tls {
Some(TlsMode::Secure) => ConnectionAddr::TcpTls {
host,
port,
insecure: false,
},
Some(TlsMode::Insecure) => ConnectionAddr::TcpTls {
host,
port,
insecure: true,
},
_ => ConnectionAddr::Tcp(host, port),
}
}
28 changes: 27 additions & 1 deletion redis/src/cluster_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::cluster::ClusterConnection;
use crate::cluster::{ClusterConnection, TlsMode};
use crate::connection::{ConnectionAddr, ConnectionInfo, IntoConnectionInfo};
use crate::types::{ErrorKind, RedisError, RedisResult};

Expand All @@ -8,6 +8,10 @@ pub(crate) struct ClusterParams {
pub(crate) password: Option<String>,
pub(crate) username: Option<String>,
pub(crate) read_from_replicas: bool,
/// tls indicates tls behavior of connections.
/// When Some(TlsMode), connections use tls and verify certification depends on TlsMode.
/// When None, connections do not use tls.
pub(crate) tls: Option<TlsMode>,
}

/// Used to configure and build a [`ClusterClient`].
Expand Down Expand Up @@ -65,6 +69,19 @@ impl ClusterClientBuilder {
} else {
&None
};
if cluster_params.tls.is_none() {
cluster_params.tls = match first_node.addr {
ConnectionAddr::TcpTls {
host: _,
port: _,
insecure,
} => Some(match insecure {
false => TlsMode::Secure,
true => TlsMode::Insecure,
}),
_ => None,
};
}

let mut nodes = Vec::with_capacity(initial_nodes.len());
for node in initial_nodes {
Expand Down Expand Up @@ -108,6 +125,15 @@ impl ClusterClientBuilder {
self
}

/// Sets TLS mode for the new ClusterClient.
///
/// It is extracted from the first node of initial_nodes if not set.
#[cfg(feature = "tls")]
pub fn tls(mut self, tls: TlsMode) -> ClusterClientBuilder {
self.cluster_params.tls = Some(tls);
self
}

/// Enables reading from replicas for all new connections (default is disabled).
///
/// If enabled, then read queries will go to the replica nodes & write queries will go to the
Expand Down
1 change: 1 addition & 0 deletions redis/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ impl ConnectionAddr {

impl fmt::Display for ConnectionAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// Cluster::get_connection_info depends on the return value from this function
match *self {
ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
Expand Down

0 comments on commit 0d0133c

Please sign in to comment.