diff --git a/Cargo.lock b/Cargo.lock index 2a616f0..e054c15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2011,7 +2011,6 @@ dependencies = [ "flate2", "futures-util", "hex", - "http", "liblzma", "md-5", "num-format", diff --git a/Cargo.toml b/Cargo.toml index 030fc6c..a9db352 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,6 @@ flate2 = "1.0.35" futures-util = "0.3.30" hex = "0.4.3" home = "0.5.11" -http = "1.1.0" indicatif = "0.17.8" indoc = "2.0.5" liblzma = "0.3.4" diff --git a/postgresql_archive/Cargo.toml b/postgresql_archive/Cargo.toml index f6cbdf6..ef903d9 100644 --- a/postgresql_archive/Cargo.toml +++ b/postgresql_archive/Cargo.toml @@ -15,7 +15,6 @@ async-trait = { workspace = true } flate2 = { workspace = true } futures-util = { workspace = true } hex = { workspace = true } -http = { workspace = true } liblzma = { workspace = true } md-5 = { workspace = true, optional = true } num-format = { workspace = true } diff --git a/postgresql_archive/src/repository/github/repository.rs b/postgresql_archive/src/repository/github/repository.rs index c1eb49e..7c43389 100644 --- a/postgresql_archive/src/repository/github/repository.rs +++ b/postgresql_archive/src/repository/github/repository.rs @@ -8,10 +8,9 @@ use crate::Error::{ use crate::{hasher, matcher, Result}; use async_trait::async_trait; use futures_util::StreamExt; -use http::{header, Extensions}; use regex::Regex; -use reqwest::{Request, Response}; -use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next}; +use reqwest::header::HeaderMap; +use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::RetryTransientMiddleware; use reqwest_tracing::TracingMiddleware; @@ -114,6 +113,7 @@ impl GitHub { loop { let request = client .get(&self.releases_url) + .headers(Self::headers()) .query(&[("page", page.to_string().as_str()), ("per_page", "100")]); let response = request.send().await?.error_for_status()?; let response_releases = response.json::>().await?; @@ -199,6 +199,20 @@ impl GitHub { Ok((asset, asset_hash, asset_hasher_fn)) } + + /// Returns the headers for the GitHub request. + fn headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.append( + GITHUB_API_VERSION_HEADER, + GITHUB_API_VERSION.parse().unwrap(), + ); + headers.append("User-Agent", USER_AGENT.parse().unwrap()); + if let Some(token) = &*GITHUB_TOKEN { + headers.append("Authorization", format!("Bearer {token}").parse().unwrap()); + } + headers + } } #[async_trait] @@ -224,7 +238,9 @@ impl Repository for GitHub { let client = reqwest_client(); debug!("Downloading archive {}", asset.browser_download_url); - let request = client.get(&asset.browser_download_url); + let request = client + .get(&asset.browser_download_url) + .headers(Self::headers()); let response = request.send().await?.error_for_status()?; #[cfg(feature = "indicatif")] let span = tracing::Span::current(); @@ -257,7 +273,9 @@ impl Repository for GitHub { "Downloading archive hash {}", asset_hash.browser_download_url ); - let request = client.get(&asset_hash.browser_download_url); + let request = client + .get(&asset_hash.browser_download_url) + .headers(Self::headers()); let response = request.send().await?.error_for_status()?; let text = response.text().await?; let re = Regex::new(&format!(r"[0-9a-f]{{{hash_len}}}"))?; @@ -281,51 +299,11 @@ impl Repository for GitHub { } } -/// Middleware to add headers to the request. If a GitHub token is set, then it is added as a -/// bearer token. This is used to authenticate with the GitHub API to increase the rate limit. -#[derive(Debug)] -struct GithubMiddleware; - -impl GithubMiddleware { - #[expect(clippy::unnecessary_wraps)] - fn add_headers(request: &mut Request) -> Result<()> { - let headers = request.headers_mut(); - headers.append( - GITHUB_API_VERSION_HEADER, - GITHUB_API_VERSION.parse().unwrap(), - ); - headers.append(header::USER_AGENT, USER_AGENT.parse().unwrap()); - if let Some(token) = &*GITHUB_TOKEN { - headers.append( - header::AUTHORIZATION, - format!("Bearer {token}").parse().unwrap(), - ); - } - Ok(()) - } -} - -#[async_trait::async_trait] -impl Middleware for GithubMiddleware { - async fn handle( - &self, - mut request: Request, - extensions: &mut Extensions, - next: Next<'_>, - ) -> reqwest_middleware::Result { - match GithubMiddleware::add_headers(&mut request) { - Ok(()) => next.run(request, extensions).await, - Err(error) => Err(reqwest_middleware::Error::Middleware(error.into())), - } - } -} - -/// Creates a new reqwest client with middleware for tracing, GitHub, and retrying transient errors. +/// Creates a new reqwest client with middleware for tracing, and retrying transient errors. fn reqwest_client() -> ClientWithMiddleware { let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); ClientBuilder::new(reqwest::Client::new()) .with(TracingMiddleware::default()) - .with(GithubMiddleware) .with(RetryTransientMiddleware::new_with_policy(retry_policy)) .build() } diff --git a/postgresql_archive/src/repository/maven/repository.rs b/postgresql_archive/src/repository/maven/repository.rs index 83f7ec6..0d7d920 100644 --- a/postgresql_archive/src/repository/maven/repository.rs +++ b/postgresql_archive/src/repository/maven/repository.rs @@ -5,9 +5,8 @@ use crate::Error::{ArchiveHashMismatch, ParseError, RepositoryFailure, VersionNo use crate::{hasher, Result}; use async_trait::async_trait; use futures_util::StreamExt; -use http::{header, Extensions}; -use reqwest::{Request, Response}; -use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next}; +use reqwest::header::HeaderMap; +use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::RetryTransientMiddleware; use reqwest_tracing::TracingMiddleware; @@ -58,7 +57,7 @@ impl Maven { debug!("Attempting to locate release for version requirement {version_req}"); let client = reqwest_client(); let url = format!("{}/maven-metadata.xml", self.url); - let request = client.get(&url); + let request = client.get(&url).headers(Self::headers()); let response = request.send().await?.error_for_status()?; let text = response.text().await?; let metadata: Metadata = @@ -86,6 +85,13 @@ impl Maven { None => Err(VersionNotFound(version_req.to_string())), } } + + /// Returns the headers for the Maven request. + fn headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.append("User-Agent", USER_AGENT.parse().unwrap()); + headers + } } #[async_trait] @@ -125,13 +131,13 @@ impl Repository for Maven { let archive_hash_url = format!("{archive_url}.{extension}"); let client = reqwest_client(); debug!("Downloading archive hash {archive_hash_url}"); - let request = client.get(&archive_hash_url); + let request = client.get(&archive_hash_url).headers(Self::headers()); let response = request.send().await?.error_for_status()?; let hash = response.text().await?; debug!("Archive hash {archive_hash_url} downloaded: {}", hash.len(),); debug!("Downloading archive {archive_url}"); - let request = client.get(&archive_url); + let request = client.get(&archive_url).headers(Self::headers()); let response = request.send().await?.error_for_status()?; #[cfg(feature = "indicatif")] let span = tracing::Span::current(); @@ -159,40 +165,11 @@ impl Repository for Maven { } } -/// Middleware to add headers to the request. -#[derive(Debug)] -struct MavenMiddleware; - -impl MavenMiddleware { - #[expect(clippy::unnecessary_wraps)] - fn add_headers(request: &mut Request) -> Result<()> { - let headers = request.headers_mut(); - headers.append(header::USER_AGENT, USER_AGENT.parse().unwrap()); - Ok(()) - } -} - -#[async_trait::async_trait] -impl Middleware for MavenMiddleware { - async fn handle( - &self, - mut request: Request, - extensions: &mut Extensions, - next: Next<'_>, - ) -> reqwest_middleware::Result { - match MavenMiddleware::add_headers(&mut request) { - Ok(()) => next.run(request, extensions).await, - Err(error) => Err(reqwest_middleware::Error::Middleware(error.into())), - } - } -} - -/// Creates a new reqwest client with middleware for tracing, GitHub, and retrying transient errors. +/// Creates a new reqwest client with middleware for tracing, and retrying transient errors. fn reqwest_client() -> ClientWithMiddleware { let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); ClientBuilder::new(reqwest::Client::new()) .with(TracingMiddleware::default()) - .with(MavenMiddleware) .with(RetryTransientMiddleware::new_with_policy(retry_policy)) .build() }