Skip to content

Commit

Permalink
move topology to mod (#782)
Browse files Browse the repository at this point in the history
  • Loading branch information
conorbros committed Sep 2, 2022
1 parent 895850f commit 02cf509
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 162 deletions.
162 changes: 2 additions & 160 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, oneshot, RwLock};
use topology::{create_topology_task, TaskConnectionInfo};
use uuid::Uuid;
use version_compare::Cmp;

pub mod node;
pub mod topology;

#[derive(Deserialize, Debug, Clone)]
pub struct CassandraSinkClusterConfig {
Expand Down Expand Up @@ -720,160 +722,6 @@ enum RewriteTableTy {
Peers,
}

pub fn create_topology_task(
nodes: Arc<RwLock<Vec<CassandraNode>>>,
mut handshake_rx: mpsc::Receiver<TaskConnectionInfo>,
data_center: String,
) {
tokio::spawn(async move {
while let Some(handshake) = handshake_rx.recv().await {
let mut attempts = 0;
while let Err(err) = topology_task_process(&nodes, &handshake, &data_center).await {
tracing::error!("topology task failed, retrying, error was: {err:?}");
attempts += 1;
if attempts > 3 {
// 3 attempts have failed, lets try a new handshake
break;
}
}

// Sleep for an hour.
// TODO: This is a crude way to ensure we dont overload the transforms with too many topology changes.
// This will be replaced with:
// * the task subscribes to events
// * the transforms request a reload when they hit connection errors
tokio::time::sleep(std::time::Duration::from_secs(60 * 60)).await;
}
});
}

async fn topology_task_process(
nodes: &Arc<RwLock<Vec<CassandraNode>>>,
handshake: &TaskConnectionInfo,
data_center: &str,
) -> Result<()> {
let outbound = handshake
.connection_factory
.new_connection(handshake.address)
.await?;

let (peers_tx, peers_rx) = oneshot::channel();
outbound.send(
Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
stream_id: 0,
tracing_id: None,
warnings: vec![],
operation: CassandraOperation::Query {
query: Box::new(parse_statement_single(
"SELECT peer, rack, data_center, tokens FROM system.peers",
)),
params: Box::new(QueryParams::default()),
},
})),
peers_tx,
)?;

let (local_tx, local_rx) = oneshot::channel();
outbound.send(
Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
stream_id: 1,
tracing_id: None,
warnings: vec![],
operation: CassandraOperation::Query {
query: Box::new(parse_statement_single(
"SELECT broadcast_address, rack, data_center, tokens FROM system.local",
)),
params: Box::new(QueryParams::default()),
},
})),
local_tx,
)?;

let (new_nodes, more_nodes) = tokio::join!(
async { system_peers_into_nodes(peers_rx.await?.response?, data_center) },
async { system_peers_into_nodes(local_rx.await?.response?, data_center) }
);
let mut new_nodes = new_nodes?;
new_nodes.extend(more_nodes?);

let mut write_lock = nodes.write().await;
let expensive_drop = std::mem::replace(&mut *write_lock, new_nodes);

// Make sure to drop write_lock before the expensive_drop which will have to perform many deallocations.
std::mem::drop(write_lock);
std::mem::drop(expensive_drop);

Ok(())
}

fn system_peers_into_nodes(
mut response: Message,
config_data_center: &str,
) -> Result<Vec<CassandraNode>> {
if let Some(Frame::Cassandra(frame)) = response.frame() {
match &mut frame.operation {
CassandraOperation::Result(CassandraResult::Rows {
value: MessageValue::Rows(rows),
..
}) => rows
.iter_mut()
.filter(|row| {
if let Some(MessageValue::Varchar(data_center)) = row.get(2) {
data_center == config_data_center
} else {
false
}
})
.map(|row| {
if row.len() != 4 {
return Err(anyhow!("expected 4 columns but was {}", row.len()));
}

let tokens = if let Some(MessageValue::List(list)) = row.pop() {
list.into_iter()
.map::<Result<String>, _>(|x| match x {
MessageValue::Varchar(a) => Ok(a),
_ => Err(anyhow!("tokens value not a varchar")),
})
.collect::<Result<Vec<String>>>()?
} else {
return Err(anyhow!("tokens not a list"));
};
let _data_center = row.pop();
let rack = if let Some(MessageValue::Varchar(value)) = row.pop() {
value
} else {
return Err(anyhow!("rack not a varchar"));
};
let address = if let Some(MessageValue::Inet(value)) = row.pop() {
value
} else {
return Err(anyhow!("address not an inet"));
};

Ok(CassandraNode {
address,
rack,
_tokens: tokens,
outbound: None,
})
})
.collect(),
operation => Err(anyhow!(
"system.peers returned unexpected cassandra operation: {:?}",
operation
)),
}
} else {
Err(anyhow!(
"Failed to parse system.peers response {:?}",
response
))
}
}

fn is_use_statement(request: &mut Message) -> bool {
if let Some(Frame::Cassandra(frame)) = request.frame() {
if let CassandraOperation::Query { query, .. } = &mut frame.operation {
Expand Down Expand Up @@ -1031,9 +879,3 @@ impl Transform for CassandraSinkCluster {
.set_pushed_messages_tx(pushed_messages_tx);
}
}

#[derive(Debug)]
pub struct TaskConnectionInfo {
pub connection_factory: ConnectionFactory,
pub address: SocketAddr,
}
170 changes: 170 additions & 0 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/topology.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use super::node::{CassandraNode, ConnectionFactory};
use crate::frame::cassandra::parse_statement_single;
use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame};
use crate::message::{Message, MessageValue};
use anyhow::{anyhow, Result};
use cassandra_protocol::frame::Version;
use cassandra_protocol::query::QueryParams;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot, RwLock};

#[derive(Debug)]
pub struct TaskConnectionInfo {
pub connection_factory: ConnectionFactory,
pub address: SocketAddr,
}

pub fn create_topology_task(
nodes: Arc<RwLock<Vec<CassandraNode>>>,
mut handshake_rx: mpsc::Receiver<TaskConnectionInfo>,
data_center: String,
) {
tokio::spawn(async move {
while let Some(handshake) = handshake_rx.recv().await {
let mut attempts = 0;
while let Err(err) = topology_task_process(&nodes, &handshake, &data_center).await {
tracing::error!("topology task failed, retrying, error was: {err:?}");
attempts += 1;
if attempts > 3 {
// 3 attempts have failed, lets try a new handshake
break;
}
}

// Sleep for an hour.
// TODO: This is a crude way to ensure we dont overload the transforms with too many topology changes.
// This will be replaced with:
// * the task subscribes to events
// * the transforms request a reload when they hit connection errors
tokio::time::sleep(std::time::Duration::from_secs(60 * 60)).await;
}
});
}

async fn topology_task_process(
nodes: &Arc<RwLock<Vec<CassandraNode>>>,
handshake: &TaskConnectionInfo,
data_center: &str,
) -> Result<()> {
let outbound = handshake
.connection_factory
.new_connection(handshake.address)
.await?;

let (peers_tx, peers_rx) = oneshot::channel();
outbound.send(
Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
stream_id: 0,
tracing_id: None,
warnings: vec![],
operation: CassandraOperation::Query {
query: Box::new(parse_statement_single(
"SELECT peer, rack, data_center, tokens FROM system.peers",
)),
params: Box::new(QueryParams::default()),
},
})),
peers_tx,
)?;

let (local_tx, local_rx) = oneshot::channel();
outbound.send(
Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
stream_id: 1,
tracing_id: None,
warnings: vec![],
operation: CassandraOperation::Query {
query: Box::new(parse_statement_single(
"SELECT broadcast_address, rack, data_center, tokens FROM system.local",
)),
params: Box::new(QueryParams::default()),
},
})),
local_tx,
)?;

let (new_nodes, more_nodes) = tokio::join!(
async { system_peers_into_nodes(peers_rx.await?.response?, data_center) },
async { system_peers_into_nodes(local_rx.await?.response?, data_center) }
);
let mut new_nodes = new_nodes?;
new_nodes.extend(more_nodes?);

let mut write_lock = nodes.write().await;
let expensive_drop = std::mem::replace(&mut *write_lock, new_nodes);

// Make sure to drop write_lock before the expensive_drop which will have to perform many deallocations.
std::mem::drop(write_lock);
std::mem::drop(expensive_drop);

Ok(())
}

fn system_peers_into_nodes(
mut response: Message,
config_data_center: &str,
) -> Result<Vec<CassandraNode>> {
if let Some(Frame::Cassandra(frame)) = response.frame() {
match &mut frame.operation {
CassandraOperation::Result(CassandraResult::Rows {
value: MessageValue::Rows(rows),
..
}) => rows
.iter_mut()
.filter(|row| {
if let Some(MessageValue::Varchar(data_center)) = row.get(2) {
data_center == config_data_center
} else {
false
}
})
.map(|row| {
if row.len() != 4 {
return Err(anyhow!("expected 4 columns but was {}", row.len()));
}

let tokens = if let Some(MessageValue::List(list)) = row.pop() {
list.into_iter()
.map::<Result<String>, _>(|x| match x {
MessageValue::Varchar(a) => Ok(a),
_ => Err(anyhow!("tokens value not a varchar")),
})
.collect::<Result<Vec<String>>>()?
} else {
return Err(anyhow!("tokens not a list"));
};
let _data_center = row.pop();
let rack = if let Some(MessageValue::Varchar(value)) = row.pop() {
value
} else {
return Err(anyhow!("rack not a varchar"));
};
let address = if let Some(MessageValue::Inet(value)) = row.pop() {
value
} else {
return Err(anyhow!("address not an inet"));
};

Ok(CassandraNode {
address,
rack,
_tokens: tokens,
outbound: None,
})
})
.collect(),
operation => Err(anyhow!(
"system.peers returned unexpected cassandra operation: {:?}",
operation
)),
}
} else {
Err(anyhow!(
"Failed to parse system.peers response {:?}",
response
))
}
}
3 changes: 1 addition & 2 deletions shotover-proxy/tests/cassandra_int_tests/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use shotover_proxy::frame::{CassandraFrame, CassandraOperation, Frame};
use shotover_proxy::message::Message;
use shotover_proxy::tls::{TlsConnector, TlsConnectorConfig};
use shotover_proxy::transforms::cassandra::sink_cluster::{
create_topology_task,
node::{CassandraNode, ConnectionFactory},
TaskConnectionInfo,
topology::{create_topology_task, TaskConnectionInfo},
};
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
Expand Down

0 comments on commit 02cf509

Please sign in to comment.