Skip to content

Commit

Permalink
stamp: polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
nyannyacha committed May 10, 2024
1 parent 1f5a40a commit 4f606e9
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 37 deletions.
23 changes: 9 additions & 14 deletions crates/base/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ pub struct ServerFlags {
pub tcp_nodelay: bool,
pub graceful_exit_deadline_sec: u64,
pub graceful_exit_keepalive_deadline_ms: Option<u64>,
pub request_idle_timeout_ms: Option<u64>,
pub request_read_timeout_ms: Option<u64>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -468,13 +468,13 @@ impl Server {

let ServerFlags {
tcp_nodelay,
request_idle_timeout_ms,
request_read_timeout_ms,
mut graceful_exit_deadline_sec,
mut graceful_exit_keepalive_deadline_ms,
..
} = flags;

let request_idle_timeout_dur = request_idle_timeout_ms.map(Duration::from_millis);
let request_read_timeout_dur = request_read_timeout_ms.map(Duration::from_millis);
let mut terminate_signal_fut = get_termination_signal();

loop {
Expand All @@ -496,7 +496,7 @@ impl Server {
event_tx,
metric_src,
graceful_exit_token.clone(),
request_idle_timeout_dur
request_read_timeout_dur
)
}
Err(e) => error!("socket error: {}", e)
Expand All @@ -523,7 +523,7 @@ impl Server {
event_tx,
metric_src,
graceful_exit_token.clone(),
request_idle_timeout_dur
request_read_timeout_dur
)
}
Err(e) => error!("socket error: {}", e)
Expand Down Expand Up @@ -692,23 +692,18 @@ fn accept_stream<I>(
event_tx: Option<UnboundedSender<ServerEvent>>,
metric_src: SharedMetricSource,
graceful_exit_token: CancellationToken,
maybe_req_idle_timeout_dur: Option<Duration>,
maybe_req_read_timeout_dur: Option<Duration>,
) where
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
metric_src.incl_active_io();
tokio::task::spawn({
async move {
let (service, cancel) = WorkerService::new(metric_src.clone(), req_tx);
let (io, maybe_timeout_tx) = if let Some(timeout_dur) = maybe_req_idle_timeout_dur {
let (timeout_tx, timeout_rx) = mpsc::unbounded_channel();

(
crate::timeout::Stream::with_timeout(io, timeout_dur, timeout_rx),
Some(timeout_tx),
)
let (io, maybe_timeout_tx) = if let Some(timeout_dur) = maybe_req_read_timeout_dur {
crate::timeout::Stream::with_timeout(io, timeout_dur)
} else {
(crate::timeout::Stream::with_bypass(io), None)
crate::timeout::Stream::with_bypass(io)
};

let _guard = cancel.drop_guard();
Expand Down
40 changes: 24 additions & 16 deletions crates/base/src/timeout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use futures_util::Future;
use pin_project::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc::{UnboundedReceiver, UnboundedSender},
sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
time::{sleep, Instant, Sleep},
};

Expand All @@ -26,7 +26,7 @@ enum StreamKind {
duration: Duration,
waiting: bool,
finished: bool,
state: UnboundedReceiver<State>,
rx: UnboundedReceiver<State>,
},

Bypass,
Expand All @@ -42,21 +42,29 @@ impl<S> Stream<S> {
Self { inner, kind }
}

pub(super) fn with_timeout(inner: S, duration: Duration, rx: UnboundedReceiver<State>) -> Self {
Self::new(
inner,
StreamKind::UseTimeout {
sleep: Box::pin(sleep(duration)),
duration,
waiting: false,
finished: false,
state: rx,
},
pub(super) fn with_timeout(
inner: S,
duration: Duration,
) -> (Self, Option<UnboundedSender<State>>) {
let (tx, rx) = mpsc::unbounded_channel();

(
Self::new(
inner,
StreamKind::UseTimeout {
sleep: Box::pin(sleep(duration)),
duration,
waiting: false,
finished: false,
rx,
},
),
Some(tx),
)
}

pub(super) fn with_bypass(inner: S) -> Self {
Self::new(inner, StreamKind::Bypass)
pub(super) fn with_bypass(inner: S) -> (Self, Option<UnboundedSender<State>>) {
(Self::new(inner, StreamKind::Bypass), None)
}
}

Expand All @@ -72,10 +80,10 @@ impl<S: AsyncRead + Unpin> AsyncRead for Stream<S> {
duration,
waiting,
finished,
state,
rx,
} => {
if !*finished {
match Pin::new(state).poll_recv(cx) {
match Pin::new(rx).poll_recv(cx) {
Poll::Ready(Some(State::Reset)) => {
*waiting = false;

Expand Down
4 changes: 2 additions & 2 deletions crates/base/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1428,7 +1428,7 @@ async fn oak_with_jsr_specifier() {
);
}

async fn test_slowloris<F, R>(request_idle_timeout_ms: u64, maybe_tls: Option<Tls>, test_fn: F)
async fn test_slowloris<F, R>(request_read_timeout_ms: u64, maybe_tls: Option<Tls>, test_fn: F)
where
F: (FnOnce(Box<dyn AsyncReadWrite>) -> R) + Send + 'static,
R: Future<Output = bool> + Send,
Expand All @@ -1445,7 +1445,7 @@ where
None,
None,
ServerFlags {
request_idle_timeout_ms: Some(request_idle_timeout_ms),
request_read_timeout_ms: Some(request_read_timeout_ms),
..Default::default()
},
health_tx,
Expand Down
4 changes: 2 additions & 2 deletions crates/cli/src/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ fn get_start_command() -> Command {
.value_parser(value_parser!(u64)),
)
.arg(
arg!(--"request-idle-timeout" <MILLISECONDS>)
.help("Maximum time that can be waited from when the connection is accepted until the request body is fully read")
arg!(--"request-read-timeout" <MILLISECONDS>)
.help("Maximum time in milliseconds that can be waited from when the connection is accepted until the request body is fully read")
.value_parser(value_parser!(u64)),
)
.arg(
Expand Down
6 changes: 3 additions & 3 deletions crates/cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ fn main() -> Result<(), anyhow::Error> {
sub_matches.get_one::<usize>("max-parallelism").cloned();
let maybe_request_wait_timeout =
sub_matches.get_one::<u64>("request-wait-timeout").cloned();
let maybe_request_idle_timeout =
sub_matches.get_one::<u64>("request-idle-timeout").cloned();
let maybe_request_read_timeout =
sub_matches.get_one::<u64>("request-read-timeout").cloned();
let static_patterns =
if let Some(val_ref) = sub_matches.get_many::<String>("static") {
val_ref.map(|s| s.as_str()).collect::<Vec<&str>>()
Expand Down Expand Up @@ -196,7 +196,7 @@ fn main() -> Result<(), anyhow::Error> {
tcp_nodelay,
graceful_exit_deadline_sec,
graceful_exit_keepalive_deadline_ms,
request_idle_timeout_ms: maybe_request_idle_timeout,
request_read_timeout_ms: maybe_request_read_timeout,
},
None,
WorkerEntrypoints {
Expand Down

0 comments on commit 4f606e9

Please sign in to comment.