Skip to content

Commit

Permalink
Cassandra connection avoid cloning entire message
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Nov 15, 2022
1 parent a469a28 commit 27672b6
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 50 deletions.
15 changes: 14 additions & 1 deletion shotover-proxy/src/frame/cassandra.rs
Expand Up @@ -8,7 +8,7 @@ use cassandra_protocol::frame::events::ServerEvent;
use cassandra_protocol::frame::message_batch::{
BatchQuery, BatchQuerySubj, BatchType, BodyReqBatch,
};
use cassandra_protocol::frame::message_error::ErrorBody;
use cassandra_protocol::frame::message_error::{ErrorBody, ErrorType};
use cassandra_protocol::frame::message_event::BodyResEvent;
use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned;
use cassandra_protocol::frame::message_query::BodyReqQuery;
Expand Down Expand Up @@ -161,6 +161,19 @@ impl CassandraFrame {
})
}

pub fn shotover_error(stream_id: i16, version: Version, message: &str) -> Self {
CassandraFrame {
version,
stream_id,
operation: CassandraOperation::Error(ErrorBody {
message: format!("Internal shotover error: {message}"),
ty: ErrorType::Server,
}),
tracing: Tracing::Response(None),
warnings: vec![],
}
}

pub fn from_bytes(bytes: Bytes) -> Result<Self> {
let frame = RawCassandraFrame::from_buffer(&bytes, Compression::None)
.map_err(|e| anyhow!("{e:?}"))?
Expand Down
71 changes: 44 additions & 27 deletions shotover-proxy/src/transforms/cassandra/connection.rs
@@ -1,13 +1,13 @@
use crate::codec::cassandra::CassandraCodec;
use crate::frame::cassandra::CassandraMetadata;
use crate::frame::{CassandraFrame, Frame};
use crate::message::{Message, Metadata};
use crate::server::CodecReadError;
use crate::tcp;
use crate::tls::{TlsConnector, ToHostname};
use crate::transforms::util::Response;
use crate::transforms::Messages;
use anyhow::{anyhow, Result};
use cassandra_protocol::frame::Opcode;
use cassandra_protocol::frame::{Opcode, Version};
use derivative::Derivative;
use futures::stream::FuturesOrdered;
use futures::{SinkExt, StreamExt};
Expand All @@ -25,7 +25,19 @@ use tracing::{error, Instrument};
struct Request {
message: Message,
return_chan: oneshot::Sender<Response>,
message_id: i16,
stream_id: i16,
}

#[derive(Debug)]
pub struct Response {
pub stream_id: i16,
pub response: Result<Message>,
}

#[derive(Debug)]
struct ReturnChannel {
return_chan: oneshot::Sender<Response>,
stream_id: i16,
}

#[derive(Clone, Derivative)]
Expand All @@ -43,7 +55,7 @@ impl CassandraConnection {
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
) -> Result<Self> {
let (out_tx, out_rx) = mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = mpsc::unbounded_channel::<ReturnChannel>();
let (rx_process_has_shutdown_tx, rx_process_has_shutdown_rx) = oneshot::channel::<()>();

if let Some(tls) = tls.as_mut() {
Expand Down Expand Up @@ -100,12 +112,12 @@ impl CassandraConnection {
/// Send a `Message` to this `CassandraConnection` and expect a response on `return_chan`
pub fn send(&self, message: Message, return_chan: oneshot::Sender<Response>) -> Result<()> {
// Convert the message to `Request` and send upstream
if let Some(message_id) = message.stream_id() {
if let Some(stream_id) = message.stream_id() {
self.connection
.send(Request {
message,
return_chan,
message_id,
stream_id,
})
.map_err(|x| x.into())
} else {
Expand All @@ -117,7 +129,7 @@ impl CassandraConnection {
async fn tx_process<T: AsyncWrite>(
write: WriteHalf<T>,
out_rx: mpsc::UnboundedReceiver<Request>,
return_tx: mpsc::UnboundedSender<Request>,
return_tx: mpsc::UnboundedSender<ReturnChannel>,
codec: CassandraCodec,
rx_process_has_shutdown_rx: oneshot::Receiver<()>,
) {
Expand All @@ -131,15 +143,18 @@ async fn tx_process<T: AsyncWrite>(
async fn tx_process_fallible<T: AsyncWrite>(
write: WriteHalf<T>,
mut out_rx: mpsc::UnboundedReceiver<Request>,
return_tx: mpsc::UnboundedSender<Request>,
return_tx: mpsc::UnboundedSender<ReturnChannel>,
codec: CassandraCodec,
rx_process_has_shutdown_rx: oneshot::Receiver<()>,
) -> Result<()> {
let mut in_w = FramedWrite::new(write, codec);
loop {
if let Some(request) = out_rx.recv().await {
in_w.send(vec![request.message.clone()]).await?;
return_tx.send(request)?;
in_w.send(vec![request.message]).await?;
return_tx.send(ReturnChannel {
return_chan: request.return_chan,
stream_id: request.stream_id,
})?;
} else {
// transform is shutting down, time to cleanly shutdown both tx_process and rx_process.
// We need to ensure that the rx_process task has shutdown before closing the write half of the tcpstream
Expand All @@ -163,7 +178,7 @@ async fn tx_process_fallible<T: AsyncWrite>(

async fn rx_process<T: AsyncRead>(
read: ReadHalf<T>,
return_rx: mpsc::UnboundedReceiver<Request>,
return_rx: mpsc::UnboundedReceiver<ReturnChannel>,
codec: CassandraCodec,
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
rx_process_has_shutdown_tx: oneshot::Sender<()>,
Expand All @@ -178,7 +193,7 @@ async fn rx_process<T: AsyncRead>(

async fn rx_process_fallible<T: AsyncRead>(
read: ReadHalf<T>,
mut return_rx: mpsc::UnboundedReceiver<Request>,
mut return_rx: mpsc::UnboundedReceiver<ReturnChannel>,
codec: CassandraCodec,
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
) -> Result<()> {
Expand All @@ -198,7 +213,7 @@ async fn rx_process_fallible<T: AsyncRead>(
// In order to handle that we have two seperate maps.
//
// We store the sender + original message here if we receive from the tx_process task first
let mut from_tx_process: HashMap<i16, (oneshot::Sender<Response>, Message)> = HashMap::new();
let mut from_tx_process: HashMap<i16, oneshot::Sender<Response>> = HashMap::new();

// We store the response message here if we receive from the server first.
let mut from_server: HashMap<i16, Message> = HashMap::new();
Expand All @@ -209,7 +224,8 @@ async fn rx_process_fallible<T: AsyncRead>(
match response {
Some(Ok(response)) => {
for m in response {
if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = m.metadata() {
let meta = m.metadata();
if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = meta {
if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() {
pushed_messages_tx.send(vec![m]).unwrap();
}
Expand All @@ -218,8 +234,8 @@ async fn rx_process_fallible<T: AsyncRead>(
None => {
from_server.insert(stream_id, m);
},
Some((return_tx, original)) => {
return_tx.send(Response { original, response: Ok(m) })
Some(return_tx) => {
return_tx.send(Response { stream_id, response: Ok(m) })
.map_err(|_| anyhow!("couldn't send message"))?;
}
}
Expand All @@ -237,13 +253,13 @@ async fn rx_process_fallible<T: AsyncRead>(
}
},
original_request = return_rx.recv() => {
if let Some(Request { message, return_chan, message_id }) = original_request {
match from_server.remove(&message_id) {
if let Some(ReturnChannel { return_chan, stream_id }) = original_request {
match from_server.remove(&stream_id) {
None => {
from_tx_process.insert(message_id, (return_chan, message));
from_tx_process.insert(stream_id, return_chan);
}
Some(m) => {
return_chan.send(Response { original: message, response: Ok(m) })
return_chan.send(Response { stream_id, response: Ok(m) })
.map_err(|_| anyhow!("couldn't send message"))?;
}
}
Expand All @@ -259,14 +275,15 @@ pub async fn receive(
timeout_duration: Option<Duration>,
failed_requests: &metrics::Counter,
mut results: FuturesOrdered<oneshot::Receiver<Response>>,
version: Version,
) -> Result<Messages> {
let expected_size = results.len();
let mut responses = Vec::with_capacity(expected_size);
while responses.len() < expected_size {
if let Some(timeout_duration) = timeout_duration {
match timeout(
timeout_duration,
receive_message(failed_requests, &mut results),
receive_message(failed_requests, &mut results, version),
)
.await
{
Expand All @@ -282,7 +299,7 @@ pub async fn receive(
}
}
} else {
responses.push(receive_message(failed_requests, &mut results).await?);
responses.push(receive_message(failed_requests, &mut results, version).await?);
}
}
Ok(responses)
Expand All @@ -291,6 +308,7 @@ pub async fn receive(
pub async fn receive_message(
failed_requests: &metrics::Counter,
results: &mut FuturesOrdered<oneshot::Receiver<Response>>,
version: Version,
) -> Result<Message> {
match results.next().await {
Some(result) => match result? {
Expand All @@ -308,12 +326,11 @@ pub async fn receive_message(
Ok(message)
}
Response {
mut original,
stream_id,
response: Err(err),
} => {
original.set_error(err.to_string());
Ok(original)
}
} => Ok(Message::from_frame(Frame::Cassandra(
CassandraFrame::shotover_error(stream_id, version, &err.to_string()),
))),
},
None => unreachable!("Ran out of responses"),
}
Expand Down
46 changes: 32 additions & 14 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs
@@ -1,10 +1,9 @@
use crate::error::ChainResponse;
use crate::frame::cassandra::{parse_statement_single, CassandraMetadata, Tracing};
use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame};
use crate::message::{IntSize, Message, MessageValue, Messages};
use crate::message::{IntSize, Message, MessageValue, Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::CassandraConnection;
use crate::transforms::util::Response;
use crate::transforms::cassandra::connection::{CassandraConnection, Response};
use crate::transforms::{Transform, Transforms, Wrapper};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -111,6 +110,7 @@ pub struct CassandraSinkCluster {
control_connection_address: Option<SocketAddr>,
init_handshake_complete: bool,

version: Version,
chain_name: String,
failed_requests: Counter,
read_timeout: Option<Duration>,
Expand Down Expand Up @@ -138,6 +138,7 @@ impl Clone for CassandraSinkCluster {
connection_factory: self.connection_factory.new_with_same_config(),
control_connection_address: None,
init_handshake_complete: false,
version: self.version,
chain_name: self.chain_name.clone(),
failed_requests: self.failed_requests.clone(),
read_timeout: self.read_timeout,
Expand Down Expand Up @@ -192,6 +193,8 @@ impl CassandraSinkCluster {
control_connection: None,
control_connection_address: None,
init_handshake_complete: false,
// Dummy value that gets replaced on the first message
version: Version::V4,
chain_name,
failed_requests,
read_timeout: receive_timeout,
Expand Down Expand Up @@ -229,6 +232,18 @@ fn create_query(messages: &Messages, query: &str, version: Version) -> Result<Me

impl CassandraSinkCluster {
async fn send_message(&mut self, mut messages: Messages) -> ChainResponse {
if let Some(message) = messages.first() {
if let Ok(Metadata::Cassandra(CassandraMetadata { version, .. })) = message.metadata() {
self.version = version;
} else {
return Err(anyhow!(
"Failed to extract cassandra version from incoming message: Not a valid cassandra message"
));
}
} else {
return Ok(vec![]);
}

if self.nodes_rx.has_changed()? {
self.pool.update_nodes(&mut self.nodes_rx);

Expand Down Expand Up @@ -265,13 +280,13 @@ impl CassandraSinkCluster {
let query = "SELECT rack, data_center, schema_version, tokens, release_version FROM system.peers";
messages.insert(
table_to_rewrite.index + 1,
create_query(&messages, query, table_to_rewrite.version)?,
create_query(&messages, query, self.version)?,
);
if let RewriteTableTy::Peers = table_to_rewrite.ty {
let query = "SELECT rack, data_center, schema_version, tokens, release_version FROM system.local";
messages.insert(
table_to_rewrite.index + 2,
create_query(&messages, query, table_to_rewrite.version)?,
create_query(&messages, query, self.version)?,
);
}
}
Expand Down Expand Up @@ -386,7 +401,7 @@ impl CassandraSinkCluster {
.get_replica_node_in_dc(
execute,
&self.local_shotover_node.rack,
&metadata.version,
self.version,
&mut self.rng,
)
.await
Expand All @@ -408,12 +423,12 @@ impl CassandraSinkCluster {
Err(GetReplicaErr::NoPreparedMetadata) => {
let id = execute.id.clone();
tracing::info!("forcing re-prepare on {:?}", id);
// this shotover node doesn't have the metadata
// this shotover node doesn't have the metadata.
// send an unprepared error in response to force
// the client to reprepare the query
return_chan_tx
.send(Response {
original: message.clone(),
stream_id: metadata.stream_id,
response: Ok(Message::from_frame(Frame::Cassandra(
CassandraFrame {
operation: CassandraOperation::Error(ErrorBody {
Expand All @@ -424,7 +439,7 @@ impl CassandraSinkCluster {
}),
stream_id: metadata.stream_id,
tracing: Tracing::Response(None), // We didn't actually hit a node so we don't have a tracing id
version: metadata.version,
version: self.version,
warnings: vec![],
},
))),
Expand All @@ -449,15 +464,20 @@ impl CassandraSinkCluster {
responses_future.push_back(return_chan_rx)
}

let mut responses =
super::connection::receive(self.read_timeout, &self.failed_requests, responses_future)
.await?;
let mut responses = super::connection::receive(
self.read_timeout,
&self.failed_requests,
responses_future,
self.version,
)
.await?;

{
let mut prepare_responses = super::connection::receive(
self.read_timeout,
&self.failed_requests,
responses_future_prepare,
self.version,
)
.await?;

Expand Down Expand Up @@ -635,7 +655,6 @@ impl CassandraSinkCluster {
index,
ty,
warnings,
version: cassandra.version,
selects: select.columns.clone(),
});
}
Expand Down Expand Up @@ -997,7 +1016,6 @@ impl CassandraSinkCluster {
struct TableToRewrite {
index: usize,
ty: RewriteTableTy,
version: Version,
selects: Vec<SelectElement>,
warnings: Vec<String>,
}
Expand Down

0 comments on commit 27672b6

Please sign in to comment.