diff --git a/server/src/server.rs b/server/src/server.rs index eecde7abae..9f65b2d169 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -654,7 +654,7 @@ impl hyper::service::Service> for TowerSe ws_builder.set_max_message_size(data.max_request_body_size as usize); let (sender, receiver) = ws_builder.finish(); - let _ = ws::background_task::(sender, receiver, data).await; + ws::background_task::(sender, receiver, data).await; } .in_current_span(), ); diff --git a/server/src/tests/ws.rs b/server/src/tests/ws.rs index e77e7dffe0..cf220cb375 100644 --- a/server/src/tests/ws.rs +++ b/server/src/tests/ws.rs @@ -24,6 +24,8 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use std::time::Duration; + use crate::server::BatchRequestConfig; use crate::tests::helpers::{deser_call, init_logger, server_with_context}; use crate::types::SubscriptionId; @@ -815,25 +817,14 @@ async fn notif_is_ignored() { } #[tokio::test] -async fn drop_client_with_pending_calls_works() { +async fn close_client_with_pending_calls_works() { + const MAX_TIMEOUT: Duration = Duration::from_secs(60); + const CONCURRENT_CALLS: usize = 10; init_logger(); - let (handle, addr) = { - let server = ServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); - - let mut module = RpcModule::new(()); - - module - .register_async_method("infinite_call", |_, _| async move { - futures_util::future::pending::<()>().await; - "ok" - }) - .unwrap(); - let addr = server.local_addr().unwrap(); - - (server.start(module).unwrap(), addr) - }; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let (handle, addr) = server_with_infinite_call(MAX_TIMEOUT.checked_mul(10).unwrap(), tx).await; let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap(); for _ in 0..10 { @@ -841,11 +832,72 @@ async fn drop_client_with_pending_calls_works() { client.send(req).with_default_timeout().await.unwrap().unwrap(); } + // Assert that the server has received the calls. + for _ in 0..CONCURRENT_CALLS { + assert!(rx.recv().await.is_some()); + } + client.close().await.unwrap(); assert!(client.receive().await.is_err()); // Stop the server and ensure that the server doesn't wait for futures to complete // when the connection has already been closed. handle.stop().unwrap(); - assert!(handle.stopped().with_default_timeout().await.is_ok()); + assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok()); +} + +#[tokio::test] +async fn drop_client_with_pending_calls_works() { + const MAX_TIMEOUT: Duration = Duration::from_secs(60); + const CONCURRENT_CALLS: usize = 10; + init_logger(); + + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let (handle, addr) = server_with_infinite_call(MAX_TIMEOUT.checked_mul(10).unwrap(), tx).await; + + { + let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap(); + + for _ in 0..CONCURRENT_CALLS { + let req = r#"{"jsonrpc":"2.0","method":"infinite_call","id":1}"#; + client.send(req).with_default_timeout().await.unwrap().unwrap(); + } + // Assert that the server has received the calls. + for _ in 0..CONCURRENT_CALLS { + assert!(rx.recv().await.is_some()); + } + } + + // Stop the server and ensure that the server doesn't wait for futures to complete + // when the connection has already been closed. + handle.stop().unwrap(); + assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok()); +} + +async fn server_with_infinite_call( + timeout: Duration, + tx: tokio::sync::mpsc::UnboundedSender<()>, +) -> (crate::ServerHandle, std::net::SocketAddr) { + let server = ServerBuilder::default() + // Make sure that the ping_interval doesn't force the connection to be closed + .ping_interval(timeout) + .build("127.0.0.1:0") + .with_default_timeout() + .await + .unwrap() + .unwrap(); + + let mut module = RpcModule::new(tx); + + module + .register_async_method("infinite_call", |_, mut ctx| async move { + let tx = std::sync::Arc::make_mut(&mut ctx); + tx.send(()).unwrap(); + futures_util::future::pending::<()>().await; + "ok" + }) + .unwrap(); + let addr = server.local_addr().unwrap(); + + (server.start(module).unwrap(), addr) } diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 6fcd07a6fc..63ef8faea2 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -226,11 +226,7 @@ pub(crate) async fn execute_call<'a, L: Logger>(req: Request<'a>, call: CallData response } -pub(crate) async fn background_task( - sender: Sender, - mut receiver: Receiver, - svc: ServiceData, -) -> Result<(), Error> { +pub(crate) async fn background_task(sender: Sender, mut receiver: Receiver, svc: ServiceData) { let ServiceData { methods, max_request_body_size, @@ -250,17 +246,17 @@ pub(crate) async fn background_task( } = svc; let (tx, rx) = mpsc::channel::(message_buffer_capacity as usize); - let (mut conn_tx, conn_rx) = oneshot::channel(); + let (conn_tx, conn_rx) = oneshot::channel(); let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length); let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection); let pending_calls = FuturesUnordered::new(); // Spawn another task that sends out the responses on the Websocket. - tokio::spawn(send_task(rx, sender, ping_interval, conn_rx)); + let send_task_handle = tokio::spawn(send_task(rx, sender, ping_interval, conn_rx)); // Buffer for incoming data. let mut data = Vec::with_capacity(100); - let stopped = stop_handle.shutdown(); + let stopped = stop_handle.clone().shutdown(); tokio::pin!(stopped); @@ -272,11 +268,11 @@ pub(crate) async fn background_task( stopped = stop; permit } - None => break Ok(()), + None => break Ok(Shutdown::ConnectionClosed), }; match try_recv(&mut receiver, &mut data, stopped).await { - Receive::Shutdown => break Ok(()), + Receive::Shutdown => break Ok(Shutdown::Stopped), Receive::Ok(stop) => { stopped = stop; } @@ -286,7 +282,7 @@ pub(crate) async fn background_task( match err { SokettoError::Closed => { tracing::debug!("WS transport: remote peer terminated the connection: {}", conn_id); - break Ok(()); + break Ok(Shutdown::ConnectionClosed); } SokettoError::MessageTooLarge { current, maximum } => { tracing::debug!( @@ -300,7 +296,7 @@ pub(crate) async fn background_task( } err => { tracing::debug!("WS transport error: {}; terminate connection: {}", err, conn_id); - break Err(err.into()); + break Err(err); } }; } @@ -326,22 +322,11 @@ pub(crate) async fn background_task( // Drive all running methods to completion. // **NOTE** Do not return early in this function. This `await` needs to run to guarantee // proper drop behaviour. - // - // This is not strictly not needed because `tokio::spawn` will drive these the completion - // but it's preferred that the `stop_handle.stopped()` should not return until all methods has been - // executed and the connection has been closed. - tokio::select! { - // All pending calls executed. - _ = pending_calls.for_each(|_| async {}) => { - _ = conn_tx.send(()); - } - // The connection was closed, no point of waiting for the pending calls. - _ = conn_tx.closed() => {} - } + graceful_shutdown(result, pending_calls, receiver, data, conn_tx, send_task_handle).await; logger.on_disconnect(remote_addr, TransportProtocol::WebSocket); drop(conn); - result + drop(stop_handle); } /// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`. @@ -352,7 +337,11 @@ async fn send_task( stop: oneshot::Receiver<()>, ) { // Interval to send out continuously `pings`. - let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval)); + let mut ping_interval = tokio::time::interval(ping_interval); + // This returns immediately so make sure it doesn't resolve before the ping_interval has been elapsed. + ping_interval.tick().await; + + let ping_interval = IntervalStream::new(ping_interval); let rx = ReceiverStream::new(rx); tokio::pin!(ping_interval, rx, stop); @@ -384,15 +373,18 @@ async fn send_task( } // Handle timer intervals. - Either::Right((Either::Left((_, stop)), next_rx)) => { + Either::Right((Either::Left((Some(_instant), stop)), next_rx)) => { if let Err(err) = send_ping(&mut ws_sender).await { tracing::debug!("WS transport error: send ping failed: {}", err); break; } + rx_item = next_rx; futs = future::select(ping_interval.next(), stop); } + Either::Right((Either::Left((None, _)), _)) => unreachable!("IntervalStream never terminates"), + // Server is stopped. Either::Right((Either::Right(_), _)) => { break; @@ -558,3 +550,55 @@ async fn execute_unchecked_call(params: ExecuteCallParams) { } }; } + +#[derive(Debug, Copy, Clone)] +pub(crate) enum Shutdown { + Stopped, + ConnectionClosed, +} + +/// Enforce a graceful shutdown. +/// +/// This will return once the connection has been terminated or all pending calls have been executed. +async fn graceful_shutdown( + result: Result, + pending_calls: FuturesUnordered, + receiver: Receiver, + data: Vec, + mut conn_tx: oneshot::Sender<()>, + send_task_handle: tokio::task::JoinHandle<()>, +) { + match result { + Ok(Shutdown::ConnectionClosed) | Err(SokettoError::Closed) => (), + Ok(Shutdown::Stopped) | Err(_) => { + // Soketto doesn't have a way to signal when the connection is closed + // thus just throw away the data and terminate the stream once the connection has + // been terminated. + // + // The receiver is not cancel-safe such that it's used in a stream to enforce that. + let disconnect_stream = futures_util::stream::unfold((receiver, data), |(mut receiver, mut data)| async { + if let Err(SokettoError::Closed) = receiver.receive(&mut data).await { + None + } else { + Some(((), (receiver, data))) + } + }); + + let graceful_shutdown = pending_calls.for_each(|_| async {}); + let disconnect = disconnect_stream.for_each(|_| async {}); + + // All pending calls has been finished or the connection closed. + // Fine to terminate + tokio::select! { + _ = graceful_shutdown => {} + _ = disconnect => {} + _ = conn_tx.closed() => {} + } + } + }; + + // Send a message to close down the "send task". + _ = conn_tx.send(()); + // Ensure that send task has been closed. + _ = send_task_handle.await; +}