diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 0849a369ab2..ad136e807f2 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -55,6 +55,7 @@ net = [ "mio/tcp", "mio/udp", "mio/uds", + "winapi/winbase", ] process = [ "bytes", @@ -94,7 +95,7 @@ pin-project-lite = "0.2.0" bytes = { version = "1.0.0", optional = true } once_cell = { version = "1.5.2", optional = true } memchr = { version = "2.2", optional = true } -mio = { version = "0.7.6", optional = true } +mio = { version = "0.7.9", optional = true } num_cpus = { version = "1.8.0", optional = true } parking_lot = { version = "0.11.0", optional = true } @@ -111,6 +112,9 @@ signal-hook-registry = { version = "1.1.1", optional = true } libc = { version = "0.2.42" } nix = { version = "0.19.0" } +[target.'cfg(windows)'.dependencies.miow] +version = "0.3.6" + [target.'cfg(windows)'.dependencies.winapi] version = "0.3.8" default-features = false @@ -124,6 +128,9 @@ proptest = "0.10.0" tempfile = "3.1.0" async-stream = "0.3" +[target.'cfg(windows)'.dev-dependencies.ntapi] +version = "0.3.6" + [target.'cfg(loom)'.dev-dependencies] loom = { version = "0.4", features = ["futures", "checkpoint"] } diff --git a/tokio/src/macros/cfg.rs b/tokio/src/macros/cfg.rs index 9ae098fb072..22285b78c3e 100644 --- a/tokio/src/macros/cfg.rs +++ b/tokio/src/macros/cfg.rs @@ -183,6 +183,16 @@ macro_rules! cfg_net_unix { } } +macro_rules! cfg_net_windows { + ($($item:item)*) => { + $( + #[cfg(all(target_os = "windows", feature = "net"))] + #[cfg_attr(docsrs, doc(cfg(all(target_os = "windows", feature = "net"))))] + $item + )* + } +} + macro_rules! cfg_process { ($($item:item)*) => { $( diff --git a/tokio/src/net/mod.rs b/tokio/src/net/mod.rs index 2f17f9eab5e..6aab037b7df 100644 --- a/tokio/src/net/mod.rs +++ b/tokio/src/net/mod.rs @@ -46,3 +46,8 @@ cfg_net_unix! { pub use unix::listener::UnixListener; pub use unix::stream::UnixStream; } + +cfg_net_windows! { + pub mod windows; + pub use windows::named_pipe::{NamedPipe, NamedPipeServer, NamedPipeServerBuilder}; +} diff --git a/tokio/src/net/windows/mod.rs b/tokio/src/net/windows/mod.rs new file mode 100644 index 00000000000..d4b530c0f89 --- /dev/null +++ b/tokio/src/net/windows/mod.rs @@ -0,0 +1,3 @@ +//! Windows platform functionality. + +pub mod named_pipe; diff --git a/tokio/src/net/windows/named_pipe.rs b/tokio/src/net/windows/named_pipe.rs new file mode 100644 index 00000000000..b505c42773c --- /dev/null +++ b/tokio/src/net/windows/named_pipe.rs @@ -0,0 +1,390 @@ +//! Windows named pipes. + +use mio::windows::NamedPipe as MioNamedPipe; +use miow::pipe::{NamedPipe as RawNamedPipe, NamedPipeBuilder}; +use winapi::{ + shared::winerror::*, + um::{minwinbase::*, namedpipeapi::WaitNamedPipeW, winbase::*}, +}; + +use std::{ + ffi::{OsStr, OsString}, + fs::OpenOptions, + future::Future, + io::{self, ErrorKind, IoSlice, Result}, + mem, + os::windows::prelude::*, + pin::Pin, + sync::Mutex, + task::{Context, Poll}, +}; + +use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf, Ready}; + +/// Default in/out buffer size. +pub const DEFAULT_BUFFER_SIZE: u32 = 65_536; + +fn mio_from_miow(pipe: RawNamedPipe) -> MioNamedPipe { + // Safety: nothing actually unsafe about this. The trait fn includes `unsafe`. + unsafe { MioNamedPipe::from_raw_handle(pipe.into_raw_handle()) } +} + +/// Connecting instance future. +#[derive(Debug)] +enum ConnectingInstance { + New(NamedPipe), + Connecting(NamedPipe), + Error(io::Error), + Ready(Option), +} + +impl ConnectingInstance { + fn new(mio: MioNamedPipe) -> Result { + Ok(Self::New(NamedPipe::server(mio)?)) + } +} + +impl Future for ConnectingInstance { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + match mem::replace(&mut *self, ConnectingInstance::Ready(None)) { + Self::Ready(None) => { + // poll on completed future + break Poll::Pending; + } + Self::Ready(Some(pipe)) => break Poll::Ready(Ok(pipe)), + Self::Connecting(pipe) => match pipe.poll_write_ready(cx) { + Poll::Ready(Ok(_)) => break Poll::Ready(Ok(pipe)), + Poll::Ready(Err(err)) => break Poll::Ready(Err(err)), + Poll::Pending => { + *self = Self::Connecting(pipe); + break Poll::Pending; + } + }, + Self::New(pipe) => match pipe.io_ref().connect() { + Ok(()) => { + *self = Self::Connecting(pipe); + // loop again to poll for write readiness + } + Err(err) if err.kind() == ErrorKind::WouldBlock => { + *self = Self::Connecting(pipe); + // loop again to poll for write readiness + } + Err(err) => break Poll::Ready(Err(err)), + }, + Self::Error(err) => break Poll::Ready(Err(err)), + } + } + } +} + +/// A builder structure for creating a new `NamedPipeServer` instance. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct NamedPipeServerBuilder { + pipe_name: OsString, + out_buffer_size: u32, + in_buffer_size: u32, + accept_remote: bool, +} + +impl NamedPipeServerBuilder { + /// Creates a default builder instance. + pub fn new>(addr: A) -> NamedPipeServerBuilder { + let pipe_name = addr.into(); + NamedPipeServerBuilder { + pipe_name, + in_buffer_size: DEFAULT_BUFFER_SIZE, + out_buffer_size: DEFAULT_BUFFER_SIZE, + accept_remote: false, + } + } + + /// Returns `self` with the given input buffer size (defaults to [`DEFAULT_BUFFER_SIZE`]). + pub fn with_in_buffer_size(mut self, in_buffer_size: u32) -> Self { + self.in_buffer_size = in_buffer_size; + self + } + + /// Returns `self` with the given output buffer size (defaults to [`DEFAULT_BUFFER_SIZE`]). + pub fn with_out_buffer_size(mut self, out_buffer_size: u32) -> Self { + self.out_buffer_size = out_buffer_size; + self + } + + /// Returns `self` with `accept_remote` set to `value` (defaults to `false`). + /// + /// `accept_remote` indicates whether this server can accept remote clients or not. + pub fn with_accept_remote(mut self, value: bool) -> Self { + self.accept_remote = value; + self + } + + /// Creates a new [`NamedPipeServer`] with the given security attributes. + /// + /// # Errors + /// + /// It'll fail if pipe with this name already exists. + /// + /// # Unsafety + /// + /// `security_attributes` must point to a well-formed `SECURITY_ATTRIBUTES` structure. + pub unsafe fn build_with_security_attributes( + self, + security_attributes: *mut SECURITY_ATTRIBUTES, + ) -> Result { + let mio = mio_from_miow( + self.miow_builder() + .first(true) + .with_security_attributes(security_attributes)?, + ); + self._build(mio) + } + + /// Creates a new [`NamedPipeServer`]. + /// + /// # Errors + /// + /// * It'll fail if pipe with this name already exists. + pub fn build(self) -> Result { + let mio = mio_from_miow(self.miow_builder().first(true).create()?); + self._build(mio) + } + + fn _build(self, mio: MioNamedPipe) -> Result { + let first_instance = ConnectingInstance::new(mio)?; + Ok(NamedPipeServer { + builder: self, + next_instance: Mutex::new(first_instance), + }) + } + + fn next_instance(&self) -> ConnectingInstance { + let instance = self + .miow_builder() + .first(false) + .create() + .map(mio_from_miow) + .and_then(ConnectingInstance::new); + match instance { + Ok(instance) => instance, + Err(err) => ConnectingInstance::Error(err), + } + } + + fn miow_builder(&self) -> NamedPipeBuilder { + let mut miow_builder = NamedPipeBuilder::new(&self.pipe_name); + miow_builder + .inbound(true) + .outbound(true) + .out_buffer_size(self.out_buffer_size) + .in_buffer_size(self.in_buffer_size) + .accept_remote(self.accept_remote); + miow_builder + } +} + +/// Named pipe server (see [`NamedPipeServerBuilder`]). +#[derive(Debug)] +pub struct NamedPipeServer { + builder: NamedPipeServerBuilder, + // At least one instance will always exist. + next_instance: Mutex, +} + +impl NamedPipeServer { + /// Returns `'static` future that will wait for a client. + /// + /// # Errors + /// + /// This future will resolve successfuly even if client disconnects + /// before an actual call to `ConnectNamedPipe`, but any attempt to write + /// will lead to `ERROR_NO_DATA`. + pub fn accept(&self) -> impl Future> + 'static { + let next_instance = self.builder.next_instance(); + mem::replace(&mut *self.next_instance.lock().unwrap(), next_instance) + } +} + +#[derive(Debug)] +enum NamedPipeInner { + Client(PollEvented), + Server(PollEvented), +} + +/// Non-blocking windows named pipe. +#[derive(Debug)] +pub struct NamedPipe { + inner: NamedPipeInner, +} + +impl NamedPipe { + fn server(mio: MioNamedPipe) -> Result { + let io = PollEvented::new(mio)?; + Ok(Self { + inner: NamedPipeInner::Server(io), + }) + } + + fn client(mio: MioNamedPipe) -> Result { + let io = PollEvented::new(mio)?; + Ok(Self { + inner: NamedPipeInner::Client(io), + }) + } + + fn io_ref(&self) -> &PollEvented { + match &self.inner { + NamedPipeInner::Client(io) | NamedPipeInner::Server(io) => io, + } + } + + /// Will try to connect to a named pipe. Returned pipe may not be writable. + /// + /// # Errors + /// + /// Will error with `ERROR_PIPE_BUSY` if there are no available instances. + fn open(addr: &OsStr) -> Result { + let file = OpenOptions::new() + .read(true) + .write(true) + .custom_flags(FILE_FLAG_OVERLAPPED) + .security_qos_flags(SECURITY_IDENTIFICATION) + .open(addr)?; + + let pipe = unsafe { MioNamedPipe::from_raw_handle(file.into_raw_handle()) }; + Self::client(pipe) + } + + /// Connects to a nemed pipe server by `addr`. + /// + /// # Errors + /// + /// It'll error if there is no such pipe. + pub async fn connect>(addr: A) -> Result { + let mut pipe_name = into_wide(addr.as_ref()); + let mut busy = false; + + loop { + if !busy { + // pipe instance may be available, so trying to open it + match Self::open(addr.as_ref()) { + Ok(pipe) => { + // Pipe is opened. + pipe.writable().await?; + return Ok(pipe); + } + Err(err) if err.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => { + // We should wait since there are no free instances + } + Err(err) => return Err(err), + } + } + + let (status, name) = wait_pipe(pipe_name).await?; + pipe_name = name; + busy = matches!(status, WaitPipeResult::Busy); + } + } + + /// Polls for read readiness. + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.io_ref() + .registration() + .poll_read_ready(cx) + .map_ok(|_| ()) + } + + /// Polls for write readiness. + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.io_ref() + .registration() + .poll_write_ready(cx) + .map_ok(|_| ()) + } + + /// Polls for any of the requested ready states. + pub async fn ready(&self, interest: Interest) -> Result { + let event = self.io_ref().registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Waits for the pipe client or server to become readable. + pub async fn readable(&self) -> Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Waits for the pipe client or server to become writeable. + pub async fn writable(&self) -> Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } +} + +impl AsRawHandle for NamedPipe { + fn as_raw_handle(&self) -> RawHandle { + self.io_ref().as_raw_handle() + } +} + +impl AsyncRead for NamedPipe { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + unsafe { self.io_ref().poll_read(cx, buf) } + } +} + +impl AsyncWrite for NamedPipe { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + self.io_ref().poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + self.io_ref().poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +fn into_wide(s: &OsStr) -> Vec { + s.encode_wide().chain(Some(0)).collect() +} + +#[derive(Debug, Clone, Copy)] +enum WaitPipeResult { + Available, + Busy, +} + +// `wide_name` will be returned. +async fn wait_pipe(wide_name: Vec) -> Result<(WaitPipeResult, Vec)> { + crate::task::spawn_blocking(move || { + let result = unsafe { WaitNamedPipeW(wide_name.as_ptr(), 0) }; + if result > 0 { + Ok((WaitPipeResult::Available, wide_name)) + } else { + let err = io::Error::last_os_error(); + if err.raw_os_error() == Some(ERROR_SEM_TIMEOUT as i32) { + Ok((WaitPipeResult::Busy, wide_name)) + } else { + Err(err) + } + } + }) + .await? +} diff --git a/tokio/tests/named_pipe.rs b/tokio/tests/named_pipe.rs new file mode 100644 index 00000000000..e97c625e13b --- /dev/null +++ b/tokio/tests/named_pipe.rs @@ -0,0 +1,230 @@ +#![cfg(feature = "full")] +#![warn(rust_2018_idioms)] +#![cfg(target_os = "windows")] + +use std::io; +use std::mem::{size_of, transmute, zeroed}; +use std::os::raw::c_void; +use std::os::windows::io::AsRawHandle; +use std::slice::from_raw_parts; + +use bytes::Buf; +use futures::{ + future::{select, try_join, try_join_all, Either}, + stream::FuturesUnordered, + StreamExt, +}; +use ntapi::ntioapi::FileDirectoryInformation; +use ntapi::ntioapi::NtQueryDirectoryFile; +use ntapi::ntioapi::FILE_DIRECTORY_INFORMATION; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{NamedPipe, NamedPipeServerBuilder}, + time::{sleep, Duration}, +}; +use winapi::shared::ntdef::UNICODE_STRING; +use winapi::shared::winerror::*; + +#[tokio::test] +async fn basic() -> io::Result<()> { + const NUM_CLIENTS: u32 = 255; + const PIPE_NAME: &'static str = r"\\.\pipe\test-named-pipe-basic"; + let mut buf = [0_u8; 16]; + + // Create server to avoid NotFound from clients. + let server = NamedPipeServerBuilder::new(PIPE_NAME).build()?; + + let server = async move { + let mut pipe; + for _ in 0..NUM_CLIENTS { + pipe = server.accept().await?; + let mut buf = Vec::new(); + pipe.read_buf(&mut buf).await?; + let i = (&*buf).get_u32_le(); + pipe.write_all(format!("Server to {}", i).as_bytes()) + .await?; + } + std::io::Result::Ok(()) + }; + + // concurrent clients + let clients = (0..NUM_CLIENTS) + .map(|i| async move { + let mut pipe = NamedPipe::connect(PIPE_NAME).await?; + pipe.write_all(&i.to_le_bytes()).await?; + let mut buf = Vec::new(); + pipe.read_buf(&mut buf).await?; + assert_eq!(buf, format!("Server to {}", i).as_bytes()); + std::io::Result::Ok(()) + }) + .collect::>() + .fold(Ok(()), |a, x| async move { a.and(x) }); + + try_join(server, clients).await?; + + // client returns not found if there is no server + let err = NamedPipe::connect(PIPE_NAME).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::NotFound); + + let server = NamedPipeServerBuilder::new(PIPE_NAME).build()?.accept(); + let client = NamedPipe::connect(PIPE_NAME); + let (mut server, mut client) = try_join(server, client).await?; + + ping_pong(&mut server, &mut client).await?; + + drop(server); + + // Client reads when server is gone + let len = client.read(&mut buf).await.unwrap(); + assert_eq!(len, 0); + + drop(client); + + let server = NamedPipeServerBuilder::new(PIPE_NAME).build()?.accept(); + let client = NamedPipe::connect(PIPE_NAME); + let (mut server, mut client) = try_join(server, client).await?; + + ping_pong(&mut server, &mut client).await?; + + drop(client); + + // Server reads when client is gone + let len = server.read(&mut buf).await?; + assert_eq!(len, 0); + + // There is no way to connect to a connected server instance + // even if client is gone. + let timeout = sleep(Duration::from_millis(300)); + let client = NamedPipe::connect(PIPE_NAME); + futures::pin_mut!(client); + futures::pin_mut!(timeout); + let result = select(timeout, client).await; + assert!(matches!(result, Either::Left(_))); + + Ok(()) +} + +async fn ping_pong(l: &mut NamedPipe, r: &mut NamedPipe) -> io::Result<()> { + let mut buf = [b' '; 5]; + + l.write_all(b"ping").await?; + r.read(&mut buf).await?; + assert_eq!(&buf, b"ping "); + r.write_all(b"pong").await?; + l.read(&mut buf).await?; + assert_eq!(&buf, b"pong "); + Ok(()) +} + +#[tokio::test] +async fn immediate_disconnect() -> io::Result<()> { + const PIPE_NAME: &'static str = r"\\.\pipe\test-named-pipe-immediate-disconnect"; + + let server = NamedPipeServerBuilder::new(PIPE_NAME).build()?; + + // there is one instance + assert_eq!(num_instances("test-named-pipe-immediate-disconnect")?, 1); + + let _ = NamedPipe::connect(PIPE_NAME).await?; + + let mut instance = server.accept().await?; + + // instance will be broken because client is gone + match instance.write_all(b"ping").await { + Err(e) if e.raw_os_error() == Some(ERROR_NO_DATA as i32) => (), + x => panic!("{:?}", x), + } + + Ok(()) +} + +#[tokio::test] +async fn connection_order() -> io::Result<()> { + const PIPE_NAME: &'static str = r"\\.\pipe\test-named-pipe-connection-order"; + + let server = NamedPipeServerBuilder::new(PIPE_NAME).build()?; + + // Clients must connect to instances in a natural order, or this loop will hang + for _ in 0..1024 { + // Some time to finally close last loop's handles. + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + let servers = vec![server.accept(), server.accept(), server.accept()]; + + // now there are four instances + assert_eq!(num_instances("test-named-pipe-connection-order")?, 4); + + let clients = vec![ + NamedPipe::connect(PIPE_NAME), + NamedPipe::connect(PIPE_NAME), + NamedPipe::connect(PIPE_NAME), + ]; + + let (mut servers, mut clients) = + try_join(try_join_all(servers), try_join_all(clients)).await?; + + for s in servers.iter_mut() { + s.write_all(b"ping").await?; + } + + for c in clients[..3].iter_mut() { + let mut buf = [0_u8; 4]; + c.read(&mut buf).await?; + assert_eq!(&buf, b"ping"); + c.write_all(b"pong").await?; + } + + for s in servers.iter_mut() { + let mut buf = [0_u8; 4]; + s.read(&mut buf).await?; + assert_eq!(&buf, b"pong"); + } + } + + Ok(()) +} + +fn num_instances>(pipe_name: T) -> io::Result { + let mut name = pipe_name.as_ref().encode_utf16().collect::>(); + let mut name = UNICODE_STRING { + Length: (name.len() * size_of::()) as u16, + MaximumLength: (name.len() * size_of::()) as u16, + Buffer: name.as_mut_ptr(), + }; + let root = std::fs::File::open(r"\\.\Pipe\")?; + let mut io_status_block = unsafe { zeroed() }; + let mut file_directory_information = [0_u8; 1024]; + + let status = unsafe { + NtQueryDirectoryFile( + root.as_raw_handle(), + std::ptr::null_mut(), + None, + std::ptr::null_mut(), + &mut io_status_block, + &mut file_directory_information as *mut _ as *mut c_void, + 1024, + FileDirectoryInformation, + 0, + &mut name, + 0, + ) + }; + + if status as u32 != NO_ERROR { + return Err(io::Error::last_os_error()); + } + + let info = unsafe { transmute::<_, &FILE_DIRECTORY_INFORMATION>(&file_directory_information) }; + let raw_name = unsafe { + from_raw_parts( + info.FileName.as_ptr(), + info.FileNameLength as usize / size_of::(), + ) + }; + let name = String::from_utf16(raw_name).unwrap(); + let num_instances = unsafe { *info.EndOfFile.QuadPart() }; + + assert_eq!(name, pipe_name.as_ref()); + + Ok(num_instances as u32) +}