From baf7cabfe1a50dfaf5d1e47501134da3bff3e9e2 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 25 Oct 2021 21:46:49 +0200 Subject: [PATCH] Add test for middleware that return early (#409) * Add test for middleware that return early Turns out #408 fixed #380. * changelog --- CHANGELOG.md | 3 +++ src/tests/merge.rs | 30 ++++++++++++++++++++++++++++++ src/tests/mod.rs | 33 +++++++++++++++++++++++++++++++-- 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38c333cc45..beeb15da86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -123,6 +123,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Use `Router::fallback` for adding fallback routes ([#408]) - **added:** `Router::fallback` for adding handlers for request that didn't match any routes ([#408]) +- **fixed:** Middleware that return early (such as `tower_http::auth::RequireAuthorization`) + now no longer catch requests that would otherwise be 404s. They also work + correctly with `Router::merge` (previously called `or`) ([#408]) [#339]: https://github.com/tokio-rs/axum/pull/339 [#286]: https://github.com/tokio-rs/axum/pull/286 diff --git a/src/tests/merge.rs b/src/tests/merge.rs index a91dcb57fd..f4ba00cde0 100644 --- a/src/tests/merge.rs +++ b/src/tests/merge.rs @@ -373,3 +373,33 @@ async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() { }) ); } + +#[tokio::test] +async fn middleware_that_return_early() { + let private = Router::new() + .route("/", get(|| async {})) + .layer(RequireAuthorizationLayer::bearer("password")); + + let public = Router::new().route("/public", get(|| async {})); + + let client = TestClient::new(private.merge(public)); + + assert_eq!( + client.get("/").send().await.status(), + StatusCode::UNAUTHORIZED + ); + assert_eq!( + client + .get("/") + .header("authorization", "Bearer password") + .send() + .await + .status(), + StatusCode::OK + ); + assert_eq!( + client.get("/doesnt-exist").send().await.status(), + StatusCode::NOT_FOUND + ); + assert_eq!(client.get("/public").send().await.status(), StatusCode::OK); +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index cf9f3bff44..73b3669226 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -25,8 +25,8 @@ use std::{ task::{Context, Poll}, time::Duration, }; -use tower::timeout::TimeoutLayer; -use tower::{service_fn, ServiceBuilder}; +use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder}; +use tower_http::auth::RequireAuthorizationLayer; use tower_service::Service; pub(crate) use helpers::*; @@ -550,6 +550,35 @@ async fn middleware_applies_to_routes_above() { assert_eq!(res.status(), StatusCode::OK); } +#[tokio::test] +async fn middleware_that_return_early() { + let app = Router::new() + .route("/", get(|| async {})) + .layer(RequireAuthorizationLayer::bearer("password")) + .route("/public", get(|| async {})); + + let client = TestClient::new(app); + + assert_eq!( + client.get("/").send().await.status(), + StatusCode::UNAUTHORIZED + ); + assert_eq!( + client + .get("/") + .header("authorization", "Bearer password") + .send() + .await + .status(), + StatusCode::OK + ); + assert_eq!( + client.get("/doesnt-exist").send().await.status(), + StatusCode::NOT_FOUND + ); + assert_eq!(client.get("/public").send().await.status(), StatusCode::OK); +} + pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {} pub(crate) fn assert_unpin() {}