From b20e9a90b1b3527a5663b2d4bdb23713845388ef Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Tue, 12 Apr 2022 17:57:09 +0400 Subject: [PATCH 01/17] implement tokio's AsyncRead and AsyncWrite for Stream Closes https://github.com/webrtc-rs/webrtc/issues/110 --- src/error.rs | 13 ++++++++++++ src/stream/mod.rs | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/src/error.rs b/src/error.rs index a07c1d9..b18c4bc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,4 @@ +use std::io; use thiserror::Error; pub type Result = std::result::Result; @@ -222,3 +223,15 @@ pub enum Error { #[error("{0}")] Other(String), } + +impl From for io::Error { + fn from(error: Error) -> Self { + match error { + e @ Error::ErrEof => io::Error::new(io::ErrorKind::UnexpectedEof, e.to_string()), + e @ Error::ErrStreamClosed => { + io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string()) + } + e => io::Error::new(io::ErrorKind::Other, e.to_string()), + } + } +} diff --git a/src/stream/mod.rs b/src/stream/mod.rs index ad54792..5701fb6 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -14,6 +14,9 @@ use std::future::Future; use std::pin::Pin; use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize, Ordering}; use std::sync::Arc; +use std::task::{Context, Poll}; +use std::io; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::{mpsc, Mutex, Notify}; #[derive(Debug, Copy, Clone, PartialEq)] @@ -471,3 +474,50 @@ impl Stream { Ok(()) } } + +impl AsyncRead for Stream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + loop { + match Pin::new(&mut Box::pin(self.read(buf.initialized_mut()))).poll(cx) { + Poll::Pending => return Poll::Pending, + // retry immediately upon empty data or incomplete chunks + // since there's no way to setup a waker. + Poll::Ready(Err(Error::ErrTryAgain)) => {}, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + Poll::Ready(Ok(_)) => return Poll::Ready(Ok(())), + } + } + } +} + +impl AsyncWrite for Stream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match Pin::new(&mut Box::pin(self.write(&Bytes::copy_from_slice(buf)))).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), + } + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // sctp flush is a no-op + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut Box::pin(self.close())).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + } + } +} From 51fb1cee79d733d102f66599312467cd78422092 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Wed, 13 Apr 2022 15:45:12 +0400 Subject: [PATCH 02/17] add PollStream wrapper calling poll on a stream directly is incorrect because the object is destroyed every time the function exits --- src/stream/mod.rs | 59 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 5701fb6..ecd0441 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -475,14 +475,39 @@ impl Stream { } } -impl AsyncRead for Stream { +/// PollStream is a wrapper around [`Stream`], which implements [`AsyncRead`] and [`AsyncWrite`]. +struct PollStream { + stream: Arc, + read_op: Option>>>>, + write_op: Option>>>>, + shutdown_op: Option>>>>, +} + +impl PollStream { + /// Creates a new PollStream. + pub fn new(stream: Arc) -> Self { + Self { + stream, + read_op: None, + write_op: None, + shutdown_op: None, + } + } +} + +impl AsyncRead for PollStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { + if self.read_op.is_none() { + self.read_op = Some(Box::pin(self.stream.clone().read(buf.initialized_mut()))); + } + + let read_op = Pin::new(&mut self.read_op.unwrap()); loop { - match Pin::new(&mut Box::pin(self.read(buf.initialized_mut()))).poll(cx) { + match read_op.poll(cx) { Poll::Pending => return Poll::Pending, // retry immediately upon empty data or incomplete chunks // since there's no way to setup a waker. @@ -494,27 +519,43 @@ impl AsyncRead for Stream { } } -impl AsyncWrite for Stream { +impl AsyncWrite for PollStream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - match Pin::new(&mut Box::pin(self.write(&Bytes::copy_from_slice(buf)))).poll(cx) { + if self.write_op.is_none() { + self.write_op = Some(Box::pin(self.stream.clone().write(&Bytes::copy_from_slice(buf)))); + } + + let write_op = Pin::new(&mut self.write_op.unwrap()); + match write_op.poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), } } - #[inline] - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - // sctp flush is a no-op - Poll::Ready(Ok(())) + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.write_op { + Some(op) => + match Pin::new(&mut op).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + }, + None => Poll::Ready(Ok(())), + } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match Pin::new(&mut Box::pin(self.close())).poll(cx) { + if self.shutdown_op.is_none() { + self.shutdown_op = Some(Box::pin(self.stream.clone().close())); + } + + let shutdown_op = Pin::new(&mut self.shutdown_op.unwrap()); + match shutdown_op.poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), From 58c8f8b4e195cd851df78e1e73b40de806ea8541 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Wed, 13 Apr 2022 17:22:16 +0400 Subject: [PATCH 03/17] add a temporary buffer to read into --- src/stream/mod.rs | 48 +++++++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index ecd0441..058559b 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -11,11 +11,11 @@ use crate::queue::pending_queue::PendingQueue; use bytes::Bytes; use std::fmt; use std::future::Future; +use std::io; use std::pin::Pin; use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize, Ordering}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::io; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::{mpsc, Mutex, Notify}; @@ -479,6 +479,7 @@ impl Stream { struct PollStream { stream: Arc, read_op: Option>>>>, + read_buf: Vec, write_op: Option>>>>, shutdown_op: Option>>>>, } @@ -486,9 +487,10 @@ struct PollStream { impl PollStream { /// Creates a new PollStream. pub fn new(stream: Arc) -> Self { - Self { + Self { stream, read_op: None, + read_buf: Vec::new(), write_op: None, shutdown_op: None, } @@ -501,8 +503,11 @@ impl AsyncRead for PollStream { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - if self.read_op.is_none() { - self.read_op = Some(Box::pin(self.stream.clone().read(buf.initialized_mut()))); + if self.read_op.is_none() { + self.read_buf = Pin::new(Vec::with_capacity(buf.capacity())); + self.read_op = Some(Box::pin( + self.stream.clone().read(self.read_buf.as_mut_slice()), + )); } let read_op = Pin::new(&mut self.read_op.unwrap()); @@ -511,9 +516,13 @@ impl AsyncRead for PollStream { Poll::Pending => return Poll::Pending, // retry immediately upon empty data or incomplete chunks // since there's no way to setup a waker. - Poll::Ready(Err(Error::ErrTryAgain)) => {}, + Poll::Ready(Err(Error::ErrTryAgain)) => {} Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), - Poll::Ready(Ok(_)) => return Poll::Ready(Ok(())), + Poll::Ready(Ok(n)) => { + self.read_op = None; + buf.put_slice(self.read_buf.as_slice()); + return Poll::Ready(Ok(())); + } } } } @@ -525,32 +534,39 @@ impl AsyncWrite for PollStream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - if self.write_op.is_none() { - self.write_op = Some(Box::pin(self.stream.clone().write(&Bytes::copy_from_slice(buf)))); + if self.write_op.is_none() { + self.write_op = Some(Box::pin( + self.stream.clone().write(&Bytes::copy_from_slice(buf)), + )); } let write_op = Pin::new(&mut self.write_op.unwrap()); match write_op.poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), + Poll::Ready(Ok(n)) => { + self.write_op = None; + Poll::Ready(Ok(n)) + } } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.write_op { - Some(op) => - match Pin::new(&mut op).poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), - }, + Some(op) => match Pin::new(&mut op).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Ok(_)) => { + self.write_op = None; + Poll::Ready(Ok(())) + } + }, None => Poll::Ready(Ok(())), } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.shutdown_op.is_none() { + if self.shutdown_op.is_none() { self.shutdown_op = Some(Box::pin(self.stream.clone().close())); } From 3c6f6bc770ecccbbd002d917018e0ae82b68c174 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Fri, 15 Apr 2022 16:43:44 +0400 Subject: [PATCH 04/17] copy some of fn from Stream also, implement Clone, Debug and AsRef and comment out AsyncWrite for now --- src/stream/mod.rs | 231 +++++++++++++++++++++++++++++++++------------- 1 file changed, 166 insertions(+), 65 deletions(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 058559b..5499936 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -419,6 +419,8 @@ impl Stream { } } + /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to + /// be read (once chunk is complete). pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize { // No lock is required as it reads the size with atomic load function. let reassembly_queue = self.reassembly_queue.lock().await; @@ -475,52 +477,127 @@ impl Stream { } } -/// PollStream is a wrapper around [`Stream`], which implements [`AsyncRead`] and [`AsyncWrite`]. -struct PollStream { +// struct ReadFuture { +// buf: Arc>, +// inner: Pin> + Send>>, +// } + +// impl ReadFuture { +// pub fn new(mut buf: Vec, stream: Arc) -> Self { +// let mut buf = Arc::new(buf); +// Self { buf: buf.clone(), inner: Box::pin(stream.read(buf.as_mut_slice())) } +// } +// } + +// impl Future for ReadFuture { +// type Output = Result>; +// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { +// match self.as_mut().inner.as_mut().poll(cx) { +// Poll::Ready(Ok(n)) => Poll::Ready(Ok(self.buf.to_vec())), +// Poll::Ready(Err(e)) => Poll::Ready(Err(e)), +// Poll::Pending => Poll::Pending, +// } +// } +// } + +/// A wrapper around around [`Stream`], which implements [`AsyncRead`] and +/// [`AsyncWrite`]. +/// +/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an +/// additional overhead. +struct PollStream<'a> { stream: Arc, - read_op: Option>>>>, + + read_fut: Option> + Send + 'a>>>, + write_fut: Option> + Send + 'a>>>, + shutdown_fut: Option> + Send + 'a>>>, + read_buf: Vec, - write_op: Option>>>>, - shutdown_op: Option>>>>, } -impl PollStream { +impl PollStream<'_> { /// Creates a new PollStream. pub fn new(stream: Arc) -> Self { Self { stream, - read_op: None, + read_fut: None, + write_fut: None, + shutdown_fut: None, read_buf: Vec::new(), - write_op: None, - shutdown_op: None, } } + + /// Get back the inner stream. + pub fn into_inner(self) -> Arc { + self.stream + } + + /// Obtain a clone of the inner stream. + pub fn clone_inner(&self) -> Arc { + self.stream.clone() + } + + /// stream_identifier returns the Stream identifier associated to the stream. + pub fn stream_identifier(&self) -> u16 { + self.stream.stream_identifier + } + + /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream. + pub fn buffered_amount(&self) -> usize { + self.stream.buffered_amount.load(Ordering::SeqCst) + } + + /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is + /// considered "low." Defaults to 0. + pub fn buffered_amount_low_threshold(&self) -> usize { + self.stream.buffered_amount_low.load(Ordering::SeqCst) + } + + /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to + /// be read (once chunk is complete). + pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize { + // No lock is required as it reads the size with atomic load function. + let reassembly_queue = self.stream.reassembly_queue.lock().await; + reassembly_queue.get_num_bytes() + } } -impl AsyncRead for PollStream { +impl AsyncRead for PollStream<'_> { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - if self.read_op.is_none() { - self.read_buf = Pin::new(Vec::with_capacity(buf.capacity())); - self.read_op = Some(Box::pin( - self.stream.clone().read(self.read_buf.as_mut_slice()), - )); + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); } - let read_op = Pin::new(&mut self.read_op.unwrap()); + let fut = match self.read_fut.as_mut() { + Some(fut) => fut, + None => { + + // read into a temporary buffer because `buf` has an unonymous lifetime, which can be + // shorter than the lifetime of `read_fut`. + let stream = self.stream.clone(); + let temp_buf = Vec::with_capacity(buf.capacity()); + self.read_fut.get_or_insert(Box::pin(( move || { + stream.read(temp_buf.as_mut_slice()) + })() + )) + } + }; + loop { - match read_op.poll(cx) { + match fut.as_mut().poll(cx) { Poll::Pending => return Poll::Pending, // retry immediately upon empty data or incomplete chunks // since there's no way to setup a waker. Poll::Ready(Err(Error::ErrTryAgain)) => {} Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), - Poll::Ready(Ok(n)) => { - self.read_op = None; - buf.put_slice(self.read_buf.as_slice()); + Poll::Ready(Ok(read_buf)) => { + let len = std::cmp::min(read_buf, buf.remaining()); + buf.put_slice(&self.read_buf[..len]); + self.read_fut = None; return Poll::Ready(Ok(())); } } @@ -528,53 +605,77 @@ impl AsyncRead for PollStream { } } -impl AsyncWrite for PollStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if self.write_op.is_none() { - self.write_op = Some(Box::pin( - self.stream.clone().write(&Bytes::copy_from_slice(buf)), - )); - } - - let write_op = Pin::new(&mut self.write_op.unwrap()); - match write_op.poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Ready(Ok(n)) => { - self.write_op = None; - Poll::Ready(Ok(n)) - } - } +// impl AsyncWrite for PollStream<'a> { +// fn poll_write( +// self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// buf: &[u8], +// ) -> Poll> { +// if self.write_fut.is_none() { +// let s = Pin::into_inner(self); +// s.write_fut = Some(Box::pin( +// self.stream.write(&Bytes::copy_from_slice(buf)), +// )); +// } + +// let write_fut = self.write_fut.unwrap().as_mut(); +// match write_fut.poll(cx) { +// Poll::Pending => Poll::Pending, +// Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), +// Poll::Ready(Ok(n)) => { +// let s = Pin::into_inner(self); +// s.write_fut = None; +// Poll::Ready(Ok(n)) +// } +// } +// } + +// fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// match self.write_fut { +// Some(op) => match op.as_mut().poll(cx) { +// Poll::Pending => Poll::Pending, +// Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), +// Poll::Ready(Ok(_)) => { +// let s = Pin::into_inner(self); +// s.write_fut = None; +// Poll::Ready(Ok(())) +// } +// }, +// None => Poll::Ready(Ok(())), +// } +// } + +// fn poll_shutdown(self: Pin<&'a mut Self>, cx: &mut Context<'_>) -> Poll> { +// if self.shutdown_fut.is_none() { +// let s = Pin::into_inner(self); +// s.shutdown_fut = Some(Box::pin(self.stream.close())); +// } + +// let shutdown_fut = self.shutdown_fut.unwrap().as_mut(); +// match shutdown_fut.poll(cx) { +// Poll::Pending => Poll::Pending, +// Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), +// Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), +// } +// } +// } + +impl<'a> Clone for PollStream<'a> { + fn clone(&self) -> PollStream<'a> { + PollStream::new(self.clone_inner()) } +} - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.write_op { - Some(op) => match Pin::new(&mut op).poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Ready(Ok(_)) => { - self.write_op = None; - Poll::Ready(Ok(())) - } - }, - None => Poll::Ready(Ok(())), - } +impl fmt::Debug for PollStream<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PollStream") + .field("stream", &self.stream) + .finish() } +} - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.shutdown_op.is_none() { - self.shutdown_op = Some(Box::pin(self.stream.clone().close())); - } - - let shutdown_op = Pin::new(&mut self.shutdown_op.unwrap()); - match shutdown_op.poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), - } +impl AsRef for PollStream<'_> { + fn as_ref(&self) -> &Stream { + &*self.stream } } From 3f856f8ee37956add3b379c88158fd23e30e8f78 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Sat, 16 Apr 2022 12:52:03 +0400 Subject: [PATCH 05/17] fix remaining errors --- src/stream/mod.rs | 158 +++++++++++++++++++++------------------------- 1 file changed, 73 insertions(+), 85 deletions(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 5499936..649f04f 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -477,29 +477,6 @@ impl Stream { } } -// struct ReadFuture { -// buf: Arc>, -// inner: Pin> + Send>>, -// } - -// impl ReadFuture { -// pub fn new(mut buf: Vec, stream: Arc) -> Self { -// let mut buf = Arc::new(buf); -// Self { buf: buf.clone(), inner: Box::pin(stream.read(buf.as_mut_slice())) } -// } -// } - -// impl Future for ReadFuture { -// type Output = Result>; -// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { -// match self.as_mut().inner.as_mut().poll(cx) { -// Poll::Ready(Ok(n)) => Poll::Ready(Ok(self.buf.to_vec())), -// Poll::Ready(Err(e)) => Poll::Ready(Err(e)), -// Poll::Pending => Poll::Pending, -// } -// } -// } - /// A wrapper around around [`Stream`], which implements [`AsyncRead`] and /// [`AsyncWrite`]. /// @@ -508,7 +485,7 @@ impl Stream { struct PollStream<'a> { stream: Arc, - read_fut: Option> + Send + 'a>>>, + read_fut: Option>> + Send + 'a>>>, write_fut: Option> + Send + 'a>>>, shutdown_fut: Option> + Send + 'a>>>, @@ -575,15 +552,20 @@ impl AsyncRead for PollStream<'_> { let fut = match self.read_fut.as_mut() { Some(fut) => fut, None => { - // read into a temporary buffer because `buf` has an unonymous lifetime, which can be // shorter than the lifetime of `read_fut`. let stream = self.stream.clone(); - let temp_buf = Vec::with_capacity(buf.capacity()); - self.read_fut.get_or_insert(Box::pin(( move || { - stream.read(temp_buf.as_mut_slice()) - })() - )) + let mut temp_buf = Vec::with_capacity(buf.capacity()); + self.read_fut.get_or_insert(Box::pin(async move { + let res = stream.read(temp_buf.as_mut_slice()).await; + match res { + Ok(n) => { + temp_buf.truncate(n); + Ok(temp_buf) + }, + Err(e) => Err(e), + } + })) } }; @@ -595,8 +577,8 @@ impl AsyncRead for PollStream<'_> { Poll::Ready(Err(Error::ErrTryAgain)) => {} Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), Poll::Ready(Ok(read_buf)) => { - let len = std::cmp::min(read_buf, buf.remaining()); - buf.put_slice(&self.read_buf[..len]); + let len = std::cmp::min(read_buf.len(), buf.remaining()); + buf.put_slice(&read_buf[..len]); self.read_fut = None; return Poll::Ready(Ok(())); } @@ -605,60 +587,66 @@ impl AsyncRead for PollStream<'_> { } } -// impl AsyncWrite for PollStream<'a> { -// fn poll_write( -// self: Pin<&mut Self>, -// cx: &mut Context<'_>, -// buf: &[u8], -// ) -> Poll> { -// if self.write_fut.is_none() { -// let s = Pin::into_inner(self); -// s.write_fut = Some(Box::pin( -// self.stream.write(&Bytes::copy_from_slice(buf)), -// )); -// } - -// let write_fut = self.write_fut.unwrap().as_mut(); -// match write_fut.poll(cx) { -// Poll::Pending => Poll::Pending, -// Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), -// Poll::Ready(Ok(n)) => { -// let s = Pin::into_inner(self); -// s.write_fut = None; -// Poll::Ready(Ok(n)) -// } -// } -// } - -// fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { -// match self.write_fut { -// Some(op) => match op.as_mut().poll(cx) { -// Poll::Pending => Poll::Pending, -// Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), -// Poll::Ready(Ok(_)) => { -// let s = Pin::into_inner(self); -// s.write_fut = None; -// Poll::Ready(Ok(())) -// } -// }, -// None => Poll::Ready(Ok(())), -// } -// } - -// fn poll_shutdown(self: Pin<&'a mut Self>, cx: &mut Context<'_>) -> Poll> { -// if self.shutdown_fut.is_none() { -// let s = Pin::into_inner(self); -// s.shutdown_fut = Some(Box::pin(self.stream.close())); -// } +impl AsyncWrite for PollStream<'_> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let fut = match self.write_fut.as_mut() { + Some(fut) => fut, + None => { + let stream = self.stream.clone(); + let bytes = Bytes::copy_from_slice(buf); + self.write_fut.get_or_insert(Box::pin(async move { + stream.write(&bytes).await + })) + } + }; + + match fut.as_mut().poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Ok(n)) => { + self.write_fut = None; + Poll::Ready(Ok(n)) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.write_fut.as_mut() { + Some(fut) => match fut.as_mut().poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Ok(_)) => { + // XXX: is a data race between poll_write and poll_flush possible? + self.write_fut = None; + Poll::Ready(Ok(())) + } + }, + None => Poll::Ready(Ok(())), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let fut = match self.shutdown_fut.as_mut() { + Some(fut) => fut, + None => { + let stream = self.stream.clone(); + self.shutdown_fut.get_or_insert(Box::pin(async move { + stream.close().await + })) + } + }; -// let shutdown_fut = self.shutdown_fut.unwrap().as_mut(); -// match shutdown_fut.poll(cx) { -// Poll::Pending => Poll::Pending, -// Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), -// Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), -// } -// } -// } + match fut.as_mut().poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + } + } +} impl<'a> Clone for PollStream<'a> { fn clone(&self) -> PollStream<'a> { From 454cd6dfd58bbb2141ee2d8a2578bfb763ea4bc0 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Sat, 16 Apr 2022 21:24:25 +0400 Subject: [PATCH 06/17] clear futures upon errors too --- src/stream/mod.rs | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 649f04f..75d0ceb 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -503,12 +503,12 @@ impl PollStream<'_> { read_buf: Vec::new(), } } - + /// Get back the inner stream. pub fn into_inner(self) -> Arc { self.stream } - + /// Obtain a clone of the inner stream. pub fn clone_inner(&self) -> Arc { self.stream.clone() @@ -559,10 +559,10 @@ impl AsyncRead for PollStream<'_> { self.read_fut.get_or_insert(Box::pin(async move { let res = stream.read(temp_buf.as_mut_slice()).await; match res { - Ok(n) => { + Ok(n) => { temp_buf.truncate(n); - Ok(temp_buf) - }, + Ok(temp_buf) + } Err(e) => Err(e), } })) @@ -575,7 +575,10 @@ impl AsyncRead for PollStream<'_> { // retry immediately upon empty data or incomplete chunks // since there's no way to setup a waker. Poll::Ready(Err(Error::ErrTryAgain)) => {} - Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + Poll::Ready(Err(e)) => { + self.read_fut = None; + return Poll::Ready(Err(e.into())); + } Poll::Ready(Ok(read_buf)) => { let len = std::cmp::min(read_buf.len(), buf.remaining()); buf.put_slice(&read_buf[..len]); @@ -598,15 +601,17 @@ impl AsyncWrite for PollStream<'_> { None => { let stream = self.stream.clone(); let bytes = Bytes::copy_from_slice(buf); - self.write_fut.get_or_insert(Box::pin(async move { - stream.write(&bytes).await - })) + self.write_fut + .get_or_insert(Box::pin(async move { stream.write(&bytes).await })) } }; match fut.as_mut().poll(cx) { Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Err(e)) => { + self.write_fut = None; + Poll::Ready(Err(e.into())) + } Poll::Ready(Ok(n)) => { self.write_fut = None; Poll::Ready(Ok(n)) @@ -618,9 +623,11 @@ impl AsyncWrite for PollStream<'_> { match self.write_fut.as_mut() { Some(fut) => match fut.as_mut().poll(cx) { Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Err(e)) => { + self.write_fut = None; + Poll::Ready(Err(e.into())) + } Poll::Ready(Ok(_)) => { - // XXX: is a data race between poll_write and poll_flush possible? self.write_fut = None; Poll::Ready(Ok(())) } @@ -634,12 +641,11 @@ impl AsyncWrite for PollStream<'_> { Some(fut) => fut, None => { let stream = self.stream.clone(); - self.shutdown_fut.get_or_insert(Box::pin(async move { - stream.close().await - })) + self.shutdown_fut + .get_or_insert(Box::pin(async move { stream.close().await })) } }; - + match fut.as_mut().poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), From dbb854c6d9d5cd62f37d8e2e8f005da2a67dd68f Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Sat, 16 Apr 2022 21:47:49 +0400 Subject: [PATCH 07/17] handle EOF --- src/stream/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 75d0ceb..41af76d 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -575,6 +575,11 @@ impl AsyncRead for PollStream<'_> { // retry immediately upon empty data or incomplete chunks // since there's no way to setup a waker. Poll::Ready(Err(Error::ErrTryAgain)) => {} + // EOF has been reached => don't touch buf and just return Ok + Poll::Ready(Err(Error::ErrEof)) => { + self.read_fut = None; + return Poll::Ready(Ok(())); + } Poll::Ready(Err(e)) => { self.read_fut = None; return Poll::Ready(Err(e.into())); From 1017d64420648987ee8f3e487ec88dfa8942f9d1 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Sat, 16 Apr 2022 21:54:46 +0400 Subject: [PATCH 08/17] remove read_buf and make PollStream public --- src/stream/mod.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 41af76d..f67debc 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -482,14 +482,12 @@ impl Stream { /// /// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an /// additional overhead. -struct PollStream<'a> { +pub struct PollStream<'a> { stream: Arc, read_fut: Option>> + Send + 'a>>>, write_fut: Option> + Send + 'a>>>, shutdown_fut: Option> + Send + 'a>>>, - - read_buf: Vec, } impl PollStream<'_> { @@ -500,7 +498,6 @@ impl PollStream<'_> { read_fut: None, write_fut: None, shutdown_fut: None, - read_buf: Vec::new(), } } From 38f064cac0d5ab9af4a7760ea9fa9f7ceafd05ed Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Sat, 16 Apr 2022 22:03:52 +0400 Subject: [PATCH 09/17] add doc test for PollStream --- src/stream/mod.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index f67debc..05ddb47 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -491,7 +491,17 @@ pub struct PollStream<'a> { } impl PollStream<'_> { - /// Creates a new PollStream. + /// Constructs a new `PollStream`. + /// + /// # Examples + /// + /// ``` + /// use webrtc_sctp::stream::{Stream, PollStream}; + /// use std::sync::Arc; + /// + /// let stream = Stream::default(); + /// let poll_stream = PollStream::new(Arc::new(stream)); + /// ``` pub fn new(stream: Arc) -> Self { Self { stream, From 88b26dece9586b28dffeb1f60072f76c1dc68a6b Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Sun, 17 Apr 2022 08:42:48 +0400 Subject: [PATCH 10/17] allocate temp_buf also add a test --- src/stream/mod.rs | 2 +- src/stream/stream_test.rs | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 05ddb47..22907f0 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -562,7 +562,7 @@ impl AsyncRead for PollStream<'_> { // read into a temporary buffer because `buf` has an unonymous lifetime, which can be // shorter than the lifetime of `read_fut`. let stream = self.stream.clone(); - let mut temp_buf = Vec::with_capacity(buf.capacity()); + let mut temp_buf = vec![0; buf.remaining()]; self.read_fut.get_or_insert(Box::pin(async move { let res = stream.read(temp_buf.as_mut_slice()).await; match res { diff --git a/src/stream/stream_test.rs b/src/stream/stream_test.rs index c80e0da..cb1c339 100644 --- a/src/stream/stream_test.rs +++ b/src/stream/stream_test.rs @@ -1,6 +1,8 @@ use super::*; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; #[test] fn test_stream_buffered_amount() -> Result<()> { @@ -69,3 +71,39 @@ async fn test_stream_amount_on_buffered_amount_low() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_poll_stream() -> std::result::Result<(), std::io::Error> { + let s = Arc::new(Stream::new( + "test_poll_stream".to_owned(), + 0, + 4096, + Arc::new(AtomicU32::new(4096)), + Arc::new(AtomicU8::new(AssociationState::Established as u8)), + None, + Arc::new(PendingQueue::new()), + )); + let mut poll_stream = PollStream::new(s.clone()); + + // write + let n = poll_stream.write(&[1, 2, 3]).await?; + assert_eq!(3, n); + assert_eq!(3, poll_stream.buffered_amount()); + + // read + s.clone() + .handle_data(ChunkPayloadData { + unordered: true, + beginning_fragment: true, + ending_fragment: true, + user_data: Bytes::from_static(&[0, 1, 2, 3, 4]), + payload_type: PayloadProtocolIdentifier::Binary, + ..Default::default() + }) + .await; + let mut buf = [0; 5]; + poll_stream.read(&mut buf).await?; + assert_eq!(buf, [0, 1, 2, 3, 4]); + + Ok(()) +} From 8176d58fef38c786513bf04ce7e41325cbc975a9 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Sun, 17 Apr 2022 09:07:40 +0400 Subject: [PATCH 11/17] test shutdown --- src/stream/stream_test.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/stream/stream_test.rs b/src/stream/stream_test.rs index cb1c339..ca0c43f 100644 --- a/src/stream/stream_test.rs +++ b/src/stream/stream_test.rs @@ -73,7 +73,7 @@ async fn test_stream_amount_on_buffered_amount_low() -> Result<()> { } #[tokio::test] -async fn test_poll_stream() -> std::result::Result<(), std::io::Error> { +async fn test_poll_stream() -> std::result::Result<(), io::Error> { let s = Arc::new(Stream::new( "test_poll_stream".to_owned(), 0, @@ -85,14 +85,15 @@ async fn test_poll_stream() -> std::result::Result<(), std::io::Error> { )); let mut poll_stream = PollStream::new(s.clone()); - // write + // async write let n = poll_stream.write(&[1, 2, 3]).await?; assert_eq!(3, n); assert_eq!(3, poll_stream.buffered_amount()); - // read - s.clone() - .handle_data(ChunkPayloadData { + // async read + // 1. pretend that we've received a chunk + let sc = s.clone(); + sc.handle_data(ChunkPayloadData { unordered: true, beginning_fragment: true, ending_fragment: true, @@ -101,9 +102,15 @@ async fn test_poll_stream() -> std::result::Result<(), std::io::Error> { ..Default::default() }) .await; + // 2. read it let mut buf = [0; 5]; poll_stream.read(&mut buf).await?; assert_eq!(buf, [0, 1, 2, 3, 4]); + // shutdown + poll_stream.shutdown().await?; + assert_eq!(true, sc.closed.load(Ordering::Relaxed)); + assert!(poll_stream.read(&mut buf).await.is_err()); + Ok(()) } From f8632c1d2d9003f71095f4f2ea53b226273d5270 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Sun, 17 Apr 2022 09:14:28 +0400 Subject: [PATCH 12/17] write a few more tests --- src/stream/mod.rs | 4 ++-- src/stream/stream_test.rs | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 22907f0..3f640d6 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -499,8 +499,8 @@ impl PollStream<'_> { /// use webrtc_sctp::stream::{Stream, PollStream}; /// use std::sync::Arc; /// - /// let stream = Stream::default(); - /// let poll_stream = PollStream::new(Arc::new(stream)); + /// let stream = Arc::new(Stream::default()); + /// let poll_stream = PollStream::new(stream); /// ``` pub fn new(stream: Arc) -> Self { Self { diff --git a/src/stream/stream_test.rs b/src/stream/stream_test.rs index ca0c43f..a9bebf6 100644 --- a/src/stream/stream_test.rs +++ b/src/stream/stream_test.rs @@ -85,6 +85,12 @@ async fn test_poll_stream() -> std::result::Result<(), io::Error> { )); let mut poll_stream = PollStream::new(s.clone()); + // getters + assert_eq!(0, poll_stream.stream_identifier()); + assert_eq!(0, poll_stream.buffered_amount()); + assert_eq!(0, poll_stream.buffered_amount_low_threshold()); + assert_eq!(0, poll_stream.get_num_bytes_in_reassembly_queue().await); + // async write let n = poll_stream.write(&[1, 2, 3]).await?; assert_eq!(3, n); @@ -112,5 +118,9 @@ async fn test_poll_stream() -> std::result::Result<(), io::Error> { assert_eq!(true, sc.closed.load(Ordering::Relaxed)); assert!(poll_stream.read(&mut buf).await.is_err()); + // misc. + let clone = poll_stream.clone(); + assert_eq!(clone.stream_identifier(), poll_stream.stream_identifier()); + Ok(()) } From da5053739197c0a6b1578fbfb026ab9242f31300 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Sun, 17 Apr 2022 20:42:03 +0400 Subject: [PATCH 13/17] fix an issue with Pending state during poll_write --- src/stream/mod.rs | 29 ++++++++++++++++++++++++----- src/stream/stream_test.rs | 17 +++++++++-------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 3f640d6..f8220d4 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -608,18 +608,37 @@ impl AsyncWrite for PollStream<'_> { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let fut = match self.write_fut.as_mut() { - Some(fut) => fut, + let (fut, fut_is_new) = match self.write_fut.as_mut() { + Some(fut) => (fut, false), None => { let stream = self.stream.clone(); let bytes = Bytes::copy_from_slice(buf); - self.write_fut - .get_or_insert(Box::pin(async move { stream.write(&bytes).await })) + ( + self.write_fut + .get_or_insert(Box::pin(async move { stream.write(&bytes).await })), + true, + ) } }; match fut.as_mut().poll(cx) { - Poll::Pending => Poll::Pending, + Poll::Pending => { + // If it's the first time we're polling the future, `Poll::Pending` can't be + // returned because that would mean the `PollStream` is not ready for writing. And + // this is not true since we've just created a future, which is going to write the + // buf to the underlying stream. + // + // It's okay to return `Poll::Ready` if the data is buffered (this is what the + // buffered writer and `File` do). + if fut_is_new { + Poll::Ready(Ok(buf.len())) + } else { + // If it's the subsequent poll, it's okay to return `Poll::Pending` as it + // indicates that the `PollStream` is not ready for writing. Only one future + // can be in progress at the time. + Poll::Pending + } + } Poll::Ready(Err(e)) => { self.write_fut = None; Poll::Ready(Err(e.into())) diff --git a/src/stream/stream_test.rs b/src/stream/stream_test.rs index a9bebf6..cab45af 100644 --- a/src/stream/stream_test.rs +++ b/src/stream/stream_test.rs @@ -94,20 +94,21 @@ async fn test_poll_stream() -> std::result::Result<(), io::Error> { // async write let n = poll_stream.write(&[1, 2, 3]).await?; assert_eq!(3, n); + poll_stream.flush().await?; assert_eq!(3, poll_stream.buffered_amount()); // async read // 1. pretend that we've received a chunk let sc = s.clone(); sc.handle_data(ChunkPayloadData { - unordered: true, - beginning_fragment: true, - ending_fragment: true, - user_data: Bytes::from_static(&[0, 1, 2, 3, 4]), - payload_type: PayloadProtocolIdentifier::Binary, - ..Default::default() - }) - .await; + unordered: true, + beginning_fragment: true, + ending_fragment: true, + user_data: Bytes::from_static(&[0, 1, 2, 3, 4]), + payload_type: PayloadProtocolIdentifier::Binary, + ..Default::default() + }) + .await; // 2. read it let mut buf = [0; 5]; poll_stream.read(&mut buf).await?; From 5e72dab5a866cffdd71c77e1419f5db050d4e6a8 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Tue, 19 Apr 2022 14:54:56 +0400 Subject: [PATCH 14/17] use fixed size read buffer https://github.com/webrtc-rs/sctp/pull/9#discussion_r852639786 --- src/stream/mod.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index f8220d4..f4b2313 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -477,6 +477,9 @@ impl Stream { } } +/// Default capacity of a temporary read buffer used by [`PollStream`]. +const DEFAULT_READ_BUF_SIZE: usize = 4096; + /// A wrapper around around [`Stream`], which implements [`AsyncRead`] and /// [`AsyncWrite`]. /// @@ -488,6 +491,8 @@ pub struct PollStream<'a> { read_fut: Option>> + Send + 'a>>>, write_fut: Option> + Send + 'a>>>, shutdown_fut: Option> + Send + 'a>>>, + + read_buf_cap: usize, } impl PollStream<'_> { @@ -508,6 +513,7 @@ impl PollStream<'_> { read_fut: None, write_fut: None, shutdown_fut: None, + read_buf_cap: DEFAULT_READ_BUF_SIZE, } } @@ -544,6 +550,12 @@ impl PollStream<'_> { let reassembly_queue = self.stream.reassembly_queue.lock().await; reassembly_queue.get_num_bytes() } + + + /// Set the capacity of the temporary read buffer (default: 4096). + pub fn set_read_buf_capacity(&mut self, capacity: usize) { + self.read_buf_cap = capacity + } } impl AsyncRead for PollStream<'_> { @@ -562,7 +574,7 @@ impl AsyncRead for PollStream<'_> { // read into a temporary buffer because `buf` has an unonymous lifetime, which can be // shorter than the lifetime of `read_fut`. let stream = self.stream.clone(); - let mut temp_buf = vec![0; buf.remaining()]; + let mut temp_buf = vec![0; self.read_buf_cap]; self.read_fut.get_or_insert(Box::pin(async move { let res = stream.read(temp_buf.as_mut_slice()).await; match res { From 91b868ec3a6c29d3788fdcca026832f340031072 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Tue, 19 Apr 2022 16:41:31 +0400 Subject: [PATCH 15/17] fix a bug where a part of data was lost https://github.com/webrtc-rs/sctp/pull/9#discussion_r852640210 BEFORE: if `temp_buf.len()` is greater than `buf.remaining()`, than we're loosing some data in `temp_buf`. AFTER: if `temp_buf.len()` is greater than `buf.remaining()`, than we're switching to a special variant of `ReadFut` - `RemainingData(Vec)`. When another poll_read comes in, remaining data is used to populate `buf`. --- src/stream/mod.rs | 70 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index f4b2313..b6b69b6 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -477,9 +477,33 @@ impl Stream { } } -/// Default capacity of a temporary read buffer used by [`PollStream`]. +/// Default capacity of the temporary read buffer used by [`PollStream`]. const DEFAULT_READ_BUF_SIZE: usize = 4096; +/// State of the read `Future` in [`PollStream`]. +enum ReadFut<'a> { + /// Nothing in progress. + Idle, + /// Reading data from the underlying stream. + Reading(Pin>> + Send + 'a>>), + /// Finished reading, but there's unread data in the temporary buffer. + RemainingData(Vec), +} + +impl<'a> ReadFut<'a> { + /// Gets a mutable reference to the future stored inside `Reading(future)`. + /// + /// # Panics + /// + /// Panics if `ReadFut` variant is not `Reading`. + fn get_reading_mut(&mut self) -> &mut Pin>> + Send + 'a>> { + match self { + ReadFut::Reading(ref mut fut) => fut, + _ => panic!("expected ReadFut to be Reading"), + } + } +} + /// A wrapper around around [`Stream`], which implements [`AsyncRead`] and /// [`AsyncWrite`]. /// @@ -488,7 +512,7 @@ const DEFAULT_READ_BUF_SIZE: usize = 4096; pub struct PollStream<'a> { stream: Arc, - read_fut: Option>> + Send + 'a>>>, + read_fut: ReadFut<'a>, write_fut: Option> + Send + 'a>>>, shutdown_fut: Option> + Send + 'a>>>, @@ -510,7 +534,7 @@ impl PollStream<'_> { pub fn new(stream: Arc) -> Self { Self { stream, - read_fut: None, + read_fut: ReadFut::Idle, write_fut: None, shutdown_fut: None, read_buf_cap: DEFAULT_READ_BUF_SIZE, @@ -551,7 +575,6 @@ impl PollStream<'_> { reassembly_queue.get_num_bytes() } - /// Set the capacity of the temporary read buffer (default: 4096). pub fn set_read_buf_capacity(&mut self, capacity: usize) { self.read_buf_cap = capacity @@ -568,14 +591,13 @@ impl AsyncRead for PollStream<'_> { return Poll::Ready(Ok(())); } - let fut = match self.read_fut.as_mut() { - Some(fut) => fut, - None => { - // read into a temporary buffer because `buf` has an unonymous lifetime, which can be - // shorter than the lifetime of `read_fut`. + let fut = match self.read_fut { + ReadFut::Idle => { + // read into a temporary buffer because `buf` has an unonymous lifetime, which can + // be shorter than the lifetime of `read_fut`. let stream = self.stream.clone(); let mut temp_buf = vec![0; self.read_buf_cap]; - self.read_fut.get_or_insert(Box::pin(async move { + self.read_fut = ReadFut::Reading(Box::pin(async move { let res = stream.read(temp_buf.as_mut_slice()).await; match res { Ok(n) => { @@ -584,7 +606,15 @@ impl AsyncRead for PollStream<'_> { } Err(e) => Err(e), } - })) + })); + self.read_fut.get_reading_mut() + } + ReadFut::Reading(ref mut fut) => fut, + ReadFut::RemainingData(ref data) => { + let len = std::cmp::min(data.len(), buf.remaining()); + buf.put_slice(&data[..len]); + self.read_fut = ReadFut::Idle; + return Poll::Ready(Ok(())); } }; @@ -596,17 +626,23 @@ impl AsyncRead for PollStream<'_> { Poll::Ready(Err(Error::ErrTryAgain)) => {} // EOF has been reached => don't touch buf and just return Ok Poll::Ready(Err(Error::ErrEof)) => { - self.read_fut = None; + self.read_fut = ReadFut::Idle; return Poll::Ready(Ok(())); } Poll::Ready(Err(e)) => { - self.read_fut = None; + self.read_fut = ReadFut::Idle; return Poll::Ready(Err(e.into())); } - Poll::Ready(Ok(read_buf)) => { - let len = std::cmp::min(read_buf.len(), buf.remaining()); - buf.put_slice(&read_buf[..len]); - self.read_fut = None; + Poll::Ready(Ok(mut temp_buf)) => { + let remaining = buf.remaining(); + let len = std::cmp::min(temp_buf.len(), remaining); + buf.put_slice(&temp_buf[..len]); + if temp_buf.len() > remaining { + temp_buf.drain(0..len); + self.read_fut = ReadFut::RemainingData(temp_buf); + } else { + self.read_fut = ReadFut::Idle; + } return Poll::Ready(Ok(())); } } From 032b264141986f332880358a4e724a6c0cb5fc60 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Tue, 19 Apr 2022 17:30:00 +0400 Subject: [PATCH 16/17] don't set read_fut to idle if data is not fully read --- src/stream/mod.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index b6b69b6..c7e10b5 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -610,10 +610,16 @@ impl AsyncRead for PollStream<'_> { self.read_fut.get_reading_mut() } ReadFut::Reading(ref mut fut) => fut, - ReadFut::RemainingData(ref data) => { - let len = std::cmp::min(data.len(), buf.remaining()); + ReadFut::RemainingData(ref mut data) => { + let remaining = buf.remaining(); + let len = std::cmp::min(data.len(), remaining); buf.put_slice(&data[..len]); - self.read_fut = ReadFut::Idle; + if data.len() > remaining { + // ReadFut remains to be RemainingData + data.drain(0..len); + } else { + self.read_fut = ReadFut::Idle; + } return Poll::Ready(Ok(())); } }; From 58db7bfb8f23efef284312bcb167a99c1e1e6a64 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Thu, 21 Apr 2022 10:48:54 +0400 Subject: [PATCH 17/17] format code --- src/stream/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/stream/mod.rs b/src/stream/mod.rs index c7e10b5..13ed433 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -496,7 +496,9 @@ impl<'a> ReadFut<'a> { /// # Panics /// /// Panics if `ReadFut` variant is not `Reading`. - fn get_reading_mut(&mut self) -> &mut Pin>> + Send + 'a>> { + fn get_reading_mut( + &mut self, + ) -> &mut Pin>> + Send + 'a>> { match self { ReadFut::Reading(ref mut fut) => fut, _ => panic!("expected ReadFut to be Reading"),