Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows: Add support for Unix sockets #249

Merged
merged 1 commit into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ add new tests.

All types and methods that are available on all tier 1 platforms are defined in
the first level of the source, i.e. `src/*.rs` files. Additional API that is
platform specific, e.g. `Domain::UNIX`, is defined in `src/sys/*.rs` and only
platform specific, e.g. `Domain::VSOCK`, is defined in `src/sys/*.rs` and only
for the platforms that support it. For API that is not available on all tier 1
platforms the `all` feature is used, to indicate to the user that they're using
API that might is not available on all platforms.
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ impl Domain {
/// Domain for IPv6 communication, corresponding to `AF_INET6`.
pub const IPV6: Domain = Domain(sys::AF_INET6);

/// Domain for Unix socket communication, corresponding to `AF_UNIX`.
pub const UNIX: Domain = Domain(sys::AF_UNIX);

/// Returns the correct domain for `address`.
pub const fn for_address(address: SocketAddr) -> Domain {
match address {
Expand Down
11 changes: 11 additions & 0 deletions src/sockaddr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::mem::{self, size_of, MaybeUninit};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::path::Path;
use std::{fmt, io};

#[cfg(windows)]
Expand Down Expand Up @@ -210,6 +211,16 @@ impl SockAddr {
_ => None,
}
}

/// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path.
///
/// Returns an error if the path is longer than `SUN_LEN`.
pub fn unix<P>(path: P) -> io::Result<SockAddr>
where
P: AsRef<Path>,
{
crate::sys::unix_sockaddr(path.as_ref())
}
}

impl From<SocketAddr> for SockAddr {
Expand Down
123 changes: 50 additions & 73 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use std::num::NonZeroU32;
)
))]
use std::num::NonZeroUsize;
#[cfg(feature = "all")]
use std::os::unix::ffi::OsStrExt;
#[cfg(all(
feature = "all",
Expand All @@ -40,9 +39,7 @@ use std::os::unix::io::RawFd;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd};
#[cfg(feature = "all")]
use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
#[cfg(feature = "all")]
use std::path::Path;
#[cfg(not(all(target_os = "redox", not(feature = "all"))))]
use std::ptr;
use std::time::{Duration, Instant};
use std::{io, slice};
Expand All @@ -58,7 +55,7 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
pub(crate) use libc::c_int;

// Used in `Domain`.
pub(crate) use libc::{AF_INET, AF_INET6};
pub(crate) use libc::{AF_INET, AF_INET6, AF_UNIX};
// Used in `Type`.
#[cfg(all(feature = "all", not(target_os = "redox")))]
pub(crate) use libc::SOCK_RAW;
Expand Down Expand Up @@ -222,10 +219,6 @@ type IovLen = c_int;

/// Unix only API.
impl Domain {
/// Domain for Unix socket communication, corresponding to `AF_UNIX`.
#[cfg_attr(docsrs, doc(cfg(unix)))]
pub const UNIX: Domain = Domain(libc::AF_UNIX);

/// Domain for low-level packet interface, corresponding to `AF_PACKET`.
#[cfg(all(
feature = "all",
Expand Down Expand Up @@ -460,71 +453,56 @@ impl<'a> MaybeUninitSlice<'a> {
}
}

/// Unix only API.
impl SockAddr {
/// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path.
///
/// # Failure
///
/// Returns an error if the path is longer than `SUN_LEN`.
#[cfg(feature = "all")]
#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "all"))))]
#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable.
pub fn unix<P>(path: P) -> io::Result<SockAddr>
where
P: AsRef<Path>,
{
#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable.
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
// SAFETY: a `sockaddr_storage` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
let len = {
let storage: &mut libc::sockaddr_un =
unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() };

let bytes = path.as_os_str().as_bytes();
let too_long = match bytes.first() {
None => false,
// linux abstract namespaces aren't null-terminated
Some(&0) => bytes.len() > storage.sun_path.len(),
Some(_) => bytes.len() >= storage.sun_path.len(),
};
if too_long {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"path must be shorter than SUN_LEN",
));
}

storage.sun_family = libc::AF_UNIX as sa_family_t;
// Safety: `bytes` and `addr.sun_path` are not overlapping and
// both point to valid memory.
// `storage` was initialized to zero above, so the path is
// already null terminated.
unsafe {
SockAddr::try_init(|storage, len| {
// Safety: `SockAddr::try_init` zeros the address, which is a
// valid representation.
let storage: &mut libc::sockaddr_un = unsafe { &mut *storage.cast() };
let len: &mut socklen_t = unsafe { &mut *len };

let bytes = path.as_ref().as_os_str().as_bytes();
let too_long = match bytes.first() {
None => false,
// linux abstract namespaces aren't null-terminated
Some(&0) => bytes.len() > storage.sun_path.len(),
Some(_) => bytes.len() >= storage.sun_path.len(),
};
if too_long {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"path must be shorter than SUN_LEN",
));
}
ptr::copy_nonoverlapping(
bytes.as_ptr(),
storage.sun_path.as_mut_ptr() as *mut u8,
bytes.len(),
)
};

storage.sun_family = libc::AF_UNIX as sa_family_t;
// Safety: `bytes` and `addr.sun_path` are not overlapping and
// both point to valid memory.
// `SockAddr::try_init` zeroes the memory, so the path is
// already null terminated.
unsafe {
ptr::copy_nonoverlapping(
bytes.as_ptr(),
storage.sun_path.as_mut_ptr() as *mut u8,
bytes.len(),
)
};

let base = storage as *const _ as usize;
let path = &storage.sun_path as *const _ as usize;
let sun_path_offset = path - base;
let length = sun_path_offset
+ bytes.len()
+ match bytes.first() {
Some(&0) | None => 0,
Some(_) => 1,
};
*len = length as socklen_t;

Ok(())
})
}
.map(|(_, addr)| addr)
}
let base = storage as *const _ as usize;
let path = &storage.sun_path as *const _ as usize;
let sun_path_offset = path - base;
sun_path_offset
+ bytes.len()
+ match bytes.first() {
Some(&0) | None => 0,
Some(_) => 1,
}
};
Ok(unsafe { SockAddr::new(storage, len as socklen_t) })
}

/// Unix only API.
impl SockAddr {
/// Constructs a `SockAddr` with the family `AF_VSOCK` and the provided CID/port.
///
/// # Errors
Expand All @@ -538,9 +516,8 @@ impl SockAddr {
doc(cfg(all(feature = "all", any(target_os = "android", target_os = "linux"))))
)]
pub fn vsock(cid: u32, port: u32) -> SockAddr {
// SAFETY: a `sockaddr_storage` of all zeros is valid, hence we can
// safely assume it's initialised.
let mut storage = unsafe { MaybeUninit::<sockaddr_storage>::zeroed().assume_init() };
// SAFETY: a `sockaddr_storage` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
{
let storage: &mut libc::sockaddr_vm =
unsafe { &mut *((&mut storage as *mut sockaddr_storage).cast()) };
Expand Down
48 changes: 46 additions & 2 deletions src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
use std::os::windows::io::{
AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
};
use std::path::Path;
use std::sync::Once;
use std::time::{Duration, Instant};
use std::{process, ptr, slice};
Expand Down Expand Up @@ -41,8 +42,8 @@ pub(crate) const MSG_TRUNC: c_int = 0x01;
// Used in `Domain`.
pub(crate) const AF_INET: c_int = windows_sys::Win32::Networking::WinSock::AF_INET as c_int;
pub(crate) const AF_INET6: c_int = windows_sys::Win32::Networking::WinSock::AF_INET6 as c_int;
const AF_UNIX: c_int = windows_sys::Win32::Networking::WinSock::AF_UNIX as c_int;
const AF_UNSPEC: c_int = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as c_int;
pub(crate) const AF_UNIX: c_int = windows_sys::Win32::Networking::WinSock::AF_UNIX as c_int;
pub(crate) const AF_UNSPEC: c_int = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as c_int;
// Used in `Type`.
pub(crate) const SOCK_STREAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_STREAM as c_int;
pub(crate) const SOCK_DGRAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_DGRAM as c_int;
Expand Down Expand Up @@ -774,6 +775,49 @@ pub(crate) fn to_mreqn(
}
}

#[allow(unused_unsafe)] // TODO: replace with `unsafe_op_in_unsafe_fn` once stable.
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
// SAFETY: a `sockaddr_storage` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
let len = {
let storage: &mut windows_sys::Win32::Networking::WinSock::sockaddr_un =
unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() };

// Windows expects a UTF-8 path here even though Windows paths are
// usually UCS-2 encoded. If Rust exposed OsStr's Wtf8 encoded
// buffer, this could be used directly, relying on Windows to
// validate the path, but Rust hides this implementation detail.
//
// See <https://github.com/rust-lang/rust/pull/95290>.
let bytes = path
.to_str()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "path must be valid UTF-8"))?
.as_bytes();

// Windows appears to allow non-null-terminated paths, but this is
// not documented, so do not rely on it yet.
//
// See <https://github.com/rust-lang/socket2/issues/331>.
if bytes.len() >= storage.sun_path.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"path must be shorter than SUN_LEN",
));
}

storage.sun_family = crate::sys::AF_UNIX as sa_family_t;
// `storage` was initialized to zero above, so the path is
// already null terminated.
storage.sun_path[..bytes.len()].copy_from_slice(bytes);

let base = storage as *const _ as usize;
let path = &storage.sun_path as *const _ as usize;
let sun_path_offset = path - base;
sun_path_offset + bytes.len() + 1
};
Ok(unsafe { SockAddr::new(storage, len as socklen_t) })
}

/// Windows only API.
impl crate::Socket {
/// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`.
Expand Down
26 changes: 21 additions & 5 deletions tests/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use std::fs::File;
use std::io;
#[cfg(not(target_os = "redox"))]
use std::io::IoSlice;
#[cfg(all(unix, feature = "all"))]
use std::io::Read;
use std::io::Write;
use std::mem::{self, MaybeUninit};
Expand All @@ -36,7 +35,6 @@ use std::os::windows::io::AsRawSocket;
use std::str;
use std::thread;
use std::time::Duration;
#[cfg(all(unix, feature = "all"))]
use std::{env, fs};

#[cfg(windows)]
Expand All @@ -62,7 +60,6 @@ fn domain_fmt_debug() {
let tests = &[
(Domain::IPV4, "AF_INET"),
(Domain::IPV6, "AF_INET6"),
#[cfg(unix)]
(Domain::UNIX, "AF_UNIX"),
#[cfg(all(feature = "all", any(target_os = "fuchsia", target_os = "linux")))]
(Domain::PACKET, "AF_PACKET"),
Expand Down Expand Up @@ -130,7 +127,6 @@ fn from_invalid_raw_fd_should_panic() {
}

#[test]
#[cfg(all(unix, feature = "all"))]
fn socket_address_unix() {
let string = "/tmp/socket";
let addr = SockAddr::unix(string).unwrap();
Expand Down Expand Up @@ -429,9 +425,29 @@ fn pair() {
assert_eq!(&buf[..n], DATA);
}

fn unix_sockets_supported() -> bool {
#[cfg(windows)]
{
// Only some versions of Windows support Unix sockets.
match Socket::new(Domain::UNIX, Type::STREAM, None) {
Ok(_) => {}
Err(err)
if err.raw_os_error()
== Some(windows_sys::Win32::Networking::WinSock::WSAEAFNOSUPPORT as i32) =>
{
return false;
}
Err(err) => panic!("socket error: {}", err),
}
}
true
}

#[test]
#[cfg(all(feature = "all", unix))]
fn unix() {
if !unix_sockets_supported() {
return;
}
let mut path = env::temp_dir();
path.push("socket2");
let _ = fs::remove_dir_all(&path);
Expand Down