From 3fcc6a00c663dfab6ea7a196f56d689eda5990d2 Mon Sep 17 00:00:00 2001 From: Stan Bondi Date: Mon, 12 Sep 2022 16:41:30 +0400 Subject: [PATCH] fix(comms/messaging): fix possible deadlock in outbound pipeline (#4657) Description --- - Fixes possible rare deadlock when broadcasting many messages due to internal channel - Reduce number of inbound and outbound pipeline workers - Greatly reduce buffer size between inbound messaging and the inbound pipeline to allow for substream backpressure - Adds "last resort" timeout in outbound pipeline Motivation and Context --- The outbound pipeline could deadlock when all pipeline workers are busy, and the outbound sink service is full, causing the pipeline to wait for both a free executor slot and a free slot to send on the channel How Has This Been Tested? --- Memorynet, Manually: wallet stress tests (2 x wallets, 2 x base nodes), checked SAF message exchange --- base_layer/p2p/src/config.rs | 7 +- base_layer/p2p/src/initialization.rs | 4 +- base_layer/wallet/tests/contacts_service.rs | 1 - base_layer/wallet/tests/wallet.rs | 2 - base_layer/wallet_ffi/src/lib.rs | 1 - common/config/presets/c_base_node.toml | 7 +- common/config/presets/d_console_wallet.toml | 7 +- comms/core/src/pipeline/builder.rs | 16 +--- comms/core/src/pipeline/inbound.rs | 22 +++++- comms/core/src/pipeline/outbound.rs | 31 +++++--- comms/core/src/pipeline/sink.rs | 21 ++++- .../core/src/protocol/messaging/extension.rs | 6 +- comms/core/src/protocol/messaging/outbound.rs | 76 +++++++------------ comms/core/src/protocol/messaging/protocol.rs | 24 +++--- comms/dht/examples/memory_net/utilities.rs | 1 - comms/dht/src/dht.rs | 5 +- comms/dht/src/inbound/dht_handler/task.rs | 4 +- comms/dht/src/inbound/forward.rs | 11 ++- comms/dht/src/outbound/broadcast.rs | 4 - comms/dht/src/outbound/error.rs | 2 - comms/dht/src/outbound/mock.rs | 19 ++--- comms/dht/src/outbound/requester.rs | 29 +++++++ .../dht/src/store_forward/saf_handler/task.rs | 4 +- comms/dht/tests/dht.rs | 1 - 24 files changed, 166 insertions(+), 139 deletions(-) diff --git a/base_layer/p2p/src/config.rs b/base_layer/p2p/src/config.rs index 9d880bafa1..b3222d8ac8 100644 --- a/base_layer/p2p/src/config.rs +++ b/base_layer/p2p/src/config.rs @@ -95,8 +95,6 @@ pub struct P2pConfig { /// The maximum number of concurrent outbound tasks allowed before back-pressure is applied to outbound messaging /// queue pub max_concurrent_outbound_tasks: usize, - /// The size of the buffer (channel) which holds pending outbound message requests - pub outbound_buffer_size: usize, /// Configuration for DHT pub dht: DhtConfig, /// Set to true to allow peers to provide test addresses (loopback, memory etc.). If set to false, memory @@ -131,9 +129,8 @@ impl Default for P2pConfig { transport: Default::default(), datastore_path: PathBuf::from("peer_db"), peer_database_name: "peers".to_string(), - max_concurrent_inbound_tasks: 50, - max_concurrent_outbound_tasks: 100, - outbound_buffer_size: 100, + max_concurrent_inbound_tasks: 4, + max_concurrent_outbound_tasks: 4, dht: DhtConfig { database_url: DbConnectionUrl::file("dht.sqlite"), ..Default::default() diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 9edcbddc56..0aac3467aa 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -186,7 +186,6 @@ pub async fn initialize_local_test_comms>( let dht_outbound_layer = dht.outbound_middleware_layer(); let (event_sender, _) = broadcast::channel(100); let pipeline = pipeline::Builder::new() - .outbound_buffer_size(10) .with_outbound_pipeline(outbound_rx, |sink| { ServiceBuilder::new().layer(dht_outbound_layer).service(sink) }) @@ -333,7 +332,7 @@ async fn configure_comms_and_dht( let node_identity = comms.node_identity(); let shutdown_signal = comms.shutdown_signal(); // Create outbound channel - let (outbound_tx, outbound_rx) = mpsc::channel(config.outbound_buffer_size); + let (outbound_tx, outbound_rx) = mpsc::channel(config.dht.outbound_buffer_size); let mut dht = Dht::builder(); dht.with_config(config.dht.clone()).with_outbound_sender(outbound_tx); @@ -350,7 +349,6 @@ async fn configure_comms_and_dht( // Hook up DHT messaging middlewares let messaging_pipeline = pipeline::Builder::new() - .outbound_buffer_size(config.outbound_buffer_size) .with_outbound_pipeline(outbound_rx, |sink| { ServiceBuilder::new().layer(dht_outbound_layer).service(sink) }) diff --git a/base_layer/wallet/tests/contacts_service.rs b/base_layer/wallet/tests/contacts_service.rs index 62520c7471..e31f5e5cd4 100644 --- a/base_layer/wallet/tests/contacts_service.rs +++ b/base_layer/wallet/tests/contacts_service.rs @@ -83,7 +83,6 @@ pub fn setup_contacts_service( peer_database_name: random::string(8), max_concurrent_inbound_tasks: 10, max_concurrent_outbound_tasks: 10, - outbound_buffer_size: 100, dht: DhtConfig { discovery_request_timeout: Duration::from_secs(1), auto_join: true, diff --git a/base_layer/wallet/tests/wallet.rs b/base_layer/wallet/tests/wallet.rs index 9206f435fe..a0cae8e830 100644 --- a/base_layer/wallet/tests/wallet.rs +++ b/base_layer/wallet/tests/wallet.rs @@ -129,7 +129,6 @@ async fn create_wallet( peer_database_name: random::string(8), max_concurrent_inbound_tasks: 10, max_concurrent_outbound_tasks: 10, - outbound_buffer_size: 100, dht: DhtConfig { discovery_request_timeout: Duration::from_secs(1), auto_join: true, @@ -672,7 +671,6 @@ async fn test_import_utxo() { peer_database_name: random::string(8), max_concurrent_inbound_tasks: 10, max_concurrent_outbound_tasks: 10, - outbound_buffer_size: 10, dht: Default::default(), allow_test_addresses: true, listener_liveness_allowlist_cidrs: StringList::new(), diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index aa85a2e18c..73c2901e9c 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -3899,7 +3899,6 @@ pub unsafe extern "C" fn comms_config_create( peer_database_name: database_name_string, max_concurrent_inbound_tasks: 25, max_concurrent_outbound_tasks: 50, - outbound_buffer_size: 50, dht: DhtConfig { discovery_request_timeout: Duration::from_secs(discovery_timeout_in_secs), database_url: DbConnectionUrl::File(dht_database_path), diff --git a/common/config/presets/c_base_node.toml b/common/config/presets/c_base_node.toml index f497013be0..8b72c4a989 100644 --- a/common/config/presets/c_base_node.toml +++ b/common/config/presets/c_base_node.toml @@ -157,13 +157,10 @@ track_reorgs = true #peer_database_name = "peers" # The maximum number of concurrent Inbound tasks allowed before back-pressure is applied to peers -#max_concurrent_inbound_tasks = 50 +#max_concurrent_inbound_tasks = 4 # The maximum number of concurrent outbound tasks allowed before back-pressure is applied to outbound messaging queue -#max_concurrent_outbound_tasks = 100 - -# The size of the buffer (channel) which holds pending outbound message requests -#outbound_buffer_size = 100 +#max_concurrent_outbound_tasks = 4 # Set to true to allow peers to provide test addresses (loopback, memory etc.). If set to false, memory # addresses, loopback, local-link (i.e addresses used in local tests) will not be accepted from peers. This diff --git a/common/config/presets/d_console_wallet.toml b/common/config/presets/d_console_wallet.toml index beb6ee206c..a44929a546 100644 --- a/common/config/presets/d_console_wallet.toml +++ b/common/config/presets/d_console_wallet.toml @@ -189,13 +189,10 @@ event_channel_size = 3500 #peer_database_name = "peers" # The maximum number of concurrent Inbound tasks allowed before back-pressure is applied to peers -#max_concurrent_inbound_tasks = 50 +#max_concurrent_inbound_tasks = 4 # The maximum number of concurrent outbound tasks allowed before back-pressure is applied to outbound messaging queue -#max_concurrent_outbound_tasks = 100 - -# The size of the buffer (channel) which holds pending outbound message requests -#outbound_buffer_size = 100 +#max_concurrent_outbound_tasks = 4 # Set to true to allow peers to provide test addresses (loopback, memory etc.). If set to false, memory # addresses, loopback, local-link (i.e addresses used in local tests) will not be accepted from peers. This diff --git a/comms/core/src/pipeline/builder.rs b/comms/core/src/pipeline/builder.rs index 2aa88da405..b4d6a438b8 100644 --- a/comms/core/src/pipeline/builder.rs +++ b/comms/core/src/pipeline/builder.rs @@ -30,16 +30,14 @@ use crate::{ }; const DEFAULT_MAX_CONCURRENT_TASKS: usize = 50; -const DEFAULT_OUTBOUND_BUFFER_SIZE: usize = 50; -type OutboundMessageSinkService = SinkService>; +type OutboundMessageSinkService = SinkService>; /// Message pipeline builder #[derive(Default)] pub struct Builder { max_concurrent_inbound_tasks: usize, max_concurrent_outbound_tasks: Option, - outbound_buffer_size: usize, inbound: Option, outbound_rx: Option>, outbound_pipeline_factory: Option TOutSvc>>, @@ -50,7 +48,6 @@ impl Builder<(), (), ()> { Self { max_concurrent_inbound_tasks: DEFAULT_MAX_CONCURRENT_TASKS, max_concurrent_outbound_tasks: None, - outbound_buffer_size: DEFAULT_OUTBOUND_BUFFER_SIZE, inbound: None, outbound_rx: None, outbound_pipeline_factory: None, @@ -69,11 +66,6 @@ impl Builder { self } - pub fn outbound_buffer_size(mut self, buf_size: usize) -> Self { - self.outbound_buffer_size = buf_size; - self - } - pub fn with_outbound_pipeline(self, receiver: mpsc::Receiver, factory: F) -> Builder where // Factory function takes in a SinkService and returns a new composed service @@ -87,7 +79,6 @@ impl Builder { max_concurrent_inbound_tasks: self.max_concurrent_inbound_tasks, max_concurrent_outbound_tasks: self.max_concurrent_outbound_tasks, inbound: self.inbound, - outbound_buffer_size: self.outbound_buffer_size, } } @@ -100,7 +91,6 @@ impl Builder { max_concurrent_outbound_tasks: self.max_concurrent_outbound_tasks, outbound_rx: self.outbound_rx, outbound_pipeline_factory: self.outbound_pipeline_factory, - outbound_buffer_size: self.outbound_buffer_size, } } } @@ -111,7 +101,7 @@ where TInSvc: Service + Clone + Send + 'static, { fn build_outbound(&mut self) -> Result, PipelineBuilderError> { - let (out_sender, out_receiver) = mpsc::channel(self.outbound_buffer_size); + let (out_sender, out_receiver) = mpsc::unbounded_channel(); let in_receiver = self .outbound_rx @@ -157,7 +147,7 @@ pub struct OutboundPipelineConfig { /// Messages read from this stream are passed to the pipeline pub in_receiver: mpsc::Receiver, /// Receiver of `OutboundMessage`s coming from the pipeline - pub out_receiver: mpsc::Receiver, + pub out_receiver: mpsc::UnboundedReceiver, /// The pipeline (`tower::Service`) to run for each in_stream message pub pipeline: TPipeline, } diff --git a/comms/core/src/pipeline/inbound.rs b/comms/core/src/pipeline/inbound.rs index f77d5f66bb..7c6e89dab4 100644 --- a/comms/core/src/pipeline/inbound.rs +++ b/comms/core/src/pipeline/inbound.rs @@ -20,12 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{fmt::Display, time::Instant}; +use std::{ + fmt::Display, + time::{Duration, Instant}, +}; use futures::future::FusedFuture; use log::*; use tari_shutdown::ShutdownSignal; -use tokio::sync::mpsc; +use tokio::{sync::mpsc, time}; use tower::{Service, ServiceExt}; use crate::bounded_executor::BoundedExecutor; @@ -103,8 +106,19 @@ where .spawn(async move { let timer = Instant::now(); trace!(target: LOG_TARGET, "Start inbound pipeline {}", id); - if let Err(err) = service.oneshot(item).await { - warn!(target: LOG_TARGET, "Inbound pipeline returned an error: '{}'", err); + match time::timeout(Duration::from_secs(30), service.oneshot(item)).await { + Ok(Ok(_)) => {}, + Ok(Err(err)) => { + warn!(target: LOG_TARGET, "Inbound pipeline returned an error: '{}'", err); + }, + Err(_) => { + error!( + target: LOG_TARGET, + "Inbound pipeline {} timed out and was aborted. THIS SHOULD NOT HAPPEN: there was a \ + deadlock or excessive delay in processing this pipeline.", + id + ); + }, } trace!( target: LOG_TARGET, diff --git a/comms/core/src/pipeline/outbound.rs b/comms/core/src/pipeline/outbound.rs index 6f2dc115b3..e25692d328 100644 --- a/comms/core/src/pipeline/outbound.rs +++ b/comms/core/src/pipeline/outbound.rs @@ -20,11 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{fmt::Display, time::Instant}; +use std::{ + fmt::Display, + time::{Duration, Instant}, +}; use futures::future::Either; use log::*; -use tokio::sync::mpsc; +use tokio::{sync::mpsc, time}; use tower::{Service, ServiceExt}; use crate::{ @@ -93,16 +96,26 @@ where let pipeline = self.config.pipeline.clone(); let id = current_id; current_id = (current_id + 1) % u64::MAX; - self.executor .spawn(async move { let timer = Instant::now(); trace!(target: LOG_TARGET, "Start outbound pipeline {}", id); - if let Err(err) = pipeline.oneshot(msg).await { - error!( - target: LOG_TARGET, - "Outbound pipeline {} returned an error: '{}'", id, err - ); + match time::timeout(Duration::from_secs(30), pipeline.oneshot(msg)).await { + Ok(Ok(_)) => {}, + Ok(Err(err)) => { + error!( + target: LOG_TARGET, + "Outbound pipeline {} returned an error: '{}'", id, err + ); + }, + Err(_) => { + error!( + target: LOG_TARGET, + "Outbound pipeline {} timed out and was aborted. THIS SHOULD NOT HAPPEN: \ + there was a deadlock or excessive delay in processing this pipeline.", + id + ); + }, } trace!( @@ -174,7 +187,7 @@ mod test { ) .await .unwrap(); - let (out_tx, out_rx) = mpsc::channel(NUM_ITEMS); + let (out_tx, out_rx) = mpsc::unbounded_channel(); let (msg_tx, mut msg_rx) = mpsc::channel(NUM_ITEMS); let executor = Handle::current(); diff --git a/comms/core/src/pipeline/sink.rs b/comms/core/src/pipeline/sink.rs index df7fe3cdb5..376792fd12 100644 --- a/comms/core/src/pipeline/sink.rs +++ b/comms/core/src/pipeline/sink.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::task::Poll; +use std::{future, task::Poll}; use futures::{future::BoxFuture, task::Context, FutureExt}; use tower::Service; @@ -59,3 +59,22 @@ where T: Send + 'static .boxed() } } +impl Service for SinkService> +where T: Send + 'static +{ + type Error = PipelineError; + type Future = future::Ready>; + type Response = (); + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, item: T) -> Self::Future { + let sink = self.0.clone(); + let result = sink + .send(item) + .map_err(|_| anyhow::anyhow!("sink closed in sink service")); + future::ready(result) + } +} diff --git a/comms/core/src/protocol/messaging/extension.rs b/comms/core/src/protocol/messaging/extension.rs index eabbc99800..c7d1eb4c68 100644 --- a/comms/core/src/protocol/messaging/extension.rs +++ b/comms/core/src/protocol/messaging/extension.rs @@ -39,9 +39,9 @@ use crate::{ runtime::task, }; -/// Buffer size for inbound messages from _all_ peers. This should be large enough to buffer quite a few incoming -/// messages before creating backpressure on peers speaking the messaging protocol. -pub const INBOUND_MESSAGE_BUFFER_SIZE: usize = 100; +/// Buffer size for inbound messages from _all_ peers. If the message consumer is slow to get through this queue, +/// sending peers will start to experience backpressure (this is a good thing). +pub const INBOUND_MESSAGE_BUFFER_SIZE: usize = 10; /// Buffer size notifications that a peer wants to speak /tari/messaging. This buffer is used for all peers, but a low /// value is ok because this events happen once (or less) per connecting peer. For e.g. a value of 10 would allow 10 /// peers to concurrently request to speak /tari/messaging. diff --git a/comms/core/src/protocol/messaging/outbound.rs b/comms/core/src/protocol/messaging/outbound.rs index f67ab63581..6f21e346b5 100644 --- a/comms/core/src/protocol/messaging/outbound.rs +++ b/comms/core/src/protocol/messaging/outbound.rs @@ -163,37 +163,28 @@ impl OutboundMessaging { } async fn try_dial_peer(&mut self) -> Result { - let span = span!( - Level::DEBUG, - "dial_peer", - node_id = self.peer_node_id.to_string().as_str() - ); - async move { - loop { - match self.connectivity.dial_peer(self.peer_node_id.clone()).await { - Ok(conn) => break Ok(conn), - Err(ConnectivityError::DialCancelled) => { - debug!( - target: LOG_TARGET, - "Dial was cancelled for peer '{}'. This is probably because of connection tie-breaking. \ - Retrying...", - self.peer_node_id, - ); - continue; - }, - Err(err) => { - debug!( - target: LOG_TARGET, - "MessagingProtocol failed to dial peer '{}' because '{:?}'", self.peer_node_id, err - ); + loop { + match self.connectivity.dial_peer(self.peer_node_id.clone()).await { + Ok(conn) => break Ok(conn), + Err(ConnectivityError::DialCancelled) => { + debug!( + target: LOG_TARGET, + "Dial was cancelled for peer '{}'. This is probably because of connection tie-breaking. \ + Retrying...", + self.peer_node_id, + ); + continue; + }, + Err(err) => { + debug!( + target: LOG_TARGET, + "MessagingProtocol failed to dial peer '{}' because '{:?}'", self.peer_node_id, err + ); - break Err(MessagingProtocolError::PeerDialFailed(err)); - }, - } + break Err(MessagingProtocolError::PeerDialFailed(err)); + }, } } - .instrument(span) - .await } async fn try_establish( @@ -232,27 +223,16 @@ impl OutboundMessaging { &mut self, conn: &mut PeerConnection, ) -> Result, MessagingProtocolError> { - let span = span!( - Level::DEBUG, - "open_substream", - node_id = self.peer_node_id.to_string().as_str() - ); - async move { - match conn.open_substream(&MESSAGING_PROTOCOL).await { - Ok(substream) => Ok(substream), - Err(err) => { - debug!( - target: LOG_TARGET, - "MessagingProtocol failed to open a substream to peer '{}' because '{}'", - self.peer_node_id, - err - ); - Err(err.into()) - }, - } + match conn.open_substream(&MESSAGING_PROTOCOL).await { + Ok(substream) => Ok(substream), + Err(err) => { + debug!( + target: LOG_TARGET, + "MessagingProtocol failed to open a substream to peer '{}' because '{}'", self.peer_node_id, err + ); + Err(err.into()) + }, } - .instrument(span) - .await } async fn start_forwarding_messages( diff --git a/comms/core/src/protocol/messaging/protocol.rs b/comms/core/src/protocol/messaging/protocol.rs index 3d02b055ff..0e383ae9c5 100644 --- a/comms/core/src/protocol/messaging/protocol.rs +++ b/comms/core/src/protocol/messaging/protocol.rs @@ -54,7 +54,7 @@ use crate::{ const LOG_TARGET: &str = "comms::protocol::messaging"; pub(super) static MESSAGING_PROTOCOL: Bytes = Bytes::from_static(b"t/msg/0.1"); -const INTERNAL_MESSAGING_EVENT_CHANNEL_SIZE: usize = 150; +const INTERNAL_MESSAGING_EVENT_CHANNEL_SIZE: usize = 10; /// The maximum amount of inbound messages to accept within the `RATE_LIMIT_RESTOCK_INTERVAL` window const RATE_LIMIT_CAPACITY: usize = 10; @@ -163,11 +163,11 @@ impl MessagingProtocol { loop { tokio::select! { Some(event) = self.internal_messaging_event_rx.recv() => { - self.handle_internal_messaging_event(event).await; + self.handle_internal_messaging_event(event); }, Some(msg) = self.retry_queue_rx.recv() => { - if let Err(err) = self.handle_retry_queue_messages(msg).await { + if let Err(err) = self.handle_retry_queue_messages(msg) { error!( target: LOG_TARGET, "Failed to retry outbound message because '{}'", @@ -177,7 +177,7 @@ impl MessagingProtocol { }, Some(req) = self.request_rx.recv() => { - if let Err(err) = self.handle_request(req).await { + if let Err(err) = self.handle_request(req) { error!( target: LOG_TARGET, "Failed to handle request because '{}'", @@ -187,7 +187,7 @@ impl MessagingProtocol { }, Some(notification) = self.proto_notification.recv() => { - self.handle_protocol_notification(notification).await; + self.handle_protocol_notification(notification); }, _ = &mut shutdown_signal => { @@ -204,7 +204,7 @@ impl MessagingProtocol { framing::canonical(socket, MAX_FRAME_LENGTH) } - async fn handle_internal_messaging_event(&mut self, event: MessagingEvent) { + fn handle_internal_messaging_event(&mut self, event: MessagingEvent) { use MessagingEvent::OutboundProtocolExited; trace!(target: LOG_TARGET, "Internal messaging event '{}'", event); match event { @@ -231,26 +231,26 @@ impl MessagingProtocol { } } - async fn handle_request(&mut self, req: MessagingRequest) -> Result<(), MessagingProtocolError> { + fn handle_request(&mut self, req: MessagingRequest) -> Result<(), MessagingProtocolError> { use MessagingRequest::SendMessage; match req { SendMessage(msg) => { trace!(target: LOG_TARGET, "Received request to send message ({})", msg); - self.send_message(msg).await?; + self.send_message(msg)?; }, } Ok(()) } - async fn handle_retry_queue_messages(&mut self, msg: OutboundMessage) -> Result<(), MessagingProtocolError> { + fn handle_retry_queue_messages(&mut self, msg: OutboundMessage) -> Result<(), MessagingProtocolError> { debug!(target: LOG_TARGET, "Retrying outbound message ({})", msg); - self.send_message(msg).await?; + self.send_message(msg)?; Ok(()) } // #[tracing::instrument(skip(self, out_msg), err)] - async fn send_message(&mut self, out_msg: OutboundMessage) -> Result<(), MessagingProtocolError> { + fn send_message(&mut self, out_msg: OutboundMessage) -> Result<(), MessagingProtocolError> { let peer_node_id = out_msg.peer_node_id.clone(); let sender = loop { match self.active_queues.entry(peer_node_id.clone()) { @@ -315,7 +315,7 @@ impl MessagingProtocol { task::spawn(inbound_messaging.run(substream)); } - async fn handle_protocol_notification(&mut self, notification: ProtocolNotification) { + fn handle_protocol_notification(&mut self, notification: ProtocolNotification) { match notification.event { // Peer negotiated to speak the messaging protocol with us ProtocolEvent::NewInboundSubstream(node_id, substream) => { diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index f9596e16ab..7f33285172 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -949,7 +949,6 @@ async fn setup_comms_dht( let dht_outbound_layer = dht.outbound_middleware_layer(); let pipeline = pipeline::Builder::new() - .outbound_buffer_size(10) .with_outbound_pipeline(outbound_rx, |sink| { ServiceBuilder::new().layer(dht_outbound_layer).service(sink) }) diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index c70db16c34..f9a00d4387 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -615,7 +615,10 @@ mod test { service.call(inbound_message).await.unwrap(); - assert_eq!(oms_mock_state.call_count().await, 1); + oms_mock_state + .wait_call_count(1, Duration::from_secs(10)) + .await + .unwrap(); let (params, _) = oms_mock_state.pop_call().await.unwrap(); // Check that OMS got a request to forward with the original Dht Header diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index e20dc71a24..c4c0e52f84 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -225,7 +225,7 @@ where S: Service ); // Propagate message to closer peers self.outbound_service - .send_raw( + .send_raw_no_wait( SendMessageParams::new() .propagate(origin_public_key.clone().into(), vec![ origin_peer.node_id, @@ -349,7 +349,7 @@ where S: Service trace!(target: LOG_TARGET, "Sending discovery response to {}", dest_public_key); self.outbound_service - .send_message_no_header( + .send_message_no_header_no_wait( SendMessageParams::new() .direct_public_key(dest_public_key) .with_destination(NodeDestination::Unknown) diff --git a/comms/dht/src/inbound/forward.rs b/comms/dht/src/inbound/forward.rs index ddc7aab54e..2bb455b67e 100644 --- a/comms/dht/src/inbound/forward.rs +++ b/comms/dht/src/inbound/forward.rs @@ -237,7 +237,9 @@ where S: Service if !is_already_forwarded { send_params.with_dht_header(dht_header.clone()); - self.outbound_service.send_raw(send_params.finish(), body).await?; + self.outbound_service + .send_raw_no_wait(send_params.finish(), body) + .await?; } Ok(()) @@ -254,6 +256,8 @@ where S: Service #[cfg(test)] mod test { + use std::time::Duration; + use tari_comms::{runtime, runtime::task, wrap_in_envelope_body}; use tokio::sync::mpsc; @@ -306,7 +310,10 @@ mod test { service.call(msg).await.unwrap(); assert!(spy.is_called()); - assert_eq!(oms_mock_state.call_count().await, 1); + oms_mock_state + .wait_call_count(1, Duration::from_secs(10)) + .await + .unwrap(); let (params, body) = oms_mock_state.pop_call().await.unwrap(); // Header and body are preserved when forwarding diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 71d079029d..51c8dc37ab 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -267,10 +267,6 @@ where S: Service match self.select_peers(broadcast_strategy.clone()).await { Ok(mut peers) => { - if reply_tx.is_closed() { - return Err(DhtOutboundError::ReplyChannelCanceled); - } - let mut reply_tx = Some(reply_tx); trace!( diff --git a/comms/dht/src/outbound/error.rs b/comms/dht/src/outbound/error.rs index 0759f7e7ce..e8ee3fcc34 100644 --- a/comms/dht/src/outbound/error.rs +++ b/comms/dht/src/outbound/error.rs @@ -47,8 +47,6 @@ pub enum DhtOutboundError { RequesterReplyChannelClosed, #[error("Peer selection failed")] PeerSelectionFailed, - #[error("Reply channel cancelled")] - ReplyChannelCanceled, #[error("Attempted to send a message to ourselves")] SendToOurselves, #[error("Discovery process failed")] diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index d56d26de24..7d7b58d926 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -184,34 +184,31 @@ impl OutboundServiceMock { match behaviour.direct { ResponseType::Queued => { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body).await; - reply_tx.send(response).expect("Reply channel cancelled"); + let _ignore = reply_tx.send(response); inner_reply_tx.reply_success(); }, ResponseType::QueuedFail => { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body).await; - reply_tx.send(response).expect("Reply channel cancelled"); + let _ignore = reply_tx.send(response); inner_reply_tx.reply_fail(SendFailReason::PeerDialFailed); }, ResponseType::QueuedSuccessDelay(delay) => { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body).await; - reply_tx.send(response).expect("Reply channel cancelled"); + let _ignore = reply_tx.send(response); sleep(delay).await; inner_reply_tx.reply_success(); }, resp => { - reply_tx - .send(SendMessageResponse::Failed(SendFailure::General(format!( - "Unexpected mock response {:?}", - resp - )))) - .expect("Reply channel cancelled"); + let _ignore = reply_tx.send(SendMessageResponse::Failed(SendFailure::General( + format!("Unexpected mock response {:?}", resp), + ))); }, }; }, BroadcastStrategy::ClosestNodes(_) => { if behaviour.broadcast == ResponseType::Queued { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body).await; - reply_tx.send(response).expect("Reply channel cancelled"); + let _ignore = reply_tx.send(response); inner_reply_tx.reply_success(); } else { reply_tx @@ -223,7 +220,7 @@ impl OutboundServiceMock { }, _ => { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body).await; - reply_tx.send(response).expect("Reply channel cancelled"); + let _ignore = reply_tx.send(response); inner_reply_tx.reply_success(); }, } diff --git a/comms/dht/src/outbound/requester.rs b/comms/dht/src/outbound/requester.rs index a3e4465483..945b64bc8b 100644 --- a/comms/dht/src/outbound/requester.rs +++ b/comms/dht/src/outbound/requester.rs @@ -269,6 +269,22 @@ impl OutboundMessageRequester { self.send_raw(params, body).await } + /// Send a message without a domain header part + pub async fn send_message_no_header_no_wait( + &mut self, + params: FinalSendMessageParams, + message: T, + ) -> Result<(), DhtOutboundError> + where + T: prost::Message, + { + if cfg!(debug_assertions) { + trace!(target: LOG_TARGET, "Send Message: {} {:?}", params, message); + } + let body = wrap_in_envelope_body!(message).to_encoded_bytes(); + self.send_raw_no_wait(params, body).await + } + /// Send a raw message pub async fn send_raw( &mut self, @@ -285,6 +301,19 @@ impl OutboundMessageRequester { .map_err(|_| DhtOutboundError::RequesterReplyChannelClosed) } + /// Send a raw message + pub async fn send_raw_no_wait( + &mut self, + params: FinalSendMessageParams, + body: Vec, + ) -> Result<(), DhtOutboundError> { + let (reply_tx, _) = oneshot::channel(); + self.sender + .send(DhtOutboundRequest::SendMessage(Box::new(params), body.into(), reply_tx)) + .await?; + Ok(()) + } + #[cfg(test)] pub fn get_mpsc_sender(&self) -> mpsc::Sender { self.sender.clone() diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index 5ed85a8174..0aada15e4e 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -229,15 +229,13 @@ where S: Service match self .outbound_service - .send_message_no_header( + .send_message_no_header_no_wait( SendMessageParams::new() .direct_public_key(message.source_peer.public_key.clone()) .with_dht_message_type(DhtMessageType::SafStoredMessages) .finish(), stored_messages, ) - .await? - .resolve() .await { Ok(_) => { diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index 60586f3251..9928c1df79 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -201,7 +201,6 @@ async fn setup_comms_dht( let dht_outbound_layer = dht.outbound_middleware_layer(); let pipeline = pipeline::Builder::new() - .outbound_buffer_size(10) .with_outbound_pipeline(outbound_rx, |sink| { ServiceBuilder::new().layer(dht_outbound_layer).service(sink) })