diff --git a/Cargo.toml b/Cargo.toml index 5e82985..23081fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ repository = "https://github.com/webrtc-rs/data" [dependencies] util = { package = "webrtc-util", version = "0.5.3", default-features = false, features = ["conn", "marshal"] } -sctp = { package = "webrtc-sctp", version = "0.4.3" } +sctp = { package = "webrtc-sctp", version = "0.5.0" } tokio = { version = "1.15.0", features = ["full"] } bytes = "1.1.0" derive_builder = "0.10.2" diff --git a/src/data_channel/data_channel_test.rs b/src/data_channel/data_channel_test.rs index 21755b2..89b44c4 100644 --- a/src/data_channel/data_channel_test.rs +++ b/src/data_channel/data_channel_test.rs @@ -5,6 +5,8 @@ use super::*; use util::conn::conn_bridge::*; use util::conn::*; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; use tokio::sync::{broadcast, mpsc}; use tokio::time::Duration; @@ -406,8 +408,6 @@ async fn test_data_channel_channel_type_partial_reliable_timed_unordered() -> Re pr_ordered_unordered_test(ChannelType::PartialReliableTimedUnordered, false).await } -//TODO: remove this conditional test -#[cfg(not(target_os = "macos"))] #[tokio::test] async fn test_data_channel_buffered_amount() -> Result<()> { let sbuf = vec![0u8; 1000]; @@ -482,6 +482,9 @@ async fn test_data_channel_buffered_amount() -> Result<()> { let dc1_cloned = Arc::clone(&dc1); tokio::spawn(async move { while let Ok(n) = dc1_cloned.read(&mut rbuf[..]).await { + if n == 0 { + break; + } assert_eq!(n, rbuf.len(), "received length should match"); } }); @@ -509,8 +512,6 @@ async fn test_data_channel_buffered_amount() -> Result<()> { Ok(()) } -//TODO: remove this conditional test -#[cfg(not(target_os = "macos"))] #[tokio::test] async fn test_stats() -> Result<()> { let sbuf = vec![0u8; 1000]; @@ -603,3 +604,57 @@ async fn test_stats() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_poll_data_channel() -> Result<()> { + let mut sbuf = vec![0u8; 1000]; + let mut rbuf = vec![0u8; 1500]; + + let (br, ca, cb) = Bridge::new(0, None, None); + + let (a0, a1) = create_new_association_pair(&br, Arc::new(ca), Arc::new(cb)).await?; + + let cfg = Config { + channel_type: ChannelType::Reliable, + reliability_parameter: 123, + label: "data".to_string(), + ..Default::default() + }; + + let dc0 = Arc::new(DataChannel::dial(&a0, 100, cfg.clone()).await?); + bridge_process_at_least_one(&br).await; + + let dc1 = Arc::new(DataChannel::accept(&a1, Config::default()).await?); + bridge_process_at_least_one(&br).await; + + let mut poll_dc0 = PollDataChannel::new(dc0.clone()); + let mut poll_dc1 = PollDataChannel::new(dc1.clone()); + + sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); + let n = poll_dc0 + .write(&Bytes::from(sbuf.clone())) + .await + .map_err(|e| Error::new(e.to_string()))?; + assert_eq!(sbuf.len(), n, "data length should match"); + + bridge_process_at_least_one(&br).await; + + let n = poll_dc1 + .read(&mut rbuf[..]) + .await + .map_err(|e| Error::new(e.to_string()))?; + assert_eq!(sbuf.len(), n, "data length should match"); + assert_eq!( + 1, + u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]), + "data should match" + ); + + dc0.close().await?; + dc1.close().await?; + bridge_process_at_least_one(&br).await; + + close_association_pair(&br, a0, a1).await; + + Ok(()) +} diff --git a/src/data_channel/mod.rs b/src/data_channel/mod.rs index c762b78..b0bb357 100644 --- a/src/data_channel/mod.rs +++ b/src/data_channel/mod.rs @@ -9,12 +9,18 @@ use crate::{ use sctp::{ association::Association, chunk::chunk_payload_data::PayloadProtocolIdentifier, stream::*, }; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use util::marshal::*; use bytes::{Buf, Bytes}; use derive_builder::Builder; +use std::fmt; +use std::io; +use std::net::Shutdown; +use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::task::{Context, Poll}; const RECEIVE_MTU: usize = 8192; @@ -146,7 +152,7 @@ impl DataChannel { Err(err) => { // When the peer sees that an incoming stream was // reset, it also resets its corresponding outgoing stream. - self.stream.close().await?; + self.stream.shutdown(Shutdown::Both).await?; return Err(err.into()); } @@ -289,7 +295,7 @@ impl DataChannel { // a corresponding notification to the application layer that the reset // has been performed. Streams are available for reuse after a reset // has been performed. - Ok(self.stream.close().await?) + Ok(self.stream.shutdown(Shutdown::Both).await?) } /// BufferedAmount returns the number of bytes of data currently queued to be @@ -333,3 +339,119 @@ impl DataChannel { ); } } + +/// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and +/// [`AsyncWrite`]. +/// +/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an +/// additional overhead. +pub struct PollDataChannel { + data_channel: Arc, + poll_stream: PollStream, +} + +impl PollDataChannel { + /// Constructs a new `PollDataChannel`. + pub fn new(data_channel: Arc) -> Self { + let stream = data_channel.stream.clone(); + Self { + data_channel, + poll_stream: PollStream::new(stream), + } + } + + /// Get back the inner data_channel. + pub fn into_inner(self) -> Arc { + self.data_channel + } + + /// Obtain a clone of the inner data_channel. + pub fn clone_inner(&self) -> Arc { + self.data_channel.clone() + } + + /// MessagesSent returns the number of messages sent + pub fn messages_sent(&self) -> usize { + self.data_channel.messages_sent.load(Ordering::SeqCst) + } + + /// MessagesReceived returns the number of messages received + pub fn messages_received(&self) -> usize { + self.data_channel.messages_received.load(Ordering::SeqCst) + } + + /// BytesSent returns the number of bytes sent + pub fn bytes_sent(&self) -> usize { + self.data_channel.bytes_sent.load(Ordering::SeqCst) + } + + /// BytesReceived returns the number of bytes received + pub fn bytes_received(&self) -> usize { + self.data_channel.bytes_received.load(Ordering::SeqCst) + } + + /// StreamIdentifier returns the Stream identifier associated to the stream. + pub fn stream_identifier(&self) -> u16 { + self.poll_stream.stream_identifier() + } + + /// BufferedAmount returns the number of bytes of data currently queued to be + /// sent over this stream. + pub fn buffered_amount(&self) -> usize { + self.poll_stream.buffered_amount() + } + + /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing + /// data that is considered "low." Defaults to 0. + pub fn buffered_amount_low_threshold(&self) -> usize { + self.poll_stream.buffered_amount_low_threshold() + } +} + +impl AsyncRead for PollDataChannel { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.poll_stream).poll_read(cx, buf) + } +} + +impl AsyncWrite for PollDataChannel { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.poll_stream).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.poll_stream).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.poll_stream).poll_shutdown(cx) + } +} + +impl Clone for PollDataChannel { + fn clone(&self) -> PollDataChannel { + PollDataChannel::new(self.clone_inner()) + } +} + +impl fmt::Debug for PollDataChannel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PollDataChannel") + .field("data_channel", &self.data_channel) + .finish() + } +} + +impl AsRef for PollDataChannel { + fn as_ref(&self) -> &DataChannel { + &*self.data_channel + } +}