From 83a398c844691786846c483c72dc51f47f889324 Mon Sep 17 00:00:00 2001 From: = <=> Date: Fri, 6 Jan 2023 16:13:13 -0600 Subject: [PATCH] io: implement StreamedFd StreamedFd is a higher-level wrapper around an AsyncFd which provides the PollEvented optimization to users. Fixes #5324. --- tokio-util/src/io/mod.rs | 5 ++ tokio-util/src/io/stream_fd.rs | 132 +++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 tokio-util/src/io/stream_fd.rs diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index 6c40d739014..fec96910f45 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -20,6 +20,11 @@ mod stream_reader; cfg_io_util! { mod sync_bridge; pub use self::sync_bridge::SyncIoBridge; + + #[cfg(unix)] + mod stream_fd; + #[cfg(unix)] + pub use stream_fd::*; } pub use self::copy_to_bytes::CopyToBytes; diff --git a/tokio-util/src/io/stream_fd.rs b/tokio-util/src/io/stream_fd.rs new file mode 100644 index 00000000000..fcc69c573f0 --- /dev/null +++ b/tokio-util/src/io/stream_fd.rs @@ -0,0 +1,132 @@ +use std::io; +use std::io::{Error, Read, Write}; +use std::os::fd::AsRawFd; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use tokio::io::unix::AsyncFd; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// Provides async reading and writing semantics to a pollable file descriptor that is a byte +/// stream. +/// +/// [`AsyncFd`] provides a way to poll file descriptors for IO readiness, but leaves reading and +/// writing to the user. This is a higher-level utility which handles this for users. +/// +/// # Warning +/// The underlying IO source this is constructed from must not be capable of nonblocking reads and +/// writes, and must be pollable. +/// +/// The underlying IO source must also be a continuous stream of bytes in either direction. It must +/// be guaranteed that a partial read or write signals a loss of readiness. +/// +/// The underlying IO source must also be self-flushing. This will assume that flushing is a no-op. +/// +/// +/// [`AsyncFd`]: struct@tokio::io::unix::AsyncFd +#[derive(Debug)] +pub struct StreamFd +where + T: AsRawFd, +{ + inner: AsyncFd, +} + +impl StreamFd +where + T: AsRawFd, +{ + /// Construct a new StreamFd from an IO source. + /// + /// # Panics + /// Panics if called from outside a tokio runtime context. + /// + /// [`RawFd`]: struct@std::os::fd::RawFd + pub fn new(fd: T) -> io::Result { + let inner = AsyncFd::new(fd)?; + + Ok(Self { inner }) + } +} + +// note: taken from PollEvented +impl AsyncRead for StreamFd +where + T: AsRawFd + Read + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + + loop { + let mut guard = ready!(this.inner.poll_read_ready_mut(cx))?; + + // safety: we will not be reading these bytes + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]) + }; + let len = b.len(); + + match guard.get_inner_mut().read(b) { + Ok(n) => { + if n > 0 && n < len { + guard.clear_ready(); + } + + // Safety: We trust `File::read` to have filled up `n` bytes in the + // buffer. + unsafe { buf.assume_init(n) }; + buf.advance(n); + return Poll::Ready(Ok(())); + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + guard.clear_ready(); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } +} + +impl AsyncWrite for StreamFd +where + T: AsRawFd + Write + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + + loop { + let mut guard = ready!(this.inner.poll_write_ready_mut(cx))?; + + match guard.get_inner_mut().write(buf) { + Ok(n) => { + // if we write only part of our buffer, this is sufficient on unix to show + // that the socket buffer is full + if n > 0 && n < buf.len() { + guard.clear_ready(); + } + + return Poll::Ready(Ok(n)); + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + guard.clear_ready(); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + unimplemented!("Shutdown is not implemented for this type") + } +}