From 5d6132411a0f017ec6079850587bc4c965fbd611 Mon Sep 17 00:00:00 2001 From: John Starks Date: Tue, 16 Aug 2022 11:37:05 -0700 Subject: [PATCH] Windows: Add support for Unix sockets Newer versions of Windows support AF_UNIX stream sockets. This change adds Windows support for the `SockAddr::unix` function and the `Domain::UNIX` constant. Since Unix sockets are now available on all tier1 platforms, this also removes `all` feature requirement from the `SockAddr::unix` function. --- CONTRIBUTING.md | 2 +- src/lib.rs | 3 ++ src/sockaddr.rs | 11 ++++ src/sys/unix.rs | 123 ++++++++++++++++++--------------------------- src/sys/windows.rs | 48 +++++++++++++++++- tests/socket.rs | 26 ++++++++-- 6 files changed, 132 insertions(+), 81 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b865720a..a6e40972 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -33,7 +33,7 @@ add new tests. All types and methods that are available on all tier 1 platforms are defined in the first level of the source, i.e. `src/*.rs` files. Additional API that is -platform specific, e.g. `Domain::UNIX`, is defined in `src/sys/*.rs` and only +platform specific, e.g. `Domain::VSOCK`, is defined in `src/sys/*.rs` and only for the platforms that support it. For API that is not available on all tier 1 platforms the `all` feature is used, to indicate to the user that they're using API that might is not available on all platforms. diff --git a/src/lib.rs b/src/lib.rs index 74ffa433..f1e94c4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -210,6 +210,9 @@ impl Domain { /// Domain for IPv6 communication, corresponding to `AF_INET6`. pub const IPV6: Domain = Domain(sys::AF_INET6); + /// Domain for Unix socket communication, corresponding to `AF_UNIX`. + pub const UNIX: Domain = Domain(sys::AF_UNIX); + /// Returns the correct domain for `address`. pub const fn for_address(address: SocketAddr) -> Domain { match address { diff --git a/src/sockaddr.rs b/src/sockaddr.rs index 07caad95..5a25adcc 100644 --- a/src/sockaddr.rs +++ b/src/sockaddr.rs @@ -1,5 +1,6 @@ use std::mem::{self, size_of, MaybeUninit}; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::path::Path; use std::{fmt, io}; #[cfg(windows)] @@ -210,6 +211,16 @@ impl SockAddr { _ => None, } } + + /// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path. + /// + /// Returns an error if the path is longer than `SUN_LEN`. + pub fn unix

(path: P) -> io::Result + where + P: AsRef, + { + crate::sys::unix_sockaddr(path.as_ref()) + } } impl From for SockAddr { diff --git a/src/sys/unix.rs b/src/sys/unix.rs index d2c93104..8d83b1fc 100644 --- a/src/sys/unix.rs +++ b/src/sys/unix.rs @@ -25,7 +25,6 @@ use std::num::NonZeroU32; ) ))] use std::num::NonZeroUsize; -#[cfg(feature = "all")] use std::os::unix::ffi::OsStrExt; #[cfg(all( feature = "all", @@ -40,9 +39,7 @@ use std::os::unix::io::RawFd; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd}; #[cfg(feature = "all")] use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream}; -#[cfg(feature = "all")] use std::path::Path; -#[cfg(not(all(target_os = "redox", not(feature = "all"))))] use std::ptr; use std::time::{Duration, Instant}; use std::{io, slice}; @@ -58,7 +55,7 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type}; pub(crate) use libc::c_int; // Used in `Domain`. -pub(crate) use libc::{AF_INET, AF_INET6}; +pub(crate) use libc::{AF_INET, AF_INET6, AF_UNIX}; // Used in `Type`. #[cfg(all(feature = "all", not(target_os = "redox")))] pub(crate) use libc::SOCK_RAW; @@ -222,10 +219,6 @@ type IovLen = c_int; /// Unix only API. impl Domain { - /// Domain for Unix socket communication, corresponding to `AF_UNIX`. - #[cfg_attr(docsrs, doc(cfg(unix)))] - pub const UNIX: Domain = Domain(libc::AF_UNIX); - /// Domain for low-level packet interface, corresponding to `AF_PACKET`. #[cfg(all( feature = "all", @@ -460,71 +453,56 @@ impl<'a> MaybeUninitSlice<'a> { } } -/// Unix only API. -impl SockAddr { - /// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path. - /// - /// # Failure - /// - /// Returns an error if the path is longer than `SUN_LEN`. - #[cfg(feature = "all")] - #[cfg_attr(docsrs, doc(cfg(all(unix, feature = "all"))))] - #[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable. - pub fn unix

(path: P) -> io::Result - where - P: AsRef, - { +#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable. +pub(crate) fn unix_sockaddr(path: &Path) -> io::Result { + // SAFETY: a `sockaddr_storage` of all zeros is valid. + let mut storage = unsafe { mem::zeroed::() }; + let len = { + let storage: &mut libc::sockaddr_un = + unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() }; + + let bytes = path.as_os_str().as_bytes(); + let too_long = match bytes.first() { + None => false, + // linux abstract namespaces aren't null-terminated + Some(&0) => bytes.len() > storage.sun_path.len(), + Some(_) => bytes.len() >= storage.sun_path.len(), + }; + if too_long { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN", + )); + } + + storage.sun_family = libc::AF_UNIX as sa_family_t; + // Safety: `bytes` and `addr.sun_path` are not overlapping and + // both point to valid memory. + // `storage` was initialized to zero above, so the path is + // already null terminated. unsafe { - SockAddr::try_init(|storage, len| { - // Safety: `SockAddr::try_init` zeros the address, which is a - // valid representation. - let storage: &mut libc::sockaddr_un = unsafe { &mut *storage.cast() }; - let len: &mut socklen_t = unsafe { &mut *len }; - - let bytes = path.as_ref().as_os_str().as_bytes(); - let too_long = match bytes.first() { - None => false, - // linux abstract namespaces aren't null-terminated - Some(&0) => bytes.len() > storage.sun_path.len(), - Some(_) => bytes.len() >= storage.sun_path.len(), - }; - if too_long { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "path must be shorter than SUN_LEN", - )); - } + ptr::copy_nonoverlapping( + bytes.as_ptr(), + storage.sun_path.as_mut_ptr() as *mut u8, + bytes.len(), + ) + }; - storage.sun_family = libc::AF_UNIX as sa_family_t; - // Safety: `bytes` and `addr.sun_path` are not overlapping and - // both point to valid memory. - // `SockAddr::try_init` zeroes the memory, so the path is - // already null terminated. - unsafe { - ptr::copy_nonoverlapping( - bytes.as_ptr(), - storage.sun_path.as_mut_ptr() as *mut u8, - bytes.len(), - ) - }; - - let base = storage as *const _ as usize; - let path = &storage.sun_path as *const _ as usize; - let sun_path_offset = path - base; - let length = sun_path_offset - + bytes.len() - + match bytes.first() { - Some(&0) | None => 0, - Some(_) => 1, - }; - *len = length as socklen_t; - - Ok(()) - }) - } - .map(|(_, addr)| addr) - } + let base = storage as *const _ as usize; + let path = &storage.sun_path as *const _ as usize; + let sun_path_offset = path - base; + sun_path_offset + + bytes.len() + + match bytes.first() { + Some(&0) | None => 0, + Some(_) => 1, + } + }; + Ok(unsafe { SockAddr::new(storage, len as socklen_t) }) +} +/// Unix only API. +impl SockAddr { /// Constructs a `SockAddr` with the family `AF_VSOCK` and the provided CID/port. /// /// # Errors @@ -538,9 +516,8 @@ impl SockAddr { doc(cfg(all(feature = "all", any(target_os = "android", target_os = "linux")))) )] pub fn vsock(cid: u32, port: u32) -> SockAddr { - // SAFETY: a `sockaddr_storage` of all zeros is valid, hence we can - // safely assume it's initialised. - let mut storage = unsafe { MaybeUninit::::zeroed().assume_init() }; + // SAFETY: a `sockaddr_storage` of all zeros is valid. + let mut storage = unsafe { mem::zeroed::() }; { let storage: &mut libc::sockaddr_vm = unsafe { &mut *((&mut storage as *mut sockaddr_storage).cast()) }; diff --git a/src/sys/windows.rs b/src/sys/windows.rs index 983ee9d1..a24720c6 100644 --- a/src/sys/windows.rs +++ b/src/sys/windows.rs @@ -14,6 +14,7 @@ use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown}; use std::os::windows::io::{ AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket, }; +use std::path::Path; use std::sync::Once; use std::time::{Duration, Instant}; use std::{process, ptr, slice}; @@ -41,8 +42,8 @@ pub(crate) const MSG_TRUNC: c_int = 0x01; // Used in `Domain`. pub(crate) const AF_INET: c_int = windows_sys::Win32::Networking::WinSock::AF_INET as c_int; pub(crate) const AF_INET6: c_int = windows_sys::Win32::Networking::WinSock::AF_INET6 as c_int; -const AF_UNIX: c_int = windows_sys::Win32::Networking::WinSock::AF_UNIX as c_int; -const AF_UNSPEC: c_int = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as c_int; +pub(crate) const AF_UNIX: c_int = windows_sys::Win32::Networking::WinSock::AF_UNIX as c_int; +pub(crate) const AF_UNSPEC: c_int = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as c_int; // Used in `Type`. pub(crate) const SOCK_STREAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_STREAM as c_int; pub(crate) const SOCK_DGRAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_DGRAM as c_int; @@ -774,6 +775,49 @@ pub(crate) fn to_mreqn( } } +#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable. +pub(crate) fn unix_sockaddr(path: &Path) -> io::Result { + // SAFETY: a `sockaddr_storage` of all zeros is valid. + let mut storage = unsafe { mem::zeroed::() }; + let len = { + let storage: &mut windows_sys::Win32::Networking::WinSock::sockaddr_un = + unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() }; + + // Windows expects a UTF-8 path here even though Windows paths are + // usually UCS-2 encoded. If Rust exposed OsStr's Wtf8 encoded + // buffer, this could be used directly, relying on Windows to + // validate the path, but Rust hides this implementation detail. + // + // See . + let bytes = path + .to_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "path must be valid UTF-8"))? + .as_bytes(); + + // Windows appears to allow non-null-terminated paths, but this is + // not documented, so do not rely on it yet. + // + // See . + if bytes.len() >= storage.sun_path.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN", + )); + } + + storage.sun_family = crate::sys::AF_UNIX as sa_family_t; + // `storage` was initialized to zero above, so the path is + // already null terminated. + storage.sun_path[..bytes.len()].copy_from_slice(bytes); + + let base = storage as *const _ as usize; + let path = &storage.sun_path as *const _ as usize; + let sun_path_offset = path - base; + sun_path_offset + bytes.len() + 1 + }; + Ok(unsafe { SockAddr::new(storage, len as socklen_t) }) +} + /// Windows only API. impl crate::Socket { /// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`. diff --git a/tests/socket.rs b/tests/socket.rs index f36720fa..ecbe9e58 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -12,7 +12,6 @@ use std::fs::File; use std::io; #[cfg(not(target_os = "redox"))] use std::io::IoSlice; -#[cfg(all(unix, feature = "all"))] use std::io::Read; use std::io::Write; use std::mem::{self, MaybeUninit}; @@ -36,7 +35,6 @@ use std::os::windows::io::AsRawSocket; use std::str; use std::thread; use std::time::Duration; -#[cfg(all(unix, feature = "all"))] use std::{env, fs}; #[cfg(windows)] @@ -62,7 +60,6 @@ fn domain_fmt_debug() { let tests = &[ (Domain::IPV4, "AF_INET"), (Domain::IPV6, "AF_INET6"), - #[cfg(unix)] (Domain::UNIX, "AF_UNIX"), #[cfg(all(feature = "all", any(target_os = "fuchsia", target_os = "linux")))] (Domain::PACKET, "AF_PACKET"), @@ -130,7 +127,6 @@ fn from_invalid_raw_fd_should_panic() { } #[test] -#[cfg(all(unix, feature = "all"))] fn socket_address_unix() { let string = "/tmp/socket"; let addr = SockAddr::unix(string).unwrap(); @@ -429,9 +425,29 @@ fn pair() { assert_eq!(&buf[..n], DATA); } +fn unix_sockets_supported() -> bool { + #[cfg(windows)] + { + // Only some versions of Windows support Unix sockets. + match Socket::new(Domain::UNIX, Type::STREAM, None) { + Ok(_) => {} + Err(err) + if err.raw_os_error() + == Some(windows_sys::Win32::Networking::WinSock::WSAEAFNOSUPPORT as i32) => + { + return false; + } + Err(err) => panic!("socket error: {}", err), + } + } + true +} + #[test] -#[cfg(all(feature = "all", unix))] fn unix() { + if !unix_sockets_supported() { + return; + } let mut path = env::temp_dir(); path.push("socket2"); let _ = fs::remove_dir_all(&path);