Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,12 @@ procfs = "0.15.1"
criterion = "0.5.1"
kuchikiki = "0.8"
http02 = { version = "0.2.11", package = "http"}
http-body-util = "0.1.0"
rand = "0.8"
mockito = "1.0.2"
test-case = "3.0.0"
reqwest = { version = "0.12", features = ["blocking", "json"] }
tower = { version = "0.5.1", features = ["util"] }
aws-smithy-types = "1.0.1"
aws-smithy-runtime = {version = "1.0.1", features = ["client", "test-util"]}
aws-smithy-http = "0.60.0"
Expand Down
4 changes: 4 additions & 0 deletions src/bin/cratesfyi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,10 @@ impl Context for BinContext {
};
}

async fn async_pool(&self) -> Result<Pool> {
self.pool()
}

fn pool(&self) -> Result<Pool> {
Ok(self
.pool
Expand Down
1 change: 1 addition & 0 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub trait Context {
async fn async_storage(&self) -> Result<Arc<AsyncStorage>>;
async fn cdn(&self) -> Result<Arc<CdnBackend>>;
fn pool(&self) -> Result<Pool>;
async fn async_pool(&self) -> Result<Pool>;
fn service_metrics(&self) -> Result<Arc<ServiceMetrics>>;
fn instance_metrics(&self) -> Result<Arc<InstanceMetrics>>;
fn index(&self) -> Result<Arc<Index>>;
Expand Down
112 changes: 108 additions & 4 deletions src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use crate::{
ServiceMetrics,
};
use anyhow::Context as _;
use axum::async_trait;
use axum::{async_trait, body::Body, http::Request, response::Response as AxumResponse, Router};
use fn_error_context::context;
use futures_util::{stream::TryStreamExt, FutureExt};
use http_body_util::BodyExt; // for `collect`
use once_cell::sync::OnceCell;
use reqwest::{
blocking::{Client, ClientBuilder, RequestBuilder, Response},
Expand All @@ -27,6 +28,7 @@ use std::{
};
use tokio::runtime::{Builder, Runtime};
use tokio::sync::oneshot::Sender;
use tower::ServiceExt;
use tracing::{debug, error, instrument, trace};

#[track_caller]
Expand Down Expand Up @@ -126,7 +128,6 @@ pub(crate) fn assert_success(path: &str, web: &TestFrontend) -> Result<()> {
assert!(status.is_success(), "failed to GET {path}: {status}");
Ok(())
}

/// Make sure that a URL returns a status code between 200-299,
/// also check the cache-control headers.
pub(crate) fn assert_success_cached(
Expand Down Expand Up @@ -259,6 +260,96 @@ pub(crate) fn assert_redirect_cached(
Ok(redirect_response)
}

pub(crate) trait AxumResponseTestExt {
async fn text(self) -> String;
}

impl AxumResponseTestExt for axum::response::Response {
async fn text(self) -> String {
String::from_utf8_lossy(&self.into_body().collect().await.unwrap().to_bytes()).to_string()
}
}

pub(crate) trait AxumRouterTestExt {
async fn assert_success(&self, path: &str) -> Result<()>;
async fn get(&self, path: &str) -> Result<AxumResponse>;
async fn assert_redirect_common(
&self,
path: &str,
expected_target: &str,
) -> Result<AxumResponse>;
async fn assert_redirect(&self, path: &str, expected_target: &str) -> Result<AxumResponse>;
}

impl AxumRouterTestExt for axum::Router {
/// Make sure that a URL returns a status code between 200-299
async fn assert_success(&self, path: &str) -> Result<()> {
let response = self
.clone()
.oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
.await?;

let status = response.status();
assert!(status.is_success(), "failed to GET {path}: {status}");
Ok(())
}
/// simple `get` method
async fn get(&self, path: &str) -> Result<AxumResponse> {
Ok(self
.clone()
.oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
.await?)
}

async fn assert_redirect_common(
&self,
path: &str,
expected_target: &str,
) -> Result<AxumResponse> {
let response = self.get(path).await?;
let status = response.status();
if !status.is_redirection() {
anyhow::bail!("non-redirect from GET {path}: {status}");
}

let redirect_target = response
.headers()
.get("Location")
.context("missing 'Location' header")?
.to_str()
.context("non-ASCII redirect")?;

// FIXME: not sure we need this
// if !expected_target.starts_with("http") {
// // TODO: Should be able to use Url::make_relative,
// // but https://github.com/servo/rust-url/issues/766
// let base = format!("http://{}", web.server_addr());
// redirect_target = redirect_target
// .strip_prefix(&base)
// .unwrap_or(redirect_target);
// }

if redirect_target != expected_target {
anyhow::bail!("got redirect to {redirect_target}");
}

Ok(response)
}

#[context("expected redirect from {path} to {expected_target}")]
async fn assert_redirect(&self, path: &str, expected_target: &str) -> Result<AxumResponse> {
let redirect_response = self.assert_redirect_common(path, expected_target).await?;

let response = self.get(expected_target).await?;
let status = response.status();
if !status.is_success() {
anyhow::bail!("failed to GET {expected_target}: {status}");
}

Ok(redirect_response)
}
}

pub(crate) struct TestEnvironment {
build_queue: OnceCell<Arc<BuildQueue>>,
async_build_queue: tokio::sync::OnceCell<Arc<AsyncBuildQueue>>,
Expand Down Expand Up @@ -534,6 +625,13 @@ impl TestEnvironment {
self.runtime().block_on(self.async_fake_release())
}

pub(crate) async fn web_app(&self) -> Router {
let template_data = Arc::new(TemplateData::new(1).unwrap());
build_axum_app(self, template_data)
.await
.expect("could not build axum app")
}

pub(crate) async fn async_fake_release(&self) -> fakes::FakeRelease {
fakes::FakeRelease::new(
self.async_db().await,
Expand Down Expand Up @@ -569,6 +667,10 @@ impl Context for TestEnvironment {
Ok(TestEnvironment::cdn(self).await)
}

async fn async_pool(&self) -> Result<Pool> {
Ok(self.async_db().await.pool())
}

fn pool(&self) -> Result<Pool> {
Ok(self.db().pool())
}
Expand Down Expand Up @@ -734,10 +836,12 @@ impl TestFrontend {
let (tx, rx) = tokio::sync::oneshot::channel::<()>();

debug!("building axum app");
let axum_app = build_axum_app(context, template_data).expect("could not build axum app");
let runtime = context.runtime().unwrap();
let axum_app = runtime
.block_on(build_axum_app(context, template_data))
.expect("could not build axum app");

let handle = thread::spawn({
let runtime = context.runtime().unwrap();
move || {
runtime.block_on(async {
axum::serve(axum_listener, axum_app.into_make_service())
Expand Down
27 changes: 15 additions & 12 deletions src/web/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,16 +393,16 @@ async fn set_sentry_transaction_name_from_axum_route(
next.run(request).await
}

fn apply_middleware(
async fn apply_middleware(
router: AxumRouter,
context: &dyn Context,
template_data: Option<Arc<TemplateData>>,
) -> Result<AxumRouter> {
let config = context.config()?;
let has_templates = template_data.is_some();
let runtime = context.runtime()?;
let async_storage = runtime.block_on(context.async_storage())?;
let build_queue = runtime.block_on(context.async_build_queue())?;

let async_storage = context.async_storage().await?;
let build_queue = context.async_build_queue().await?;

Ok(router.layer(
ServiceBuilder::new()
Expand All @@ -419,12 +419,11 @@ fn apply_middleware(
.then_some(middleware::from_fn(log_timeouts_to_sentry)),
))
.layer(option_layer(config.request_timeout.map(TimeoutLayer::new)))
.layer(Extension(context.pool()?))
.layer(Extension(context.async_pool().await?))
.layer(Extension(build_queue))
.layer(Extension(context.service_metrics()?))
.layer(Extension(context.instance_metrics()?))
.layer(Extension(context.config()?))
.layer(Extension(context.storage()?))
.layer(Extension(async_storage))
.layer(option_layer(template_data.map(Extension)))
.layer(middleware::from_fn(csp::csp_middleware))
Expand All @@ -435,15 +434,15 @@ fn apply_middleware(
))
}

pub(crate) fn build_axum_app(
pub(crate) async fn build_axum_app(
context: &dyn Context,
template_data: Arc<TemplateData>,
) -> Result<AxumRouter, Error> {
apply_middleware(routes::build_axum_routes(), context, Some(template_data))
apply_middleware(routes::build_axum_routes(), context, Some(template_data)).await
}

pub(crate) fn build_metrics_axum_app(context: &dyn Context) -> Result<AxumRouter, Error> {
apply_middleware(routes::build_metric_routes(), context, None)
pub(crate) async fn build_metrics_axum_app(context: &dyn Context) -> Result<AxumRouter, Error> {
apply_middleware(routes::build_metric_routes(), context, None).await
}

pub fn start_background_metrics_webserver(
Expand All @@ -458,8 +457,10 @@ pub fn start_background_metrics_webserver(
axum_addr.port()
);

let metrics_axum_app = build_metrics_axum_app(context)?.into_make_service();
let runtime = context.runtime()?;
let metrics_axum_app = runtime
.block_on(build_metrics_axum_app(context))?
.into_make_service();

runtime.spawn(async move {
match tokio::net::TcpListener::bind(axum_addr)
Expand Down Expand Up @@ -501,8 +502,10 @@ pub fn start_web_server(addr: Option<SocketAddr>, context: &dyn Context) -> Resu
context.storage()?;
context.repository_stats_updater()?;

let app = build_axum_app(context, template_data)?.into_make_service();
context.runtime()?.block_on(async {
let app = build_axum_app(context, template_data)
.await?
.into_make_service();
let listener = tokio::net::TcpListener::bind(axum_addr)
.await
.context("error binding socket for metrics web server")?;
Expand Down
Loading
Loading