Skip to content

Commit

Permalink
Include destination in cassandra connection errors (#926)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Nov 21, 2022
1 parent e31cd45 commit 17f08ff
Showing 1 changed file with 40 additions and 16 deletions.
56 changes: 40 additions & 16 deletions shotover-proxy/src/transforms/cassandra/connection.rs
Expand Up @@ -59,6 +59,8 @@ impl CassandraConnection {
let (return_tx, return_rx) = mpsc::unbounded_channel::<ReturnChannel>();
let (rx_process_has_shutdown_tx, rx_process_has_shutdown_rx) = oneshot::channel::<String>();

let destination = format!("{host:?}");

if let Some(tls) = tls.as_mut() {
let tls_stream = tls.connect(connect_timeout, host).await?;
let (read, write) = split(tls_stream);
Expand All @@ -69,6 +71,7 @@ impl CassandraConnection {
return_tx,
codec.clone(),
rx_process_has_shutdown_rx,
destination.clone(),
)
.in_current_span(),
);
Expand All @@ -79,6 +82,7 @@ impl CassandraConnection {
codec.clone(),
pushed_messages_tx,
rx_process_has_shutdown_tx,
destination,
)
.in_current_span(),
);
Expand All @@ -92,6 +96,7 @@ impl CassandraConnection {
return_tx,
codec.clone(),
rx_process_has_shutdown_rx,
destination.clone(),
)
.in_current_span(),
);
Expand All @@ -102,6 +107,7 @@ impl CassandraConnection {
codec.clone(),
pushed_messages_tx,
rx_process_has_shutdown_tx,
destination,
)
.in_current_span(),
);
Expand Down Expand Up @@ -149,6 +155,8 @@ async fn tx_process<T: AsyncWrite>(
return_tx: mpsc::UnboundedSender<ReturnChannel>,
codec: CassandraCodec,
mut rx_process_has_shutdown_rx: oneshot::Receiver<String>,
// Only used for error reporting
destination: String,
) {
let mut in_w = FramedWrite::new(write, codec);

Expand All @@ -159,10 +167,10 @@ async fn tx_process<T: AsyncWrite>(
loop {
if let Some(request) = out_rx.recv().await {
if let Some(error) = &connection_dead_error {
send_error_to_request(request.return_chan, request.stream_id, error);
send_error_to_request(request.return_chan, request.stream_id, &destination, error);
} else if let Err(error) = in_w.send(vec![request.message]).await {
let error = format!("{:?}", error);
send_error_to_request(request.return_chan, request.stream_id, &error);
send_error_to_request(request.return_chan, request.stream_id, &destination, &error);
connection_dead_error = Some(error.clone());
} else if let Err(mpsc::error::SendError(return_chan)) = return_tx.send(ReturnChannel {
return_chan: request.return_chan,
Expand All @@ -171,7 +179,12 @@ async fn tx_process<T: AsyncWrite>(
let error = rx_process_has_shutdown_rx
.try_recv()
.expect("Rx task must send this before closing return_tx");
send_error_to_request(return_chan.return_chan, return_chan.stream_id, &error);
send_error_to_request(
return_chan.return_chan,
return_chan.stream_id,
&destination,
&error,
);
connection_dead_error = Some(error.clone());
}
}
Expand All @@ -198,11 +211,18 @@ async fn tx_process<T: AsyncWrite>(
}
}

fn send_error_to_request(return_chan: oneshot::Sender<Response>, stream_id: i16, error: &str) {
fn send_error_to_request(
return_chan: oneshot::Sender<Response>,
stream_id: i16,
destination: &str,
error: &str,
) {
return_chan
.send(Response {
stream_id,
response: Err(anyhow!(error.to_owned())),
response: Err(anyhow!(
"Connection to destination cassandra node {destination} was closed: {error}"
)),
})
.ok();
}
Expand All @@ -213,6 +233,8 @@ async fn rx_process<T: AsyncRead>(
codec: CassandraCodec,
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
rx_process_has_shutdown_tx: oneshot::Sender<String>,
// Only used for error reporting
destination: String,
) {
let mut reader = FramedRead::new(read, codec);

Expand Down Expand Up @@ -261,18 +283,18 @@ async fn rx_process<T: AsyncRead>(
Some(Err(CodecReadError::Io(err))) => {
// Manually handle Io errors so they can use the nicer Display formatting
let error_message = format!("IO error: {err}");
send_errors_and_shutdown(return_rx, from_tx_process, rx_process_has_shutdown_tx, &error_message).await;
send_errors_and_shutdown(return_rx, from_tx_process, rx_process_has_shutdown_tx, destination, &error_message).await;
return;
}
Some(Err(err)) => {
// Anyhow errors should be formatted with Debug
let error_message = format!("{err:?}");
send_errors_and_shutdown(return_rx, from_tx_process, rx_process_has_shutdown_tx, &error_message).await;
send_errors_and_shutdown(return_rx, from_tx_process, rx_process_has_shutdown_tx, destination, &error_message).await;
return;
}
None => {
// We know the connection wasnt closed by the tx task dropping its writer because the tx task must outlive the rx task
send_errors_and_shutdown(return_rx, from_tx_process, rx_process_has_shutdown_tx, "The destination cassandra node closed the conection").await;
send_errors_and_shutdown(return_rx, from_tx_process, rx_process_has_shutdown_tx, destination, "The destination cassandra node closed the conection").await;
return;
}
}
Expand Down Expand Up @@ -303,21 +325,26 @@ async fn send_errors_and_shutdown(
mut return_rx: mpsc::UnboundedReceiver<ReturnChannel>,
mut waiting: HashMap<i16, oneshot::Sender<Response>>,
rx_process_has_shutdown_tx: oneshot::Sender<String>,
destination: String,
message: &str,
) {
// Ensure we send this before closing return_rx.
// This means that when the tx task finds return_rx is closed, it can rely on rx_process_has_shutdown_tx being already sent
rx_process_has_shutdown_tx
// Dont send the full message here because the tx task is responsible for that.
.send(message.to_owned())
.expect("Tx task must outlive rx task");

return_rx.close();

let full_message =
format!("Connection to destination cassandra node {destination} was closed: {message}");

for (stream_id, return_tx) in waiting.drain() {
return_tx
.send(Response {
stream_id,
response: Err(anyhow!(message.to_owned())),
response: Err(anyhow!(full_message.to_owned())),
})
.ok();
}
Expand All @@ -328,7 +355,7 @@ async fn send_errors_and_shutdown(
.return_chan
.send(Response {
stream_id: return_chan.stream_id,
response: Err(anyhow!(message.to_owned())),
response: Err(anyhow!(full_message.to_owned())),
})
.ok();
}
Expand Down Expand Up @@ -391,12 +418,9 @@ async fn receive_message(
Response {
stream_id,
response: Err(err),
} => {
tracing::error!("receive() error: {:?}", err);
Ok(Message::from_frame(Frame::Cassandra(
CassandraFrame::shotover_error(stream_id, version, &format!("{:?}", err)),
)))
}
} => Ok(Message::from_frame(Frame::Cassandra(
CassandraFrame::shotover_error(stream_id, version, &format!("{:?}", err)),
))),
},
None => unreachable!("Ran out of responses"),
}
Expand Down

0 comments on commit 17f08ff

Please sign in to comment.