From e3aca2dc67667abfc41f20c995982b0911ddd51f Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Wed, 14 Dec 2022 10:19:08 +1100 Subject: [PATCH] Do not clone transforms when running chain in reverse (#960) --- shotover-proxy/src/transforms/chain.rs | 4 +-- shotover-proxy/src/transforms/mod.rs | 49 +++++++++++++++++++++----- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/shotover-proxy/src/transforms/chain.rs b/shotover-proxy/src/transforms/chain.rs index 290564e65..42243708f 100644 --- a/shotover-proxy/src/transforms/chain.rs +++ b/shotover-proxy/src/transforms/chain.rs @@ -191,9 +191,7 @@ impl TransformChain { pub async fn process_request_rev(&mut self, mut wrapper: Wrapper<'_>) -> ChainResponse { let start = Instant::now(); - - let mut chain: Vec<_> = self.chain.iter().cloned().rev().collect(); - wrapper.reset(&mut chain); + wrapper.reset_rev(&mut self.chain); let result = wrapper.call_next_transform_pushed().await; self.chain_total.increment(1); diff --git a/shotover-proxy/src/transforms/mod.rs b/shotover-proxy/src/transforms/mod.rs index 1d6d54c24..0029dfb1f 100644 --- a/shotover-proxy/src/transforms/mod.rs +++ b/shotover-proxy/src/transforms/mod.rs @@ -43,8 +43,10 @@ use futures::Future; use metrics::{counter, histogram}; use serde::Deserialize; use std::fmt::{Debug, Formatter}; +use std::iter::Rev; use std::net::SocketAddr; use std::pin::Pin; +use std::slice::IterMut; use strum_macros::IntoStaticStr; use tokio::sync::mpsc; use tokio::time::Instant; @@ -444,15 +446,13 @@ pub async fn build_chain_from_config( Ok(TransformChainBuilder::new(transforms, name)) } -use std::slice::IterMut; - /// The [`Wrapper`] struct is passed into each transform and contains a list of mutable references to the /// remaining transforms that will process the messages attached to this [`Wrapper`]. /// Most [`Transform`] authors will only be interested in [`Wrapper.messages`]. #[derive(Debug)] pub struct Wrapper<'a> { pub messages: Messages, - transforms: IterMut<'a, Transforms>, + transforms: TransformIter<'a>, pub client_details: String, /// Contains the shotover source's ip address and port which the message was received on pub local_addr: SocketAddr, @@ -463,6 +463,33 @@ pub struct Wrapper<'a> { pub flush: bool, } +#[derive(Debug)] +enum TransformIter<'a> { + Forwards(IterMut<'a, Transforms>), + Backwards(Rev>), +} + +impl<'a> TransformIter<'a> { + fn new_forwards(transforms: &'a mut [Transforms]) -> TransformIter<'a> { + TransformIter::Forwards(transforms.iter_mut()) + } + + fn new_backwards(transforms: &'a mut [Transforms]) -> TransformIter<'a> { + TransformIter::Backwards(transforms.iter_mut().rev()) + } +} + +impl<'a> Iterator for TransformIter<'a> { + type Item = &'a mut Transforms; + + fn next(&mut self) -> Option { + match self { + TransformIter::Forwards(iter) => iter.next(), + TransformIter::Backwards(iter) => iter.next(), + } + } +} + /// [`Wrapper`] will not (cannot) bring the current list of transforms that it needs to traverse with it /// This is purely to make it convenient to clone all the data within Wrapper rather than it's transform /// state. @@ -470,7 +497,7 @@ impl<'a> Clone for Wrapper<'a> { fn clone(&self) -> Self { Wrapper { messages: self.messages.clone(), - transforms: [].iter_mut(), + transforms: TransformIter::new_forwards(&mut []), client_details: self.client_details.clone(), chain_name: self.chain_name.clone(), local_addr: self.local_addr, @@ -540,7 +567,7 @@ impl<'a> Wrapper<'a> { pub fn new(m: Messages) -> Self { Wrapper { messages: m, - transforms: [].iter_mut(), + transforms: TransformIter::new_forwards(&mut []), client_details: "".to_string(), local_addr: "127.0.0.1:8000".parse().unwrap(), chain_name: "".to_string(), @@ -551,7 +578,7 @@ impl<'a> Wrapper<'a> { pub fn new_with_chain_name(m: Messages, chain_name: String, local_addr: SocketAddr) -> Self { Wrapper { messages: m, - transforms: [].iter_mut(), + transforms: TransformIter::new_forwards(&mut []), client_details: "".to_string(), local_addr, chain_name, @@ -562,7 +589,7 @@ impl<'a> Wrapper<'a> { pub fn flush_with_chain_name(chain_name: String) -> Self { Wrapper { messages: vec![], - transforms: [].iter_mut(), + transforms: TransformIter::new_forwards(&mut []), client_details: "".into(), // The connection is closed so we need to just fake an address here local_addr: "127.0.0.1:10000".parse().unwrap(), @@ -579,7 +606,7 @@ impl<'a> Wrapper<'a> { ) -> Self { Wrapper { messages: m, - transforms: [].iter_mut(), + transforms: TransformIter::new_forwards(&mut []), client_details, local_addr, chain_name, @@ -588,7 +615,11 @@ impl<'a> Wrapper<'a> { } pub fn reset(&mut self, transforms: &'a mut [Transforms]) { - self.transforms = transforms.iter_mut(); + self.transforms = TransformIter::new_forwards(transforms); + } + + pub fn reset_rev(&mut self, transforms: &'a mut [Transforms]) { + self.transforms = TransformIter::new_backwards(transforms); } }