Skip to content

Commit

Permalink
Use direct system calls in TcpListener::bind
Browse files Browse the repository at this point in the history
Safes on a systemcall on platforms that support SOCK_NONBLOCK and SOCK_CLOEXEC.
  • Loading branch information
Thomasdezeeuw committed Jul 21, 2019
1 parent cd2c104 commit 66f37de
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 125 deletions.
29 changes: 6 additions & 23 deletions src/net/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
//!
//! [portability guidelines]: ../struct.Poll.html#portability

#[cfg(debug_assertions)]
use crate::poll::SelectorId;
use crate::{event, sys, Interests, Registry, Token};

use net2::TcpBuilder;
use std::fmt;
use std::io::{self, IoSlice, IoSliceMut, Read, Write};
use std::net::{self, SocketAddr};
use std::time::Duration;

#[cfg(debug_assertions)]
use crate::poll::SelectorId;
use crate::{event, sys, Interests, Registry, Token};

/*
*
* ===== TcpStream =====
Expand Down Expand Up @@ -425,24 +424,8 @@ impl TcpListener {
/// combination with the `TcpListener::from_listener` method to transfer
/// ownership into mio.
pub fn bind(addr: SocketAddr) -> io::Result<TcpListener> {
// Create the socket
let sock = match addr {
SocketAddr::V4(..) => TcpBuilder::new_v4(),
SocketAddr::V6(..) => TcpBuilder::new_v6(),
}?;

// Set SO_REUSEADDR, but only on Unix (mirrors what libstd does)
if cfg!(unix) {
sock.reuse_address(true)?;
}

// Bind the socket
sock.bind(addr)?;

// listen
let listener = sock.listen(1024)?;
Ok(TcpListener {
sys: sys::TcpListener::new(listener)?,
sys::TcpListener::bind(addr).map(|sys| TcpListener {
sys,
#[cfg(debug_assertions)]
selector_id: SelectorId::new(),
})
Expand Down
130 changes: 78 additions & 52 deletions src/sys/unix/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt;
use std::io::{self, IoSlice, IoSliceMut, Read, Write};
use std::mem::size_of_val;
use std::mem::{size_of, size_of_val};
use std::net::{self, SocketAddr};
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::time::Duration;
Expand All @@ -21,45 +21,10 @@ pub struct TcpListener {
}

impl TcpStream {
pub fn connect(address: SocketAddr) -> io::Result<TcpStream> {
let domain = match address {
SocketAddr::V4(..) => libc::AF_INET,
SocketAddr::V6(..) => libc::AF_INET6,
};
#[cfg(any(
target_os = "ios", // Darwin doesn't have SOCK_NONBLOCK or SOCK_CLOEXEC.
target_os = "macos",
target_os = "solaris" // Not sure about Solaris, couldn't find anything online.
))]
let socket_type = libc::SOCK_STREAM;
#[cfg(any(
target_os = "android",
target_os = "bitrig",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "linux",
target_os = "netbsd",
target_os = "openbsd"
))]
let socket_type = libc::SOCK_STREAM | libc::SOCK_NONBLOCK | libc::SOCK_CLOEXEC;

let socket = syscall!(socket(domain, socket_type, 0));

#[cfg(any(target_os = "ios", target_os = "macos", target_os = "solaris"))]
let socket = socket.and_then(|socket| {
// For platforms that don't support flags in socket, we need to
// set the flags ourselves.
syscall!(fcntl(
socket,
libc::F_SETFL,
libc::O_NONBLOCK | libc::O_CLOEXEC
))
.map(|_| socket)
});

socket
pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> {
new_socket(addr)
.and_then(|socket| {
let (raw_address, raw_address_length) = socket_address(&address);
let (raw_address, raw_address_length) = socket_address(&addr);
syscall!(connect(socket, raw_address, raw_address_length))
.or_else(|err| match err {
// Connect hasn't finished, but that is fine.
Expand Down Expand Up @@ -162,19 +127,6 @@ impl TcpStream {
}
}

fn socket_address(address: &SocketAddr) -> (*const libc::sockaddr, libc::socklen_t) {
match address {
SocketAddr::V4(ref address) => (
address as *const _ as *const libc::sockaddr,
size_of_val(address) as libc::socklen_t,
),
SocketAddr::V6(ref address) => (
address as *const _ as *const libc::sockaddr,
size_of_val(address) as libc::socklen_t,
),
}
}

impl<'a> Read for &'a TcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
(&self.inner).read(buf)
Expand Down Expand Up @@ -245,6 +197,27 @@ impl AsRawFd for TcpStream {
}

impl TcpListener {
pub fn bind(addr: SocketAddr) -> io::Result<TcpListener> {
new_socket(addr).and_then(|socket| {
// Set SO_REUSEADDR (mirrors what libstd does).
syscall!(setsockopt(
socket,
libc::SOL_SOCKET,
libc::SO_REUSEADDR,
&1 as *const libc::c_int as *const libc::c_void,
size_of::<libc::c_int>() as libc::socklen_t,
))
.and_then(|_| {
let (raw_address, raw_address_length) = socket_address(&addr);
syscall!(bind(socket, raw_address, raw_address_length))
})
.and_then(|_| syscall!(listen(socket, 1024)))
.map(|_| TcpListener {
inner: unsafe { net::TcpListener::from_raw_fd(socket) },
})
})
}

pub fn new(inner: net::TcpListener) -> io::Result<TcpListener> {
set_nonblock(inner.as_raw_fd())?;
Ok(TcpListener { inner })
Expand Down Expand Up @@ -319,3 +292,56 @@ impl AsRawFd for TcpListener {
self.inner.as_raw_fd()
}
}

/// Create a new non-blocking socket.
fn new_socket(addr: SocketAddr) -> io::Result<libc::c_int> {
let domain = match addr {
SocketAddr::V4(..) => libc::AF_INET,
SocketAddr::V6(..) => libc::AF_INET6,
};
#[cfg(any(
target_os = "ios", // Darwin doesn't have SOCK_NONBLOCK or SOCK_CLOEXEC.
target_os = "macos",
target_os = "solaris" // Not sure about Solaris, couldn't find anything online.
))]
let socket_type = libc::SOCK_STREAM;
#[cfg(any(
target_os = "android",
target_os = "bitrig",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "linux",
target_os = "netbsd",
target_os = "openbsd"
))]
let socket_type = libc::SOCK_STREAM | libc::SOCK_NONBLOCK | libc::SOCK_CLOEXEC;

let socket = syscall!(socket(domain, socket_type, 0));

#[cfg(any(target_os = "ios", target_os = "macos", target_os = "solaris"))]
let socket = socket.and_then(|socket| {
// For platforms that don't support flags in socket, we need to
// set the flags ourselves.
syscall!(fcntl(
socket,
libc::F_SETFL,
libc::O_NONBLOCK | libc::O_CLOEXEC
))
.map(|_| socket)
});

socket
}

fn socket_address(address: &SocketAddr) -> (*const libc::sockaddr, libc::socklen_t) {
match address {
SocketAddr::V4(ref address) => (
address as *const _ as *const libc::sockaddr,
size_of_val(address) as libc::socklen_t,
),
SocketAddr::V6(ref address) => (
address as *const _ as *const libc::sockaddr,
size_of_val(address) as libc::socklen_t,
),
}
}
121 changes: 71 additions & 50 deletions src/sys/windows/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ use std::io::{self, IoSlice, IoSliceMut, Read, Write};
use std::mem::size_of_val;
use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
use std::os::windows::raw::SOCKET;
use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64.
use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration;

use net2::TcpStreamExt;
use winapi::ctypes::c_int;
use winapi::shared::ws2def::SOCKADDR;
use winapi::um::winsock2::{
bind, connect, ioctlsocket, socket, FIONBIO, INVALID_SOCKET, PF_INET, PF_INET6, SOCKET_ERROR,
SOCK_STREAM,
bind, connect, ioctlsocket, listen, socket, FIONBIO, INVALID_SOCKET, PF_INET, PF_INET6, SOCKET,
SOCKET_ERROR, SOCK_STREAM,
};

use crate::poll;
Expand Down Expand Up @@ -85,48 +85,36 @@ macro_rules! wouldblock {
}

impl TcpStream {
pub fn connect(address: SocketAddr) -> io::Result<TcpStream> {
let domain = match address {
SocketAddr::V4(..) => PF_INET,
SocketAddr::V6(..) => PF_INET6,
};

syscall!(
socket(domain, SOCK_STREAM, 0),
PartialEq::eq,
INVALID_SOCKET
)
.and_then(|socket| {
syscall!(ioctlsocket(socket, FIONBIO, &mut 1), PartialEq::ne, 0)
pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> {
new_socket(addr)
.and_then(|socket| {
// Required for a future `connect_overlapped` operation to be
// executed successfully.
let any_addr = inaddr_any(addr);
let (raw_addr, raw_addr_length) = socket_addr(&any_addr);
syscall!(
bind(socket, raw_addr, raw_addr_length),
PartialEq::eq,
SOCKET_ERROR
)
.and_then(|_| {
// Required for a future `connect_overlapped` operation to be
// executed successfully.
let any_address = inaddr_any(address);
let (raw_address, raw_address_length) = socket_address(&any_address);
let (raw_addr, raw_addr_length) = socket_addr(&addr);
syscall!(
bind(socket, raw_address, raw_address_length),
connect(socket, raw_addr, raw_addr_length),
PartialEq::eq,
SOCKET_ERROR
)
.and_then(|_| {
let (raw_address, raw_address_length) = socket_address(&address);
syscall!(
connect(socket, raw_address, raw_address_length),
PartialEq::eq,
SOCKET_ERROR
)
.or_else(|err| match err {
ref err if err.kind() == io::ErrorKind::WouldBlock => Ok(0),
err => Err(err),
})
.or_else(|err| match err {
ref err if err.kind() == io::ErrorKind::WouldBlock => Ok(0),
err => Err(err),
})
})
.map(|_| socket)
})
.map(|socket| TcpStream {
internal: Arc::new(RwLock::new(None)),
inner: unsafe { net::TcpStream::from_raw_socket(socket as SOCKET) },
})
})
.map(|socket| TcpStream {
internal: Arc::new(RwLock::new(None)),
inner: unsafe { net::TcpStream::from_raw_socket(socket as StdSocket) },
})
}

pub fn connect_stream(stream: net::TcpStream, addr: SocketAddr) -> io::Result<TcpStream> {
Expand Down Expand Up @@ -242,19 +230,6 @@ fn inaddr_any(other: SocketAddr) -> SocketAddr {
}
}

fn socket_address(address: &SocketAddr) -> (*const SOCKADDR, c_int) {
match address {
SocketAddr::V4(ref address) => (
address as *const _ as *const SOCKADDR,
size_of_val(address) as c_int,
),
SocketAddr::V6(ref address) => (
address as *const _ as *const SOCKADDR,
size_of_val(address) as c_int,
),
}
}

impl super::SocketState for TcpStream {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
let internal = self.internal.read().unwrap();
Expand Down Expand Up @@ -443,6 +418,22 @@ impl AsRawSocket for TcpStream {
}

impl TcpListener {
pub fn bind(addr: SocketAddr) -> io::Result<TcpListener> {
new_socket(addr).and_then(|socket| {
let (raw_addr, raw_addr_length) = socket_addr(&addr);
syscall!(
bind(socket, raw_addr, raw_addr_length,),
PartialEq::eq,
SOCKET_ERROR
)
.and_then(|_| syscall!(listen(socket, 1024), PartialEq::eq, SOCKET_ERROR))
.map(|_| TcpListener {
internal: Arc::new(RwLock::new(None)),
inner: unsafe { net::TcpListener::from_raw_socket(socket as StdSocket) },
})
})
}

pub fn new(inner: net::TcpListener) -> io::Result<TcpListener> {
inner.set_nonblocking(true)?;
Ok(TcpListener {
Expand Down Expand Up @@ -595,3 +586,33 @@ impl AsRawSocket for TcpListener {
self.inner.as_raw_socket()
}
}

/// Create a new non-blocking socket.
fn new_socket(addr: SocketAddr) -> io::Result<SOCKET> {
let domain = match addr {
SocketAddr::V4(..) => PF_INET,
SocketAddr::V6(..) => PF_INET6,
};

syscall!(
socket(domain, SOCK_STREAM, 0),
PartialEq::eq,
INVALID_SOCKET
)
.and_then(|socket| {
syscall!(ioctlsocket(socket, FIONBIO, &mut 1), PartialEq::ne, 0).map(|_| socket as SOCKET)
})
}

fn socket_addr(addr: &SocketAddr) -> (*const SOCKADDR, c_int) {
match addr {
SocketAddr::V4(ref addr) => (
addr as *const _ as *const SOCKADDR,
size_of_val(addr) as c_int,
),
SocketAddr::V6(ref addr) => (
addr as *const _ as *const SOCKADDR,
size_of_val(addr) as c_int,
),
}
}

0 comments on commit 66f37de

Please sign in to comment.