diff --git a/src/error.rs b/src/error.rs index a07c1d9..b18c4bc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,4 @@ +use std::io; use thiserror::Error; pub type Result = std::result::Result; @@ -222,3 +223,15 @@ pub enum Error { #[error("{0}")] Other(String), } + +impl From for io::Error { + fn from(error: Error) -> Self { + match error { + e @ Error::ErrEof => io::Error::new(io::ErrorKind::UnexpectedEof, e.to_string()), + e @ Error::ErrStreamClosed => { + io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string()) + } + e => io::Error::new(io::ErrorKind::Other, e.to_string()), + } + } +} diff --git a/src/stream/mod.rs b/src/stream/mod.rs index ad54792..13ed433 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -11,9 +11,12 @@ use crate::queue::pending_queue::PendingQueue; use bytes::Bytes; use std::fmt; use std::future::Future; +use std::io; use std::pin::Pin; use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize, Ordering}; use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::{mpsc, Mutex, Notify}; #[derive(Debug, Copy, Clone, PartialEq)] @@ -416,6 +419,8 @@ impl Stream { } } + /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to + /// be read (once chunk is complete). pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize { // No lock is required as it reads the size with atomic load function. let reassembly_queue = self.reassembly_queue.lock().await; @@ -471,3 +476,287 @@ impl Stream { Ok(()) } } + +/// Default capacity of the temporary read buffer used by [`PollStream`]. +const DEFAULT_READ_BUF_SIZE: usize = 4096; + +/// State of the read `Future` in [`PollStream`]. +enum ReadFut<'a> { + /// Nothing in progress. + Idle, + /// Reading data from the underlying stream. + Reading(Pin>> + Send + 'a>>), + /// Finished reading, but there's unread data in the temporary buffer. + RemainingData(Vec), +} + +impl<'a> ReadFut<'a> { + /// Gets a mutable reference to the future stored inside `Reading(future)`. + /// + /// # Panics + /// + /// Panics if `ReadFut` variant is not `Reading`. + fn get_reading_mut( + &mut self, + ) -> &mut Pin>> + Send + 'a>> { + match self { + ReadFut::Reading(ref mut fut) => fut, + _ => panic!("expected ReadFut to be Reading"), + } + } +} + +/// A wrapper around around [`Stream`], which implements [`AsyncRead`] and +/// [`AsyncWrite`]. +/// +/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an +/// additional overhead. +pub struct PollStream<'a> { + stream: Arc, + + read_fut: ReadFut<'a>, + write_fut: Option> + Send + 'a>>>, + shutdown_fut: Option> + Send + 'a>>>, + + read_buf_cap: usize, +} + +impl PollStream<'_> { + /// Constructs a new `PollStream`. + /// + /// # Examples + /// + /// ``` + /// use webrtc_sctp::stream::{Stream, PollStream}; + /// use std::sync::Arc; + /// + /// let stream = Arc::new(Stream::default()); + /// let poll_stream = PollStream::new(stream); + /// ``` + pub fn new(stream: Arc) -> Self { + Self { + stream, + read_fut: ReadFut::Idle, + write_fut: None, + shutdown_fut: None, + read_buf_cap: DEFAULT_READ_BUF_SIZE, + } + } + + /// Get back the inner stream. + pub fn into_inner(self) -> Arc { + self.stream + } + + /// Obtain a clone of the inner stream. + pub fn clone_inner(&self) -> Arc { + self.stream.clone() + } + + /// stream_identifier returns the Stream identifier associated to the stream. + pub fn stream_identifier(&self) -> u16 { + self.stream.stream_identifier + } + + /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream. + pub fn buffered_amount(&self) -> usize { + self.stream.buffered_amount.load(Ordering::SeqCst) + } + + /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is + /// considered "low." Defaults to 0. + pub fn buffered_amount_low_threshold(&self) -> usize { + self.stream.buffered_amount_low.load(Ordering::SeqCst) + } + + /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to + /// be read (once chunk is complete). + pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize { + // No lock is required as it reads the size with atomic load function. + let reassembly_queue = self.stream.reassembly_queue.lock().await; + reassembly_queue.get_num_bytes() + } + + /// Set the capacity of the temporary read buffer (default: 4096). + pub fn set_read_buf_capacity(&mut self, capacity: usize) { + self.read_buf_cap = capacity + } +} + +impl AsyncRead for PollStream<'_> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + let fut = match self.read_fut { + ReadFut::Idle => { + // read into a temporary buffer because `buf` has an unonymous lifetime, which can + // be shorter than the lifetime of `read_fut`. + let stream = self.stream.clone(); + let mut temp_buf = vec![0; self.read_buf_cap]; + self.read_fut = ReadFut::Reading(Box::pin(async move { + let res = stream.read(temp_buf.as_mut_slice()).await; + match res { + Ok(n) => { + temp_buf.truncate(n); + Ok(temp_buf) + } + Err(e) => Err(e), + } + })); + self.read_fut.get_reading_mut() + } + ReadFut::Reading(ref mut fut) => fut, + ReadFut::RemainingData(ref mut data) => { + let remaining = buf.remaining(); + let len = std::cmp::min(data.len(), remaining); + buf.put_slice(&data[..len]); + if data.len() > remaining { + // ReadFut remains to be RemainingData + data.drain(0..len); + } else { + self.read_fut = ReadFut::Idle; + } + return Poll::Ready(Ok(())); + } + }; + + loop { + match fut.as_mut().poll(cx) { + Poll::Pending => return Poll::Pending, + // retry immediately upon empty data or incomplete chunks + // since there's no way to setup a waker. + Poll::Ready(Err(Error::ErrTryAgain)) => {} + // EOF has been reached => don't touch buf and just return Ok + Poll::Ready(Err(Error::ErrEof)) => { + self.read_fut = ReadFut::Idle; + return Poll::Ready(Ok(())); + } + Poll::Ready(Err(e)) => { + self.read_fut = ReadFut::Idle; + return Poll::Ready(Err(e.into())); + } + Poll::Ready(Ok(mut temp_buf)) => { + let remaining = buf.remaining(); + let len = std::cmp::min(temp_buf.len(), remaining); + buf.put_slice(&temp_buf[..len]); + if temp_buf.len() > remaining { + temp_buf.drain(0..len); + self.read_fut = ReadFut::RemainingData(temp_buf); + } else { + self.read_fut = ReadFut::Idle; + } + return Poll::Ready(Ok(())); + } + } + } + } +} + +impl AsyncWrite for PollStream<'_> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let (fut, fut_is_new) = match self.write_fut.as_mut() { + Some(fut) => (fut, false), + None => { + let stream = self.stream.clone(); + let bytes = Bytes::copy_from_slice(buf); + ( + self.write_fut + .get_or_insert(Box::pin(async move { stream.write(&bytes).await })), + true, + ) + } + }; + + match fut.as_mut().poll(cx) { + Poll::Pending => { + // If it's the first time we're polling the future, `Poll::Pending` can't be + // returned because that would mean the `PollStream` is not ready for writing. And + // this is not true since we've just created a future, which is going to write the + // buf to the underlying stream. + // + // It's okay to return `Poll::Ready` if the data is buffered (this is what the + // buffered writer and `File` do). + if fut_is_new { + Poll::Ready(Ok(buf.len())) + } else { + // If it's the subsequent poll, it's okay to return `Poll::Pending` as it + // indicates that the `PollStream` is not ready for writing. Only one future + // can be in progress at the time. + Poll::Pending + } + } + Poll::Ready(Err(e)) => { + self.write_fut = None; + Poll::Ready(Err(e.into())) + } + Poll::Ready(Ok(n)) => { + self.write_fut = None; + Poll::Ready(Ok(n)) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.write_fut.as_mut() { + Some(fut) => match fut.as_mut().poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { + self.write_fut = None; + Poll::Ready(Err(e.into())) + } + Poll::Ready(Ok(_)) => { + self.write_fut = None; + Poll::Ready(Ok(())) + } + }, + None => Poll::Ready(Ok(())), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let fut = match self.shutdown_fut.as_mut() { + Some(fut) => fut, + None => { + let stream = self.stream.clone(); + self.shutdown_fut + .get_or_insert(Box::pin(async move { stream.close().await })) + } + }; + + match fut.as_mut().poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + } + } +} + +impl<'a> Clone for PollStream<'a> { + fn clone(&self) -> PollStream<'a> { + PollStream::new(self.clone_inner()) + } +} + +impl fmt::Debug for PollStream<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PollStream") + .field("stream", &self.stream) + .finish() + } +} + +impl AsRef for PollStream<'_> { + fn as_ref(&self) -> &Stream { + &*self.stream + } +} diff --git a/src/stream/stream_test.rs b/src/stream/stream_test.rs index c80e0da..cab45af 100644 --- a/src/stream/stream_test.rs +++ b/src/stream/stream_test.rs @@ -1,6 +1,8 @@ use super::*; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; #[test] fn test_stream_buffered_amount() -> Result<()> { @@ -69,3 +71,57 @@ async fn test_stream_amount_on_buffered_amount_low() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_poll_stream() -> std::result::Result<(), io::Error> { + let s = Arc::new(Stream::new( + "test_poll_stream".to_owned(), + 0, + 4096, + Arc::new(AtomicU32::new(4096)), + Arc::new(AtomicU8::new(AssociationState::Established as u8)), + None, + Arc::new(PendingQueue::new()), + )); + let mut poll_stream = PollStream::new(s.clone()); + + // getters + assert_eq!(0, poll_stream.stream_identifier()); + assert_eq!(0, poll_stream.buffered_amount()); + assert_eq!(0, poll_stream.buffered_amount_low_threshold()); + assert_eq!(0, poll_stream.get_num_bytes_in_reassembly_queue().await); + + // async write + let n = poll_stream.write(&[1, 2, 3]).await?; + assert_eq!(3, n); + poll_stream.flush().await?; + assert_eq!(3, poll_stream.buffered_amount()); + + // async read + // 1. pretend that we've received a chunk + let sc = s.clone(); + sc.handle_data(ChunkPayloadData { + unordered: true, + beginning_fragment: true, + ending_fragment: true, + user_data: Bytes::from_static(&[0, 1, 2, 3, 4]), + payload_type: PayloadProtocolIdentifier::Binary, + ..Default::default() + }) + .await; + // 2. read it + let mut buf = [0; 5]; + poll_stream.read(&mut buf).await?; + assert_eq!(buf, [0, 1, 2, 3, 4]); + + // shutdown + poll_stream.shutdown().await?; + assert_eq!(true, sc.closed.load(Ordering::Relaxed)); + assert!(poll_stream.read(&mut buf).await.is_err()); + + // misc. + let clone = poll_stream.clone(); + assert_eq!(clone.stream_identifier(), poll_stream.stream_identifier()); + + Ok(()) +}