From 12c4013aad399a01944a51894eb58b8e2480a2c6 Mon Sep 17 00:00:00 2001 From: Kolby ML <31669092+KolbyML@users.noreply.github.com> Date: Mon, 1 May 2023 12:36:45 -0600 Subject: [PATCH] Make Windows UDS work with tests and clean implementation --- src/lib.rs | 4 + src/net/tcp/stream.rs | 20 ++--- src/net/uds/listener.rs | 9 -- src/net/uds/stream.rs | 84 ++++++++---------- src/sys/unix/pipe.rs | 20 ++--- src/sys/unix/uds/mod.rs | 4 +- src/sys/windows/iocp.rs | 4 +- src/sys/windows/mod.rs | 2 +- src/sys/windows/net.rs | 5 +- src/sys/windows/stdnet/addr.rs | 55 +++++++++--- src/sys/windows/stdnet/listener.rs | 64 +++++++++----- src/sys/windows/stdnet/mod.rs | 6 +- src/sys/windows/stdnet/socket.rs | 91 +++++++++++--------- src/sys/windows/stdnet/stream.rs | 34 +++++--- src/sys/windows/udp.rs | 4 +- src/sys/windows/uds/listener.rs | 6 +- tests/unix_listener.rs | 6 +- tests/unix_pipe.rs | 4 +- tests/unix_stream.rs | 134 ++++++++++------------------- tests/util/mod.rs | 4 +- 20 files changed, 290 insertions(+), 270 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 56a7160be..aabd716b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,6 +91,10 @@ pub mod windows { //! Windows only extensions. pub use crate::sys::named_pipe::NamedPipe; + // blocking windows uds which mimick std implementation used for tests + cfg_net! { + pub use crate::sys::windows::stdnet; + } } pub mod features { diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 3264904f5..8a3f6a2f2 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -269,49 +269,49 @@ impl TcpStream { impl Read for TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl<'a> Read for &'a TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl Write for TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } impl<'a> Write for &'a TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 181806202..4e12c8feb 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -30,21 +30,12 @@ impl UnixListener { /// standard library in the Mio equivalent. The conversion assumes nothing /// about the underlying listener; it is left up to the user to set it in /// non-blocking mode. - #[cfg(unix)] - #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn from_std(listener: net::UnixListener) -> UnixListener { UnixListener { inner: IoSource::new(listener), } } - #[cfg(windows)] - pub(crate) fn from_std(listener: net::UnixListener) -> UnixListener { - UnixListener { - inner: IoSource::new(listener), - } - } - /// Accepts a new incoming connection to this listener. /// /// The call is responsible for ensuring that the listening socket is in diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 0a04f035a..b0cd17ffd 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -41,21 +41,12 @@ impl UnixStream { /// The Unix stream here will not have `connect` called on it, so it /// should already be connected via some other means (be it manually, or /// the standard library). - #[cfg(unix)] - #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn from_std(stream: net::UnixStream) -> UnixStream { UnixStream { inner: IoSource::new(stream), } } - #[cfg(windows)] - pub(crate) fn from_std(stream: net::UnixStream) -> UnixStream { - UnixStream { - inner: IoSource::new(stream), - } - } - /// Creates an unnamed pair of connected sockets. /// /// Returns two `UnixStream`s which are connected to each other. @@ -170,34 +161,6 @@ impl UnixStream { /// # let _ = std::fs::remove_file(&file_path); /// let server = UnixListener::bind(&file_path).unwrap(); /// - /// let handle = std::thread::spawn(move || { - /// if let Ok((stream2, _)) = server.accept() { - /// // Wait until the stream is readable... - /// - /// // Read from the stream using a direct WinSock call, of course the - /// // `io::Read` implementation would be easier to use. - /// let mut buf = [0; 512]; - /// let n = stream2.try_io(|| { - /// let res = unsafe { - /// WinSock::recv( - /// stream2.as_raw_socket().try_into().unwrap(), - /// &mut buf as *mut _ as *mut _, - /// buf.len() as c_int, - /// 0 - /// ) - /// }; - /// if res != WinSock::SOCKET_ERROR { - /// Ok(res as usize) - /// } else { - /// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure - /// // should return `WouldBlock` error. - /// Err(io::Error::last_os_error()) - /// } - /// }).unwrap(); - /// eprintln!("read {} bytes", n); - /// } - /// }); - /// /// let stream1 = UnixStream::connect(&file_path).unwrap(); /// /// // Wait until the stream is writable... @@ -226,6 +189,33 @@ impl UnixStream { /// })?; /// eprintln!("write {} bytes", n); /// + /// let handle = std::thread::spawn(move || { + /// if let Ok((stream2, _)) = server.accept() { + /// // Wait until the stream is readable... + /// + /// // Read from the stream using a direct WinSock call, of course the + /// // `io::Read` implementation would be easier to use. + /// let mut buf = [0; 512]; + /// let n = stream2.try_io(|| { + /// let res = unsafe { + /// WinSock::recv( + /// stream2.as_raw_socket().try_into().unwrap(), + /// &mut buf as *mut _ as *mut _, + /// buf.len() as c_int, + /// 0 + /// ) + /// }; + /// if res != WinSock::SOCKET_ERROR { + /// Ok(res as usize) + /// } else { + /// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure + /// // should return `WouldBlock` error. + /// Err(io::Error::last_os_error()) + /// } + /// }).unwrap(); + /// eprintln!("read {} bytes", n); + /// } + /// }); /// # handle.join().unwrap(); /// # Ok(()) /// # } @@ -240,49 +230,49 @@ impl UnixStream { impl Read for UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl<'a> Read for &'a UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl Write for UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } impl<'a> Write for &'a UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } diff --git a/src/sys/unix/pipe.rs b/src/sys/unix/pipe.rs index a5b74fda0..c2654ad59 100644 --- a/src/sys/unix/pipe.rs +++ b/src/sys/unix/pipe.rs @@ -321,29 +321,29 @@ impl event::Source for Sender { impl Write for Sender { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write(buf)) + self.inner.do_io(|mut sender| sender.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write_vectored(bufs)) + self.inner.do_io(|mut sender| sender.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|sender| (&*sender).flush()) + self.inner.do_io(|mut sender| sender.flush()) } } impl Write for &Sender { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write(buf)) + self.inner.do_io(|mut sender| sender.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write_vectored(bufs)) + self.inner.do_io(|mut sender| sender.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|sender| (&*sender).flush()) + self.inner.do_io(|mut sender| sender.flush()) } } @@ -486,21 +486,21 @@ impl event::Source for Receiver { impl Read for Receiver { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read(buf)) + self.inner.do_io(|mut sender| sender.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read_vectored(bufs)) + self.inner.do_io(|mut sender| sender.read_vectored(bufs)) } } impl Read for &Receiver { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read(buf)) + self.inner.do_io(|mut sender| sender.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read_vectored(bufs)) + self.inner.do_io(|mut sender| sender.read_vectored(bufs)) } } diff --git a/src/sys/unix/uds/mod.rs b/src/sys/unix/uds/mod.rs index 0d04ab2ae..5177f0e96 100644 --- a/src/sys/unix/uds/mod.rs +++ b/src/sys/unix/uds/mod.rs @@ -40,7 +40,7 @@ cfg_os_poll! { sockaddr.sun_family = libc::AF_UNIX as libc::sa_family_t; let bytes = path.as_os_str().as_bytes(); - match (bytes.get(0), bytes.len().cmp(&sockaddr.sun_path.len())) { + match (bytes.first(), bytes.len().cmp(&sockaddr.sun_path.len())) { // Abstract paths don't need a null terminator (Some(&0), Ordering::Greater) => { return Err(io::Error::new( @@ -64,7 +64,7 @@ cfg_os_poll! { let offset = path_offset(&sockaddr); let mut socklen = offset + bytes.len(); - match bytes.get(0) { + match bytes.first() { // The struct has already been zeroes so the null byte for pathname // addresses is already there. Some(&0) | None => {} diff --git a/src/sys/windows/iocp.rs b/src/sys/windows/iocp.rs index c71b695d4..01aeb9fb9 100644 --- a/src/sys/windows/iocp.rs +++ b/src/sys/windows/iocp.rs @@ -206,7 +206,7 @@ impl CompletionStatus { /// A completion key is a per-handle key that is specified when it is added /// to an I/O completion port via `add_handle` or `add_socket`. pub fn token(&self) -> usize { - self.0.lpCompletionKey as usize + self.0.lpCompletionKey } /// Returns a pointer to the `Overlapped` structure that was specified when @@ -268,6 +268,6 @@ mod tests { } assert_eq!(s[2].bytes_transferred(), 0); assert_eq!(s[2].token(), 0); - assert_eq!(s[2].overlapped(), 0 as *mut _); + assert_eq!(s[2].overlapped(), std::ptr::null_mut()); } } diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 07f7dda6c..889278ed7 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -28,7 +28,7 @@ macro_rules! wsa_syscall { } cfg_net! { - pub(crate) mod stdnet; + pub mod stdnet; pub(crate) mod uds; pub(crate) use self::uds::SocketAddr; } diff --git a/src/sys/windows/net.rs b/src/sys/windows/net.rs index 38b17492c..c2b540560 100644 --- a/src/sys/windows/net.rs +++ b/src/sys/windows/net.rs @@ -1,6 +1,7 @@ use std::io; use std::mem; use std::net::SocketAddr; +use std::sync::Once; use windows_sys::Win32::Networking::WinSock::{ closesocket, ioctlsocket, socket, AF_INET, AF_INET6, FIONBIO, IN6_ADDR, IN6_ADDR_0, @@ -74,7 +75,7 @@ pub(crate) fn socket_addr(addr: &SocketAddr) -> (SocketAddrCRepr, i32) { }; let sockaddr_in = SOCKADDR_IN { - sin_family: AF_INET as u16, // 1 + sin_family: AF_INET, // 1 sin_port: addr.port().to_be(), sin_addr, sin_zero: [0; 8], @@ -96,7 +97,7 @@ pub(crate) fn socket_addr(addr: &SocketAddr) -> (SocketAddrCRepr, i32) { }; let sockaddr_in6 = SOCKADDR_IN6 { - sin6_family: AF_INET6 as u16, // 23 + sin6_family: AF_INET6, // 23 sin6_port: addr.port().to_be(), sin6_addr, sin6_flowinfo: addr.flowinfo(), diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs index 26b1fddde..88a96a024 100644 --- a/src/sys/windows/stdnet/addr.rs +++ b/src/sys/windows/stdnet/addr.rs @@ -3,9 +3,9 @@ use std::os::raw::c_int; use std::path::Path; use std::{fmt, io, mem}; -use windows_sys::Win32::Networking::WinSock::{sockaddr_un, SOCKADDR}; +use windows_sys::Win32::Networking::WinSock::{SOCKADDR, SOCKADDR_UN}; -fn path_offset(addr: &sockaddr_un) -> usize { +fn path_offset(addr: &SOCKADDR_UN) -> usize { // Work with an actual instance of the type since using a null pointer is UB let base = addr as *const _ as usize; let path = &addr.sun_path as *const _ as usize; @@ -14,16 +14,16 @@ fn path_offset(addr: &sockaddr_un) -> usize { cfg_os_poll! { use windows_sys::Win32::Networking::WinSock::AF_UNIX; - pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { - let sockaddr = mem::MaybeUninit::::zeroed(); + pub(super) fn socket_addr(path: &Path) -> io::Result<(SOCKADDR_UN, c_int)> { + let sockaddr = mem::MaybeUninit::::zeroed(); - // This is safe to assume because a `sockaddr_un` filled with `0` + // This is safe to assume because a `SOCKADDR_UN` filled with `0` // bytes is properly initialized. // - // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `0` is a valid value for `SOCKADDR_UN::sun_family`; it is // `WinSock::AF_UNSPEC`. // - // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // `[0; 108]` is a valid value for `SOCKADDR_UN::sun_path`; it begins an // abstract path. let mut sockaddr = unsafe { sockaddr.assume_init() }; sockaddr.sun_family = AF_UNIX; @@ -66,8 +66,9 @@ cfg_os_poll! { } } -pub(crate) struct SocketAddr { - addr: sockaddr_un, +/// An address associated with a Unix socket. +pub struct SocketAddr { + addr: SOCKADDR_UN, len: c_int, } @@ -77,11 +78,11 @@ impl SocketAddr { F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, { let mut sockaddr = { - let sockaddr = mem::MaybeUninit::::zeroed(); + let sockaddr = mem::MaybeUninit::::zeroed(); unsafe { sockaddr.assume_init() } }; - let mut len = mem::size_of::() as c_int; + let mut len = mem::size_of::() as c_int; let result = f(&mut sockaddr as *mut _ as *mut _, &mut len)?; Ok(( result, @@ -99,6 +100,38 @@ impl SocketAddr { SocketAddr::init(f).map(|(_, addr)| addr) } + cfg_os_poll! { + pub(crate) fn from_parts(sockaddr: SOCKADDR_UN, mut len: c_int) -> io::Result { + if len == 0 { + // When there is a datagram from unnamed unix socket + // linux returns zero bytes of address + len = path_offset(&sockaddr) as c_int; // i.e. zero-length address + } else if sockaddr.sun_family != windows_sys::Win32::Networking::WinSock::AF_UNIX as _ { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!( + "file descriptor did not correspond to a Unix socket: {}", + sockaddr.sun_family + ), + )); + } + + Ok(SocketAddr { + addr: sockaddr, + len, + }) + } + } + + /// Returns the contents of this address if it is a `pathname` address. + pub fn as_pathname(&self) -> Option<&Path> { + if let AddressKind::Pathname(path) = self.address() { + Some(path) + } else { + None + } + } + pub(crate) fn address(&self) -> AddressKind<'_> { let len = self.len as usize - path_offset(&self.addr); // sockaddr_un::sun_path on Windows is a Win32 UTF-8 file system path diff --git a/src/sys/windows/stdnet/listener.rs b/src/sys/windows/stdnet/listener.rs index 214167276..487aac9eb 100644 --- a/src/sys/windows/stdnet/listener.rs +++ b/src/sys/windows/stdnet/listener.rs @@ -1,34 +1,39 @@ +use super::{socket::Socket, SocketAddr}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::{fmt, io, mem}; - use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; -use super::{socket::Socket, SocketAddr}; - -pub(crate) struct UnixListener(Socket); +/// A structure representing a Unix domain socket server. +pub struct UnixListener { + inner: Socket, +} impl UnixListener { - pub(crate) fn local_addr(&self) -> io::Result { + /// Returns the local socket address of this listener. + pub fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( - getsockname(self.0.as_raw_socket() as _, addr, len), + getsockname(self.inner.as_raw_socket() as _, addr, len), SOCKET_ERROR ) }) } - pub(crate) fn take_error(&self) -> io::Result> { - self.0.take_error() + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { + self.inner.take_error() } } cfg_os_poll! { + use std::os::raw::c_int; use std::path::Path; - use super::{socket_addr, UnixStream}; + use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN; impl UnixListener { - pub(crate) fn bind>(path: P) -> io::Result { + /// Creates a new `UnixListener` bound to the specified socket. + pub fn bind>(path: P) -> io::Result { let inner = Socket::new()?; let (addr, len) = socket_addr(path.as_ref())?; @@ -37,16 +42,33 @@ cfg_os_poll! { SOCKET_ERROR )?; wsa_syscall!(listen(inner.as_raw_socket() as _, 1024), SOCKET_ERROR)?; - Ok(UnixListener(inner)) + Ok(UnixListener { + inner + }) } - pub(crate) fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - SocketAddr::init(|addr, len| self.0.accept(addr, len)) - .map(|(sock, addr)| (UnixStream(sock), addr)) + /// Accepts a new incoming connection to this listener. + /// + /// This function will block the calling thread until a new Unix connection + /// is established. When established, the corresponding [`UnixStream`] and + /// the remote peer's address will be returned. + pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + let mut storage: SOCKADDR_UN = unsafe { mem::zeroed() }; + let mut len = mem::size_of_val(&storage) as c_int; + let sock = self.inner.accept(&mut storage as *mut _ as *mut _, &mut len)?; + let addr = SocketAddr::from_parts(storage, len)?; + Ok((UnixStream(sock), addr)) } - pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.0.set_nonblocking(nonblocking) + /// Moves the socket into or out of nonblocking mode. + /// + /// This will result in the `accept` operation becoming nonblocking, + /// i.e., immediately returning from their calls. If the IO operation is + /// successful, `Ok` is returned and no further action is required. If the + /// IO operation could not be completed and needs to be retried, an error + /// with kind [`io::ErrorKind::WouldBlock`] is returned. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.inner.set_nonblocking(nonblocking) } } } @@ -54,7 +76,7 @@ cfg_os_poll! { impl fmt::Debug for UnixListener { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { let mut builder = fmt.debug_struct("UnixListener"); - builder.field("socket", &self.0.as_raw_socket()); + builder.field("socket", &self.inner.as_raw_socket()); if let Ok(addr) = self.local_addr() { builder.field("local", &addr); } @@ -64,19 +86,21 @@ impl fmt::Debug for UnixListener { impl AsRawSocket for UnixListener { fn as_raw_socket(&self) -> RawSocket { - self.0.as_raw_socket() + self.inner.as_raw_socket() } } impl FromRawSocket for UnixListener { unsafe fn from_raw_socket(sock: RawSocket) -> Self { - UnixListener(Socket::from_raw_socket(sock)) + UnixListener { + inner: Socket::from_raw_socket(sock), + } } } impl IntoRawSocket for UnixListener { fn into_raw_socket(self) -> RawSocket { - let ret = self.0.as_raw_socket(); + let ret = self.inner.as_raw_socket(); mem::forget(self); ret } diff --git a/src/sys/windows/stdnet/mod.rs b/src/sys/windows/stdnet/mod.rs index 0eb5130d4..c1e1cc748 100644 --- a/src/sys/windows/stdnet/mod.rs +++ b/src/sys/windows/stdnet/mod.rs @@ -4,9 +4,9 @@ mod listener; mod socket; mod stream; -pub(crate) use self::addr::SocketAddr; -pub(crate) use self::listener::UnixListener; -pub(crate) use self::stream::UnixStream; +pub use self::addr::SocketAddr; +pub use self::listener::UnixListener; +pub use self::stream::UnixStream; cfg_os_poll! { pub(self) use self::addr::socket_addr; diff --git a/src/sys/windows/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs index 9212c1e04..d34d04037 100644 --- a/src/sys/windows/stdnet/socket.rs +++ b/src/sys/windows/stdnet/socket.rs @@ -1,12 +1,10 @@ use std::cmp::min; -use std::convert::TryInto; use std::io::{self, IoSlice, IoSliceMut}; use std::mem; use std::net::Shutdown; use std::os::raw::c_int; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::ptr; - use windows_sys::Win32::Networking::WinSock::{self, closesocket, SOCKET, SOCKET_ERROR, WSABUF}; /// Maximum size of a buffer passed to system call like `recv` and `send`. @@ -42,42 +40,38 @@ impl Socket { ); match res { Ok(_) => Ok(total as usize), - Err(ref err) if err.raw_os_error() == Some(WinSock::WSAESHUTDOWN as i32) => Ok(0), + Err(ref err) if err.raw_os_error() == Some(WinSock::WSAESHUTDOWN) => Ok(0), Err(err) => Err(err), } } - pub fn send(&self, buf: &[u8]) -> io::Result { - wsa_syscall!( - send( + pub fn write(&self, buf: &[u8]) -> io::Result { + let response = unsafe { + windows_sys::Win32::Networking::WinSock::send( self.0, buf.as_ptr().cast(), min(buf.len(), MAX_BUF_LEN) as c_int, 0, - ), - SOCKET_ERROR - ) - .map(|n| n as usize) + ) + }; + if response == SOCKET_ERROR { + return match unsafe { windows_sys::Win32::Networking::WinSock::WSAGetLastError() } { + windows_sys::Win32::Networking::WinSock::WSAESHUTDOWN => { + Err(io::Error::new(io::ErrorKind::BrokenPipe, "brokenpipe")) + } + e => Err(std::io::Error::from_raw_os_error(e)), + }; + } + Ok(response as usize) } - pub fn send_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { let mut total = 0; wsa_syscall!( WSASend( self.0, - // FIXME: From the `WSASend` docs [1]: - // > For a Winsock application, once the WSASend function is called, - // > the system owns these buffers and the application may not - // > access them. - // - // So what we're doing is actually UB as `bufs` needs to be `&mut - // [IoSlice<'_>]`. - // - // See: https://github.com/rust-lang/socket2-rs/issues/129. - // - // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend - bufs.as_ptr() as *mut _, - min(bufs.len(), u32::MAX as usize) as u32, + bufs.as_ptr() as *mut WSABUF, + bufs.len().min(u32::MAX as usize) as u32, &mut total, 0, std::ptr::null_mut(), @@ -94,7 +88,7 @@ impl Socket { Shutdown::Read => WinSock::SD_RECEIVE, Shutdown::Both => WinSock::SD_BOTH, }; - wsa_syscall!(shutdown(self.0, how.try_into().unwrap()), SOCKET_ERROR)?; + wsa_syscall!(shutdown(self.0, how), SOCKET_ERROR)?; Ok(()) } @@ -104,8 +98,8 @@ impl Socket { wsa_syscall!( getsockopt( self.0 as _, - WinSock::SOL_SOCKET.try_into().unwrap(), - WinSock::SO_ERROR.try_into().unwrap(), + WinSock::SOL_SOCKET, + WinSock::SO_ERROR, &mut val as *mut _ as *mut _, &mut len, ), @@ -116,36 +110,48 @@ impl Socket { if val == 0 { Ok(None) } else { - Ok(Some(io::Error::from_raw_os_error(val as i32))) + Ok(Some(io::Error::from_raw_os_error(val))) } } } cfg_os_poll! { + use windows_sys::Win32::Foundation::{HANDLE, HANDLE_FLAG_INHERIT, SetHandleInformation}; use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR}; use super::init; impl Socket { pub fn new() -> io::Result { init(); - wsa_syscall!( - WSASocketW( - WinSock::AF_UNIX.into(), - WinSock::SOCK_STREAM.into(), - 0, - ptr::null_mut(), - 0, - WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, - ), - INVALID_SOCKET - ).map(Socket) + match wsa_syscall!(WSASocketW( + WinSock::AF_UNIX.into(), + WinSock::SOCK_STREAM, + 0, + ptr::null_mut(), + 0, + WinSock::WSA_FLAG_OVERLAPPED, + ), INVALID_SOCKET) { + Ok(res) => { + let socket = Socket(res); + socket.set_no_inherit()?; + Ok(socket) + }, + Err(e) => Err(e), + } } pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { // WinSock's accept returns a socket with the same properties as the listener. it is // called on. In particular, the WSA_FLAG_NO_HANDLE_INHERIT will be inherited from the // listener. - wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET).map(Socket) + match wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET) { + Ok(res) => { + let socket = Socket(res); + socket.set_no_inherit()?; + Ok(socket) + }, + Err(e) => Err(e), + } } pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { @@ -156,6 +162,11 @@ cfg_os_poll! { )?; Ok(()) } + + pub fn set_no_inherit(&self) -> io::Result<()> { + syscall!(SetHandleInformation(self.0 as HANDLE, HANDLE_FLAG_INHERIT, 0), PartialEq::eq, -1)?; + Ok(()) + } } } diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index ce1da2f54..9d78f5383 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -7,10 +7,12 @@ use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; use super::{socket::Socket, SocketAddr}; -pub(crate) struct UnixStream(pub(super) Socket); +/// A Unix stream socket. +pub struct UnixStream(pub(super) Socket); impl UnixStream { - pub(crate) fn local_addr(&self) -> io::Result { + /// Connects to the socket specified by [`address`]. + pub fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( getsockname(self.0.as_raw_socket() as _, addr, len), @@ -19,7 +21,8 @@ impl UnixStream { }) } - pub(crate) fn peer_addr(&self) -> io::Result { + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( getpeername(self.0.as_raw_socket() as _, addr, len), @@ -28,22 +31,28 @@ impl UnixStream { }) } - pub(crate) fn take_error(&self) -> io::Result> { + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { self.0.take_error() } - pub(crate) fn shutdown(&self, how: Shutdown) -> io::Result<()> { + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of [`Shutdown`]). + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { self.0.shutdown(how) } } cfg_os_poll! { use std::path::Path; - use windows_sys::Win32::Networking::WinSock::WSAEINPROGRESS; use super::socket_addr; impl UnixStream { - pub(crate) fn connect>(path: P) -> io::Result { + /// Connects to the socket named by `path`. + pub fn connect>(path: P) -> io::Result { let inner = Socket::new()?; let (addr, len) = socket_addr(path.as_ref())?; @@ -51,18 +60,19 @@ cfg_os_poll! { connect( inner.as_raw_socket() as _, &addr as *const _ as *const _, - len as i32, + len, ), SOCKET_ERROR ) { Ok(_) => {} - Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} + Err(ref err) if err.kind() == std::io::ErrorKind::Other => {} Err(e) => return Err(e), } Ok(UnixStream(inner)) } - pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + /// Moves the socket into or out of nonblocking mode. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { self.0.set_nonblocking(nonblocking) } } @@ -118,11 +128,11 @@ impl io::Write for UnixStream { impl<'a> io::Write for &'a UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.send(buf) + self.0.write(buf) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.0.send_vectored(bufs) + self.0.write_vectored(bufs) } fn flush(&mut self) -> io::Result<()> { diff --git a/src/sys/windows/udp.rs b/src/sys/windows/udp.rs index 87e269fa3..3e975973f 100644 --- a/src/sys/windows/udp.rs +++ b/src/sys/windows/udp.rs @@ -30,8 +30,8 @@ pub(crate) fn only_v6(socket: &net::UdpSocket) -> io::Result { syscall!( getsockopt( socket.as_raw_socket() as usize, - IPPROTO_IPV6 as i32, - IPV6_V6ONLY as i32, + IPPROTO_IPV6, + IPV6_V6ONLY, optval.as_mut_ptr().cast(), &mut optlen, ), diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index 4ba4395e5..f33312d50 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -13,9 +13,9 @@ pub(crate) fn bind(path: &Path) -> io::Result { } pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { - listener - .accept() - .map(|(stream, addr)| (UnixStream::from_std(stream), addr)) + listener.set_nonblocking(true)?; + let (stream, addr) = listener.accept()?; + Ok((UnixStream::from_std(stream), addr)) } pub(crate) fn local_addr(listener: &net::UnixListener) -> io::Result { diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index c131497cc..2303ca530 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,8 +1,8 @@ #![cfg(all(feature = "os-poll", feature = "net"))] -#[cfg(windows)] -use mio::net; use mio::net::UnixListener; +#[cfg(windows)] +use mio::windows::stdnet as net; use mio::{Interest, Token}; use std::io::{self, Read}; #[cfg(unix)] @@ -139,7 +139,7 @@ fn unix_listener_deregister() { #[cfg(target_os = "linux")] #[test] -fn unix_listener_abstract_namesapce() { +fn unix_listener_abstract_namespace() { use rand::Rng; let num: u64 = rand::thread_rng().gen(); let name = format!("\u{0000}-mio-abstract-uds-{}", num); diff --git a/tests/unix_pipe.rs b/tests/unix_pipe.rs index a83e3833b..f8e6464c9 100644 --- a/tests/unix_pipe.rs +++ b/tests/unix_pipe.rs @@ -49,7 +49,7 @@ fn smoke() { ); let n = receiver.read(&mut buf).unwrap(); assert_eq!(n, DATA1.len()); - assert_eq!(&buf[..n], &*DATA1); + assert_eq!(&buf[..n], DATA1); } #[test] @@ -162,7 +162,7 @@ fn from_child_process_io() { let mut buf = [0; 20]; let n = receiver.read(&mut buf).unwrap(); assert_eq!(n, DATA1.len()); - assert_eq!(&buf[..n], &*DATA1); + assert_eq!(&buf[..n], DATA1); drop(sender); diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 0103d0b1b..eff0f3628 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,8 +1,8 @@ #![cfg(all(feature = "os-poll", feature = "net"))] -#[cfg(windows)] -use mio::net; use mio::net::UnixStream; +#[cfg(windows)] +use mio::windows::stdnet as net; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; @@ -12,8 +12,6 @@ use std::path::Path; use std::sync::mpsc::channel; use std::sync::{Arc, Barrier}; use std::thread; -#[cfg(windows)] -use std::time::Duration; #[macro_use] mod util; @@ -83,7 +81,6 @@ fn unix_stream_connect() { handle.join().unwrap(); } -#[cfg(unix)] #[test] fn unix_stream_from_std() { smoke_test( @@ -253,13 +250,7 @@ fn unix_stream_shutdown_write() { ); let err = stream.write(DATA2).unwrap_err(); - #[cfg(unix)] assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); - #[cfg(windows)] - { - use windows_sys::Win32::Networking::WinSock::WSAESHUTDOWN; - assert_eq!(err.raw_os_error(), Some(WSAESHUTDOWN)); - } // Read should be ok let mut buf = [0; DEFAULT_BUF_SIZE]; @@ -465,6 +456,8 @@ where assert!(stream.take_error().unwrap().is_none()); + // To comply with draining behavior on windows we have to check assert_would_block() + // https://github.com/tokio-rs/mio/issues/1611 assert_would_block(stream.read(&mut buf)); let bufs = [IoSlice::new(DATA1), IoSlice::new(DATA2)]; @@ -492,107 +485,70 @@ where handle.join().unwrap(); } -#[cfg(windows)] -fn new_listener( +fn new_echo_listener( connections: usize, test_name: &'static str, - handle_stream: F, -) -> (thread::JoinHandle<()>, net::SocketAddr) -where - F: Fn(net::UnixStream) + std::marker::Send + 'static, -{ +) -> (thread::JoinHandle<()>, net::SocketAddr) { let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); - // We use mio's non-blocking listener here for windows, since there is no listener in std - // yet. We must be sure to poll before listener I/O. - let mut listener = net::UnixListener::bind(path).unwrap(); - let (mut poll, mut events) = init_with_poll(); - poll.registry() - .register(&mut listener, TOKEN_1, Interest::READABLE) - .unwrap(); - + let listener = net::UnixListener::bind(path).unwrap(); let local_addr = listener.local_addr().unwrap(); addr_sender.send(local_addr).unwrap(); for _ in 0..connections { - poll.poll(&mut events, Some(Duration::from_millis(500))) - .unwrap(); - let (stream, _) = listener.accept().unwrap(); - assert_would_block(listener.accept()); - handle_stream(stream); + let (mut stream, _) = listener.accept().unwrap(); + + // On Linux based system it will cause a connection reset + // error when the reading side of the peer connection is + // shutdown, we don't consider it an actual here. + let (mut read, mut written) = (0, 0); + let mut buf = [0; DEFAULT_BUF_SIZE]; + loop { + let n = match stream.read(&mut buf) { + Ok(amount) => { + read += amount; + amount + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref err) if err.kind() == io::ErrorKind::ConnectionReset => break, + Err(err) => panic!("{}", err), + }; + if n == 0 { + break; + } + match stream.write(&buf[..n]) { + Ok(amount) => written += amount, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref err) if err.kind() == io::ErrorKind::BrokenPipe => break, + Err(err) => panic!("{}", err), + }; + } + assert_eq!(read, written, "unequal reads and writes"); } }); (handle, addr_receiver.recv().unwrap()) } -#[cfg(unix)] -fn new_listener( +fn new_noop_listener( connections: usize, + barrier: Arc, test_name: &'static str, - handle_stream: F, -) -> (thread::JoinHandle<()>, net::SocketAddr) -where - F: Fn(net::UnixStream) + std::marker::Send + 'static, -{ - let (addr_sender, addr_receiver) = channel(); +) -> (thread::JoinHandle<()>, net::SocketAddr) { + let (sender, receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); let listener = net::UnixListener::bind(path).unwrap(); let local_addr = listener.local_addr().unwrap(); - addr_sender.send(local_addr).unwrap(); + sender.send(local_addr).unwrap(); for _ in 0..connections { let (stream, _) = listener.accept().unwrap(); - handle_stream(stream); + barrier.wait(); + stream.shutdown(Shutdown::Write).unwrap(); + barrier.wait(); + drop(stream); } }); - (handle, addr_receiver.recv().unwrap()) -} - -fn new_echo_listener( - connections: usize, - test_name: &'static str, -) -> (thread::JoinHandle<()>, net::SocketAddr) { - new_listener(connections, test_name, |mut stream| { - // On Linux based system it will cause a connection reset - // error when the reading side of the peer connection is - // shutdown, we don't consider it an actual here. - let (mut read, mut written) = (0, 0); - let mut buf = [0; DEFAULT_BUF_SIZE]; - loop { - let n = match stream.read(&mut buf) { - Ok(amount) => { - read += amount; - amount - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) if err.kind() == io::ErrorKind::ConnectionReset => break, - Err(err) => panic!("{}", err), - }; - if n == 0 { - break; - } - match stream.write(&buf[..n]) { - Ok(amount) => written += amount, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) if err.kind() == io::ErrorKind::BrokenPipe => break, - Err(err) => panic!("{}", err), - }; - } - assert_eq!(read, written, "unequal reads and writes"); - }) -} - -fn new_noop_listener( - connections: usize, - barrier: Arc, - test_name: &'static str, -) -> (thread::JoinHandle<()>, net::SocketAddr) { - new_listener(connections, test_name, move |stream| { - barrier.wait(); - stream.shutdown(Shutdown::Write).unwrap(); - barrier.wait(); - drop(stream); - }) + (handle, receiver.recv().unwrap()) } diff --git a/tests/util/mod.rs b/tests/util/mod.rs index 7a192d9b0..7fcb9fe5e 100644 --- a/tests/util/mod.rs +++ b/tests/util/mod.rs @@ -285,8 +285,8 @@ pub fn set_linger_zero(socket: &TcpStream) { let res = unsafe { setsockopt( socket.as_raw_socket() as _, - SOL_SOCKET as i32, - SO_LINGER as i32, + SOL_SOCKET, + SO_LINGER, &mut val as *mut _ as *mut _, size_of::() as _, )