From 169ee81cef57713176ce0992ec02660a06f3346e Mon Sep 17 00:00:00 2001 From: Ryan Fowler Date: Tue, 26 May 2026 23:08:02 -0700 Subject: [PATCH] Honor connect timeout for WebSocket dialing --- AGENTS.md | 2 +- docs/websocket.md | 6 + src/websocket/mod.rs | 361 ++++++++++++++++++++++++++++++++++--------- tests/integration.rs | 67 ++++++++ 4 files changed, 358 insertions(+), 78 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index fa84431..0f80411 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -95,7 +95,7 @@ metadata/update/DNS/TLS inspection modes, and executes requests via `src/http`. - `--tls` remains a compatibility alias for setting the minimum TLS version; prefer `--min-tls` in new docs/examples, and use `--max-tls` to cap negotiation or combine min/max for an exact TLS version. - Rust TLS version options accept only TLS 1.2 and TLS 1.3; legacy TLS 1.0/1.1 values are rejected consistently for CLI flags, config, WebSocket, and inspection paths. - WebSocket terminal sessions use the interactive prompt by default and can be controlled with `--ws-interactive auto|on|off`; output-file/clipboard/retry flags are rejected because the WebSocket path streams through the message loop instead of the normal response pipeline. -- `wss://` WebSocket handshakes build a rustls client config so `--ca-cert`, `--cert`/`--key`, `--insecure`, and TLS min/max settings apply; plain `ws://` rejects TLS flags. WebSocket requests use a custom dialer so `--dns-server` works for direct connections, and `--proxy` supports HTTP CONNECT plus SOCKS5/SOCKS5H tunnels before the WebSocket/TLS handshake. +- `wss://` WebSocket handshakes build a rustls client config so `--ca-cert`, `--cert`/`--key`, `--insecure`, and TLS min/max settings apply; plain `ws://` rejects TLS flags. WebSocket requests use a custom dialer so `--dns-server` works for direct connections, and `--proxy` supports HTTP CONNECT plus SOCKS5/SOCKS5H tunnels before the WebSocket/TLS handshake. `--connect-timeout` bounds WebSocket DNS, TCP, proxy negotiation, and TLS setup, capped by the remaining `--timeout` budget when both are set. - Metadata-only commands (`--help`, `--version`, `--buildinfo`) perform best-effort config parsing for presentation settings, but config errors and background auto-updates cannot block them. - Rust formatting code has a central `core::Printer`/`PrinterHandle` and ANSI `Sequence` abstraction; JSON/NDJSON write through the printer directly, other formatter/progress style helpers route escape emission through the shared sequence writer, and stderr metadata/inspection/error/warning renderers use the same printer for request/response headers and `--inspect-dns`/`--inspect-tls`. - Rust error rendering uses rich diagnostics for common CLI/config errors, styling labels, flags/options, invalid values, file paths, and config line context while preserving plain-text `Display` output. diff --git a/docs/websocket.md b/docs/websocket.md index 9a6da75..fdee397 100644 --- a/docs/websocket.md +++ b/docs/websocket.md @@ -98,6 +98,12 @@ The `--timeout` flag applies to the WebSocket handshake only. The connection sta fetch --timeout 5 ws://api.example.com/ws ``` +Use `--connect-timeout` to bound WebSocket connection setup phases such as custom DNS resolution, TCP connect, proxy CONNECT or SOCKS negotiation, and TLS handshakes. When both timeout flags are set, the connect timeout is capped by the remaining `--timeout` budget: + +```sh +fetch --connect-timeout 2 --timeout 10 wss://api.example.com/ws +``` + ## Limitations - WebSocket requires HTTP/1.1 for the upgrade handshake. Using `--http 3` with WebSocket is not supported. diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index f8bc2ab..51b8afb 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -1,8 +1,9 @@ +use std::future::Future; use std::io::{self, IsTerminal, Read}; use std::net::{IpAddr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use base64::Engine; use futures_util::{SinkExt, StreamExt}; @@ -63,9 +64,11 @@ pub async fn execute(cli: &Cli) -> Result { print_request_metadata(cli, method, &url, Some(request.headers())); } - let connect = connect_websocket(cli, &url, request, connector); - let (mut stream, response) = if let Some(seconds) = cli.timeout { - let timeout = crate::http::duration_from_seconds("timeout", seconds)?; + let request_start = Instant::now(); + let request_timeout = websocket_request_timeout(cli)?; + let connect_timeout = websocket_connect_timeout(cli, request_timeout, request_start)?; + let connect = connect_websocket(cli, &url, request, connector, connect_timeout); + let (mut stream, response) = if let Some(timeout) = request_timeout { tokio::time::timeout(timeout, connect) .await .map_err(|_| { @@ -170,6 +173,7 @@ async fn connect_websocket( url: &Url, request: tokio_tungstenite::tungstenite::http::Request<()>, connector: Option, + timeout: TimeoutBudget, ) -> Result< ( tokio_tungstenite::WebSocketStream>, @@ -177,22 +181,32 @@ async fn connect_websocket( ), WsError, > { - let stream = dial_websocket(cli, url).await.map_err(websocket_io_error)?; - client_async_tls_with_config(request, stream, None, connector).await + let stream = dial_websocket(cli, url, timeout) + .await + .map_err(websocket_io_error)?; + timeout_ws( + timeout, + client_async_tls_with_config(request, stream, None, connector), + ) + .await } -async fn dial_websocket(cli: &Cli, url: &Url) -> Result { +async fn dial_websocket( + cli: &Cli, + url: &Url, + timeout: TimeoutBudget, +) -> Result { if let Some(proxy) = cli.proxy.as_deref() { - return dial_proxy(proxy, url, websocket_timeout(cli)?).await; + return dial_proxy(proxy, url, timeout).await; } - let stream = connect_tcp(url, cli.dns_server.as_deref(), websocket_timeout(cli)?).await?; + let stream = connect_tcp(url, cli.dns_server.as_deref(), timeout).await?; Ok(Box::pin(stream)) } async fn connect_tcp( url: &Url, dns_server: Option<&str>, - timeout: Option, + timeout: TimeoutBudget, ) -> Result { let host = url .host_str() @@ -202,22 +216,26 @@ async fn connect_tcp( .ok_or_else(|| FetchError::Message("URL port is required".to_string()))?; if host.parse::().is_ok() || dns_server.is_none() { - return TcpStream::connect((host, port)) - .await - .map_err(FetchError::from); + return timeout_fetch(timeout, async { + TcpStream::connect((host, port)) + .await + .map_err(FetchError::from) + }) + .await; } - let mut addrs = resolve_websocket_host(host, dns_server, timeout).await?; + let mut addrs = + timeout_fetch(timeout, resolve_websocket_host(host, dns_server, timeout)).await?; for addr in &mut addrs { addr.set_port(port); } - connect_first(addrs).await + timeout_fetch(timeout, connect_first(addrs)).await } async fn resolve_websocket_host( host: &str, dns_server: Option<&str>, - timeout: Option, + timeout: TimeoutBudget, ) -> Result, FetchError> { let Some(dns_server) = dns_server else { return tokio::net::lookup_host((host, 0)) @@ -230,13 +248,13 @@ async fn resolve_websocket_host( let server_url = Url::parse(dns_server).map_err(|err| { FetchError::Message(format!("invalid dns-server '{dns_server}': {err}")) })?; - crate::dns::doh::lookup_doh(&server_url, host, timeout) + crate::dns::doh::lookup_doh(&server_url, host, timeout.remaining()?) .await .map_err(|err| FetchError::Runtime(format!("lookup {host}: {err}")))? } else { let server_addr = crate::dns::resolver::normalize_udp_dns_server(dns_server) .map_err(|err| FetchError::Message(err.to_string()))?; - crate::dns::resolver::lookup_udp(&server_addr, host, timeout) + crate::dns::resolver::lookup_udp(&server_addr, host, timeout.remaining()?) .await .map_err(|err| FetchError::Runtime(format!("lookup {host}: {err}")))? }; @@ -262,7 +280,7 @@ async fn connect_first(addrs: Vec) -> Result async fn dial_proxy( proxy: &str, target: &Url, - timeout: Option, + timeout: TimeoutBudget, ) -> Result { let proxy_url = parse_proxy_url(proxy)?; match proxy_url.scheme() { @@ -290,7 +308,7 @@ async fn dial_http_proxy( raw_proxy: &str, proxy_url: &Url, target: &Url, - timeout: Option, + timeout: TimeoutBudget, ) -> Result { let stream = connect_proxy_tcp(proxy_url, timeout).await?; let mut stream: DialStream = if proxy_url.scheme() == "https" { @@ -302,9 +320,13 @@ async fn dial_http_proxy( FetchError::Message(format!("invalid proxy '{raw_proxy}': invalid host")) })?; let config = crate::tls::rustls_client_config(&[], None, None, false, None, None)?; - let stream = tokio_rustls::TlsConnector::from(Arc::new(config)) - .connect(server_name, stream) - .await?; + let stream = timeout_fetch(timeout, async { + tokio_rustls::TlsConnector::from(Arc::new(config)) + .connect(server_name, stream) + .await + .map_err(FetchError::from) + }) + .await?; Box::pin(stream) } else { Box::pin(stream) @@ -321,42 +343,46 @@ async fn dial_http_proxy( request.push_str("\r\n"); } request.push_str("\r\n"); - stream.write_all(request.as_bytes()).await?; + timeout_fetch(timeout, async { + stream.write_all(request.as_bytes()).await?; - let mut raw = Vec::new(); - let mut buf = [0_u8; 1]; - while !raw.ends_with(b"\r\n\r\n") { - if raw.len() >= 16 * 1024 { - return Err(FetchError::Runtime( - "proxy CONNECT response was too large".to_string(), - )); + let mut raw = Vec::new(); + let mut buf = [0_u8; 1]; + while !raw.ends_with(b"\r\n\r\n") { + if raw.len() >= 16 * 1024 { + return Err(FetchError::Runtime( + "proxy CONNECT response was too large".to_string(), + )); + } + let n = stream.read(&mut buf).await?; + if n == 0 { + return Err(FetchError::Runtime( + "proxy closed connection during CONNECT".to_string(), + )); + } + raw.extend_from_slice(&buf[..n]); } - let n = stream.read(&mut buf).await?; - if n == 0 { - return Err(FetchError::Runtime( - "proxy closed connection during CONNECT".to_string(), - )); + let response = String::from_utf8_lossy(&raw); + let status = response + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .and_then(|value| value.parse::().ok()) + .unwrap_or(0); + if !(200..300).contains(&status) { + return Err(FetchError::Runtime(format!( + "proxy CONNECT failed with status {status}" + ))); } - raw.extend_from_slice(&buf[..n]); - } - let response = String::from_utf8_lossy(&raw); - let status = response - .lines() - .next() - .and_then(|line| line.split_whitespace().nth(1)) - .and_then(|value| value.parse::().ok()) - .unwrap_or(0); - if !(200..300).contains(&status) { - return Err(FetchError::Runtime(format!( - "proxy CONNECT failed with status {status}" - ))); - } + Ok(()) + }) + .await?; Ok(stream) } async fn connect_proxy_tcp( proxy_url: &Url, - timeout: Option, + timeout: TimeoutBudget, ) -> Result { let host = proxy_url .host_str() @@ -370,21 +396,14 @@ async fn connect_proxy_tcp( 80 } }); - let addrs = tokio::net::lookup_host((host, port)) - .await - .map_err(|err| FetchError::Runtime(format!("lookup {host}: {err}")))? - .collect(); - let connect = connect_first(addrs); - if let Some(timeout) = timeout { - tokio::time::timeout(timeout, connect).await.map_err(|_| { - FetchError::Runtime(format!( - "request timed out after {}", - crate::http::format_go_duration(timeout) - )) - })? - } else { - connect.await - } + timeout_fetch(timeout, async { + let addrs = tokio::net::lookup_host((host, port)) + .await + .map_err(|err| FetchError::Runtime(format!("lookup {host}: {err}")))? + .collect(); + connect_first(addrs).await + }) + .await } fn proxy_basic_auth(proxy_url: &Url) -> Result, FetchError> { @@ -407,7 +426,7 @@ fn proxy_basic_auth(proxy_url: &Url) -> Result, FetchError> { async fn dial_socks5_proxy( proxy_url: &Url, target: &Url, - timeout: Option, + timeout: TimeoutBudget, ) -> Result { let mut stream = connect_proxy_tcp(proxy_url, timeout).await?; let username = percent_encoding::percent_decode_str(proxy_url.username()) @@ -425,13 +444,25 @@ async fn dial_socks5_proxy( "SOCKS5 proxy credentials are too long".to_string(), )); } - stream.write_all(&[0x05, 0x02, 0x00, 0x02]).await?; + timeout_fetch(timeout, async { + stream.write_all(&[0x05, 0x02, 0x00, 0x02]).await?; + Ok(()) + }) + .await?; } else { - stream.write_all(&[0x05, 0x01, 0x00]).await?; + timeout_fetch(timeout, async { + stream.write_all(&[0x05, 0x01, 0x00]).await?; + Ok(()) + }) + .await?; } let mut method = [0_u8; 2]; - stream.read_exact(&mut method).await?; + timeout_fetch(timeout, async { + stream.read_exact(&mut method).await?; + Ok(()) + }) + .await?; if method[0] != 0x05 { return Err(FetchError::Runtime( "SOCKS5 proxy returned an invalid greeting".to_string(), @@ -449,9 +480,17 @@ async fn dial_socks5_proxy( auth.extend_from_slice(username.as_bytes()); auth.push(password.len() as u8); auth.extend_from_slice(password.as_bytes()); - stream.write_all(&auth).await?; + timeout_fetch(timeout, async { + stream.write_all(&auth).await?; + Ok(()) + }) + .await?; let mut response = [0_u8; 2]; - stream.read_exact(&mut response).await?; + timeout_fetch(timeout, async { + stream.read_exact(&mut response).await?; + Ok(()) + }) + .await?; if response != [0x01, 0x00] { return Err(FetchError::Runtime( "SOCKS5 proxy authentication failed".to_string(), @@ -471,18 +510,30 @@ async fn dial_socks5_proxy( } let mut request = vec![0x05, 0x01, 0x00]; - write_socks5_target(&mut request, proxy_url.scheme() == "socks5h", target).await?; - stream.write_all(&request).await?; + timeout_fetch( + timeout, + write_socks5_target(&mut request, proxy_url.scheme() == "socks5h", target), + ) + .await?; + timeout_fetch(timeout, async { + stream.write_all(&request).await?; + Ok(()) + }) + .await?; let mut response = [0_u8; 4]; - stream.read_exact(&mut response).await?; + timeout_fetch(timeout, async { + stream.read_exact(&mut response).await?; + Ok(()) + }) + .await?; if response[0] != 0x05 || response[1] != 0x00 { return Err(FetchError::Runtime(format!( "SOCKS5 proxy CONNECT failed with status {}", response[1] ))); } - read_socks5_bound_addr(&mut stream, response[3]).await?; + timeout_fetch(timeout, read_socks5_bound_addr(&mut stream, response[3])).await?; Ok(Box::pin(stream)) } @@ -573,12 +624,118 @@ fn url_authority(url: &Url) -> Result { }) } -fn websocket_timeout(cli: &Cli) -> Result, FetchError> { +#[derive(Clone, Copy)] +struct TimeoutBudget { + timeout: Option, + started_at: Instant, +} + +impl TimeoutBudget { + fn new(timeout: Option) -> Self { + Self { + timeout, + started_at: Instant::now(), + } + } + + fn remaining(self) -> Result, FetchError> { + let Some(timeout) = self.timeout else { + return Ok(None); + }; + let elapsed = self.started_at.elapsed(); + if elapsed >= timeout { + return Err(websocket_timeout_error(timeout)); + } + Ok(Some(timeout - elapsed)) + } + + fn timeout_error(self) -> FetchError { + websocket_timeout_error(self.timeout.expect("timeout checked by caller")) + } +} + +fn websocket_request_timeout(cli: &Cli) -> Result, FetchError> { cli.timeout .map(|seconds| crate::http::duration_from_seconds("timeout", seconds)) .transpose() } +fn websocket_connect_timeout( + cli: &Cli, + request_timeout: Option, + request_start: Instant, +) -> Result { + let connect_timeout = cli + .connect_timeout + .map(|seconds| crate::http::duration_from_seconds("connect-timeout", seconds)) + .transpose()?; + let request_remaining = remaining_timeout(request_timeout, request_start)?; + let timeout = match (connect_timeout, request_remaining) { + (Some(connect), Some(remaining)) => Some(connect.min(remaining)), + (Some(connect), None) => Some(connect), + (None, remaining) => remaining, + }; + Ok(TimeoutBudget::new(timeout)) +} + +fn remaining_timeout( + timeout: Option, + started_at: Instant, +) -> Result, FetchError> { + let Some(timeout) = timeout else { + return Ok(None); + }; + let elapsed = started_at.elapsed(); + if elapsed >= timeout { + return Err(websocket_timeout_error(timeout)); + } + Ok(Some(timeout - elapsed)) +} + +async fn timeout_fetch( + timeout: TimeoutBudget, + future: impl Future>, +) -> Result { + let Some(remaining) = timeout.remaining()? else { + return future.await; + }; + let started_at = Instant::now(); + match tokio::time::timeout(remaining, future).await { + Ok(Err(err)) if started_at.elapsed() >= remaining && is_timeout_error(&err) => { + Err(timeout.timeout_error()) + } + Ok(result) => result, + Err(_) => Err(timeout.timeout_error()), + } +} + +async fn timeout_ws( + timeout: TimeoutBudget, + future: impl Future>, +) -> Result { + let Some(remaining) = timeout.remaining().map_err(websocket_io_error)? else { + return future.await; + }; + tokio::time::timeout(remaining, future) + .await + .map_err(|_| websocket_io_error(timeout.timeout_error()))? +} + +fn websocket_timeout_error(timeout: Duration) -> FetchError { + FetchError::Runtime(format!( + "request timed out after {}", + crate::http::format_go_duration(timeout) + )) +} + +fn is_timeout_error(err: &FetchError) -> bool { + match err { + FetchError::Runtime(message) => message.contains("timed out"), + FetchError::Reqwest(err) => err.is_timeout(), + _ => false, + } +} + fn websocket_io_error(err: FetchError) -> WsError { WsError::Io(io::Error::other(err.to_string())) } @@ -869,4 +1026,54 @@ mod tests { let off_cli = Cli::try_parse_from(["fetch", "--color", "off", "ws://example.com"]).unwrap(); assert!(!use_color(&off_cli, true)); } + + #[test] + fn websocket_connect_timeout_uses_connect_timeout_when_shorter() { + let cli = Cli::try_parse_from([ + "fetch", + "--connect-timeout", + "0.25", + "--timeout", + "5", + "ws://example.com", + ]) + .unwrap(); + let request_timeout = websocket_request_timeout(&cli).unwrap(); + let budget = websocket_connect_timeout(&cli, request_timeout, Instant::now()).unwrap(); + let remaining = budget.remaining().unwrap().unwrap(); + + assert!(remaining <= Duration::from_millis(250)); + assert!(remaining > Duration::from_millis(200)); + } + + #[test] + fn websocket_connect_timeout_is_bounded_by_remaining_request_timeout() { + let cli = Cli::try_parse_from([ + "fetch", + "--connect-timeout", + "5", + "--timeout", + "0.25", + "ws://example.com", + ]) + .unwrap(); + let request_timeout = websocket_request_timeout(&cli).unwrap(); + let request_start = Instant::now() - Duration::from_millis(100); + let budget = websocket_connect_timeout(&cli, request_timeout, request_start).unwrap(); + let remaining = budget.remaining().unwrap().unwrap(); + + assert!(remaining <= Duration::from_millis(150)); + assert!(remaining > Duration::from_millis(100)); + } + + #[test] + fn websocket_connect_timeout_falls_back_to_request_timeout() { + let cli = Cli::try_parse_from(["fetch", "--timeout", "0.25", "ws://example.com"]).unwrap(); + let request_timeout = websocket_request_timeout(&cli).unwrap(); + let budget = websocket_connect_timeout(&cli, request_timeout, Instant::now()).unwrap(); + let remaining = budget.remaining().unwrap().unwrap(); + + assert!(remaining <= Duration::from_millis(250)); + assert!(remaining > Duration::from_millis(200)); + } } diff --git a/tests/integration.rs b/tests/integration.rs index ac1c3a2..7b3da31 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1956,6 +1956,25 @@ fn start_http_connect_proxy(target_addr: String) -> (String, mpsc::Receiver String { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind stalling proxy"); + let proxy_url = format!("{scheme}://{}", listener.local_addr().unwrap()); + thread::spawn(move || { + for conn in listener.incoming() { + let Ok(mut conn) = conn else { + break; + }; + thread::spawn(move || { + let _ = conn.set_read_timeout(Some(Duration::from_millis(200))); + let mut buf = [0_u8; 1024]; + let _ = conn.read(&mut buf); + thread::sleep(Duration::from_secs(5)); + }); + } + }); + proxy_url +} + fn handle_http_connect_proxy_conn( mut conn: TcpStream, target_addr: &str, @@ -6425,6 +6444,54 @@ fn websocket_custom_dns_and_proxy_cases() { assert!(res.stdout.contains("echo: socks websocket")); } +#[test] +fn websocket_connect_timeout_covers_dns_and_proxy_handshakes() { + let unresponsive_dns_addr = start_unresponsive_udp_dns_server(); + let res = run_fetch(&[ + "--dns-server", + &unresponsive_dns_addr, + "--connect-timeout", + "0.05", + "--timeout", + "1", + "ws://ws-dns-timeout.test:80", + "--ws-interactive", + "off", + ]); + assert_exit(&res, 1); + assert!(res.stderr.contains("request timed out after 50ms")); + + let http_proxy = start_stalling_proxy("http"); + let res = run_fetch(&[ + "--proxy", + &http_proxy, + "--connect-timeout", + "0.05", + "--timeout", + "1", + "ws://example.com/socket", + "--ws-interactive", + "off", + ]); + assert_exit(&res, 1); + assert!(res.stderr.contains("request timed out after 50ms")); + + let socks_proxy = start_stalling_proxy("socks5"); + let res = run_fetch(&[ + "--proxy", + &socks_proxy, + "--connect-timeout", + "0.05", + "--timeout", + "1", + "ws://example.com/socket", + "--ws-interactive", + "off", + ]); + assert_exit(&res, 1); + assert!(res.stderr.contains("request timed out after 50ms")); +} + #[test] fn websocket_dry_run_prints_effective_handshake_headers() { let res = run_fetch(&[