Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Value support for Cassandra types #409

Merged
merged 7 commits into from Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion shotover-proxy/Cargo.toml
Expand Up @@ -54,7 +54,8 @@ halfbrown = "0.1.11"

# Transform dependencies
redis-protocol = "3.0.1"
cassandra-protocol = { git = "https://github.com/krojew/cdrs-tokio" }
#cassandra-protocol = { git = "https://github.com/krojew/cdrs-tokio" }
cassandra-protocol = { git = "https://github.com/conorbros/cdrs-tokio", branch = "typed-collections" }
rukai marked this conversation as resolved.
Show resolved Hide resolved
conorbros marked this conversation as resolved.
Show resolved Hide resolved
rdkafka = "0.28"
crc16 = "0.4.0"
ordered-float = { version = "2.8.0", features = ["serde"] }
Expand Down
275 changes: 206 additions & 69 deletions shotover-proxy/src/message/mod.rs
Expand Up @@ -4,15 +4,19 @@ use bytes::Bytes;
use cassandra_protocol::{
frame::{
frame_error::{AdditionalErrorInfo, ErrorBody},
frame_result::{ColSpec, ColType},
frame_result::{ColSpec, ColType, ColTypeOptionValue},
Direction, Flags, Frame as CassandraFrame, Opcode, Serialize as CassandraSerialize,
},
types::{
cassandra_type::CassandraType,
data_serialization_types::{
decode_ascii, decode_bigint, decode_boolean, decode_decimal, decode_double,
decode_float, decode_inet, decode_int, decode_smallint, decode_tinyint, decode_varchar,
decode_ascii, decode_bigint, decode_boolean, decode_date, decode_decimal,
decode_double, decode_float, decode_inet, decode_int, decode_list, decode_map,
decode_set, decode_smallint, decode_time, decode_timestamp, decode_tinyint,
decode_tuple, decode_udt, decode_varchar, decode_varint,
},
CBytes,
prelude::{List, Map, Tuple, Udt},
AsCassandraType, CBytes,
},
};
use num::BigInt;
Expand Down Expand Up @@ -385,7 +389,7 @@ impl From<Value> for Frame {
fn from(value: Value) -> Frame {
match value {
Value::NULL => Frame::Null,
Value::None => unimplemented!(),
Value::None => todo!(),
Value::Bytes(b) => Frame::BulkString(b.to_vec()),
Value::Strings(s) => Frame::SimpleString(s),
Value::Integer(i, _) => Frame::Integer(i),
Expand All @@ -394,24 +398,24 @@ impl From<Value> for Frame {
Value::Inet(i) => Frame::SimpleString(i.to_string()),
Value::List(l) => Frame::Array(l.into_iter().map(|v| v.into()).collect()),
Value::Rows(r) => Frame::Array(r.into_iter().map(|v| Value::List(v).into()).collect()),
Value::NamedRows(_) => unimplemented!(),
Value::Document(_) => unimplemented!(),
Value::NamedRows(_) => todo!(),
Value::Document(_) => todo!(),
Value::FragmentedResponse(l) => Frame::Array(l.into_iter().map(|v| v.into()).collect()),
Value::Ascii(_) => unimplemented!(),
Value::Double(_) => unimplemented!(),
Value::Set(_) => unimplemented!(),
Value::Map(_) => unimplemented!(),
Value::Varint(_) => unimplemented!(),
Value::Decimal(_) => unimplemented!(),
Value::Date(_) => unimplemented!(),
Value::Timestamp(_) => unimplemented!(),
Value::Timeuuid(_) => unimplemented!(),
Value::Varchar(_) => unimplemented!(),
Value::Uuid(_) => unimplemented!(),
Value::Time(_) => unimplemented!(),
Value::Counter(_) => unimplemented!(),
Value::Tuple(_) => unimplemented!(),
Value::Udt(_) => unimplemented!(),
Value::Ascii(_a) => todo!(),
Value::Double(_d) => todo!(),
Value::Set(_s) => todo!(),
Value::Map(_) => todo!(),
Value::Varint(_v) => todo!(),
Value::Decimal(_d) => todo!(),
Value::Date(_date) => todo!(),
Value::Timestamp(_timestamp) => todo!(),
Value::Timeuuid(_timeuuid) => todo!(),
Value::Varchar(_v) => todo!(),
Value::Uuid(_uuid) => todo!(),
Value::Time(_t) => todo!(),
Value::Counter(_c) => todo!(),
Value::Tuple(_) => todo!(),
Value::Udt(_) => todo!(),
}
}
}
Expand Down Expand Up @@ -447,27 +451,200 @@ impl Value {
}
ColType::Uuid => Value::Bytes(Bytes::copy_from_slice(actual_bytes)),
ColType::Varchar => Value::Strings(decode_varchar(actual_bytes).unwrap()),
ColType::Varint => unimplemented!("We dont have a varint type yet"),
ColType::Varint => Value::Varint(decode_varint(actual_bytes).unwrap()),
ColType::Timeuuid => Value::Bytes(Bytes::copy_from_slice(actual_bytes)),
ColType::Inet => Value::Inet(decode_inet(actual_bytes).unwrap()),
ColType::Timestamp => Value::NULL,
ColType::Date => Value::NULL,
ColType::Time => Value::NULL,
ColType::Date => Value::Date(decode_date(actual_bytes).unwrap()),
ColType::Timestamp => Value::Timestamp(decode_timestamp(actual_bytes).unwrap()),
ColType::Time => Value::Time(decode_time(actual_bytes).unwrap()),
ColType::Smallint => {
Value::Integer(decode_smallint(actual_bytes).unwrap() as i64, IntSize::I16)
}
ColType::Tinyint => {
Value::Integer(decode_tinyint(actual_bytes).unwrap() as i64, IntSize::I8)
}
// TODO: process collection types based on ColTypeOption
// (https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L569)
_ => Value::NULL,
ColType::List => {
let decoded_list = decode_list(actual_bytes).unwrap();
let list = List::new(spec.col_type.clone(), decoded_list)
.as_cassandra_type()
.unwrap()
.unwrap();

let typed_list = Value::create_list(list);
Value::List(typed_list)
}
ColType::Map => {
let decoded_map = decode_map(actual_bytes).unwrap();
let map = Map::new(decoded_map, spec.col_type.clone())
.as_cassandra_type()
.unwrap()
.unwrap();

#[allow(clippy::mutable_key_type)]
let typed_map = Value::create_map(map);

Value::Map(typed_map)
}
ColType::Set => {
let decoded_set = decode_set(actual_bytes).unwrap();
let set = List::new(spec.col_type.clone(), decoded_set)
.as_cassandra_type()
.unwrap()
.unwrap();

#[allow(clippy::mutable_key_type)]
let typed_set = Value::create_set(set);
Value::Set(typed_set)
}
ColType::Udt => {
if let Some(ColTypeOptionValue::UdtType(ref list_type_option)) =
spec.col_type.value
{
let len = list_type_option.descriptions.len();
let decoded_udt = decode_udt(actual_bytes, len).unwrap();

let udt = Udt::new(decoded_udt, list_type_option)
.as_cassandra_type()
.unwrap()
.unwrap();

let typed_udt = Value::create_udt(udt);

Value::Udt(typed_udt)
} else {
panic!("Not a UDT. This indicates a bug in cassandra_protocol")
}
}
ColType::Tuple => {
if let Some(ColTypeOptionValue::TupleType(ref list_type_option)) =
spec.col_type.value
{
let len = list_type_option.types.len();
let decoded_tuple = decode_tuple(actual_bytes, len).unwrap();
let tuple = Tuple::new(decoded_tuple, list_type_option)
.as_cassandra_type()
.unwrap()
.unwrap();

let typed_tuple = Value::create_tuple(tuple);
Value::Tuple(typed_tuple)
} else {
panic!("Not a Tuple. This indicates a bug in cassandra_protocol")
}
}
ColType::Custom => unimplemented!(),
ColType::Null => Value::NULL,
}
} else {
Value::NULL
}
}

fn create_udt(collection: CassandraType) -> BTreeMap<String, Value> {
if let CassandraType::Udt(udt) = collection {
let mut values = BTreeMap::new();
udt.into_iter().for_each(|(key, element)| {
values.insert(key, Value::create_element(element));
});

values
} else {
panic!("Not a UDT. Only CassandraType::Udt should be passed to this method.");
}
}

#[allow(clippy::mutable_key_type)]
fn create_map(collection: CassandraType) -> BTreeMap<Value, Value> {
if let CassandraType::Map(map) = collection {
let mut value_list = BTreeMap::new();
for (key, value) in map.into_iter() {
value_list.insert(Value::create_element(key), Value::create_element(value));
}

value_list
} else {
panic!("Not a Map. Only CassandraType::Map should be passed to this method");
}
}

fn create_element(element: CassandraType) -> Value {
rukai marked this conversation as resolved.
Show resolved Hide resolved
match element {
CassandraType::Ascii(a) => Value::Ascii(a),
CassandraType::Bigint(b) => Value::Integer(b, IntSize::I64),
CassandraType::Blob(b) => Value::Bytes(b.into_vec().into()),
CassandraType::Boolean(b) => Value::Boolean(b),
CassandraType::Counter(c) => Value::Counter(c),
CassandraType::Decimal(d) => {
let big_decimal = BigDecimal::new(d.unscaled, d.scale.into());
Value::Decimal(big_decimal)
}
CassandraType::Double(d) => Value::Double(d.into()),
CassandraType::Float(f) => Value::Float(f.into()),
CassandraType::Int(c) => Value::Integer(c as i64, IntSize::I64),
CassandraType::Timestamp(t) => Value::Timestamp(t),
CassandraType::Uuid(u) => Value::Uuid(u),
CassandraType::Varchar(v) => Value::Varchar(v),
CassandraType::Varint(v) => Value::Varint(v),
CassandraType::Timeuuid(t) => Value::Timeuuid(t),
CassandraType::Inet(i) => Value::Inet(i),
CassandraType::Date(d) => Value::Date(d),
CassandraType::Time(d) => Value::Time(d),
CassandraType::Smallint(d) => Value::Integer(d.into(), IntSize::I16),
CassandraType::Tinyint(d) => Value::Integer(d.into(), IntSize::I8),
CassandraType::List(_) => Value::List(Value::create_list(element)),
CassandraType::Map(_) => Value::Map(Value::create_map(element)),
CassandraType::Set(_) => Value::Set(Value::create_set(element)),
CassandraType::Udt(_) => Value::Udt(Value::create_udt(element)),
CassandraType::Tuple(_) => Value::Tuple(Value::create_tuple(element)),
CassandraType::Null => Value::NULL,
}
}

fn create_list(collection: CassandraType) -> Vec<Value> {
match collection {
CassandraType::List(collection) => {
let mut value_list = Vec::with_capacity(collection.len());
for element in collection.into_iter() {
rukai marked this conversation as resolved.
Show resolved Hide resolved
value_list.push(Value::create_element(element));
}

value_list
}
_ => panic!("Not a List. Only CassandraType::List should be passed to this method."),
}
}

#[allow(clippy::mutable_key_type)]
fn create_set(collection: CassandraType) -> BTreeSet<Value> {
rukai marked this conversation as resolved.
Show resolved Hide resolved
match collection {
CassandraType::List(collection) | CassandraType::Set(collection) => {
let mut value_list = BTreeSet::new();
for element in collection.into_iter() {
value_list.insert(Value::create_element(element));
}

value_list
}
_ => panic!(
"Not a List or Set. Only CassandraType::List or CassandraType::Set should be passed to this method."
),
}
}

fn create_tuple(collection: CassandraType) -> Vec<Value> {
match collection {
CassandraType::Tuple(collection) => {
let mut value_list = Vec::with_capacity(collection.len());
for element in collection.into_iter() {
value_list.push(Value::create_element(element));
}

value_list
}
_ => panic!("Not a Tuple. Only CassandraType::Tuple should be passed to this method."),
}
}

pub fn into_str_bytes(self) -> Bytes {
match self {
Value::NULL => Bytes::from("".to_string()),
Expand Down Expand Up @@ -500,46 +677,6 @@ impl Value {
Value::Udt(_) => unimplemented!(),
}
}

pub fn into_bytes(self) -> Bytes {
match self {
Value::NULL => Bytes::new(),
Value::None => Bytes::new(),
Value::Bytes(b) => b,
Value::Strings(s) => Bytes::from(s),
Value::Integer(i, _) => Bytes::from(Vec::from(i.to_le_bytes())),
Value::Float(f) => Bytes::from(Vec::from(f.to_le_bytes())),
Value::Boolean(b) => Bytes::from(Vec::from(if b {
(1_u8).to_le_bytes()
} else {
(0_u8).to_le_bytes()
})),
Value::Inet(i) => Bytes::from(match i {
IpAddr::V4(four) => Vec::from(four.octets()),
IpAddr::V6(six) => Vec::from(six.octets()),
}),
Value::FragmentedResponse(_) => unimplemented!(),
Value::Document(_) => unimplemented!(),
Value::NamedRows(_) => unimplemented!(),
Value::List(_) => unimplemented!(),
Value::Rows(_) => unimplemented!(),
Value::Ascii(_) => unimplemented!(),
Value::Double(_) => unimplemented!(),
Value::Set(_) => unimplemented!(),
Value::Map(_) => unimplemented!(),
Value::Varint(_) => unimplemented!(),
Value::Decimal(_) => unimplemented!(),
Value::Date(_) => unimplemented!(),
Value::Timestamp(_) => unimplemented!(),
Value::Timeuuid(_) => unimplemented!(),
Value::Varchar(_) => unimplemented!(),
Value::Uuid(_) => unimplemented!(),
Value::Time(_) => unimplemented!(),
Value::Counter(_) => unimplemented!(),
Value::Tuple(_) => unimplemented!(),
Value::Udt(_) => unimplemented!(),
}
}
}

impl From<Value> for cassandra_protocol::types::value::Bytes {
Expand Down
23 changes: 22 additions & 1 deletion shotover-proxy/src/protocols/cassandra_codec.rs
Expand Up @@ -178,7 +178,28 @@ impl CassandraCodec {
let _ = temp.write_i32::<BigEndian>(*x as i32).unwrap();
temp
}
_ => unreachable!(),
Value::None => unimplemented!(),
Value::Inet(_) => unimplemented!(),
Value::FragmentedResponse(_) => unimplemented!(),
Value::Document(_) => unimplemented!(),
Value::NamedRows(_) => unimplemented!(),
Value::List(_) => unimplemented!(),
Value::Rows(_) => unimplemented!(),
Value::Ascii(_) => unimplemented!(),
Value::Double(_) => unimplemented!(),
Value::Set(_) => unimplemented!(),
Value::Map(_) => unimplemented!(),
Value::Varint(_) => unimplemented!(),
Value::Decimal(_) => unimplemented!(),
Value::Date(_) => unimplemented!(),
Value::Timestamp(_) => unimplemented!(),
Value::Timeuuid(_) => unimplemented!(),
Value::Varchar(_) => unimplemented!(),
Value::Uuid(_) => unimplemented!(),
Value::Time(_) => unimplemented!(),
Value::Counter(_) => unimplemented!(),
Value::Tuple(_) => unimplemented!(),
Value::Udt(_) => unimplemented!(),
})
})
.collect()
Expand Down