diff --git a/applications/tari_base_node/src/bootstrap.rs b/applications/tari_base_node/src/bootstrap.rs index 5021e527847..97d1c24643d 100644 --- a/applications/tari_base_node/src/bootstrap.rs +++ b/applications/tari_base_node/src/bootstrap.rs @@ -78,7 +78,7 @@ impl BaseNodeBootstrapper<'_, B> where B: BlockchainBackend + 'static { pub async fn bootstrap(self) -> Result { - let base_node_config = &self.app_config.base_node; + let mut base_node_config = self.app_config.base_node.clone(); let mut p2p_config = self.app_config.base_node.p2p.clone(); let peer_seeds = &self.app_config.peer_seeds; @@ -95,6 +95,8 @@ where B: BlockchainBackend + 'static .collect::, _>>() .map_err(|e| ExitError::new(ExitCode::ConfigError, e))?; + base_node_config.state_machine.blockchain_sync_config.forced_sync_peers = sync_peers.clone(); + debug!(target: LOG_TARGET, "{} sync peer(s) configured", sync_peers.len()); let mempool_sync = MempoolSyncInitializer::new(mempool_config, self.mempool.clone()); diff --git a/base_layer/core/src/base_node/sync/rpc/service.rs b/base_layer/core/src/base_node/sync/rpc/service.rs index 436ebccd8fb..8ea2c04ce1f 100644 --- a/base_layer/core/src/base_node/sync/rpc/service.rs +++ b/base_layer/core/src/base_node/sync/rpc/service.rs @@ -35,7 +35,7 @@ use tari_comms::{ }; use tari_utilities::hex::Hex; use tokio::{ - sync::{mpsc, RwLock}, + sync::{mpsc, Mutex}, task, }; use tracing::{instrument, span, Instrument, Level}; @@ -65,7 +65,7 @@ const LOG_TARGET: &str = "c::base_node::sync_rpc"; pub struct BaseNodeSyncRpcService { db: AsyncBlockchainDb, - active_sessions: RwLock>>, + active_sessions: Mutex>>, base_node_service: LocalNodeCommsInterface, } @@ -73,7 +73,7 @@ impl BaseNodeSyncRpcService { pub fn new(db: AsyncBlockchainDb, base_node_service: LocalNodeCommsInterface) -> Self { Self { db, - active_sessions: RwLock::new(Vec::new()), + active_sessions: Mutex::new(Vec::new()), base_node_service, } } @@ -84,7 +84,7 @@ impl BaseNodeSyncRpcService { } pub async fn try_add_exclusive_session(&self, peer: NodeId) -> Result, RpcStatus> { - let mut lock = self.active_sessions.write().await; + let mut lock = self.active_sessions.lock().await; *lock = lock.drain(..).filter(|l| l.strong_count() > 0).collect(); debug!(target: LOG_TARGET, "Number of active sync sessions: {}", lock.len()); diff --git a/base_layer/p2p/src/services/liveness/state.rs b/base_layer/p2p/src/services/liveness/state.rs index 6cadbdcd102..0c2282811bc 100644 --- a/base_layer/p2p/src/services/liveness/state.rs +++ b/base_layer/p2p/src/services/liveness/state.rs @@ -173,9 +173,11 @@ impl LivenessState { let (node_id, _) = self.inflight_pings.get(&nonce)?; if node_id == sent_by { - self.inflight_pings - .remove(&nonce) - .map(|(node_id, sent_time)| self.add_latency_sample(node_id, sent_time.elapsed()).calc_average()) + self.inflight_pings.remove(&nonce).map(|(node_id, sent_time)| { + let latency = sent_time.elapsed(); + self.add_latency_sample(node_id, latency); + latency + }) } else { warn!( target: LOG_TARGET, diff --git a/comms/core/src/protocol/rpc/client/mod.rs b/comms/core/src/protocol/rpc/client/mod.rs index 982595f0529..257905bf645 100644 --- a/comms/core/src/protocol/rpc/client/mod.rs +++ b/comms/core/src/protocol/rpc/client/mod.rs @@ -39,6 +39,7 @@ use std::{ use bytes::Bytes; use futures::{ + future, future::{BoxFuture, Either}, task::{Context, Poll}, FutureExt, @@ -491,7 +492,10 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId break; } } - None => break, + None => { + debug!(target: LOG_TARGET, "(stream={}) Request channel closed. Worker is terminating.", self.stream_id()); + break + }, } } } @@ -618,7 +622,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId ); } - let (response_tx, response_rx) = mpsc::channel(10); + let (response_tx, response_rx) = mpsc::channel(5); if let Err(mut rx) = reply.send(response_rx) { event!(Level::WARN, "Client request was cancelled after request was sent"); warn!( @@ -636,7 +640,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId if let Err(err) = self.send_request(req).await { warn!(target: LOG_TARGET, "{}", err); metrics::client_errors(&self.node_id, &self.protocol_id).inc(); - let _result = response_tx.send(Err(err.into())); + let _result = response_tx.send(Err(err.into())).await; return Ok(()); } @@ -654,7 +658,27 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId break; } - let resp = match self.read_response(request_id).await { + // Check if the response receiver has been dropped while receiving messages + let resp_result = { + let resp_fut = self.read_response(request_id); + tokio::pin!(resp_fut); + let closed_fut = response_tx.closed(); + tokio::pin!(closed_fut); + match future::select(resp_fut, closed_fut).await { + Either::Left((r, _)) => Some(r), + Either::Right(_) => None, + } + }; + let resp_result = match resp_result { + Some(r) => r, + None => { + self.premature_close(request_id, method).await?; + break; + }, + }; + + // let resp = match self.read_response(request_id).await { + let resp = match resp_result { Ok(resp) => { if let Some(t) = timer.take() { let _ = self.last_request_latency_tx.send(Some(t.elapsed())); @@ -682,14 +706,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId event!(Level::ERROR, "Response timed out"); metrics::client_timeouts(&self.node_id, &self.protocol_id).inc(); if response_tx.is_closed() { - let req = proto::rpc::RpcRequest { - request_id: u32::try_from(request_id).unwrap(), - method, - flags: RpcMessageFlags::FIN.bits().into(), - ..Default::default() - }; - - self.send_request(req).await?; + self.premature_close(request_id, method).await?; } else { let _result = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await; } @@ -721,21 +738,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId // The consumer may drop the receiver before all responses are received. // We handle this by sending a 'FIN' message to the server. if response_tx.is_closed() { - warn!( - target: LOG_TARGET, - "(stream={}) Response receiver was dropped before the response/stream could complete for \ - protocol {}, interrupting the stream. ", - self.stream_id(), - self.protocol_name() - ); - let req = proto::rpc::RpcRequest { - request_id: u32::try_from(request_id).unwrap(), - method, - flags: RpcMessageFlags::FIN.bits().into(), - ..Default::default() - }; - - self.send_request(req).await?; + self.premature_close(request_id, method).await?; break; } else { let _result = response_tx.send(Ok(resp)).await; @@ -766,6 +769,29 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId Ok(()) } + async fn premature_close(&mut self, request_id: u16, method: u32) -> Result<(), RpcError> { + warn!( + target: LOG_TARGET, + "(stream={}) Response receiver was dropped before the response/stream could complete for protocol {}, \ + interrupting the stream. ", + self.stream_id(), + self.protocol_name() + ); + let req = proto::rpc::RpcRequest { + request_id: u32::try_from(request_id).unwrap(), + method, + flags: RpcMessageFlags::FIN.bits().into(), + deadline: self.config.deadline.map(|d| d.as_secs()).unwrap_or(0), + ..Default::default() + }; + + // If we cannot set FIN quickly, just exit + if let Ok(res) = time::timeout(Duration::from_secs(2), self.send_request(req)).await { + res?; + } + Ok(()) + } + async fn send_request(&mut self, req: proto::rpc::RpcRequest) -> Result<(), RpcError> { let payload = req.to_encoded_bytes(); if payload.len() > rpc::max_request_size() { diff --git a/comms/core/src/protocol/rpc/server/early_close.rs b/comms/core/src/protocol/rpc/server/early_close.rs new file mode 100644 index 00000000000..82973bb8ef1 --- /dev/null +++ b/comms/core/src/protocol/rpc/server/early_close.rs @@ -0,0 +1,119 @@ +// Copyright 2022. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::Sink; +use tokio_stream::Stream; + +pub struct EarlyClose { + inner: TSock, +} + +impl> + Unpin> EarlyClose { + pub fn new(inner: TSock) -> Self { + Self { inner } + } +} + +impl Stream for EarlyClose { + type Item = TSock::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } +} + +impl Sink for EarlyClose +where TSock: Sink + Stream> + Unpin +{ + type Error = EarlyCloseError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Poll::Ready(r) = Pin::new(&mut self.inner).poll_ready(cx) { + return Poll::Ready(r.map_err(Into::into)); + } + check_for_early_close(&mut self.inner, cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: TItem) -> Result<(), Self::Error> { + Pin::new(&mut self.inner).start_send(item)?; + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Poll::Ready(r) = Pin::new(&mut self.inner).poll_flush(cx) { + return Poll::Ready(r.map_err(Into::into)); + } + check_for_early_close(&mut self.inner, cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Poll::Ready(r) = Pin::new(&mut self.inner).poll_close(cx) { + return Poll::Ready(r.map_err(Into::into)); + } + check_for_early_close(&mut self.inner, cx) + } +} + +fn check_for_early_close> + Unpin>( + sock: &mut TSock, + cx: &mut Context<'_>, +) -> Poll>> { + match Pin::new(sock).poll_next(cx) { + Poll::Ready(Some(Ok(msg))) => Poll::Ready(Err(EarlyCloseError::UnexpectedMessage(msg))), + Poll::Ready(Some(Err(err))) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + Poll::Pending => Poll::Pending, + Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err.into())), + Poll::Ready(None) => Poll::Ready(Err( + io::Error::new(io::ErrorKind::BrokenPipe, "Connection closed").into() + )), + } +} + +#[derive(Debug, thiserror::Error)] +pub enum EarlyCloseError { + #[error(transparent)] + Io(#[from] io::Error), + #[error("Unexpected message")] + UnexpectedMessage(T), +} + +impl EarlyCloseError { + pub fn io(&self) -> Option<&io::Error> { + match self { + Self::Io(err) => Some(err), + _ => None, + } + } + + pub fn unexpected_message(&self) -> Option<&T> { + match self { + EarlyCloseError::UnexpectedMessage(msg) => Some(msg), + _ => None, + } + } +} diff --git a/comms/core/src/protocol/rpc/server/error.rs b/comms/core/src/protocol/rpc/server/error.rs index 38f257b4238..ea3458b4e5e 100644 --- a/comms/core/src/protocol/rpc/server/error.rs +++ b/comms/core/src/protocol/rpc/server/error.rs @@ -22,10 +22,15 @@ use std::io; +use bytes::BytesMut; use prost::DecodeError; use tokio::sync::oneshot; -use crate::{peer_manager::NodeId, proto, protocol::rpc::handshake::RpcHandshakeError}; +use crate::{ + peer_manager::NodeId, + proto, + protocol::rpc::{handshake::RpcHandshakeError, server::early_close::EarlyCloseError}, +}; #[derive(Debug, thiserror::Error)] pub enum RpcServerError { @@ -55,6 +60,8 @@ pub enum RpcServerError { ServiceCallExceededDeadline, #[error("Stream read exceeded deadline")] ReadStreamExceededDeadline, + #[error("Early close error: {0}")] + EarlyCloseError(#[from] EarlyCloseError), } impl From for RpcServerError { diff --git a/comms/core/src/protocol/rpc/server/mod.rs b/comms/core/src/protocol/rpc/server/mod.rs index 5dc56f354cd..0a6dfa2c85b 100644 --- a/comms/core/src/protocol/rpc/server/mod.rs +++ b/comms/core/src/protocol/rpc/server/mod.rs @@ -34,6 +34,7 @@ mod metrics; pub mod mock; +mod early_close; mod router; use std::{ @@ -50,6 +51,7 @@ use std::{ }; use futures::{future, stream, stream::FuturesUnordered, SinkExt, StreamExt}; +use log::*; use prost::Message; use router::Router; use tokio::{sync::mpsc, task::JoinHandle, time}; @@ -78,6 +80,7 @@ use crate::{ rpc::{ body::BodyBytes, message::{RpcMethod, RpcResponse}, + server::early_close::EarlyClose, }, ProtocolEvent, ProtocolId, @@ -89,7 +92,7 @@ use crate::{ Substream, }; -const LOG_TARGET: &str = "comms::rpc"; +const LOG_TARGET: &str = "comms::rpc::server"; pub trait NamedProtocolService { const PROTOCOL_NAME: &'static [u8]; @@ -323,18 +326,7 @@ where let _ = reply.send(num_active); }, GetNumActiveSessionsForPeer(node_id, reply) => { - let num_active = self - .sessions - .get(&node_id) - .map(|num_sessions| { - let max_sessions = self - .config - .maximum_sessions_per_client - .unwrap_or_else(BoundedExecutor::max_theoretical_tasks); - max_sessions.saturating_sub(*num_sessions) - }) - .unwrap_or(0); - + let num_active = self.sessions.get(&node_id).copied().unwrap_or(0); let _ = reply.send(num_active); }, } @@ -375,23 +367,23 @@ where } fn new_session_for(&mut self, node_id: NodeId) -> Result { + let count = self.sessions.entry(node_id.clone()).or_insert(0); match self.config.maximum_sessions_per_client { Some(max) if max > 0 => { - let count = self.sessions.entry(node_id.clone()).or_insert(0); - debug_assert!(*count <= max); if *count >= max { return Err(RpcServerError::MaxSessionsPerClientReached { node_id }); } - *count += 1; - Ok(*count) }, - Some(_) => Ok(0), - None => Ok(0), + Some(_) | None => {}, } + + *count += 1; + Ok(*count) } fn on_session_complete(&mut self, node_id: &NodeId) { + info!(target: LOG_TARGET, "Session complete for {}", node_id); if let Some(v) = self.sessions.get_mut(node_id) { *v -= 1; if *v == 0 { @@ -438,11 +430,20 @@ where }, }; - if let Err(err) = self.new_session_for(node_id.clone()) { - handshake - .reject_with_reason(HandshakeRejectReason::NoSessionsAvailable) - .await?; - return Err(err); + match self.new_session_for(node_id.clone()) { + Ok(num_sessions) => { + info!( + target: LOG_TARGET, + "NEW SESSION for {} ({} active) ", node_id, num_sessions + ); + }, + + Err(err) => { + handshake + .reject_with_reason(HandshakeRejectReason::NoSessionsAvailable) + .await?; + return Err(err); + }, } let version = handshake.perform_server_handshake().await?; @@ -467,7 +468,9 @@ where let num_sessions = metrics::num_sessions(&node_id, &service.protocol); num_sessions.inc(); service.start().await; + info!(target: LOG_TARGET, "END OF SESSION for {} ", node_id,); num_sessions.dec(); + node_id }) .map_err(|_| RpcServerError::MaximumSessionsReached)?; @@ -483,7 +486,7 @@ struct ActivePeerRpcService { protocol: ProtocolId, node_id: NodeId, service: TSvc, - framed: CanonicalFraming, + framed: EarlyClose>, comms_provider: TCommsProvider, logging_context_string: Arc, } @@ -513,7 +516,7 @@ where protocol, node_id, service, - framed, + framed: EarlyClose::new(framed), comms_provider, } } @@ -525,9 +528,17 @@ where ); if let Err(err) = self.run().await { metrics::error_counter(&self.node_id, &self.protocol, &err).inc(); - error!( + let level = match &err { + RpcServerError::Io(e) => err_to_log_level(&e), + RpcServerError::EarlyCloseError(e) => e.io().map(err_to_log_level).unwrap_or(log::Level::Error), + _ => log::Level::Error, + }; + log!( target: LOG_TARGET, - "({}) Rpc server exited with an error: {}", self.logging_context_string, err + level, + "({}) Rpc server exited with an error: {}", + self.logging_context_string, + err ); } } @@ -541,11 +552,14 @@ where request_bytes.observe(frame.len() as f64); if let Err(err) = self.handle_request(frame.freeze()).await { if let Err(err) = self.framed.close().await { - error!( + let level = err.io().map(err_to_log_level).unwrap_or(log::Level::Error); + + log!( target: LOG_TARGET, + level, "({}) Failed to close substream after socket error: {}", self.logging_context_string, - err + err, ); } error!( @@ -725,44 +739,50 @@ where .map(|resp| Bytes::from(resp.to_encoded_bytes())); loop { - // Check if the client interrupted the outgoing stream - if let Err(err) = self.check_interruptions().await { - match err { - err @ RpcServerError::ClientInterruptedStream => { - debug!(target: LOG_TARGET, "Stream was interrupted: {}", err); - break; - }, - err => { - error!(target: LOG_TARGET, "Stream was interrupted: {}", err); - return Err(err); - }, - } - } - let next_item = log_timing( self.logging_context_string.clone(), request_id, "message read", stream.next(), ); - match time::timeout(deadline, next_item).await { - Ok(Some(msg)) => { - response_bytes.observe(msg.len() as f64); - debug!( - target: LOG_TARGET, - "({}) Sending body len = {}", - self.logging_context_string, - msg.len() - ); + let timeout = time::sleep(deadline); - self.framed.send(msg).await?; + tokio::select! { + // Check if the client interrupted the outgoing stream + Err(err) = self.check_interruptions() => { + match err { + err @ RpcServerError::ClientInterruptedStream => { + debug!(target: LOG_TARGET, "Stream was interrupted by client: {}", err); + break; + }, + err => { + error!(target: LOG_TARGET, "Stream was interrupted: {}", err); + return Err(err); + }, + } }, - Ok(None) => { - debug!(target: LOG_TARGET, "{} Request complete", self.logging_context_string,); - break; + msg = next_item => { + match msg { + Some(msg) => { + response_bytes.observe(msg.len() as f64); + debug!( + target: LOG_TARGET, + "({}) Sending body len = {}", + self.logging_context_string, + msg.len() + ); + + self.framed.send(msg).await?; + }, + None => { + debug!(target: LOG_TARGET, "{} Request complete", self.logging_context_string,); + break; + }, + } }, - Err(_) => { - debug!( + + _ = timeout => { + debug!( target: LOG_TARGET, "({}) Failed to return result within client deadline ({:.0?})", self.logging_context_string, @@ -776,8 +796,8 @@ where ) .inc(); break; - }, - } + } + } // end select! } // end loop Ok(()) } @@ -833,11 +853,9 @@ async fn log_timing>(context_str: Arc, request_ ret } -#[allow(clippy::cognitive_complexity)] fn into_response(request_id: u32, result: Result) -> RpcResponse { match result { Ok(msg) => { - trace!(target: LOG_TARGET, "Sending body len = {}", msg.len()); let mut flags = RpcMessageFlags::empty(); if msg.is_finished() { flags |= RpcMessageFlags::FIN; @@ -860,3 +878,10 @@ fn into_response(request_id: u32, result: Result) -> RpcRe }, } } + +fn err_to_log_level(err: &io::Error) -> log::Level { + match err.kind() { + io::ErrorKind::BrokenPipe | io::ErrorKind::WriteZero => log::Level::Debug, + _ => log::Level::Error, + } +} diff --git a/comms/core/src/protocol/rpc/test/smoke.rs b/comms/core/src/protocol/rpc/test/smoke.rs index 515ba4f41cb..6ebb3ea4661 100644 --- a/comms/core/src/protocol/rpc/test/smoke.rs +++ b/comms/core/src/protocol/rpc/test/smoke.rs @@ -551,7 +551,7 @@ async fn max_per_client_sessions() { let socket = inbound.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); - let mut client = GreetingClient::builder() + let client = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) .connect(framed) .await @@ -568,7 +568,6 @@ async fn max_per_client_sessions() { unpack_enum!(RpcError::HandshakeError(err) = err); unpack_enum!(RpcHandshakeError::Rejected(HandshakeRejectReason::NoSessionsAvailable) = err); - client.close().await; drop(client); let substream = outbound.get_yamux_control().open_stream().await.unwrap(); muxer diff --git a/comms/core/tests/greeting_service.rs b/comms/core/tests/greeting_service.rs index f06c738b51b..e455e00fde0 100644 --- a/comms/core/tests/greeting_service.rs +++ b/comms/core/tests/greeting_service.rs @@ -107,6 +107,7 @@ impl GreetingRpc for GreetingService { id, item_size, num_items, + delay_ms: delay_secs, } = request.into_message(); let (tx, rx) = mpsc::channel(10); let t = std::time::Instant::now(); @@ -118,7 +119,20 @@ impl GreetingRpc for GreetingService { .take(usize::try_from(num_items).unwrap()) .enumerate() { - tx.send(item).await.unwrap(); + if delay_secs > 0 { + time::sleep(Duration::from_millis(delay_secs)).await; + } + if tx.send(item).await.is_err() { + log::info!( + "[{}] reqid: {} t={:.2?} STREAM INTERRUPTED {}/{}", + id, + req_id, + t.elapsed(), + i + 1, + num_items + ); + return; + } log::info!( "[{}] reqid: {} t={:.2?} sent {}/{}", id, @@ -160,4 +174,6 @@ pub struct StreamLargeItemsRequest { pub num_items: u64, #[prost(uint64, tag = "3")] pub item_size: u64, + #[prost(uint64, tag = "4")] + pub delay_ms: u64, } diff --git a/comms/core/tests/rpc.rs b/comms/core/tests/rpc.rs new file mode 100644 index 00000000000..90e393012d9 --- /dev/null +++ b/comms/core/tests/rpc.rs @@ -0,0 +1,125 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#![cfg(feature = "rpc")] + +mod greeting_service; +use greeting_service::{GreetingClient, GreetingServer, GreetingService, StreamLargeItemsRequest}; + +mod helpers; +use std::time::Duration; + +use futures::StreamExt; +use helpers::create_comms; +use tari_comms::{ + protocol::rpc::{RpcServer, RpcServerHandle}, + transports::TcpTransport, + CommsNode, +}; +use tari_shutdown::{Shutdown, ShutdownSignal}; +use tari_test_utils::async_assert_eventually; +use tokio::time; + +async fn spawn_node(signal: ShutdownSignal) -> (CommsNode, RpcServerHandle) { + let rpc_server = RpcServer::builder() + .with_unlimited_simultaneous_sessions() + .finish() + .add_service(GreetingServer::new(GreetingService::default())); + + let rpc_server_hnd = rpc_server.get_handle(); + let comms = create_comms(signal) + .add_rpc_server(rpc_server) + .spawn_with_transport(TcpTransport::new()) + .await + .unwrap(); + + comms + .node_identity() + .set_public_address(comms.listening_address().clone()); + (comms, rpc_server_hnd) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn client_prematurely_ends_session() { + env_logger::init(); + let shutdown = Shutdown::new(); + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, mut rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + + { + let mut client = conn1_2.connect_rpc::().await.unwrap(); + + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 1); + + let mut stream = client + .stream_large_items(StreamLargeItemsRequest { + id: 1, + num_items: 100, + item_size: 2300 * 1024, + delay_ms: 50, + }) + .await + .unwrap(); + + let mut count = 0; + while let Some(r) = stream.next().await { + count += 1; + + let data = r.unwrap(); + assert_eq!(data.len(), 2300 * 1024); + // Prematurely drop the stream + if count == 5 { + log::info!("Ending the stream prematurely"); + drop(stream); + break; + } + } + + // Drop stream and client + } + + time::sleep(Duration::from_secs(1)).await; + async_assert_eventually!( + rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(), + expect = 0, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); +} diff --git a/comms/core/tests/rpc_stress.rs b/comms/core/tests/rpc_stress.rs index 3c77537f5fb..708121ca3db 100644 --- a/comms/core/tests/rpc_stress.rs +++ b/comms/core/tests/rpc_stress.rs @@ -40,7 +40,7 @@ use tari_comms::{ use tari_shutdown::{Shutdown, ShutdownSignal}; use tokio::{task, time::Instant}; -pub async fn spawn_node(signal: ShutdownSignal) -> CommsNode { +async fn spawn_node(signal: ShutdownSignal) -> CommsNode { let rpc_server = RpcServer::builder() .with_unlimited_simultaneous_sessions() .finish() @@ -132,6 +132,7 @@ async fn run_stress_test(test_params: Params) { id: i as u64, num_items: num_items as u64, item_size: payload_size as u64, + delay_ms: 0, }) .await .unwrap(); diff --git a/infrastructure/test_utils/src/futures/async_assert_eventually.rs b/infrastructure/test_utils/src/futures/async_assert_eventually.rs index 0449ab101d6..cd7ef71eb2a 100644 --- a/infrastructure/test_utils/src/futures/async_assert_eventually.rs +++ b/infrastructure/test_utils/src/futures/async_assert_eventually.rs @@ -43,7 +43,7 @@ macro_rules! async_assert_eventually { assert!( attempts <= $max_attempts, "assert_eventually assertion failed. Expression did not equal value after {} attempts.", - attempts + attempts - 1 ); tokio::time::sleep($interval).await; value = $check_expr;