Skip to content

Commit

Permalink
Add TcpSocket, a basic TCP socket builder (#1358)
Browse files Browse the repository at this point in the history
This provides `TcpSocket`, a basic API for building a TCP socket. The
goal is not to provide comprehensive coverage of all system options, but
to provide an API for the most common cases.

This is added now as a replacement for the removal of
`TcpStream::connect_std`. The `connect_std` function from v0.6 was used
until now as the strategy to set socket option before obtaining a mio
TcpStream.

Providing some strategy for customizing a `TcpStream` is required for
Hyper to be able to upgrade.
  • Loading branch information
carllerche committed Oct 6, 2020
1 parent 87a15fc commit 5b09e60
Show file tree
Hide file tree
Showing 12 changed files with 316 additions and 126 deletions.
2 changes: 1 addition & 1 deletion src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

cfg_tcp! {
mod tcp;
pub use self::tcp::{TcpListener, TcpStream};
pub use self::tcp::{TcpListener, TcpSocket, TcpStream};
}

cfg_udp! {
Expand Down
17 changes: 15 additions & 2 deletions src/net/tcp/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
use std::{fmt, io};

use super::TcpStream;
use super::{TcpSocket, TcpStream};
use crate::io_source::IoSource;
use crate::{event, sys, Interest, Registry, Token};

Expand Down Expand Up @@ -49,7 +49,20 @@ impl TcpListener {
/// 3. Bind the socket to the specified address.
/// 4. Calls `listen` on the socket to prepare it to receive new connections.
pub fn bind(addr: SocketAddr) -> io::Result<TcpListener> {
sys::tcp::bind(addr).map(TcpListener::from_std)
let socket = TcpSocket::new_for_addr(addr)?;

// On platforms with Berkeley-derived sockets, this allows to quickly
// rebind a socket, without needing to wait for the OS to clean up the
// previous one.
//
// On Windows, this allows rebinding sockets which are actively in use,
// which allows “socket hijacking”, so we explicitly don't set it here.
// https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
#[cfg(not(windows))]
socket.set_reuseaddr(true)?;

socket.bind(addr)?;
socket.listen(1024)
}

/// Creates a new `TcpListener` from a standard `net::TcpListener`.
Expand Down
3 changes: 3 additions & 0 deletions src/net/tcp/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
mod listener;
pub use self::listener::TcpListener;

mod socket;
pub use self::socket::TcpSocket;

mod stream;
pub use self::stream::TcpStream;
107 changes: 107 additions & 0 deletions src/net/tcp/socket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use crate::net::{TcpStream, TcpListener};
use crate::sys;

use std::io;
use std::mem;
use std::net::SocketAddr;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd, FromRawFd};

/// A non-blocking TCP socket used to configure a stream or listener.
///
/// The `TcpSocket` type wraps the operating-system's socket handle. This type
/// is used to configure the socket before establishing a connection or start
/// listening for inbound connections.
///
/// The socket will be closed when the value is dropped.
#[derive(Debug)]
pub struct TcpSocket {
sys: sys::tcp::TcpSocket,
}

impl TcpSocket {
/// Create a new IPv4 TCP socket.
///
/// This calls `socket(2)`.
pub fn new_v4() -> io::Result<TcpSocket> {
sys::tcp::new_v4_socket().map(|sys| TcpSocket {
sys
})
}

/// Create a new IPv6 TCP socket.
///
/// This calls `socket(2)`.
pub fn new_v6() -> io::Result<TcpSocket> {
sys::tcp::new_v6_socket().map(|sys| TcpSocket {
sys
})
}

pub(crate) fn new_for_addr(addr: SocketAddr) -> io::Result<TcpSocket> {
if addr.is_ipv4() {
TcpSocket::new_v4()
} else {
TcpSocket::new_v6()
}
}

/// Bind `addr` to the TCP socket.
pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
sys::tcp::bind(self.sys, addr)
}

/// Connect the socket to `addr`.
///
/// This consumes the socket and performs the connect operation. Once the
/// connection completes, the socket is now a non-blocking `TcpStream` and
/// can be used as such.
pub fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
let stream = sys::tcp::connect(self.sys, addr)?;

// Don't close the socket
mem::forget(self);
Ok(TcpStream::from_std(stream))
}

/// Listen for inbound connections, converting the socket to a
/// `TcpListener`.
pub fn listen(self, backlog: u32) -> io::Result<TcpListener> {
let listener = sys::tcp::listen(self.sys, backlog)?;

// Don't close the socket
mem::forget(self);
Ok(TcpListener::from_std(listener))
}

/// Sets the value of `SO_REUSEADDR` on this socket.
pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> {
sys::tcp::set_reuseaddr(self.sys, reuseaddr)
}
}

impl Drop for TcpSocket {
fn drop(&mut self) {
sys::tcp::close(self.sys);
}
}

#[cfg(unix)]
impl AsRawFd for TcpSocket {
fn as_raw_fd(&self) -> RawFd {
self.sys
}
}

#[cfg(unix)]
impl FromRawFd for TcpSocket {
/// Converts a `RawFd` to a `TcpStream`.
///
/// # Notes
///
/// The caller is responsible for ensuring that the socket is in
/// non-blocking mode.
unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket {
TcpSocket { sys: fd }
}
}
6 changes: 4 additions & 2 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};

use crate::io_source::IoSource;
use crate::{event, sys, Interest, Registry, Token};
use crate::{event, Interest, Registry, Token};
use crate::net::TcpSocket;

/// A non-blocking TCP stream between a local socket and a remote socket.
///
Expand Down Expand Up @@ -48,7 +49,8 @@ impl TcpStream {
/// Create a new TCP stream and issue a non-blocking connect to the
/// specified address.
pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> {
sys::tcp::connect(addr).map(TcpStream::from_std)
let socket = TcpSocket::new_for_addr(addr)?;
socket.connect(addr)
}

/// Creates a new `TcpStream` from a standard `net::TcpStream`.
Expand Down
26 changes: 24 additions & 2 deletions src/sys/shell/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
use std::io;
use std::net::{self, SocketAddr};

pub fn connect(_: SocketAddr) -> io::Result<net::TcpStream> {
pub(crate) type TcpSocket = i32;

pub(crate) fn new_v4_socket() -> io::Result<TcpSocket> {
os_required!();
}

pub(crate) fn new_v6_socket() -> io::Result<TcpSocket> {
os_required!();
}

pub(crate) fn bind(_socket: TcpSocket, _addr: SocketAddr) -> io::Result<()> {
os_required!();
}

pub(crate) fn connect(_: TcpSocket, _addr: SocketAddr) -> io::Result<net::TcpStream> {
os_required!();
}

pub(crate) fn listen(_: TcpSocket, _: u32) -> io::Result<net::TcpListener> {
os_required!();
}

pub(crate) fn close(_: TcpSocket) {
os_required!();
}

pub fn bind(_: SocketAddr) -> io::Result<net::TcpListener> {
pub(crate) fn set_reuseaddr(_: TcpSocket, _: bool) -> io::Result<()> {
os_required!();
}

Expand Down
2 changes: 1 addition & 1 deletion src/sys/unix/net.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(all(feature = "os-poll", any(feature = "tcp", feature = "udp")))]
use std::net::SocketAddr;

#[cfg(all(feature = "os-poll", any(feature = "tcp", feature = "udp")))]
#[cfg(all(feature = "os-poll", any(feature = "udp")))]
pub(crate) fn new_ip_socket(
addr: SocketAddr,
socket_type: libc::c_int,
Expand Down
93 changes: 50 additions & 43 deletions src/sys/unix/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,59 @@ use std::mem::{size_of, MaybeUninit};
use std::net::{self, SocketAddr};
use std::os::unix::io::{AsRawFd, FromRawFd};

use crate::sys::unix::net::{new_ip_socket, socket_addr, to_socket_addr};
use crate::sys::unix::net::{new_socket, socket_addr, to_socket_addr};

pub fn connect(addr: SocketAddr) -> io::Result<net::TcpStream> {
new_ip_socket(addr, libc::SOCK_STREAM)
.and_then(|socket| {
let (raw_addr, raw_addr_length) = socket_addr(&addr);
syscall!(connect(socket, raw_addr, raw_addr_length))
.or_else(|err| match err {
// Connect hasn't finished, but that is fine.
ref err if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(0),
err => Err(err),
})
.map(|_| socket)
.map_err(|err| {
// Close the socket if we hit an error, ignoring the error
// from closing since we can't pass back two errors.
let _ = unsafe { libc::close(socket) };
err
})
})
.map(|socket| unsafe { net::TcpStream::from_raw_fd(socket) })
pub type TcpSocket = libc::c_int;

pub(crate) fn new_v4_socket() -> io::Result<TcpSocket> {
new_socket(libc::AF_INET, libc::SOCK_STREAM)
}

pub fn bind(addr: SocketAddr) -> io::Result<net::TcpListener> {
new_ip_socket(addr, libc::SOCK_STREAM).and_then(|socket| {
// Set SO_REUSEADDR (mirrors what libstd does).
syscall!(setsockopt(
socket,
libc::SOL_SOCKET,
libc::SO_REUSEADDR,
&1 as *const libc::c_int as *const libc::c_void,
size_of::<libc::c_int>() as libc::socklen_t,
))
.and_then(|_| {
let (raw_addr, raw_addr_length) = socket_addr(&addr);
syscall!(bind(socket, raw_addr, raw_addr_length))
})
.and_then(|_| syscall!(listen(socket, 1024)))
.map_err(|err| {
// Close the socket if we hit an error, ignoring the error
// from closing since we can't pass back two errors.
let _ = unsafe { libc::close(socket) };
err
})
.map(|_| unsafe { net::TcpListener::from_raw_fd(socket) })
})
pub(crate) fn new_v6_socket() -> io::Result<TcpSocket> {
new_socket(libc::AF_INET6, libc::SOCK_STREAM)
}

pub(crate) fn bind(socket: TcpSocket, addr: SocketAddr) -> io::Result<()> {
let (raw_addr, raw_addr_length) = socket_addr(&addr);
syscall!(bind(socket, raw_addr, raw_addr_length))?;
Ok(())
}

pub(crate) fn connect(socket: TcpSocket, addr: SocketAddr) -> io::Result<net::TcpStream> {
let (raw_addr, raw_addr_length) = socket_addr(&addr);

match syscall!(connect(socket, raw_addr, raw_addr_length)) {
Err(err) if err.raw_os_error() != Some(libc::EINPROGRESS) => {
Err(err)
}
_ => {
Ok(unsafe { net::TcpStream::from_raw_fd(socket) })
}
}
}

pub(crate) fn listen(socket: TcpSocket, backlog: u32) -> io::Result<net::TcpListener> {
use std::convert::TryInto;

let backlog = backlog.try_into().unwrap_or(i32::max_value());
syscall!(listen(socket, backlog))?;
Ok(unsafe { net::TcpListener::from_raw_fd(socket) })
}

pub(crate) fn close(socket: TcpSocket) {
let _ = unsafe { net::TcpStream::from_raw_fd(socket) };
}

pub(crate) fn set_reuseaddr(socket: TcpSocket, reuseaddr: bool) -> io::Result<()> {
let val: libc::c_int = if reuseaddr { 1 } else { 0 };
syscall!(setsockopt(
socket,
libc::SOL_SOCKET,
libc::SO_REUSEADDR,
&val as *const libc::c_int as *const libc::c_void,
size_of::<libc::c_int>() as libc::socklen_t,
))?;
Ok(())
}

pub fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> {
Expand Down
29 changes: 9 additions & 20 deletions src/sys/windows/net.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use std::io;
use std::mem::size_of_val;
use std::net::SocketAddr;
#[cfg(all(feature = "os-poll", feature = "tcp"))]
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
use std::sync::Once;

use winapi::ctypes::c_int;
use winapi::shared::ws2def::SOCKADDR;
use winapi::um::winsock2::{
ioctlsocket, socket, FIONBIO, INVALID_SOCKET, PF_INET, PF_INET6, SOCKET,
ioctlsocket, socket, FIONBIO, INVALID_SOCKET, SOCKET,
};

/// Initialise the network stack for Windows.
Expand All @@ -23,12 +21,19 @@ pub(crate) fn init() {
}

/// Create a new non-blocking socket.
pub(crate) fn new_socket(addr: SocketAddr, socket_type: c_int) -> io::Result<SOCKET> {
#[cfg(feature = "udp")]
pub(crate) fn new_ip_socket(addr: SocketAddr, socket_type: c_int) -> io::Result<SOCKET> {
use winapi::um::winsock2::{PF_INET, PF_INET6};

let domain = match addr {
SocketAddr::V4(..) => PF_INET,
SocketAddr::V6(..) => PF_INET6,
};

new_socket(domain, socket_type)
}

pub(crate) fn new_socket(domain: c_int, socket_type: c_int) -> io::Result<SOCKET> {
syscall!(
socket(domain, socket_type, 0),
PartialEq::eq,
Expand All @@ -51,19 +56,3 @@ pub(crate) fn socket_addr(addr: &SocketAddr) -> (*const SOCKADDR, c_int) {
),
}
}

#[cfg(all(feature = "os-poll", feature = "tcp"))]
pub(crate) fn inaddr_any(other: SocketAddr) -> SocketAddr {
match other {
SocketAddr::V4(..) => {
let any = Ipv4Addr::new(0, 0, 0, 0);
let addr = SocketAddrV4::new(any, 0);
SocketAddr::V4(addr)
}
SocketAddr::V6(..) => {
let any = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0);
let addr = SocketAddrV6::new(any, 0, 0, 0);
SocketAddr::V6(addr)
}
}
}
Loading

0 comments on commit 5b09e60

Please sign in to comment.