Skip to content

Commit

Permalink
Merge pull request #37 from adamreichold/read-timeout
Browse files Browse the repository at this point in the history
Extend RequestBuilder to support specifying timeouts
  • Loading branch information
sbstp committed Jan 8, 2020
2 parents 47853c1 + 73d650b commit aeb2bd2
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 13 deletions.
4 changes: 1 addition & 3 deletions src/happy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@ use std::sync::mpsc::channel;
use std::thread;
use std::time::{Duration, Instant};

const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
const RACE_DELAY: Duration = Duration::from_millis(200);

/// This function implements a basic form of the happy eyeballs RFC to quickly connect
/// to a domain which is available in both IPv4 and IPv6. Connection attempts are raced
/// against each other and the first to connect successfully wins the race.
///
/// If the timeout is not provided, a default timeout of 10 seconds is used.
pub fn connect<A>(addrs: A, timeout: impl Into<Option<Duration>>) -> io::Result<TcpStream>
pub fn connect<A>(addrs: A, timeout: Duration) -> io::Result<TcpStream>
where
A: ToSocketAddrs,
{
let timeout = timeout.into().unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
let addrs: Vec<_> = addrs.to_socket_addrs()?.collect();

if let [addr] = &addrs[..] {
Expand Down
31 changes: 30 additions & 1 deletion src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::convert::From;
use std::convert::TryInto;
use std::io::{prelude::*, BufWriter};
use std::str;
use std::time::Duration;

#[cfg(feature = "compress")]
use http::header::ACCEPT_ENCODING;
Expand Down Expand Up @@ -69,6 +70,8 @@ pub struct RequestBuilder<B = [u8; 0]> {
body: B,
max_redirections: u32,
follow_redirects: bool,
connect_timeout: Duration,
read_timeout: Duration,
#[cfg(feature = "charsets")]
pub(crate) default_charset: Option<Charset>,
#[cfg(feature = "compress")]
Expand Down Expand Up @@ -108,6 +111,8 @@ impl RequestBuilder {
body: [],
max_redirections: 5,
follow_redirects: true,
connect_timeout: Duration::from_secs(30),
read_timeout: Duration::from_secs(30),
#[cfg(feature = "charsets")]
default_charset: None,
#[cfg(feature = "compress")]
Expand Down Expand Up @@ -241,6 +246,8 @@ impl<B> RequestBuilder<B> {
body,
max_redirections: self.max_redirections,
follow_redirects: self.follow_redirects,
connect_timeout: self.connect_timeout,
read_timeout: self.read_timeout,
#[cfg(feature = "charsets")]
default_charset: self.default_charset,
#[cfg(feature = "compress")]
Expand Down Expand Up @@ -314,6 +321,22 @@ impl<B> RequestBuilder<B> {
self
}

/// Sets a connect timeout for this request.
///
/// The default is 30 seconds.
pub fn connect_timeout(mut self, duration: Duration) -> Self {
self.connect_timeout = duration;
self
}

/// Sets a read timeout for this request.
///
/// The default is 30 seconds.
pub fn read_timeout(mut self, duration: Duration) -> Self {
self.read_timeout = duration;
self
}

/// Set the default charset to use while parsing the response of this `Request`.
///
/// If the response does not say which charset it uses, this charset will be used to decode the request.
Expand Down Expand Up @@ -353,6 +376,8 @@ impl<B: AsRef<[u8]>> RequestBuilder<B> {
body: self.body,
max_redirections: self.max_redirections,
follow_redirects: self.follow_redirects,
connect_timeout: self.connect_timeout,
read_timeout: self.read_timeout,
#[cfg(feature = "charsets")]
default_charset: self.default_charset,
#[cfg(feature = "compress")]
Expand Down Expand Up @@ -386,6 +411,8 @@ pub struct PreparedRequest<B> {
body: B,
max_redirections: u32,
follow_redirects: bool,
connect_timeout: Duration,
read_timeout: Duration,
#[cfg(feature = "charsets")]
pub(crate) default_charset: Option<Charset>,
#[cfg(feature = "compress")]
Expand All @@ -405,6 +432,8 @@ impl PreparedRequest<Vec<u8>> {
body: Vec::new(),
max_redirections: 5,
follow_redirects: true,
connect_timeout: Duration::from_secs(30),
read_timeout: Duration::from_secs(30),
#[cfg(feature = "charsets")]
default_charset: None,
#[cfg(feature = "compress")]
Expand Down Expand Up @@ -526,7 +555,7 @@ impl<B: AsRef<[u8]>> PreparedRequest<B> {
let mut redirections = 0;

loop {
let mut stream = BaseStream::connect(&url)?;
let mut stream = BaseStream::connect(&url, self.connect_timeout, self.read_timeout)?;
self.write_request(&mut stream, &url)?;
let resp = parse_response(stream, self)?;

Expand Down
30 changes: 21 additions & 9 deletions src/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use std::io::Cursor;
use std::io::{self, Read, Write};
use std::net::TcpStream;
use std::time::Duration;

#[cfg(feature = "tls")]
use native_tls::{HandshakeError, TlsConnector, TlsStream};
Expand All @@ -20,30 +21,41 @@ pub enum BaseStream {
}

impl BaseStream {
pub fn connect(url: &Url) -> Result<BaseStream> {
pub fn connect(url: &Url, connect_timeout: Duration, read_timeout: Duration) -> Result<BaseStream> {
let host = url.host_str().ok_or(ErrorKind::InvalidUrlHost)?;
let port = url.port_or_known_default().ok_or(ErrorKind::InvalidUrlPort)?;

debug!("trying to connect to {}:{}", host, port);

Ok(match url.scheme() {
"http" => BaseStream::Plain(happy::connect((host, port), None)?),
match url.scheme() {
"http" => BaseStream::connect_tcp(host, port, connect_timeout, read_timeout).map(BaseStream::Plain),
#[cfg(feature = "tls")]
"https" => BaseStream::connect_tls(host, port)?,
_ => return Err(ErrorKind::InvalidBaseUrl.into()),
})
"https" => BaseStream::connect_tls(host, port, connect_timeout, read_timeout).map(BaseStream::Tls),
_ => Err(ErrorKind::InvalidBaseUrl.into()),
}
}

fn connect_tcp(host: &str, port: u16, connect_timeout: Duration, read_timeout: Duration) -> Result<TcpStream> {
let stream = happy::connect((host, port), connect_timeout)?;
stream.set_read_timeout(Some(read_timeout))?;
Ok(stream)
}

#[cfg(feature = "tls")]
fn connect_tls(host: &str, port: u16) -> Result<BaseStream> {
fn connect_tls(
host: &str,
port: u16,
connect_timeout: Duration,
read_timeout: Duration,
) -> Result<TlsStream<TcpStream>> {
let connector = TlsConnector::new()?;
let stream = happy::connect((host, port), None)?;
let stream = BaseStream::connect_tcp(host, port, connect_timeout, read_timeout)?;
let tls_stream = match connector.connect(host, stream) {
Ok(stream) => stream,
Err(HandshakeError::Failure(err)) => return Err(err.into()),
Err(HandshakeError::WouldBlock(_)) => panic!("socket configured in non-blocking mode"),
};
Ok(BaseStream::Tls(tls_stream))
Ok(tls_stream)
}

#[cfg(test)]
Expand Down
31 changes: 31 additions & 0 deletions tests/test_timeout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use std::io;
use std::net::TcpListener;
use std::thread;
use std::time::Duration;

#[test]
fn request_fails_due_to_read_timeout() {
let listener = TcpListener::bind("localhost:0").unwrap();
let port = listener.local_addr().unwrap().port();
let thread = thread::spawn(move || {
let _stream = listener.accept().unwrap();
thread::sleep(Duration::from_millis(500));
});

let result = attohttpc::get(format!("http://localhost:{}", port))
.read_timeout(Duration::from_millis(100))
.send();

match result {
Err(err) => match err.kind() {
attohttpc::ErrorKind::Io(err) => match err.kind() {
io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => (),
err => panic!("Unexpected I/O error: {:?}", err),
},
err => panic!("Unexpected error: {:?}", err),
},
Ok(resp) => panic!("Unexpected response: {:?}", resp),
}

thread.join().unwrap();
}

0 comments on commit aeb2bd2

Please sign in to comment.