Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RequestBodyTimeout and ResponseBodyTimeout #109

Closed
wants to merge 11 commits into from
2 changes: 2 additions & 0 deletions tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- Add `RequestBodyTimeout` and `ResponseBodyTimeout` which apply timeouts to
request and response bodies respectively.
- Add `StatusInRangeAsFailures` which is a response classifier that considers
responses with status code in a certain range as failures. Useful for HTTP
clients where both server errors (5xx) and client errors (4xx) are considered
Expand Down
2 changes: 2 additions & 0 deletions tower-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ default = []
full = [
"add-extension",
"auth",
"timeout",
"compression",
"compression-full",
"decompression-full",
Expand All @@ -66,6 +67,7 @@ full = [

add-extension = []
auth = ["base64"]
timeout = ["tower/timeout"]
follow-redirect = ["iri-string", "tower/util"]
fs = ["tokio/fs", "tokio-util/io", "mime_guess", "mime"]
map-request-body = []
Expand Down
10 changes: 9 additions & 1 deletion tower-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
//! use tower::{ServiceBuilder, Service, ServiceExt};
//! use hyper::Body;
//! use http::{Request, Response, HeaderValue, header::USER_AGENT};
//! use std::time::Duration;
//!
//! #[tokio::main]
//! async fn main() {
Expand Down Expand Up @@ -217,13 +218,16 @@
clippy::match_like_matches_macro,
clippy::type_complexity
)]
#![forbid(unsafe_code)]
#![cfg_attr(not(test), forbid(unsafe_code))]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like all unsafe usage went away in later commits, so this can be reverted?

#![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(test, allow(clippy::float_cmp))]

#[macro_use]
pub(crate) mod macros;

#[cfg(test)]
pub(crate) mod test_utils;

#[cfg(feature = "auth")]
#[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
pub mod auth;
Expand Down Expand Up @@ -275,6 +279,10 @@ pub mod follow_redirect;
#[cfg_attr(docsrs, doc(cfg(feature = "metrics")))]
pub mod metrics;

#[cfg(feature = "timeout")]
#[cfg_attr(docsrs, doc(cfg(feature = "timeout")))]
pub mod timeout;

pub mod classify;
pub mod services;

Expand Down
7 changes: 7 additions & 0 deletions tower-http/src/test_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use http::{Request, Response};
use hyper::Body;
use tower::BoxError;

pub(crate) async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
132 changes: 132 additions & 0 deletions tower-http/src/timeout/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
//! Middleware that adds timeouts to request or response bodies.
//!
//! Note these middleware differ from [`tower::timeout::Timeout`] which only
//! adds a timeout to the response future and doesn't consider request bodies.
Comment on lines +3 to +4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean for ResponseBodyTimeout? Is it equivalent to the tower middleware?


use http_body::Body;
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::time::Sleep;
use tower::BoxError;

pub mod request_body;
pub mod response_body;

#[doc(inline)]
pub use self::{
request_body::{RequestBodyTimeout, RequestBodyTimeoutLayer},
response_body::{ResponseBodyTimeout, ResponseBodyTimeoutLayer},
};

/// An HTTP body with a timeout applied.
#[pin_project]
#[derive(Debug)]
pub struct TimeoutBody<B> {
#[pin]
inner: B,
#[pin]
state: State,
}

impl<B> TimeoutBody<B> {
pub(crate) fn new(inner: B, timeout: Duration) -> Self {
Self {
inner,
state: State::NotPolled(timeout),
}
}
}

// Only start the timeout after first poll of the body. This enum manages that.
#[allow(clippy::large_enum_variant)]
#[pin_project(project = StateProj)]
#[derive(Debug)]
enum State {
NotPolled(Duration),
SleepPending(#[pin] Sleep),
}
Copy link
Member Author

@davidpdrsn davidpdrsn Jun 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how useful this would be in practice. I'm just wondering if users might receive a request and not poll the body immediately and hit the timeout because they were slow to call the first poll.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to have a Lazy and non lazy version?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's common to first hit the DB for an auth check with a value from an authorization header before starting to look at the http body. This seems like a good idea to me.

(based on my understanding that the body would usually be polled immediately, except for use cases where the body is streamed)


impl Future for State {
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let new_state = match self.as_mut().project() {
StateProj::NotPolled(timeout) => State::SleepPending(tokio::time::sleep(*timeout)),
StateProj::SleepPending(sleep) => return sleep.poll(cx),
};
self.set(new_state);
}
}
}

impl<B> Body for TimeoutBody<B>
where
B: Body,
B::Error: Into<BoxError>,
{
type Data = B::Data;
type Error = BoxError;

fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let this = self.project();

if let Poll::Ready(chunk) = this.inner.poll_data(cx) {
let chunk = chunk.map(|chunk| chunk.map_err(Into::into));
return Poll::Ready(chunk);
}

if this.state.poll(cx).is_ready() {
let err = tower::timeout::error::Elapsed::new().into();
return Poll::Ready(Some(Err(err)));
}

Poll::Pending
}

fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
let this = self.project();

if let Poll::Ready(trailers) = this.inner.poll_trailers(cx) {
let trailers = trailers.map_err(Into::into);
return Poll::Ready(trailers);
}

if this.state.poll(cx).is_ready() {
let err = tower::timeout::error::Elapsed::new().into();
return Poll::Ready(Err(err));
}

Poll::Pending
}

fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}

fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}

impl<B> Default for TimeoutBody<B>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this impl useful for?

where
B: Default,
{
fn default() -> Self {
// We assume that `B::default` is empty so the value of the timeout
// shouldn't matter. Polling an empty body should be quick
TimeoutBody::new(B::default(), Duration::from_secs(10))
}
}
167 changes: 167 additions & 0 deletions tower-http/src/timeout/request_body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
//! Middleware that adds timeouts to request bodies.
//!
//! Be careful using this with streaming requests as it might abort a stream
//! earlier than othwerwise intended.
//!
//! # Example
//!
//! ```
//! use tower_http::timeout::RequestBodyTimeoutLayer;
//! use tower::BoxError;
//! use hyper::{Request, Response, Body, Error};
//! use tower::{Service, ServiceExt, ServiceBuilder};
//! use std::time::Duration;
//! use bytes::Bytes;
//!
//! async fn handle<B>(request: Request<B>) -> Result<Response<Body>, BoxError>
//! where
//! B: http_body::Body,
//! B::Error: Into<BoxError>,
//! {
//! // Buffer the whole request body. The timeout will be automatically
//! // applied by `RequestBodyTimeoutLayer`.
//! let body_bytes = hyper::body::to_bytes(request.into_body())
//! .await
//! .map_err(Into::into)?;
//!
//! Ok(Response::new(Body::empty()))
//! }
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), BoxError> {
//! let mut service = ServiceBuilder::new()
//! // Make sure the request body completes with 100 milliseconds.
//! .layer(RequestBodyTimeoutLayer::new(Duration::from_millis(100)))
//! .service_fn(handle);
//!
//! // Create a response body with a channel that we can use to send data
//! // asynchronously.
//! let (mut tx, body) = hyper::Body::channel();
//!
//! tokio::spawn(async move {
//! // Keep sending data forever. This would make the server hang if we
//! // didn't use `RequestBodyTimeoutLayer`.
//! loop {
//! tokio::time::sleep(Duration::from_secs(1)).await;
//! if tx.send_data(Bytes::from("foo")).await.is_err() {
//! break;
//! }
//! }
//! });
//!
//! let req = Request::new(body);
//!
//! // Calling the service should fail with a timeout error since buffering the
//! // request body hits the timeout.
//! let err = service.ready().await?.call(req).await.unwrap_err();
//! assert!(err.is::<tower::timeout::error::Elapsed>());
//! # Ok(())
//! # }
//! ```

use super::TimeoutBody;
use http::Request;
use std::{
task::{Context, Poll},
time::Duration,
};
use tower_layer::Layer;
use tower_service::Service;

/// Layer that applies [`RequestBodyTimeoutLayer`] which adds a timeout to the
/// request body.
///
/// If receiving the request body doesn't complete within the specified time, an
/// error is returned.
///
/// See the [module docs](crate::timeout::request_body) for an example.
#[derive(Debug, Copy, Clone)]
pub struct RequestBodyTimeoutLayer {
timeout: Duration,
}

impl RequestBodyTimeoutLayer {
/// Create a new `RequestBodyTimeoutLayer`.
pub fn new(timeout: Duration) -> Self {
RequestBodyTimeoutLayer { timeout }
}
}

impl<S> Layer<S> for RequestBodyTimeoutLayer {
type Service = RequestBodyTimeout<S>;

fn layer(&self, inner: S) -> Self::Service {
RequestBodyTimeout::new(inner, self.timeout)
}
}

/// Middleware that adds a timeout to the request bodies.
///
/// If receiving the request body doesn't complete within the specified time, an
/// error is returned.
///
/// See the [module docs](crate::timeout::request_body) for an example.
#[derive(Debug, Copy, Clone)]
pub struct RequestBodyTimeout<S> {
inner: S,
timeout: Duration,
}

impl<S> RequestBodyTimeout<S> {
/// Create a new `RequestBodyTimeout`.
pub fn new(inner: S, timeout: Duration) -> Self {
Self { inner, timeout }
}

define_inner_service_accessors!();

/// Returns a new [`Layer`] that wraps services with a
/// [`RequestBodyTimeoutLayer`]h middleware.
///
/// [`Layer`]: tower_layer::Layer
pub fn layer(timeout: Duration) -> RequestBodyTimeoutLayer {
RequestBodyTimeoutLayer::new(timeout)
}
}

impl<S, ReqBody> Service<Request<ReqBody>> for RequestBodyTimeout<S>
where
S: Service<Request<TimeoutBody<ReqBody>>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
self.inner
.call(req.map(|body| TimeoutBody::new(body, self.timeout)))
}
}

#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use http::{Request, Response};
use hyper::{Body, Error};

#[allow(dead_code)]
async fn is_compatible_with_hyper() {
let svc = tower::service_fn(handle);
let svc = RequestBodyTimeout::new(svc, Duration::from_secs(1));

let make_service = tower::make::Shared::new(svc);

let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000));
let server = hyper::Server::bind(&addr).serve(make_service);
server.await.unwrap();
}

async fn handle<B>(_req: Request<B>) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from("Hello, World!")))
}
}
Loading