From 336f4d68b2753f64a92b2942651d76ae0f20517d Mon Sep 17 00:00:00 2001 From: Stan Bondi Date: Wed, 8 Sep 2021 21:29:07 +0400 Subject: [PATCH] fix: handle stream read error case by explicitly closing the substream (#3321) Description --- - adds yamux stream id to logs to enrich rpc sesssion tracing info - handle stream read error case by explicitly closing the substream - updates rpc tests to use yamux Motivation and Context --- May allow us to diagnose where slowness in a RPC session/substream occurs How Has This Been Tested? --- Tests updated --- .../tari_base_node/src/command_handler.rs | 6 +- .../wallet/src/connectivity_service/test.rs | 3 +- .../tests/output_manager_service/service.rs | 3 +- .../tests/transaction_service/service.rs | 5 +- .../transaction_protocols.rs | 3 +- comms/rpc_macros/src/generator.rs | 6 +- comms/rpc_macros/src/lib.rs | 2 +- comms/rpc_macros/tests/macro.rs | 11 +- comms/src/memsocket/mod.rs | 37 +++- comms/src/multiplexing/yamux.rs | 42 +++-- comms/src/protocol/rpc/client.rs | 75 +++++--- comms/src/protocol/rpc/handshake.rs | 6 +- comms/src/protocol/rpc/mod.rs | 1 + comms/src/protocol/rpc/server/mock.rs | 10 +- comms/src/protocol/rpc/server/mod.rs | 162 +++++++++++------- comms/src/protocol/rpc/server/router.rs | 11 +- .../src/protocol/rpc/test/greeting_service.rs | 6 +- comms/src/protocol/rpc/test/smoke.rs | 64 ++++--- 18 files changed, 279 insertions(+), 174 deletions(-) diff --git a/applications/tari_base_node/src/command_handler.rs b/applications/tari_base_node/src/command_handler.rs index 76d79551ed..ef8fd3a57c 100644 --- a/applications/tari_base_node/src/command_handler.rs +++ b/applications/tari_base_node/src/command_handler.rs @@ -123,10 +123,8 @@ impl CommandHandler { self.executor.spawn(async move { let mut status_line = StatusLine::new(); - let version = format!("v{}", consts::APP_VERSION_NUMBER); - status_line.add_field("", version); - let network = format!("{}", config.network); - status_line.add_field("", network); + status_line.add_field("", format!("v{}", consts::APP_VERSION_NUMBER)); + status_line.add_field("", config.network); status_line.add_field("State", state_info.borrow().state_info.short_desc()); let metadata = node.get_metadata().await.unwrap(); diff --git a/base_layer/wallet/src/connectivity_service/test.rs b/base_layer/wallet/src/connectivity_service/test.rs index 9a8f5a2da9..db486a3027 100644 --- a/base_layer/wallet/src/connectivity_service/test.rs +++ b/base_layer/wallet/src/connectivity_service/test.rs @@ -35,7 +35,6 @@ use tari_comms::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, node_identity::build_node_identity, }, - Substream, }; use tari_shutdown::Shutdown; use tari_test_utils::runtime::spawn_until_shutdown; @@ -46,7 +45,7 @@ use tokio::{ async fn setup() -> ( WalletConnectivityHandle, - MockRpcServer, + MockRpcServer, ConnectivityManagerMockState, Shutdown, ) { diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index ceb4ed4c8b..6cd4f2f031 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -36,7 +36,6 @@ use tari_comms::{ node_identity::build_node_identity, }, types::CommsSecretKey, - Substream, }; use tari_core::{ base_node::rpc::BaseNodeWalletRpcServer, @@ -97,7 +96,7 @@ async fn setup_output_manager_service( OutputManagerHandle, Shutdown, TransactionServiceHandle, - MockRpcServer, Substream>, + MockRpcServer>, Arc, BaseNodeWalletRpcMockState, ConnectivityManagerMockState, diff --git a/base_layer/wallet/tests/transaction_service/service.rs b/base_layer/wallet/tests/transaction_service/service.rs index a7db549994..1f0617874a 100644 --- a/base_layer/wallet/tests/transaction_service/service.rs +++ b/base_layer/wallet/tests/transaction_service/service.rs @@ -72,7 +72,6 @@ use tari_comms::{ }, types::CommsSecretKey, CommsNode, - Substream, }; use tari_comms_dht::outbound::mock::{ create_outbound_service_mock, @@ -244,7 +243,7 @@ pub fn setup_transaction_service_no_comms( Sender>, Sender>, Shutdown, - MockRpcServer, Substream>, + MockRpcServer>, Arc, BaseNodeWalletRpcMockState, ) { @@ -268,7 +267,7 @@ pub fn setup_transaction_service_no_comms_and_oms_backend( Sender>, Sender>, Shutdown, - MockRpcServer, Substream>, + MockRpcServer>, Arc, BaseNodeWalletRpcMockState, ) { diff --git a/base_layer/wallet/tests/transaction_service/transaction_protocols.rs b/base_layer/wallet/tests/transaction_service/transaction_protocols.rs index d3d8c16ca2..31d12b8fd7 100644 --- a/base_layer/wallet/tests/transaction_service/transaction_protocols.rs +++ b/base_layer/wallet/tests/transaction_service/transaction_protocols.rs @@ -32,7 +32,6 @@ use tari_comms::{ }, types::CommsPublicKey, NodeIdentity, - Substream, }; use tari_comms_dht::outbound::mock::{create_outbound_service_mock, OutboundServiceMockState}; use tari_core::{ @@ -96,7 +95,7 @@ pub async fn setup( TransactionServiceResources, ConnectivityManagerMockState, OutboundServiceMockState, - MockRpcServer, Substream>, + MockRpcServer>, Arc, BaseNodeWalletRpcMockState, broadcast::Sender, diff --git a/comms/rpc_macros/src/generator.rs b/comms/rpc_macros/src/generator.rs index 5f44066f19..a6f4ac1917 100644 --- a/comms/rpc_macros/src/generator.rs +++ b/comms/rpc_macros/src/generator.rs @@ -194,15 +194,15 @@ impl RpcCodeGenerator { .collect::(); let client_struct_body = quote! { - pub async fn connect(framed: #dep_mod::CanonicalFraming) -> Result - where TSubstream: #dep_mod::AsyncRead + #dep_mod::AsyncWrite + Unpin + Send + 'static { + pub async fn connect(framed: #dep_mod::CanonicalFraming<#dep_mod::Substream>) -> Result { use #dep_mod::NamedProtocolService; let inner = #dep_mod::RpcClient::connect(Default::default(), framed, Self::PROTOCOL_NAME.into()).await?; Ok(Self { inner }) } pub fn builder() -> #dep_mod::RpcClientBuilder { - #dep_mod::RpcClientBuilder::new() + use #dep_mod::NamedProtocolService; + #dep_mod::RpcClientBuilder::new().with_protocol_id(Self::PROTOCOL_NAME.into()) } #client_methods diff --git a/comms/rpc_macros/src/lib.rs b/comms/rpc_macros/src/lib.rs index 69442606de..0e883e96da 100644 --- a/comms/rpc_macros/src/lib.rs +++ b/comms/rpc_macros/src/lib.rs @@ -12,7 +12,7 @@ mod options; /// /// Generates Tari RPC "harness code" for a given trait. /// -/// ```no_run +/// ```no_run,ignore /// # use tari_comms_rpc_macros::tari_rpc; /// # use tari_comms::protocol::rpc::{Request, Streaming, Response, RpcStatus, RpcServer}; /// use tari_comms::{framing, memsocket::MemorySocket}; diff --git a/comms/rpc_macros/tests/macro.rs b/comms/rpc_macros/tests/macro.rs index b41f3f9914..71f0b7053f 100644 --- a/comms/rpc_macros/tests/macro.rs +++ b/comms/rpc_macros/tests/macro.rs @@ -25,12 +25,12 @@ use prost::Message; use std::{collections::HashMap, ops::AddAssign, sync::Arc}; use tari_comms::{ framing, - memsocket::MemorySocket, message::MessageExt, protocol::{ rpc, rpc::{NamedProtocolService, Request, Response, RpcStatus, RpcStatusCode, Streaming}, }, + test_utils::transport::build_multiplexed_connections, }; use tari_comms_rpc_macros::tari_rpc; use tari_test_utils::unpack_enum; @@ -152,9 +152,12 @@ async fn it_returns_an_error_for_invalid_method_nums() { #[tokio::test] async fn it_generates_client_calls() { - let (sock_client, sock_server) = MemorySocket::new_pair(); - let client = task::spawn(TestClient::connect(framing::canonical(sock_client, 1024))); - let mut sock_server = framing::canonical(sock_server, 1024); + let (_, sock_client, mut sock_server) = build_multiplexed_connections().await; + let client = task::spawn(TestClient::connect(framing::canonical( + sock_client.get_yamux_control().open_stream().await.unwrap(), + 1024, + ))); + let mut sock_server = framing::canonical(sock_server.incoming_mut().next().await.unwrap(), 1024); let mut handshake = rpc::Handshake::new(&mut sock_server); handshake.perform_server_handshake().await.unwrap(); // Wait for client to connect diff --git a/comms/src/memsocket/mod.rs b/comms/src/memsocket/mod.rs index ed77fc6146..caaa683593 100644 --- a/comms/src/memsocket/mod.rs +++ b/comms/src/memsocket/mod.rs @@ -30,6 +30,7 @@ use futures::{ stream::{FusedStream, Stream}, task::{Context, Poll}, }; +use log::*; use std::{ cmp, collections::{hash_map::Entry, HashMap}, @@ -433,6 +434,7 @@ impl AsyncRead for MemorySocket { buf.advance(bytes_to_read); current_buffer.advance(bytes_to_read); + trace!("reading {} bytes", bytes_to_read); bytes_read += bytes_to_read; } @@ -462,11 +464,12 @@ impl AsyncRead for MemorySocket { impl AsyncWrite for MemorySocket { /// Attempt to write bytes from `buf` into the outgoing channel. - fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let len = buf.len(); - match self.outgoing.poll_ready(context) { + match self.outgoing.poll_ready(cx) { Poll::Ready(Ok(())) => { + trace!("writing {} bytes", len); if let Err(e) = self.outgoing.start_send(Bytes::copy_from_slice(buf)) { if e.is_disconnected() { return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e))); @@ -475,6 +478,7 @@ impl AsyncWrite for MemorySocket { // Unbounded channels should only ever have "Disconnected" errors unreachable!(); } + Poll::Ready(Ok(len)) }, Poll::Ready(Err(e)) => { if e.is_disconnected() { @@ -484,19 +488,18 @@ impl AsyncWrite for MemorySocket { // Unbounded channels should only ever have "Disconnected" errors unreachable!(); }, - Poll::Pending => return Poll::Pending, + Poll::Pending => Poll::Pending, } - - Poll::Ready(Ok(len)) } /// Attempt to flush the channel. Cannot Fail. - fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + trace!("flush"); Poll::Ready(Ok(())) } /// Attempt to close the channel. Cannot Fail. - fn poll_shutdown(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context) -> Poll> { self.outgoing.close_channel(); Poll::Ready(Ok(())) @@ -506,7 +509,8 @@ impl AsyncWrite for MemorySocket { #[cfg(test)] mod test { use super::*; - use crate::runtime; + use crate::{framing, runtime}; + use futures::SinkExt; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_stream::StreamExt; @@ -705,4 +709,21 @@ mod test { Ok(()) } + + #[runtime::test] + async fn read_and_write_canonical_framing() -> io::Result<()> { + let (a, b) = MemorySocket::new_pair(); + let mut a = framing::canonical(a, 1024); + let mut b = framing::canonical(b, 1024); + + a.send(Bytes::from_static(b"frame-1")).await?; + b.send(Bytes::from_static(b"frame-2")).await?; + let msg = b.next().await.unwrap()?; + assert_eq!(&msg[..], b"frame-1"); + + let msg = a.next().await.unwrap()?; + assert_eq!(&msg[..], b"frame-2"); + + Ok(()) + } } diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 1723033739..28a14dfff7 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -166,7 +166,7 @@ pub struct IncomingSubstreams { } impl IncomingSubstreams { - pub fn new(inner: IncomingRx, substream_counter: SubstreamCounter, shutdown: Shutdown) -> Self { + pub(self) fn new(inner: IncomingRx, substream_counter: SubstreamCounter, shutdown: Shutdown) -> Self { Self { inner, substream_counter, @@ -205,6 +205,12 @@ pub struct Substream { counter_guard: CounterGuard, } +impl Substream { + pub fn id(&self) -> yamux::StreamId { + self.stream.get_ref().id() + } +} + impl tokio::io::AsyncRead for Substream { fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { Pin::new(&mut self.stream).poll_read(cx, buf) @@ -242,13 +248,17 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static } } - #[tracing::instrument(name = "yamux::incoming_worker::run", skip(self))] + #[tracing::instrument(name = "yamux::incoming_worker::run", skip(self), fields(connection = %self.connection))] pub async fn run(mut self) { loop { tokio::select! { biased; - _ = &mut self.shutdown_signal => { + _ = self.shutdown_signal.wait() => { + debug!( + target: LOG_TARGET, + "{} Yamux connection shutdown", self.connection + ); let mut control = self.connection.control(); if let Err(err) = control.close().await { error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err); @@ -259,11 +269,13 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static result = self.connection.next_stream() => { match result { Ok(Some(stream)) => { - event!(Level::TRACE, "yamux::stream received {}", stream);if self.sender.send(stream).await.is_err() { + event!(Level::TRACE, "yamux::incoming_worker::new_stream {}", stream); + if self.sender.send(stream).await.is_err() { debug!( target: LOG_TARGET, - "Incoming peer substream task is shutting down because the internal stream sender channel \ - was closed" + "{} Incoming peer substream task is shutting down because the internal stream sender channel \ + was closed", + self.connection ); break; } @@ -271,19 +283,23 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static Ok(None) =>{ debug!( target: LOG_TARGET, - "Incoming peer substream completed. IncomingWorker exiting" + "{} Incoming peer substream completed. IncomingWorker exiting", + self.connection ); break; } Err(err) => { event!( - Level::ERROR, - "Incoming peer substream task received an error because '{}'", - err - ); - error!( + Level::ERROR, + "{} Incoming peer substream task received an error because '{}'", + self.connection, + err + ); + error!( target: LOG_TARGET, - "Incoming peer substream task received an error because '{}'", err + "{} Incoming peer substream task received an error because '{}'", + self.connection, + err ); break; }, diff --git a/comms/src/protocol/rpc/client.rs b/comms/src/protocol/rpc/client.rs index 5467befeba..a6d6554f3e 100644 --- a/comms/src/protocol/rpc/client.rs +++ b/comms/src/protocol/rpc/client.rs @@ -38,6 +38,7 @@ use crate::{ ProtocolId, }, runtime::task, + Substream, }; use bytes::Bytes; use futures::{ @@ -60,7 +61,6 @@ use std::{ }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tokio::{ - io::{AsyncRead, AsyncWrite}, sync::{mpsc, oneshot, Mutex}, time, }; @@ -76,14 +76,11 @@ pub struct RpcClient { impl RpcClient { /// Create a new RpcClient using the given framed substream and perform the RPC handshake. - pub async fn connect( + pub async fn connect( config: RpcClientConfig, - framed: CanonicalFraming, + framed: CanonicalFraming, protocol_name: ProtocolId, - ) -> Result - where - TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, - { + ) -> Result { let (request_tx, request_rx) = mpsc::channel(1); let shutdown = Shutdown::new(); let shutdown_signal = shutdown.to_signal(); @@ -224,14 +221,14 @@ where TClient: From + NamedProtocolService self } - pub(crate) fn with_protocol_id(mut self, protocol_id: ProtocolId) -> Self { + /// Set the protocol ID associated with this client. This is used for logging purposes only. + pub fn with_protocol_id(mut self, protocol_id: ProtocolId) -> Self { self.protocol_id = Some(protocol_id); self } /// Negotiates and establishes a session to the peer's RPC service - pub async fn connect(self, framed: CanonicalFraming) -> Result - where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static { + pub async fn connect(self, framed: CanonicalFraming) -> Result { RpcClient::connect( self.config, framed, @@ -346,10 +343,10 @@ impl Service> for ClientConnector { } } -pub struct RpcClientWorker { +struct RpcClientWorker { config: RpcClientConfig, request_rx: mpsc::Receiver, - framed: CanonicalFraming, + framed: CanonicalFraming, // Request ids are limited to u16::MAX because varint encoding is used over the wire and the magnitude of the value // sent determines the byte size. A u16 will be more than enough for the purpose next_request_id: u16, @@ -359,13 +356,11 @@ pub struct RpcClientWorker { shutdown_signal: ShutdownSignal, } -impl RpcClientWorker -where TSubstream: AsyncRead + AsyncWrite + Unpin + Send -{ - pub fn new( +impl RpcClientWorker { + pub(self) fn new( config: RpcClientConfig, request_rx: mpsc::Receiver, - framed: CanonicalFraming, + framed: CanonicalFraming, ready_tx: oneshot::Sender>, protocol_id: ProtocolId, shutdown_signal: ShutdownSignal, @@ -386,11 +381,16 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send String::from_utf8_lossy(&self.protocol_id) } - #[tracing::instrument(name = "rpc_client_worker run", skip(self), fields(next_request_id= self.next_request_id))] + fn stream_id(&self) -> yamux::StreamId { + self.framed.get_ref().id() + } + + #[tracing::instrument(name = "rpc_client_worker run", skip(self), fields(next_request_id = self.next_request_id))] async fn run(mut self) { debug!( target: LOG_TARGET, - "Performing client handshake for '{}'", + "(stream={}) Performing client handshake for '{}'", + self.stream_id(), self.protocol_name() ); let start = Instant::now(); @@ -400,7 +400,8 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send let latency = start.elapsed(); debug!( target: LOG_TARGET, - "RPC Session ({}) negotiation completed. Latency: {:.0?}", + "(stream={}) RPC Session ({}) negotiation completed. Latency: {:.0?}", + self.stream_id(), self.protocol_name(), latency ); @@ -428,7 +429,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send match req { Some(req) => { if let Err(err) = self.handle_request(req).await { - error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); + error!(target: LOG_TARGET, "(stream={}) Unexpected error: {}. Worker is terminating.", self.stream_id(), err); break; } } @@ -439,12 +440,18 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send } if let Err(err) = self.framed.close().await { - debug!(target: LOG_TARGET, "IO Error when closing substream: {}", err); + debug!( + target: LOG_TARGET, + "(stream={}) IO Error when closing substream: {}", + self.stream_id(), + err + ); } debug!( target: LOG_TARGET, - "RpcClientWorker ({}) terminated.", + "(stream={}) RpcClientWorker ({}) terminated.", + self.stream_id(), self.protocol_name() ); } @@ -477,14 +484,20 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send debug!( target: LOG_TARGET, - "Ping (protocol {}) sent in {:.2?}", + "(stream={}) Ping (protocol {}) sent in {:.2?}", + self.stream_id(), self.protocol_name(), start.elapsed() ); let resp = match self.read_reply().await { Ok(resp) => resp, Err(RpcError::ReplyTimeout) => { - debug!(target: LOG_TARGET, "Ping timed out after {:.0?}", start.elapsed()); + debug!( + target: LOG_TARGET, + "(stream={}) Ping timed out after {:.0?}", + self.stream_id(), + start.elapsed() + ); let _ = reply.send(Err(RpcStatus::timed_out("Response timed out"))); return Ok(()); }, @@ -499,7 +512,12 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send let resp_flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8); if !resp_flags.contains(RpcMessageFlags::ACK) { - warn!(target: LOG_TARGET, "Invalid ping response {:?}", resp); + warn!( + target: LOG_TARGET, + "(stream={}) Invalid ping response {:?}", + self.stream_id(), + resp + ); let _ = reply.send(Err(RpcStatus::protocol_error(format!( "Received invalid ping response on protocol '{}'", self.protocol_name() @@ -613,8 +631,9 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send if response_tx.is_closed() { warn!( target: LOG_TARGET, - "Response receiver was dropped before the response/stream could complete for protocol {}, \ - the stream will continue until completed", + "(stream={}) Response receiver was dropped before the response/stream could complete for \ + protocol {}, the stream will continue until completed", + self.framed.get_ref().id(), self.protocol_name() ); } else { diff --git a/comms/src/protocol/rpc/handshake.rs b/comms/src/protocol/rpc/handshake.rs index b39c15e6d7..4e07a0294d 100644 --- a/comms/src/protocol/rpc/handshake.rs +++ b/comms/src/protocol/rpc/handshake.rs @@ -138,15 +138,19 @@ where T: AsyncRead + AsyncWrite + Unpin let msg = proto::rpc::RpcSession { supported_versions: SUPPORTED_RPC_VERSIONS.to_vec(), }; + let payload = msg.to_encoded_bytes(); + debug!(target: LOG_TARGET, "Sending client handshake ({} bytes)", payload.len()); // It is possible that the server rejects the session and closes the substream before we've had a chance to send // anything. Rather than returning an IO error, let's ignore the send error and see if we can receive anything, // or return an IO error similarly to what send would have done. - if let Err(err) = self.framed.send(msg.to_encoded_bytes().into()).await { + if let Err(err) = self.framed.send(payload.into()).await { warn!( target: LOG_TARGET, "IO error when sending new session handshake to peer: {}", err ); + panic!(); } + self.framed.flush().await?; match self.recv_next_frame().await { Ok(Some(Ok(msg))) => { let msg = proto::rpc::RpcSessionReply::decode(&mut msg.freeze())?; diff --git a/comms/src/protocol/rpc/mod.rs b/comms/src/protocol/rpc/mod.rs index 2244979adf..33208df391 100644 --- a/comms/src/protocol/rpc/mod.rs +++ b/comms/src/protocol/rpc/mod.rs @@ -63,6 +63,7 @@ pub const RPC_MAX_FRAME_SIZE: usize = 4 * 1024 * 1024; // 4 MiB pub mod __macro_reexports { pub use crate::{ framing::CanonicalFraming, + multiplexing::Substream, protocol::{ rpc::{ client_pool::RpcPoolClient, diff --git a/comms/src/protocol/rpc/server/mock.rs b/comms/src/protocol/rpc/server/mock.rs index 69659ba03b..dae0f9ce93 100644 --- a/comms/src/protocol/rpc/server/mock.rs +++ b/comms/src/protocol/rpc/server/mock.rs @@ -194,14 +194,14 @@ impl RpcCommsProvider for MockCommsProvider { } } -pub struct MockRpcServer { - inner: Option>, - protocol_tx: ProtocolNotificationTx, +pub struct MockRpcServer { + inner: Option>, + protocol_tx: ProtocolNotificationTx, our_node: Arc, request_tx: mpsc::Sender, } -impl MockRpcServer +impl MockRpcServer where TSvc: MakeService< ProtocolId, @@ -259,7 +259,7 @@ where } } -impl MockRpcServer { +impl MockRpcServer { pub async fn create_mockimpl_connection(&self, peer: Peer) -> PeerConnection { // MockRpcImpl accepts any protocol self.create_connection(peer, ProtocolId::new()).await diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index 88fdb7ee61..d2b5f842eb 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -52,19 +52,16 @@ use crate::{ proto, protocol::{ProtocolEvent, ProtocolId, ProtocolNotification, ProtocolNotificationRx}, Bytes, + Substream, }; use futures::SinkExt; use prost::Message; use std::{ - borrow::Cow, future::Future, + sync::Arc, time::{Duration, Instant}, }; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc, - time, -}; +use tokio::{sync::mpsc, time}; use tokio_stream::StreamExt; use tower::Service; use tower_make::MakeService; @@ -116,14 +113,13 @@ impl RpcServer { RpcServerHandle::new(self.request_tx.clone()) } - pub(super) async fn serve( + pub(super) async fn serve( self, service: S, - notifications: ProtocolNotificationRx, + notifications: ProtocolNotificationRx, comms_provider: TCommsProvider, ) -> Result<(), RpcServerError> where - TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: MakeService< ProtocolId, Request, @@ -197,18 +193,17 @@ impl Default for RpcServerBuilder { } } -pub(super) struct PeerRpcServer { +pub(super) struct PeerRpcServer { executor: BoundedExecutor, config: RpcServerBuilder, service: TSvc, - protocol_notifications: Option>, + protocol_notifications: Option>, comms_provider: TCommsProvider, request_rx: mpsc::Receiver, } -impl PeerRpcServer +impl PeerRpcServer where - TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, TSvc: MakeService< ProtocolId, Request, @@ -225,7 +220,7 @@ where fn new( config: RpcServerBuilder, service: TSvc, - protocol_notifications: ProtocolNotificationRx, + protocol_notifications: ProtocolNotificationRx, comms_provider: TCommsProvider, request_rx: mpsc::Receiver, ) -> Self { @@ -289,7 +284,7 @@ where #[tracing::instrument(name = "rpc::server::new_client_connection", skip(self, notification), err)] async fn handle_protocol_notification( &mut self, - notification: ProtocolNotification, + notification: ProtocolNotification, ) -> Result<(), RpcServerError> { match notification.event { ProtocolEvent::NewInboundSubstream(node_id, substream) => { @@ -318,7 +313,7 @@ where &mut self, protocol: ProtocolId, node_id: NodeId, - mut framed: CanonicalFraming, + mut framed: CanonicalFraming, ) -> Result<(), RpcServerError> { let mut handshake = Handshake::new(&mut framed).with_timeout(self.config.handshake_timeout); @@ -357,14 +352,14 @@ where "Server negotiated RPC v{} with client node `{}`", version, node_id ); - let service = ActivePeerRpcService { - config: self.config.clone(), + let service = ActivePeerRpcService::new( + self.config.clone(), protocol, - node_id: node_id.clone(), - framed, + node_id.clone(), service, - comms_provider: self.comms_provider.clone(), - }; + framed, + self.comms_provider.clone(), + ); self.executor .try_spawn(service.start()) @@ -374,64 +369,91 @@ where } } -struct ActivePeerRpcService { +struct ActivePeerRpcService { config: RpcServerBuilder, protocol: ProtocolId, node_id: NodeId, service: TSvc, - framed: CanonicalFraming, + framed: CanonicalFraming, comms_provider: TCommsProvider, + logging_context_string: Arc, } -impl ActivePeerRpcService +impl ActivePeerRpcService where - TSubstream: AsyncRead + AsyncWrite + Unpin, TSvc: Service, Response = Response, Error = RpcStatus>, TCommsProvider: RpcCommsProvider + Send + Clone + 'static, { + pub(self) fn new( + config: RpcServerBuilder, + protocol: ProtocolId, + node_id: NodeId, + service: TSvc, + framed: CanonicalFraming, + comms_provider: TCommsProvider, + ) -> Self { + Self { + logging_context_string: Arc::new(format!( + "stream_id: {}, peer: {}, protocol: {}", + framed.get_ref().id(), + node_id, + String::from_utf8_lossy(&protocol) + )), + + config, + protocol, + node_id, + service, + framed, + comms_provider, + } + } + async fn start(mut self) { debug!( target: LOG_TARGET, - "(Peer = `{}`) Rpc server ({}) started.", - self.node_id, - self.protocol_name() + "({}) Rpc server started.", self.logging_context_string, ); if let Err(err) = self.run().await { error!( target: LOG_TARGET, - "(Peer = `{}`) Rpc server ({}) exited with an error: {}", - self.node_id, - self.protocol_name(), - err + "({}) Rpc server exited with an error: {}", self.logging_context_string, err ); } debug!( target: LOG_TARGET, - "(Peer = {}) Rpc service ({}) shutdown", - self.node_id, - self.protocol_name() + "({}) Rpc service shutdown", self.logging_context_string ); } - fn protocol_name(&self) -> Cow<'_, str> { - String::from_utf8_lossy(&self.protocol) - } - async fn run(&mut self) -> Result<(), RpcServerError> { while let Some(result) = self.framed.next().await { - let start = Instant::now(); - if let Err(err) = self.handle(result?.freeze()).await { - self.framed.close().await?; - return Err(err); + match result { + Ok(frame) => { + let start = Instant::now(); + if let Err(err) = self.handle(frame.freeze()).await { + self.framed.close().await?; + return Err(err); + } + let elapsed = start.elapsed(); + debug!( + target: LOG_TARGET, + "({}) RPC request completed in {:.0?}{}", + self.logging_context_string, + elapsed, + if elapsed.as_secs() > 5 { " (LONG REQUEST)" } else { "" } + ); + }, + Err(err) => { + if let Err(err) = self.framed.close().await { + error!( + target: LOG_TARGET, + "({}) Failed to close substream after socket error: {}", self.logging_context_string, err + ); + } + return Err(err.into()); + }, } - let elapsed = start.elapsed(); - debug!( - target: LOG_TARGET, - "RPC ({}) request completed in {:.0?}{}", - self.protocol_name(), - elapsed, - if elapsed.as_secs() > 5 { " (LONG REQUEST)" } else { "" } - ); } self.framed.close().await?; @@ -450,7 +472,7 @@ where if deadline < self.config.minimum_client_deadline { debug!( target: LOG_TARGET, - "[Peer=`{}`] Client has an invalid deadline. {}", self.node_id, decoded_msg + "({}) Client has an invalid deadline. {}", self.logging_context_string, decoded_msg ); // Let the client know that they have disobeyed the spec let status = RpcStatus::bad_request(format!( @@ -471,9 +493,7 @@ where if msg_flags.contains(RpcMessageFlags::ACK) { debug!( target: LOG_TARGET, - "[Peer=`{}` {}] sending ACK response.", - self.node_id, - self.protocol_name() + "({}) sending ACK response.", self.logging_context_string ); let ack = proto::rpc::RpcResponse { request_id, @@ -487,7 +507,7 @@ where debug!( target: LOG_TARGET, - "[Peer=`{}`] Got request {}", self.node_id, decoded_msg + "({}) Got request {}", self.logging_context_string, decoded_msg ); let req = Request::with_context( @@ -496,7 +516,12 @@ where decoded_msg.message.into(), ); - let service_call = log_timing(request_id, "service call", self.service.call(req)); + let service_call = log_timing( + self.logging_context_string.clone(), + request_id, + "service call", + self.service.call(req), + ); let service_result = time::timeout(deadline, service_call).await; let service_result = match service_result { Ok(v) => v, @@ -545,7 +570,12 @@ where let mut message = body.into_message(); loop { - let msg_read = log_timing(request_id, "message read", message.next()); + let msg_read = log_timing( + self.logging_context_string.clone(), + request_id, + "message read", + message.next(), + ); match time::timeout(deadline, msg_read).await { Ok(Some(msg)) => { let resp = match msg { @@ -573,8 +603,13 @@ where }, }; - let is_valid = - log_timing(request_id, "transmit", self.send_response(request_id, resp)).await?; + let is_valid = log_timing( + self.logging_context_string.clone(), + request_id, + "transmit", + self.send_response(request_id, resp), + ) + .await?; if !is_valid { break; @@ -647,14 +682,15 @@ where } } -async fn log_timing>(request_id: u32, tag: &str, fut: F) -> R { +async fn log_timing>(context_str: Arc, request_id: u32, tag: &str, fut: F) -> R { let t = Instant::now(); let span = span!(Level::TRACE, "rpc::internal::timing::{}::{}", request_id, tag); let ret = fut.instrument(span).await; let elapsed = t.elapsed(); trace!( target: LOG_TARGET, - "RPC TIMING(REQ_ID={}): '{}' took {:.2}s{}", + "({}) RPC TIMING(REQ_ID={}): '{}' took {:.2}s{}", + context_str, request_id, tag, elapsed.as_secs_f32(), diff --git a/comms/src/protocol/rpc/server/router.rs b/comms/src/protocol/rpc/server/router.rs index 1d40988075..342454e122 100644 --- a/comms/src/protocol/rpc/server/router.rs +++ b/comms/src/protocol/rpc/server/router.rs @@ -42,6 +42,7 @@ use crate::{ }, runtime::task, Bytes, + Substream, }; use futures::{ future::BoxFuture, @@ -49,10 +50,7 @@ use futures::{ FutureExt, }; use std::sync::Arc; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc, -}; +use tokio::sync::mpsc; use tower::Service; use tower_make::MakeService; @@ -133,13 +131,12 @@ where >>::Future: Send + 'static, { /// Start all services - pub(crate) async fn serve( + pub(crate) async fn serve( self, - protocol_notifications: ProtocolNotificationRx, + protocol_notifications: ProtocolNotificationRx, comms_provider: TCommsProvider, ) -> Result<(), RpcError> where - TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, TCommsProvider: RpcCommsProvider + Clone + Send + 'static, { self.server diff --git a/comms/src/protocol/rpc/test/greeting_service.rs b/comms/src/protocol/rpc/test/greeting_service.rs index b303ce5fc7..f66221b5ae 100644 --- a/comms/src/protocol/rpc/test/greeting_service.rs +++ b/comms/src/protocol/rpc/test/greeting_service.rs @@ -27,6 +27,7 @@ use crate::{ ProtocolId, }, utils, + Substream, }; use core::iter; use std::{ @@ -393,14 +394,13 @@ impl __rpc_deps::NamedProtocolService for GreetingClient { } impl GreetingClient { - pub async fn connect(framed: __rpc_deps::CanonicalFraming) -> Result - where TSubstream: __rpc_deps::AsyncRead + __rpc_deps::AsyncWrite + Unpin + Send + 'static { + pub async fn connect(framed: __rpc_deps::CanonicalFraming) -> Result { let inner = __rpc_deps::RpcClient::connect(Default::default(), framed, Self::PROTOCOL_NAME.into()).await?; Ok(Self { inner }) } pub fn builder() -> __rpc_deps::RpcClientBuilder { - __rpc_deps::RpcClientBuilder::new() + __rpc_deps::RpcClientBuilder::new().with_protocol_id(Self::PROTOCOL_NAME.into()) } pub async fn say_hello(&mut self, request: SayHelloRequest) -> Result { diff --git a/comms/src/protocol/rpc/test/smoke.rs b/comms/src/protocol/rpc/test/smoke.rs index bc0f1bb25d..553c0001cd 100644 --- a/comms/src/protocol/rpc/test/smoke.rs +++ b/comms/src/protocol/rpc/test/smoke.rs @@ -22,7 +22,7 @@ use crate::{ framing, - memsocket::MemorySocket, + multiplexing::Yamux, protocol::{ rpc::{ context::RpcCommsBackend, @@ -50,10 +50,11 @@ use crate::{ ProtocolNotification, }, runtime, - test_utils::node_identity::build_node_identity, + test_utils::{node_identity::build_node_identity, transport::build_multiplexed_connections}, NodeIdentity, + Substream, }; -use futures::{future, future::Either, StreamExt}; +use futures::StreamExt; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::Shutdown; @@ -67,7 +68,7 @@ pub(super) async fn setup_service( service_impl: T, num_concurrent_sessions: usize, ) -> ( - mpsc::Sender>, + mpsc::Sender>, task::JoinHandle<()>, RpcCommsBackend, Shutdown, @@ -86,11 +87,10 @@ pub(super) async fn setup_service( .add_service(GreetingServer::new(service_impl)) .serve(notif_rx, context); - futures::pin_mut!(fut); - - match future::select(shutdown_signal, fut).await { - Either::Left(_) => {}, - Either::Right((r, _)) => r.unwrap(), + tokio::select! { + biased; + _ = shutdown_signal => {}, + r = fut => r.unwrap(), } } }); @@ -100,31 +100,35 @@ pub(super) async fn setup_service( pub(super) async fn setup( service_impl: T, num_concurrent_sessions: usize, -) -> (MemorySocket, task::JoinHandle<()>, Arc, Shutdown) { +) -> (Yamux, Yamux, task::JoinHandle<()>, Arc, Shutdown) { let (notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; - let (inbound, outbound) = MemorySocket::new_pair(); - let node_identity = build_node_identity(Default::default()); + let (_, inbound, outbound) = build_multiplexed_connections().await; + let substream = outbound.get_yamux_control().open_stream().await.unwrap(); + let node_identity = build_node_identity(Default::default()); // Notify that a peer wants to speak the greeting RPC protocol context.peer_manager().add_peer(node_identity.to_peer()).await.unwrap(); notif_tx .send(ProtocolNotification::new( ProtocolId::from_static(b"/test/greeting/1.0"), - ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), inbound), + ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), )) .await .unwrap(); - (outbound, server_hnd, node_identity, shutdown) + (inbound, outbound, server_hnd, node_identity, shutdown) } #[runtime::test] async fn request_response_errors_and_streaming() { - let (socket, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; + let (mut muxer, _outbound, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) + .with_deadline_grace_period(Duration::from_secs(5)) + .with_handshake_timeout(Duration::from_secs(5)) .connect(framed) .await .unwrap(); @@ -200,7 +204,8 @@ async fn request_response_errors_and_streaming() { #[runtime::test] async fn concurrent_requests() { - let (socket, _, _, _shutdown) = setup(GreetingService::default(), 1).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::default(), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() @@ -240,7 +245,8 @@ async fn concurrent_requests() { #[runtime::test] async fn response_too_big() { - let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, RPC_MAX_FRAME_SIZE); let mut client = GreetingClient::builder().connect(framed).await.unwrap(); @@ -261,7 +267,8 @@ async fn response_too_big() { #[runtime::test] async fn ping_latency() { - let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, RPC_MAX_FRAME_SIZE); let mut client = GreetingClient::builder().connect(framed).await.unwrap(); @@ -274,7 +281,8 @@ async fn ping_latency() { #[runtime::test] async fn server_shutdown_before_connect() { - let (socket, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; + let (mut muxer, _outbound, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); shutdown.trigger(); @@ -288,7 +296,8 @@ async fn server_shutdown_before_connect() { #[runtime::test] async fn timeout() { let delay = Arc::new(RwLock::new(Duration::from_secs(10))); - let (socket, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() .with_deadline(Duration::from_secs(1)) @@ -313,7 +322,9 @@ async fn timeout() { async fn unknown_protocol() { let (notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; - let (inbound, socket) = MemorySocket::new_pair(); + let (_, inbound, mut outbound) = build_multiplexed_connections().await; + let in_substream = inbound.get_yamux_control().open_stream().await.unwrap(); + let node_identity = build_node_identity(Default::default()); // This case should never happen because protocols are preregistered with the connection manager and so a @@ -322,12 +333,13 @@ async fn unknown_protocol() { notif_tx .send(ProtocolNotification::new( ProtocolId::from_static(b"this-is-junk"), - ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), inbound), + ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), in_substream), )) .await .unwrap(); - let framed = framing::canonical(socket, 1024); + let out_socket = outbound.incoming_mut().next().await.unwrap(); + let framed = framing::canonical(out_socket, 1024); let err = GreetingClient::connect(framed).await.unwrap_err(); assert!(matches!( err, @@ -337,7 +349,8 @@ async fn unknown_protocol() { #[runtime::test] async fn rejected_no_sessions_available() { - let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); let err = GreetingClient::builder().connect(framed).await.unwrap_err(); assert!(matches!( @@ -349,7 +362,8 @@ async fn rejected_no_sessions_available() { #[runtime::test] async fn stream_still_works_after_cancel() { let service_impl = GreetingService::default(); - let (socket, _, _, _shutdown) = setup(service_impl.clone(), 1).await; + let (mut muxer, _outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder()