diff --git a/Cargo.lock b/Cargo.lock index 668b6a213..a494e5308 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -765,6 +765,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctor" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" +dependencies = [ + "quote", + "syn 1.0.109", +] + [[package]] name = "cxx" version = "1.0.93" @@ -1020,6 +1030,15 @@ dependencies = [ "log", ] +[[package]] +name = "erased-serde" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f2b0c2380453a92ea8b6c8e5f64ecaafccddde8ceab55ff7a8ac1029f894569" +dependencies = [ + "serde", +] + [[package]] name = "errno" version = "0.2.8" @@ -1232,6 +1251,17 @@ dependencies = [ "wasi 0.11.0+wasi-snapshot-preview1", ] +[[package]] +name = "ghost" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e0cd8a998937e25c6ba7cc276b96ec5cc3f4dc4ab5de9ede4fb152bdd5c5eb" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "gimli" version = "0.27.2" @@ -1548,6 +1578,16 @@ version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" +[[package]] +name = "inventory" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498ae1c9c329c7972b917506239b557a60386839192f1cf0ca034f345b65db99" +dependencies = [ + "ctor", + "ghost", +] + [[package]] name = "io-lifetimes" version = "1.0.9" @@ -3078,6 +3118,7 @@ dependencies = [ "tracing", "tracing-appender", "tracing-subscriber", + "typetag", "uuid", "version-compare", ] @@ -3649,6 +3690,30 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +[[package]] +name = "typetag" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69bf9bd14fed1815295233a0eee76a963283b53ebcbd674d463f697d3bfcae0c" +dependencies = [ + "erased-serde", + "inventory", + "once_cell", + "serde", + "typetag-impl", +] + +[[package]] +name = "typetag-impl" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf9f5f225956dc2254c6c27500deac9390a066b2e8a1a571300627a7c4400a33" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "unicode-bidi" version = "0.3.13" diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 0b56338ff..8c81a8061 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -77,6 +77,7 @@ chacha20poly1305 = { version = "0.10.0", features = ["std"] } generic-array = { version = "0.14", features = ["serde"] } dyn-clone = "1.0.10" kafka-protocol = "0.6.0" +typetag = "0.2.5" [dev-dependencies] criterion = { git = "https://github.com/shotover/criterion.rs", branch = "0.4.0-bench_with_input_fn", features = ["async_tokio"] } diff --git a/shotover-proxy/benches/benches/chain.rs b/shotover-proxy/benches/benches/chain.rs index 64fabd8cc..23019ab66 100644 --- a/shotover-proxy/benches/benches/chain.rs +++ b/shotover-proxy/benches/benches/chain.rs @@ -12,11 +12,12 @@ use shotover_proxy::transforms::chain::{TransformChain, TransformChainBuilder}; use shotover_proxy::transforms::debug::returner::{DebugReturner, Response}; use shotover_proxy::transforms::filter::QueryTypeFilter; use shotover_proxy::transforms::null::NullSink; +#[cfg(feature = "alpha-transforms")] use shotover_proxy::transforms::protect::{KeyManagerConfig, ProtectConfig}; use shotover_proxy::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite; use shotover_proxy::transforms::redis::timestamp_tagging::RedisTimestampTagger; use shotover_proxy::transforms::throttling::RequestThrottlingConfig; -use shotover_proxy::transforms::Wrapper; +use shotover_proxy::transforms::{TransformConfig, Wrapper}; fn criterion_benchmark(c: &mut Criterion) { let rt = tokio::runtime::Runtime::new().unwrap(); @@ -175,11 +176,13 @@ fn criterion_benchmark(c: &mut Criterion) { { let chain = TransformChainBuilder::new( vec![ - RequestThrottlingConfig { - // an absurdly large value is given so that all messages will pass through - max_requests_per_second: std::num::NonZeroU32::new(100_000_000).unwrap(), - } - .get_builder() + rt.block_on( + RequestThrottlingConfig { + // an absurdly large value is given so that all messages will pass through + max_requests_per_second: std::num::NonZeroU32::new(100_000_000).unwrap(), + } + .get_builder("".to_owned()), + ) .unwrap(), Box::::default(), ], @@ -272,6 +275,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); } + #[cfg(feature = "alpha-transforms")] { let chain = TransformChainBuilder::new( vec![ @@ -290,7 +294,7 @@ fn criterion_benchmark(c: &mut Criterion) { kek_id: "".to_string(), }, } - .get_builder(), + .get_builder("".to_owned()), ) .unwrap(), Box::::default(), @@ -332,6 +336,7 @@ fn criterion_benchmark(c: &mut Criterion) { } } +#[cfg(feature = "alpha-transforms")] fn cassandra_parsed_query(query: &str) -> Wrapper { Wrapper::new_with_chain_name( vec![Message::from_frame(Frame::Cassandra(CassandraFrame { diff --git a/shotover-proxy/src/config/chain.rs b/shotover-proxy/src/config/chain.rs new file mode 100644 index 000000000..ed76e91a8 --- /dev/null +++ b/shotover-proxy/src/config/chain.rs @@ -0,0 +1,108 @@ +use crate::transforms::chain::TransformChainBuilder; +use crate::transforms::{TransformBuilder, TransformConfig}; +use anyhow::Result; +use serde::de::{DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor}; +use serde::Deserialize; +use std::fmt::{self, Debug}; +use std::iter; + +#[derive(Deserialize, Debug)] +pub struct TransformChainConfig( + #[serde(rename = "TransformChain", deserialize_with = "vec_transform_config")] + pub Vec>, +); + +impl TransformChainConfig { + pub async fn get_builder(&self, name: String) -> Result { + let mut transforms: Vec> = Vec::new(); + for tc in &self.0 { + transforms.push(tc.get_builder(name.clone()).await?) + } + Ok(TransformChainBuilder::new(transforms, name)) + } +} + +/// This function is a custom deserializer that works around a mismatch in the way yaml and typetag represent things, +/// resulting in typetagged structs with no fields failing to deserialize from a single line yaml entry. +/// e.g. with typetag + yaml + the default serializer: +/// this would fail to deserialize: +/// ```yaml +/// chain_config: +/// redis_chain: +/// - NullSink +/// ``` +/// +/// but this would work fine: +/// ```yaml +/// chain_config: +/// redis_chain: +/// - NullSink: {} +/// ``` +/// +/// With the use of this custom deserializer both cases now deserialize correctly. +/// The implementation was a suggestion from dtolnay: https://github.com/dtolnay/typetag/pull/40#issuecomment-1454961686 +fn vec_transform_config<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + struct VecTransformConfigVisitor; + + impl<'de> Visitor<'de> for VecTransformConfigVisitor { + type Value = Vec>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("list of TransformConfig") + } + + fn visit_seq(self, mut seq: S) -> Result + where + S: SeqAccess<'de>, + { + let mut vec = Vec::new(); + while let Some(item) = seq.next_element_seed(TransformConfigVisitor)? { + vec.push(item); + } + Ok(vec) + } + } + + struct TransformConfigVisitor; + + impl<'de> Visitor<'de> for TransformConfigVisitor { + type Value = Box; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("TransformConfig") + } + + fn visit_map(self, map: M) -> Result + where + M: MapAccess<'de>, + { + let de = serde::de::value::MapAccessDeserializer::new(map); + Deserialize::deserialize(de) + } + + fn visit_str(self, string: &str) -> Result + where + E: serde::de::Error, + { + let singleton_map = iter::once((string, ())); + let de = serde::de::value::MapDeserializer::new(singleton_map); + Deserialize::deserialize(de) + } + } + + impl<'de> DeserializeSeed<'de> for TransformConfigVisitor { + type Value = Box; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + } + + deserializer.deserialize_seq(VecTransformConfigVisitor) +} diff --git a/shotover-proxy/src/config/mod.rs b/shotover-proxy/src/config/mod.rs index 025ef1c86..bcb7b6ec9 100644 --- a/shotover-proxy/src/config/mod.rs +++ b/shotover-proxy/src/config/mod.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, Context, Result}; use serde::Deserialize; +pub mod chain; pub mod topology; #[derive(Deserialize, Debug, Clone)] diff --git a/shotover-proxy/src/config/topology.rs b/shotover-proxy/src/config/topology.rs index edce2b34e..9acc498ce 100644 --- a/shotover-proxy/src/config/topology.rs +++ b/shotover-proxy/src/config/topology.rs @@ -1,6 +1,6 @@ +use crate::config::chain::TransformChainConfig; use crate::sources::{Sources, SourcesConfig}; use crate::transforms::chain::TransformChainBuilder; -use crate::transforms::{build_chain_from_config, TransformsConfig}; use anyhow::{anyhow, Context, Result}; use itertools::Itertools; use serde::Deserialize; @@ -11,7 +11,7 @@ use tracing::info; #[derive(Deserialize, Debug)] pub struct Topology { pub sources: HashMap, - pub chain_config: HashMap>, + pub chain_config: HashMap, pub source_to_chain_mapping: HashMap, } @@ -20,6 +20,7 @@ impl Topology { let file = std::fs::File::open(filepath).map_err(|err| { anyhow!(err).context(format!("Couldn't open the topology file {}", filepath)) })?; + let deserializer = serde_yaml::Deserializer::from_reader(file); serde_yaml::with::singleton_map_recursive::deserialize(deserializer) .context(format!("Failed to parse topology file {}", filepath)) @@ -28,10 +29,7 @@ impl Topology { async fn build_chains(&self) -> Result> { let mut result = HashMap::new(); for (key, value) in &self.chain_config { - result.insert( - key.clone(), - build_chain_from_config(key.clone(), value).await?, - ); + result.insert(key.clone(), value.get_builder(key.clone()).await?); } Ok(result) } @@ -95,22 +93,27 @@ impl Topology { #[cfg(test)] mod topology_tests { + use crate::config::chain::TransformChainConfig; use crate::config::topology::Topology; use crate::transforms::coalesce::CoalesceConfig; + use crate::transforms::debug::printer::DebugPrinterConfig; + use crate::transforms::null::NullSinkConfig; + use crate::transforms::TransformConfig; use crate::{ sources::{redis::RedisConfig, Sources, SourcesConfig}, transforms::{ distributed::tuneable_consistency_scatter::TuneableConsistencyScatterConfig, parallel_map::ParallelMapConfig, redis::cache::RedisConfig as RedisCacheConfig, - TransformsConfig, }, }; use std::collections::HashMap; use tokio::sync::watch; - async fn run_test_topology(chain: Vec) -> anyhow::Result> { + async fn run_test_topology( + chain: Vec>, + ) -> anyhow::Result> { let mut chain_config = HashMap::new(); - chain_config.insert("redis_chain".to_string(), chain); + chain_config.insert("redis_chain".to_string(), TransformChainConfig(chain)); let redis_source = SourcesConfig::Redis(RedisConfig { listen_addr: "127.0.0.1".to_string(), @@ -147,12 +150,9 @@ redis_chain: #[tokio::test] async fn test_validate_chain_valid_chain() { - run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - ]) - .await - .unwrap(); + run_test_topology(vec![Box::new(DebugPrinterConfig), Box::new(NullSinkConfig)]) + .await + .unwrap(); } #[tokio::test] @@ -169,11 +169,11 @@ redis_chain: "#; let error = run_test_topology(vec![ - TransformsConfig::Coalesce(CoalesceConfig { + Box::new(CoalesceConfig { flush_when_buffered_message_count: None, flush_when_millis_since_last_flush: None, }), - TransformsConfig::NullSink, + Box::new(NullSinkConfig), ]) .await .unwrap_err() @@ -190,9 +190,9 @@ redis_chain: "#; let error = run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - TransformsConfig::NullSink, + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + Box::new(NullSinkConfig), ]) .await .unwrap_err() @@ -209,9 +209,9 @@ redis_chain: "#; let error = run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), ]) .await .unwrap_err() @@ -229,10 +229,10 @@ redis_chain: "#; let error = run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - TransformsConfig::DebugPrinter, + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + Box::new(DebugPrinterConfig), ]) .await .unwrap_err() @@ -242,20 +242,20 @@ redis_chain: } #[tokio::test] - async fn test_validate_chain_valid_subchain_scatter() { - let subchain = vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - ]; + async fn test_validate_chain_valid_subchain_consistent_scatter() { + let subchain = TransformChainConfig(vec![ + Box::new(DebugPrinterConfig) as Box, + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + ]); let mut route_map = HashMap::new(); route_map.insert("subchain-1".to_string(), subchain); run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::TuneableConsistencyScatter(TuneableConsistencyScatterConfig { + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(TuneableConsistencyScatterConfig { route_map, write_consistency: 1, read_consistency: 1, @@ -274,20 +274,20 @@ redis_chain: Terminating transform "NullSink" is not last in chain. Terminating transform must be last in chain. "#; - let subchain = vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - ]; + let subchain = TransformChainConfig(vec![ + Box::new(DebugPrinterConfig) as Box, + Box::new(NullSinkConfig), + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + ]); let mut route_map = HashMap::new(); route_map.insert("subchain-1".to_string(), subchain); let error = run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::TuneableConsistencyScatter(TuneableConsistencyScatterConfig { + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(TuneableConsistencyScatterConfig { route_map, write_consistency: 1, read_consistency: 1, @@ -302,22 +302,20 @@ redis_chain: #[tokio::test] async fn test_validate_chain_valid_subchain_redis_cache() { - let chain = vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - ]; - let caching_schema = HashMap::new(); run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::RedisCache(RedisCacheConfig { - chain, + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(RedisCacheConfig { + chain: TransformChainConfig(vec![ + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + ]), caching_schema, }), - TransformsConfig::NullSink, + Box::new(NullSinkConfig), ]) .await .unwrap(); @@ -332,21 +330,19 @@ redis_chain: Terminating transform "NullSink" is not last in chain. Terminating transform must be last in chain. "#; - let chain = vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - ]; - let error = run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::RedisCache(RedisCacheConfig { - chain, + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(RedisCacheConfig { + chain: TransformChainConfig(vec![ + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + ]), caching_schema: HashMap::new(), }), - TransformsConfig::NullSink, + Box::new(NullSinkConfig), ]) .await .unwrap_err() @@ -357,18 +353,16 @@ redis_chain: #[tokio::test] async fn test_validate_chain_valid_subchain_parallel_map() { - let chain = vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - ]; - run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::ParallelMap(ParallelMapConfig { + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(ParallelMapConfig { parallelism: 1, - chain, + chain: TransformChainConfig(vec![ + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + ]), ordered_results: false, }), ]) @@ -385,19 +379,17 @@ redis_chain: Terminating transform "NullSink" is not last in chain. Terminating transform must be last in chain. "#; - let chain = vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - ]; - let error = run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::ParallelMap(ParallelMapConfig { + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(ParallelMapConfig { parallelism: 1, - chain, + chain: TransformChainConfig(vec![ + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + ]), ordered_results: false, }), ]) @@ -417,20 +409,20 @@ redis_chain: Terminating transform "NullSink" is not last in chain. Terminating transform must be last in chain. "#; - let subchain = vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - ]; + let subchain = TransformChainConfig(vec![ + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + ]); let mut route_map = HashMap::new(); route_map.insert("subchain-1".to_string(), subchain); let error = run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::TuneableConsistencyScatter(TuneableConsistencyScatterConfig { + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(TuneableConsistencyScatterConfig { route_map, write_consistency: 1, read_consistency: 1, @@ -452,18 +444,18 @@ redis_chain: Non-terminating transform "DebugPrinter" is last in chain. Last transform must be terminating. "#; - let subchain = vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - ]; + let subchain = TransformChainConfig(vec![ + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + ]); let mut route_map = HashMap::new(); route_map.insert("subchain-1".to_string(), subchain); let error = run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::TuneableConsistencyScatter(TuneableConsistencyScatterConfig { + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(TuneableConsistencyScatterConfig { route_map, write_consistency: 1, read_consistency: 1, @@ -486,19 +478,19 @@ redis_chain: Non-terminating transform "DebugPrinter" is last in chain. Last transform must be terminating. "#; - let subchain = vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::NullSink, - TransformsConfig::DebugPrinter, - ]; + let subchain = TransformChainConfig(vec![ + Box::new(DebugPrinterConfig), + Box::new(NullSinkConfig), + Box::new(DebugPrinterConfig), + ]); let mut route_map = HashMap::new(); route_map.insert("subchain-1".to_string(), subchain); let error = run_test_topology(vec![ - TransformsConfig::DebugPrinter, - TransformsConfig::DebugPrinter, - TransformsConfig::TuneableConsistencyScatter(TuneableConsistencyScatterConfig { + Box::new(DebugPrinterConfig), + Box::new(DebugPrinterConfig), + Box::new(TuneableConsistencyScatterConfig { route_map, write_consistency: 1, read_consistency: 1, diff --git a/shotover-proxy/src/lib.rs b/shotover-proxy/src/lib.rs index 2cb7bfa91..2b568f039 100644 --- a/shotover-proxy/src/lib.rs +++ b/shotover-proxy/src/lib.rs @@ -6,7 +6,7 @@ //! the [`transforms::Transforms`] enum. //! //! To allow your [`transforms::Transform`] to be configurable in Shotover config files you will need to create -//! a serializable config struct and register it in the [`transforms::TransformsConfig`] enum (note plural Transform**s**). +//! a serializable config struct and implement [`transforms::TransformConfig`] trait. //! //! ## Messages //! * [`message::Message`], the main struct that carries database queries/frames around in Shotover. @@ -15,7 +15,7 @@ //! * [`transforms::Wrapper`], used to wrap messages as they traverse the [`transforms::Transform`] chain. //! * [`transforms::Transform`], the main extension trait for adding your own behaviour to Shotover. //! * [`transforms::Transforms`], the enum to register with (add a variant) for enabling your own transform. -//! * [`transforms::TransformsConfig`], the enum to register with (add a variant) for configuring your own transform. +//! * [`transforms::TransformConfig`], the trait to implement for configuring your own transform. // Accidentally printing would break json log output #![deny(clippy::print_stdout)] diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 742da23f8..5eaab63ea 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -2,7 +2,7 @@ use crate::frame::{CassandraOperation, CassandraResult, Frame}; use crate::message::Message; use crate::message_value::{IntSize, MessageValue}; use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event; -use crate::transforms::Transforms; +use crate::transforms::{TransformConfig, Transforms}; use crate::{ error::ChainResponse, transforms::{Transform, TransformBuilder, Wrapper}, @@ -20,8 +20,10 @@ pub struct CassandraPeersRewriteConfig { pub port: u16, } -impl CassandraPeersRewriteConfig { - pub async fn get_builder(&self) -> Result> { +#[typetag::deserialize(name = "CassandraPeersRewrite")] +#[async_trait(?Send)] +impl TransformConfig for CassandraPeersRewriteConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(CassandraPeersRewrite::new(self.port))) } } diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs index df0626f8b..0e6359de3 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster/mod.rs @@ -6,7 +6,7 @@ use crate::message::{Message, Messages, Metadata}; use crate::message_value::{IntSize, MessageValue}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::{CassandraConnection, Response}; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use cassandra_protocol::events::ServerEvent; @@ -79,8 +79,10 @@ pub struct CassandraSinkClusterConfig { pub read_timeout: Option, } -impl CassandraSinkClusterConfig { - pub async fn get_builder(&self, chain_name: String) -> Result> { +#[typetag::deserialize(name = "CassandraSinkCluster")] +#[async_trait(?Send)] +impl TransformConfig for CassandraSinkClusterConfig { + async fn get_builder(&self, chain_name: String) -> Result> { let tls = self.tls.clone().map(TlsConnector::new).transpose()?; let mut shotover_nodes = self.shotover_nodes.clone(); let index = self diff --git a/shotover-proxy/src/transforms/cassandra/sink_single.rs b/shotover-proxy/src/transforms/cassandra/sink_single.rs index 1a29b201f..4c78eac4f 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_single.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_single.rs @@ -5,7 +5,7 @@ use crate::frame::cassandra::CassandraMetadata; use crate::message::{Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::Response; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use cassandra_protocol::frame::Version; @@ -25,8 +25,10 @@ pub struct CassandraSinkSingleConfig { pub read_timeout: Option, } -impl CassandraSinkSingleConfig { - pub async fn get_builder(&self, chain_name: String) -> Result> { +#[typetag::deserialize(name = "CassandraSinkSingle")] +#[async_trait(?Send)] +impl TransformConfig for CassandraSinkSingleConfig { + async fn get_builder(&self, chain_name: String) -> Result> { let tls = self.tls.clone().map(TlsConnector::new).transpose()?; Ok(Box::new(CassandraSinkSingleBuilder::new( self.address.clone(), diff --git a/shotover-proxy/src/transforms/coalesce.rs b/shotover-proxy/src/transforms/coalesce.rs index 734c75409..bbd1a10a4 100644 --- a/shotover-proxy/src/transforms/coalesce.rs +++ b/shotover-proxy/src/transforms/coalesce.rs @@ -1,7 +1,6 @@ -use super::Transforms; use crate::error::ChainResponse; use crate::message::Messages; -use crate::transforms::{Transform, TransformBuilder, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; use anyhow::Result; use async_trait::async_trait; use serde::Deserialize; @@ -21,8 +20,10 @@ pub struct CoalesceConfig { pub flush_when_millis_since_last_flush: Option, } -impl CoalesceConfig { - pub async fn get_builder(&self) -> Result> { +#[typetag::deserialize(name = "Coalesce")] +#[async_trait(?Send)] +impl TransformConfig for CoalesceConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(Coalesce { buffer: Vec::with_capacity(self.flush_when_buffered_message_count.unwrap_or(0)), flush_when_buffered_message_count: self.flush_when_buffered_message_count, diff --git a/shotover-proxy/src/transforms/debug/force_parse.rs b/shotover-proxy/src/transforms/debug/force_parse.rs index 781194114..2c46ca55c 100644 --- a/shotover-proxy/src/transforms/debug/force_parse.rs +++ b/shotover-proxy/src/transforms/debug/force_parse.rs @@ -5,7 +5,10 @@ /// without worrying about the performance impact of other transform logic. /// It could also be used to ensure that messages round trip correctly when parsed. use crate::error::ChainResponse; +#[cfg(feature = "alpha-transforms")] +use crate::transforms::TransformConfig; use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +#[cfg(feature = "alpha-transforms")] use anyhow::Result; use async_trait::async_trait; use serde::Deserialize; @@ -14,12 +17,15 @@ use serde::Deserialize; /// Must be individually enabled at the request or response level. #[derive(Deserialize, Debug)] pub struct DebugForceParseConfig { - parse_requests: bool, - parse_responses: bool, + pub parse_requests: bool, + pub parse_responses: bool, } -impl DebugForceParseConfig { - pub async fn get_builder(&self) -> Result> { +#[cfg(feature = "alpha-transforms")] +#[typetag::deserialize(name = "DebugForceParse")] +#[async_trait(?Send)] +impl TransformConfig for DebugForceParseConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(DebugForceParse { parse_requests: self.parse_requests, parse_responses: self.parse_responses, @@ -33,12 +39,15 @@ impl DebugForceParseConfig { /// Must be individually enabled at the request or response level. #[derive(Deserialize, Debug)] pub struct DebugForceEncodeConfig { - encode_requests: bool, - encode_responses: bool, + pub encode_requests: bool, + pub encode_responses: bool, } -impl DebugForceEncodeConfig { - pub async fn get_builder(&self) -> Result> { +#[cfg(feature = "alpha-transforms")] +#[typetag::deserialize(name = "DebugForceEncode")] +#[async_trait(?Send)] +impl TransformConfig for DebugForceEncodeConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(DebugForceParse { parse_requests: self.encode_requests, parse_responses: self.encode_responses, diff --git a/shotover-proxy/src/transforms/debug/printer.rs b/shotover-proxy/src/transforms/debug/printer.rs index 6b42104d9..092cacd43 100644 --- a/shotover-proxy/src/transforms/debug/printer.rs +++ b/shotover-proxy/src/transforms/debug/printer.rs @@ -1,8 +1,20 @@ -use tracing::info; - use crate::error::ChainResponse; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use anyhow::Result; use async_trait::async_trait; +use serde::Deserialize; +use tracing::info; + +#[derive(Deserialize, Debug)] +pub struct DebugPrinterConfig; + +#[typetag::deserialize(name = "DebugPrinter")] +#[async_trait(?Send)] +impl TransformConfig for DebugPrinterConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { + Ok(Box::new(DebugPrinter::new())) + } +} #[derive(Debug, Clone)] pub struct DebugPrinter { diff --git a/shotover-proxy/src/transforms/debug/returner.rs b/shotover-proxy/src/transforms/debug/returner.rs index 798e66364..0accd7b19 100644 --- a/shotover-proxy/src/transforms/debug/returner.rs +++ b/shotover-proxy/src/transforms/debug/returner.rs @@ -1,6 +1,8 @@ use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages}; -use crate::transforms::{ChainResponse, Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{ + ChainResponse, Transform, TransformBuilder, TransformConfig, Transforms, Wrapper, +}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use serde::Deserialize; @@ -11,8 +13,10 @@ pub struct DebugReturnerConfig { response: Response, } -impl DebugReturnerConfig { - pub async fn get_builder(&self) -> Result> { +#[typetag::deserialize(name = "DebugReturner")] +#[async_trait(?Send)] +impl TransformConfig for DebugReturnerConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(DebugReturner::new(self.response.clone()))) } } diff --git a/shotover-proxy/src/transforms/distributed/tuneable_consistency_scatter.rs b/shotover-proxy/src/transforms/distributed/tuneable_consistency_scatter.rs index 3da93d169..dd8715635 100644 --- a/shotover-proxy/src/transforms/distributed/tuneable_consistency_scatter.rs +++ b/shotover-proxy/src/transforms/distributed/tuneable_consistency_scatter.rs @@ -1,10 +1,9 @@ +use crate::config::chain::TransformChainConfig; use crate::error::ChainResponse; use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, QueryType}; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; -use crate::transforms::{ - build_chain_from_config, Transform, TransformBuilder, Transforms, TransformsConfig, Wrapper, -}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; use anyhow::Result; use async_trait::async_trait; use futures::stream::FuturesUnordered; @@ -15,18 +14,20 @@ use tracing::{error, warn}; #[derive(Deserialize, Debug)] pub struct TuneableConsistencyScatterConfig { - pub route_map: HashMap>, + pub route_map: HashMap, pub write_consistency: i32, pub read_consistency: i32, } -impl TuneableConsistencyScatterConfig { - pub async fn get_builder(&self) -> Result> { +#[typetag::deserialize(name = "TuneableConsistencyScatter")] +#[async_trait(?Send)] +impl TransformConfig for TuneableConsistencyScatterConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { let mut route_map = Vec::with_capacity(self.route_map.len()); warn!("Using this transform is considered unstable - Does not work with REDIS pipelines"); for (key, value) in &self.route_map { - route_map.push(build_chain_from_config(key.clone(), value).await?); + route_map.push(value.get_builder(key.clone()).await?); } route_map.sort_by_key(|x| x.name.clone()); diff --git a/shotover-proxy/src/transforms/filter.rs b/shotover-proxy/src/transforms/filter.rs index b9715e40d..80e288b82 100644 --- a/shotover-proxy/src/transforms/filter.rs +++ b/shotover-proxy/src/transforms/filter.rs @@ -1,6 +1,6 @@ use crate::error::ChainResponse; use crate::message::{Message, QueryType}; -use crate::transforms::{Transform, TransformBuilder, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; use serde::Deserialize; @@ -20,8 +20,10 @@ pub struct QueryTypeFilterConfig { pub filter: QueryType, } -impl QueryTypeFilterConfig { - pub async fn get_builder(&self) -> Result> { +#[typetag::deserialize(name = "QueryTypeFilter")] +#[async_trait(?Send)] +impl TransformConfig for QueryTypeFilterConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(QueryTypeFilter { filter: self.filter.clone(), })) diff --git a/shotover-proxy/src/transforms/kafka/sink_single.rs b/shotover-proxy/src/transforms/kafka/sink_single.rs index 7b9d6b134..008931bcd 100644 --- a/shotover-proxy/src/transforms/kafka/sink_single.rs +++ b/shotover-proxy/src/transforms/kafka/sink_single.rs @@ -22,8 +22,14 @@ pub struct KafkaSinkSingleConfig { pub read_timeout: Option, } -impl KafkaSinkSingleConfig { - pub async fn get_builder(&self, chain_name: String) -> Result> { +#[cfg(feature = "alpha-transforms")] +use crate::transforms::TransformConfig; + +#[cfg(feature = "alpha-transforms")] +#[typetag::deserialize(name = "KafkaSinkSingle")] +#[async_trait(?Send)] +impl TransformConfig for KafkaSinkSingleConfig { + async fn get_builder(&self, chain_name: String) -> Result> { Ok(Box::new(KafkaSinkSingleBuilder::new( self.address.clone(), chain_name, diff --git a/shotover-proxy/src/transforms/load_balance.rs b/shotover-proxy/src/transforms/load_balance.rs index c5bebe982..989f851ae 100644 --- a/shotover-proxy/src/transforms/load_balance.rs +++ b/shotover-proxy/src/transforms/load_balance.rs @@ -1,9 +1,8 @@ use super::Transforms; +use crate::config::chain::TransformChainConfig; use crate::error::ChainResponse; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; -use crate::transforms::{ - build_chain_from_config, Transform, TransformBuilder, TransformsConfig, Wrapper, -}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; use serde::Deserialize; @@ -14,12 +13,14 @@ use tokio::sync::Mutex; pub struct ConnectionBalanceAndPoolConfig { pub name: String, pub max_connections: usize, - pub chain: Vec, + pub chain: TransformChainConfig, } -impl ConnectionBalanceAndPoolConfig { - pub async fn get_builder(&self) -> Result> { - let chain = build_chain_from_config(self.name.clone(), &self.chain).await?; +#[typetag::deserialize(name = "ConnectionBalanceAndPool")] +#[async_trait(?Send)] +impl TransformConfig for ConnectionBalanceAndPoolConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { + let chain = self.chain.get_builder(self.name.clone()).await?; Ok(Box::new(ConnectionBalanceAndPoolBuilder { max_connections: self.max_connections, diff --git a/shotover-proxy/src/transforms/mod.rs b/shotover-proxy/src/transforms/mod.rs index 8d9fa29c0..9be452962 100644 --- a/shotover-proxy/src/transforms/mod.rs +++ b/shotover-proxy/src/transforms/mod.rs @@ -1,52 +1,35 @@ use crate::error::ChainResponse; use crate::message::Messages; -use crate::transforms::cassandra::peers_rewrite::{ - CassandraPeersRewrite, CassandraPeersRewriteConfig, -}; -use crate::transforms::cassandra::sink_cluster::{ - CassandraSinkCluster, CassandraSinkClusterConfig, -}; -use crate::transforms::cassandra::sink_single::{CassandraSinkSingle, CassandraSinkSingleConfig}; -use crate::transforms::chain::TransformChainBuilder; -use crate::transforms::coalesce::{Coalesce, CoalesceConfig}; +use crate::transforms::cassandra::peers_rewrite::CassandraPeersRewrite; +use crate::transforms::cassandra::sink_cluster::CassandraSinkCluster; +use crate::transforms::cassandra::sink_single::CassandraSinkSingle; +use crate::transforms::coalesce::Coalesce; use crate::transforms::debug::force_parse::DebugForceParse; -#[cfg(feature = "alpha-transforms")] -use crate::transforms::debug::force_parse::{DebugForceEncodeConfig, DebugForceParseConfig}; use crate::transforms::debug::printer::DebugPrinter; use crate::transforms::debug::random_delay::DebugRandomDelay; -use crate::transforms::debug::returner::{DebugReturner, DebugReturnerConfig}; -use crate::transforms::distributed::tuneable_consistency_scatter::{ - TuneableConsistencyScatterConfig, TuneableConsistentencyScatter, -}; -use crate::transforms::filter::{QueryTypeFilter, QueryTypeFilterConfig}; +use crate::transforms::debug::returner::DebugReturner; +use crate::transforms::distributed::tuneable_consistency_scatter::TuneableConsistentencyScatter; +use crate::transforms::filter::QueryTypeFilter; use crate::transforms::kafka::sink_single::KafkaSinkSingle; -#[cfg(feature = "alpha-transforms")] -use crate::transforms::kafka::sink_single::KafkaSinkSingleConfig; use crate::transforms::load_balance::ConnectionBalanceAndPool; use crate::transforms::loopback::Loopback; use crate::transforms::null::NullSink; -use crate::transforms::parallel_map::{ParallelMap, ParallelMapConfig}; +use crate::transforms::parallel_map::ParallelMap; use crate::transforms::protect::Protect; -#[cfg(feature = "alpha-transforms")] -use crate::transforms::protect::ProtectConfig; -use crate::transforms::query_counter::{QueryCounter, QueryCounterConfig}; -use crate::transforms::redis::cache::{RedisConfig, SimpleRedisCache}; -use crate::transforms::redis::cluster_ports_rewrite::{ - RedisClusterPortsRewrite, RedisClusterPortsRewriteConfig, -}; -use crate::transforms::redis::sink_cluster::{RedisSinkCluster, RedisSinkClusterConfig}; -use crate::transforms::redis::sink_single::{RedisSinkSingle, RedisSinkSingleConfig}; +use crate::transforms::query_counter::QueryCounter; +use crate::transforms::redis::cache::SimpleRedisCache; +use crate::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite; +use crate::transforms::redis::sink_cluster::RedisSinkCluster; +use crate::transforms::redis::sink_single::RedisSinkSingle; use crate::transforms::redis::timestamp_tagging::RedisTimestampTagger; -use crate::transforms::tee::{Tee, TeeConfig}; -use crate::transforms::throttling::{RequestThrottling, RequestThrottlingConfig}; +use crate::transforms::tee::Tee; +use crate::transforms::throttling::RequestThrottling; use anyhow::{anyhow, Result}; -use async_recursion::async_recursion; use async_trait::async_trait; use core::fmt; use dyn_clone::DynClone; use futures::Future; use metrics::{counter, histogram}; -use serde::Deserialize; use std::fmt::{Debug, Formatter}; use std::iter::Rev; use std::net::SocketAddr; @@ -235,95 +218,10 @@ impl Transforms { } } -/// The TransformsConfig enum is responsible for TransformConfig registration and enum dispatch -/// in the transform chain. Allows you to register your config struct for the config file. -#[derive(Deserialize, Debug)] -pub enum TransformsConfig { - #[cfg(feature = "alpha-transforms")] - KafkaSinkSingle(KafkaSinkSingleConfig), - CassandraSinkSingle(CassandraSinkSingleConfig), - CassandraSinkCluster(CassandraSinkClusterConfig), - RedisSinkSingle(RedisSinkSingleConfig), - CassandraPeersRewrite(CassandraPeersRewriteConfig), - RedisCache(RedisConfig), - Tee(TeeConfig), - TuneableConsistencyScatter(TuneableConsistencyScatterConfig), - RedisSinkCluster(RedisSinkClusterConfig), - RedisClusterPortsRewrite(RedisClusterPortsRewriteConfig), - RedisTimestampTagger, - DebugPrinter, - DebugReturner(DebugReturnerConfig), - NullSink, - #[cfg(test)] - Loopback, - #[cfg(feature = "alpha-transforms")] - Protect(ProtectConfig), - #[cfg(feature = "alpha-transforms")] - DebugForceParse(DebugForceParseConfig), - #[cfg(feature = "alpha-transforms")] - DebugForceEncode(DebugForceEncodeConfig), - ParallelMap(ParallelMapConfig), - //PoolConnections(ConnectionBalanceAndPoolConfig), - Coalesce(CoalesceConfig), - QueryTypeFilter(QueryTypeFilterConfig), - QueryCounter(QueryCounterConfig), - RequestThrottling(RequestThrottlingConfig), -} - -impl TransformsConfig { - #[async_recursion] - pub async fn get_builder(&self, chain_name: String) -> Result> { - match self { - #[cfg(feature = "alpha-transforms")] - TransformsConfig::KafkaSinkSingle(c) => c.get_builder(chain_name).await, - TransformsConfig::CassandraSinkSingle(c) => c.get_builder(chain_name).await, - TransformsConfig::CassandraSinkCluster(c) => c.get_builder(chain_name).await, - TransformsConfig::CassandraPeersRewrite(c) => c.get_builder().await, - TransformsConfig::RedisCache(r) => r.get_builder().await, - TransformsConfig::Tee(t) => t.get_builder().await, - TransformsConfig::RedisSinkSingle(r) => r.get_builder(chain_name).await, - TransformsConfig::TuneableConsistencyScatter(c) => c.get_builder().await, - TransformsConfig::RedisTimestampTagger => { - Ok(Box::new(RedisTimestampTagger::new()) as Box) - } - TransformsConfig::RedisClusterPortsRewrite(r) => r.get_builder().await, - TransformsConfig::DebugPrinter => { - Ok(Box::new(DebugPrinter::new()) as Box) - } - TransformsConfig::DebugReturner(d) => d.get_builder().await, - TransformsConfig::NullSink => { - Ok(Box::::default() as Box) - } - #[cfg(test)] - TransformsConfig::Loopback => { - Ok(Box::::default() as Box) - } - #[cfg(feature = "alpha-transforms")] - TransformsConfig::Protect(p) => p.get_builder().await, - #[cfg(feature = "alpha-transforms")] - TransformsConfig::DebugForceParse(d) => d.get_builder().await, - #[cfg(feature = "alpha-transforms")] - TransformsConfig::DebugForceEncode(d) => d.get_builder().await, - TransformsConfig::RedisSinkCluster(r) => r.get_builder(chain_name).await, - TransformsConfig::ParallelMap(s) => s.get_builder().await, - // TransformsConfig::PoolConnections(s) => s.get_builder().await, - TransformsConfig::Coalesce(s) => s.get_builder().await, - TransformsConfig::QueryTypeFilter(s) => s.get_builder().await, - TransformsConfig::QueryCounter(s) => s.get_builder().await, - TransformsConfig::RequestThrottling(s) => s.get_builder(), - } - } -} - -pub async fn build_chain_from_config( - name: String, - transform_configs: &[TransformsConfig], -) -> Result { - let mut transforms: Vec> = Vec::new(); - for tc in transform_configs { - transforms.push(tc.get_builder(name.clone()).await?) - } - Ok(TransformChainBuilder::new(transforms, name)) +#[typetag::deserialize] +#[async_trait(?Send)] +pub trait TransformConfig: Debug { + async fn get_builder(&self, chain_name: String) -> Result>; } /// The [`Wrapper`] struct is passed into each transform and contains a list of mutable references to the @@ -515,7 +413,8 @@ impl<'a> Wrapper<'a> { /// but make sure to copy the value from the TransformBuilder to ensure all instances are referring to the same value. /// /// Once you have created your [`Transform`], you will need to create -/// new enum variants in [Transforms], [TransformBuilder] and [TransformsConfig] to make them configurable in Shotover. +/// new enum variants in [Transforms]. +/// And implement the [TransformBuilder] and [TransformConfig] traits to make them configurable in Shotover. /// Shotover uses a concept called enum dispatch to provide dynamic configuration of transform chains /// with minimal impact on performance. /// diff --git a/shotover-proxy/src/transforms/null.rs b/shotover-proxy/src/transforms/null.rs index 0025d344c..96407cc7b 100644 --- a/shotover-proxy/src/transforms/null.rs +++ b/shotover-proxy/src/transforms/null.rs @@ -1,8 +1,19 @@ use crate::error::ChainResponse; -use crate::transforms::{Transform, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use anyhow::Result; use async_trait::async_trait; +use serde::Deserialize; -use super::{TransformBuilder, Transforms}; +#[derive(Deserialize, Debug)] +pub struct NullSinkConfig; + +#[typetag::deserialize(name = "NullSink")] +#[async_trait(?Send)] +impl TransformConfig for NullSinkConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { + Ok(Box::new(NullSink {})) + } +} #[derive(Debug, Clone, Default)] pub struct NullSink {} diff --git a/shotover-proxy/src/transforms/parallel_map.rs b/shotover-proxy/src/transforms/parallel_map.rs index 0786826f5..20d53e2ee 100644 --- a/shotover-proxy/src/transforms/parallel_map.rs +++ b/shotover-proxy/src/transforms/parallel_map.rs @@ -1,9 +1,8 @@ +use crate::config::chain::TransformChainConfig; use crate::error::ChainResponse; use crate::message::Messages; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; -use crate::transforms::{ - build_chain_from_config, Transform, TransformBuilder, Transforms, TransformsConfig, Wrapper, -}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; use anyhow::Result; use async_trait::async_trait; use futures::stream::{FuturesOrdered, FuturesUnordered}; @@ -69,13 +68,15 @@ where #[derive(Deserialize, Debug)] pub struct ParallelMapConfig { pub parallelism: u32, - pub chain: Vec, + pub chain: TransformChainConfig, pub ordered_results: bool, } -impl ParallelMapConfig { - pub async fn get_builder(&self) -> Result> { - let chain = build_chain_from_config("parallel_map_chain".into(), &self.chain).await?; +#[typetag::deserialize(name = "ParallelMap")] +#[async_trait(?Send)] +impl TransformConfig for ParallelMapConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { + let chain = self.chain.get_builder("parallel_map_chain".into()).await?; Ok(Box::new(ParallelMapBuilder { chains: std::iter::repeat(chain) diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 5a211f15f..dea521e24 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -25,8 +25,14 @@ pub struct ProtectConfig { pub key_manager: KeyManagerConfig, } -impl ProtectConfig { - pub async fn get_builder(&self) -> Result> { +#[cfg(feature = "alpha-transforms")] +use crate::transforms::TransformConfig; + +#[cfg(feature = "alpha-transforms")] +#[typetag::deserialize(name = "Protect")] +#[async_trait(?Send)] +impl TransformConfig for ProtectConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(Protect { keyspace_table_columns: self .keyspace_table_columns diff --git a/shotover-proxy/src/transforms/query_counter.rs b/shotover-proxy/src/transforms/query_counter.rs index ba395e15d..d94c78a67 100644 --- a/shotover-proxy/src/transforms/query_counter.rs +++ b/shotover-proxy/src/transforms/query_counter.rs @@ -1,6 +1,7 @@ use crate::error::ChainResponse; use crate::frame::Frame; use crate::frame::RedisFrame; +use crate::transforms::TransformConfig; use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; use anyhow::Result; use async_trait::async_trait; @@ -85,8 +86,10 @@ fn get_redis_query_type(frame: &RedisFrame) -> Option { None } -impl QueryCounterConfig { - pub async fn get_builder(&self) -> Result> { +#[typetag::deserialize(name = "QueryCounter")] +#[async_trait(?Send)] +impl TransformConfig for QueryCounterConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(QueryCounter::new(self.name.clone()))) } } diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index bab2d027c..de2a8c1bc 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -1,10 +1,9 @@ +use crate::config::chain::TransformChainConfig; use crate::error::ChainResponse; use crate::frame::{CassandraFrame, CassandraOperation, Frame, RedisFrame}; use crate::message::{Message, Messages}; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; -use crate::transforms::{ - build_chain_from_config, Transform, TransformBuilder, Transforms, TransformsConfig, Wrapper, -}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; use anyhow::{bail, Result}; use async_trait::async_trait; use bytes::Bytes; @@ -78,11 +77,13 @@ impl From<&TableCacheSchemaConfig> for TableCacheSchema { #[derive(Deserialize, Debug)] pub struct RedisConfig { pub caching_schema: HashMap, - pub chain: Vec, + pub chain: TransformChainConfig, } -impl RedisConfig { - pub async fn get_builder(&self) -> Result> { +#[typetag::deserialize(name = "RedisCache")] +#[async_trait(?Send)] +impl TransformConfig for RedisConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { let missed_requests = register_counter!("cache_miss"); let caching_schema: HashMap = self @@ -92,7 +93,7 @@ impl RedisConfig { .collect(); Ok(Box::new(SimpleRedisCacheBuilder { - cache_chain: build_chain_from_config("cache_chain".to_string(), &self.chain).await?, + cache_chain: self.chain.get_builder("cache_chain".to_string()).await?, caching_schema, missed_requests, })) diff --git a/shotover-proxy/src/transforms/redis/cluster_ports_rewrite.rs b/shotover-proxy/src/transforms/redis/cluster_ports_rewrite.rs index 1ea4a0e68..e0d01b0f7 100644 --- a/shotover-proxy/src/transforms/redis/cluster_ports_rewrite.rs +++ b/shotover-proxy/src/transforms/redis/cluster_ports_rewrite.rs @@ -6,15 +6,17 @@ use serde::Deserialize; use crate::error::ChainResponse; use crate::frame::Frame; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; #[derive(Deserialize, Debug)] pub struct RedisClusterPortsRewriteConfig { pub new_port: u16, } -impl RedisClusterPortsRewriteConfig { - pub async fn get_builder(&self) -> Result> { +#[typetag::deserialize(name = "RedisClusterPortsRewrite")] +#[async_trait(?Send)] +impl TransformConfig for RedisClusterPortsRewriteConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(RedisClusterPortsRewrite { new_port: self.new_port, })) diff --git a/shotover-proxy/src/transforms/redis/sink_cluster.rs b/shotover-proxy/src/transforms/redis/sink_cluster.rs index b443c5924..3b991d7a8 100644 --- a/shotover-proxy/src/transforms/redis/sink_cluster.rs +++ b/shotover-proxy/src/transforms/redis/sink_cluster.rs @@ -9,7 +9,8 @@ use crate::transforms::redis::TransformError; use crate::transforms::util::cluster_connection_pool::{Authenticator, ConnectionPool}; use crate::transforms::util::{Request, Response}; use crate::transforms::{ - ResponseFuture, Transform, TransformBuilder, Transforms, Wrapper, CONTEXT_CHAIN_NAME, + ResponseFuture, Transform, TransformBuilder, TransformConfig, Transforms, Wrapper, + CONTEXT_CHAIN_NAME, }; use anyhow::{anyhow, bail, ensure, Context, Result}; use async_trait::async_trait; @@ -45,8 +46,10 @@ pub struct RedisSinkClusterConfig { pub connect_timeout_ms: u64, } -impl RedisSinkClusterConfig { - pub async fn get_builder(&self, chain_name: String) -> Result> { +#[typetag::deserialize(name = "RedisSinkCluster")] +#[async_trait(?Send)] +impl TransformConfig for RedisSinkClusterConfig { + async fn get_builder(&self, chain_name: String) -> Result> { let mut cluster = RedisSinkCluster::new( self.first_contact_points.clone(), self.direct_destination.clone(), diff --git a/shotover-proxy/src/transforms/redis/sink_single.rs b/shotover-proxy/src/transforms/redis/sink_single.rs index 48be9fb9f..a70d90563 100644 --- a/shotover-proxy/src/transforms/redis/sink_single.rs +++ b/shotover-proxy/src/transforms/redis/sink_single.rs @@ -6,8 +6,9 @@ use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages}; use crate::tcp; use crate::tls::{AsyncStream, TlsConnector, TlsConnectorConfig}; -use crate::transforms::ChainResponse; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{ + ChainResponse, Transform, TransformBuilder, TransformConfig, Transforms, Wrapper, +}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use futures::{FutureExt, SinkExt, StreamExt}; @@ -29,8 +30,10 @@ pub struct RedisSinkSingleConfig { pub connect_timeout_ms: u64, } -impl RedisSinkSingleConfig { - pub async fn get_builder(&self, chain_name: String) -> Result> { +#[typetag::deserialize(name = "RedisSinkSingle")] +#[async_trait(?Send)] +impl TransformConfig for RedisSinkSingleConfig { + async fn get_builder(&self, chain_name: String) -> Result> { let tls = self.tls.clone().map(TlsConnector::new).transpose()?; Ok(Box::new(RedisSinkSingleBuilder::new( self.address.clone(), diff --git a/shotover-proxy/src/transforms/redis/timestamp_tagging.rs b/shotover-proxy/src/transforms/redis/timestamp_tagging.rs index c52031521..4156692ac 100644 --- a/shotover-proxy/src/transforms/redis/timestamp_tagging.rs +++ b/shotover-proxy/src/transforms/redis/timestamp_tagging.rs @@ -2,15 +2,27 @@ use crate::error::ChainResponse; use crate::frame::redis::redis_query_type; use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, QueryType}; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use bytes::Bytes; use itertools::Itertools; +use serde::Deserialize; use std::fmt::Write; use std::time::{SystemTime, UNIX_EPOCH}; use tracing::{debug, trace}; +#[derive(Deserialize, Debug)] +pub struct RedisTimestampTaggerConfig; + +#[typetag::deserialize(name = "RedisTimestampTagger")] +#[async_trait(?Send)] +impl TransformConfig for RedisTimestampTaggerConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { + Ok(Box::new(RedisTimestampTagger {})) + } +} + #[derive(Clone, Default)] pub struct RedisTimestampTagger {} diff --git a/shotover-proxy/src/transforms/tee.rs b/shotover-proxy/src/transforms/tee.rs index bb89b0ae1..39a9d7bce 100644 --- a/shotover-proxy/src/transforms/tee.rs +++ b/shotover-proxy/src/transforms/tee.rs @@ -1,8 +1,7 @@ +use crate::config::chain::TransformChainConfig; use crate::error::ChainResponse; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; -use crate::transforms::{ - build_chain_from_config, Transform, TransformBuilder, Transforms, TransformsConfig, Wrapper, -}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; use anyhow::Result; use async_trait::async_trait; use metrics::{register_counter, Counter}; @@ -102,7 +101,7 @@ pub enum ConsistencyBehavior { pub struct TeeConfig { pub behavior: Option, pub timeout_micros: Option, - pub chain: Vec, + pub chain: TransformChainConfig, pub buffer_size: Option, } @@ -110,11 +109,13 @@ pub struct TeeConfig { pub enum ConsistencyBehaviorConfig { Ignore, FailOnMismatch, - SubchainOnMismatch(Vec), + SubchainOnMismatch(TransformChainConfig), } -impl TeeConfig { - pub async fn get_builder(&self) -> Result> { +#[typetag::deserialize(name = "Tee")] +#[async_trait(?Send)] +impl TransformConfig for TeeConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { let buffer_size = self.buffer_size.unwrap_or(5); let behavior = match &self.behavior { Some(ConsistencyBehaviorConfig::Ignore) => ConsistencyBehaviorBuilder::Ignore, @@ -123,12 +124,14 @@ impl TeeConfig { } Some(ConsistencyBehaviorConfig::SubchainOnMismatch(mismatch_chain)) => { ConsistencyBehaviorBuilder::SubchainOnMismatch( - build_chain_from_config("mismatch_chain".to_string(), mismatch_chain).await?, + mismatch_chain + .get_builder("mismatch_chain".to_string()) + .await?, ) } None => ConsistencyBehaviorBuilder::Ignore, }; - let tee_chain = build_chain_from_config("tee_chain".to_string(), &self.chain).await?; + let tee_chain = self.chain.get_builder("tee_chain".to_string()).await?; Ok(Box::new(TeeBuilder::new( tee_chain, @@ -196,7 +199,7 @@ impl Transform for Tee { #[cfg(test)] mod tests { use super::*; - use crate::transforms::TransformsConfig; + use crate::transforms::null::NullSinkConfig; #[tokio::test] async fn test_validate_no_subchain() { @@ -204,10 +207,10 @@ mod tests { let config = TeeConfig { behavior: Some(ConsistencyBehaviorConfig::Ignore), timeout_micros: None, - chain: vec![TransformsConfig::NullSink], + chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]), buffer_size: None, }; - let transform = config.get_builder().await.unwrap(); + let transform = config.get_builder("".to_owned()).await.unwrap(); let result = transform.validate(); assert_eq!(result, Vec::::new()); } @@ -216,10 +219,10 @@ mod tests { let config = TeeConfig { behavior: Some(ConsistencyBehaviorConfig::FailOnMismatch), timeout_micros: None, - chain: vec![TransformsConfig::NullSink], + chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]), buffer_size: None, }; - let transform = config.get_builder().await.unwrap(); + let transform = config.get_builder("".to_owned()).await.unwrap(); let result = transform.validate(); assert_eq!(result, Vec::::new()); } @@ -228,16 +231,15 @@ mod tests { #[tokio::test] async fn test_validate_invalid_chain() { let config = TeeConfig { - behavior: Some(ConsistencyBehaviorConfig::SubchainOnMismatch(vec![ - TransformsConfig::NullSink, - TransformsConfig::NullSink, - ])), + behavior: Some(ConsistencyBehaviorConfig::SubchainOnMismatch( + TransformChainConfig(vec![Box::new(NullSinkConfig), Box::new(NullSinkConfig)]), + )), timeout_micros: None, - chain: vec![TransformsConfig::NullSink], + chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]), buffer_size: None, }; - let transform = config.get_builder().await.unwrap(); + let transform = config.get_builder("".to_owned()).await.unwrap(); let result = transform.validate(); let expected = vec!["Tee:", " mismatch_chain:", " Terminating transform \"NullSink\" is not last in chain. Terminating transform must be last in chain."]; assert_eq!(result, expected); @@ -246,15 +248,15 @@ mod tests { #[tokio::test] async fn test_validate_valid_chain() { let config = TeeConfig { - behavior: Some(ConsistencyBehaviorConfig::SubchainOnMismatch(vec![ - TransformsConfig::NullSink, - ])), + behavior: Some(ConsistencyBehaviorConfig::SubchainOnMismatch( + TransformChainConfig(vec![Box::new(NullSinkConfig)]), + )), timeout_micros: None, - chain: vec![TransformsConfig::NullSink], + chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]), buffer_size: None, }; - let transform = config.get_builder().await.unwrap(); + let transform = config.get_builder("".to_owned()).await.unwrap(); let result = transform.validate(); assert_eq!(result, Vec::::new()); } diff --git a/shotover-proxy/src/transforms/throttling.rs b/shotover-proxy/src/transforms/throttling.rs index 7e3a34099..99c68a771 100644 --- a/shotover-proxy/src/transforms/throttling.rs +++ b/shotover-proxy/src/transforms/throttling.rs @@ -1,8 +1,5 @@ -use crate::{ - error::ChainResponse, - message::Message, - transforms::{Transform, TransformBuilder, Wrapper}, -}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::{error::ChainResponse, message::Message}; use anyhow::Result; use async_trait::async_trait; use governor::{ @@ -23,8 +20,10 @@ pub struct RequestThrottlingConfig { pub max_requests_per_second: NonZeroU32, } -impl RequestThrottlingConfig { - pub fn get_builder(&self) -> Result> { +#[typetag::deserialize(name = "RequestThrottling")] +#[async_trait(?Send)] +impl TransformConfig for RequestThrottlingConfig { + async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(RequestThrottling { limiter: Arc::new(RateLimiter::direct(Quota::per_second( self.max_requests_per_second,