Skip to content

Commit

Permalink
CassandraSinkCluster keyspace based routing - handle use statements (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Sep 22, 2022
1 parent 486ca44 commit 2985349
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 35 deletions.
95 changes: 90 additions & 5 deletions shotover-proxy/src/codec/cassandra.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
use crate::frame::cassandra::CassandraOperation;
use crate::frame::cassandra::{CassandraMetadata, CassandraOperation};
use crate::frame::{CassandraFrame, Frame, MessageType};
use crate::message::{Encodable, Message, Messages};
use crate::message::{Encodable, Message, Messages, Metadata};
use crate::server::CodecReadError;
use anyhow::{anyhow, Result};
use bytes::{Buf, BufMut, BytesMut};
use cassandra_protocol::compression::Compression;
use cassandra_protocol::frame::message_error::{AdditionalErrorInfo, ErrorBody};
use cassandra_protocol::frame::{CheckEnvelopeSizeError, Envelope as RawCassandraFrame, Version};
use cassandra_protocol::frame::{
CheckEnvelopeSizeError, Envelope as RawCassandraFrame, Opcode, Version,
};
use cql3_parser::cassandra_statement::CassandraStatement;
use cql3_parser::common::Identifier;
use tokio_util::codec::{Decoder, Encoder};
use tracing::info;

#[derive(Debug, Clone)]
pub struct CassandraCodec {
compressor: Compression,
messages: Vec<Message>,
current_use_keyspace: Option<Identifier>,
}

impl Default for CassandraCodec {
Expand All @@ -27,6 +32,7 @@ impl CassandraCodec {
CassandraCodec {
compressor: Compression::None,
messages: vec![],
current_use_keyspace: None,
}
}
}
Expand Down Expand Up @@ -65,8 +71,23 @@ impl Decoder for CassandraCodec {
return Err(reject_protocol_version(version.into()));
}

self.messages
.push(Message::from_bytes(bytes.freeze(), MessageType::Cassandra));
let mut message = Message::from_bytes(bytes.freeze(), MessageType::Cassandra);

if let Ok(Metadata::Cassandra(CassandraMetadata {
opcode: Opcode::Query | Opcode::Batch,
..
})) = message.metadata()
{
if let Some(keyspace) = get_use_keyspace(&mut message) {
self.current_use_keyspace = Some(keyspace);
}

if let Some(keyspace) = &self.current_use_keyspace {
set_default_keyspace(&mut message, keyspace);
}
}

self.messages.push(message);
}
Err(CheckEnvelopeSizeError::NotEnoughBytes) => {
if self.messages.is_empty() || src.remaining() != 0 {
Expand All @@ -89,6 +110,70 @@ impl Decoder for CassandraCodec {
}
}

fn get_use_keyspace(message: &mut Message) -> Option<Identifier> {
if let Some(Frame::Cassandra(frame)) = message.frame() {
if let CassandraOperation::Query { query, .. } = &mut frame.operation {
if let CassandraStatement::Use(keyspace) = query.as_ref() {
return Some(keyspace.clone());
}
}
}
None
}

fn set_default_keyspace(message: &mut Message, keyspace: &Identifier) {
// TODO: rewrite Operation::Prepared in the same way
if let Some(Frame::Cassandra(frame)) = message.frame() {
for query in frame.operation.queries() {
let name = match query {
CassandraStatement::AlterMaterializedView(x) => &mut x.name,
CassandraStatement::AlterTable(x) => &mut x.name,
CassandraStatement::AlterType(x) => &mut x.name,
CassandraStatement::CreateAggregate(x) => &mut x.name,
CassandraStatement::CreateFunction(x) => &mut x.name,
CassandraStatement::CreateIndex(x) => &mut x.table,
CassandraStatement::CreateMaterializedView(x) => &mut x.name,
CassandraStatement::CreateTable(x) => &mut x.name,
CassandraStatement::CreateTrigger(x) => &mut x.name,
CassandraStatement::CreateType(x) => &mut x.name,
CassandraStatement::Delete(x) => &mut x.table_name,
CassandraStatement::DropAggregate(x) => &mut x.name,
CassandraStatement::DropFunction(x) => &mut x.name,
CassandraStatement::DropIndex(x) => &mut x.name,
CassandraStatement::DropMaterializedView(x) => &mut x.name,
CassandraStatement::DropTable(x) => &mut x.name,
CassandraStatement::DropTrigger(x) => &mut x.name,
CassandraStatement::DropType(x) => &mut x.name,
CassandraStatement::Insert(x) => &mut x.table_name,
CassandraStatement::Select(x) => &mut x.table_name,
CassandraStatement::Truncate(name) => name,
CassandraStatement::Update(x) => &mut x.table_name,
CassandraStatement::AlterKeyspace(_)
| CassandraStatement::AlterRole(_)
| CassandraStatement::AlterUser(_)
| CassandraStatement::ApplyBatch
| CassandraStatement::CreateKeyspace(_)
| CassandraStatement::CreateRole(_)
| CassandraStatement::CreateUser(_)
| CassandraStatement::DropRole(_)
| CassandraStatement::DropUser(_)
| CassandraStatement::Grant(_)
| CassandraStatement::ListRoles(_)
| CassandraStatement::Revoke(_)
| CassandraStatement::DropKeyspace(_)
| CassandraStatement::ListPermissions(_)
| CassandraStatement::Use(_)
| CassandraStatement::Unknown(_) => {
return;
}
};
if name.keyspace.is_none() {
name.keyspace = Some(keyspace.clone());
}
}
}
}

/// If the client tried to use a protocol that we dont support then we need to reject it.
/// The rejection process is sending back an error and then closing the connection.
fn reject_protocol_version(version: u8) -> CodecReadError {
Expand Down
13 changes: 4 additions & 9 deletions shotover-proxy/src/frame/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub mod raw_frame {
version: frame.version,
stream_id: frame.stream_id,
tracing_id: frame.tracing_id,
opcode: frame.opcode,
})
}

Expand All @@ -85,20 +86,13 @@ pub mod raw_frame {
_ => nonzero!(1u32),
})
}

pub(crate) fn get_opcode(bytes: &[u8]) -> Result<Opcode> {
if bytes.len() < 9 {
bail!("Cassandra frame too short, needs at least 9 bytes for header");
}
let opcode = Opcode::try_from(bytes[4])?;
Ok(opcode)
}
}

pub(crate) struct CassandraMetadata {
pub struct CassandraMetadata {
pub version: Version,
pub stream_id: StreamId,
pub tracing_id: Option<Uuid>,
pub opcode: Opcode,
// missing `warnings` field because we are not using it currently
}

Expand All @@ -119,6 +113,7 @@ impl CassandraFrame {
version: self.version,
stream_id: self.stream_id,
tracing_id: self.tracing_id,
opcode: self.operation.to_opcode(),
}
}

Expand Down
14 changes: 2 additions & 12 deletions shotover-proxy/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use std::net::IpAddr;
use std::num::NonZeroU32;
use uuid::Uuid;

enum Metadata {
pub enum Metadata {
Cassandra(CassandraMetadata),
Redis,
None,
Expand Down Expand Up @@ -171,16 +171,6 @@ impl Message {
}
}

/// Only use for messages read straight from the socket
/// that are definitely in an unparsed state
/// (haven't passed through any transforms where they might have been parsed or modified)
pub(crate) fn as_raw_bytes(&self) -> Option<&Bytes> {
match self.inner.as_ref().unwrap() {
MessageInner::RawBytes { bytes, .. } => Some(bytes),
_ => None,
}
}

/// Batch messages have a cell count of 1 cell per inner message.
/// Cell count is determined as follows:
/// * Regular message - 1 cell
Expand Down Expand Up @@ -270,7 +260,7 @@ impl Message {
}

/// Get metadata for this `Message`
fn metadata(&self) -> Result<Metadata> {
pub fn metadata(&self) -> Result<Metadata> {
match self.inner.as_ref().unwrap() {
MessageInner::RawBytes {
bytes,
Expand Down
16 changes: 9 additions & 7 deletions shotover-proxy/src/transforms/cassandra/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::codec::cassandra::CassandraCodec;
use crate::frame::cassandra;
use crate::message::Message;
use crate::frame::cassandra::CassandraMetadata;
use crate::message::{Message, Metadata};
use crate::tls::TlsConnector;
use crate::transforms::util::Response;
use crate::transforms::Messages;
Expand Down Expand Up @@ -139,7 +139,7 @@ async fn rx_process_fallible<T: AsyncRead>(
match response {
Ok(response) => {
for m in response {
if let Ok(Opcode::Event) = cassandra::raw_frame::get_opcode(m.as_raw_bytes().unwrap()) {
if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = m.metadata() {
if let Some(ref pushed_messages_tx) = pushed_messages_tx {
pushed_messages_tx.send(vec![m]).unwrap();
}
Expand Down Expand Up @@ -223,10 +223,12 @@ pub async fn receive_message(
response: Ok(message),
..
} => {
if let Some(raw_bytes) = message.as_raw_bytes() {
if let Ok(Opcode::Error) = cassandra::raw_frame::get_opcode(raw_bytes) {
failed_requests.increment(1);
}
if let Ok(Metadata::Cassandra(CassandraMetadata {
opcode: Opcode::Error,
..
})) = message.metadata()
{
failed_requests.increment(1);
}
Ok(message)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,6 @@ impl CassandraSinkCluster {
}
}

// TODO: handle use statement state
fn is_system_query(&self, request: &mut Message) -> bool {
if let Some(Frame::Cassandra(frame)) = request.frame() {
if let CassandraOperation::Query { query, .. } = &mut frame.operation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tokio::time::{sleep, timeout};

use crate::cassandra_int_tests::cluster::run_topology_task;
use crate::helpers::cassandra::{
assert_query_result, CassandraConnection, CassandraDriver, ResultValue,
assert_query_result, run_query, CassandraConnection, CassandraDriver, ResultValue,
};
use crate::helpers::ShotoverManager;
use std::net::SocketAddr;
Expand All @@ -32,6 +32,10 @@ async fn test_rewrite_system_peers(connection: &CassandraConnection) {
async fn test_rewrite_system_peers_v2(connection: &CassandraConnection) {
let all_columns = "peer, peer_port, data_center, host_id, native_address, native_port, preferred_ip, preferred_port, rack, release_version, schema_version, tokens";
assert_query_result(connection, "SELECT * FROM system.peers_v2;", &[]).await;

run_query(connection, "USE system;").await;
assert_query_result(connection, "SELECT * FROM peers_v2;", &[]).await;

assert_query_result(
connection,
&format!("SELECT {all_columns} FROM system.peers_v2;"),
Expand Down

0 comments on commit 2985349

Please sign in to comment.