Skip to content
This repository was archived by the owner on Aug 23, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repository = "https://github.com/webrtc-rs/data"

[dependencies]
util = { package = "webrtc-util", version = "0.5.3", default-features = false, features = ["conn", "marshal"] }
sctp = { package = "webrtc-sctp", version = "0.4.3" }
sctp = { package = "webrtc-sctp", version = "0.5.0" }
tokio = { version = "1.15.0", features = ["full"] }
bytes = "1.1.0"
derive_builder = "0.10.2"
Expand Down
63 changes: 59 additions & 4 deletions src/data_channel/data_channel_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use super::*;
use util::conn::conn_bridge::*;
use util::conn::*;

use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::sync::{broadcast, mpsc};
use tokio::time::Duration;

Expand Down Expand Up @@ -406,8 +408,6 @@ async fn test_data_channel_channel_type_partial_reliable_timed_unordered() -> Re
pr_ordered_unordered_test(ChannelType::PartialReliableTimedUnordered, false).await
}

//TODO: remove this conditional test
#[cfg(not(target_os = "macos"))]
#[tokio::test]
async fn test_data_channel_buffered_amount() -> Result<()> {
let sbuf = vec![0u8; 1000];
Expand Down Expand Up @@ -482,6 +482,9 @@ async fn test_data_channel_buffered_amount() -> Result<()> {
let dc1_cloned = Arc::clone(&dc1);
tokio::spawn(async move {
while let Ok(n) = dc1_cloned.read(&mut rbuf[..]).await {
if n == 0 {
break;
}
assert_eq!(n, rbuf.len(), "received length should match");
}
});
Expand Down Expand Up @@ -509,8 +512,6 @@ async fn test_data_channel_buffered_amount() -> Result<()> {
Ok(())
}

//TODO: remove this conditional test
#[cfg(not(target_os = "macos"))]
#[tokio::test]
async fn test_stats() -> Result<()> {
let sbuf = vec![0u8; 1000];
Expand Down Expand Up @@ -603,3 +604,57 @@ async fn test_stats() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_poll_data_channel() -> Result<()> {
let mut sbuf = vec![0u8; 1000];
let mut rbuf = vec![0u8; 1500];

let (br, ca, cb) = Bridge::new(0, None, None);

let (a0, a1) = create_new_association_pair(&br, Arc::new(ca), Arc::new(cb)).await?;

let cfg = Config {
channel_type: ChannelType::Reliable,
reliability_parameter: 123,
label: "data".to_string(),
..Default::default()
};

let dc0 = Arc::new(DataChannel::dial(&a0, 100, cfg.clone()).await?);
bridge_process_at_least_one(&br).await;

let dc1 = Arc::new(DataChannel::accept(&a1, Config::default()).await?);
bridge_process_at_least_one(&br).await;

let mut poll_dc0 = PollDataChannel::new(dc0.clone());
let mut poll_dc1 = PollDataChannel::new(dc1.clone());

sbuf[0..4].copy_from_slice(&1u32.to_be_bytes());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps using a more "tricky" value than 1 here would be more robust. e.g. a random value with bits set in all octets such as 961666284

let n = poll_dc0
.write(&Bytes::from(sbuf.clone()))
.await
.map_err(|e| Error::new(e.to_string()))?;
assert_eq!(sbuf.len(), n, "data length should match");

bridge_process_at_least_one(&br).await;

let n = poll_dc1
.read(&mut rbuf[..])
.await
.map_err(|e| Error::new(e.to_string()))?;
assert_eq!(sbuf.len(), n, "data length should match");
assert_eq!(
1,
u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]),
"data should match"
);

dc0.close().await?;
dc1.close().await?;
bridge_process_at_least_one(&br).await;

close_association_pair(&br, a0, a1).await;

Ok(())
}
126 changes: 124 additions & 2 deletions src/data_channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@ use crate::{
use sctp::{
association::Association, chunk::chunk_payload_data::PayloadProtocolIdentifier, stream::*,
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use util::marshal::*;

use bytes::{Buf, Bytes};
use derive_builder::Builder;
use std::fmt;
use std::io;
use std::net::Shutdown;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};

const RECEIVE_MTU: usize = 8192;

Expand Down Expand Up @@ -146,7 +152,7 @@ impl DataChannel {
Err(err) => {
// When the peer sees that an incoming stream was
// reset, it also resets its corresponding outgoing stream.
self.stream.close().await?;
self.stream.shutdown(Shutdown::Both).await?;

return Err(err.into());
}
Expand Down Expand Up @@ -289,7 +295,7 @@ impl DataChannel {
// a corresponding notification to the application layer that the reset
// has been performed. Streams are available for reuse after a reset
// has been performed.
Ok(self.stream.close().await?)
Ok(self.stream.shutdown(Shutdown::Both).await?)
}

/// BufferedAmount returns the number of bytes of data currently queued to be
Expand Down Expand Up @@ -333,3 +339,119 @@ impl DataChannel {
);
}
}

/// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and
/// [`AsyncWrite`].
///
/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an
/// additional overhead.
pub struct PollDataChannel {
data_channel: Arc<DataChannel>,
poll_stream: PollStream,
}

impl PollDataChannel {
/// Constructs a new `PollDataChannel`.
pub fn new(data_channel: Arc<DataChannel>) -> Self {
let stream = data_channel.stream.clone();
Self {
data_channel,
poll_stream: PollStream::new(stream),
}
}

/// Get back the inner data_channel.
pub fn into_inner(self) -> Arc<DataChannel> {
self.data_channel
}

/// Obtain a clone of the inner data_channel.
pub fn clone_inner(&self) -> Arc<DataChannel> {
self.data_channel.clone()
}

/// MessagesSent returns the number of messages sent
pub fn messages_sent(&self) -> usize {
self.data_channel.messages_sent.load(Ordering::SeqCst)
}

/// MessagesReceived returns the number of messages received
pub fn messages_received(&self) -> usize {
self.data_channel.messages_received.load(Ordering::SeqCst)
}

/// BytesSent returns the number of bytes sent
pub fn bytes_sent(&self) -> usize {
self.data_channel.bytes_sent.load(Ordering::SeqCst)
}

/// BytesReceived returns the number of bytes received
pub fn bytes_received(&self) -> usize {
self.data_channel.bytes_received.load(Ordering::SeqCst)
}

/// StreamIdentifier returns the Stream identifier associated to the stream.
pub fn stream_identifier(&self) -> u16 {
self.poll_stream.stream_identifier()
}

/// BufferedAmount returns the number of bytes of data currently queued to be
/// sent over this stream.
pub fn buffered_amount(&self) -> usize {
self.poll_stream.buffered_amount()
}

/// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
/// data that is considered "low." Defaults to 0.
pub fn buffered_amount_low_threshold(&self) -> usize {
self.poll_stream.buffered_amount_low_threshold()
}
}

impl AsyncRead for PollDataChannel {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.poll_stream).poll_read(cx, buf)
}
}

impl AsyncWrite for PollDataChannel {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.poll_stream).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.poll_stream).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.poll_stream).poll_shutdown(cx)
}
}

impl Clone for PollDataChannel {
fn clone(&self) -> PollDataChannel {
PollDataChannel::new(self.clone_inner())
}
}

impl fmt::Debug for PollDataChannel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PollDataChannel")
.field("data_channel", &self.data_channel)
.finish()
}
}

impl AsRef<DataChannel> for PollDataChannel {
fn as_ref(&self) -> &DataChannel {
&*self.data_channel
}
}