From 9aa14d0e24dcd16c28c67031e7aec1882e4e6e26 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Mon, 9 May 2022 11:38:49 -0500 Subject: [PATCH 1/7] Add support for checking which route routed a request - Adds RouteType and CatcherType traits to identify routes and catchers - RouteType and CatcherType are implemented via codegen for attribute macros - Adds routed_by and caught_by methods to local client response - Adds catcher to RequestState - Updates route in RequestState to None if a catcher is run - examples/hello tests now also check which route generated the reponse - Adds DefaultCatcher type to represent Rocket's default catcher - FileServer now implements RouteType --- core/codegen/src/attribute/catch/mod.rs | 3 ++ core/codegen/src/attribute/route/mod.rs | 3 ++ core/codegen/src/exports.rs | 2 + core/lib/src/catcher/catcher.rs | 23 ++++++++++- core/lib/src/fs/server.rs | 7 +++- core/lib/src/local/asynchronous/response.rs | 6 +++ core/lib/src/local/blocking/response.rs | 8 +++- core/lib/src/local/response.rs | 44 +++++++++++++++++++++ core/lib/src/request/request.rs | 32 +++++++++++++-- core/lib/src/route/route.rs | 42 +++++++++++++++++--- core/lib/src/router/router.rs | 1 + core/lib/src/server.rs | 6 ++- examples/hello/src/tests.rs | 6 +++ 13 files changed, 169 insertions(+), 14 deletions(-) diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index 09528c71e2..0911936edd 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -62,6 +62,8 @@ pub fn _catch( /// Rocket code generated proxy structure. #deprecated #vis struct #user_catcher_fn_name { } + impl #CatcherType for #user_catcher_fn_name { } + /// Rocket code generated proxy static conversion implementations. #[allow(nonstandard_style, deprecated, clippy::style)] impl #user_catcher_fn_name { @@ -83,6 +85,7 @@ pub fn _catch( name: stringify!(#user_catcher_fn_name), code: #status_code, handler: monomorphized_function, + route_type: #_Box::new(self), } } diff --git a/core/codegen/src/attribute/route/mod.rs b/core/codegen/src/attribute/route/mod.rs index 1079f3066b..c4611a063a 100644 --- a/core/codegen/src/attribute/route/mod.rs +++ b/core/codegen/src/attribute/route/mod.rs @@ -342,6 +342,8 @@ fn codegen_route(route: Route) -> Result { /// Rocket code generated proxy structure. #deprecated #vis struct #handler_fn_name { } + impl #RouteType for #handler_fn_name {} + /// Rocket code generated proxy static conversion implementations. #[allow(nonstandard_style, deprecated, clippy::style)] impl #handler_fn_name { @@ -368,6 +370,7 @@ fn codegen_route(route: Route) -> Result { format: #format, rank: #rank, sentinels: #sentinels, + route_type: #_Box::new(self), } } diff --git a/core/codegen/src/exports.rs b/core/codegen/src/exports.rs index d6c1f4d911..ee8d553e92 100644 --- a/core/codegen/src/exports.rs +++ b/core/codegen/src/exports.rs @@ -98,7 +98,9 @@ define_exported_paths! { StaticRouteInfo => ::rocket::StaticRouteInfo, StaticCatcherInfo => ::rocket::StaticCatcherInfo, Route => ::rocket::Route, + RouteType => ::rocket::route::RouteType, Catcher => ::rocket::Catcher, + CatcherType => ::rocket::catcher::CatcherType, SmallVec => ::rocket::http::private::SmallVec, Status => ::rocket::http::Status, } diff --git a/core/lib/src/catcher/catcher.rs b/core/lib/src/catcher/catcher.rs index 3ed8ff327d..1407bd0f68 100644 --- a/core/lib/src/catcher/catcher.rs +++ b/core/lib/src/catcher/catcher.rs @@ -1,3 +1,4 @@ +use std::any::{TypeId, Any}; use std::fmt; use std::io::Cursor; @@ -10,6 +11,14 @@ use crate::catcher::{Handler, BoxFuture}; use yansi::Paint; +// We could also choose to require a Debug impl? +/// A generic trait for route types. This should be automatically implemented on the structs +/// generated by the codegen for each route. +/// +/// It may also be desirable to add an option for other routes to define a RouteType. This +/// would likely just be a case of adding an alternate constructor to the Route type. +pub trait CatcherType: Any + Send + Sync + 'static { } + /// An error catching route. /// /// Catchers are routes that run when errors are produced by the application. @@ -127,6 +136,9 @@ pub struct Catcher { /// /// This is -(number of nonempty segments in base). pub(crate) rank: isize, + + /// A unique route type to identify this route + pub(crate) catcher_type: Option, } // The rank is computed as -(number of nonempty segments in base) => catchers @@ -185,7 +197,8 @@ impl Catcher { base: uri::Origin::ROOT, handler: Box::new(handler), rank: rank(uri::Origin::ROOT.path()), - code + code, + catcher_type: None, } } @@ -307,6 +320,10 @@ impl Catcher { } } +/// Catcher type of the default catcher created by Rocket +pub struct DefaultCatcher { _priv: () } +impl CatcherType for DefaultCatcher {} + impl Default for Catcher { fn default() -> Self { fn handler<'r>(s: Status, req: &'r Request<'_>) -> BoxFuture<'r> { @@ -315,6 +332,7 @@ impl Default for Catcher { let mut catcher = Catcher::new(None, handler); catcher.name = Some("".into()); + catcher.catcher_type = Some(TypeId::of::()); catcher } } @@ -328,6 +346,8 @@ pub struct StaticInfo { pub code: Option, /// The catcher's handler, i.e, the annotated function. pub handler: for<'r> fn(Status, &'r Request<'_>) -> BoxFuture<'r>, + /// A unique route type to identify this route + pub catcher_type: Box, } #[doc(hidden)] @@ -336,6 +356,7 @@ impl From for Catcher { fn from(info: StaticInfo) -> Catcher { let mut catcher = Catcher::new(info.code, info.handler); catcher.name = Some(info.name.into()); + catcher.catcher_type = Some(info.catcher_type.as_ref().type_id()); catcher } } diff --git a/core/lib/src/fs/server.rs b/core/lib/src/fs/server.rs index da78ec3374..e1b66e0482 100644 --- a/core/lib/src/fs/server.rs +++ b/core/lib/src/fs/server.rs @@ -2,7 +2,7 @@ use std::path::{PathBuf, Path}; use crate::{Request, Data}; use crate::http::{Method, uri::Segments, ext::IntoOwned}; -use crate::route::{Route, Handler, Outcome}; +use crate::route::{Route, Handler, Outcome, RouteType}; use crate::response::Redirect; use crate::fs::NamedFile; @@ -180,10 +180,13 @@ impl FileServer { } } +impl RouteType for FileServer {} + impl From for Vec { fn from(server: FileServer) -> Self { let source = figment::Source::File(server.root.clone()); - let mut route = Route::ranked(server.rank, Method::Get, "/", server); + let mut route = Route::ranked(server.rank, Method::Get, "/", server) + .with_type::(); route.name = Some(format!("FileServer: {}", source).into()); vec![route] } diff --git a/core/lib/src/local/asynchronous/response.rs b/core/lib/src/local/asynchronous/response.rs index cabbdccc21..8ba00eb767 100644 --- a/core/lib/src/local/asynchronous/response.rs +++ b/core/lib/src/local/asynchronous/response.rs @@ -98,6 +98,12 @@ impl<'c> LocalResponse<'c> { } } +impl<'r> LocalResponse<'r> { + pub(crate) fn _request(&self) -> &Request<'r> { + &self._request + } +} + impl LocalResponse<'_> { pub(crate) fn _response(&self) -> &Response<'_> { &self.response diff --git a/core/lib/src/local/blocking/response.rs b/core/lib/src/local/blocking/response.rs index fc0093984d..86dbdfced9 100644 --- a/core/lib/src/local/blocking/response.rs +++ b/core/lib/src/local/blocking/response.rs @@ -1,7 +1,7 @@ use std::io; use tokio::io::AsyncReadExt; -use crate::{Response, local::asynchronous, http::CookieJar}; +use crate::{Response, local::asynchronous, http::CookieJar, Request}; use super::Client; @@ -54,6 +54,12 @@ pub struct LocalResponse<'c> { pub(in super) client: &'c Client, } +impl<'r> LocalResponse<'r> { + pub(crate) fn _request(&self) -> &Request<'r> { + &self.inner._request() + } +} + impl LocalResponse<'_> { fn _response(&self) -> &Response<'_> { &self.inner._response() diff --git a/core/lib/src/local/response.rs b/core/lib/src/local/response.rs index 411be73f1f..38d972a94d 100644 --- a/core/lib/src/local/response.rs +++ b/core/lib/src/local/response.rs @@ -180,6 +180,50 @@ macro_rules! pub_response_impl { self._into_msgpack() $(.$suffix)? } + /// Checks if a route was routed by a specific route type + /// + /// # Example + /// + /// ```rust + /// # use rocket::get; + /// #[get("/")] + /// fn index() -> &'static str { "Hello World" } + #[doc = $doc_prelude] + /// # Client::_test(|_, _, response| { + /// let response: LocalResponse = response; + /// assert!(response.routed_by::()) + /// # }); + /// ``` + pub fn routed_by(&self) -> bool { + if let Some(route_type) = self._request().route().map(|r| r.route_type).flatten() { + route_type == std::any::TypeId::of::() + } else { + false + } + } + + /// Checks if a route was caught by a specific route type + /// + /// # Example + /// + /// ```rust + /// # use rocket::get; + /// #[get("/")] + /// fn index() -> &'static str { "Hello World" } + #[doc = $doc_prelude] + /// # Client::_test(|_, _, response| { + /// let response: LocalResponse = response; + /// assert!(response.routed_by::()) + /// # }); + /// ``` + pub fn caught_by(&self) -> bool { + if let Some(catcher_type) = self._request().catcher().map(|r| r.catcher_type).flatten() { + catcher_type == std::any::TypeId::of::() + } else { + false + } + } + #[cfg(test)] #[allow(dead_code)] fn _ensure_impls_exist() { diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 7f7e50e7eb..324fd943ff 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -8,7 +8,7 @@ use state::{TypeMap, InitCell}; use futures::future::BoxFuture; use atomic::{Atomic, Ordering}; -use crate::{Rocket, Route, Orbit}; +use crate::{Rocket, Route, Orbit, Catcher}; use crate::request::{FromParam, FromSegments, FromRequest, Outcome}; use crate::form::{self, ValueField, FromForm}; use crate::data::Limits; @@ -45,6 +45,7 @@ pub(crate) struct ConnectionMeta { pub(crate) struct RequestState<'r> { pub rocket: &'r Rocket, pub route: Atomic>, + pub catcher: Atomic>, pub cookies: CookieJar<'r>, pub accept: InitCell>, pub content_type: InitCell>, @@ -69,6 +70,7 @@ impl RequestState<'_> { RequestState { rocket: self.rocket, route: Atomic::new(self.route.load(Ordering::Acquire)), + catcher: Atomic::new(self.catcher.load(Ordering::Acquire)), cookies: self.cookies.clone(), accept: self.accept.clone(), content_type: self.content_type.clone(), @@ -97,6 +99,7 @@ impl<'r> Request<'r> { state: RequestState { rocket, route: Atomic::new(None), + catcher: Atomic::new(None), cookies: CookieJar::new(rocket.config()), accept: InitCell::new(), content_type: InitCell::new(), @@ -691,6 +694,22 @@ impl<'r> Request<'r> { self.state.route.load(Ordering::Acquire) } + /// Get the presently matched catcher, if any. + /// + /// This method returns `Some` while a catcher is running. + /// + /// # Example + /// + /// ```rust + /// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap(); + /// # let request = c.get("/"); + /// let catcher = request.catcher(); + /// ``` + #[inline(always)] + pub fn catcher(&self) -> Option<&'r Catcher> { + self.state.catcher.load(Ordering::Acquire) + } + /// Invokes the request guard implementation for `T`, returning its outcome. /// /// # Example @@ -969,8 +988,15 @@ impl<'r> Request<'r> { /// Set `self`'s parameters given that the route used to reach this request /// was `route`. Use during routing when attempting a given route. #[inline(always)] - pub(crate) fn set_route(&self, route: &'r Route) { - self.state.route.store(Some(route), Ordering::Release) + pub(crate) fn set_route(&self, route: Option<&'r Route>) { + self.state.route.store(route, Ordering::Release) + } + + /// Set `self`'s parameters given that the route used to reach this request + /// was `catcher`. Use during routing when attempting a given catcher. + #[inline(always)] + pub(crate) fn set_catcher(&self, catcher: Option<&'r Catcher>) { + self.state.catcher.store(catcher, Ordering::Release) } /// Set the method of `self`, even when `self` is a shared reference. Used diff --git a/core/lib/src/route/route.rs b/core/lib/src/route/route.rs index 24853d9517..f51e4dc94b 100644 --- a/core/lib/src/route/route.rs +++ b/core/lib/src/route/route.rs @@ -1,12 +1,22 @@ use std::fmt; +use std::any::{Any, TypeId}; use std::borrow::Cow; +use std::convert::From; use yansi::Paint; -use crate::http::{uri, Method, MediaType}; -use crate::route::{Handler, RouteUri, BoxFuture}; +use crate::http::{uri, MediaType, Method}; +use crate::route::{BoxFuture, Handler, RouteUri}; use crate::sentinel::Sentry; +// We could also choose to require a Debug impl? +/// A generic trait for route types. This should be automatically implemented on the structs +/// generated by the codegen for each route. +/// +/// It may also be desirable to add an option for other routes to define a RouteType. This +/// would likely just be a case of adding an alternate constructor to the Route type. +pub trait RouteType: Any + Send + Sync + 'static { } + /// A request handling route. /// /// A route consists of exactly the information in its fields. While a `Route` @@ -16,7 +26,8 @@ use crate::sentinel::Sentry; /// /// ```rust /// # #[macro_use] extern crate rocket; -/// # use std::path::PathBuf; +/// # +/// use std::path::PathBuf; /// #[get("/route/?query", rank = 2, format = "json")] /// fn route_name(path: PathBuf) { /* handler procedure */ } /// @@ -178,6 +189,8 @@ pub struct Route { pub format: Option, /// The discovered sentinels. pub(crate) sentinels: Vec, + /// A unique route type to identify this route + pub(crate) route_type: Option, } impl Route { @@ -243,7 +256,9 @@ impl Route { /// ``` #[track_caller] pub fn ranked(rank: R, method: Method, uri: &str, handler: H) -> Route - where H: Handler + 'static, R: Into>, + where + H: Handler + 'static, + R: Into>, { let uri = RouteUri::new("/", uri); let rank = rank.into().unwrap_or_else(|| uri.default_rank()); @@ -252,7 +267,10 @@ impl Route { format: None, sentinels: Vec::new(), handler: Box::new(handler), - rank, uri, method, + rank, + uri, + method, + route_type: None, } } @@ -297,6 +315,12 @@ impl Route { self } + /// Marks this route with the specified type + pub fn with_type(mut self) -> Self { + self.route_type = Some(TypeId::of::()); + self + } + /// Maps the `base` of this route using `mapper`, returning a new `Route` /// with the returned base. /// @@ -335,7 +359,8 @@ impl Route { /// assert_eq!(rebased.uri.path(), "/boo/foo/bar"); /// ``` pub fn map_base<'a, F>(mut self, mapper: F) -> Result> - where F: FnOnce(uri::Origin<'a>) -> String + where + F: FnOnce(uri::Origin<'a>) -> String, { let base = mapper(self.uri.base); self.uri = RouteUri::try_new(&base, &self.uri.unmounted_origin.to_string())?; @@ -394,6 +419,8 @@ pub struct StaticInfo { /// Route-derived sentinels, if any. /// This isn't `&'static [SentryInfo]` because `type_name()` isn't `const`. pub sentinels: Vec, + /// A unique route type to identify this route + pub route_type: Box, } #[doc(hidden)] @@ -410,6 +437,9 @@ impl From for Route { format: info.format, sentinels: info.sentinels.into_iter().collect(), uri, + // Uses `.as_ref()` to get the type id if the internal type, rather than the type id of + // the box + route_type: Some(info.route_type.as_ref().type_id()), } } } diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 5617f4fbcd..a042930789 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -10,6 +10,7 @@ use crate::router::Collide; pub(crate) struct Router { routes: HashMap>, catchers: HashMap, Vec>, + pub default_catcher: Catcher, } #[derive(Debug)] diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index e3836984fd..666685a18c 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -325,7 +325,7 @@ impl Rocket { for route in self.router.route(request) { // Retrieve and set the requests parameters. info_!("Matched: {}", route); - request.set_route(route); + request.set_route(Some(route)); let name = route.name.as_deref(); let outcome = handle(name, || route.handler.handle(request, data)).await @@ -364,8 +364,10 @@ impl Rocket { // from earlier, unsuccessful paths from being reflected in error // response. We may wish to relax this in the future. req.cookies().reset_delta(); + req.set_route(None); if let Some(catcher) = self.router.catch(status, req) { + req.set_catcher(Some(catcher)); warn_!("Responding with registered {} catcher.", catcher); let name = catcher.name.as_deref(); handle(name, || catcher.handler.handle(status, req)).await @@ -374,6 +376,7 @@ impl Rocket { } else { let code = status.code.blue().bold(); warn_!("No {} catcher registered. Using Rocket default.", code); + req.set_catcher(Some(&self.router.default_catcher)); Ok(crate::catcher::default_handler(status, req)) } } @@ -401,6 +404,7 @@ impl Rocket { } } + req.set_catcher(Some(&self.router.default_catcher)); // If it failed again or if it was already a 500, use Rocket's default. error_!("{} catcher failed. Using Rocket default 500.", status.code); crate::catcher::default_handler(Status::InternalServerError, req) diff --git a/examples/hello/src/tests.rs b/examples/hello/src/tests.rs index fd5b628d96..f089a1f9c0 100644 --- a/examples/hello/src/tests.rs +++ b/examples/hello/src/tests.rs @@ -30,10 +30,12 @@ fn hello() { let uri = format!("/?{}{}{}", q("lang", lang), q("emoji", emoji), q("name", name)); let response = client.get(uri).dispatch(); + assert!(response.routed_by::(), "Response was not generated by the `hello` route"); assert_eq!(response.into_string().unwrap(), expected); let uri = format!("/?{}{}{}", q("emoji", emoji), q("name", name), q("lang", lang)); let response = client.get(uri).dispatch(); + assert!(response.routed_by::(), "Response was not generated by the `hello` route"); assert_eq!(response.into_string().unwrap(), expected); } } @@ -42,6 +44,7 @@ fn hello() { fn hello_world() { let client = Client::tracked(super::rocket()).unwrap(); let response = client.get("/hello/world").dispatch(); + assert!(response.routed_by::(), "Response was not generated by the `world` route"); assert_eq!(response.into_string(), Some("Hello, world!".into())); } @@ -49,6 +52,7 @@ fn hello_world() { fn hello_mir() { let client = Client::tracked(super::rocket()).unwrap(); let response = client.get("/hello/%D0%BC%D0%B8%D1%80").dispatch(); + assert!(response.routed_by::(), "Response was not generated by the `mir` route"); assert_eq!(response.into_string(), Some("Привет, мир!".into())); } @@ -60,11 +64,13 @@ fn wave() { let real_name = RawStr::new(name).percent_decode_lossy(); let expected = format!("👋 Hello, {} year old named {}!", age, real_name); let response = client.get(uri).dispatch(); + assert!(response.routed_by::(), "Response was not generated by the `wave` route"); assert_eq!(response.into_string().unwrap(), expected); for bad_age in &["1000", "-1", "bird", "?"] { let bad_uri = format!("/wave/{}/{}", name, bad_age); let response = client.get(bad_uri).dispatch(); + assert!(response.caught_by::(), "Response was not generated by the default catcher"); assert_eq!(response.status(), Status::NotFound); } } From 78aecfdd41198aeb3bef4f9bd96c358563e8d3a7 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Mon, 9 May 2022 12:02:44 -0500 Subject: [PATCH 2/7] Documentation and minor fixes - Add with_type to catcher - Documents CatcherType, RouteType, and the associated methods in LocalResponse. --- core/lib/src/catcher/catcher.rs | 17 +++++++++++----- core/lib/src/local/response.rs | 35 +++++++++++++++++++++++++++++---- core/lib/src/route/route.rs | 13 +++++++----- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/core/lib/src/catcher/catcher.rs b/core/lib/src/catcher/catcher.rs index 1407bd0f68..4a60f791d3 100644 --- a/core/lib/src/catcher/catcher.rs +++ b/core/lib/src/catcher/catcher.rs @@ -11,12 +11,10 @@ use crate::catcher::{Handler, BoxFuture}; use yansi::Paint; -// We could also choose to require a Debug impl? -/// A generic trait for route types. This should be automatically implemented on the structs -/// generated by the codegen for each route. +/// A generic trait for catcher types. This should be automatically implemented on the structs +/// generated by the codegen for each catcher, and manually implemented on custom cater types. /// -/// It may also be desirable to add an option for other routes to define a RouteType. This -/// would likely just be a case of adding an alternate constructor to the Route type. +/// Use the `Catcher::with_type::()` method to set the catcher type. pub trait CatcherType: Any + Send + Sync + 'static { } /// An error catching route. @@ -183,6 +181,8 @@ impl Catcher { /// /// Panics if `code` is not in the HTTP status code error range `[400, /// 600)`. + /// + /// If applicable, `with_type` should also be called to set the route type for testing #[inline(always)] pub fn new(code: S, handler: H) -> Catcher where S: Into>, H: Handler @@ -274,6 +274,13 @@ impl Catcher { self } + /// Marks this catcher with the specified type. For a custom catcher type, i.e. something that can + /// be passed to `.register()`, it should be that type to make identification easier. + pub fn with_type(mut self) -> Self { + self.catcher_type = Some(TypeId::of::()); + self + } + /// Maps the `base` of this catcher using `mapper`, returning a new /// `Catcher` with the returned base. /// diff --git a/core/lib/src/local/response.rs b/core/lib/src/local/response.rs index 38d972a94d..f0accbe8f6 100644 --- a/core/lib/src/local/response.rs +++ b/core/lib/src/local/response.rs @@ -180,7 +180,8 @@ macro_rules! pub_response_impl { self._into_msgpack() $(.$suffix)? } - /// Checks if a route was routed by a specific route type + /// Checks if a route was routed by a specific route type. This only returns true if the route + /// actually generated a response, and a catcher was not run. /// /// # Example /// @@ -194,6 +195,12 @@ macro_rules! pub_response_impl { /// assert!(response.routed_by::()) /// # }); /// ``` + /// + /// # Other Route types + /// + /// [`FileServer`](crate::fs::FileServer) implementes `RouteType`, so a route that should + /// return a static file can be checked against it. Libraries which provide a Route type should + /// implement `RouteType`, see [`RouteType`](crate::route::RouteType) for more information. pub fn routed_by(&self) -> bool { if let Some(route_type) = self._request().route().map(|r| r.route_type).flatten() { route_type == std::any::TypeId::of::() @@ -208,14 +215,18 @@ macro_rules! pub_response_impl { /// /// ```rust /// # use rocket::get; - /// #[get("/")] - /// fn index() -> &'static str { "Hello World" } + /// #[catch(404)] + /// fn default_404() -> &'static str { "Hello World" } #[doc = $doc_prelude] /// # Client::_test(|_, _, response| { /// let response: LocalResponse = response; - /// assert!(response.routed_by::()) + /// assert!(response.caught_by::()) /// # }); /// ``` + /// + /// # Rocket's default catcher + /// + /// The default catcher has a `CatcherType` of [`DefaultCatcher`](crate::catcher::DefaultCatcher) pub fn caught_by(&self) -> bool { if let Some(catcher_type) = self._request().catcher().map(|r| r.catcher_type).flatten() { catcher_type == std::any::TypeId::of::() @@ -224,6 +235,22 @@ macro_rules! pub_response_impl { } } + /// Checks if a route was caught by a catcher + /// + /// # Example + /// + /// ```rust + /// # use rocket::get; + #[doc = $doc_prelude] + /// # Client::_test(|_, _, response| { + /// let response: LocalResponse = response; + /// assert!(response.was_caught()) + /// # }); + /// ``` + pub fn was_caught(&self) -> bool { + self._request().catcher().is_some() + } + #[cfg(test)] #[allow(dead_code)] fn _ensure_impls_exist() { diff --git a/core/lib/src/route/route.rs b/core/lib/src/route/route.rs index f51e4dc94b..fc2b355762 100644 --- a/core/lib/src/route/route.rs +++ b/core/lib/src/route/route.rs @@ -9,12 +9,10 @@ use crate::http::{uri, MediaType, Method}; use crate::route::{BoxFuture, Handler, RouteUri}; use crate::sentinel::Sentry; -// We could also choose to require a Debug impl? /// A generic trait for route types. This should be automatically implemented on the structs -/// generated by the codegen for each route. +/// generated by the codegen for each route, and manually implemented on custom route types. /// -/// It may also be desirable to add an option for other routes to define a RouteType. This -/// would likely just be a case of adding an alternate constructor to the Route type. +/// Use the `Route::with_type::()` method to set the route type. pub trait RouteType: Any + Send + Sync + 'static { } /// A request handling route. @@ -219,6 +217,8 @@ impl Route { /// assert_eq!(index.method, Method::Get); /// assert_eq!(index.uri, "/"); /// ``` + /// + /// If applicable, `with_type` should also be called to set the route type for testing #[track_caller] pub fn new(method: Method, uri: &str, handler: H) -> Route { Route::ranked(None, method, uri, handler) @@ -254,6 +254,8 @@ impl Route { /// assert_eq!(foo.method, Method::Post); /// assert_eq!(foo.uri, "/foo?bar"); /// ``` + /// + /// If applicable, `with_type` should also be called to set the route type for testing #[track_caller] pub fn ranked(rank: R, method: Method, uri: &str, handler: H) -> Route where @@ -315,7 +317,8 @@ impl Route { self } - /// Marks this route with the specified type + /// Marks this route with the specified type. For a custom route type, i.e. something that can + /// be passed to `.mount()`, it should be that type to make identification easier. pub fn with_type(mut self) -> Self { self.route_type = Some(TypeId::of::()); self From 2bfdb537ba3e4715ad84dfcb14a82c27b63dbe8d Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Mon, 9 May 2022 12:17:02 -0500 Subject: [PATCH 3/7] Fix formatting --- core/codegen/src/attribute/catch/mod.rs | 2 +- core/lib/src/local/asynchronous/client.rs | 17 ++++++++++++++- core/lib/src/local/blocking/client.rs | 15 +++++++++++++- core/lib/src/local/response.rs | 12 +++++------ examples/hello/src/tests.rs | 25 ++++++++++++++++++----- 5 files changed, 57 insertions(+), 14 deletions(-) diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index 0911936edd..8bf16ad85d 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -85,7 +85,7 @@ pub fn _catch( name: stringify!(#user_catcher_fn_name), code: #status_code, handler: monomorphized_function, - route_type: #_Box::new(self), + catcher_type: #_Box::new(self), } } diff --git a/core/lib/src/local/asynchronous/client.rs b/core/lib/src/local/asynchronous/client.rs index ecec4527cc..f7aacfedaa 100644 --- a/core/lib/src/local/asynchronous/client.rs +++ b/core/lib/src/local/asynchronous/client.rs @@ -2,7 +2,7 @@ use std::fmt; use parking_lot::RwLock; -use crate::{Rocket, Phase, Orbit, Ignite, Error}; +use crate::{Rocket, Phase, Orbit, Ignite, Error, Build}; use crate::local::asynchronous::{LocalRequest, LocalResponse}; use crate::http::{Method, uri::Origin, private::cookie}; @@ -76,6 +76,21 @@ impl Client { }) } + // WARNING: This is unstable! Do not use this method outside of Rocket! + // This is used by the `Client` doctests. + #[doc(hidden)] + pub fn _test_with(mods: M, f: F) -> T + where F: FnOnce(&Self, LocalRequest<'_>, LocalResponse<'_>) -> T + Send, + M: FnOnce(Rocket) -> Rocket + { + crate::async_test(async { + let client = Client::debug(mods(crate::build())).await.unwrap(); + let request = client.get("/"); + let response = request.clone().dispatch().await; + f(&client, request, response) + }) + } + #[inline(always)] pub(crate) fn _rocket(&self) -> &Rocket { &self.rocket diff --git a/core/lib/src/local/blocking/client.rs b/core/lib/src/local/blocking/client.rs index d3a8b0ef94..1a9c95e1dc 100644 --- a/core/lib/src/local/blocking/client.rs +++ b/core/lib/src/local/blocking/client.rs @@ -1,7 +1,7 @@ use std::fmt; use std::cell::RefCell; -use crate::{Rocket, Phase, Orbit, Ignite, Error}; +use crate::{Rocket, Phase, Orbit, Ignite, Error, Build}; use crate::local::{asynchronous, blocking::{LocalRequest, LocalResponse}}; use crate::http::{Method, uri::Origin}; @@ -54,6 +54,19 @@ impl Client { f(&client, request, response) } + // WARNING: This is unstable! Do not use this method outside of Rocket! + // This is used by the `Client` doctests. + #[doc(hidden)] + pub fn _test_with(mods: M, f: F) -> T + where F: FnOnce(&Self, LocalRequest<'_>, LocalResponse<'_>) -> T + Send, + M: FnOnce(Rocket) -> Rocket + { + let client = Client::debug(mods(crate::build())).unwrap(); + let request = client.get("/"); + let response = request.clone().dispatch(); + f(&client, request, response) + } + #[inline(always)] pub(crate) fn inner(&self) -> &asynchronous::Client { self.inner.as_ref().expect("internal invariant broken: self.inner is Some") diff --git a/core/lib/src/local/response.rs b/core/lib/src/local/response.rs index f0accbe8f6..605750856b 100644 --- a/core/lib/src/local/response.rs +++ b/core/lib/src/local/response.rs @@ -186,13 +186,13 @@ macro_rules! pub_response_impl { /// # Example /// /// ```rust - /// # use rocket::get; + /// # use rocket::{get, routes}; /// #[get("/")] /// fn index() -> &'static str { "Hello World" } #[doc = $doc_prelude] - /// # Client::_test(|_, _, response| { + /// # Client::_test_with(|r| r.mount("/", routes![index]), |_, _, response| { /// let response: LocalResponse = response; - /// assert!(response.routed_by::()) + /// assert!(response.routed_by::()); /// # }); /// ``` /// @@ -214,13 +214,13 @@ macro_rules! pub_response_impl { /// # Example /// /// ```rust - /// # use rocket::get; + /// # use rocket::{catch, catchers}; /// #[catch(404)] /// fn default_404() -> &'static str { "Hello World" } #[doc = $doc_prelude] - /// # Client::_test(|_, _, response| { + /// # Client::_test_with(|r| r.register("/", catchers![default_404]), |_, _, response| { /// let response: LocalResponse = response; - /// assert!(response.caught_by::()) + /// assert!(response.caught_by::()); /// # }); /// ``` /// diff --git a/examples/hello/src/tests.rs b/examples/hello/src/tests.rs index f089a1f9c0..bb66dc87c5 100644 --- a/examples/hello/src/tests.rs +++ b/examples/hello/src/tests.rs @@ -30,12 +30,18 @@ fn hello() { let uri = format!("/?{}{}{}", q("lang", lang), q("emoji", emoji), q("name", name)); let response = client.get(uri).dispatch(); - assert!(response.routed_by::(), "Response was not generated by the `hello` route"); + assert!( + response.routed_by::(), + "Response was not generated by the `hello` route" + ); assert_eq!(response.into_string().unwrap(), expected); let uri = format!("/?{}{}{}", q("emoji", emoji), q("name", name), q("lang", lang)); let response = client.get(uri).dispatch(); - assert!(response.routed_by::(), "Response was not generated by the `hello` route"); + assert!( + response.routed_by::(), + "Response was not generated by the `hello` route" + ); assert_eq!(response.into_string().unwrap(), expected); } } @@ -44,7 +50,10 @@ fn hello() { fn hello_world() { let client = Client::tracked(super::rocket()).unwrap(); let response = client.get("/hello/world").dispatch(); - assert!(response.routed_by::(), "Response was not generated by the `world` route"); + assert!( + response.routed_by::(), + "Response was not generated by the `world` route" + ); assert_eq!(response.into_string(), Some("Hello, world!".into())); } @@ -64,13 +73,19 @@ fn wave() { let real_name = RawStr::new(name).percent_decode_lossy(); let expected = format!("👋 Hello, {} year old named {}!", age, real_name); let response = client.get(uri).dispatch(); - assert!(response.routed_by::(), "Response was not generated by the `wave` route"); + assert!( + response.routed_by::(), + "Response was not generated by the `wave` route" + ); assert_eq!(response.into_string().unwrap(), expected); for bad_age in &["1000", "-1", "bird", "?"] { let bad_uri = format!("/wave/{}/{}", name, bad_age); let response = client.get(bad_uri).dispatch(); - assert!(response.caught_by::(), "Response was not generated by the default catcher"); + assert!( + response.caught_by::(), + "Response was not generated by the default catcher" + ); assert_eq!(response.status(), Status::NotFound); } } From c41087140052a906f8238d3aa95419a68cf9b43e Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Wed, 11 May 2022 11:38:29 -0500 Subject: [PATCH 4/7] Fix typos and minor mistakes --- core/lib/src/catcher/catcher.rs | 2 +- core/lib/src/route/route.rs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/lib/src/catcher/catcher.rs b/core/lib/src/catcher/catcher.rs index 4a60f791d3..3cf70e5f95 100644 --- a/core/lib/src/catcher/catcher.rs +++ b/core/lib/src/catcher/catcher.rs @@ -12,7 +12,7 @@ use crate::catcher::{Handler, BoxFuture}; use yansi::Paint; /// A generic trait for catcher types. This should be automatically implemented on the structs -/// generated by the codegen for each catcher, and manually implemented on custom cater types. +/// generated by the codegen for each catcher, and manually implemented on custom catcher types. /// /// Use the `Catcher::with_type::()` method to set the catcher type. pub trait CatcherType: Any + Send + Sync + 'static { } diff --git a/core/lib/src/route/route.rs b/core/lib/src/route/route.rs index fc2b355762..da1b6c8e43 100644 --- a/core/lib/src/route/route.rs +++ b/core/lib/src/route/route.rs @@ -24,8 +24,7 @@ pub trait RouteType: Any + Send + Sync + 'static { } /// /// ```rust /// # #[macro_use] extern crate rocket; -/// # -/// use std::path::PathBuf; +/// # use std::path::PathBuf; /// #[get("/route/?query", rank = 2, format = "json")] /// fn route_name(path: PathBuf) { /* handler procedure */ } /// From ca8e40963dded81f893794ddc85d71d25f141ff0 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Thu, 31 Aug 2023 17:44:20 -0500 Subject: [PATCH 5/7] Rename methods to make them easier to read --- core/lib/src/local/response.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/lib/src/local/response.rs b/core/lib/src/local/response.rs index 605750856b..4fa40406c9 100644 --- a/core/lib/src/local/response.rs +++ b/core/lib/src/local/response.rs @@ -192,16 +192,16 @@ macro_rules! pub_response_impl { #[doc = $doc_prelude] /// # Client::_test_with(|r| r.mount("/", routes![index]), |_, _, response| { /// let response: LocalResponse = response; - /// assert!(response.routed_by::()); + /// assert!(response.was_routed_by::()); /// # }); /// ``` /// /// # Other Route types /// /// [`FileServer`](crate::fs::FileServer) implementes `RouteType`, so a route that should - /// return a static file can be checked against it. Libraries which provide a Route type should + /// return a static file can be checked against it. Libraries which provide custom Routes should /// implement `RouteType`, see [`RouteType`](crate::route::RouteType) for more information. - pub fn routed_by(&self) -> bool { + pub fn was_routed_by(&self) -> bool { if let Some(route_type) = self._request().route().map(|r| r.route_type).flatten() { route_type == std::any::TypeId::of::() } else { @@ -220,14 +220,14 @@ macro_rules! pub_response_impl { #[doc = $doc_prelude] /// # Client::_test_with(|r| r.register("/", catchers![default_404]), |_, _, response| { /// let response: LocalResponse = response; - /// assert!(response.caught_by::()); + /// assert!(response.was_caught_by::()); /// # }); /// ``` /// /// # Rocket's default catcher /// /// The default catcher has a `CatcherType` of [`DefaultCatcher`](crate::catcher::DefaultCatcher) - pub fn caught_by(&self) -> bool { + pub fn was_caught_by(&self) -> bool { if let Some(catcher_type) = self._request().catcher().map(|r| r.catcher_type).flatten() { catcher_type == std::any::TypeId::of::() } else { From bd1d0f5c6c00bd42edb069dadf4fc6c227516f44 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Sun, 10 Sep 2023 15:02:02 -0500 Subject: [PATCH 6/7] Implement tracing through the list of routes --- core/lib/src/local/response.rs | 63 +++++++++++++++++++++++++++++++-- core/lib/src/request/request.rs | 25 ++++++++++++- 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/core/lib/src/local/response.rs b/core/lib/src/local/response.rs index 4fa40406c9..a6b1239aa2 100644 --- a/core/lib/src/local/response.rs +++ b/core/lib/src/local/response.rs @@ -180,8 +180,9 @@ macro_rules! pub_response_impl { self._into_msgpack() $(.$suffix)? } - /// Checks if a route was routed by a specific route type. This only returns true if the route - /// actually generated a response, and a catcher was not run. + /// Checks if a response was generted by a specific route type. This only returns true if the route + /// actually generated the response, and a catcher was _not_ run. See [`was_attempted_by`] to + /// check if a route was attempted, but may not have generated the response /// /// # Example /// @@ -202,13 +203,69 @@ macro_rules! pub_response_impl { /// return a static file can be checked against it. Libraries which provide custom Routes should /// implement `RouteType`, see [`RouteType`](crate::route::RouteType) for more information. pub fn was_routed_by(&self) -> bool { - if let Some(route_type) = self._request().route().map(|r| r.route_type).flatten() { + // If this request was caught, the route in `.route()` did NOT generate this response. + if self._request().catcher().is_some() { + false + } else if let Some(route_type) = self._request().route().map(|r| r.route_type).flatten() { route_type == std::any::TypeId::of::() } else { false } } + /// Checks if a request was routed to a specific route type. This will return true for routes + /// that were attempted, _but not actually called_. This enables a test to verify that a route + /// was attempted, even if another route actually generated the response, e.g. an + /// authenticated route will typically defer to an error catcher if the request does not have + /// the proper authentication. This makes it possible to verify that a request was routed to + /// the authentication route, even if the response was eventaully generated by another route or + /// a catcher. + /// + /// # Example + /// + // WARNING: this doc-test is NOT run, because cargo test --doc does not run doc-tests for items + // only available during tests. + /// ```rust + /// # use rocket::{get, routes, async_trait, request::{Request, Outcome, FromRequest}}; + /// # struct WillFail {} + /// # #[async_trait] + /// # impl<'r> FromRequest<'r> for WillFail { + /// # type Error = (); + /// # async fn from_request(request: &'r Request<'_>) -> Outcome { + /// # Outcome::Forward(()) + /// # } + /// # } + /// #[get("/", rank = 2)] + /// fn index1(guard: WillFail) -> &'static str { "Hello World" } + /// #[get("/")] + /// fn index2() -> &'static str { "Hello World" } + #[doc = $doc_prelude] + /// # Client::_test_with(|r| r.mount("/", routes![index1, index2]), |_, _, response| { + /// let response: LocalResponse = response; + /// assert!(response.was_attempted_by::()); + /// assert!(response.was_attempted_by::()); + /// assert!(response.was_routed_by::()); + /// # }); + /// ``` + /// + /// # Other Route types + /// + /// [`FileServer`](crate::fs::FileServer) implementes `RouteType`, so a route that should + /// return a static file can be checked against it. Libraries which provide custom Routes should + /// implement `RouteType`, see [`RouteType`](crate::route::RouteType) for more information. + /// + /// # Note + /// + /// This method is marked as `cfg(test)`, and is therefore only available in unit and + /// integration tests. This is because the list of routes attempted is only collected in these + /// testing environments, to minimize performance impacts during normal operation. + #[cfg(test)] + pub fn was_attempted_by(&self) -> bool { + self._request().route_path(|path| path.iter().any(|r| + r.route_type == Some(std::any::TypeId::of::()) + )) + } + /// Checks if a route was caught by a specific route type /// /// # Example diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 324fd943ff..e6f94ede72 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -19,6 +19,9 @@ use crate::http::uncased::UncasedStr; use crate::http::private::Certificates; use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority}; +#[cfg(test)] +use parking_lot::Mutex; + /// The type of an incoming web request. /// /// This should be used sparingly in Rocket applications. In particular, it @@ -51,6 +54,8 @@ pub(crate) struct RequestState<'r> { pub content_type: InitCell>, pub cache: Arc, pub host: Option>, + #[cfg(test)] + pub route_path: Arc>>, } impl Request<'_> { @@ -76,6 +81,8 @@ impl RequestState<'_> { content_type: self.content_type.clone(), cache: self.cache.clone(), host: self.host.clone(), + #[cfg(test)] + route_path: self.route_path.clone(), } } } @@ -105,6 +112,8 @@ impl<'r> Request<'r> { content_type: InitCell::new(), cache: Arc::new(::new()), host: None, + #[cfg(test)] + route_path: Arc::new(Mutex::new(vec![])), } } } @@ -989,7 +998,21 @@ impl<'r> Request<'r> { /// was `route`. Use during routing when attempting a given route. #[inline(always)] pub(crate) fn set_route(&self, route: Option<&'r Route>) { - self.state.route.store(route, Ordering::Release) + self.state.route.store(route, Ordering::Release); + #[cfg(test)] + if let Some(route) = route { + self.state.route_path.lock().push(route); + } + } + + /// Compute a value using the route path of this request + /// + /// This doesn't simply return a refernce, since the reference is held behind a Arc. + /// This method is only intended to be used internally, and is therefore NOT pub. + #[inline(always)] + #[cfg(test)] + pub(crate) fn route_path(&self, operation: impl FnOnce(&[&'r Route]) -> R) -> R { + operation(self.state.route_path.lock().as_ref()) } /// Set `self`'s parameters given that the route used to reach this request From 09ebdbd6f9243a89fa241d92f2b6b15b24c64419 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Thu, 14 Sep 2023 01:18:55 -0500 Subject: [PATCH 7/7] Fix examples to use new names For some reason, the scripts/test.sh does not successfully run all tests - it terminates on the doctest step with a sigkill. This means I have to push changes to github to test them. --- examples/hello/src/tests.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/hello/src/tests.rs b/examples/hello/src/tests.rs index bb66dc87c5..8f0b191ce6 100644 --- a/examples/hello/src/tests.rs +++ b/examples/hello/src/tests.rs @@ -31,7 +31,7 @@ fn hello() { let uri = format!("/?{}{}{}", q("lang", lang), q("emoji", emoji), q("name", name)); let response = client.get(uri).dispatch(); assert!( - response.routed_by::(), + response.was_routed_by::(), "Response was not generated by the `hello` route" ); assert_eq!(response.into_string().unwrap(), expected); @@ -39,7 +39,7 @@ fn hello() { let uri = format!("/?{}{}{}", q("emoji", emoji), q("name", name), q("lang", lang)); let response = client.get(uri).dispatch(); assert!( - response.routed_by::(), + response.was_routed_by::(), "Response was not generated by the `hello` route" ); assert_eq!(response.into_string().unwrap(), expected); @@ -51,7 +51,7 @@ fn hello_world() { let client = Client::tracked(super::rocket()).unwrap(); let response = client.get("/hello/world").dispatch(); assert!( - response.routed_by::(), + response.was_routed_by::(), "Response was not generated by the `world` route" ); assert_eq!(response.into_string(), Some("Hello, world!".into())); @@ -61,7 +61,10 @@ fn hello_world() { fn hello_mir() { let client = Client::tracked(super::rocket()).unwrap(); let response = client.get("/hello/%D0%BC%D0%B8%D1%80").dispatch(); - assert!(response.routed_by::(), "Response was not generated by the `mir` route"); + assert!( + response.was_routed_by::(), + "Response was not generated by the `mir` route" + ); assert_eq!(response.into_string(), Some("Привет, мир!".into())); } @@ -74,7 +77,7 @@ fn wave() { let expected = format!("👋 Hello, {} year old named {}!", age, real_name); let response = client.get(uri).dispatch(); assert!( - response.routed_by::(), + response.was_routed_by::(), "Response was not generated by the `wave` route" ); assert_eq!(response.into_string().unwrap(), expected); @@ -83,7 +86,7 @@ fn wave() { let bad_uri = format!("/wave/{}/{}", name, bad_age); let response = client.get(bad_uri).dispatch(); assert!( - response.caught_by::(), + response.was_caught_by::(), "Response was not generated by the default catcher" ); assert_eq!(response.status(), Status::NotFound);