diff --git a/AGENTS.md b/AGENTS.md index af19856..0d606da 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -89,7 +89,8 @@ metadata/update/DNS/TLS inspection modes, and executes requests via `src/http`. - gRPC calls and reflection advertise `grpc-accept-encoding: gzip`; response frames with the compressed flag are decompressed with the response `grpc-encoding` before protobuf decoding, with unsupported encodings reported by name. - gRPC standard request headers, status extraction from headers/trailers, and full framed-body reads live under `src/grpc`; request execution and reflection should reuse those helpers instead of duplicating protocol handling. - Client-streaming and bidi gRPC calls stream JSON input into framed protobuf request bodies instead of materializing the whole stream up front; stdin-backed gRPC JSON streams use the shared incremental parser behind a blocking stdin bridge, and Windows pipe stdin is peeked before reads so complete request messages can be sent before EOF without byte-at-a-time reads. -- `--inspect-dns` resolves the URL hostname without making an HTTP request, showing common DNS record types, resolver backend, duration, and per-record TTLs from direct UDP or DoH responses. +- Custom UDP DNS queries advertise EDNS(0) and retry truncated responses over TCP. +- `--inspect-dns` resolves the URL hostname without making an HTTP request, showing common DNS record types, resolver backend, duration, and per-record TTLs from direct UDP or DoH responses. UDP inspection queries retry truncated UDP responses over TCP; if TCP fallback cannot complete the lookup, render a warning about incomplete results and exit non-zero instead of silently omitting that record type. - `--inspect-tls --http 3` performs QUIC/TLS inspection with `h3` ALPN instead of the TCP TLS path. - `--inspect-tls` honors `--dns-server` for both TCP and QUIC inspection, resolving domain targets through the configured UDP or DoH resolver before the TLS handshake. - Rust `--inspect-tls` renders a verified certificate chain when verification succeeds, appending omitted trusted roots or replacing server-sent cross-signed roots with the matching platform/custom trusted root for expiry display; `--insecure` keeps the raw peer chain. diff --git a/docs/advanced-features.md b/docs/advanced-features.md index a50b40f..07db156 100644 --- a/docs/advanced-features.md +++ b/docs/advanced-features.md @@ -23,6 +23,8 @@ fetch --dns-server 1.1.1.1:53 example.com fetch --dns-server "[2001:4860:4860::8888]:53" example.com ``` +UDP DNS queries advertise EDNS(0) and retry truncated responses over TCP. + ### DNS-over-HTTPS (DoH) Use HTTPS URL for encrypted DNS queries: @@ -47,7 +49,7 @@ fetch --inspect-dns example.com fetch --inspect-dns --dns-server https://1.1.1.1/dns-query example.com ``` -The output shows the resolver backend, A, AAAA, CNAME, TXT, MX, NS, SOA, SRV, CAA, SVCB, and HTTPS records when present, address count, record count, lookup duration, and per-record TTLs. +The output shows the resolver backend, A, AAAA, CNAME, TXT, MX, NS, SOA, SRV, CAA, SVCB, and HTTPS records when present, address count, record count, lookup duration, and per-record TTLs. UDP DNS inspection advertises EDNS(0) and retries truncated UDP responses over TCP; if TCP fallback cannot complete the lookup, `fetch` warns that the results are incomplete and exits with a non-zero status. ### Configuration File diff --git a/docs/cli-reference.md b/docs/cli-reference.md index 0506e61..cde1b22 100644 --- a/docs/cli-reference.md +++ b/docs/cli-reference.md @@ -353,7 +353,8 @@ fetch --retry 3 --retry-delay 0.5 example.com ### `--dns-server IP[:PORT]|URL` Use custom DNS server. Supports UDP DNS and DNS-over-HTTPS for requests and -DNS/TLS inspection. +DNS/TLS inspection. UDP DNS queries advertise EDNS(0) and retry truncated +responses over TCP. ```sh fetch --dns-server 8.8.8.8 example.com @@ -363,7 +364,7 @@ fetch --dns-server https://1.1.1.1/dns-query example.com ### `--inspect-dns` -Inspect DNS resolution for the URL hostname only (no HTTP request is made). Displays the resolver backend, A, AAAA, CNAME, TXT, MX, NS, SOA, SRV, CAA, SVCB, and HTTPS records when present, along with per-record TTLs, address count, record count, and lookup duration. +Inspect DNS resolution for the URL hostname only (no HTTP request is made). Displays the resolver backend, A, AAAA, CNAME, TXT, MX, NS, SOA, SRV, CAA, SVCB, and HTTPS records when present, along with per-record TTLs, address count, record count, and lookup duration. UDP DNS inspection advertises EDNS(0) and retries truncated UDP responses over TCP; if TCP fallback cannot complete the lookup, `fetch` warns that the results are incomplete and exits with a non-zero status. ```sh fetch --inspect-dns example.com diff --git a/src/dns/inspect.rs b/src/dns/inspect.rs index 34140cd..29a7c3b 100644 --- a/src/dns/inspect.rs +++ b/src/dns/inspect.rs @@ -3,7 +3,6 @@ use std::net::IpAddr; use std::time::{Duration, Instant}; use futures_util::future::join_all; -use tokio::net::UdpSocket; use url::Url; use crate::cli::Cli; @@ -79,6 +78,34 @@ struct QueryType { dns_type: u16, } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum QueryErrorKind { + Other, + Truncated, +} + +#[derive(Debug)] +struct QueryError { + kind: QueryErrorKind, + message: String, +} + +impl QueryError { + fn other(err: impl ToString) -> Self { + Self { + kind: QueryErrorKind::Other, + message: err.to_string(), + } + } + + fn truncated(err: impl ToString) -> Self { + Self { + kind: QueryErrorKind::Truncated, + message: err.to_string(), + } + } +} + #[derive(Clone, Debug, Eq, PartialEq)] struct Record { typ: String, @@ -92,7 +119,9 @@ struct Inspection { host: String, resolver: String, records: HashMap>, + warnings: Vec, duration: Duration, + exit_code: i32, } #[derive(Clone, Debug, Eq, PartialEq)] @@ -127,9 +156,9 @@ pub async fn execute(cli: &Cli) -> Result { .await; match inspected { - Ok(()) => { + Ok(code) => { printer.flush_to(&mut std::io::stderr())?; - Ok(0) + Ok(code) } Err(err) => { write_error_with_color(err, cli.color.as_deref()); @@ -140,9 +169,21 @@ pub async fn execute(cli: &Cli) -> Result { #[cfg(test)] async fn inspect(url: &Url, dns_server: Option<&str>) -> Result { + let (out, _) = inspect_with_code(url, dns_server).await?; + Ok(out) +} + +#[cfg(test)] +async fn inspect_with_code( + url: &Url, + dns_server: Option<&str>, +) -> Result<(String, i32), FetchError> { let mut out = Printer::new(false); - inspect_to(url, dns_server, &mut out, None).await?; - Ok(out.into_string().expect("DNS inspection output is UTF-8")) + let code = inspect_to(url, dns_server, &mut out, None).await?; + Ok(( + out.into_string().expect("DNS inspection output is UTF-8"), + code, + )) } async fn inspect_to( @@ -150,7 +191,7 @@ async fn inspect_to( dns_server: Option<&str>, out: &mut Printer, timeout: Option, -) -> Result<(), FetchError> { +) -> Result { let host = url .host_str() .filter(|host| !host.is_empty()) @@ -160,12 +201,12 @@ async fn inspect_to( if let Ok(ip) = host.parse::() { render_ip_literal_to(out, host, ip, target.label(), start.elapsed()); - return Ok(()); + return Ok(0); } let result = lookup(host, target, start, timeout).await?; render_to(&result, out); - Ok(()) + Ok(result.exit_code) } async fn lookup( @@ -178,7 +219,9 @@ async fn lookup( host: host.to_string(), resolver: target.label().to_string(), records: HashMap::new(), + warnings: Vec::new(), duration: Duration::ZERO, + exit_code: 0, }; if matches!(target, ResolverTarget::Default { .. }) { @@ -207,7 +250,7 @@ async fn lookup( let client = doh_client.as_ref(); let target = ⌖ async move { - match (target, client) { + let result = match (target, client) { (ResolverTarget::Doh { url, .. }, Some(client)) => { lookup_doh_records(client, url, host, query_type).await } @@ -220,14 +263,16 @@ async fn lookup( (ResolverTarget::Doh { .. }, None) => { unreachable!("DoH client initialized above") } - } + }; + (query_type, result) } }); let results = join_all(futures).await; let mut first_err = None; + let mut truncated_types = Vec::new(); let mut seen: HashMap = HashMap::new(); - for result in results { + for (query_type, result) in results { match result { Ok(records) => { for record in records { @@ -243,6 +288,12 @@ async fn lookup( records.push(record); } } + Err(err) if err.kind == QueryErrorKind::Truncated => { + truncated_types.push(query_type.label); + if first_err.is_none() { + first_err = Some(err); + } + } Err(err) if first_err.is_none() => { first_err = Some(err); } @@ -251,11 +302,16 @@ async fn lookup( } out.duration = start.elapsed(); + if !truncated_types.is_empty() { + out.warnings.push(truncated_warning(&truncated_types)); + out.exit_code = 1; + return Ok(out); + } if record_count(&out) > 0 { return Ok(out); } if let Some(err) = first_err { - return Err(format!("lookup {host}: {err}").into()); + return Err(format!("lookup {host}: {}", err.message).into()); } Err(format!("lookup {host}: no DNS records found").into()) } @@ -285,11 +341,11 @@ async fn lookup_doh_records( server_url: &Url, host: &str, query_type: QueryType, -) -> Result, FetchError> { +) -> Result, QueryError> { let records = crate::dns::doh::lookup_doh_records_with_client(client, server_url, host, query_type.label) .await - .map_err(|err| FetchError::Message(err.to_string()))?; + .map_err(QueryError::other)?; Ok(records .into_iter() .map(|answer| { @@ -309,38 +365,35 @@ async fn lookup_udp_records( host: &str, query_type: QueryType, timeout: Duration, -) -> Result, FetchError> { +) -> Result, QueryError> { let id = dns_query_id(); - let raw = wire::build_query(id, host, query_type.dns_type) - .map_err(|err| FetchError::Message(err.to_string()))?; - let socket = UdpSocket::bind(if server_addr.starts_with('[') { - "[::]:0" - } else { - "0.0.0.0:0" - }) - .await?; - socket.connect(server_addr).await?; - socket.send(&raw).await?; - - let mut buf = vec![0u8; 4096]; - let n = match tokio::time::timeout(timeout, socket.recv(&mut buf)).await { - Ok(Ok(n)) => n, - Ok(Err(err)) => return Err(err.into()), - Err(_) => return Err("DNS lookup timed out".into()), + let raw = wire::build_query(id, host, query_type.dns_type).map_err(QueryError::other)?; + let mut response = crate::dns::transport::query_udp(server_addr, &raw, timeout) + .await + .map_err(QueryError::other)?; + let raw_records = match wire::parse_response(&response, id) { + Ok(_) => wire::parse_response(&response, id).map_err(QueryError::other)?, + Err(err) if err.is_truncated() => { + response = crate::dns::transport::query_tcp(server_addr, &raw, timeout) + .await + .map_err(QueryError::truncated)?; + wire::parse_response(&response, id).map_err(QueryError::truncated)? + } + Err(err) => return Err(QueryError::other(err)), }; - let raw_records = - wire::parse_response(&buf[..n], id).map_err(|err| FetchError::Message(err.to_string()))?; let mut records = Vec::new(); for raw_record in raw_records { if raw_record.class != DNS_CLASS_IN { continue; } if let Some(value) = resource_value( - &buf[..n], + &response, raw_record.typ, raw_record.data_offset, raw_record.data.len(), - )? { + ) + .map_err(QueryError::other)? + { records.push(Record { typ: type_label(raw_record.typ).to_string(), value, @@ -558,6 +611,16 @@ fn render_to(result: &Inspection, out: &mut Printer) { out.push_str("Duration: "); out.write_styled(&format_duration(result.duration), &[Sequence::Dim]); out.push_str("\n"); + render_warnings(out, &result.warnings); +} + +fn render_warnings(out: &mut Printer, warnings: &[String]) { + if !warnings.is_empty() { + out.push('\n'); + } + for warning in warnings { + core::write_warning_msg_no_flush(out, warning); + } } fn write_dns_title(out: &mut Printer, host: &str, resolver: &str) { @@ -629,6 +692,19 @@ fn record_count(result: &Inspection) -> usize { result.records.values().map(Vec::len).sum() } +fn truncated_warning(types: &[&'static str]) -> String { + if types.len() == 1 { + return format!( + "DNS response for {} was truncated over UDP after EDNS(0), and TCP fallback failed; results are incomplete", + types[0] + ); + } + format!( + "DNS responses for {} were truncated over UDP after EDNS(0), and TCP fallback failed; results are incomplete", + types.join(", ") + ) +} + fn format_duration(duration: Duration) -> String { let nanos = duration.as_nanos(); let rounded = if nanos < 1_000_000 { @@ -989,7 +1065,8 @@ impl ResolverTarget { mod tests { use super::*; use clap::Parser; - use std::net::UdpSocket as StdUdpSocket; + use std::io::{Read, Write}; + use std::net::{TcpListener, TcpStream as StdTcpStream, UdpSocket as StdUdpSocket}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; use std::thread; @@ -1214,6 +1291,47 @@ mod tests { stop(); } + #[tokio::test] + async fn test_inspect_udp_truncated_response_uses_tcp_fallback() { + let (addr, stop) = start_truncated_udp_tcp_server(DNS_TYPE_TXT); + + let (out, code) = inspect_with_code( + &Url::parse("https://example.com").unwrap(), + Some(addr.as_str()), + ) + .await + .unwrap(); + + assert_eq!(code, 0); + assert!(out.contains("TXT\n")); + assert!(out.contains("v=spf1 -all (TTL 2m)")); + assert!(!out.contains("warning:")); + stop(); + } + + #[tokio::test] + async fn test_inspect_udp_truncated_response_warns_and_exits_nonzero() { + let (addr, stop) = start_truncated_udp_server(DNS_TYPE_TXT); + + let (out, code) = inspect_with_code( + &Url::parse("https://example.com").unwrap(), + Some(addr.as_str()), + ) + .await + .unwrap(); + + assert_eq!(code, 1); + assert!(out.contains("A\n")); + assert!(out.contains("192.0.2.10 (TTL 42s)")); + assert!(out.contains( + "\n\nwarning: DNS response for TXT was truncated over UDP after EDNS(0), and TCP fallback failed; results are incomplete\n" + )); + assert!(out.ends_with( + "warning: DNS response for TXT was truncated over UDP after EDNS(0), and TCP fallback failed; results are incomplete\n" + )); + stop(); + } + #[tokio::test] async fn test_lookup_uses_default_resolver_when_no_system_dns_server_discovered() { let target = resolver_target_from_resolv_conf(None, Some("# no nameservers\n")); @@ -1257,7 +1375,9 @@ mod tests { has_ttl: true, }], )]), + warnings: Vec::new(), duration: Duration::ZERO, + exit_code: 0, }); assert!(out.contains("└─ 192.0.2.1 (TTL 1m)")); @@ -1278,7 +1398,9 @@ mod tests { has_ttl: true, }], )]), + warnings: Vec::new(), duration: Duration::ZERO, + exit_code: 0, }, true, ); @@ -1316,7 +1438,9 @@ mod tests { }, ], )]), + warnings: Vec::new(), duration: Duration::ZERO, + exit_code: 0, }); let first = out.find("192.0.2.10").unwrap(); @@ -1477,6 +1601,180 @@ mod tests { }) } + fn start_truncated_udp_server(truncated_type: u16) -> (String, impl FnOnce()) { + let socket = StdUdpSocket::bind("127.0.0.1:0").unwrap(); + socket + .set_read_timeout(Some(Duration::from_millis(100))) + .unwrap(); + let addr = socket.local_addr().unwrap().to_string(); + let done = Arc::new(Mutex::new(false)); + let thread_done = done.clone(); + let handle = thread::spawn(move || { + let mut buf = [0u8; 512]; + loop { + if *thread_done.lock().unwrap() { + return; + } + let Ok((n, peer)) = socket.recv_from(&mut buf) else { + continue; + }; + if n < 12 { + continue; + } + let query_type = read_question_type(&buf[..n]).unwrap_or_default(); + let mut response = Vec::new(); + response.extend_from_slice(&buf[0..2]); + let flags = if query_type == truncated_type { + 0x8380u16 + } else { + 0x8180u16 + }; + response.extend_from_slice(&flags.to_be_bytes()); + response.extend_from_slice(&1u16.to_be_bytes()); + response.extend_from_slice(&(u16::from(query_type == DNS_TYPE_A)).to_be_bytes()); + response.extend_from_slice(&0u16.to_be_bytes()); + response.extend_from_slice(&0u16.to_be_bytes()); + let question_name_end = question_end(&buf[..n]).unwrap_or(12); + let question_end = (question_name_end + 4).min(n); + response.extend_from_slice(&buf[12..question_end]); + if query_type == DNS_TYPE_A { + response.extend_from_slice(&[0xc0, 0x0c]); + response.extend_from_slice(&DNS_TYPE_A.to_be_bytes()); + response.extend_from_slice(&DNS_CLASS_IN.to_be_bytes()); + response.extend_from_slice(&42u32.to_be_bytes()); + response.extend_from_slice(&4u16.to_be_bytes()); + response.extend_from_slice(&[192, 0, 2, 10]); + } + let _ = socket.send_to(&response, peer); + } + }); + + (addr, move || { + *done.lock().unwrap() = true; + let _ = StdUdpSocket::bind("127.0.0.1:0") + .unwrap() + .send_to(&[0], "127.0.0.1:9"); + handle.join().unwrap(); + }) + } + + fn start_truncated_udp_tcp_server(truncated_type: u16) -> (String, impl FnOnce()) { + let udp_socket = StdUdpSocket::bind("127.0.0.1:0").unwrap(); + udp_socket + .set_read_timeout(Some(Duration::from_millis(100))) + .unwrap(); + let addr = udp_socket.local_addr().unwrap(); + let tcp_listener = TcpListener::bind(addr).unwrap(); + tcp_listener.set_nonblocking(true).unwrap(); + let done = Arc::new(Mutex::new(false)); + + let udp_done = done.clone(); + let udp_handle = thread::spawn(move || { + let mut buf = [0u8; 512]; + loop { + if *udp_done.lock().unwrap() { + return; + } + let Ok((n, peer)) = udp_socket.recv_from(&mut buf) else { + continue; + }; + if n < 12 { + continue; + } + let query = &buf[..n]; + let query_type = read_question_type(query).unwrap_or_default(); + let flags = if query_type == truncated_type { + 0x8380u16 + } else { + 0x8180u16 + }; + let mut response = inspect_response_header(query, flags, query_type == DNS_TYPE_A); + if query_type == DNS_TYPE_A { + write_raw_answer(&mut response, DNS_TYPE_A, 42, &[192, 0, 2, 10]); + } + let _ = udp_socket.send_to(&response, peer); + } + }); + + let tcp_done = done.clone(); + let tcp_handle = thread::spawn(move || { + loop { + if *tcp_done.lock().unwrap() { + return; + } + match tcp_listener.accept() { + Ok((mut stream, _)) => { + let Some(query) = read_tcp_query(&mut stream) else { + continue; + }; + let query_type = read_question_type(&query).unwrap_or_default(); + let mut response = + inspect_response_header(&query, 0x8180, query_type == truncated_type); + if query_type == truncated_type && query_type == DNS_TYPE_TXT { + write_txt_answer(&mut response, 120, "v=spf1 -all"); + } + let mut framed = Vec::with_capacity(response.len() + 2); + framed.extend_from_slice(&(response.len() as u16).to_be_bytes()); + framed.extend_from_slice(&response); + let _ = stream.write_all(&framed); + } + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { + thread::sleep(Duration::from_millis(10)); + } + Err(_) => {} + } + } + }); + + (addr.to_string(), move || { + *done.lock().unwrap() = true; + let _ = StdUdpSocket::bind("127.0.0.1:0") + .unwrap() + .send_to(&[0], addr); + let _ = StdTcpStream::connect(addr); + udp_handle.join().unwrap(); + tcp_handle.join().unwrap(); + }) + } + + fn inspect_response_header(query: &[u8], flags: u16, has_answer: bool) -> Vec { + let question_name_end = question_end(query).unwrap_or(12); + let question_end = (question_name_end + 4).min(query.len()); + let mut response = Vec::new(); + response.extend_from_slice(&query[0..2]); + response.extend_from_slice(&flags.to_be_bytes()); + response.extend_from_slice(&1u16.to_be_bytes()); + response.extend_from_slice(&(u16::from(has_answer)).to_be_bytes()); + response.extend_from_slice(&0u16.to_be_bytes()); + response.extend_from_slice(&0u16.to_be_bytes()); + response.extend_from_slice(&query[12..question_end]); + response + } + + fn write_raw_answer(response: &mut Vec, dns_type: u16, ttl: u32, data: &[u8]) { + response.extend_from_slice(&[0xc0, 0x0c]); + response.extend_from_slice(&dns_type.to_be_bytes()); + response.extend_from_slice(&DNS_CLASS_IN.to_be_bytes()); + response.extend_from_slice(&ttl.to_be_bytes()); + response.extend_from_slice(&(data.len() as u16).to_be_bytes()); + response.extend_from_slice(data); + } + + fn write_txt_answer(response: &mut Vec, ttl: u32, value: &str) { + let mut data = vec![value.len() as u8]; + data.extend_from_slice(value.as_bytes()); + write_raw_answer(response, DNS_TYPE_TXT, ttl, &data); + } + + fn read_tcp_query(stream: &mut StdTcpStream) -> Option> { + let mut len_buf = [0u8; 2]; + stream.read_exact(&mut len_buf).ok()?; + let len = usize::from(u16::from_be_bytes(len_buf)); + let mut query = vec![0u8; len]; + stream.read_exact(&mut query).ok()?; + Some(query) + } + fn read_question_type(raw: &[u8]) -> Option { let end = question_end(raw)?; Some(u16::from_be_bytes([raw[end], raw[end + 1]])) diff --git a/src/dns/mod.rs b/src/dns/mod.rs index bf64b1d..553da84 100644 --- a/src/dns/mod.rs +++ b/src/dns/mod.rs @@ -2,5 +2,6 @@ pub(crate) mod custom; pub mod doh; pub mod inspect; pub mod resolver; +pub(crate) mod transport; pub(crate) mod util; pub(crate) mod wire; diff --git a/src/dns/resolver.rs b/src/dns/resolver.rs index 0209f68..83fa60c 100644 --- a/src/dns/resolver.rs +++ b/src/dns/resolver.rs @@ -2,8 +2,6 @@ use std::fmt; use std::net::{IpAddr, SocketAddr}; use std::time::Duration; -use tokio::net::UdpSocket; - use crate::dns::util::{dns_query_id, udp_dns_timeout}; use crate::dns::wire; @@ -67,30 +65,19 @@ pub async fn lookup_udp_type( ) -> Result, ResolverError> { let id = dns_query_id(); let raw = wire::build_query(id, host, dns_type).map_err(resolver_error)?; - let bind_addr = if server_addr.starts_with('[') { - "[::]:0" - } else { - "0.0.0.0:0" - }; - let socket = UdpSocket::bind(bind_addr) - .await - .map_err(|err| ResolverError(err.to_string()))?; - socket - .connect(server_addr) - .await - .map_err(|err| ResolverError(err.to_string()))?; - socket - .send(&raw) + let response = crate::dns::transport::query_udp(server_addr, &raw, timeout) .await - .map_err(|err| ResolverError(err.to_string()))?; - - let mut buf = vec![0u8; 4096]; - let n = match tokio::time::timeout(timeout, socket.recv(&mut buf)).await { - Ok(Ok(n)) => n, - Ok(Err(err)) => return Err(ResolverError(err.to_string())), - Err(_) => return Err(ResolverError("DNS lookup timed out".to_string())), - }; - dns_records_from_response(&buf[..n], id) + .map_err(resolver_error)?; + match dns_records_from_response(&response, id) { + Ok(records) => Ok(records), + Err(err) if err.is_truncated() => { + let response = crate::dns::transport::query_tcp(server_addr, &raw, timeout) + .await + .map_err(resolver_error)?; + dns_records_from_response(&response, id).map_err(resolver_error) + } + Err(err) => Err(resolver_error(err)), + } } pub fn normalize_udp_dns_server(server: &str) -> Result { @@ -118,8 +105,8 @@ fn dns_server_value_error(server: &str) -> ResolverError { fn dns_records_from_response( raw: &[u8], expected_id: u16, -) -> Result, ResolverError> { - let records = wire::parse_response(raw, expected_id).map_err(resolver_error)?; +) -> Result, wire::WireError> { + let records = wire::parse_response(raw, expected_id)?; Ok(records.into_iter().filter_map(ip_record).collect()) } @@ -154,7 +141,8 @@ fn resolver_error(err: impl ToString) -> ResolverError { #[cfg(test)] mod tests { use super::*; - use std::net::UdpSocket as StdUdpSocket; + use std::io::{Read, Write}; + use std::net::{TcpListener, TcpStream as StdTcpStream, UdpSocket as StdUdpSocket}; use std::sync::{Arc, Barrier, Mutex}; use std::thread; use std::time::{Duration, Instant}; @@ -209,6 +197,20 @@ mod tests { stop(); } + #[tokio::test] + async fn lookup_udp_type_falls_back_to_tcp_on_truncated_udp_response() { + let (addr, stop) = start_truncated_udp_tcp_server(); + + let records = lookup_udp_type(&addr, "example.com", DNS_TYPE_A, Duration::from_secs(1)) + .await + .unwrap(); + + assert_eq!(records.len(), 1); + assert_eq!(records[0].ip.to_string(), "203.0.113.10"); + assert_eq!(records[0].ttl, Some(55)); + stop(); + } + #[tokio::test] async fn lookup_udp_ip_literal_skips_server() { let addrs = lookup_udp("127.0.0.1:9", "127.0.0.1", None).await.unwrap(); @@ -399,6 +401,112 @@ mod tests { }) } + fn start_truncated_udp_tcp_server() -> (String, impl FnOnce()) { + let udp_socket = StdUdpSocket::bind("127.0.0.1:0").unwrap(); + udp_socket + .set_read_timeout(Some(Duration::from_millis(100))) + .unwrap(); + let addr = udp_socket.local_addr().unwrap(); + let tcp_listener = TcpListener::bind(addr).unwrap(); + tcp_listener.set_nonblocking(true).unwrap(); + let done = Arc::new(Mutex::new(false)); + + let udp_done = done.clone(); + let udp_handle = thread::spawn(move || { + let mut buf = [0u8; 512]; + loop { + if *udp_done.lock().unwrap() { + return; + } + let Ok((n, peer)) = udp_socket.recv_from(&mut buf) else { + continue; + }; + if n < 12 { + continue; + } + let response = truncated_response(&buf[..n]); + let _ = udp_socket.send_to(&response, peer); + } + }); + + let tcp_done = done.clone(); + let tcp_handle = thread::spawn(move || { + loop { + if *tcp_done.lock().unwrap() { + return; + } + match tcp_listener.accept() { + Ok((mut stream, _)) => { + let Some(query) = read_tcp_query(&mut stream) else { + continue; + }; + let mut response = success_response(&query); + let mut framed = Vec::with_capacity(response.len() + 2); + framed.extend_from_slice(&(response.len() as u16).to_be_bytes()); + framed.append(&mut response); + let _ = stream.write_all(&framed); + } + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { + thread::sleep(Duration::from_millis(10)); + } + Err(_) => {} + } + } + }); + + (addr.to_string(), move || { + *done.lock().unwrap() = true; + let _ = StdUdpSocket::bind("127.0.0.1:0") + .unwrap() + .send_to(&[0], addr); + let _ = StdTcpStream::connect(addr); + udp_handle.join().unwrap(); + tcp_handle.join().unwrap(); + }) + } + + fn truncated_response(query: &[u8]) -> Vec { + let question_name_end = question_end(query).unwrap_or(12); + let question_end = (question_name_end + 4).min(query.len()); + let mut response = Vec::new(); + response.extend_from_slice(&query[0..2]); + response.extend_from_slice(&0x8380u16.to_be_bytes()); + response.extend_from_slice(&1u16.to_be_bytes()); + response.extend_from_slice(&0u16.to_be_bytes()); + response.extend_from_slice(&0u16.to_be_bytes()); + response.extend_from_slice(&0u16.to_be_bytes()); + response.extend_from_slice(&query[12..question_end]); + response + } + + fn success_response(query: &[u8]) -> Vec { + let query_type = read_question_type(query).unwrap_or_default(); + let question_name_end = question_end(query).unwrap_or(12); + let question_end = (question_name_end + 4).min(query.len()); + let mut response = Vec::new(); + response.extend_from_slice(&query[0..2]); + let answer_count = u16::from(query_type == DNS_TYPE_A); + response.extend_from_slice(&0x8180u16.to_be_bytes()); + response.extend_from_slice(&1u16.to_be_bytes()); + response.extend_from_slice(&answer_count.to_be_bytes()); + response.extend_from_slice(&0u16.to_be_bytes()); + response.extend_from_slice(&0u16.to_be_bytes()); + response.extend_from_slice(&query[12..question_end]); + if query_type == DNS_TYPE_A { + write_answer(&mut response, DNS_TYPE_A, 55, &[203, 0, 113, 10]); + } + response + } + + fn read_tcp_query(stream: &mut StdTcpStream) -> Option> { + let mut len_buf = [0u8; 2]; + stream.read_exact(&mut len_buf).ok()?; + let len = usize::from(u16::from_be_bytes(len_buf)); + let mut query = vec![0u8; len]; + stream.read_exact(&mut query).ok()?; + Some(query) + } + fn write_answer(response: &mut Vec, dns_type: u16, ttl: u32, data: &[u8]) { response.extend_from_slice(&[0xc0, 0x0c]); response.extend_from_slice(&dns_type.to_be_bytes()); diff --git a/src/dns/transport.rs b/src/dns/transport.rs new file mode 100644 index 0000000..719cbea --- /dev/null +++ b/src/dns/transport.rs @@ -0,0 +1,92 @@ +use std::fmt; +use std::time::Duration; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpStream, UdpSocket}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct DnsTransportError(String); + +impl fmt::Display for DnsTransportError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} + +impl std::error::Error for DnsTransportError {} + +pub(crate) async fn query_udp( + server_addr: &str, + query: &[u8], + timeout: Duration, +) -> Result, DnsTransportError> { + let socket = UdpSocket::bind(if server_addr.starts_with('[') { + "[::]:0" + } else { + "0.0.0.0:0" + }) + .await + .map_err(transport_error)?; + socket.connect(server_addr).await.map_err(transport_error)?; + socket.send(query).await.map_err(transport_error)?; + + let mut buf = vec![0u8; 4096]; + let n = match tokio::time::timeout(timeout, socket.recv(&mut buf)).await { + Ok(Ok(n)) => n, + Ok(Err(err)) => return Err(transport_error(err)), + Err(_) => return Err(DnsTransportError("DNS lookup timed out".to_string())), + }; + buf.truncate(n); + Ok(buf) +} + +pub(crate) async fn query_tcp( + server_addr: &str, + query: &[u8], + timeout: Duration, +) -> Result, DnsTransportError> { + if query.len() > usize::from(u16::MAX) { + return Err(DnsTransportError( + "DNS query is too large for TCP".to_string(), + )); + } + + let mut framed_query = Vec::with_capacity(query.len() + 2); + framed_query.extend_from_slice(&(query.len() as u16).to_be_bytes()); + framed_query.extend_from_slice(query); + + match tokio::time::timeout(timeout, query_tcp_inner(server_addr, &framed_query)).await { + Ok(result) => result, + Err(_) => Err(DnsTransportError("DNS lookup timed out".to_string())), + } +} + +async fn query_tcp_inner( + server_addr: &str, + framed_query: &[u8], +) -> Result, DnsTransportError> { + let mut stream = TcpStream::connect(server_addr) + .await + .map_err(transport_error)?; + stream + .write_all(framed_query) + .await + .map_err(transport_error)?; + + let mut len_buf = [0u8; 2]; + stream + .read_exact(&mut len_buf) + .await + .map_err(transport_error)?; + let response_len = usize::from(u16::from_be_bytes(len_buf)); + let mut response = vec![0u8; response_len]; + stream + .read_exact(&mut response) + .await + .map_err(transport_error)?; + Ok(response) +} + +fn transport_error(err: impl ToString) -> DnsTransportError { + DnsTransportError(err.to_string()) +} diff --git a/src/dns/wire.rs b/src/dns/wire.rs index 140fd1e..50c1d91 100644 --- a/src/dns/wire.rs +++ b/src/dns/wire.rs @@ -11,7 +11,12 @@ pub(crate) const TYPE_SRV: u16 = 33; pub(crate) const TYPE_SVCB: u16 = 64; pub(crate) const TYPE_HTTPS: u16 = 65; pub(crate) const TYPE_CAA: u16 = 257; +pub(crate) const TYPE_OPT: u16 = 41; pub(crate) const CLASS_IN: u16 = 1; +pub(crate) const EDNS_UDP_PAYLOAD_SIZE: u16 = 1232; + +const TRUNCATED_RESPONSE: &str = "DNS response was truncated"; +const FLAG_TRUNCATED: u16 = 0x0200; #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct WireError(String); @@ -24,6 +29,12 @@ impl fmt::Display for WireError { impl std::error::Error for WireError {} +impl WireError { + pub(crate) fn is_truncated(&self) -> bool { + self.0 == TRUNCATED_RESPONSE + } +} + #[derive(Debug, Clone, Copy)] pub(crate) struct ResourceRecord<'a> { pub(crate) typ: u16, @@ -40,10 +51,11 @@ pub(crate) fn build_query(id: u16, host: &str, dns_type: u16) -> Result, raw.extend_from_slice(&1u16.to_be_bytes()); raw.extend_from_slice(&0u16.to_be_bytes()); raw.extend_from_slice(&0u16.to_be_bytes()); - raw.extend_from_slice(&0u16.to_be_bytes()); + raw.extend_from_slice(&1u16.to_be_bytes()); write_name(&mut raw, host)?; raw.extend_from_slice(&dns_type.to_be_bytes()); raw.extend_from_slice(&CLASS_IN.to_be_bytes()); + write_opt_record(&mut raw); Ok(raw) } @@ -59,6 +71,9 @@ pub(crate) fn parse_response<'a>( return Err(WireError("mismatched DNS response ID".to_string())); } let flags = read_u16(raw, 2)?; + if flags & FLAG_TRUNCATED != 0 { + return Err(WireError(TRUNCATED_RESPONSE.to_string())); + } let rcode = i32::from(flags & 0x000f); if rcode != 0 { let name = rcode_name(rcode); @@ -67,9 +82,6 @@ pub(crate) fn parse_response<'a>( } return Err(WireError(format!("no such host: {name}"))); } - if flags & 0x0200 != 0 { - return Err(WireError("DNS response was truncated".to_string())); - } let question_count = usize::from(read_u16(raw, 4)?); let answer_count = usize::from(read_u16(raw, 6)?); @@ -195,6 +207,14 @@ fn write_name(raw: &mut Vec, host: &str) -> Result<(), WireError> { Ok(()) } +fn write_opt_record(raw: &mut Vec) { + raw.push(0); + raw.extend_from_slice(&TYPE_OPT.to_be_bytes()); + raw.extend_from_slice(&EDNS_UDP_PAYLOAD_SIZE.to_be_bytes()); + raw.extend_from_slice(&0u32.to_be_bytes()); + raw.extend_from_slice(&0u16.to_be_bytes()); +} + fn rcode_name(status: i32) -> &'static str { match status { 1 => "FormatError", @@ -205,3 +225,25 @@ fn rcode_name(status: i32) -> &'static str { _ => "", } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_query_advertises_edns0_udp_payload_size() { + let query = build_query(0x1234, "example.com", TYPE_A).unwrap(); + + assert_eq!(read_u16(&query, 4).unwrap(), 1); + assert_eq!(read_u16(&query, 10).unwrap(), 1); + + let (_, question_end) = read_name(&query, 12).unwrap(); + let opt = question_end + 4; + assert_eq!(query[opt], 0); + assert_eq!(read_u16(&query, opt + 1).unwrap(), TYPE_OPT); + assert_eq!(read_u16(&query, opt + 3).unwrap(), EDNS_UDP_PAYLOAD_SIZE); + assert_eq!(read_u32(&query, opt + 5).unwrap(), 0); + assert_eq!(read_u16(&query, opt + 9).unwrap(), 0); + assert_eq!(opt + 11, query.len()); + } +}