Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ flate2 = "1.0.35"
futures-util = "0.3.30"
hex = "0.4.3"
home = "0.5.11"
http = "1.1.0"
indicatif = "0.17.8"
indoc = "2.0.5"
liblzma = "0.3.4"
Expand Down
1 change: 0 additions & 1 deletion postgresql_archive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ async-trait = { workspace = true }
flate2 = { workspace = true }
futures-util = { workspace = true }
hex = { workspace = true }
http = { workspace = true }
liblzma = { workspace = true }
md-5 = { workspace = true, optional = true }
num-format = { workspace = true }
Expand Down
70 changes: 24 additions & 46 deletions postgresql_archive/src/repository/github/repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use crate::Error::{
use crate::{hasher, matcher, Result};
use async_trait::async_trait;
use futures_util::StreamExt;
use http::{header, Extensions};
use regex::Regex;
use reqwest::{Request, Response};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next};
use reqwest::header::HeaderMap;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use reqwest_tracing::TracingMiddleware;
Expand Down Expand Up @@ -114,6 +113,7 @@ impl GitHub {
loop {
let request = client
.get(&self.releases_url)
.headers(Self::headers())
.query(&[("page", page.to_string().as_str()), ("per_page", "100")]);
let response = request.send().await?.error_for_status()?;
let response_releases = response.json::<Vec<Release>>().await?;
Expand Down Expand Up @@ -199,6 +199,20 @@ impl GitHub {

Ok((asset, asset_hash, asset_hasher_fn))
}

/// Returns the headers for the GitHub request.
fn headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.append(
GITHUB_API_VERSION_HEADER,
GITHUB_API_VERSION.parse().unwrap(),
);
headers.append("User-Agent", USER_AGENT.parse().unwrap());
if let Some(token) = &*GITHUB_TOKEN {
headers.append("Authorization", format!("Bearer {token}").parse().unwrap());
}
headers
}
}

#[async_trait]
Expand All @@ -224,7 +238,9 @@ impl Repository for GitHub {

let client = reqwest_client();
debug!("Downloading archive {}", asset.browser_download_url);
let request = client.get(&asset.browser_download_url);
let request = client
.get(&asset.browser_download_url)
.headers(Self::headers());
let response = request.send().await?.error_for_status()?;
#[cfg(feature = "indicatif")]
let span = tracing::Span::current();
Expand Down Expand Up @@ -257,7 +273,9 @@ impl Repository for GitHub {
"Downloading archive hash {}",
asset_hash.browser_download_url
);
let request = client.get(&asset_hash.browser_download_url);
let request = client
.get(&asset_hash.browser_download_url)
.headers(Self::headers());
let response = request.send().await?.error_for_status()?;
let text = response.text().await?;
let re = Regex::new(&format!(r"[0-9a-f]{{{hash_len}}}"))?;
Expand All @@ -281,51 +299,11 @@ impl Repository for GitHub {
}
}

/// Middleware to add headers to the request. If a GitHub token is set, then it is added as a
/// bearer token. This is used to authenticate with the GitHub API to increase the rate limit.
#[derive(Debug)]
struct GithubMiddleware;

impl GithubMiddleware {
#[expect(clippy::unnecessary_wraps)]
fn add_headers(request: &mut Request) -> Result<()> {
let headers = request.headers_mut();
headers.append(
GITHUB_API_VERSION_HEADER,
GITHUB_API_VERSION.parse().unwrap(),
);
headers.append(header::USER_AGENT, USER_AGENT.parse().unwrap());
if let Some(token) = &*GITHUB_TOKEN {
headers.append(
header::AUTHORIZATION,
format!("Bearer {token}").parse().unwrap(),
);
}
Ok(())
}
}

#[async_trait::async_trait]
impl Middleware for GithubMiddleware {
async fn handle(
&self,
mut request: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
match GithubMiddleware::add_headers(&mut request) {
Ok(()) => next.run(request, extensions).await,
Err(error) => Err(reqwest_middleware::Error::Middleware(error.into())),
}
}
}

/// Creates a new reqwest client with middleware for tracing, GitHub, and retrying transient errors.
/// Creates a new reqwest client with middleware for tracing, and retrying transient errors.
fn reqwest_client() -> ClientWithMiddleware {
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
ClientBuilder::new(reqwest::Client::new())
.with(TracingMiddleware::default())
.with(GithubMiddleware)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build()
}
Expand Down
49 changes: 13 additions & 36 deletions postgresql_archive/src/repository/maven/repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ use crate::Error::{ArchiveHashMismatch, ParseError, RepositoryFailure, VersionNo
use crate::{hasher, Result};
use async_trait::async_trait;
use futures_util::StreamExt;
use http::{header, Extensions};
use reqwest::{Request, Response};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next};
use reqwest::header::HeaderMap;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use reqwest_tracing::TracingMiddleware;
Expand Down Expand Up @@ -58,7 +57,7 @@ impl Maven {
debug!("Attempting to locate release for version requirement {version_req}");
let client = reqwest_client();
let url = format!("{}/maven-metadata.xml", self.url);
let request = client.get(&url);
let request = client.get(&url).headers(Self::headers());
let response = request.send().await?.error_for_status()?;
let text = response.text().await?;
let metadata: Metadata =
Expand Down Expand Up @@ -86,6 +85,13 @@ impl Maven {
None => Err(VersionNotFound(version_req.to_string())),
}
}

/// Returns the headers for the Maven request.
fn headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.append("User-Agent", USER_AGENT.parse().unwrap());
headers
}
}

#[async_trait]
Expand Down Expand Up @@ -125,13 +131,13 @@ impl Repository for Maven {
let archive_hash_url = format!("{archive_url}.{extension}");
let client = reqwest_client();
debug!("Downloading archive hash {archive_hash_url}");
let request = client.get(&archive_hash_url);
let request = client.get(&archive_hash_url).headers(Self::headers());
let response = request.send().await?.error_for_status()?;
let hash = response.text().await?;
debug!("Archive hash {archive_hash_url} downloaded: {}", hash.len(),);

debug!("Downloading archive {archive_url}");
let request = client.get(&archive_url);
let request = client.get(&archive_url).headers(Self::headers());
let response = request.send().await?.error_for_status()?;
#[cfg(feature = "indicatif")]
let span = tracing::Span::current();
Expand Down Expand Up @@ -159,40 +165,11 @@ impl Repository for Maven {
}
}

/// Middleware to add headers to the request.
#[derive(Debug)]
struct MavenMiddleware;

impl MavenMiddleware {
#[expect(clippy::unnecessary_wraps)]
fn add_headers(request: &mut Request) -> Result<()> {
let headers = request.headers_mut();
headers.append(header::USER_AGENT, USER_AGENT.parse().unwrap());
Ok(())
}
}

#[async_trait::async_trait]
impl Middleware for MavenMiddleware {
async fn handle(
&self,
mut request: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
match MavenMiddleware::add_headers(&mut request) {
Ok(()) => next.run(request, extensions).await,
Err(error) => Err(reqwest_middleware::Error::Middleware(error.into())),
}
}
}

/// Creates a new reqwest client with middleware for tracing, GitHub, and retrying transient errors.
/// Creates a new reqwest client with middleware for tracing, and retrying transient errors.
fn reqwest_client() -> ClientWithMiddleware {
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
ClientBuilder::new(reqwest::Client::new())
.with(TracingMiddleware::default())
.with(MavenMiddleware)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build()
}
Expand Down