Skip to content

Commit

Permalink
Windows: Add support for Unix sockets
Browse files Browse the repository at this point in the history
Newer versions of Windows support AF_UNIX stream sockets. This change
adds Windows support for the `SockAddr::unix` function and the
`Domain::UNIX` constant.

Since Unix sockets are now available on all tier1 platforms, this also
removes `all` feature requirement from the `SockAddr::unix` function.
  • Loading branch information
jstarks authored and Thomasdezeeuw committed Aug 18, 2022
1 parent eaa6300 commit 5d61324
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 81 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ add new tests.

All types and methods that are available on all tier 1 platforms are defined in
the first level of the source, i.e. `src/*.rs` files. Additional API that is
platform specific, e.g. `Domain::UNIX`, is defined in `src/sys/*.rs` and only
platform specific, e.g. `Domain::VSOCK`, is defined in `src/sys/*.rs` and only
for the platforms that support it. For API that is not available on all tier 1
platforms the `all` feature is used, to indicate to the user that they're using
API that might is not available on all platforms.
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ impl Domain {
/// Domain for IPv6 communication, corresponding to `AF_INET6`.
pub const IPV6: Domain = Domain(sys::AF_INET6);

/// Domain for Unix socket communication, corresponding to `AF_UNIX`.
pub const UNIX: Domain = Domain(sys::AF_UNIX);

/// Returns the correct domain for `address`.
pub const fn for_address(address: SocketAddr) -> Domain {
match address {
Expand Down
11 changes: 11 additions & 0 deletions src/sockaddr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::mem::{self, size_of, MaybeUninit};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::path::Path;
use std::{fmt, io};

#[cfg(windows)]
Expand Down Expand Up @@ -210,6 +211,16 @@ impl SockAddr {
_ => None,
}
}

/// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path.
///
/// Returns an error if the path is longer than `SUN_LEN`.
pub fn unix<P>(path: P) -> io::Result<SockAddr>
where
P: AsRef<Path>,
{
crate::sys::unix_sockaddr(path.as_ref())
}
}

impl From<SocketAddr> for SockAddr {
Expand Down
123 changes: 50 additions & 73 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use std::num::NonZeroU32;
)
))]
use std::num::NonZeroUsize;
#[cfg(feature = "all")]
use std::os::unix::ffi::OsStrExt;
#[cfg(all(
feature = "all",
Expand All @@ -40,9 +39,7 @@ use std::os::unix::io::RawFd;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd};
#[cfg(feature = "all")]
use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
#[cfg(feature = "all")]
use std::path::Path;
#[cfg(not(all(target_os = "redox", not(feature = "all"))))]
use std::ptr;
use std::time::{Duration, Instant};
use std::{io, slice};
Expand All @@ -58,7 +55,7 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
pub(crate) use libc::c_int;

// Used in `Domain`.
pub(crate) use libc::{AF_INET, AF_INET6};
pub(crate) use libc::{AF_INET, AF_INET6, AF_UNIX};
// Used in `Type`.
#[cfg(all(feature = "all", not(target_os = "redox")))]
pub(crate) use libc::SOCK_RAW;
Expand Down Expand Up @@ -222,10 +219,6 @@ type IovLen = c_int;

/// Unix only API.
impl Domain {
/// Domain for Unix socket communication, corresponding to `AF_UNIX`.
#[cfg_attr(docsrs, doc(cfg(unix)))]
pub const UNIX: Domain = Domain(libc::AF_UNIX);

/// Domain for low-level packet interface, corresponding to `AF_PACKET`.
#[cfg(all(
feature = "all",
Expand Down Expand Up @@ -460,71 +453,56 @@ impl<'a> MaybeUninitSlice<'a> {
}
}

/// Unix only API.
impl SockAddr {
/// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path.
///
/// # Failure
///
/// Returns an error if the path is longer than `SUN_LEN`.
#[cfg(feature = "all")]
#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "all"))))]
#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable.
pub fn unix<P>(path: P) -> io::Result<SockAddr>
where
P: AsRef<Path>,
{
#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable.
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
// SAFETY: a `sockaddr_storage` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
let len = {
let storage: &mut libc::sockaddr_un =
unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() };

let bytes = path.as_os_str().as_bytes();
let too_long = match bytes.first() {
None => false,
// linux abstract namespaces aren't null-terminated
Some(&0) => bytes.len() > storage.sun_path.len(),
Some(_) => bytes.len() >= storage.sun_path.len(),
};
if too_long {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"path must be shorter than SUN_LEN",
));
}

storage.sun_family = libc::AF_UNIX as sa_family_t;
// Safety: `bytes` and `addr.sun_path` are not overlapping and
// both point to valid memory.
// `storage` was initialized to zero above, so the path is
// already null terminated.
unsafe {
SockAddr::try_init(|storage, len| {
// Safety: `SockAddr::try_init` zeros the address, which is a
// valid representation.
let storage: &mut libc::sockaddr_un = unsafe { &mut *storage.cast() };
let len: &mut socklen_t = unsafe { &mut *len };

let bytes = path.as_ref().as_os_str().as_bytes();
let too_long = match bytes.first() {
None => false,
// linux abstract namespaces aren't null-terminated
Some(&0) => bytes.len() > storage.sun_path.len(),
Some(_) => bytes.len() >= storage.sun_path.len(),
};
if too_long {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"path must be shorter than SUN_LEN",
));
}
ptr::copy_nonoverlapping(
bytes.as_ptr(),
storage.sun_path.as_mut_ptr() as *mut u8,
bytes.len(),
)
};

storage.sun_family = libc::AF_UNIX as sa_family_t;
// Safety: `bytes` and `addr.sun_path` are not overlapping and
// both point to valid memory.
// `SockAddr::try_init` zeroes the memory, so the path is
// already null terminated.
unsafe {
ptr::copy_nonoverlapping(
bytes.as_ptr(),
storage.sun_path.as_mut_ptr() as *mut u8,
bytes.len(),
)
};

let base = storage as *const _ as usize;
let path = &storage.sun_path as *const _ as usize;
let sun_path_offset = path - base;
let length = sun_path_offset
+ bytes.len()
+ match bytes.first() {
Some(&0) | None => 0,
Some(_) => 1,
};
*len = length as socklen_t;

Ok(())
})
}
.map(|(_, addr)| addr)
}
let base = storage as *const _ as usize;
let path = &storage.sun_path as *const _ as usize;
let sun_path_offset = path - base;
sun_path_offset
+ bytes.len()
+ match bytes.first() {
Some(&0) | None => 0,
Some(_) => 1,
}
};
Ok(unsafe { SockAddr::new(storage, len as socklen_t) })
}

/// Unix only API.
impl SockAddr {
/// Constructs a `SockAddr` with the family `AF_VSOCK` and the provided CID/port.
///
/// # Errors
Expand All @@ -538,9 +516,8 @@ impl SockAddr {
doc(cfg(all(feature = "all", any(target_os = "android", target_os = "linux"))))
)]
pub fn vsock(cid: u32, port: u32) -> SockAddr {
// SAFETY: a `sockaddr_storage` of all zeros is valid, hence we can
// safely assume it's initialised.
let mut storage = unsafe { MaybeUninit::<sockaddr_storage>::zeroed().assume_init() };
// SAFETY: a `sockaddr_storage` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
{
let storage: &mut libc::sockaddr_vm =
unsafe { &mut *((&mut storage as *mut sockaddr_storage).cast()) };
Expand Down
48 changes: 46 additions & 2 deletions src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
use std::os::windows::io::{
AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
};
use std::path::Path;
use std::sync::Once;
use std::time::{Duration, Instant};
use std::{process, ptr, slice};
Expand Down Expand Up @@ -41,8 +42,8 @@ pub(crate) const MSG_TRUNC: c_int = 0x01;
// Used in `Domain`.
pub(crate) const AF_INET: c_int = windows_sys::Win32::Networking::WinSock::AF_INET as c_int;
pub(crate) const AF_INET6: c_int = windows_sys::Win32::Networking::WinSock::AF_INET6 as c_int;
const AF_UNIX: c_int = windows_sys::Win32::Networking::WinSock::AF_UNIX as c_int;
const AF_UNSPEC: c_int = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as c_int;
pub(crate) const AF_UNIX: c_int = windows_sys::Win32::Networking::WinSock::AF_UNIX as c_int;
pub(crate) const AF_UNSPEC: c_int = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as c_int;
// Used in `Type`.
pub(crate) const SOCK_STREAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_STREAM as c_int;
pub(crate) const SOCK_DGRAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_DGRAM as c_int;
Expand Down Expand Up @@ -774,6 +775,49 @@ pub(crate) fn to_mreqn(
}
}

#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable.
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
// SAFETY: a `sockaddr_storage` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
let len = {
let storage: &mut windows_sys::Win32::Networking::WinSock::sockaddr_un =
unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() };

// Windows expects a UTF-8 path here even though Windows paths are
// usually UCS-2 encoded. If Rust exposed OsStr's Wtf8 encoded
// buffer, this could be used directly, relying on Windows to
// validate the path, but Rust hides this implementation detail.
//
// See <https://github.com/rust-lang/rust/pull/95290>.
let bytes = path
.to_str()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "path must be valid UTF-8"))?
.as_bytes();

// Windows appears to allow non-null-terminated paths, but this is
// not documented, so do not rely on it yet.
//
// See <https://github.com/rust-lang/socket2/issues/331>.
if bytes.len() >= storage.sun_path.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"path must be shorter than SUN_LEN",
));
}

storage.sun_family = crate::sys::AF_UNIX as sa_family_t;
// `storage` was initialized to zero above, so the path is
// already null terminated.
storage.sun_path[..bytes.len()].copy_from_slice(bytes);

let base = storage as *const _ as usize;
let path = &storage.sun_path as *const _ as usize;
let sun_path_offset = path - base;
sun_path_offset + bytes.len() + 1
};
Ok(unsafe { SockAddr::new(storage, len as socklen_t) })
}

/// Windows only API.
impl crate::Socket {
/// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`.
Expand Down
26 changes: 21 additions & 5 deletions tests/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use std::fs::File;
use std::io;
#[cfg(not(target_os = "redox"))]
use std::io::IoSlice;
#[cfg(all(unix, feature = "all"))]
use std::io::Read;
use std::io::Write;
use std::mem::{self, MaybeUninit};
Expand All @@ -36,7 +35,6 @@ use std::os::windows::io::AsRawSocket;
use std::str;
use std::thread;
use std::time::Duration;
#[cfg(all(unix, feature = "all"))]
use std::{env, fs};

#[cfg(windows)]
Expand All @@ -62,7 +60,6 @@ fn domain_fmt_debug() {
let tests = &[
(Domain::IPV4, "AF_INET"),
(Domain::IPV6, "AF_INET6"),
#[cfg(unix)]
(Domain::UNIX, "AF_UNIX"),
#[cfg(all(feature = "all", any(target_os = "fuchsia", target_os = "linux")))]
(Domain::PACKET, "AF_PACKET"),
Expand Down Expand Up @@ -130,7 +127,6 @@ fn from_invalid_raw_fd_should_panic() {
}

#[test]
#[cfg(all(unix, feature = "all"))]
fn socket_address_unix() {
let string = "/tmp/socket";
let addr = SockAddr::unix(string).unwrap();
Expand Down Expand Up @@ -429,9 +425,29 @@ fn pair() {
assert_eq!(&buf[..n], DATA);
}

fn unix_sockets_supported() -> bool {
#[cfg(windows)]
{
// Only some versions of Windows support Unix sockets.
match Socket::new(Domain::UNIX, Type::STREAM, None) {
Ok(_) => {}
Err(err)
if err.raw_os_error()
== Some(windows_sys::Win32::Networking::WinSock::WSAEAFNOSUPPORT as i32) =>
{
return false;
}
Err(err) => panic!("socket error: {}", err),
}
}
true
}

#[test]
#[cfg(all(feature = "all", unix))]
fn unix() {
if !unix_sockets_supported() {
return;
}
let mut path = env::temp_dir();
path.push("socket2");
let _ = fs::remove_dir_all(&path);
Expand Down

0 comments on commit 5d61324

Please sign in to comment.