Skip to content

Commit

Permalink
fix(rpc)!: read from substream while streaming to check for interrupt…
Browse files Browse the repository at this point in the history
…ions (#3548)

Description
---
- fixes issue where long running streams continue long after client has dropped the stream
- detects stream interruptions on server side and ends the RPC stream 
- (breaking change) send client protocol message back to explicitly end the stream
- check for shutdown signal termination in client side while reading the stream

Motivation and Context
---
Streams should terminate as soon as they are not used, this was a particular problem when
streaming  UTXOs to wallets and pruned nodes, when erroring out and retrying, new streams
would be established while the old streams still continued until termination.

The main trick was to read from the yamux stream just before writing to it without blocking if there
was nothing there. This allows yamux to receive the close stream message.

How Has This Been Tested?
---
New unit test
Manually, added code to forcefully error out of pruned sync after starting the stream. The node had another local node as a forced_sync_peer and so would continuously retry that peer. This caused many (saw up to 30) RPC sessions to accumulate and many concurrent streams. With these changes, the sessions used go between 1 and 0 for that peer.
  • Loading branch information
sdbondi committed Nov 8, 2021
1 parent e17ee64 commit 9194501
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 13 deletions.
57 changes: 48 additions & 9 deletions comms/src/protocol/rpc/client.rs
Expand Up @@ -587,13 +587,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
self.protocol_name(),
);
rx.close();
// RPC is strictly request/response
// If the client drops the RpcClient request at this point after the, we have two options:
// 1. Obey the protocol: receive the response
// 2. Error out and immediately close the session (seems brittle and may be unexpected)
// Option 1 has the disadvantage when receiving large/many streamed responses, however if all client handles
// have been dropped, then read_reply will exit early the stream will close and the server-side
// can exit early
return Ok(());
}

if let Err(err) = self.send_request(req).await {
Expand All @@ -603,6 +597,17 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
}

loop {
if self.shutdown_signal.is_triggered() {
debug!(
target: LOG_TARGET,
"[{}, stream_id: {}, req_id: {}] Client connector closed. Quitting stream early",
self.protocol_name(),
self.stream_id(),
request_id
);
break;
}

let resp = match self.read_response(request_id).await {
Ok(resp) => {
let latency = start.elapsed();
Expand Down Expand Up @@ -667,10 +672,20 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
warn!(
target: LOG_TARGET,
"(stream={}) Response receiver was dropped before the response/stream could complete for \
protocol {}, the stream will continue until completed",
protocol {}, interrupting the stream. ",
self.stream_id(),
self.protocol_name()
);
let req = proto::rpc::RpcRequest {
request_id: request_id as u32,
method,
deadline: self.config.deadline.map(|t| t.as_secs()).unwrap_or(0),
flags: RpcMessageFlags::FIN.bits().into(),
payload: vec![],
};

self.send_request(req).await?;
break;
} else {
let _ = response_tx.send(Ok(resp)).await;
}
Expand Down Expand Up @@ -714,7 +729,31 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId

async fn read_response(&mut self, request_id: u16) -> Result<proto::rpc::RpcResponse, RpcError> {
let mut reader = RpcResponseReader::new(&mut self.framed, self.config, request_id);
let resp = reader.read_response().await?;

let mut num_ignored = 0;
let resp = loop {
match reader.read_response().await {
Ok(resp) => break resp,
Err(RpcError::ResponseIdDidNotMatchRequest { actual, expected })
if actual.saturating_add(1) == request_id =>
{
warn!(
target: LOG_TARGET,
"Possible delayed response received for previous request {}", actual
);
num_ignored += 1;

// Be lenient for a number of messages that may have been buffered to come through for the previous
// request.
const MAX_ALLOWED_IGNORED: usize = 5;
if num_ignored > MAX_ALLOWED_IGNORED {
return Err(RpcError::ResponseIdDidNotMatchRequest { actual, expected });
}
continue;
}
Err(err) => return Err(err),
}
};
Ok(resp)
}

Expand Down
4 changes: 4 additions & 0 deletions comms/src/protocol/rpc/server/error.rs
Expand Up @@ -35,10 +35,14 @@ pub enum RpcServerError {
MaximumSessionsReached,
#[error("Internal service request canceled")]
RequestCanceled,
#[error("Stream was closed by remote")]
StreamClosedByRemote,
#[error("Handshake error: {0}")]
HandshakeError(#[from] RpcHandshakeError),
#[error("Service not found for protocol `{0}`")]
ProtocolServiceNotFound(String),
#[error("Unexpected incoming message")]
UnexpectedIncomingMessage,
}

impl From<oneshot::error::RecvError> for RpcServerError {
Expand Down
32 changes: 30 additions & 2 deletions comms/src/protocol/rpc/server/mod.rs
Expand Up @@ -63,15 +63,18 @@ use crate::{
Bytes,
Substream,
};
use futures::{stream, SinkExt, StreamExt};
use futures::{future, stream, SinkExt, StreamExt};
use prost::Message;
use std::{
borrow::Cow,
future::Future,
pin::Pin,
sync::Arc,
task::Poll,
time::{Duration, Instant},
};
use tokio::{sync::mpsc, time};
use tokio_stream::Stream;
use tower::Service;
use tower_make::MakeService;
use tracing::{debug, error, instrument, span, trace, warn, Instrument, Level};
Expand Down Expand Up @@ -502,6 +505,11 @@ where
}

let msg_flags = RpcMessageFlags::from_bits_truncate(decoded_msg.flags as u8);

if msg_flags.contains(RpcMessageFlags::FIN) {
debug!(target: LOG_TARGET, "({}) Client sent FIN.", self.logging_context_string);
return Ok(());
}
if msg_flags.contains(RpcMessageFlags::ACK) {
debug!(
target: LOG_TARGET,
Expand Down Expand Up @@ -594,6 +602,12 @@ where
.map(|resp| Bytes::from(resp.to_encoded_bytes()));

loop {
// Check if the client interrupted the outgoing stream
if let Err(err) = self.check_interruptions().await {
warn!(target: LOG_TARGET, "{}", err);
break;
}

let next_item = log_timing(
self.logging_context_string.clone(),
request_id,
Expand All @@ -602,7 +616,7 @@ where
);
match time::timeout(deadline, next_item).await {
Ok(Some(msg)) => {
trace!(
debug!(
target: LOG_TARGET,
"({}) Sending body len = {}",
self.logging_context_string,
Expand Down Expand Up @@ -630,6 +644,20 @@ where
Ok(())
}

async fn check_interruptions(&mut self) -> Result<(), RpcServerError> {
let check = future::poll_fn(|cx| match Pin::new(&mut self.framed).poll_next(cx) {
Poll::Ready(Some(Ok(_))) => Poll::Ready(Some(RpcServerError::UnexpectedIncomingMessage)),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(RpcServerError::from(err))),
Poll::Ready(None) => Poll::Ready(Some(RpcServerError::StreamClosedByRemote)),
Poll::Pending => Poll::Ready(None),
})
.await;
match check {
Some(err) => Err(err),
None => Ok(()),
}
}

fn create_request_context(&self, request_id: u32) -> RequestContext {
RequestContext::new(request_id, self.node_id.clone(), Box::new(self.comms_provider.clone()))
}
Expand Down
5 changes: 4 additions & 1 deletion comms/src/protocol/rpc/test/greeting_service.rs
Expand Up @@ -172,7 +172,10 @@ impl GreetingRpc for GreetingService {
tokio::spawn(async move {
for _ in 0..num_items {
time::sleep(Duration::from_millis(delay_ms)).await;
tx.send(Ok(item.clone())).await.unwrap();
if tx.send(Ok(item.clone())).await.is_err() {
log::info!("stream was interrupted");
break;
}
}
});

Expand Down
50 changes: 49 additions & 1 deletion comms/src/protocol/rpc/test/smoke.rs
Expand Up @@ -62,6 +62,7 @@ use tari_test_utils::unpack_enum;
use tokio::{
sync::{mpsc, RwLock},
task,
time,
};

pub(super) async fn setup_service<T: GreetingRpc>(
Expand Down Expand Up @@ -389,7 +390,7 @@ async fn stream_still_works_after_cancel() {
// Request was sent
assert_eq!(service_impl.call_count(), 1);

// Subsequent call still works, after waiting for the previous one
// Subsequent call still works
let resp = client
.slow_stream(SlowStreamRequest {
num_items: 100,
Expand All @@ -403,3 +404,50 @@ async fn stream_still_works_after_cancel() {
r.unwrap();
});
}

#[runtime::test]
async fn stream_interruption_handling() {
let service_impl = GreetingService::default();
let (mut muxer, _outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await;
let socket = muxer.incoming_mut().next().await.unwrap();

let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();

let mut resp = client
.slow_stream(SlowStreamRequest {
num_items: 10000,
item_size: 100,
delay_ms: 100,
})
.await
.unwrap();

let _ = resp.next().await.unwrap().unwrap();
// Drop it before the stream is finished
drop(resp);

// Subsequent call still works, without waiting
let mut resp = client
.slow_stream(SlowStreamRequest {
num_items: 100,
item_size: 100,
delay_ms: 1,
})
.await
.unwrap();

let next_fut = resp.next();
tokio::pin!(next_fut);
// Allow 10 seconds, if the previous stream is still streaming, it will take a while for this stream to start and
// the timeout will expire
time::timeout(Duration::from_secs(10), next_fut)
.await
.unwrap()
.unwrap()
.unwrap();
}

0 comments on commit 9194501

Please sign in to comment.