Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Order in usage of ClusterParams. #997

Merged
merged 2 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 23 additions & 30 deletions redis/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ use std::time::Duration;

use rand::{seq::IteratorRandom, thread_rng, Rng};

use crate::cluster_client::RetryParams;
use crate::cluster_pipeline::UNROUTABLE_ERROR;
use crate::cluster_routing::{
MultipleNodeRoutingInfo, ResponsePolicy, Routable, SingleNodeRoutingInfo, SlotAddr,
Expand Down Expand Up @@ -211,14 +210,9 @@ pub struct ClusterConnection<C = Connection> {
connections: RefCell<HashMap<String, C>>,
slots: RefCell<SlotMap>,
auto_reconnect: RefCell<bool>,
read_from_replicas: bool,
username: Option<String>,
password: Option<String>,
read_timeout: RefCell<Option<Duration>>,
write_timeout: RefCell<Option<Duration>>,
tls: Option<TlsMode>,
tls_params: Option<TlsConnParams>,
retry_params: RetryParams,
cluster_params: ClusterParams,
}

impl<C> ClusterConnection<C>
Expand All @@ -228,21 +222,15 @@ where
pub(crate) fn new(
cluster_params: ClusterParams,
initial_nodes: Vec<ConnectionInfo>,
tls_params: Option<TlsConnParams>,
) -> RedisResult<Self> {
let connection = Self {
connections: RefCell::new(HashMap::new()),
slots: RefCell::new(SlotMap::new()),
auto_reconnect: RefCell::new(true),
read_from_replicas: cluster_params.read_from_replicas,
username: cluster_params.username,
password: cluster_params.password,
read_timeout: RefCell::new(None),
write_timeout: RefCell::new(None),
tls: cluster_params.tls,
tls_params,
initial_nodes: initial_nodes.to_vec(),
retry_params: cluster_params.retry_params,
cluster_params,
};
connection.create_initial_connections()?;

Expand Down Expand Up @@ -387,7 +375,7 @@ where

for conn in samples.iter_mut() {
let value = conn.req_command(&slot_cmd())?;
if let Ok(mut slots_data) = parse_slots(value, self.tls) {
if let Ok(mut slots_data) = parse_slots(value, self.cluster_params.tls) {
slots_data.sort_by_key(|s| s.start());
let last_slot = slots_data.iter().try_fold(0, |prev_end, slot_data| {
if prev_end != slot_data.start() {
Expand All @@ -413,7 +401,10 @@ where
)));
}

new_slots = Some(SlotMap::from_slots(&slots_data, self.read_from_replicas));
new_slots = Some(SlotMap::from_slots(
&slots_data,
self.cluster_params.read_from_replicas,
));
break;
}
}
Expand All @@ -429,16 +420,11 @@ where
}

fn connect(&self, node: &str) -> RedisResult<C> {
let params = ClusterParams {
password: self.password.clone(),
username: self.username.clone(),
tls: self.tls,
..Default::default()
};
let info = get_connection_info(node, params, self.tls_params.clone())?;
let params = self.cluster_params.clone();
let info = get_connection_info(node, params)?;

let mut conn = C::connect(info, None)?;
if self.read_from_replicas {
if self.cluster_params.read_from_replicas {
// If READONLY is sent to primary nodes, it will have no effect
cmd("READONLY").query(&mut conn)?;
}
Expand Down Expand Up @@ -698,7 +684,7 @@ where
match rv {
Ok(rv) => return Ok(rv),
Err(err) => {
if retries == self.retry_params.number_of_retries {
if retries == self.cluster_params.retry_params.number_of_retries {
return Err(err);
}
retries += 1;
Expand All @@ -719,7 +705,10 @@ where
}
ErrorKind::TryAgain | ErrorKind::ClusterDown => {
// Sleep and retry.
let sleep_time = self.retry_params.wait_time_for_retry(retries);
let sleep_time = self
.cluster_params
.retry_params
.wait_time_for_retry(retries);
thread::sleep(sleep_time);
}
ErrorKind::IoError => {
Expand Down Expand Up @@ -991,7 +980,6 @@ pub(crate) fn parse_slots(raw_slot_resp: Value, tls: Option<TlsMode>) -> RedisRe
pub(crate) fn get_connection_info(
node: &str,
cluster_params: ClusterParams,
tls_params: Option<TlsConnParams>,
) -> RedisResult<ConnectionInfo> {
let invalid_error = || (ErrorKind::InvalidClientConfig, "Invalid node string");

Expand All @@ -1005,7 +993,12 @@ pub(crate) fn get_connection_info(
.ok_or_else(invalid_error)?;

Ok(ConnectionInfo {
addr: get_connection_addr(host.to_string(), port, cluster_params.tls, tls_params),
addr: get_connection_addr(
host.to_string(),
port,
cluster_params.tls,
cluster_params.tls_params,
),
redis: RedisConnectionInfo {
password: cluster_params.password,
username: cluster_params.username,
Expand Down Expand Up @@ -1069,13 +1062,13 @@ mod tests {
];

for (input, expected) in cases {
let res = get_connection_info(input, ClusterParams::default(), None);
let res = get_connection_info(input, ClusterParams::default());
assert_eq!(res.unwrap().addr, expected);
}

let cases = vec![":0", "[]:6379"];
for input in cases {
let res = get_connection_info(input, ClusterParams::default(), None);
let res = get_connection_info(input, ClusterParams::default());
assert_eq!(
res.err(),
Some(RedisError::from((
Expand Down
49 changes: 13 additions & 36 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ use crate::{
Value,
};

#[cfg(feature = "tls-rustls")]
use crate::tls::TlsConnParams;

#[cfg(not(feature = "tls-rustls"))]
use crate::connection::TlsConnParams;

Expand Down Expand Up @@ -77,9 +74,8 @@ where
pub(crate) async fn new(
initial_nodes: &[ConnectionInfo],
cluster_params: ClusterParams,
tls_params: Option<TlsConnParams>,
) -> RedisResult<ClusterConnection<C>> {
ClusterConnInner::new(initial_nodes, cluster_params, tls_params)
ClusterConnInner::new(initial_nodes, cluster_params)
.await
.map(|inner| {
let (tx, mut rx) = mpsc::channel::<Message<_>>(100);
Expand Down Expand Up @@ -171,7 +167,6 @@ struct InnerCore<C> {
conn_lock: RwLock<(ConnectionMap<C>, SlotMap)>,
cluster_params: ClusterParams,
pending_requests: Mutex<Vec<PendingRequest<Response, C>>>,
tls_params: Option<TlsConnParams>,
}

type Core<C> = Arc<InnerCore<C>>;
Expand Down Expand Up @@ -457,15 +452,12 @@ where
async fn new(
initial_nodes: &[ConnectionInfo],
cluster_params: ClusterParams,
tls_params: Option<TlsConnParams>,
) -> RedisResult<Self> {
let connections =
Self::create_initial_connections(initial_nodes, &cluster_params, &tls_params).await?;
let connections = Self::create_initial_connections(initial_nodes, &cluster_params).await?;
let inner = Arc::new(InnerCore {
conn_lock: RwLock::new((connections, Default::default())),
cluster_params,
pending_requests: Mutex::new(Vec::new()),
tls_params,
});
let mut connection = ClusterConnInner {
inner,
Expand All @@ -480,15 +472,13 @@ where
async fn create_initial_connections(
initial_nodes: &[ConnectionInfo],
params: &ClusterParams,
tls_params: &Option<TlsConnParams>,
) -> RedisResult<ConnectionMap<C>> {
let connections = stream::iter(initial_nodes.iter().cloned())
.map(|info| {
let params = params.clone();
let tls_params = tls_params.clone();
async move {
let addr = info.addr.to_string();
let result = connect_and_check(&addr, params, tls_params).await;
let result = connect_and_check(&addr, params).await;
match result {
Ok(conn) => Some((addr, async { conn }.boxed().shared())),
Err(e) => {
Expand Down Expand Up @@ -528,7 +518,6 @@ where
&addr,
connections.remove(&addr),
&inner.cluster_params,
inner.tls_params.clone(),
)
.await;
if let Ok(conn) = conn {
Expand Down Expand Up @@ -584,13 +573,8 @@ where
.fold(
HashMap::with_capacity(nodes_len),
|mut connections, (addr, connection)| async {
let conn = Self::get_or_create_conn(
addr,
connection,
&inner.cluster_params,
inner.tls_params.clone(),
)
.await;
let conn =
Self::get_or_create_conn(addr, connection, &inner.cluster_params).await;
if let Ok(conn) = conn {
connections.insert(addr.to_string(), async { conn }.boxed().shared());
}
Expand Down Expand Up @@ -888,12 +872,10 @@ where

let addr_conn_option = match conn {
Some((addr, Some(conn))) => Some((addr, conn.await)),
Some((addr, None)) => {
connect_and_check(&addr, core.cluster_params.clone(), core.tls_params.clone())
.await
.ok()
.map(|conn| (addr, conn))
}
Some((addr, None)) => connect_and_check(&addr, core.cluster_params.clone())
.await
.ok()
.map(|conn| (addr, conn)),
None => None,
};

Expand Down Expand Up @@ -1047,16 +1029,15 @@ where
addr: &str,
conn_option: Option<ConnectionFuture<C>>,
params: &ClusterParams,
tls_params: Option<TlsConnParams>,
) -> RedisResult<C> {
if let Some(conn) = conn_option {
let mut conn = conn.await;
match check_connection(&mut conn).await {
Ok(_) => Ok(conn),
Err(_) => connect_and_check(addr, params.clone(), tls_params).await,
Err(_) => connect_and_check(addr, params.clone()).await,
}
} else {
connect_and_check(addr, params.clone(), tls_params).await
connect_and_check(addr, params.clone()).await
}
}
}
Expand Down Expand Up @@ -1258,16 +1239,12 @@ impl Connect for MultiplexedConnection {
}
}

async fn connect_and_check<C>(
node: &str,
params: ClusterParams,
tls_params: Option<TlsConnParams>,
) -> RedisResult<C>
async fn connect_and_check<C>(node: &str, params: ClusterParams) -> RedisResult<C>
where
C: ConnectionLike + Connect + Send + 'static,
{
let read_from_replicas = params.read_from_replicas;
let info = get_connection_info(node, params, tls_params)?;
let info = get_connection_info(node, params)?;
let mut conn = C::connect(info).await?;
check_connection(&mut conn).await?;
if read_from_replicas {
Expand Down