diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 8f776ef5d..3b3d2471c 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -870,10 +870,20 @@ where let addr_conn_option = match conn { Some((addr, Some(conn))) => Some((addr, conn.await)), - Some((addr, None)) => connect_and_check(&addr, core.cluster_params.clone()) - .await - .ok() - .map(|conn| (addr, conn)), + Some((addr, None)) => { + match connect_and_check::(&addr, core.cluster_params.clone()).await { + Ok(conn) => { + let conn_clone = conn.clone(); + core.conn_lock + .write() + .await + .0 + .insert(addr.clone(), async { conn_clone }.boxed().shared()); + Some((addr, conn)) + } + Err(_) => None, + } + } None => None, }; diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index 7256238ae..530db6048 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -638,6 +638,48 @@ fn test_async_cluster_ask_redirect() { assert_eq!(value, Ok(Some(123))); } +#[test] +fn test_async_cluster_ask_save_new_connection() { + let name = "node"; + let ping_attempts = Arc::new(AtomicI32::new(0)); + let ping_attempts_clone = ping_attempts.clone(); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + if port != 6391 { + respond_startup_two_nodes(name, cmd)?; + return Err(parse_redis_value(b"-ASK 14000 node:6391\r\n")); + } + + if contains_slice(cmd, b"PING") { + ping_attempts_clone.fetch_add(1, Ordering::Relaxed); + } + respond_startup_two_nodes(name, cmd)?; + Err(Ok(Value::Okay)) + } + }, + ); + + for _ in 0..4 { + runtime + .block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ) + .unwrap(); + } + + assert_eq!(ping_attempts.load(Ordering::Relaxed), 1); +} + #[test] fn test_async_cluster_ask_redirect_even_if_original_call_had_no_route() { let name = "node"; @@ -1456,6 +1498,81 @@ fn test_async_cluster_non_retryable_error_should_not_retry() { assert_eq!(completed.load(Ordering::SeqCst), 1); } +#[test] +fn test_async_cluster_saves_reconnected_connection() { + let name = "test_async_cluster_saves_reconnected_connection"; + let ping_attempts = Arc::new(AtomicI32::new(0)); + let ping_attempts_clone = ping_attempts.clone(); + let get_attempts = AtomicI32::new(0); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(1), + name, + move |cmd: &[u8], port| { + if port == 6380 { + respond_startup_two_nodes(name, cmd)?; + return Err(parse_redis_value( + format!("-MOVED 123 {name}:6379\r\n").as_bytes(), + )); + } + + if contains_slice(cmd, b"PING") { + let connect_attempt = ping_attempts_clone.fetch_add(1, Ordering::Relaxed); + let past_get_attempts = get_attempts.load(Ordering::Relaxed); + // We want connection checks to fail after the first GET attempt, until it retries. Hence, we wait for 5 PINGs - + // 1. initial connection, + // 2. refresh slots on client creation, + // 3. refresh_connections `check_connection` after first GET failed, + // 4. refresh_connections `connect_and_check` after first GET failed, + // 5. reconnect on 2nd GET attempt. + // more than 5 attempts mean that the server reconnects more than once, which is the behavior we're testing against. + if past_get_attempts != 1 || connect_attempt > 3 { + respond_startup_two_nodes(name, cmd)?; + } + if connect_attempt > 5 { + panic!("Too many pings!"); + } + Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "mock-io-error", + )))) + } else { + respond_startup_two_nodes(name, cmd)?; + let past_get_attempts = get_attempts.fetch_add(1, Ordering::Relaxed); + // we fail the initial GET request, and after that we'll fail the first reconnect attempt, in the `refresh_connections` attempt. + if past_get_attempts == 0 { + // Error once with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "mock-io-error", + )))) + } else { + Err(Ok(Value::Data(b"123".to_vec()))) + } + } + }, + ); + + for _ in 0..4 { + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + // If you need to change the number here due to a change in the cluster, you probably also need to adjust the test. + // See the PING counts above to explain why 5 is the target number. + assert_eq!(ping_attempts.load(Ordering::Acquire), 5); +} + #[cfg(feature = "tls-rustls")] mod mtls_test { use crate::support::mtls_test::create_cluster_client_from_cluster;