From d92f4a29809a546757bed3f425013ec57bb6ed5d Mon Sep 17 00:00:00 2001 From: Shachar Langbeheim Date: Fri, 1 Dec 2023 12:42:18 +0000 Subject: [PATCH 1/2] Move `tls_params` into `ClusterParams`. The two objects are used together, so it makes sense to join them. --- redis/src/cluster.rs | 18 ++++++---- redis/src/cluster_async/mod.rs | 49 +++++++------------------ redis/src/cluster_client.rs | 66 ++++++++++++---------------------- 3 files changed, 47 insertions(+), 86 deletions(-) diff --git a/redis/src/cluster.rs b/redis/src/cluster.rs index 765fa2146..d0c487f86 100644 --- a/redis/src/cluster.rs +++ b/redis/src/cluster.rs @@ -228,7 +228,6 @@ where pub(crate) fn new( cluster_params: ClusterParams, initial_nodes: Vec, - tls_params: Option, ) -> RedisResult { let connection = Self { connections: RefCell::new(HashMap::new()), @@ -240,7 +239,7 @@ where read_timeout: RefCell::new(None), write_timeout: RefCell::new(None), tls: cluster_params.tls, - tls_params, + tls_params: cluster_params.tls_params, initial_nodes: initial_nodes.to_vec(), retry_params: cluster_params.retry_params, }; @@ -433,9 +432,10 @@ where password: self.password.clone(), username: self.username.clone(), tls: self.tls, + tls_params: self.tls_params.clone(), ..Default::default() }; - let info = get_connection_info(node, params, self.tls_params.clone())?; + let info = get_connection_info(node, params)?; let mut conn = C::connect(info, None)?; if self.read_from_replicas { @@ -991,7 +991,6 @@ pub(crate) fn parse_slots(raw_slot_resp: Value, tls: Option) -> RedisRe pub(crate) fn get_connection_info( node: &str, cluster_params: ClusterParams, - tls_params: Option, ) -> RedisResult { let invalid_error = || (ErrorKind::InvalidClientConfig, "Invalid node string"); @@ -1005,7 +1004,12 @@ pub(crate) fn get_connection_info( .ok_or_else(invalid_error)?; Ok(ConnectionInfo { - addr: get_connection_addr(host.to_string(), port, cluster_params.tls, tls_params), + addr: get_connection_addr( + host.to_string(), + port, + cluster_params.tls, + cluster_params.tls_params, + ), redis: RedisConnectionInfo { password: cluster_params.password, username: cluster_params.username, @@ -1069,13 +1073,13 @@ mod tests { ]; for (input, expected) in cases { - let res = get_connection_info(input, ClusterParams::default(), None); + let res = get_connection_info(input, ClusterParams::default()); assert_eq!(res.unwrap().addr, expected); } let cases = vec![":0", "[]:6379"]; for input in cases { - let res = get_connection_info(input, ClusterParams::default(), None); + let res = get_connection_info(input, ClusterParams::default()); assert_eq!( res.err(), Some(RedisError::from(( diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 2e88a2fe3..a54f5a44e 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -44,9 +44,6 @@ use crate::{ Value, }; -#[cfg(feature = "tls-rustls")] -use crate::tls::TlsConnParams; - #[cfg(not(feature = "tls-rustls"))] use crate::connection::TlsConnParams; @@ -77,9 +74,8 @@ where pub(crate) async fn new( initial_nodes: &[ConnectionInfo], cluster_params: ClusterParams, - tls_params: Option, ) -> RedisResult> { - ClusterConnInner::new(initial_nodes, cluster_params, tls_params) + ClusterConnInner::new(initial_nodes, cluster_params) .await .map(|inner| { let (tx, mut rx) = mpsc::channel::>(100); @@ -171,7 +167,6 @@ struct InnerCore { conn_lock: RwLock<(ConnectionMap, SlotMap)>, cluster_params: ClusterParams, pending_requests: Mutex>>, - tls_params: Option, } type Core = Arc>; @@ -457,15 +452,12 @@ where async fn new( initial_nodes: &[ConnectionInfo], cluster_params: ClusterParams, - tls_params: Option, ) -> RedisResult { - let connections = - Self::create_initial_connections(initial_nodes, &cluster_params, &tls_params).await?; + let connections = Self::create_initial_connections(initial_nodes, &cluster_params).await?; let inner = Arc::new(InnerCore { conn_lock: RwLock::new((connections, Default::default())), cluster_params, pending_requests: Mutex::new(Vec::new()), - tls_params, }); let mut connection = ClusterConnInner { inner, @@ -480,15 +472,13 @@ where async fn create_initial_connections( initial_nodes: &[ConnectionInfo], params: &ClusterParams, - tls_params: &Option, ) -> RedisResult> { let connections = stream::iter(initial_nodes.iter().cloned()) .map(|info| { let params = params.clone(); - let tls_params = tls_params.clone(); async move { let addr = info.addr.to_string(); - let result = connect_and_check(&addr, params, tls_params).await; + let result = connect_and_check(&addr, params).await; match result { Ok(conn) => Some((addr, async { conn }.boxed().shared())), Err(e) => { @@ -528,7 +518,6 @@ where &addr, connections.remove(&addr), &inner.cluster_params, - inner.tls_params.clone(), ) .await; if let Ok(conn) = conn { @@ -584,13 +573,8 @@ where .fold( HashMap::with_capacity(nodes_len), |mut connections, (addr, connection)| async { - let conn = Self::get_or_create_conn( - addr, - connection, - &inner.cluster_params, - inner.tls_params.clone(), - ) - .await; + let conn = + Self::get_or_create_conn(addr, connection, &inner.cluster_params).await; if let Ok(conn) = conn { connections.insert(addr.to_string(), async { conn }.boxed().shared()); } @@ -888,12 +872,10 @@ where let addr_conn_option = match conn { Some((addr, Some(conn))) => Some((addr, conn.await)), - Some((addr, None)) => { - connect_and_check(&addr, core.cluster_params.clone(), core.tls_params.clone()) - .await - .ok() - .map(|conn| (addr, conn)) - } + Some((addr, None)) => connect_and_check(&addr, core.cluster_params.clone()) + .await + .ok() + .map(|conn| (addr, conn)), None => None, }; @@ -1047,16 +1029,15 @@ where addr: &str, conn_option: Option>, params: &ClusterParams, - tls_params: Option, ) -> RedisResult { if let Some(conn) = conn_option { let mut conn = conn.await; match check_connection(&mut conn).await { Ok(_) => Ok(conn), - Err(_) => connect_and_check(addr, params.clone(), tls_params).await, + Err(_) => connect_and_check(addr, params.clone()).await, } } else { - connect_and_check(addr, params.clone(), tls_params).await + connect_and_check(addr, params.clone()).await } } } @@ -1258,16 +1239,12 @@ impl Connect for MultiplexedConnection { } } -async fn connect_and_check( - node: &str, - params: ClusterParams, - tls_params: Option, -) -> RedisResult +async fn connect_and_check(node: &str, params: ClusterParams) -> RedisResult where C: ConnectionLike + Connect + Send + 'static, { let read_from_replicas = params.read_from_replicas; - let info = get_connection_info(node, params, tls_params)?; + let info = get_connection_info(node, params)?; let mut conn = C::connect(info).await?; check_connection(&mut conn).await?; if read_from_replicas { diff --git a/redis/src/cluster_client.rs b/redis/src/cluster_client.rs index 94d32663b..12a9b652b 100644 --- a/redis/src/cluster_client.rs +++ b/redis/src/cluster_client.rs @@ -80,17 +80,29 @@ pub(crate) struct ClusterParams { /// When None, connections do not use tls. pub(crate) tls: Option, pub(crate) retry_params: RetryParams, + pub(crate) tls_params: Option, } -impl From for ClusterParams { - fn from(value: BuilderParams) -> Self { - Self { +impl ClusterParams { + fn from(value: BuilderParams) -> RedisResult { + #[cfg(not(feature = "tls-rustls"))] + let tls_params = None; + + #[cfg(feature = "tls-rustls")] + let tls_params = { + let retrieved_tls_params = value.certs.clone().map(retrieve_tls_certificates); + + retrieved_tls_params.transpose()? + }; + + Ok(Self { password: value.password, username: value.username, read_from_replicas: value.read_from_replicas, tls: value.tls, retry_params: value.retries_configuration, - } + tls_params, + }) } } @@ -131,20 +143,6 @@ impl ClusterClientBuilder { pub fn build(self) -> RedisResult { let initial_nodes = self.initial_nodes?; - #[cfg(not(feature = "tls-rustls"))] - let tls_params = None; - - #[cfg(feature = "tls-rustls")] - let tls_params = { - let retrieved_tls_params = self - .builder_params - .certs - .clone() - .map(retrieve_tls_certificates); - - retrieved_tls_params.transpose()? - }; - let first_node = match initial_nodes.first() { Some(node) => node, None => { @@ -155,7 +153,7 @@ impl ClusterClientBuilder { } }; - let mut cluster_params: ClusterParams = self.builder_params.into(); + let mut cluster_params = ClusterParams::from(self.builder_params)?; let password = if cluster_params.password.is_none() { cluster_params.password = first_node.redis.password.clone(); &cluster_params.password @@ -210,7 +208,6 @@ impl ClusterClientBuilder { Ok(ClusterClient { initial_nodes: nodes, cluster_params, - tls_params, }) } @@ -311,7 +308,6 @@ impl ClusterClientBuilder { pub struct ClusterClient { initial_nodes: Vec, cluster_params: ClusterParams, - tls_params: Option, } impl ClusterClient { @@ -344,11 +340,7 @@ impl ClusterClient { /// /// An error is returned if there is a failure while creating connections or slots. pub fn get_connection(&self) -> RedisResult { - cluster::ClusterConnection::new( - self.cluster_params.clone(), - self.initial_nodes.clone(), - self.tls_params.clone(), - ) + cluster::ClusterConnection::new(self.cluster_params.clone(), self.initial_nodes.clone()) } /// Creates new connections to Redis Cluster nodes and returns a @@ -359,12 +351,8 @@ impl ClusterClient { /// An error is returned if there is a failure while creating connections or slots. #[cfg(feature = "cluster-async")] pub async fn get_async_connection(&self) -> RedisResult { - cluster_async::ClusterConnection::new( - &self.initial_nodes, - self.cluster_params.clone(), - self.tls_params.clone(), - ) - .await + cluster_async::ClusterConnection::new(&self.initial_nodes, self.cluster_params.clone()) + .await } #[doc(hidden)] @@ -372,11 +360,7 @@ impl ClusterClient { where C: crate::ConnectionLike + crate::cluster::Connect + Send, { - cluster::ClusterConnection::new( - self.cluster_params.clone(), - self.initial_nodes.clone(), - self.tls_params.clone(), - ) + cluster::ClusterConnection::new(self.cluster_params.clone(), self.initial_nodes.clone()) } #[doc(hidden)] @@ -393,12 +377,8 @@ impl ClusterClient { + Unpin + 'static, { - cluster_async::ClusterConnection::new( - &self.initial_nodes, - self.cluster_params.clone(), - self.tls_params.clone(), - ) - .await + cluster_async::ClusterConnection::new(&self.initial_nodes, self.cluster_params.clone()) + .await } /// Use `new()`. From 261dfa0410e469833fa5aa8a5f6ddbfd3a0ca7a2 Mon Sep 17 00:00:00 2001 From: Shachar Langbeheim Date: Fri, 1 Dec 2023 12:45:55 +0000 Subject: [PATCH 2/2] Cluster: Contain `ClusterParams` internally. This matches the async cluster's structure, and ensures that fields that are added to `ClusterParams` are automatically added to the sync cluster --- redis/src/cluster.rs | 39 ++++++++++++++------------------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/redis/src/cluster.rs b/redis/src/cluster.rs index d0c487f86..ca358c86b 100644 --- a/redis/src/cluster.rs +++ b/redis/src/cluster.rs @@ -43,7 +43,6 @@ use std::time::Duration; use rand::{seq::IteratorRandom, thread_rng, Rng}; -use crate::cluster_client::RetryParams; use crate::cluster_pipeline::UNROUTABLE_ERROR; use crate::cluster_routing::{ MultipleNodeRoutingInfo, ResponsePolicy, Routable, SingleNodeRoutingInfo, SlotAddr, @@ -211,14 +210,9 @@ pub struct ClusterConnection { connections: RefCell>, slots: RefCell, auto_reconnect: RefCell, - read_from_replicas: bool, - username: Option, - password: Option, read_timeout: RefCell>, write_timeout: RefCell>, - tls: Option, - tls_params: Option, - retry_params: RetryParams, + cluster_params: ClusterParams, } impl ClusterConnection @@ -233,15 +227,10 @@ where connections: RefCell::new(HashMap::new()), slots: RefCell::new(SlotMap::new()), auto_reconnect: RefCell::new(true), - read_from_replicas: cluster_params.read_from_replicas, - username: cluster_params.username, - password: cluster_params.password, read_timeout: RefCell::new(None), write_timeout: RefCell::new(None), - tls: cluster_params.tls, - tls_params: cluster_params.tls_params, initial_nodes: initial_nodes.to_vec(), - retry_params: cluster_params.retry_params, + cluster_params, }; connection.create_initial_connections()?; @@ -386,7 +375,7 @@ where for conn in samples.iter_mut() { let value = conn.req_command(&slot_cmd())?; - if let Ok(mut slots_data) = parse_slots(value, self.tls) { + if let Ok(mut slots_data) = parse_slots(value, self.cluster_params.tls) { slots_data.sort_by_key(|s| s.start()); let last_slot = slots_data.iter().try_fold(0, |prev_end, slot_data| { if prev_end != slot_data.start() { @@ -412,7 +401,10 @@ where ))); } - new_slots = Some(SlotMap::from_slots(&slots_data, self.read_from_replicas)); + new_slots = Some(SlotMap::from_slots( + &slots_data, + self.cluster_params.read_from_replicas, + )); break; } } @@ -428,17 +420,11 @@ where } fn connect(&self, node: &str) -> RedisResult { - let params = ClusterParams { - password: self.password.clone(), - username: self.username.clone(), - tls: self.tls, - tls_params: self.tls_params.clone(), - ..Default::default() - }; + let params = self.cluster_params.clone(); let info = get_connection_info(node, params)?; let mut conn = C::connect(info, None)?; - if self.read_from_replicas { + if self.cluster_params.read_from_replicas { // If READONLY is sent to primary nodes, it will have no effect cmd("READONLY").query(&mut conn)?; } @@ -698,7 +684,7 @@ where match rv { Ok(rv) => return Ok(rv), Err(err) => { - if retries == self.retry_params.number_of_retries { + if retries == self.cluster_params.retry_params.number_of_retries { return Err(err); } retries += 1; @@ -719,7 +705,10 @@ where } ErrorKind::TryAgain | ErrorKind::ClusterDown => { // Sleep and retry. - let sleep_time = self.retry_params.wait_time_for_retry(retries); + let sleep_time = self + .cluster_params + .retry_params + .wait_time_for_retry(retries); thread::sleep(sleep_time); } ErrorKind::IoError => {