diff --git a/Cargo.toml b/Cargo.toml index bcb5e73..e81ac0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ edition = "2024" license = "Apache-2.0" rust-version = "1.87.0" -version = "0.15.0" +version = "0.16.0" [workspace.dependencies] anyhow = "1" @@ -43,6 +43,7 @@ futures = "0.3" futures-util = "0.3" hickory-resolver = "0.25.1" html-escape = "0.2.13" +httpdate = "1" humantime = "2" indicatif = "0.18.0" indicatif-log-bridge = "0.2.1" @@ -55,6 +56,7 @@ parking_lot = "0.12" pem = "3" percent-encoding = "2.3" reqwest = "0.12" +rstest = "0.23" sectxtlib = "0.3.1" sequoia-openpgp = { version = "2", default-features = false } serde = "1" @@ -73,10 +75,10 @@ walkdir = "2.4" # internal dependencies -csaf-walker = { version = "0.15.0", path = "csaf", default-features = false } -sbom-walker = { version = "0.15.0", path = "sbom", default-features = false } -walker-common = { version = "0.15.0", path = "common" } -walker-extras = { version = "0.15.0", path = "extras" } +csaf-walker = { version = "0.16.0", path = "csaf", default-features = false } +sbom-walker = { version = "0.16.0", path = "sbom", default-features = false } +walker-common = { version = "0.16.0", path = "common" } +walker-extras = { version = "0.16.0", path = "extras" } [workspace.metadata.release] tag = false diff --git a/common/Cargo.toml b/common/Cargo.toml index 732c379..c795965 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -28,6 +28,7 @@ fluent-uri = { workspace = true } fsquirrel = { workspace = true } futures-util = { workspace = true } html-escape = { workspace = true } +httpdate = { workspace = true } humantime = { workspace = true } indicatif = { workspace = true } indicatif-log-bridge = { workspace = true } @@ -86,6 +87,12 @@ denylist = [ "_semver", ] +[dev-dependencies] +hyper = { version = "1", features = ["server", "http1"] } +hyper-util = { version = "0.1", features = ["tokio"] } +rstest = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "net"] } + [package.metadata.release] enable-features = ["sequoia-openpgp/crypto-nettle"] tag = true diff --git a/common/src/cli/client.rs b/common/src/cli/client.rs index 1b39a91..ab8ed43 100644 --- a/common/src/cli/client.rs +++ b/common/src/cli/client.rs @@ -10,14 +10,18 @@ pub struct ClientArguments { /// Per-request retries count #[arg(short, long, default_value = "5")] pub retries: usize, + + /// Per-request minimum delay after rate limit (429). + #[arg(long, default_value = "10s")] + pub default_retry_after: humantime::Duration, } impl From for FetcherOptions { fn from(value: ClientArguments) -> Self { - FetcherOptions { - timeout: value.timeout.into(), - retries: value.retries, - } + FetcherOptions::new() + .timeout(value.timeout) + .retries(value.retries) + .retry_after(value.default_retry_after.into()) } } diff --git a/common/src/fetcher/mod.rs b/common/src/fetcher/mod.rs index e2c73ce..ccb3666 100644 --- a/common/src/fetcher/mod.rs +++ b/common/src/fetcher/mod.rs @@ -4,6 +4,7 @@ mod data; use backon::{ExponentialBuilder, Retryable}; pub use data::*; +use crate::http::calculate_retry_after_from_response_header; use reqwest::{Client, ClientBuilder, IntoUrl, Method, Response}; use std::fmt::Debug; use std::future::Future; @@ -19,6 +20,8 @@ use url::Url; pub struct Fetcher { client: Client, retries: usize, + /// *default_retry_after* is used when a 429 response does not include a Retry-After header + default_retry_after: Duration, } /// Error when retrieving @@ -26,14 +29,18 @@ pub struct Fetcher { pub enum Error { #[error("Request error: {0}")] Request(#[from] reqwest::Error), + #[error("Rate limited (HTTP 429), retry after {0:?}")] + RateLimited(Duration), } /// Options for the [`Fetcher`] #[non_exhaustive] #[derive(Clone, Debug)] pub struct FetcherOptions { - pub timeout: Duration, - pub retries: usize, + timeout: Duration, + retries: usize, + default_retry_after: Duration, + max_retry_after: Duration, } impl FetcherOptions { @@ -53,6 +60,26 @@ impl FetcherOptions { self.retries = retries; self } + + /// Set the default retry-after duration when a 429 response doesn't include a Retry-After header. + pub fn retry_after(mut self, duration: Duration) -> Self { + if duration > self.max_retry_after { + panic!("Default retry-after cannot be greater than max retry-after (300s)"); + } + self.default_retry_after = duration; + self + } + + /// Set the default retry-after duration when a 429 response doesn't include a Retry-After header + /// and checks the duration against the maximum retry-after. + pub fn retry_after_with_max(mut self, default: Duration, max: Duration) -> Self { + if default > max { + panic!("Default retry-after cannot be greater than max retry-after"); + } + self.default_retry_after = default; + self.max_retry_after = max; + self + } } impl Default for FetcherOptions { @@ -60,6 +87,8 @@ impl Default for FetcherOptions { Self { timeout: Duration::from_secs(30), retries: 5, + default_retry_after: Duration::from_secs(10), + max_retry_after: Duration::from_mins(5), } } } @@ -83,6 +112,7 @@ impl Fetcher { Self { client, retries: options.retries, + default_retry_after: options.default_retry_after, } } @@ -110,19 +140,23 @@ impl Fetcher { let url = url.into_url()?; let retries = self.retries; - let backoff = ExponentialBuilder::default(); - - (|| async { - match self.fetch_once(url.clone(), &processor).await { - Ok(result) => Ok(result), - Err(err) => { - log::info!("Failed to retrieve: {err}"); - Err(err) + let retry = ExponentialBuilder::default().with_max_times(retries); + + (|| async { self.fetch_once(url.clone(), &processor).await }) + .retry(retry) + .adjust(|e, dur| { + if let Error::RateLimited(retry_after) = e { + if let Some(dur_value) = dur + && dur_value > *retry_after + { + return dur; + } + Some(*retry_after) // only use server-provided delay if it's longer + } else { + dur // minimum delay as per backoff strategy } - } - }) - .retry(&backoff.with_max_times(retries)) - .await + }) + .await } async fn fetch_once( @@ -134,6 +168,14 @@ impl Fetcher { log::debug!("Response Status: {}", response.status()); + // Check for rate limiting + if let Some(retry_after) = + calculate_retry_after_from_response_header(&response, self.default_retry_after) + { + log::info!("Rate limited (429), retry after: {:?}", retry_after); + return Err(Error::RateLimited(retry_after)); + } + Ok(processor.process(response).await?) } } diff --git a/common/src/http.rs b/common/src/http.rs new file mode 100644 index 0000000..50a5c39 --- /dev/null +++ b/common/src/http.rs @@ -0,0 +1,56 @@ +use std::time::Duration; + +use reqwest::{Response, StatusCode, header}; + +pub enum RetryAfter { + Duration(Duration), + After(std::time::SystemTime), +} + +/// Parse Retry-After header value. +/// Supports both delay-seconds (numeric) and HTTP-date formats as per RFC7231 +fn parse_retry_after(value: &str) -> Option { + // Try parsing as seconds (numeric) + if let Ok(seconds) = value.parse::() { + return Some(RetryAfter::Duration(Duration::from_secs(seconds))); + } + + // Try parsing as HTTP-date (RFC7231 format) + // Common formats: "Sun, 06 Nov 1994 08:49:37 GMT" (IMF-fixdate preferred) + if let Ok(datetime) = httpdate::parse_http_date(value) { + return Some(RetryAfter::After(datetime)); + } + + None +} + +pub fn calculate_retry_after_from_response_header( + response: &Response, + default_duration: Duration, +) -> Option { + if response.status() == StatusCode::TOO_MANY_REQUESTS { + let retry_after = response + .headers() + .get(header::RETRY_AFTER) + .and_then(|v| v.to_str().ok()) + .and_then(parse_retry_after) + .and_then(|retry| match retry { + RetryAfter::Duration(d) => Some(d), + RetryAfter::After(after) => { + // Calculate duration from now until the specified time + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .ok() + .and_then(|now| { + after + .duration_since(std::time::UNIX_EPOCH) + .ok() + .and_then(|target| target.checked_sub(now)) + }) + } + }) + .unwrap_or(default_duration); + return Some(retry_after); + } + None +} diff --git a/common/src/lib.rs b/common/src/lib.rs index 9c11922..c561cc8 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -4,6 +4,7 @@ pub mod changes; pub mod compression; pub mod fetcher; +pub mod http; pub mod locale; pub mod progress; pub mod report; diff --git a/common/tests/fetcher.rs b/common/tests/fetcher.rs new file mode 100644 index 0000000..f4bcb70 --- /dev/null +++ b/common/tests/fetcher.rs @@ -0,0 +1,258 @@ +use reqwest::StatusCode; +use rstest::rstest; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; +use tokio::net::TcpListener; +use walker_common::fetcher::{Fetcher, FetcherOptions}; + +/// Test helper to start a mock HTTP server +async fn start_mock_server(handler: F) -> String +where + F: Fn(hyper::Request) -> hyper::Response + Send + Sync + 'static, +{ + use hyper::service::service_fn; + use hyper_util::rt::TokioIo; + use std::convert::Infallible; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let handler = Arc::new(handler); + + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); + let handler = handler.clone(); + + tokio::spawn(async move { + let service = service_fn(move |req| { + let handler = handler.clone(); + async move { Ok::<_, Infallible>(handler(req)) } + }); + + if let Err(err) = hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .await + { + eprintln!("Error serving connection: {:?}", err); + } + }); + } + }); + + format!("http://{}", addr) +} + +#[tokio::test] +async fn test_successful_fetch() { + let server = start_mock_server(|_req| { + hyper::Response::builder() + .status(StatusCode::OK) + .body("Hello, World!".to_string()) + .unwrap() + }) + .await; + + let fetcher = Fetcher::new(FetcherOptions::new()).await.unwrap(); + let result: String = fetcher.fetch(&server).await.unwrap(); + + assert_eq!(result, "Hello, World!"); +} + +#[rstest] +#[case::with_retry_after_header(Some("1"), 1)] +#[case::without_retry_after_header(None, 10)] +#[tokio::test] +async fn test_rate_limit_retry_after( + #[case] retry_after_header: Option<&str>, + #[case] expected_min_wait_secs: u64, +) { + let attempt_count = Arc::new(AtomicUsize::new(0)); + let attempt_count_clone = attempt_count.clone(); + let retry_after_header = retry_after_header.map(String::from); + + let server = start_mock_server(move |_req| { + let count = attempt_count_clone.fetch_add(1, Ordering::SeqCst); + + // First request returns 429 + if count == 0 { + let mut builder = hyper::Response::builder().status(StatusCode::TOO_MANY_REQUESTS); + + if let Some(ref header_value) = retry_after_header { + builder = builder.header("Retry-After", header_value.as_str()); + } + + builder.body("Rate limited".to_string()).unwrap() + } else { + // Subsequent requests succeed + hyper::Response::builder() + .status(StatusCode::OK) + .body("Success after retry".to_string()) + .unwrap() + } + }) + .await; + + let fetcher = Fetcher::new(FetcherOptions::new().retries(3)) + .await + .unwrap(); + + let start = std::time::Instant::now(); + let result: String = fetcher.fetch(&server).await.unwrap(); + let elapsed = start.elapsed(); + + assert_eq!(result, "Success after retry"); + assert_eq!(attempt_count.load(Ordering::SeqCst), 2); + + // Should have waited at least the expected duration + assert!( + elapsed >= Duration::from_secs(expected_min_wait_secs), + "Expected at least {}s wait, got {:?}", + expected_min_wait_secs, + elapsed + ); +} + +#[rstest] +#[case::succeeds_after_retries(2, 5, true, 3)] +#[case::exhausts_retries(usize::MAX, 2, false, 3)] +#[tokio::test] +async fn test_retry_behavior( + #[case] fail_until: usize, + #[case] max_retries: usize, + #[case] should_succeed: bool, + #[case] expected_attempts: usize, +) { + let attempt_count = Arc::new(AtomicUsize::new(0)); + let attempt_count_clone = attempt_count.clone(); + + let server = start_mock_server(move |_req| { + let count = attempt_count_clone.fetch_add(1, Ordering::SeqCst); + + if count < fail_until { + hyper::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body("Server error".to_string()) + .unwrap() + } else { + hyper::Response::builder() + .status(StatusCode::OK) + .body("Success".to_string()) + .unwrap() + } + }) + .await; + + let fetcher = Fetcher::new(FetcherOptions::new().retries(max_retries)) + .await + .unwrap(); + + let result: Result = fetcher.fetch(&server).await; + + assert_eq!(result.is_ok(), should_succeed); + if should_succeed { + assert_eq!(result.unwrap(), "Success"); + } + assert_eq!(attempt_count.load(Ordering::SeqCst), expected_attempts); +} + +#[tokio::test] +async fn test_multiple_rate_limits() { + let attempt_count = Arc::new(AtomicUsize::new(0)); + let attempt_count_clone = attempt_count.clone(); + + let server = start_mock_server(move |_req| { + let count = attempt_count_clone.fetch_add(1, Ordering::SeqCst); + + // Return 429 for first two attempts + if count < 2 { + hyper::Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .header("Retry-After", "1") + .body("Rate limited".to_string()) + .unwrap() + } else { + hyper::Response::builder() + .status(StatusCode::OK) + .body("Success".to_string()) + .unwrap() + } + }) + .await; + + let fetcher = Fetcher::new(FetcherOptions::new().retries(5)) + .await + .unwrap(); + + let start = std::time::Instant::now(); + let result: String = fetcher.fetch(&server).await.unwrap(); + let elapsed = start.elapsed(); + + assert_eq!(result, "Success"); + assert_eq!(attempt_count.load(Ordering::SeqCst), 3); + + // Should have waited at least 2 seconds (1 second for each 429) + assert!( + elapsed >= Duration::from_secs(2), + "Expected at least 2s wait, got {:?}", + elapsed + ); +} + +#[rstest] +#[case::custom_default_2_seconds(2)] +#[case::custom_default_3_seconds(3)] +#[tokio::test] +async fn test_configurable_default_retry_after(#[case] custom_default_secs: u64) { + let attempt_count = Arc::new(AtomicUsize::new(0)); + let attempt_count_clone = attempt_count.clone(); + + let server = start_mock_server(move |_req| { + let count = attempt_count_clone.fetch_add(1, Ordering::SeqCst); + + // First request returns 429 without Retry-After header + if count == 0 { + hyper::Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .body("Rate limited".to_string()) + .unwrap() + } else { + hyper::Response::builder() + .status(StatusCode::OK) + .body("Success".to_string()) + .unwrap() + } + }) + .await; + + let fetcher = Fetcher::new( + FetcherOptions::new() + .retries(3) + .retry_after(Duration::from_secs(custom_default_secs)), + ) + .await + .unwrap(); + + let start = std::time::Instant::now(); + let result: String = fetcher.fetch(&server).await.unwrap(); + let elapsed = start.elapsed(); + + assert_eq!(result, "Success"); + assert_eq!(attempt_count.load(Ordering::SeqCst), 2); + + // Should have waited at least the custom default + assert!( + elapsed >= Duration::from_secs(custom_default_secs), + "Expected at least {}s wait (custom default), got {:?}", + custom_default_secs, + elapsed + ); + + // Should not have waited 10 seconds (the standard default) + assert!( + elapsed < Duration::from_secs(10), + "Expected less than 10s, got {:?}", + elapsed + ); +} diff --git a/extras/src/visitors/send/clap.rs b/extras/src/visitors/send/clap.rs index a170139..fb6f606 100644 --- a/extras/src/visitors/send/clap.rs +++ b/extras/src/visitors/send/clap.rs @@ -107,12 +107,9 @@ impl SendArguments { ) .await?; - Ok(SendVisitor { - url: target, - sender, - retries, - min_delay: Some(min_delay.into()), - max_delay: Some(max_delay.into()), - }) + Ok(SendVisitor::new(target, sender) + .retries(retries) + .min_delay(min_delay) + .max_delay(max_delay)) } } diff --git a/extras/src/visitors/send/mod.rs b/extras/src/visitors/send/mod.rs index e72bf05..4049959 100644 --- a/extras/src/visitors/send/mod.rs +++ b/extras/src/visitors/send/mod.rs @@ -2,7 +2,10 @@ use backon::{ExponentialBuilder, Retryable}; use bytes::Bytes; use reqwest::{Body, Method, StatusCode, Url, header}; use std::time::Duration; -use walker_common::sender::{self, HttpSender}; +use walker_common::{ + http::calculate_retry_after_from_response_header, + sender::{self, HttpSender}, +}; #[cfg(feature = "sbom-walker")] mod sbom; @@ -31,6 +34,8 @@ pub enum SendError { Server(StatusCode), #[error("unexpected status: {0}")] UnexpectedStatus(StatusCode), + #[error("Rate limited (HTTP 429), retry after {0:?}")] + RateLimited(Duration), } /// Send data to a remote sink. @@ -44,13 +49,16 @@ pub struct SendVisitor { pub sender: HttpSender, /// The number of retries in case of a server or transmission failure - pub retries: usize, + retries: usize, - /// The minimum delay between retries - pub min_delay: Option, + /// The minimum delay between retries, will be overruled by the retry-after header if present. + min_delay: Option, - /// The maximum delay between retries - pub max_delay: Option, + /// The maximum delay between retries, will be overruled by the retry-after header if present. + max_delay: Option, + + /// The default retry-after duration when a 429 response doesn't include a Retry-After header + default_retry_after: Duration, } impl SendVisitor { @@ -61,6 +69,7 @@ impl SendVisitor { retries: 0, min_delay: None, max_delay: None, + default_retry_after: Duration::from_secs(10), } } @@ -120,6 +129,18 @@ impl SendVisitor { .await .map_err(|err| SendOnceError::Temporary(err.into()))?; + if let Some(retry_after) = + calculate_retry_after_from_response_header(&response, self.default_retry_after) + { + log::info!( + "Rate limited (429) when uploading {name}, retry after: {:?}", + retry_after + ); + return Err(SendOnceError::Temporary(SendError::RateLimited( + retry_after, + ))); + } + let status = response.status(); if status.is_success() { @@ -157,10 +178,18 @@ impl SendVisitor { Ok( (|| async { self.send_once(name, data.clone(), &customizer).await }) .retry(retry) - .sleep(tokio::time::sleep) .when(|e| matches!(e, SendOnceError::Temporary(_))) - .notify(|err, dur| { - log::info!("retrying {err} after {dur:?}"); + .adjust(|e, dur| { + if let SendOnceError::Temporary(SendError::RateLimited(retry_after)) = e { + if let Some(dur_value) = dur + && dur_value > *retry_after + { + return dur; + } + Some(*retry_after) // only use server-provided delay if it's longer + } else { + dur // minimum delay as per backoff strategy + } }) .await?, )