Skip to content

Commit

Permalink
Add connect_timeout
Browse files Browse the repository at this point in the history
Unlike the rest of the API, connect_timeout does not correspond to a
single API call. However, it's a fundamental operation, and is a pain to
set up.
  • Loading branch information
sfackler committed Jun 8, 2017
1 parent b465cb3 commit 53b29f3
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 1 deletion.
54 changes: 54 additions & 0 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,30 @@ impl Socket {
self.inner.connect(addr)
}

/// Initiate a connection on this socket to the specified address, only
/// only waiting for a certain period of time for the connection to be
/// established.
///
/// Unlike many other methods on `Socket`, this does *not* correspond to a
/// single C function. It sets the socket to nonblocking mode, connects via
/// connect(2), and then waits for the connection to complete with poll(2)
/// on Unix and select on Windows. When the connection is complete, the
/// socket is set back to blocking mode. On Unix, this will loop over
/// `EINTR` errors.
///
/// # Warnings
///
/// The nonblocking state of the socket is overridden by this function -
/// it will be returned in blocking mode on success, and in an indeterminate
/// state on failure.
///
/// If the connection request times out, the it may still be processing in
/// the background - a second call to `connect` or `connect_timeout` may
/// fail.
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
self.inner.connect_timeout(addr, timeout)
}

/// Binds this socket to the specified address.
///
/// This function directly corresponds to the bind(2) function on Windows
Expand Down Expand Up @@ -667,3 +691,33 @@ impl From<Protocol> for i32 {
a.into()
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn connect_timeout_unrouteable() {
// this IP is unroutable, so connections should always time out
let addr: SocketAddr = "10.255.255.1:80".parse().unwrap();

let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
match socket.connect_timeout(&addr, Duration::from_millis(250)) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {}
Err(e) => panic!("unexpected error {}", e),
}
}

#[test]
fn connect_timeout_valid() {
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
socket.bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
socket.listen(128).unwrap();

let addr = socket.local_addr().unwrap();

let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
socket.connect_timeout(&addr, Duration::from_millis(250)).unwrap();
}
}
63 changes: 62 additions & 1 deletion src/sys/unix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr}
use std::ops::Neg;
use std::os::unix::prelude::*;
use std::sync::atomic::{AtomicBool, Ordering, ATOMIC_BOOL_INIT};
use std::time::Duration;
use std::time::{Duration, Instant};

use libc::{self, c_void, c_int, sockaddr_in, sockaddr_storage, sockaddr_in6};
use libc::{sockaddr, socklen_t, AF_INET, AF_INET6, ssize_t};
Expand Down Expand Up @@ -118,6 +118,67 @@ impl Socket {
}
}

pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
self.set_nonblocking(true)?;
let r = self.connect(addr);
self.set_nonblocking(false)?;

match r {
Ok(()) => return Ok(()),
// there's no io::ErrorKind conversion registered for EINPROGRESS :(
Err(ref e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {}
Err(e) => return Err(e),
}

let mut pollfd = libc::pollfd {
fd: self.fd,
events: libc::POLLOUT,
revents: 0,
};

if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
return Err(io::Error::new(io::ErrorKind::InvalidInput,
"cannot set a 0 duration timeout"));
}

let start = Instant::now();

loop {
let elapsed = start.elapsed();
if elapsed >= timeout {
return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out"));
}

let timeout = timeout - elapsed;
let mut timeout = timeout.as_secs()
.saturating_mul(1_000)
.saturating_add(timeout.subsec_nanos() as u64 / 1_000_000);
if timeout == 0 {
timeout = 1;
}

let timeout = cmp::min(timeout, c_int::max_value() as u64) as c_int;

match unsafe { libc::poll(&mut pollfd, 1, timeout) } {
-1 => {
let err = io::Error::last_os_error();
if err.kind() != io::ErrorKind::Interrupted {
return Err(err);
}
}
0 => return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
_ => {
if pollfd.revents & libc::POLLOUT == 0 {
if let Some(e) = self.take_error()? {
return Err(e);
}
}
return Ok(());
}
}
}
}

pub fn local_addr(&self) -> io::Result<SocketAddr> {
unsafe {
let mut storage: libc::sockaddr_storage = mem::zeroed();
Expand Down
48 changes: 48 additions & 0 deletions src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,54 @@ impl Socket {
}
}

pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
self.set_nonblocking(true)?;
let r = self.connect(addr);
self.set_nonblocking(true)?;

match r {
Ok(()) => return Ok(()),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
}

if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
return Err(io::Error::new(io::ErrorKind::InvalidInput,
"cannot set a 0 duration timeout"));
}

let mut timeout = timeval {
tv_sec: timeout.as_secs() as c_long,
tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
};
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
timeout.tv_usec = 1;
}

let fds = unsafe {
let mut fds = mem::zeroed::<fd_set>();
fds.fd_count = 1;
fds.fd_array[0] = self.socket;
fds
};

let mut writefds = fds;
let mut errorfds = fds;

match unsafe { ws2_32::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout) } {
SOCKET_ERROR => return Err(io::Error::last_os_error()),
0 => return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
_ => {
if writefds.fd_count != 1 {
if let Some(e) = self.take_error()? {
return Err(e);
}
}
Ok(())
}
}
}

pub fn local_addr(&self) -> io::Result<SocketAddr> {
unsafe {
let mut storage: SOCKADDR_STORAGE = mem::zeroed();
Expand Down

0 comments on commit 53b29f3

Please sign in to comment.