Skip to content

Commit

Permalink
Windows: Add support for Unix sockets
Browse files Browse the repository at this point in the history
Newer versions of Windows support AF_UNIX stream sockets. This change
adds Windows support for the `SockAddr::unix` function and the
`Domain::UNIX` constant.

Since Unix sockets are now available on all tier1 platforms, this also
removes `all` feature requirement from the `SockAddr::unix` function.
  • Loading branch information
jstarks committed Jul 22, 2021
1 parent 1c67209 commit 20b6d40
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 74 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ add new tests.

All types and methods that are available on all tier 1 platforms are defined in
the first level of the source, i.e. `src/*.rs` files. Additional API that is
platform specific, e.g. `Domain::UNIX`, is defined in `src/sys/*.rs` and only
platform specific, e.g. `Domain::VSOCK`, is defined in `src/sys/*.rs` and only
for the platforms that support it. For API that is not available on all tier 1
platforms the `all` feature is used, to indicate to the user that they're using
API that might is not available on all platforms.
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ impl Domain {
/// Domain for IPv6 communication, corresponding to `AF_INET6`.
pub const IPV6: Domain = Domain(sys::AF_INET6);

/// Domain for Unix socket communication, corresponding to `AF_UNIX`.
pub const UNIX: Domain = Domain(sys::AF_UNIX);

/// Returns the correct domain for `address`.
pub const fn for_address(address: SocketAddr) -> Domain {
match address {
Expand Down
13 changes: 13 additions & 0 deletions src/sockaddr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::mem::{self, size_of, MaybeUninit};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::path::Path;
use std::{fmt, io};

use crate::sys::{
Expand Down Expand Up @@ -208,6 +209,18 @@ impl SockAddr {
_ => None,
}
}

/// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path.
///
/// # Failure
///
/// Returns an error if the path is longer than `SUN_LEN`.
pub fn unix<P>(path: P) -> io::Result<SockAddr>
where
P: AsRef<Path>,
{
crate::sys::unix_sockaddr(path.as_ref())
}
}

impl From<SocketAddr> for SockAddr {
Expand Down
115 changes: 48 additions & 67 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use std::num::NonZeroU32;
)
))]
use std::num::NonZeroUsize;
#[cfg(feature = "all")]
use std::os::unix::ffi::OsStrExt;
#[cfg(all(
feature = "all",
Expand All @@ -40,7 +39,6 @@ use std::os::unix::io::RawFd;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
#[cfg(feature = "all")]
use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
#[cfg(feature = "all")]
use std::path::Path;
#[cfg(not(all(target_os = "redox", not(feature = "all"))))]
use std::ptr;
Expand All @@ -58,7 +56,7 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
pub(crate) use libc::c_int;

// Used in `Domain`.
pub(crate) use libc::{AF_INET, AF_INET6};
pub(crate) use libc::{AF_INET, AF_INET6, AF_UNIX};
// Used in `Type`.
#[cfg(all(feature = "all", not(target_os = "redox")))]
pub(crate) use libc::SOCK_RAW;
Expand Down Expand Up @@ -204,10 +202,6 @@ type IovLen = c_int;

/// Unix only API.
impl Domain {
/// Domain for Unix socket communication, corresponding to `AF_UNIX`.
#[cfg_attr(docsrs, doc(cfg(unix)))]
pub const UNIX: Domain = Domain(libc::AF_UNIX);

/// Domain for low-level packet interface, corresponding to `AF_PACKET`.
#[cfg(all(
feature = "all",
Expand Down Expand Up @@ -436,70 +430,57 @@ impl<'a> MaybeUninitSlice<'a> {
}
}

/// Unix only API.
impl SockAddr {
/// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path.
///
/// # Failure
///
/// Returns an error if the path is longer than `SUN_LEN`.
#[cfg(feature = "all")]
#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "all"))))]
#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable.
pub fn unix<P>(path: P) -> io::Result<SockAddr>
where
P: AsRef<Path>,
{
unsafe {
SockAddr::init(|storage, len| {
// Safety: `SockAddr::init` zeros the address, which is a valid
// representation.
let storage: &mut libc::sockaddr_un = unsafe { &mut *storage.cast() };
let len: &mut socklen_t = unsafe { &mut *len };

let bytes = path.as_ref().as_os_str().as_bytes();
let too_long = match bytes.first() {
None => false,
// linux abstract namespaces aren't null-terminated
Some(&0) => bytes.len() > storage.sun_path.len(),
Some(_) => bytes.len() >= storage.sun_path.len(),
};
if too_long {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"path must be shorter than SUN_LEN",
));
}
#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable.
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
unsafe {
SockAddr::init(|storage, len| {
// Safety: `SockAddr::init` zeros the address, which is a valid
// representation.
let storage: &mut libc::sockaddr_un = unsafe { &mut *storage.cast() };
let len: &mut socklen_t = unsafe { &mut *len };

let bytes = path.as_os_str().as_bytes();
let too_long = match bytes.first() {
None => false,
// linux abstract namespaces aren't null-terminated
Some(&0) => bytes.len() > storage.sun_path.len(),
Some(_) => bytes.len() >= storage.sun_path.len(),
};
if too_long {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"path must be shorter than SUN_LEN",
));
}

storage.sun_family = libc::AF_UNIX as sa_family_t;
// Safety: `bytes` and `addr.sun_path` are not overlapping and
// both point to valid memory.
// `SockAddr::init` zeroes the memory, so the path is already
// null terminated.
unsafe {
ptr::copy_nonoverlapping(
bytes.as_ptr(),
storage.sun_path.as_mut_ptr() as *mut u8,
bytes.len(),
)
storage.sun_family = libc::AF_UNIX as sa_family_t;
// Safety: `bytes` and `addr.sun_path` are not overlapping and
// both point to valid memory.
// `SockAddr::init` zeroes the memory, so the path is already
// null terminated.
unsafe {
ptr::copy_nonoverlapping(
bytes.as_ptr(),
storage.sun_path.as_mut_ptr() as *mut u8,
bytes.len(),
)
};

let base = storage as *const _ as usize;
let path = &storage.sun_path as *const _ as usize;
let sun_path_offset = path - base;
let length = sun_path_offset
+ bytes.len()
+ match bytes.first() {
Some(&0) | None => 0,
Some(_) => 1,
};
*len = length as socklen_t;

let base = storage as *const _ as usize;
let path = &storage.sun_path as *const _ as usize;
let sun_path_offset = path - base;
let length = sun_path_offset
+ bytes.len()
+ match bytes.first() {
Some(&0) | None => 0,
Some(_) => 1,
};
*len = length as socklen_t;

Ok(())
})
}
.map(|(_, addr)| addr)
Ok(())
})
}
.map(|(_, addr)| addr)
}

impl SockAddr {
Expand Down
49 changes: 48 additions & 1 deletion src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::marker::PhantomData;
use std::mem::{self, size_of, MaybeUninit};
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
use std::os::windows::prelude::*;
use std::path::Path;
use std::sync::Once;
use std::time::{Duration, Instant};
use std::{ptr, slice};
Expand Down Expand Up @@ -44,7 +45,7 @@ pub(crate) use winapi::ctypes::c_int;
pub(crate) const MSG_TRUNC: c_int = 0x01;

// Used in `Domain`.
pub(crate) use winapi::shared::ws2def::{AF_INET, AF_INET6};
pub(crate) use winapi::shared::ws2def::{AF_INET, AF_INET6, AF_UNIX};
// Used in `Type`.
pub(crate) use winapi::shared::ws2def::{SOCK_DGRAM, SOCK_STREAM};
#[cfg(feature = "all")]
Expand Down Expand Up @@ -735,6 +736,52 @@ pub(crate) fn from_in6_addr(addr: in6_addr) -> Ipv6Addr {
Ipv6Addr::from(*unsafe { addr.u.Byte() })
}

/// This type is not yet in winapi.
#[repr(C)]
#[allow(non_camel_case_types)]
struct sockaddr_un {
pub sun_family: sa_family_t,
pub sun_path: [u8; 108],
}

#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable.
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
unsafe {
SockAddr::init(|storage, len| {
// Safety: `SockAddr::init` zeros the address, which is a valid
// representation.
let storage: &mut crate::sys::sockaddr_un = unsafe { &mut *storage.cast() };
let len: &mut socklen_t = unsafe { &mut *len };

let bytes = path
.to_str()
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "path must be valid UTF-8")
})?
.as_bytes();

if bytes.len() > storage.sun_path.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"path must be shorter than SUN_LEN",
));
}

storage.sun_family = crate::sys::AF_UNIX as sa_family_t;
storage.sun_path[..bytes.len()].copy_from_slice(bytes);

let base = storage as *const _ as usize;
let path = &storage.sun_path as *const _ as usize;
let sun_path_offset = path - base;
let length = sun_path_offset + bytes.len();
*len = length as socklen_t;

Ok(())
})
}
.map(|(_, addr)| addr)
}

/// Windows only API.
impl crate::Socket {
/// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`.
Expand Down
25 changes: 20 additions & 5 deletions tests/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use std::fs::File;
use std::io;
#[cfg(not(target_os = "redox"))]
use std::io::IoSlice;
#[cfg(all(unix, feature = "all"))]
use std::io::Read;
use std::io::Write;
use std::mem::{self, MaybeUninit};
Expand All @@ -35,7 +34,6 @@ use std::os::windows::io::AsRawSocket;
use std::str;
use std::thread;
use std::time::Duration;
#[cfg(all(unix, feature = "all"))]
use std::{env, fs};

#[cfg(windows)]
Expand Down Expand Up @@ -65,7 +63,6 @@ fn domain_fmt_debug() {
let tests = &[
(Domain::IPV4, "AF_INET"),
(Domain::IPV6, "AF_INET6"),
#[cfg(unix)]
(Domain::UNIX, "AF_UNIX"),
#[cfg(all(feature = "all", any(target_os = "fuchsia", target_os = "linux")))]
(Domain::PACKET, "AF_PACKET"),
Expand Down Expand Up @@ -133,7 +130,6 @@ fn from_invalid_raw_fd_should_panic() {
}

#[test]
#[cfg(all(unix, feature = "all"))]
fn socket_address_unix() {
let string = "/tmp/socket";
let addr = SockAddr::unix(string).unwrap();
Expand Down Expand Up @@ -432,9 +428,28 @@ fn pair() {
assert_eq!(&buf[..n], DATA);
}

fn unix_sockets_supported() -> bool {
#[cfg(windows)]
{
// Only some versions of Windows support Unix sockets.
match Socket::new(Domain::UNIX, Type::STREAM, None) {
Ok(_) => {}
Err(err)
if err.raw_os_error() == Some(winapi::um::winsock2::WSAEAFNOSUPPORT as i32) =>
{
return false;
}
Err(err) => panic!("socket error: {}", err),
}
}
true
}

#[test]
#[cfg(all(feature = "all", unix))]
fn unix() {
if !unix_sockets_supported() {
return;
}
let mut path = env::temp_dir();
path.push("socket2");
let _ = fs::remove_dir_all(&path);
Expand Down

0 comments on commit 20b6d40

Please sign in to comment.