Skip to content
This repository was archived by the owner on Aug 23, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::io;
use thiserror::Error;

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -222,3 +223,15 @@ pub enum Error {
#[error("{0}")]
Other(String),
}

impl From<Error> for io::Error {
fn from(error: Error) -> Self {
match error {
e @ Error::ErrEof => io::Error::new(io::ErrorKind::UnexpectedEof, e.to_string()),
e @ Error::ErrStreamClosed => {
io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string())
}
e => io::Error::new(io::ErrorKind::Other, e.to_string()),
}
}
}
289 changes: 289 additions & 0 deletions src/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ use crate::queue::pending_queue::PendingQueue;
use bytes::Bytes;
use std::fmt;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::{mpsc, Mutex, Notify};

#[derive(Debug, Copy, Clone, PartialEq)]
Expand Down Expand Up @@ -416,6 +419,8 @@ impl Stream {
}
}

/// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to
/// be read (once chunk is complete).
pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize {
// No lock is required as it reads the size with atomic load function.
let reassembly_queue = self.reassembly_queue.lock().await;
Expand Down Expand Up @@ -471,3 +476,287 @@ impl Stream {
Ok(())
}
}

/// Default capacity of the temporary read buffer used by [`PollStream`].
const DEFAULT_READ_BUF_SIZE: usize = 4096;

/// State of the read `Future` in [`PollStream`].
enum ReadFut<'a> {
/// Nothing in progress.
Idle,
/// Reading data from the underlying stream.
Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send + 'a>>),
/// Finished reading, but there's unread data in the temporary buffer.
RemainingData(Vec<u8>),
}

impl<'a> ReadFut<'a> {
/// Gets a mutable reference to the future stored inside `Reading(future)`.
///
/// # Panics
///
/// Panics if `ReadFut` variant is not `Reading`.
fn get_reading_mut(
&mut self,
) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send + 'a>> {
match self {
ReadFut::Reading(ref mut fut) => fut,
_ => panic!("expected ReadFut to be Reading"),
}
}
}

/// A wrapper around around [`Stream`], which implements [`AsyncRead`] and
/// [`AsyncWrite`].
///
/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an
/// additional overhead.
pub struct PollStream<'a> {
stream: Arc<Stream>,

read_fut: ReadFut<'a>,
write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>> + Send + 'a>>>,
shutdown_fut: Option<Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>>,

read_buf_cap: usize,
}

impl PollStream<'_> {
/// Constructs a new `PollStream`.
///
/// # Examples
///
/// ```
/// use webrtc_sctp::stream::{Stream, PollStream};
/// use std::sync::Arc;
///
/// let stream = Arc::new(Stream::default());
/// let poll_stream = PollStream::new(stream);
/// ```
pub fn new(stream: Arc<Stream>) -> Self {
Self {
stream,
read_fut: ReadFut::Idle,
write_fut: None,
shutdown_fut: None,
read_buf_cap: DEFAULT_READ_BUF_SIZE,
}
}

/// Get back the inner stream.
pub fn into_inner(self) -> Arc<Stream> {
self.stream
}

/// Obtain a clone of the inner stream.
pub fn clone_inner(&self) -> Arc<Stream> {
self.stream.clone()
}

/// stream_identifier returns the Stream identifier associated to the stream.
pub fn stream_identifier(&self) -> u16 {
self.stream.stream_identifier
}

/// buffered_amount returns the number of bytes of data currently queued to be sent over this stream.
pub fn buffered_amount(&self) -> usize {
self.stream.buffered_amount.load(Ordering::SeqCst)
}

/// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is
/// considered "low." Defaults to 0.
pub fn buffered_amount_low_threshold(&self) -> usize {
self.stream.buffered_amount_low.load(Ordering::SeqCst)
}

/// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to
/// be read (once chunk is complete).
pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize {
// No lock is required as it reads the size with atomic load function.
let reassembly_queue = self.stream.reassembly_queue.lock().await;
reassembly_queue.get_num_bytes()
}

/// Set the capacity of the temporary read buffer (default: 4096).
pub fn set_read_buf_capacity(&mut self, capacity: usize) {
self.read_buf_cap = capacity
}
}

impl AsyncRead for PollStream<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}

let fut = match self.read_fut {
ReadFut::Idle => {
// read into a temporary buffer because `buf` has an unonymous lifetime, which can
// be shorter than the lifetime of `read_fut`.
let stream = self.stream.clone();
let mut temp_buf = vec![0; self.read_buf_cap];
self.read_fut = ReadFut::Reading(Box::pin(async move {
let res = stream.read(temp_buf.as_mut_slice()).await;
match res {
Ok(n) => {
temp_buf.truncate(n);
Ok(temp_buf)
}
Err(e) => Err(e),
}
}));
self.read_fut.get_reading_mut()
}
ReadFut::Reading(ref mut fut) => fut,
ReadFut::RemainingData(ref mut data) => {
let remaining = buf.remaining();
let len = std::cmp::min(data.len(), remaining);
buf.put_slice(&data[..len]);
if data.len() > remaining {
// ReadFut remains to be RemainingData
data.drain(0..len);
} else {
self.read_fut = ReadFut::Idle;
}
return Poll::Ready(Ok(()));
}
};

loop {
match fut.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
// retry immediately upon empty data or incomplete chunks
// since there's no way to setup a waker.
Poll::Ready(Err(Error::ErrTryAgain)) => {}
// EOF has been reached => don't touch buf and just return Ok
Poll::Ready(Err(Error::ErrEof)) => {
self.read_fut = ReadFut::Idle;
return Poll::Ready(Ok(()));
}
Poll::Ready(Err(e)) => {
self.read_fut = ReadFut::Idle;
return Poll::Ready(Err(e.into()));
}
Poll::Ready(Ok(mut temp_buf)) => {
let remaining = buf.remaining();
let len = std::cmp::min(temp_buf.len(), remaining);
buf.put_slice(&temp_buf[..len]);
if temp_buf.len() > remaining {
temp_buf.drain(0..len);
self.read_fut = ReadFut::RemainingData(temp_buf);
} else {
self.read_fut = ReadFut::Idle;
}
return Poll::Ready(Ok(()));
}
}
}
}
}

impl AsyncWrite for PollStream<'_> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let (fut, fut_is_new) = match self.write_fut.as_mut() {
Some(fut) => (fut, false),
None => {
let stream = self.stream.clone();
let bytes = Bytes::copy_from_slice(buf);
(
self.write_fut
.get_or_insert(Box::pin(async move { stream.write(&bytes).await })),
true,
)
}
};

match fut.as_mut().poll(cx) {
Poll::Pending => {
// If it's the first time we're polling the future, `Poll::Pending` can't be
// returned because that would mean the `PollStream` is not ready for writing. And
// this is not true since we've just created a future, which is going to write the
// buf to the underlying stream.
//
// It's okay to return `Poll::Ready` if the data is buffered (this is what the
// buffered writer and `File` do).
if fut_is_new {
Poll::Ready(Ok(buf.len()))
} else {
// If it's the subsequent poll, it's okay to return `Poll::Pending` as it
// indicates that the `PollStream` is not ready for writing. Only one future
// can be in progress at the time.
Poll::Pending
}
}
Poll::Ready(Err(e)) => {
self.write_fut = None;
Poll::Ready(Err(e.into()))
}
Poll::Ready(Ok(n)) => {
self.write_fut = None;
Poll::Ready(Ok(n))
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.write_fut.as_mut() {
Some(fut) => match fut.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
self.write_fut = None;
Poll::Ready(Err(e.into()))
}
Poll::Ready(Ok(_)) => {
self.write_fut = None;
Poll::Ready(Ok(()))
}
},
None => Poll::Ready(Ok(())),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let fut = match self.shutdown_fut.as_mut() {
Some(fut) => fut,
None => {
let stream = self.stream.clone();
self.shutdown_fut
.get_or_insert(Box::pin(async move { stream.close().await }))
}
};

match fut.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
}
}
}

impl<'a> Clone for PollStream<'a> {
fn clone(&self) -> PollStream<'a> {
PollStream::new(self.clone_inner())
}
}

impl fmt::Debug for PollStream<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PollStream")
.field("stream", &self.stream)
.finish()
}
}

impl AsRef<Stream> for PollStream<'_> {
fn as_ref(&self) -> &Stream {
&*self.stream
}
}
56 changes: 56 additions & 0 deletions src/stream/stream_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;

#[test]
fn test_stream_buffered_amount() -> Result<()> {
Expand Down Expand Up @@ -69,3 +71,57 @@ async fn test_stream_amount_on_buffered_amount_low() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_poll_stream() -> std::result::Result<(), io::Error> {
let s = Arc::new(Stream::new(
"test_poll_stream".to_owned(),
0,
4096,
Arc::new(AtomicU32::new(4096)),
Arc::new(AtomicU8::new(AssociationState::Established as u8)),
None,
Arc::new(PendingQueue::new()),
));
let mut poll_stream = PollStream::new(s.clone());

// getters
assert_eq!(0, poll_stream.stream_identifier());
assert_eq!(0, poll_stream.buffered_amount());
assert_eq!(0, poll_stream.buffered_amount_low_threshold());
assert_eq!(0, poll_stream.get_num_bytes_in_reassembly_queue().await);

// async write
let n = poll_stream.write(&[1, 2, 3]).await?;
assert_eq!(3, n);
poll_stream.flush().await?;
assert_eq!(3, poll_stream.buffered_amount());

// async read
// 1. pretend that we've received a chunk
let sc = s.clone();
sc.handle_data(ChunkPayloadData {
unordered: true,
beginning_fragment: true,
ending_fragment: true,
user_data: Bytes::from_static(&[0, 1, 2, 3, 4]),
payload_type: PayloadProtocolIdentifier::Binary,
..Default::default()
})
.await;
// 2. read it
let mut buf = [0; 5];
poll_stream.read(&mut buf).await?;
assert_eq!(buf, [0, 1, 2, 3, 4]);

// shutdown
poll_stream.shutdown().await?;
assert_eq!(true, sc.closed.load(Ordering::Relaxed));
assert!(poll_stream.read(&mut buf).await.is_err());

// misc.
let clone = poll_stream.clone();
assert_eq!(clone.stream_identifier(), poll_stream.stream_identifier());

Ok(())
}