Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cassandra connection avoid cloning entire message #913

Merged
merged 4 commits into from Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 14 additions & 1 deletion shotover-proxy/src/frame/cassandra.rs
Expand Up @@ -8,7 +8,7 @@ use cassandra_protocol::frame::events::ServerEvent;
use cassandra_protocol::frame::message_batch::{
BatchQuery, BatchQuerySubj, BatchType, BodyReqBatch,
};
use cassandra_protocol::frame::message_error::ErrorBody;
use cassandra_protocol::frame::message_error::{ErrorBody, ErrorType};
use cassandra_protocol::frame::message_event::BodyResEvent;
use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned;
use cassandra_protocol::frame::message_query::BodyReqQuery;
Expand Down Expand Up @@ -161,6 +161,19 @@ impl CassandraFrame {
})
}

pub fn shotover_error(stream_id: i16, version: Version, message: &str) -> Self {
CassandraFrame {
version,
stream_id,
operation: CassandraOperation::Error(ErrorBody {
message: format!("Internal shotover error: {message}"),
ty: ErrorType::Server,
}),
tracing: Tracing::Response(None),
warnings: vec![],
}
}

pub fn from_bytes(bytes: Bytes) -> Result<Self> {
let frame = RawCassandraFrame::from_buffer(&bytes, Compression::None)
.map_err(|e| anyhow!("{e:?}"))?
Expand Down
77 changes: 47 additions & 30 deletions shotover-proxy/src/transforms/cassandra/connection.rs
@@ -1,13 +1,13 @@
use crate::codec::cassandra::CassandraCodec;
use crate::frame::cassandra::CassandraMetadata;
use crate::frame::{CassandraFrame, Frame};
use crate::message::{Message, Metadata};
use crate::server::CodecReadError;
use crate::tcp;
use crate::tls::{TlsConnector, ToHostname};
use crate::transforms::util::Response;
use crate::transforms::Messages;
use anyhow::{anyhow, Result};
use cassandra_protocol::frame::Opcode;
use cassandra_protocol::frame::{Opcode, Version};
use derivative::Derivative;
use futures::stream::FuturesOrdered;
use futures::{SinkExt, StreamExt};
Expand All @@ -25,7 +25,19 @@ use tracing::{error, Instrument};
struct Request {
message: Message,
return_chan: oneshot::Sender<Response>,
message_id: i16,
stream_id: i16,
}

#[derive(Debug)]
pub struct Response {
pub stream_id: i16,
pub response: Result<Message>,
}

#[derive(Debug)]
struct ReturnChannel {
return_chan: oneshot::Sender<Response>,
stream_id: i16,
}

#[derive(Clone, Derivative)]
Expand All @@ -43,7 +55,7 @@ impl CassandraConnection {
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
) -> Result<Self> {
let (out_tx, out_rx) = mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = mpsc::unbounded_channel::<ReturnChannel>();
let (rx_process_has_shutdown_tx, rx_process_has_shutdown_rx) = oneshot::channel::<()>();

if let Some(tls) = tls.as_mut() {
Expand Down Expand Up @@ -100,12 +112,12 @@ impl CassandraConnection {
/// Send a `Message` to this `CassandraConnection` and expect a response on `return_chan`
pub fn send(&self, message: Message, return_chan: oneshot::Sender<Response>) -> Result<()> {
// Convert the message to `Request` and send upstream
if let Some(message_id) = message.stream_id() {
if let Some(stream_id) = message.stream_id() {
self.connection
.send(Request {
message,
return_chan,
message_id,
stream_id,
})
.map_err(|x| x.into())
} else {
Expand All @@ -117,7 +129,7 @@ impl CassandraConnection {
async fn tx_process<T: AsyncWrite>(
write: WriteHalf<T>,
out_rx: mpsc::UnboundedReceiver<Request>,
return_tx: mpsc::UnboundedSender<Request>,
return_tx: mpsc::UnboundedSender<ReturnChannel>,
codec: CassandraCodec,
rx_process_has_shutdown_rx: oneshot::Receiver<()>,
) {
Expand All @@ -131,15 +143,18 @@ async fn tx_process<T: AsyncWrite>(
async fn tx_process_fallible<T: AsyncWrite>(
write: WriteHalf<T>,
mut out_rx: mpsc::UnboundedReceiver<Request>,
return_tx: mpsc::UnboundedSender<Request>,
return_tx: mpsc::UnboundedSender<ReturnChannel>,
codec: CassandraCodec,
rx_process_has_shutdown_rx: oneshot::Receiver<()>,
) -> Result<()> {
let mut in_w = FramedWrite::new(write, codec);
loop {
if let Some(request) = out_rx.recv().await {
in_w.send(vec![request.message.clone()]).await?;
return_tx.send(request)?;
in_w.send(vec![request.message]).await?;
return_tx.send(ReturnChannel {
return_chan: request.return_chan,
stream_id: request.stream_id,
})?;
} else {
// transform is shutting down, time to cleanly shutdown both tx_process and rx_process.
// We need to ensure that the rx_process task has shutdown before closing the write half of the tcpstream
Expand All @@ -163,7 +178,7 @@ async fn tx_process_fallible<T: AsyncWrite>(

async fn rx_process<T: AsyncRead>(
read: ReadHalf<T>,
return_rx: mpsc::UnboundedReceiver<Request>,
return_rx: mpsc::UnboundedReceiver<ReturnChannel>,
codec: CassandraCodec,
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
rx_process_has_shutdown_tx: oneshot::Sender<()>,
Expand All @@ -178,7 +193,7 @@ async fn rx_process<T: AsyncRead>(

async fn rx_process_fallible<T: AsyncRead>(
read: ReadHalf<T>,
mut return_rx: mpsc::UnboundedReceiver<Request>,
mut return_rx: mpsc::UnboundedReceiver<ReturnChannel>,
codec: CassandraCodec,
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
) -> Result<()> {
Expand All @@ -192,13 +207,13 @@ async fn rx_process_fallible<T: AsyncRead>(
// Implementation:
// To process a message we need to receive things from two different sources:
// 1. the response from the cassandra server
// 2. the oneshot::Sender and original message from the tx_process task
// 2. the oneshot::Sender from the tx_process task
//
// We can receive these in any order.
// In order to handle that we have two seperate maps.
//
// We store the sender + original message here if we receive from the tx_process task first
let mut from_tx_process: HashMap<i16, (oneshot::Sender<Response>, Message)> = HashMap::new();
// We store the sender here if we receive from the tx_process task first
let mut from_tx_process: HashMap<i16, oneshot::Sender<Response>> = HashMap::new();

// We store the response message here if we receive from the server first.
let mut from_server: HashMap<i16, Message> = HashMap::new();
Expand All @@ -209,7 +224,8 @@ async fn rx_process_fallible<T: AsyncRead>(
match response {
Some(Ok(response)) => {
for m in response {
if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = m.metadata() {
let meta = m.metadata();
if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = meta {
if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() {
pushed_messages_tx.send(vec![m]).unwrap();
}
Expand All @@ -218,8 +234,8 @@ async fn rx_process_fallible<T: AsyncRead>(
None => {
from_server.insert(stream_id, m);
},
Some((return_tx, original)) => {
return_tx.send(Response { original, response: Ok(m) })
Some(return_tx) => {
return_tx.send(Response { stream_id, response: Ok(m) })
.map_err(|_| anyhow!("couldn't send message"))?;
}
}
Expand All @@ -236,14 +252,14 @@ async fn rx_process_fallible<T: AsyncRead>(
None => return Ok(())
}
},
original_request = return_rx.recv() => {
if let Some(Request { message, return_chan, message_id }) = original_request {
match from_server.remove(&message_id) {
return_chan = return_rx.recv() => {
if let Some(ReturnChannel { return_chan, stream_id }) = return_chan {
match from_server.remove(&stream_id) {
None => {
from_tx_process.insert(message_id, (return_chan, message));
from_tx_process.insert(stream_id, return_chan);
}
Some(m) => {
return_chan.send(Response { original: message, response: Ok(m) })
return_chan.send(Response { stream_id, response: Ok(m) })
.map_err(|_| anyhow!("couldn't send message"))?;
}
}
Expand All @@ -259,14 +275,15 @@ pub async fn receive(
timeout_duration: Option<Duration>,
failed_requests: &metrics::Counter,
mut results: FuturesOrdered<oneshot::Receiver<Response>>,
version: Version,
) -> Result<Messages> {
let expected_size = results.len();
let mut responses = Vec::with_capacity(expected_size);
while responses.len() < expected_size {
if let Some(timeout_duration) = timeout_duration {
match timeout(
timeout_duration,
receive_message(failed_requests, &mut results),
receive_message(failed_requests, &mut results, version),
)
.await
{
Expand All @@ -282,7 +299,7 @@ pub async fn receive(
}
}
} else {
responses.push(receive_message(failed_requests, &mut results).await?);
responses.push(receive_message(failed_requests, &mut results, version).await?);
}
}
Ok(responses)
Expand All @@ -291,6 +308,7 @@ pub async fn receive(
pub async fn receive_message(
failed_requests: &metrics::Counter,
results: &mut FuturesOrdered<oneshot::Receiver<Response>>,
version: Version,
) -> Result<Message> {
match results.next().await {
Some(result) => match result? {
Expand All @@ -308,12 +326,11 @@ pub async fn receive_message(
Ok(message)
}
Response {
mut original,
stream_id,
response: Err(err),
} => {
original.set_error(err.to_string());
Ok(original)
}
} => Ok(Message::from_frame(Frame::Cassandra(
CassandraFrame::shotover_error(stream_id, version, &err.to_string()),
))),
},
None => unreachable!("Ran out of responses"),
}
Expand Down
46 changes: 32 additions & 14 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs
@@ -1,10 +1,9 @@
use crate::error::ChainResponse;
use crate::frame::cassandra::{parse_statement_single, CassandraMetadata, Tracing};
use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame};
use crate::message::{IntSize, Message, MessageValue, Messages};
use crate::message::{IntSize, Message, MessageValue, Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::CassandraConnection;
use crate::transforms::util::Response;
use crate::transforms::cassandra::connection::{CassandraConnection, Response};
use crate::transforms::{Transform, Transforms, Wrapper};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -111,6 +110,7 @@ pub struct CassandraSinkCluster {
control_connection_address: Option<SocketAddr>,
init_handshake_complete: bool,

version: Version,
chain_name: String,
failed_requests: Counter,
read_timeout: Option<Duration>,
Expand Down Expand Up @@ -138,6 +138,7 @@ impl Clone for CassandraSinkCluster {
connection_factory: self.connection_factory.new_with_same_config(),
control_connection_address: None,
init_handshake_complete: false,
version: self.version,
chain_name: self.chain_name.clone(),
failed_requests: self.failed_requests.clone(),
read_timeout: self.read_timeout,
Expand Down Expand Up @@ -192,6 +193,8 @@ impl CassandraSinkCluster {
control_connection: None,
control_connection_address: None,
init_handshake_complete: false,
// Dummy value that gets replaced on the first message
version: Version::V4,
chain_name,
failed_requests,
read_timeout: receive_timeout,
Expand Down Expand Up @@ -229,6 +232,18 @@ fn create_query(messages: &Messages, query: &str, version: Version) -> Result<Me

impl CassandraSinkCluster {
async fn send_message(&mut self, mut messages: Messages) -> ChainResponse {
if let Some(message) = messages.first() {
conorbros marked this conversation as resolved.
Show resolved Hide resolved
if let Ok(Metadata::Cassandra(CassandraMetadata { version, .. })) = message.metadata() {
self.version = version;
} else {
return Err(anyhow!(
"Failed to extract cassandra version from incoming message: Not a valid cassandra message"
));
}
} else {
return Ok(vec![]);
conorbros marked this conversation as resolved.
Show resolved Hide resolved
}

if self.nodes_rx.has_changed()? {
self.pool.update_nodes(&mut self.nodes_rx);

Expand Down Expand Up @@ -265,13 +280,13 @@ impl CassandraSinkCluster {
let query = "SELECT rack, data_center, schema_version, tokens, release_version FROM system.peers";
messages.insert(
table_to_rewrite.index + 1,
create_query(&messages, query, table_to_rewrite.version)?,
create_query(&messages, query, self.version)?,
);
if let RewriteTableTy::Peers = table_to_rewrite.ty {
let query = "SELECT rack, data_center, schema_version, tokens, release_version FROM system.local";
messages.insert(
table_to_rewrite.index + 2,
create_query(&messages, query, table_to_rewrite.version)?,
create_query(&messages, query, self.version)?,
);
}
}
Expand Down Expand Up @@ -386,7 +401,7 @@ impl CassandraSinkCluster {
.get_replica_node_in_dc(
execute,
&self.local_shotover_node.rack,
&metadata.version,
self.version,
&mut self.rng,
)
.await
Expand All @@ -408,12 +423,12 @@ impl CassandraSinkCluster {
Err(GetReplicaErr::NoPreparedMetadata) => {
let id = execute.id.clone();
tracing::info!("forcing re-prepare on {:?}", id);
// this shotover node doesn't have the metadata
// this shotover node doesn't have the metadata.
// send an unprepared error in response to force
// the client to reprepare the query
return_chan_tx
.send(Response {
original: message.clone(),
stream_id: metadata.stream_id,
response: Ok(Message::from_frame(Frame::Cassandra(
CassandraFrame {
operation: CassandraOperation::Error(ErrorBody {
Expand All @@ -424,7 +439,7 @@ impl CassandraSinkCluster {
}),
stream_id: metadata.stream_id,
tracing: Tracing::Response(None), // We didn't actually hit a node so we don't have a tracing id
version: metadata.version,
version: self.version,
warnings: vec![],
},
))),
Expand All @@ -449,15 +464,20 @@ impl CassandraSinkCluster {
responses_future.push_back(return_chan_rx)
}

let mut responses =
super::connection::receive(self.read_timeout, &self.failed_requests, responses_future)
.await?;
let mut responses = super::connection::receive(
self.read_timeout,
&self.failed_requests,
responses_future,
self.version,
)
.await?;

{
let mut prepare_responses = super::connection::receive(
self.read_timeout,
&self.failed_requests,
responses_future_prepare,
self.version,
)
.await?;

Expand Down Expand Up @@ -635,7 +655,6 @@ impl CassandraSinkCluster {
index,
ty,
warnings,
version: cassandra.version,
selects: select.columns.clone(),
});
}
Expand Down Expand Up @@ -997,7 +1016,6 @@ impl CassandraSinkCluster {
struct TableToRewrite {
index: usize,
ty: RewriteTableTy,
version: Version,
selects: Vec<SelectElement>,
warnings: Vec<String>,
}
Expand Down