Skip to content

Commit

Permalink
fix(comms/rpc): detect early close in all cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbondi committed Sep 8, 2022
1 parent dffea23 commit 8bf1d5f
Show file tree
Hide file tree
Showing 12 changed files with 426 additions and 104 deletions.
4 changes: 3 additions & 1 deletion applications/tari_base_node/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl<B> BaseNodeBootstrapper<'_, B>
where B: BlockchainBackend + 'static
{
pub async fn bootstrap(self) -> Result<ServiceHandles, ExitError> {
let base_node_config = &self.app_config.base_node;
let mut base_node_config = self.app_config.base_node.clone();
let mut p2p_config = self.app_config.base_node.p2p.clone();
let peer_seeds = &self.app_config.peer_seeds;

Expand All @@ -95,6 +95,8 @@ where B: BlockchainBackend + 'static
.collect::<Result<Vec<_>, _>>()
.map_err(|e| ExitError::new(ExitCode::ConfigError, e))?;

base_node_config.state_machine.blockchain_sync_config.forced_sync_peers = sync_peers.clone();

debug!(target: LOG_TARGET, "{} sync peer(s) configured", sync_peers.len());

let mempool_sync = MempoolSyncInitializer::new(mempool_config, self.mempool.clone());
Expand Down
8 changes: 4 additions & 4 deletions base_layer/core/src/base_node/sync/rpc/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use tari_comms::{
};
use tari_utilities::hex::Hex;
use tokio::{
sync::{mpsc, RwLock},
sync::{mpsc, Mutex},
task,
};
use tracing::{instrument, span, Instrument, Level};
Expand Down Expand Up @@ -65,15 +65,15 @@ const LOG_TARGET: &str = "c::base_node::sync_rpc";

pub struct BaseNodeSyncRpcService<B> {
db: AsyncBlockchainDb<B>,
active_sessions: RwLock<Vec<Weak<NodeId>>>,
active_sessions: Mutex<Vec<Weak<NodeId>>>,
base_node_service: LocalNodeCommsInterface,
}

impl<B: BlockchainBackend + 'static> BaseNodeSyncRpcService<B> {
pub fn new(db: AsyncBlockchainDb<B>, base_node_service: LocalNodeCommsInterface) -> Self {
Self {
db,
active_sessions: RwLock::new(Vec::new()),
active_sessions: Mutex::new(Vec::new()),
base_node_service,
}
}
Expand All @@ -84,7 +84,7 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncRpcService<B> {
}

pub async fn try_add_exclusive_session(&self, peer: NodeId) -> Result<Arc<NodeId>, RpcStatus> {
let mut lock = self.active_sessions.write().await;
let mut lock = self.active_sessions.lock().await;
*lock = lock.drain(..).filter(|l| l.strong_count() > 0).collect();
debug!(target: LOG_TARGET, "Number of active sync sessions: {}", lock.len());

Expand Down
8 changes: 5 additions & 3 deletions base_layer/p2p/src/services/liveness/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ impl LivenessState {

let (node_id, _) = self.inflight_pings.get(&nonce)?;
if node_id == sent_by {
self.inflight_pings
.remove(&nonce)
.map(|(node_id, sent_time)| self.add_latency_sample(node_id, sent_time.elapsed()).calc_average())
self.inflight_pings.remove(&nonce).map(|(node_id, sent_time)| {
let latency = sent_time.elapsed();
self.add_latency_sample(node_id, latency);
latency
})
} else {
warn!(
target: LOG_TARGET,
Expand Down
80 changes: 53 additions & 27 deletions comms/core/src/protocol/rpc/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use std::{

use bytes::Bytes;
use futures::{
future,
future::{BoxFuture, Either},
task::{Context, Poll},
FutureExt,
Expand Down Expand Up @@ -491,7 +492,10 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
break;
}
}
None => break,
None => {
debug!(target: LOG_TARGET, "(stream={}) Request channel closed. Worker is terminating.", self.stream_id());
break
},
}
}
}
Expand Down Expand Up @@ -618,7 +622,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
);
}

let (response_tx, response_rx) = mpsc::channel(10);
let (response_tx, response_rx) = mpsc::channel(5);
if let Err(mut rx) = reply.send(response_rx) {
event!(Level::WARN, "Client request was cancelled after request was sent");
warn!(
Expand All @@ -636,7 +640,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
if let Err(err) = self.send_request(req).await {
warn!(target: LOG_TARGET, "{}", err);
metrics::client_errors(&self.node_id, &self.protocol_id).inc();
let _result = response_tx.send(Err(err.into()));
let _result = response_tx.send(Err(err.into())).await;
return Ok(());
}

Expand All @@ -654,7 +658,27 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
break;
}

let resp = match self.read_response(request_id).await {
// Check if the response receiver has been dropped while receiving messages
let resp_result = {
let resp_fut = self.read_response(request_id);
tokio::pin!(resp_fut);
let closed_fut = response_tx.closed();
tokio::pin!(closed_fut);
match future::select(resp_fut, closed_fut).await {
Either::Left((r, _)) => Some(r),
Either::Right(_) => None,
}
};
let resp_result = match resp_result {
Some(r) => r,
None => {
self.premature_close(request_id, method).await?;
break;
},
};

// let resp = match self.read_response(request_id).await {
let resp = match resp_result {
Ok(resp) => {
if let Some(t) = timer.take() {
let _ = self.last_request_latency_tx.send(Some(t.elapsed()));
Expand Down Expand Up @@ -682,14 +706,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
event!(Level::ERROR, "Response timed out");
metrics::client_timeouts(&self.node_id, &self.protocol_id).inc();
if response_tx.is_closed() {
let req = proto::rpc::RpcRequest {
request_id: u32::try_from(request_id).unwrap(),
method,
flags: RpcMessageFlags::FIN.bits().into(),
..Default::default()
};

self.send_request(req).await?;
self.premature_close(request_id, method).await?;
} else {
let _result = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await;
}
Expand Down Expand Up @@ -721,21 +738,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
// The consumer may drop the receiver before all responses are received.
// We handle this by sending a 'FIN' message to the server.
if response_tx.is_closed() {
warn!(
target: LOG_TARGET,
"(stream={}) Response receiver was dropped before the response/stream could complete for \
protocol {}, interrupting the stream. ",
self.stream_id(),
self.protocol_name()
);
let req = proto::rpc::RpcRequest {
request_id: u32::try_from(request_id).unwrap(),
method,
flags: RpcMessageFlags::FIN.bits().into(),
..Default::default()
};

self.send_request(req).await?;
self.premature_close(request_id, method).await?;
break;
} else {
let _result = response_tx.send(Ok(resp)).await;
Expand Down Expand Up @@ -766,6 +769,29 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
Ok(())
}

async fn premature_close(&mut self, request_id: u16, method: u32) -> Result<(), RpcError> {
warn!(
target: LOG_TARGET,
"(stream={}) Response receiver was dropped before the response/stream could complete for protocol {}, \
interrupting the stream. ",
self.stream_id(),
self.protocol_name()
);
let req = proto::rpc::RpcRequest {
request_id: u32::try_from(request_id).unwrap(),
method,
flags: RpcMessageFlags::FIN.bits().into(),
deadline: self.config.deadline.map(|d| d.as_secs()).unwrap_or(0),
..Default::default()
};

// If we cannot set FIN quickly, just exit
if let Ok(res) = time::timeout(Duration::from_secs(2), self.send_request(req)).await {
res?;
}
Ok(())
}

async fn send_request(&mut self, req: proto::rpc::RpcRequest) -> Result<(), RpcError> {
let payload = req.to_encoded_bytes();
if payload.len() > rpc::max_request_size() {
Expand Down
119 changes: 119 additions & 0 deletions comms/core/src/protocol/rpc/server/early_close.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright 2022. The Tari Project
//
// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
// following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
// disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the
// following disclaimer in the documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
// products derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::{
io,
pin::Pin,
task::{Context, Poll},
};

use futures::Sink;
use tokio_stream::Stream;

pub struct EarlyClose<TSock> {
inner: TSock,
}

impl<T, TSock: Stream<Item = io::Result<T>> + Unpin> EarlyClose<TSock> {
pub fn new(inner: TSock) -> Self {
Self { inner }
}
}

impl<TSock: Stream + Unpin> Stream for EarlyClose<TSock> {
type Item = TSock::Item;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}

impl<TItem, TSock, T> Sink<TItem> for EarlyClose<TSock>
where TSock: Sink<TItem, Error = io::Error> + Stream<Item = io::Result<T>> + Unpin
{
type Error = EarlyCloseError<T>;

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Poll::Ready(r) = Pin::new(&mut self.inner).poll_ready(cx) {
return Poll::Ready(r.map_err(Into::into));
}
check_for_early_close(&mut self.inner, cx)
}

fn start_send(mut self: Pin<&mut Self>, item: TItem) -> Result<(), Self::Error> {
Pin::new(&mut self.inner).start_send(item)?;
Ok(())
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Poll::Ready(r) = Pin::new(&mut self.inner).poll_flush(cx) {
return Poll::Ready(r.map_err(Into::into));
}
check_for_early_close(&mut self.inner, cx)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Poll::Ready(r) = Pin::new(&mut self.inner).poll_close(cx) {
return Poll::Ready(r.map_err(Into::into));
}
check_for_early_close(&mut self.inner, cx)
}
}

fn check_for_early_close<T, TSock: Stream<Item = io::Result<T>> + Unpin>(
sock: &mut TSock,
cx: &mut Context<'_>,
) -> Poll<Result<(), EarlyCloseError<T>>> {
match Pin::new(sock).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => Poll::Ready(Err(EarlyCloseError::UnexpectedMessage(msg))),
Poll::Ready(Some(Err(err))) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Poll::Pending => Poll::Pending,
Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err.into())),
Poll::Ready(None) => Poll::Ready(Err(
io::Error::new(io::ErrorKind::BrokenPipe, "Connection closed").into()
)),
}
}

#[derive(Debug, thiserror::Error)]
pub enum EarlyCloseError<T> {
#[error(transparent)]
Io(#[from] io::Error),
#[error("Unexpected message")]
UnexpectedMessage(T),
}

impl<T> EarlyCloseError<T> {
pub fn io(&self) -> Option<&io::Error> {
match self {
Self::Io(err) => Some(err),
_ => None,
}
}

pub fn unexpected_message(&self) -> Option<&T> {
match self {
EarlyCloseError::UnexpectedMessage(msg) => Some(msg),
_ => None,
}
}
}
9 changes: 8 additions & 1 deletion comms/core/src/protocol/rpc/server/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@

use std::io;

use bytes::BytesMut;
use prost::DecodeError;
use tokio::sync::oneshot;

use crate::{peer_manager::NodeId, proto, protocol::rpc::handshake::RpcHandshakeError};
use crate::{
peer_manager::NodeId,
proto,
protocol::rpc::{handshake::RpcHandshakeError, server::early_close::EarlyCloseError},
};

#[derive(Debug, thiserror::Error)]
pub enum RpcServerError {
Expand Down Expand Up @@ -55,6 +60,8 @@ pub enum RpcServerError {
ServiceCallExceededDeadline,
#[error("Stream read exceeded deadline")]
ReadStreamExceededDeadline,
#[error("Early close error: {0}")]
EarlyCloseError(#[from] EarlyCloseError<BytesMut>),
}

impl From<oneshot::error::RecvError> for RpcServerError {
Expand Down

0 comments on commit 8bf1d5f

Please sign in to comment.