Skip to content

Commit

Permalink
fix: handle stream read error case by explicitly closing the substream (
Browse files Browse the repository at this point in the history
#3321)


Description
---
- adds yamux stream id to logs to enrich rpc sesssion tracing info
- handle stream read error case by explicitly closing the substream
- updates rpc tests to use yamux

Motivation and Context
---
May allow us to diagnose where slowness in a RPC session/substream occurs 

How Has This Been Tested?
---
Tests updated
  • Loading branch information
sdbondi committed Sep 8, 2021
1 parent 47bafbf commit 336f4d6
Show file tree
Hide file tree
Showing 18 changed files with 279 additions and 174 deletions.
6 changes: 2 additions & 4 deletions applications/tari_base_node/src/command_handler.rs
Expand Up @@ -123,10 +123,8 @@ impl CommandHandler {

self.executor.spawn(async move {
let mut status_line = StatusLine::new();
let version = format!("v{}", consts::APP_VERSION_NUMBER);
status_line.add_field("", version);
let network = format!("{}", config.network);
status_line.add_field("", network);
status_line.add_field("", format!("v{}", consts::APP_VERSION_NUMBER));
status_line.add_field("", config.network);
status_line.add_field("State", state_info.borrow().state_info.short_desc());

let metadata = node.get_metadata().await.unwrap();
Expand Down
3 changes: 1 addition & 2 deletions base_layer/wallet/src/connectivity_service/test.rs
Expand Up @@ -35,7 +35,6 @@ use tari_comms::{
mocks::{create_connectivity_mock, ConnectivityManagerMockState},
node_identity::build_node_identity,
},
Substream,
};
use tari_shutdown::Shutdown;
use tari_test_utils::runtime::spawn_until_shutdown;
Expand All @@ -46,7 +45,7 @@ use tokio::{

async fn setup() -> (
WalletConnectivityHandle,
MockRpcServer<MockRpcImpl, Substream>,
MockRpcServer<MockRpcImpl>,
ConnectivityManagerMockState,
Shutdown,
) {
Expand Down
3 changes: 1 addition & 2 deletions base_layer/wallet/tests/output_manager_service/service.rs
Expand Up @@ -36,7 +36,6 @@ use tari_comms::{
node_identity::build_node_identity,
},
types::CommsSecretKey,
Substream,
};
use tari_core::{
base_node::rpc::BaseNodeWalletRpcServer,
Expand Down Expand Up @@ -97,7 +96,7 @@ async fn setup_output_manager_service<T: OutputManagerBackend + 'static>(
OutputManagerHandle,
Shutdown,
TransactionServiceHandle,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>, Substream>,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>>,
Arc<NodeIdentity>,
BaseNodeWalletRpcMockState,
ConnectivityManagerMockState,
Expand Down
5 changes: 2 additions & 3 deletions base_layer/wallet/tests/transaction_service/service.rs
Expand Up @@ -72,7 +72,6 @@ use tari_comms::{
},
types::CommsSecretKey,
CommsNode,
Substream,
};
use tari_comms_dht::outbound::mock::{
create_outbound_service_mock,
Expand Down Expand Up @@ -244,7 +243,7 @@ pub fn setup_transaction_service_no_comms(
Sender<DomainMessage<base_node_proto::BaseNodeServiceResponse>>,
Sender<DomainMessage<proto::TransactionCancelledMessage>>,
Shutdown,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>, Substream>,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>>,
Arc<NodeIdentity>,
BaseNodeWalletRpcMockState,
) {
Expand All @@ -268,7 +267,7 @@ pub fn setup_transaction_service_no_comms_and_oms_backend(
Sender<DomainMessage<base_node_proto::BaseNodeServiceResponse>>,
Sender<DomainMessage<proto::TransactionCancelledMessage>>,
Shutdown,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>, Substream>,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>>,
Arc<NodeIdentity>,
BaseNodeWalletRpcMockState,
) {
Expand Down
Expand Up @@ -32,7 +32,6 @@ use tari_comms::{
},
types::CommsPublicKey,
NodeIdentity,
Substream,
};
use tari_comms_dht::outbound::mock::{create_outbound_service_mock, OutboundServiceMockState};
use tari_core::{
Expand Down Expand Up @@ -96,7 +95,7 @@ pub async fn setup(
TransactionServiceResources<TransactionServiceSqliteDatabase>,
ConnectivityManagerMockState,
OutboundServiceMockState,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>, Substream>,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>>,
Arc<NodeIdentity>,
BaseNodeWalletRpcMockState,
broadcast::Sender<Duration>,
Expand Down
6 changes: 3 additions & 3 deletions comms/rpc_macros/src/generator.rs
Expand Up @@ -194,15 +194,15 @@ impl RpcCodeGenerator {
.collect::<TokenStream>();

let client_struct_body = quote! {
pub async fn connect<TSubstream>(framed: #dep_mod::CanonicalFraming<TSubstream>) -> Result<Self, #dep_mod::RpcError>
where TSubstream: #dep_mod::AsyncRead + #dep_mod::AsyncWrite + Unpin + Send + 'static {
pub async fn connect(framed: #dep_mod::CanonicalFraming<#dep_mod::Substream>) -> Result<Self, #dep_mod::RpcError> {
use #dep_mod::NamedProtocolService;
let inner = #dep_mod::RpcClient::connect(Default::default(), framed, Self::PROTOCOL_NAME.into()).await?;
Ok(Self { inner })
}

pub fn builder() -> #dep_mod::RpcClientBuilder<Self> {
#dep_mod::RpcClientBuilder::new()
use #dep_mod::NamedProtocolService;
#dep_mod::RpcClientBuilder::new().with_protocol_id(Self::PROTOCOL_NAME.into())
}

#client_methods
Expand Down
2 changes: 1 addition & 1 deletion comms/rpc_macros/src/lib.rs
Expand Up @@ -12,7 +12,7 @@ mod options;
///
/// Generates Tari RPC "harness code" for a given trait.
///
/// ```no_run
/// ```no_run,ignore
/// # use tari_comms_rpc_macros::tari_rpc;
/// # use tari_comms::protocol::rpc::{Request, Streaming, Response, RpcStatus, RpcServer};
/// use tari_comms::{framing, memsocket::MemorySocket};
Expand Down
11 changes: 7 additions & 4 deletions comms/rpc_macros/tests/macro.rs
Expand Up @@ -25,12 +25,12 @@ use prost::Message;
use std::{collections::HashMap, ops::AddAssign, sync::Arc};
use tari_comms::{
framing,
memsocket::MemorySocket,
message::MessageExt,
protocol::{
rpc,
rpc::{NamedProtocolService, Request, Response, RpcStatus, RpcStatusCode, Streaming},
},
test_utils::transport::build_multiplexed_connections,
};
use tari_comms_rpc_macros::tari_rpc;
use tari_test_utils::unpack_enum;
Expand Down Expand Up @@ -152,9 +152,12 @@ async fn it_returns_an_error_for_invalid_method_nums() {

#[tokio::test]
async fn it_generates_client_calls() {
let (sock_client, sock_server) = MemorySocket::new_pair();
let client = task::spawn(TestClient::connect(framing::canonical(sock_client, 1024)));
let mut sock_server = framing::canonical(sock_server, 1024);
let (_, sock_client, mut sock_server) = build_multiplexed_connections().await;
let client = task::spawn(TestClient::connect(framing::canonical(
sock_client.get_yamux_control().open_stream().await.unwrap(),
1024,
)));
let mut sock_server = framing::canonical(sock_server.incoming_mut().next().await.unwrap(), 1024);
let mut handshake = rpc::Handshake::new(&mut sock_server);
handshake.perform_server_handshake().await.unwrap();
// Wait for client to connect
Expand Down
37 changes: 29 additions & 8 deletions comms/src/memsocket/mod.rs
Expand Up @@ -30,6 +30,7 @@ use futures::{
stream::{FusedStream, Stream},
task::{Context, Poll},
};
use log::*;
use std::{
cmp,
collections::{hash_map::Entry, HashMap},
Expand Down Expand Up @@ -433,6 +434,7 @@ impl AsyncRead for MemorySocket {
buf.advance(bytes_to_read);

current_buffer.advance(bytes_to_read);
trace!("reading {} bytes", bytes_to_read);

bytes_read += bytes_to_read;
}
Expand Down Expand Up @@ -462,11 +464,12 @@ impl AsyncRead for MemorySocket {

impl AsyncWrite for MemorySocket {
/// Attempt to write bytes from `buf` into the outgoing channel.
fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let len = buf.len();

match self.outgoing.poll_ready(context) {
match self.outgoing.poll_ready(cx) {
Poll::Ready(Ok(())) => {
trace!("writing {} bytes", len);
if let Err(e) = self.outgoing.start_send(Bytes::copy_from_slice(buf)) {
if e.is_disconnected() {
return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e)));
Expand All @@ -475,6 +478,7 @@ impl AsyncWrite for MemorySocket {
// Unbounded channels should only ever have "Disconnected" errors
unreachable!();
}
Poll::Ready(Ok(len))
},
Poll::Ready(Err(e)) => {
if e.is_disconnected() {
Expand All @@ -484,19 +488,18 @@ impl AsyncWrite for MemorySocket {
// Unbounded channels should only ever have "Disconnected" errors
unreachable!();
},
Poll::Pending => return Poll::Pending,
Poll::Pending => Poll::Pending,
}

Poll::Ready(Ok(len))
}

/// Attempt to flush the channel. Cannot Fail.
fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> {
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<io::Result<()>> {
trace!("flush");
Poll::Ready(Ok(()))
}

/// Attempt to close the channel. Cannot Fail.
fn poll_shutdown(self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> {
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context) -> Poll<io::Result<()>> {
self.outgoing.close_channel();

Poll::Ready(Ok(()))
Expand All @@ -506,7 +509,8 @@ impl AsyncWrite for MemorySocket {
#[cfg(test)]
mod test {
use super::*;
use crate::runtime;
use crate::{framing, runtime};
use futures::SinkExt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_stream::StreamExt;

Expand Down Expand Up @@ -705,4 +709,21 @@ mod test {

Ok(())
}

#[runtime::test]
async fn read_and_write_canonical_framing() -> io::Result<()> {
let (a, b) = MemorySocket::new_pair();
let mut a = framing::canonical(a, 1024);
let mut b = framing::canonical(b, 1024);

a.send(Bytes::from_static(b"frame-1")).await?;
b.send(Bytes::from_static(b"frame-2")).await?;
let msg = b.next().await.unwrap()?;
assert_eq!(&msg[..], b"frame-1");

let msg = a.next().await.unwrap()?;
assert_eq!(&msg[..], b"frame-2");

Ok(())
}
}
42 changes: 29 additions & 13 deletions comms/src/multiplexing/yamux.rs
Expand Up @@ -166,7 +166,7 @@ pub struct IncomingSubstreams {
}

impl IncomingSubstreams {
pub fn new(inner: IncomingRx, substream_counter: SubstreamCounter, shutdown: Shutdown) -> Self {
pub(self) fn new(inner: IncomingRx, substream_counter: SubstreamCounter, shutdown: Shutdown) -> Self {
Self {
inner,
substream_counter,
Expand Down Expand Up @@ -205,6 +205,12 @@ pub struct Substream {
counter_guard: CounterGuard,
}

impl Substream {
pub fn id(&self) -> yamux::StreamId {
self.stream.get_ref().id()
}
}

impl tokio::io::AsyncRead for Substream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
Expand Down Expand Up @@ -242,13 +248,17 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static
}
}

#[tracing::instrument(name = "yamux::incoming_worker::run", skip(self))]
#[tracing::instrument(name = "yamux::incoming_worker::run", skip(self), fields(connection = %self.connection))]
pub async fn run(mut self) {
loop {
tokio::select! {
biased;

_ = &mut self.shutdown_signal => {
_ = self.shutdown_signal.wait() => {
debug!(
target: LOG_TARGET,
"{} Yamux connection shutdown", self.connection
);
let mut control = self.connection.control();
if let Err(err) = control.close().await {
error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err);
Expand All @@ -259,31 +269,37 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static
result = self.connection.next_stream() => {
match result {
Ok(Some(stream)) => {
event!(Level::TRACE, "yamux::stream received {}", stream);if self.sender.send(stream).await.is_err() {
event!(Level::TRACE, "yamux::incoming_worker::new_stream {}", stream);
if self.sender.send(stream).await.is_err() {
debug!(
target: LOG_TARGET,
"Incoming peer substream task is shutting down because the internal stream sender channel \
was closed"
"{} Incoming peer substream task is shutting down because the internal stream sender channel \
was closed",
self.connection
);
break;
}
},
Ok(None) =>{
debug!(
target: LOG_TARGET,
"Incoming peer substream completed. IncomingWorker exiting"
"{} Incoming peer substream completed. IncomingWorker exiting",
self.connection
);
break;
}
Err(err) => {
event!(
Level::ERROR,
"Incoming peer substream task received an error because '{}'",
err
);
error!(
Level::ERROR,
"{} Incoming peer substream task received an error because '{}'",
self.connection,
err
);
error!(
target: LOG_TARGET,
"Incoming peer substream task received an error because '{}'", err
"{} Incoming peer substream task received an error because '{}'",
self.connection,
err
);
break;
},
Expand Down

0 comments on commit 336f4d6

Please sign in to comment.