diff --git a/crates/base/src/server.rs b/crates/base/src/server.rs index 615d6468..885e6ff5 100644 --- a/crates/base/src/server.rs +++ b/crates/base/src/server.rs @@ -244,7 +244,7 @@ pub struct ServerFlags { pub tcp_nodelay: bool, pub graceful_exit_deadline_sec: u64, pub graceful_exit_keepalive_deadline_ms: Option, - pub request_idle_timeout_ms: Option, + pub request_read_timeout_ms: Option, } #[derive(Debug)] @@ -468,13 +468,13 @@ impl Server { let ServerFlags { tcp_nodelay, - request_idle_timeout_ms, + request_read_timeout_ms, mut graceful_exit_deadline_sec, mut graceful_exit_keepalive_deadline_ms, .. } = flags; - let request_idle_timeout_dur = request_idle_timeout_ms.map(Duration::from_millis); + let request_read_timeout_dur = request_read_timeout_ms.map(Duration::from_millis); let mut terminate_signal_fut = get_termination_signal(); loop { @@ -496,7 +496,7 @@ impl Server { event_tx, metric_src, graceful_exit_token.clone(), - request_idle_timeout_dur + request_read_timeout_dur ) } Err(e) => error!("socket error: {}", e) @@ -523,7 +523,7 @@ impl Server { event_tx, metric_src, graceful_exit_token.clone(), - request_idle_timeout_dur + request_read_timeout_dur ) } Err(e) => error!("socket error: {}", e) @@ -692,7 +692,7 @@ fn accept_stream( event_tx: Option>, metric_src: SharedMetricSource, graceful_exit_token: CancellationToken, - maybe_req_idle_timeout_dur: Option, + maybe_req_read_timeout_dur: Option, ) where I: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -700,15 +700,10 @@ fn accept_stream( tokio::task::spawn({ async move { let (service, cancel) = WorkerService::new(metric_src.clone(), req_tx); - let (io, maybe_timeout_tx) = if let Some(timeout_dur) = maybe_req_idle_timeout_dur { - let (timeout_tx, timeout_rx) = mpsc::unbounded_channel(); - - ( - crate::timeout::Stream::with_timeout(io, timeout_dur, timeout_rx), - Some(timeout_tx), - ) + let (io, maybe_timeout_tx) = if let Some(timeout_dur) = maybe_req_read_timeout_dur { + crate::timeout::Stream::with_timeout(io, timeout_dur) } else { - (crate::timeout::Stream::with_bypass(io), None) + crate::timeout::Stream::with_bypass(io) }; let _guard = cancel.drop_guard(); diff --git a/crates/base/src/timeout.rs b/crates/base/src/timeout.rs index 57134819..4fac82bf 100644 --- a/crates/base/src/timeout.rs +++ b/crates/base/src/timeout.rs @@ -11,7 +11,7 @@ use futures_util::Future; use pin_project::pin_project; use tokio::{ io::{AsyncRead, AsyncWrite}, - sync::mpsc::{UnboundedReceiver, UnboundedSender}, + sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, time::{sleep, Instant, Sleep}, }; @@ -26,7 +26,7 @@ enum StreamKind { duration: Duration, waiting: bool, finished: bool, - state: UnboundedReceiver, + rx: UnboundedReceiver, }, Bypass, @@ -42,21 +42,29 @@ impl Stream { Self { inner, kind } } - pub(super) fn with_timeout(inner: S, duration: Duration, rx: UnboundedReceiver) -> Self { - Self::new( - inner, - StreamKind::UseTimeout { - sleep: Box::pin(sleep(duration)), - duration, - waiting: false, - finished: false, - state: rx, - }, + pub(super) fn with_timeout( + inner: S, + duration: Duration, + ) -> (Self, Option>) { + let (tx, rx) = mpsc::unbounded_channel(); + + ( + Self::new( + inner, + StreamKind::UseTimeout { + sleep: Box::pin(sleep(duration)), + duration, + waiting: false, + finished: false, + rx, + }, + ), + Some(tx), ) } - pub(super) fn with_bypass(inner: S) -> Self { - Self::new(inner, StreamKind::Bypass) + pub(super) fn with_bypass(inner: S) -> (Self, Option>) { + (Self::new(inner, StreamKind::Bypass), None) } } @@ -72,10 +80,10 @@ impl AsyncRead for Stream { duration, waiting, finished, - state, + rx, } => { if !*finished { - match Pin::new(state).poll_recv(cx) { + match Pin::new(rx).poll_recv(cx) { Poll::Ready(Some(State::Reset)) => { *waiting = false; diff --git a/crates/base/tests/integration_tests.rs b/crates/base/tests/integration_tests.rs index 3b45c913..e24f82fb 100644 --- a/crates/base/tests/integration_tests.rs +++ b/crates/base/tests/integration_tests.rs @@ -1428,7 +1428,7 @@ async fn oak_with_jsr_specifier() { ); } -async fn test_slowloris(request_idle_timeout_ms: u64, maybe_tls: Option, test_fn: F) +async fn test_slowloris(request_read_timeout_ms: u64, maybe_tls: Option, test_fn: F) where F: (FnOnce(Box) -> R) + Send + 'static, R: Future + Send, @@ -1445,7 +1445,7 @@ where None, None, ServerFlags { - request_idle_timeout_ms: Some(request_idle_timeout_ms), + request_read_timeout_ms: Some(request_read_timeout_ms), ..Default::default() }, health_tx, diff --git a/crates/cli/src/flags.rs b/crates/cli/src/flags.rs index 81581e55..298bad2b 100644 --- a/crates/cli/src/flags.rs +++ b/crates/cli/src/flags.rs @@ -141,8 +141,8 @@ fn get_start_command() -> Command { .value_parser(value_parser!(u64)), ) .arg( - arg!(--"request-idle-timeout" ) - .help("Maximum time that can be waited from when the connection is accepted until the request body is fully read") + arg!(--"request-read-timeout" ) + .help("Maximum time in milliseconds that can be waited from when the connection is accepted until the request body is fully read") .value_parser(value_parser!(u64)), ) .arg( diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index e31ed849..e4468757 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -122,8 +122,8 @@ fn main() -> Result<(), anyhow::Error> { sub_matches.get_one::("max-parallelism").cloned(); let maybe_request_wait_timeout = sub_matches.get_one::("request-wait-timeout").cloned(); - let maybe_request_idle_timeout = - sub_matches.get_one::("request-idle-timeout").cloned(); + let maybe_request_read_timeout = + sub_matches.get_one::("request-read-timeout").cloned(); let static_patterns = if let Some(val_ref) = sub_matches.get_many::("static") { val_ref.map(|s| s.as_str()).collect::>() @@ -196,7 +196,7 @@ fn main() -> Result<(), anyhow::Error> { tcp_nodelay, graceful_exit_deadline_sec, graceful_exit_keepalive_deadline_ms, - request_idle_timeout_ms: maybe_request_idle_timeout, + request_read_timeout_ms: maybe_request_read_timeout, }, None, WorkerEntrypoints {