Skip to content

Commit

Permalink
fix(ws server): fix shutdown on connection closed (#1103)
Browse files Browse the repository at this point in the history
* fix(ws server): fix flaky shutdown test

* Update server/src/transport/ws.rs

* Update server/src/transport/ws.rs

* fix interval stream bug

* Update server/src/transport/ws.rs

* Update server/src/transport/ws.rs

* Update server/src/transport/ws.rs

* Update server/src/transport/ws.rs

* check conn_tx.closed as well

* add more tests + cleanup

* fix nit

* Update server/src/tests/ws.rs

* add comment in weird test

* rewrite tests without sleeps

* remove needless result

* fix compile warn
  • Loading branch information
niklasad1 committed Apr 27, 2023
1 parent 3cb95de commit 457d2d2
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 45 deletions.
2 changes: 1 addition & 1 deletion server/src/server.rs
Expand Up @@ -654,7 +654,7 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe
ws_builder.set_max_message_size(data.max_request_body_size as usize);
let (sender, receiver) = ws_builder.finish();

let _ = ws::background_task::<L>(sender, receiver, data).await;
ws::background_task::<L>(sender, receiver, data).await;
}
.in_current_span(),
);
Expand Down
86 changes: 69 additions & 17 deletions server/src/tests/ws.rs
Expand Up @@ -24,6 +24,8 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use std::time::Duration;

use crate::server::BatchRequestConfig;
use crate::tests::helpers::{deser_call, init_logger, server_with_context};
use crate::types::SubscriptionId;
Expand Down Expand Up @@ -815,37 +817,87 @@ async fn notif_is_ignored() {
}

#[tokio::test]
async fn drop_client_with_pending_calls_works() {
async fn close_client_with_pending_calls_works() {
const MAX_TIMEOUT: Duration = Duration::from_secs(60);
const CONCURRENT_CALLS: usize = 10;
init_logger();

let (handle, addr) = {
let server = ServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap();

let mut module = RpcModule::new(());

module
.register_async_method("infinite_call", |_, _| async move {
futures_util::future::pending::<()>().await;
"ok"
})
.unwrap();
let addr = server.local_addr().unwrap();

(server.start(module).unwrap(), addr)
};
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();

let (handle, addr) = server_with_infinite_call(MAX_TIMEOUT.checked_mul(10).unwrap(), tx).await;
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

for _ in 0..10 {
let req = r#"{"jsonrpc":"2.0","method":"infinite_call","id":1}"#;
client.send(req).with_default_timeout().await.unwrap().unwrap();
}

// Assert that the server has received the calls.
for _ in 0..CONCURRENT_CALLS {
assert!(rx.recv().await.is_some());
}

client.close().await.unwrap();
assert!(client.receive().await.is_err());

// Stop the server and ensure that the server doesn't wait for futures to complete
// when the connection has already been closed.
handle.stop().unwrap();
assert!(handle.stopped().with_default_timeout().await.is_ok());
assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok());
}

#[tokio::test]
async fn drop_client_with_pending_calls_works() {
const MAX_TIMEOUT: Duration = Duration::from_secs(60);
const CONCURRENT_CALLS: usize = 10;
init_logger();

let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let (handle, addr) = server_with_infinite_call(MAX_TIMEOUT.checked_mul(10).unwrap(), tx).await;

{
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

for _ in 0..CONCURRENT_CALLS {
let req = r#"{"jsonrpc":"2.0","method":"infinite_call","id":1}"#;
client.send(req).with_default_timeout().await.unwrap().unwrap();
}
// Assert that the server has received the calls.
for _ in 0..CONCURRENT_CALLS {
assert!(rx.recv().await.is_some());
}
}

// Stop the server and ensure that the server doesn't wait for futures to complete
// when the connection has already been closed.
handle.stop().unwrap();
assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok());
}

async fn server_with_infinite_call(
timeout: Duration,
tx: tokio::sync::mpsc::UnboundedSender<()>,
) -> (crate::ServerHandle, std::net::SocketAddr) {
let server = ServerBuilder::default()
// Make sure that the ping_interval doesn't force the connection to be closed
.ping_interval(timeout)
.build("127.0.0.1:0")
.with_default_timeout()
.await
.unwrap()
.unwrap();

let mut module = RpcModule::new(tx);

module
.register_async_method("infinite_call", |_, mut ctx| async move {
let tx = std::sync::Arc::make_mut(&mut ctx);
tx.send(()).unwrap();
futures_util::future::pending::<()>().await;
"ok"
})
.unwrap();
let addr = server.local_addr().unwrap();

(server.start(module).unwrap(), addr)
}
98 changes: 71 additions & 27 deletions server/src/transport/ws.rs
Expand Up @@ -226,11 +226,7 @@ pub(crate) async fn execute_call<'a, L: Logger>(req: Request<'a>, call: CallData
response
}

pub(crate) async fn background_task<L: Logger>(
sender: Sender,
mut receiver: Receiver,
svc: ServiceData<L>,
) -> Result<(), Error> {
pub(crate) async fn background_task<L: Logger>(sender: Sender, mut receiver: Receiver, svc: ServiceData<L>) {
let ServiceData {
methods,
max_request_body_size,
Expand All @@ -250,17 +246,17 @@ pub(crate) async fn background_task<L: Logger>(
} = svc;

let (tx, rx) = mpsc::channel::<String>(message_buffer_capacity as usize);
let (mut conn_tx, conn_rx) = oneshot::channel();
let (conn_tx, conn_rx) = oneshot::channel();
let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length);
let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection);
let pending_calls = FuturesUnordered::new();

// Spawn another task that sends out the responses on the Websocket.
tokio::spawn(send_task(rx, sender, ping_interval, conn_rx));
let send_task_handle = tokio::spawn(send_task(rx, sender, ping_interval, conn_rx));

// Buffer for incoming data.
let mut data = Vec::with_capacity(100);
let stopped = stop_handle.shutdown();
let stopped = stop_handle.clone().shutdown();

tokio::pin!(stopped);

Expand All @@ -272,11 +268,11 @@ pub(crate) async fn background_task<L: Logger>(
stopped = stop;
permit
}
None => break Ok(()),
None => break Ok(Shutdown::ConnectionClosed),
};

match try_recv(&mut receiver, &mut data, stopped).await {
Receive::Shutdown => break Ok(()),
Receive::Shutdown => break Ok(Shutdown::Stopped),
Receive::Ok(stop) => {
stopped = stop;
}
Expand All @@ -286,7 +282,7 @@ pub(crate) async fn background_task<L: Logger>(
match err {
SokettoError::Closed => {
tracing::debug!("WS transport: remote peer terminated the connection: {}", conn_id);
break Ok(());
break Ok(Shutdown::ConnectionClosed);
}
SokettoError::MessageTooLarge { current, maximum } => {
tracing::debug!(
Expand All @@ -300,7 +296,7 @@ pub(crate) async fn background_task<L: Logger>(
}
err => {
tracing::debug!("WS transport error: {}; terminate connection: {}", err, conn_id);
break Err(err.into());
break Err(err);
}
};
}
Expand All @@ -326,22 +322,11 @@ pub(crate) async fn background_task<L: Logger>(
// Drive all running methods to completion.
// **NOTE** Do not return early in this function. This `await` needs to run to guarantee
// proper drop behaviour.
//
// This is not strictly not needed because `tokio::spawn` will drive these the completion
// but it's preferred that the `stop_handle.stopped()` should not return until all methods has been
// executed and the connection has been closed.
tokio::select! {
// All pending calls executed.
_ = pending_calls.for_each(|_| async {}) => {
_ = conn_tx.send(());
}
// The connection was closed, no point of waiting for the pending calls.
_ = conn_tx.closed() => {}
}
graceful_shutdown(result, pending_calls, receiver, data, conn_tx, send_task_handle).await;

logger.on_disconnect(remote_addr, TransportProtocol::WebSocket);
drop(conn);
result
drop(stop_handle);
}

/// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`.
Expand All @@ -352,7 +337,11 @@ async fn send_task(
stop: oneshot::Receiver<()>,
) {
// Interval to send out continuously `pings`.
let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval));
let mut ping_interval = tokio::time::interval(ping_interval);
// This returns immediately so make sure it doesn't resolve before the ping_interval has been elapsed.
ping_interval.tick().await;

let ping_interval = IntervalStream::new(ping_interval);
let rx = ReceiverStream::new(rx);

tokio::pin!(ping_interval, rx, stop);
Expand Down Expand Up @@ -384,15 +373,18 @@ async fn send_task(
}

// Handle timer intervals.
Either::Right((Either::Left((_, stop)), next_rx)) => {
Either::Right((Either::Left((Some(_instant), stop)), next_rx)) => {
if let Err(err) = send_ping(&mut ws_sender).await {
tracing::debug!("WS transport error: send ping failed: {}", err);
break;
}

rx_item = next_rx;
futs = future::select(ping_interval.next(), stop);
}

Either::Right((Either::Left((None, _)), _)) => unreachable!("IntervalStream never terminates"),

// Server is stopped.
Either::Right((Either::Right(_), _)) => {
break;
Expand Down Expand Up @@ -558,3 +550,55 @@ async fn execute_unchecked_call<L: Logger>(params: ExecuteCallParams<L>) {
}
};
}

#[derive(Debug, Copy, Clone)]
pub(crate) enum Shutdown {
Stopped,
ConnectionClosed,
}

/// Enforce a graceful shutdown.
///
/// This will return once the connection has been terminated or all pending calls have been executed.
async fn graceful_shutdown<F: Future>(
result: Result<Shutdown, SokettoError>,
pending_calls: FuturesUnordered<F>,
receiver: Receiver,
data: Vec<u8>,
mut conn_tx: oneshot::Sender<()>,
send_task_handle: tokio::task::JoinHandle<()>,
) {
match result {
Ok(Shutdown::ConnectionClosed) | Err(SokettoError::Closed) => (),
Ok(Shutdown::Stopped) | Err(_) => {
// Soketto doesn't have a way to signal when the connection is closed
// thus just throw away the data and terminate the stream once the connection has
// been terminated.
//
// The receiver is not cancel-safe such that it's used in a stream to enforce that.
let disconnect_stream = futures_util::stream::unfold((receiver, data), |(mut receiver, mut data)| async {
if let Err(SokettoError::Closed) = receiver.receive(&mut data).await {
None
} else {
Some(((), (receiver, data)))
}
});

let graceful_shutdown = pending_calls.for_each(|_| async {});
let disconnect = disconnect_stream.for_each(|_| async {});

// All pending calls has been finished or the connection closed.
// Fine to terminate
tokio::select! {
_ = graceful_shutdown => {}
_ = disconnect => {}
_ = conn_tx.closed() => {}
}
}
};

// Send a message to close down the "send task".
_ = conn_tx.send(());
// Ensure that send task has been closed.
_ = send_task_handle.await;
}

0 comments on commit 457d2d2

Please sign in to comment.