diff --git a/Cargo.toml b/Cargo.toml index 36f0b94..0ab48d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,3 +9,7 @@ members = [ "examples", "rsocket-test", ] + +[replace] +"rsocket_rust:0.7.1" = { path = "../rsocket-rust/rsocket" } +"rsocket_rust_transport_tcp:0.7.1" = { path = "../rsocket-rust/rsocket-transport-tcp" } diff --git a/rsocket-test/Cargo.toml b/rsocket-test/Cargo.toml index 34c3a50..2f63d38 100644 --- a/rsocket-test/Cargo.toml +++ b/rsocket-test/Cargo.toml @@ -36,3 +36,19 @@ version = "0.7.1" version = "1.0.3" default-features = false features = ["full"] + +[dev-dependencies.tokio-stream] +version = "0.1.7" +features = ["sync"] + +[dev-dependencies.anyhow] +version = "1.0.40" + +[dev-dependencies.async-trait] +version = "0.1.50" + +[dev-dependencies.serial_test] +version = "0.5.1" + +[dev-dependencies.async-stream] +version = "0.3.1" diff --git a/rsocket-test/tests/test_stream_cancellation.rs b/rsocket-test/tests/test_stream_cancellation.rs new file mode 100644 index 0000000..08a8bad --- /dev/null +++ b/rsocket-test/tests/test_stream_cancellation.rs @@ -0,0 +1,334 @@ +#[macro_use] +extern crate log; + +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use futures::StreamExt; +use tokio_stream::wrappers::ReceiverStream; + +use rsocket_rust::prelude::{Flux, Payload, RSocket}; + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use futures::Future; + use rsocket_rust_transport_websocket::{WebsocketClientTransport, WebsocketServerTransport}; + use serial_test::serial; + use tokio::runtime::Runtime; + use async_stream::stream; + use rsocket_rust::Client; + use rsocket_rust::prelude::*; + use rsocket_rust::utils::EchoRSocket; + use rsocket_rust_transport_tcp::{TcpClientTransport, TcpServerTransport, UnixClientTransport, UnixServerTransport}; + + use crate::TestSocket; + + #[serial] + #[test] + fn request_stream_can_be_cancelled_by_client_uds() { + init_logger(); + with_uds_test_socket_run(request_stream_can_be_cancelled_by_client); + } + + #[serial] + #[test] + fn request_stream_can_be_cancelled_by_client_tcp() { + init_logger(); + with_tcp_test_socket_run(request_stream_can_be_cancelled_by_client); + } + + #[serial] + #[test] + fn request_stream_can_be_cancelled_by_client_ws() { + init_logger(); + with_ws_test_socket_run(request_stream_can_be_cancelled_by_client); + } + + /// + /// Client requests a channel, consumes an item and drops the stream handle. + /// + /// Amount of active streams is verified before and after requesting and after dropping. + /// + /// Before request_stream: 0 subscribers + /// When request_stream is called: 1 subscriber + /// When request_stream handle is dropped: 0 subscribers + async fn request_stream_can_be_cancelled_by_client(client: Client) { + assert_eq!( + client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(), + Some("0") + ); + + let mut results = client.request_stream(Payload::from("")); + let payload = results.next().await.expect("valid payload").unwrap(); + assert_eq!(payload.metadata_utf8(), Some("subscribers: 1")); + assert_eq!(payload.data_utf8(), Some("0")); + + assert_eq!( + client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(), + Some("1") + ); + + debug!("when the Flux is dropped"); + drop(results); + // Give the server enough time to receive the CANCEL frame + tokio::time::sleep(Duration::from_millis(250)).await; + + assert_eq!( + client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(), + Some("0") + ); + } + + #[serial] + #[test] + fn request_channel_can_be_cancelled_by_client_uds() { + init_logger(); + with_uds_test_socket_run(request_channel_can_be_cancelled_by_client); + } + + #[serial] + #[test] + fn request_channel_can_be_cancelled_by_client_tcp() { + init_logger(); + with_tcp_test_socket_run(request_channel_can_be_cancelled_by_client); + } + + #[serial] + #[test] + fn request_channel_can_be_cancelled_by_client_ws() { + init_logger(); + with_ws_test_socket_run(request_channel_can_be_cancelled_by_client); + } + + /// + /// Client requests a stream, consumes an item and drops the stream handle. + /// + /// Amount of active streams is verified before and after requesting and after dropping. + /// + /// Before request_channel: 0 subscribers + /// When request_channel is called: 1 subscriber + /// When request_channel handle is dropped: 0 subscribers + async fn request_channel_can_be_cancelled_by_client(client: Client) { + assert_eq!( + client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(), + Some("0") + ); + + let mut results = client.request_channel( + stream!{ yield Ok(Payload::from("")) }.boxed() + ); + let payload = results.next().await.expect("valid payload").unwrap(); + assert_eq!(payload.metadata_utf8(), Some("subscribers: 1")); + assert_eq!(payload.data_utf8(), Some("0")); + + assert_eq!( + client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(), + Some("1") + ); + + debug!("when the Flux is dropped"); + drop(results); + // Give the server enough time to receive the CANCEL frame + tokio::time::sleep(Duration::from_millis(250)).await; + + assert_eq!( + client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(), + Some("0") + ); + } + + fn init_logger() { + let _ = env_logger::builder() + .format_timestamp_millis() + .filter_level(log::LevelFilter::Debug) + // .is_test(true) + .try_init(); + } + + /// Executes the [run_test] scenario using a client which is connected over a UDS transport to + /// a TestSocket + fn with_uds_test_socket_run(run_test: F) + where + F: (FnOnce(Client) -> Fut) + Send + 'static, + Fut: Future + Send + 'static, + { + info!("=====> begin uds"); + let server_runtime = Runtime::new().unwrap(); + + server_runtime.spawn(async move { + RSocketFactory::receive() + .transport(UnixServerTransport::from("/tmp/rsocket-uds.sock".to_owned())) + .acceptor(Box::new(|_setup, _socket| { Ok(Box::new(TestSocket::new())) })) + .serve() + .await + }); + + std::thread::sleep(Duration::from_millis(500)); + + let client_runtime = Runtime::new().unwrap(); + + client_runtime.block_on(async { + let client = RSocketFactory::connect() + .acceptor(Box::new(|| Box::new(EchoRSocket))) + .transport(UnixClientTransport::from("/tmp/rsocket-uds.sock".to_owned())) + .setup(Payload::from("READY!")) + .mime_type("text/plain", "text/plain") + .start() + .await + .unwrap(); + run_test(client).await; + }); + info!("<===== uds done!"); + } + + /// Executes the [run_test] scenario using a client which is connected over a UDS transport to + /// a TestSocket + fn with_ws_test_socket_run(run_test: F) + where + F: (FnOnce(Client) -> Fut) + Send + 'static, + Fut: Future + Send + 'static, + { + info!("=====> begin ws"); + let server_runtime = Runtime::new().unwrap(); + server_runtime.spawn(async move { + RSocketFactory::receive() + .transport(WebsocketServerTransport::from("127.0.0.1:8080".to_owned())) + .acceptor(Box::new(|_setup, _socket| { Ok(Box::new(TestSocket::new())) })) + .serve() + .await + }); + + std::thread::sleep(Duration::from_millis(500)); + + let client_runtime = Runtime::new().unwrap(); + + client_runtime.block_on(async { + let client = RSocketFactory::connect() + .acceptor(Box::new(|| Box::new(EchoRSocket))) + .transport(WebsocketClientTransport::from("127.0.0.1:8080")) + .setup(Payload::from("READY!")) + .mime_type("text/plain", "text/plain") + .start() + .await + .unwrap(); + + + run_test(client).await; + }); + info!("<===== ws done!"); + } + + /// Executes the [run_test] scenario using a client which is connected over a TCP transport to + /// a TestSocket + fn with_tcp_test_socket_run(run_test: F) + where + F: (FnOnce(Client) -> Fut) + Send + 'static, + Fut: Future + Send + 'static, + { + info!("=====> begin tcp"); + let server_runtime = Runtime::new().unwrap(); + server_runtime.spawn(async move { + RSocketFactory::receive() + .transport(TcpServerTransport::from("127.0.0.1:7878".to_owned())) + .acceptor(Box::new(|_setup, _socket| { Ok(Box::new(TestSocket::new())) })) + .serve() + .await + }); + + std::thread::sleep(Duration::from_millis(500)); + + let client_runtime = Runtime::new().unwrap(); + + client_runtime.block_on(async { + let client = RSocketFactory::connect() + .acceptor(Box::new(|| Box::new(EchoRSocket))) + .transport(TcpClientTransport::from("127.0.0.1:7878".to_owned())) + .setup(Payload::from("READY!")) + .mime_type("text/plain", "text/plain") + .start() + .await + .unwrap(); + run_test(client).await; + }); + info!("<===== tpc done!"); + } +} + +/// Stateful socket for tests, can be used to count active subscribers. +struct TestSocket { + subscribers: Arc>, +} + +impl TestSocket { + fn new() -> Self { + TestSocket { + subscribers: Arc::new(Mutex::new(0)), + } + } + + fn inc_subscriber_count(subscribers: &Arc>) { + let mut guard = subscribers.lock().unwrap(); + *guard = *guard + 1; + info!(target: "TestSocket", "subscribers:({})", guard); + } + + fn dec_subscriber_count(subscribers: &Arc>) { + let mut guard = subscribers.lock().unwrap(); + *guard = *guard - 1; + info!(target: "TestSocket", "subscribers:({})", guard); + } +} + +#[async_trait] +impl RSocket for TestSocket { + async fn metadata_push(&self, _req: Payload) -> Result<()> { + unimplemented!(); + } + + async fn fire_and_forget(&self, _req: Payload) -> Result<()> { + unimplemented!(); + } + + async fn request_response(&self, req: Payload) -> Result> { + let subscribers = *self.subscribers.lock().unwrap(); + let response = match req.data_utf8() { + Some("subscribers") => format!("{}", subscribers), + _ => "Request payload did not contain a known key!".to_owned(), + }; + Ok(Some(Payload::builder().set_data_utf8(&response).build())) + } + + fn request_stream(&self, _req: Payload) -> Flux> { + let (tx, rx) = tokio::sync::mpsc::channel(32); + let subscribers = self.subscribers.clone(); + tokio::spawn(async move { + TestSocket::inc_subscriber_count(&subscribers); + + for i in 0 as u32..100 { + if tx.is_closed() { + debug!(target: "TestSocket", "tx is closed, break!"); + break; + } + let payload = Payload::builder() + .set_data_utf8(format!("{}", i).as_str()) + .set_metadata_utf8(format!("subscribers: {}", *subscribers.lock().unwrap()).as_str()) + .build(); + tx.send(Ok(payload)).await.unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; + } + + TestSocket::dec_subscriber_count(&subscribers); + }); + + ReceiverStream::new(rx).boxed() + } + + fn request_channel(&self, _reqs: Flux>) -> Flux> { + self.request_stream(Payload::from("")) + } +} diff --git a/rsocket/Cargo.toml b/rsocket/Cargo.toml index 5544f30..0d00135 100644 --- a/rsocket/Cargo.toml +++ b/rsocket/Cargo.toml @@ -29,6 +29,10 @@ version = "1.0.3" default-features = false features = [ "macros", "rt", "rt-multi-thread", "sync", "time" ] +[dependencies.tokio-stream] +version = "0.1.7" +features = ["sync"] + [features] default = [] frame = [] diff --git a/rsocket/src/transport/socket.rs b/rsocket/src/transport/socket.rs index 2e34e65..4a139e8 100644 --- a/rsocket/src/transport/socket.rs +++ b/rsocket/src/transport/socket.rs @@ -7,6 +7,7 @@ use async_trait::async_trait; use bytes::{Buf, BufMut, Bytes, BytesMut}; use dashmap::{mapref::entry::Entry, DashMap}; use futures::{Sink, SinkExt, Stream, StreamExt}; +use futures::future::{AbortHandle, Abortable}; use tokio::sync::{mpsc, oneshot, RwLock}; use super::fragmentation::{Joiner, Splitter}; @@ -28,6 +29,8 @@ pub(crate) struct DuplexSocket { canceller: mpsc::Sender, splitter: Option, joiners: Arc>, + /// AbortHandles for streams and channels associated by sid + abort_handles: Arc>, } #[derive(Clone)] @@ -58,6 +61,7 @@ impl DuplexSocket { handlers: Arc::new(DashMap::new()), joiners: Arc::new(DashMap::new()), splitter, + abort_handles: Arc::new(DashMap::new()), }; let cloned_socket = socket.clone(); @@ -291,6 +295,9 @@ impl DuplexSocket { #[inline] async fn on_cancel(&mut self, sid: u32, _flag: u16) { + if let Some((sid,abort_handle)) = self.abort_handles.remove(&sid) { + abort_handle.abort(); + } self.joiners.remove(&sid); if let Some((_, handler)) = self.handlers.remove(&sid) { let e: Result<_> = @@ -338,11 +345,14 @@ impl DuplexSocket { Handler::ResRR(c) => unreachable!(), Handler::ReqRS(sender) => { if flag & Frame::FLAG_NEXT != 0 { - if let Err(e) = sender.send(Ok(input)).await { + if sender.is_closed() { + self.send_cancel_frame(sid); + } else if let Err(e) = sender.send(Ok(input)).await { error!( "response successful payload for REQUEST_STREAM failed: sid={}", sid ); + self.send_cancel_frame(sid); } } if flag & Frame::FLAG_COMPLETE != 0 { @@ -352,8 +362,11 @@ impl DuplexSocket { Handler::ReqRC(sender) => { // TODO: support channel if flag & Frame::FLAG_NEXT != 0 { - if let Err(_) = sender.clone().send(Ok(input)).await { + if sender.is_closed() { + self.send_cancel_frame(sid); + } else if let Err(_) = sender.clone().send(Ok(input)).await { error!("response successful payload for REQUEST_CHANNEL failed: sid={}",sid); + self.send_cancel_frame(sid); } } if flag & Frame::FLAG_COMPLETE != 0 { @@ -366,6 +379,14 @@ impl DuplexSocket { } } + #[inline] + fn send_cancel_frame(&self, sid: u32) { + let cancel_frame = frame::Cancel::builder(sid, Frame::FLAG_COMPLETE).build(); + if let Err(e) = self.tx.send(cancel_frame) { + error!("Sending CANCEL frame failed: sid={}, reason: {}", sid, e); + } + } + pub(crate) async fn bind_responder(&self, responder: Box) { self.responder.set(responder).await; } @@ -454,9 +475,14 @@ impl DuplexSocket { let responder = self.responder.clone(); let mut tx = self.tx.clone(); let splitter = self.splitter.clone(); + let abort_handles = self.abort_handles.clone(); runtime::spawn(async move { - // TODO: support cancel - let mut payloads = responder.request_stream(input); + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + abort_handles.insert(sid, abort_handle); + let mut payloads = Abortable::new( + responder.request_stream(input), + abort_registration + ); while let Some(next) = payloads.next().await { match next { Ok(it) => { @@ -471,6 +497,7 @@ impl DuplexSocket { } }; } + abort_handles.remove(&sid); let complete = frame::Payload::builder(sid, Frame::FLAG_COMPLETE).build(); tx.send(complete) .expect("Send stream complete response failed"); @@ -484,13 +511,21 @@ impl DuplexSocket { let (sender, mut receiver) = mpsc::channel::>(32); sender.send(Ok(first)).await.expect("Send failed!"); self.register_handler(sid, Handler::ReqRC(sender)).await; + let abort_handles = self.abort_handles.clone(); runtime::spawn(async move { // respond client channel - let mut outputs = responder.request_channel(Box::pin(stream! { + let outputs = responder.request_channel(Box::pin(stream! { while let Some(it) = receiver.recv().await{ yield it; } })); + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + abort_handles.insert(sid, abort_handle); + let mut outputs = Abortable::new( + outputs, + abort_registration + ); + // TODO: support custom RequestN. let request_n = frame::RequestN::builder(sid, 0).build(); @@ -518,6 +553,7 @@ impl DuplexSocket { }; tx.send(sending).expect("Send failed!"); } + abort_handles.remove(&sid); let complete = frame::Payload::builder(sid, Frame::FLAG_COMPLETE).build(); if let Err(e) = tx.send(complete) { error!("complete REQUEST_CHANNEL failed: {}", e);