From 9d1941133435724ec881cfbe0ca4e2c90d9deb80 Mon Sep 17 00:00:00 2001 From: Ryan Fowler Date: Tue, 26 May 2026 23:47:21 -0700 Subject: [PATCH] Centralize timeout budget handling --- AGENTS.md | 2 + src/app.rs | 3 +- src/config/mod.rs | 2 +- src/dns/inspect.rs | 37 ++----- src/duration.rs | 216 ++++++++++++++++++++++++++++++++++++++++- src/grpc/reflection.rs | 5 +- src/http/client.rs | 68 ++----------- src/http/mod.rs | 74 +------------- src/tls/inspect.rs | 30 +----- src/update.rs | 2 +- src/websocket/mod.rs | 107 +++----------------- 11 files changed, 261 insertions(+), 285 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 4ad715f..8f35fc7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -101,7 +101,9 @@ metadata/update/DNS/TLS inspection modes, and executes requests via `src/http`. - Rust error rendering uses rich diagnostics for common CLI/config errors, styling labels, flags/options, invalid values, file paths, and config line context while preserving plain-text `Display` output. - Rust `-vvv` output prints config, DNS, TCP, and TTFB debug metadata through the central printer, including color policy and the blank response-header separator before formatted bodies. - Rust `--timing` enables DNS pre-resolution timing and wraps reqwest's connector service so the waterfall includes DNS, TCP, TTFB, and Body phases. reqwest does not currently expose a separate TLS handshake duration, so Rust reports the combined TCP/TLS connector phase as TCP timing. +- The top-level app future in `src/app.rs` is heap-pinned before the shutdown-signal `tokio::select!`; do not move it back to `tokio::pin!` on the main stack because the combined async request/WebSocket/inspection state can overflow Windows' smaller main-thread stack even for metadata commands. - Rust response body paging is controlled by `--pager auto|on|off` or `pager = ...`; `auto` routes terminal stdout through `less -FIRX`, `on` forces the pager, and `off` disables it. Image responses and output-file writes bypass the pager. +- Timeout duration parsing, Go-style duration formatting, elapsed request budgets, connect/DNS timeout caps, and shared `request timed out after ...` errors live in `src/duration.rs`; HTTP, WebSocket, DNS inspection, and TLS inspection paths should reuse `TimeoutBudget` instead of recomputing remaining time locally. - Custom/pre-resolved DNS observes timeout budgets before the reqwest client is built: `--connect-timeout` bounds DNS resolution when set, otherwise DNS uses the remaining `--timeout` budget, and DoH lookup clients receive the same budget. - Custom/pre-resolved DNS is scoped to the request URL; manual redirects that change scheme, host, or port rebuild the reqwest client and resolve the redirect target so `--dns-server`, `-vvv`, and `--timing` stay aligned with the actual target. - Custom/pre-resolved DNS runs A and AAAA lookups concurrently for both UDP and DoH, preserving any successful records when the other family fails. diff --git a/src/app.rs b/src/app.rs index ef4e725..4ac314a 100644 --- a/src/app.rs +++ b/src/app.rs @@ -27,8 +27,7 @@ pub async fn main_entry() -> i32 { }; let signal_color = cli.color.clone(); - let run = run(cli); - tokio::pin!(run); + let mut run = Box::pin(run(cli)); tokio::select! { result = &mut run => match result { Ok(code) => code, diff --git a/src/config/mod.rs b/src/config/mod.rs index 827755e..38308f4 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -642,7 +642,7 @@ fn parse_duration_seconds( let seconds = value .parse::() .map_err(|_| value_error(path, line_num, option, value, usage))?; - if !seconds.is_finite() || !(0.0..=crate::http::MAX_DURATION_SECONDS).contains(&seconds) { + if !seconds.is_finite() || !(0.0..=crate::duration::MAX_DURATION_SECONDS).contains(&seconds) { return Err(value_error(path, line_num, option, value, usage)); } Ok(seconds) diff --git a/src/dns/inspect.rs b/src/dns/inspect.rs index caaf362..5b69dd3 100644 --- a/src/dns/inspect.rs +++ b/src/dns/inspect.rs @@ -11,6 +11,7 @@ use crate::cli::Cli; use crate::core::{self, Printer, Sequence}; use crate::dns::util::{dns_query_id, udp_dns_timeout}; use crate::dns::wire; +use crate::duration::{TimeoutBudget, duration_from_seconds}; use crate::error::{FetchError, write_error_with_color, write_warning_with_color}; const DNS_TYPE_A: u16 = wire::TYPE_A; @@ -114,25 +115,17 @@ pub async fn execute(cli: &Cli) -> Result { let timeout = cli .timeout - .map(|seconds| crate::http::duration_from_seconds("timeout", seconds)) + .map(|seconds| duration_from_seconds("timeout", seconds)) .transpose()?; let use_color = core::color_enabled(cli.color.as_deref(), std::io::stderr().is_terminal()); - let inspected = if let Some(timeout) = timeout { - match tokio::time::timeout( + let inspected = TimeoutBudget::new(timeout) + .run(inspect_with_color( + &url, + cli.dns_server.as_deref(), + use_color, timeout, - inspect_with_color(&url, cli.dns_server.as_deref(), use_color, Some(timeout)), - ) - .await - { - Ok(result) => result, - Err(_) => Err(FetchError::Message(format!( - "request timed out after {}", - format_timeout(timeout) - ))), - } - } else { - inspect_with_color(&url, cli.dns_server.as_deref(), use_color, None).await - }; + )) + .await; match inspected { Ok(output) => { @@ -652,18 +645,6 @@ fn format_duration(duration: Duration) -> String { format_go_duration_nanos(rounded) } -fn format_timeout(timeout: Duration) -> String { - let seconds = timeout.as_secs_f64(); - if timeout.subsec_nanos() == 0 { - format!("{}s", timeout.as_secs()) - } else { - format!("{seconds:.3}s") - .trim_end_matches('0') - .trim_end_matches('.') - .to_string() - } -} - fn format_go_duration_nanos(nanos: u128) -> String { if nanos < 1_000 { return format!("{nanos}ns"); diff --git a/src/duration.rs b/src/duration.rs index fbdd3a7..3e8c13e 100644 --- a/src/duration.rs +++ b/src/duration.rs @@ -1,4 +1,7 @@ -use std::time::Duration; +use std::future::Future; +use std::time::{Duration, Instant}; + +use crate::error::FetchError; const NANOS_PER_MICRO: u128 = 1_000; const NANOS_PER_MILLI: u128 = 1_000_000; @@ -7,6 +10,108 @@ const NANOS_PER_MINUTE: u128 = 60 * NANOS_PER_SECOND; const NANOS_PER_HOUR: u128 = 60 * NANOS_PER_MINUTE; const NANOS_PER_DAY: u128 = 24 * NANOS_PER_HOUR; +pub(crate) const MAX_DURATION_SECONDS: f64 = i64::MAX as f64 / 1_000_000_000_f64; + +#[derive(Clone, Copy, Debug)] +pub(crate) struct TimeoutBudget { + timeout: Option, + started_at: Instant, +} + +impl TimeoutBudget { + pub(crate) fn new(timeout: Option) -> Self { + Self::started_at(timeout, Instant::now()) + } + + pub(crate) fn started_at(timeout: Option, started_at: Instant) -> Self { + Self { + timeout, + started_at, + } + } + + pub(crate) fn for_connect( + connect_timeout: Option, + request_timeout: Option, + request_started_at: Instant, + ) -> Result { + let request_remaining = remaining_timeout(request_timeout, request_started_at)?; + Ok(Self::new(min_timeout(connect_timeout, request_remaining))) + } + + pub(crate) fn timeout(self) -> Option { + self.timeout + } + + pub(crate) fn remaining(self) -> Result, FetchError> { + remaining_timeout(self.timeout, self.started_at) + } + + pub(crate) fn timeout_error(self) -> FetchError { + request_timeout_error(self.timeout.expect("timeout checked by caller")) + } + + pub(crate) async fn run( + self, + future: impl Future>, + ) -> Result { + let Some(remaining) = self.remaining()? else { + return future.await; + }; + let started_at = Instant::now(); + match tokio::time::timeout(remaining, future).await { + Ok(Err(err)) if started_at.elapsed() >= remaining && is_timeout_error(&err) => { + Err(self.timeout_error()) + } + Ok(result) => result, + Err(_) => Err(self.timeout_error()), + } + } +} + +pub(crate) fn duration_from_seconds(flag: &str, seconds: f64) -> Result { + if !seconds.is_finite() || !(0.0..=MAX_DURATION_SECONDS).contains(&seconds) { + return Err(format!("{flag} must be a non-negative number").into()); + } + Ok(Duration::from_secs_f64(seconds)) +} + +pub(crate) fn remaining_timeout( + timeout: Option, + started_at: Instant, +) -> Result, FetchError> { + let Some(timeout) = timeout else { + return Ok(None); + }; + let elapsed = started_at.elapsed(); + if elapsed >= timeout { + return Err(request_timeout_error(timeout)); + } + Ok(Some(timeout - elapsed)) +} + +pub(crate) fn request_timeout_error(timeout: Duration) -> FetchError { + FetchError::Runtime(request_timeout_message(timeout)) +} + +pub(crate) fn request_timeout_message(timeout: Duration) -> String { + format!("request timed out after {}", format_go_duration(timeout)) +} + +pub(crate) fn format_go_duration(duration: Duration) -> String { + let nanos = duration.as_nanos(); + if nanos < 1_000 { + return format!("{nanos}ns"); + } + if nanos < 1_000_000 { + return format_duration_unit(nanos, 1_000, "us"); + } + if nanos < 1_000_000_000 { + return format_duration_unit(nanos, 1_000_000, "ms"); + } + format_duration_unit(nanos, 1_000_000_000, "s") +} + pub(crate) fn parse_duration_interval(value: &str) -> Option { let mut rest = value.trim(); if rest.is_empty() || rest.starts_with('-') { @@ -88,6 +193,46 @@ fn duration_unit(value: &str) -> Option<(&'static str, u128)> { None } +fn min_timeout(left: Option, right: Option) -> Option { + match (left, right) { + (Some(left), Some(right)) => Some(left.min(right)), + (Some(left), None) => Some(left), + (None, right) => right, + } +} + +fn is_timeout_error(err: &FetchError) -> bool { + match err { + FetchError::Message(message) | FetchError::Runtime(message) => { + message.contains("timed out") + } + FetchError::Reqwest(err) => err.is_timeout(), + _ => false, + } +} + +fn format_duration_unit(nanos: u128, unit_nanos: u128, suffix: &str) -> String { + let whole = nanos / unit_nanos; + let remainder = nanos % unit_nanos; + if remainder == 0 { + return format!("{whole}{suffix}"); + } + + let digits = match suffix { + "us" => 3_u32, + "ms" => 6_u32, + _ => 9_u32, + }; + let scale = 10_u128.pow(digits); + let fraction_value = remainder * scale / unit_nanos; + let fraction = format!( + "{fraction_value:0width$}", + width = usize::try_from(digits).expect("small duration precision") + ); + let fraction = fraction.trim_end_matches('0'); + format!("{whole}.{fraction}{suffix}") +} + #[cfg(test)] mod tests { use super::*; @@ -137,4 +282,73 @@ mod tests { assert_eq!(parse_duration_interval("1sec"), None); assert_eq!(parse_duration_interval("garbage"), None); } + + #[test] + fn duration_from_seconds_rejects_values_outside_supported_range() { + assert_eq!( + duration_from_seconds("timeout", 1.5).unwrap(), + Duration::from_millis(1500) + ); + + for seconds in [-1.0, f64::NAN, f64::INFINITY, 1e100] { + let err = duration_from_seconds("timeout", seconds).unwrap_err(); + assert_eq!(err.to_string(), "timeout must be a non-negative number"); + } + } + + #[test] + fn timeout_budget_for_connect_uses_shortest_available_timeout() { + let budget = TimeoutBudget::for_connect( + Some(Duration::from_secs(5)), + Some(Duration::from_millis(250)), + Instant::now() - Duration::from_millis(100), + ) + .unwrap(); + let remaining = budget.remaining().unwrap().unwrap(); + + assert!(remaining <= Duration::from_millis(150)); + assert!(remaining > Duration::from_millis(100)); + + let budget = TimeoutBudget::for_connect( + Some(Duration::from_millis(250)), + Some(Duration::from_secs(5)), + Instant::now(), + ) + .unwrap(); + assert!(budget.timeout().unwrap() <= Duration::from_millis(250)); + } + + #[test] + fn remaining_timeout_reports_expired_request_budget() { + let err = remaining_timeout( + Some(Duration::from_millis(10)), + Instant::now() - Duration::from_millis(20), + ) + .unwrap_err(); + + assert_eq!(err.to_string(), "request timed out after 10ms"); + } + + #[test] + fn request_timeout_message_uses_go_duration_units() { + assert_eq!( + request_timeout_message(Duration::from_nanos(100)), + "request timed out after 100ns" + ); + assert_eq!( + request_timeout_message(Duration::from_millis(50)), + "request timed out after 50ms" + ); + } + + #[test] + fn format_go_duration_matches_common_go_units() { + assert_eq!(format_go_duration(Duration::from_nanos(100)), "100ns"); + assert_eq!(format_go_duration(Duration::from_nanos(1_500)), "1.5us"); + assert_eq!(format_go_duration(Duration::from_nanos(1_500_000)), "1.5ms"); + assert_eq!( + format_go_duration(Duration::from_nanos(1_500_000_000)), + "1.5s" + ); + } } diff --git a/src/grpc/reflection.rs b/src/grpc/reflection.rs index 933adfa..c4af6b7 100644 --- a/src/grpc/reflection.rs +++ b/src/grpc/reflection.rs @@ -7,6 +7,7 @@ use url::Url; use crate::cli::Cli; use crate::core; +use crate::duration::duration_from_seconds; use crate::error::FetchError; use crate::grpc::encoding::{self, MessageEncoding}; use crate::grpc::framing; @@ -33,11 +34,11 @@ pub async fn execute_discovery(cli: &Cli) -> Result { let request_start = Instant::now(); let request_timeout = cli .timeout - .map(|seconds| crate::http::duration_from_seconds("timeout", seconds)) + .map(|seconds| duration_from_seconds("timeout", seconds)) .transpose()?; let connect_timeout = cli .connect_timeout - .map(|seconds| crate::http::duration_from_seconds("connect-timeout", seconds)) + .map(|seconds| duration_from_seconds("connect-timeout", seconds)) .transpose()?; crate::tls::install_default_crypto_provider(); let connect_timing = crate::http::client::ConnectionTiming::default(); diff --git a/src/http/client.rs b/src/http/client.rs index 953d25f..29947a9 100644 --- a/src/http/client.rs +++ b/src/http/client.rs @@ -13,6 +13,7 @@ use url::Url; use crate::cli::{Cli, HttpVersion}; use crate::dns::custom; +use crate::duration::TimeoutBudget; use crate::error::FetchError; use crate::timing::DnsTiming; @@ -58,11 +59,12 @@ pub(crate) async fn build_client_for_url( context: &ClientBuildContext<'_>, ) -> Result { let http_version = context.mode.http_version(); - let dns_timeout = dns_resolution_timeout( - context.request_timeout, + let dns_timeout = TimeoutBudget::for_connect( context.connect_timeout, + context.request_timeout, context.request_start, - )?; + )? + .timeout(); let dns_resolution = resolve_dns_for_client(cli, url, http_version, dns_timeout).await?; let mut builder = Client::builder() .use_rustls_tls() @@ -84,7 +86,7 @@ pub(crate) async fn build_client_for_url( builder = builder.danger_accept_invalid_certs(true); } if let Some(timeout) = - remaining_request_timeout(context.request_timeout, context.request_start)? + TimeoutBudget::started_at(context.request_timeout, context.request_start).remaining()? { builder = builder.timeout(timeout); } @@ -121,36 +123,6 @@ pub(crate) fn configure_unix_socket( } } -fn dns_resolution_timeout( - request_timeout: Option, - connect_timeout: Option, - start: Instant, -) -> Result, FetchError> { - let remaining = remaining_request_timeout(request_timeout, start)?; - Ok(match (connect_timeout, remaining) { - (Some(connect), Some(remaining)) => Some(connect.min(remaining)), - (Some(connect), None) => Some(connect), - (None, remaining) => remaining, - }) -} - -fn remaining_request_timeout( - timeout: Option, - start: Instant, -) -> Result, FetchError> { - let Some(timeout) = timeout else { - return Ok(None); - }; - let elapsed = start.elapsed(); - if elapsed >= timeout { - return Err(FetchError::Runtime(format!( - "request timed out after {}", - crate::http::format_go_duration(timeout) - ))); - } - Ok(Some(timeout - elapsed)) -} - async fn resolve_dns_for_client( cli: &Cli, url: &Url, @@ -158,33 +130,7 @@ async fn resolve_dns_for_client( timeout: Option, ) -> Result, FetchError> { let resolve = resolve_dns_for_client_inner(cli, url, http_version, timeout); - if let Some(timeout) = timeout { - let start = Instant::now(); - match tokio::time::timeout(timeout, resolve).await { - Ok(Err(err)) if start.elapsed() >= timeout && is_timeout_error(&err) => { - Err(request_timeout_error(timeout)) - } - Ok(result) => result, - Err(_) => Err(request_timeout_error(timeout)), - } - } else { - resolve.await - } -} - -fn request_timeout_error(timeout: Duration) -> FetchError { - FetchError::Runtime(format!( - "request timed out after {}", - crate::http::format_go_duration(timeout) - )) -} - -fn is_timeout_error(err: &FetchError) -> bool { - match err { - FetchError::Runtime(message) => message.contains("timed out"), - FetchError::Reqwest(err) => err.is_timeout(), - _ => false, - } + TimeoutBudget::new(timeout).run(resolve).await } async fn resolve_dns_for_client_inner( diff --git a/src/http/mod.rs b/src/http/mod.rs index 6af702a..8a382c8 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -34,6 +34,7 @@ use crate::auth::aws_sigv4; use crate::auth::digest; use crate::cli::{Cli, CompressionMode, HttpVersion}; use crate::core; +use crate::duration::{duration_from_seconds, request_timeout_message}; use crate::error::{FetchError, write_error_with_color, write_warning_with_color}; use crate::format::content_type::{self, ContentType}; use crate::format::css; @@ -111,7 +112,6 @@ const MAX_BUFFERED_RESPONSE_BYTES: usize = 16 * 1024 * 1024; const MAX_DISCARDED_RESPONSE_BYTES: usize = 1024 * 1024; const BINARY_RESPONSE_WARNING: &str = "the response body appears to be binary\n\nTo output to the terminal anyway, use '--output -'"; -pub(crate) const MAX_DURATION_SECONDS: f64 = i64::MAX as f64 / 1_000_000_000_f64; pub async fn execute(cli: &Cli) -> Result { let http_version = @@ -2695,10 +2695,7 @@ fn timeout_error_message(cli: &Cli, err: &reqwest::Error) -> Option { } let seconds = cli.timeout?; let duration = duration_from_seconds("timeout", seconds).ok()?; - Some(format!( - "request timed out after {}", - format_go_duration(duration) - )) + Some(request_timeout_message(duration)) } fn reqwest_request_error_message(err: &reqwest::Error) -> String { @@ -2786,42 +2783,6 @@ fn is_certificate_validation_message(message: &str) -> bool { || lower.contains("hostname"))) } -pub(crate) fn format_go_duration(duration: Duration) -> String { - let nanos = duration.as_nanos(); - if nanos < 1_000 { - return format!("{nanos}ns"); - } - if nanos < 1_000_000 { - return format_duration_unit(nanos, 1_000, "us"); - } - if nanos < 1_000_000_000 { - return format_duration_unit(nanos, 1_000_000, "ms"); - } - format_duration_unit(nanos, 1_000_000_000, "s") -} - -fn format_duration_unit(nanos: u128, unit_nanos: u128, suffix: &str) -> String { - let whole = nanos / unit_nanos; - let remainder = nanos % unit_nanos; - if remainder == 0 { - return format!("{whole}{suffix}"); - } - - let digits = match suffix { - "us" => 3_u32, - "ms" => 6_u32, - _ => 9_u32, - }; - let scale = 10_u128.pow(digits); - let fraction_value = remainder * scale / unit_nanos; - let fraction = format!( - "{fraction_value:0width$}", - width = usize::try_from(digits).expect("small duration precision") - ); - let fraction = fraction.trim_end_matches('0'); - format!("{whole}.{fraction}{suffix}") -} - fn retry_reason(status: StatusCode) -> String { format!( "{} {}", @@ -2893,13 +2854,6 @@ pub(crate) fn request_target(url: &Url) -> String { target } -pub(crate) fn duration_from_seconds(flag: &str, seconds: f64) -> Result { - if !seconds.is_finite() || !(0.0..=MAX_DURATION_SECONDS).contains(&seconds) { - return Err(format!("{flag} must be a non-negative number").into()); - } - Ok(Duration::from_secs_f64(seconds)) -} - pub(crate) fn total_attempts_for_retry(retry_count: usize) -> Result { retry_count.checked_add(1).ok_or_else(|| { FetchError::invalid_value( @@ -4457,19 +4411,6 @@ mod tests { assert_eq!(format_delay(Duration::from_secs(1)), "1.0s"); } - #[test] - fn duration_from_seconds_rejects_values_outside_supported_range() { - assert_eq!( - duration_from_seconds("timeout", 1.5).unwrap(), - Duration::from_millis(1500) - ); - - for seconds in [-1.0, f64::NAN, f64::INFINITY, 1e100] { - let err = duration_from_seconds("timeout", seconds).unwrap_err(); - assert_eq!(err.to_string(), "timeout must be a non-negative number"); - } - } - #[test] fn total_attempts_for_retry_rejects_overflow() { assert_eq!(total_attempts_for_retry(0).unwrap(), 1); @@ -4591,17 +4532,6 @@ mod tests { } } - #[test] - fn format_go_duration_matches_common_go_units() { - assert_eq!(format_go_duration(Duration::from_nanos(100)), "100ns"); - assert_eq!(format_go_duration(Duration::from_nanos(1_500)), "1.5us"); - assert_eq!(format_go_duration(Duration::from_nanos(1_500_000)), "1.5ms"); - assert_eq!( - format_go_duration(Duration::from_nanos(1_500_000_000)), - "1.5s" - ); - } - #[test] fn http2_plain_http_rejects_like_go_transport() { let url = Url::parse("http://127.0.0.1:3000/").unwrap(); diff --git a/src/tls/inspect.rs b/src/tls/inspect.rs index c491a7b..a389117 100644 --- a/src/tls/inspect.rs +++ b/src/tls/inspect.rs @@ -19,6 +19,7 @@ use url::Url; use crate::cli::{Cli, HttpVersion}; use crate::core::{self, Printer, Sequence}; +use crate::duration::{TimeoutBudget, duration_from_seconds}; use crate::error::{FetchError, write_warning_with_color}; pub async fn execute(cli: &Cli) -> Result { @@ -36,20 +37,11 @@ pub async fn execute(cli: &Cli) -> Result { let timeout = cli .timeout - .map(|seconds| crate::http::duration_from_seconds("timeout", seconds)) + .map(|seconds| duration_from_seconds("timeout", seconds)) .transpose()?; - let inspection = if let Some(timeout) = timeout { - tokio::time::timeout(timeout, inspect(cli, &url, http_version, Some(timeout))) - .await - .map_err(|_| { - FetchError::Message(format!( - "request timed out after {}", - format_timeout(timeout) - )) - })?? - } else { - inspect(cli, &url, http_version, None).await? - }; + let inspection = TimeoutBudget::new(timeout) + .run(inspect(cli, &url, http_version, timeout)) + .await?; if !cli.silent { eprint!( @@ -799,18 +791,6 @@ fn cipher_suite_label(cipher: SupportedCipherSuite) -> String { format!("{:?}", cipher.suite()) } -fn format_timeout(timeout: Duration) -> String { - let seconds = timeout.as_secs_f64(); - if timeout.subsec_nanos() == 0 { - format!("{}s", timeout.as_secs()) - } else { - format!("{seconds:.3}s") - .trim_end_matches('0') - .trim_end_matches('.') - .to_string() - } -} - #[derive(Clone)] struct Inspection { version: Option, diff --git a/src/update.rs b/src/update.rs index dea5c27..e08cc6f 100644 --- a/src/update.rs +++ b/src/update.rs @@ -51,7 +51,7 @@ struct ReleaseArtifact<'a> { pub async fn execute(cli: &Cli) -> Result { let timeout = cli .timeout - .map(|seconds| crate::http::duration_from_seconds("timeout", seconds)) + .map(|seconds| crate::duration::duration_from_seconds("timeout", seconds)) .transpose()?; let mut builder = reqwest::Client::builder().use_rustls_tls(); if let Some(timeout) = timeout { diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 1784faf..5bb8749 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -21,6 +21,7 @@ use url::Url; use crate::auth::aws_sigv4; use crate::cli::Cli; use crate::core; +use crate::duration::{TimeoutBudget, duration_from_seconds}; use crate::error::{FetchError, write_warning_with_color}; use crate::format::json; @@ -66,21 +67,14 @@ pub async fn execute(cli: &Cli) -> Result { let request_start = Instant::now(); let request_timeout = websocket_request_timeout(cli)?; + let request_budget = TimeoutBudget::started_at(request_timeout, request_start); let connect_timeout = websocket_connect_timeout(cli, request_timeout, request_start)?; - let connect = connect_websocket(cli, &url, request, connector, connect_timeout); - let (mut stream, response) = if let Some(timeout) = request_timeout { - tokio::time::timeout(timeout, connect) + let connect = async { + connect_websocket(cli, &url, request, connector, connect_timeout) .await - .map_err(|_| { - FetchError::Message(format!( - "request timed out after {}", - crate::timing::format_timing_duration(timeout) - )) - })? - .map_err(websocket_error)? - } else { - connect.await.map_err(websocket_error)? + .map_err(websocket_error) }; + let (mut stream, response) = request_budget.run(connect).await?; if cli.verbose > 0 && !cli.silent { let status = response.status(); @@ -611,39 +605,9 @@ fn url_authority(url: &Url) -> Result { }) } -#[derive(Clone, Copy)] -struct TimeoutBudget { - timeout: Option, - started_at: Instant, -} - -impl TimeoutBudget { - fn new(timeout: Option) -> Self { - Self { - timeout, - started_at: Instant::now(), - } - } - - fn remaining(self) -> Result, FetchError> { - let Some(timeout) = self.timeout else { - return Ok(None); - }; - let elapsed = self.started_at.elapsed(); - if elapsed >= timeout { - return Err(websocket_timeout_error(timeout)); - } - Ok(Some(timeout - elapsed)) - } - - fn timeout_error(self) -> FetchError { - websocket_timeout_error(self.timeout.expect("timeout checked by caller")) - } -} - fn websocket_request_timeout(cli: &Cli) -> Result, FetchError> { cli.timeout - .map(|seconds| crate::http::duration_from_seconds("timeout", seconds)) + .map(|seconds| duration_from_seconds("timeout", seconds)) .transpose() } @@ -654,46 +618,16 @@ fn websocket_connect_timeout( ) -> Result { let connect_timeout = cli .connect_timeout - .map(|seconds| crate::http::duration_from_seconds("connect-timeout", seconds)) + .map(|seconds| duration_from_seconds("connect-timeout", seconds)) .transpose()?; - let request_remaining = remaining_timeout(request_timeout, request_start)?; - let timeout = match (connect_timeout, request_remaining) { - (Some(connect), Some(remaining)) => Some(connect.min(remaining)), - (Some(connect), None) => Some(connect), - (None, remaining) => remaining, - }; - Ok(TimeoutBudget::new(timeout)) -} - -fn remaining_timeout( - timeout: Option, - started_at: Instant, -) -> Result, FetchError> { - let Some(timeout) = timeout else { - return Ok(None); - }; - let elapsed = started_at.elapsed(); - if elapsed >= timeout { - return Err(websocket_timeout_error(timeout)); - } - Ok(Some(timeout - elapsed)) + TimeoutBudget::for_connect(connect_timeout, request_timeout, request_start) } async fn timeout_fetch( timeout: TimeoutBudget, future: impl Future>, ) -> Result { - let Some(remaining) = timeout.remaining()? else { - return future.await; - }; - let started_at = Instant::now(); - match tokio::time::timeout(remaining, future).await { - Ok(Err(err)) if started_at.elapsed() >= remaining && is_timeout_error(&err) => { - Err(timeout.timeout_error()) - } - Ok(result) => result, - Err(_) => Err(timeout.timeout_error()), - } + timeout.run(future).await } async fn timeout_ws( @@ -708,21 +642,6 @@ async fn timeout_ws( .map_err(|_| websocket_io_error(timeout.timeout_error()))? } -fn websocket_timeout_error(timeout: Duration) -> FetchError { - FetchError::Runtime(format!( - "request timed out after {}", - crate::http::format_go_duration(timeout) - )) -} - -fn is_timeout_error(err: &FetchError) -> bool { - match err { - FetchError::Runtime(message) => message.contains("timed out"), - FetchError::Reqwest(err) => err.is_timeout(), - _ => false, - } -} - fn websocket_io_error(err: FetchError) -> WsError { WsError::Io(io::Error::other(err.to_string())) } @@ -929,7 +848,11 @@ fn write_binary_indicator(cli: &Cli, len: usize) { } fn websocket_error(err: WsError) -> FetchError { - FetchError::Message(err.to_string()) + let message = err.to_string(); + if let Some(start) = message.find("request timed out after ") { + return FetchError::Runtime(message[start..].to_string()); + } + FetchError::Message(message) } #[cfg(test)]