Skip to content

Commit

Permalink
212/fix-infinite-hangs
Browse files Browse the repository at this point in the history
  • Loading branch information
alk888 committed Feb 14, 2024
1 parent 7db8d45 commit 90e7d51
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 32 deletions.
80 changes: 69 additions & 11 deletions src/client/dispatcher.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use futures::Stream;
use rabbitmq_stream_protocol::Response;
use std::sync::{atomic::AtomicU32, Arc};
use std::sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
};
use tracing::trace;

use dashmap::DashMap;
Expand All @@ -17,7 +20,7 @@ use super::{channel::ChannelReceiver, handler::MessageHandler};
pub(crate) struct Dispatcher<T>(DispatcherState<T>);

pub(crate) struct DispatcherState<T> {
requests: Arc<DashMap<u32, Sender<Response>>>,
requests: Arc<RequestsMap>,
correlation_id: Arc<AtomicU32>,
handler: Arc<RwLock<Option<T>>>,
}
Expand All @@ -32,13 +35,49 @@ impl<T> Clone for DispatcherState<T> {
}
}

struct RequestsMap {
requests: DashMap<u32, Sender<Response>>,
closed: AtomicBool,
}

impl RequestsMap {
fn new() -> RequestsMap {
RequestsMap {
requests: DashMap::new(),
closed: AtomicBool::new(false),
}
}

fn insert(&self, correlation_id: u32, sender: Sender<Response>) -> bool {
if self.closed.load(Ordering::Relaxed) {
return false;
}
self.requests.insert(correlation_id, sender);
true
}

fn remove(&self, correlation_id: u32) -> Option<Sender<Response>> {
self.requests.remove(&correlation_id).map(|r| r.1)
}

fn close(&self) {
self.closed.store(true, Ordering::Relaxed);
self.requests.clear();
}

#[cfg(test)]
fn len(&self) -> usize {
self.requests.len()
}
}

impl<T> Dispatcher<T>
where
T: MessageHandler,
{
pub fn new() -> Dispatcher<T> {
Dispatcher(DispatcherState {
requests: Arc::new(DashMap::new()),
requests: Arc::new(RequestsMap::new()),
correlation_id: Arc::new(AtomicU32::new(0)),
handler: Arc::new(RwLock::new(None)),
})
Expand All @@ -47,23 +86,25 @@ where
#[cfg(test)]
pub fn with_handler(handler: T) -> Dispatcher<T> {
Dispatcher(DispatcherState {
requests: Arc::new(DashMap::new()),
requests: Arc::new(RequestsMap::new()),
correlation_id: Arc::new(AtomicU32::new(0)),
handler: Arc::new(RwLock::new(Some(handler))),
})
}

pub async fn response_channel(&self) -> (u32, Receiver<Response>) {
pub fn response_channel(&self) -> Option<(u32, Receiver<Response>)> {
let (tx, rx) = channel(1);

let correlation_id = self
.0
.correlation_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);

self.0.requests.insert(correlation_id, tx);

(correlation_id, rx)
if self.0.requests.insert(correlation_id, tx) {
Some((correlation_id, rx))
} else {
None
}
}

#[cfg(test)]
Expand All @@ -75,6 +116,7 @@ where
let mut guard = self.0.handler.write().await;
*guard = Some(handler);
}

pub async fn start<R>(&self, stream: ChannelReceiver<R>)
where
R: Stream<Item = Result<Response, ClientError>> + Unpin + Send,
Expand All @@ -89,10 +131,10 @@ where
T: MessageHandler,
{
pub async fn dispatch(&self, correlation_id: u32, response: Response) {
let receiver = self.requests.remove(&correlation_id);
let receiver = self.requests.remove(correlation_id);

if let Some(rcv) = receiver {
let _ = rcv.1.send(response).await;
let _ = rcv.send(response).await;
}
}

Expand All @@ -103,6 +145,7 @@ where
}

pub async fn close(self, error: Option<ClientError>) {
self.requests.close();
if let Some(handler) = self.handler.read().await.as_ref() {
if let Some(err) = error {
let _ = handler.handle_message(Some(Err(err))).await;
Expand Down Expand Up @@ -265,7 +308,7 @@ mod tests {

dispatcher.start(rx).await;

let (correlation_id, mut rx) = dispatcher.response_channel().await;
let (correlation_id, mut rx) = dispatcher.response_channel().unwrap();

let req: Request = PeerPropertiesCommand::new(correlation_id, HashMap::new()).into();

Expand Down Expand Up @@ -298,4 +341,19 @@ mod tests {

assert!(matches!(response, Some(..)));
}

#[tokio::test]
async fn should_reject_requests_after_closing() {
let mock_source = MockIO::push();

let dispatcher = Dispatcher::with_handler(|_| async { Ok(()) });

let maybe_channel = dispatcher.response_channel();
assert!(maybe_channel.is_some());

dispatcher.0.requests.close();

let maybe_channel = dispatcher.response_channel();
assert!(maybe_channel.is_none());
}
}
51 changes: 30 additions & 21 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::convert::TryFrom;

use std::ops::DerefMut;
use std::{
collections::HashMap,
io,
Expand Down Expand Up @@ -79,6 +80,7 @@ mod message;
mod metadata;
mod metrics;
mod options;
mod task;

#[cfg_attr(docsrs, doc(cfg(feature = "tokio-stream")))]
#[pin_project(project = StreamProj)]
Expand Down Expand Up @@ -138,7 +140,7 @@ pub struct ClientState {
heartbeat: u32,
max_frame_size: u32,
last_heatbeat: Instant,
heartbeat_task: Option<JoinHandle<()>>,
heartbeat_task: Option<task::TaskHandle>,
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -249,9 +251,7 @@ impl Client {

let mut state = self.state.write().await;

if let Some(heartbeat_task) = state.heartbeat_task.take() {
heartbeat_task.abort();
}
state.heartbeat_task.take();

drop(state);
self.channel.close().await
Expand Down Expand Up @@ -476,6 +476,9 @@ impl Client {
})
.await?;

// Start heartbeat task after connection is established
self.start_hearbeat_task(self.state.write().await.deref_mut());

Ok(())
}

Expand Down Expand Up @@ -545,13 +548,15 @@ impl Client {
T: FromResponse,
M: FnOnce(u32) -> R,
{
let (correlation_id, mut receiver) = self.dispatcher.response_channel().await;
let Some((correlation_id, mut receiver)) = self.dispatcher.response_channel() else {
return Err(ClientError::ConnectionClosed);
};

self.channel
.send(msg_factory(correlation_id).into())
.await?;

let response = receiver.recv().await.expect("It should contain a response");
let response = receiver.recv().await.ok_or(ClientError::ConnectionClosed)?;

self.handle_response::<T>(response).await
}
Expand Down Expand Up @@ -609,21 +614,8 @@ impl Client {
heart_beat
);

if let Some(task) = state.heartbeat_task.take() {
task.abort();
}

if heart_beat != 0 {
let heartbeat_interval = (heart_beat / 2).max(1);
let channel = self.channel.clone();
let heartbeat_task = tokio::spawn(async move {
loop {
trace!("Sending heartbeat");
let _ = channel.send(HeartBeatCommand::default().into()).await;
tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await;
}
});
state.heartbeat_task = Some(heartbeat_task);
if state.heartbeat_task.take().is_some() {
self.start_hearbeat_task(&mut state);
}

drop(state);
Expand All @@ -636,6 +628,23 @@ impl Client {
self.tune_notifier.notify_one();
}

fn start_hearbeat_task(&self, state: &mut ClientState) {
if state.heartbeat == 0 {
return;
}
let heartbeat_interval = (state.heartbeat / 2).max(1);
let channel = self.channel.clone();
let heartbeat_task = tokio::spawn(async move {
loop {
trace!("Sending heartbeat");
let _ = channel.send(HeartBeatCommand::default().into()).await;
tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await;
}
})
.into();
state.heartbeat_task = Some(heartbeat_task);
}

async fn handle_heart_beat_command(&self) {
trace!("Received heartbeat");
let mut state = self.state.write().await;
Expand Down
15 changes: 15 additions & 0 deletions src/client/task.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pub struct TaskHandle {
task: tokio::task::JoinHandle<()>,
}

impl From<tokio::task::JoinHandle<()>> for TaskHandle {
fn from(task: tokio::task::JoinHandle<()>) -> Self {
TaskHandle { task }
}
}

impl Drop for TaskHandle {
fn drop(&mut self) {
self.task.abort();
}
}
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub enum ClientError {
GenericError(#[from] Box<dyn std::error::Error + Send + Sync>),
#[error("Client already closed")]
AlreadyClosed,
#[error("Connection closed")]
ConnectionClosed,
#[error(transparent)]
Tls(#[from] tokio_rustls::rustls::Error),
#[error("Request error: {0:?}")]
Expand Down
10 changes: 10 additions & 0 deletions tests/integration/client_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;

use fake::{Fake, Faker};
use rabbitmq_stream_protocol::commands::close::CloseRequest;
use tokio::sync::mpsc::channel;

use rabbitmq_stream_client::error::ClientError;
Expand Down Expand Up @@ -368,3 +369,12 @@ async fn client_publish() {
delivery.messages.get(0).unwrap().data()
);
}

#[cfg(test)]
#[tokio::test(flavor = "multi_thread")]
async fn client_handle_unexpected_connection_interruption() {
let mut options = ClientOptions::default();
options.set_port(5672);
let res = Client::connect(options).await;
assert!(matches!(res, Err(ClientError::ConnectionClosed)));
}

0 comments on commit 90e7d51

Please sign in to comment.