Skip to content

Commit

Permalink
Add connect_timeout
Browse files Browse the repository at this point in the history
Unlike the rest of the API, connect_timeout does not correspond to a
single API call. However, it's a fundamental operation, and is a pain to
set up.
  • Loading branch information
sfackler committed Jun 8, 2017
1 parent b465cb3 commit aa17f4b
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,29 @@ impl Socket {
self.inner.connect(addr)
}

/// Initiate a connection on this socket to the specified address, only
/// only waiting for a certain period of time for the connection to be
/// established.
///
/// Unlike many other methods on `Socket`, this does *not* correspond to a
/// single C function. It sets the socket to nonblocking mode, connects via
/// connect(2), and then waits for the connection to complete with poll(2)
/// on Unix and select on Windows. When the connection is complete, the
/// socket is set back to blocking mode.
///
/// # Warnings
///
/// The nonblocking state of the socket is overridden by this function -
/// it will be returned in blocking mode on success, and in an indeterminate
/// state on failure.
///
/// If the connection request times out, the it may still be processing in
/// the background - a second call to `connect` or `connect_timeout` may
/// fail.
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
self.inner.connect_timeout(addr, timeout)
}

/// Binds this socket to the specified address.
///
/// This function directly corresponds to the bind(2) function on Windows
Expand Down Expand Up @@ -667,3 +690,33 @@ impl From<Protocol> for i32 {
a.into()
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn connect_timeout_unrouteable() {
// this IP is unroutable, so connections should always time out
let addr: SocketAddr = "10.255.255.1:80".parse().unwrap();

let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
match socket.connect_timeout(&addr, Duration::from_millis(250)) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {}
Err(e) => panic!("unexpected error {}", e),
}
}

#[test]
fn connect_timeout_valid() {
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
socket.bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
socket.listen(128).unwrap();

let addr = socket.local_addr().unwrap();

let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
socket.connect_timeout(&addr, Duration::from_millis(250)).unwrap();
}
}
47 changes: 47 additions & 0 deletions src/sys/unix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,53 @@ impl Socket {
}
}

pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
self.set_nonblocking(true)?;

match self.connect(addr) {
Ok(()) => {
self.set_nonblocking(false)?;
return Ok(());
}
// there's no io::ErrorKind conversion registered for EINPROGRESS :(
Err(ref e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {}
Err(e) => return Err(e),
}

let mut pollfd = libc::pollfd {
fd: self.as_raw_fd(),
events: libc::POLLOUT,
revents: 0,
};

if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
return Err(io::Error::new(io::ErrorKind::InvalidInput,
"cannot set a 0 duration timeout"));
}

let mut timeout = timeout.as_secs()
.saturating_mul(1_000)
.saturating_add(timeout.subsec_nanos() as u64 / 1_000_000);
if timeout == 0 {
timeout = 1;
}

let timeout = cmp::min(timeout, c_int::max_value() as u64) as c_int;

let r = unsafe { libc::poll(&mut pollfd, 1, timeout) };
match r {
-1 => return Err(io::Error::last_os_error()),
0 => return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
_ => {
if let Some(e) = self.take_error()? {
return Err(e);
}
}
}

self.set_nonblocking(false)
}

pub fn local_addr(&self) -> io::Result<SocketAddr> {
unsafe {
let mut storage: libc::sockaddr_storage = mem::zeroed();
Expand Down
52 changes: 52 additions & 0 deletions src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,58 @@ impl Socket {
}
}

pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
self.set_nonblocking(true)?;

match self.connect(addr) {
Ok(()) => {
self.set_nonblocking(false)?;
return Ok(());
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
}

if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
return Err(io::Error::new(io::ErrorKind::InvalidInput,
"cannot set a 0 duration timeout"));
}

let mut timeout = timeval {
tv_sec: timeout.as_secs() as c_long,
tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
};
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
timeout.tv_usec = 1;
}

let fds = unsafe {
let mut fds = mem::zeroed::<fd_set>();
fds.fd_count = 1;
fds.fd_array[0] = self.socket;
fds
};

let mut writefds = fds;
let mut errorfds = fds;

let r = unsafe {
ws2_32::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout)
};

match r {
SOCKET_ERROR => return Err(io::Error::last_os_error()),
0 => return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
_ => {
if let Some(e) = self.take_error()? {
return Err(e);
}
}
}

self.set_nonblocking(false)
}

pub fn local_addr(&self) -> io::Result<SocketAddr> {
unsafe {
let mut storage: SOCKADDR_STORAGE = mem::zeroed();
Expand Down

0 comments on commit aa17f4b

Please sign in to comment.