From e0ef641e5fc12048b1632cf2feebddac19a2f8ad Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 8 Nov 2022 21:31:06 +0100 Subject: [PATCH] Rework `Form` and `Query` rejections (#1496) * Change `FailedToDeserializeQueryString` rejection for `Form` Its now called `FailedToDeserializeForm`. * changelog * Make dedicate rejection type for axum-extra's `Form` * update trybuild test * Make dedicate rejection type for axum-extra's `Query` --- axum-extra/src/extract/form.rs | 94 +++++++++++-------- axum-extra/src/extract/mod.rs | 4 +- axum-extra/src/extract/query.rs | 52 ++++++++-- .../fail/wrong_return_type.stderr | 2 +- axum/CHANGELOG.md | 3 + axum/src/extract/query.rs | 4 +- axum/src/extract/rejection.rs | 46 +++------ axum/src/form.rs | 2 +- 8 files changed, 122 insertions(+), 85 deletions(-) diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index 254c7cce41..df33e5cade 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -1,16 +1,13 @@ use axum::{ async_trait, body::HttpBody, - extract::{ - rejection::{FailedToDeserializeQueryString, FormRejection, InvalidFormContentType}, - FromRequest, - }, - BoxError, + extract::{rejection::RawFormRejection, FromRequest, RawForm}, + response::{IntoResponse, Response}, + BoxError, Error, RequestExt, }; -use bytes::Bytes; -use http::{header, HeaderMap, Method, Request}; +use http::{Request, StatusCode}; use serde::de::DeserializeOwned; -use std::ops::Deref; +use std::{fmt, ops::Deref}; /// Extractor that deserializes `application/x-www-form-urlencoded` requests /// into some type. @@ -65,41 +62,60 @@ where { type Rejection = FormRejection; - async fn from_request(req: Request, state: &S) -> Result { - if req.method() == Method::GET { - let query = req.uri().query().unwrap_or_default(); - let value = serde_html_form::from_str(query) - .map_err(FailedToDeserializeQueryString::__private_new)?; - Ok(Form(value)) - } else { - if !has_content_type(req.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) { - return Err(InvalidFormContentType::default().into()); - } - - let bytes = Bytes::from_request(req, state).await?; - let value = serde_html_form::from_bytes(&bytes) - .map_err(FailedToDeserializeQueryString::__private_new)?; - - Ok(Form(value)) + async fn from_request(req: Request, _state: &S) -> Result { + let RawForm(bytes) = req + .extract() + .await + .map_err(FormRejection::RawFormRejection)?; + + serde_html_form::from_bytes::(&bytes) + .map(Self) + .map_err(|err| FormRejection::FailedToDeserializeForm(Error::new(err))) + } +} + +/// Rejection used for [`Form`]. +/// +/// Contains one variant for each way the [`Form`] extractor can fail. +#[derive(Debug)] +#[non_exhaustive] +#[cfg(feature = "form")] +pub enum FormRejection { + #[allow(missing_docs)] + RawFormRejection(RawFormRejection), + #[allow(missing_docs)] + FailedToDeserializeForm(Error), +} + +impl IntoResponse for FormRejection { + fn into_response(self) -> Response { + match self { + Self::RawFormRejection(inner) => inner.into_response(), + Self::FailedToDeserializeForm(inner) => ( + StatusCode::BAD_REQUEST, + format!("Failed to deserialize form: {}", inner), + ) + .into_response(), } } } -// this is duplicated in `axum/src/extract/mod.rs` -fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool { - let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { - content_type - } else { - return false; - }; - - let content_type = if let Ok(content_type) = content_type.to_str() { - content_type - } else { - return false; - }; - - content_type.starts_with(expected_content_type.as_ref()) +impl fmt::Display for FormRejection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::RawFormRejection(inner) => inner.fmt(f), + Self::FailedToDeserializeForm(inner) => inner.fmt(f), + } + } +} + +impl std::error::Error for FormRejection { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::RawFormRejection(inner) => Some(inner), + Self::FailedToDeserializeForm(inner) => Some(inner), + } + } } #[cfg(test)] diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index e795cfb7bc..eaed0a4d3f 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -25,10 +25,10 @@ pub use self::cookie::PrivateCookieJar; pub use self::cookie::SignedCookieJar; #[cfg(feature = "form")] -pub use self::form::Form; +pub use self::form::{Form, FormRejection}; #[cfg(feature = "query")] -pub use self::query::Query; +pub use self::query::{Query, QueryRejection}; #[cfg(feature = "json-lines")] #[doc(no_inline)] diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index 4a8d6f8676..8c829dd5fc 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -1,13 +1,12 @@ use axum::{ async_trait, - extract::{ - rejection::{FailedToDeserializeQueryString, QueryRejection}, - FromRequestParts, - }, + extract::FromRequestParts, + response::{IntoResponse, Response}, + Error, }; -use http::request::Parts; +use http::{request::Parts, StatusCode}; use serde::de::DeserializeOwned; -use std::ops::Deref; +use std::{fmt, ops::Deref}; /// Extractor that deserializes query strings into some type. /// @@ -69,7 +68,7 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let query = parts.uri.query().unwrap_or_default(); let value = serde_html_form::from_str(query) - .map_err(FailedToDeserializeQueryString::__private_new)?; + .map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?; Ok(Query(value)) } } @@ -82,6 +81,45 @@ impl Deref for Query { } } +/// Rejection used for [`Query`]. +/// +/// Contains one variant for each way the [`Query`] extractor can fail. +#[derive(Debug)] +#[non_exhaustive] +#[cfg(feature = "query")] +pub enum QueryRejection { + #[allow(missing_docs)] + FailedToDeserializeQueryString(Error), +} + +impl IntoResponse for QueryRejection { + fn into_response(self) -> Response { + match self { + Self::FailedToDeserializeQueryString(inner) => ( + StatusCode::BAD_REQUEST, + format!("Failed to deserialize query string: {}", inner), + ) + .into_response(), + } + } +} + +impl fmt::Display for QueryRejection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::FailedToDeserializeQueryString(inner) => inner.fmt(f), + } + } +} + +impl std::error::Error for QueryRejection { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::FailedToDeserializeQueryString(inner) => Some(inner), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr b/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr index 4781a22cd9..303ae1862f 100644 --- a/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr +++ b/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr @@ -13,7 +13,7 @@ error[E0277]: the trait bound `bool: IntoResponse` is not satisfied (Response<()>, T1, T2, R) (Response<()>, T1, T2, T3, R) (Response<()>, T1, T2, T3, T4, R) - and 119 others + and 120 others note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check` --> tests/debug_handler/fail/wrong_return_type.rs:4:23 | diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index c9cf061a75..1596f37b90 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -47,6 +47,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from [`axum-macros`] behind the `macros` feature ([#1352]) - **added:** Add `extract::RawForm` for accessing raw urlencoded query bytes or request body ([#1487]) +- **breaking:** Rename `FormRejection::FailedToDeserializeQueryString` to + `FormRejection::FailedToDeserializeForm` ([#1496]) [#1352]: https://github.com/tokio-rs/axum/pull/1352 [#1368]: https://github.com/tokio-rs/axum/pull/1368 @@ -63,6 +65,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1420]: https://github.com/tokio-rs/axum/pull/1420 [#1421]: https://github.com/tokio-rs/axum/pull/1421 [#1487]: https://github.com/tokio-rs/axum/pull/1487 +[#1496]: https://github.com/tokio-rs/axum/pull/1496 # 0.6.0-rc.2 (10. September, 2022) diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index a520853a16..23cb02dc7a 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -59,8 +59,8 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let query = parts.uri.query().unwrap_or_default(); - let value = serde_urlencoded::from_str(query) - .map_err(FailedToDeserializeQueryString::__private_new)?; + let value = + serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?; Ok(Query(value)) } } diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index e49ac17fac..0b3829362d 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -1,8 +1,5 @@ //! Rejection response types. -use crate::{BoxError, Error}; -use axum_core::response::{IntoResponse, Response}; - pub use crate::extract::path::FailedToDeserializePathParams; pub use axum_core::extract::rejection::*; @@ -73,39 +70,22 @@ define_rejection! { pub struct FailedToResolveHost; } -/// Rejection type for extractors that deserialize query strings if the input -/// couldn't be deserialized into the target type. -#[derive(Debug)] -pub struct FailedToDeserializeQueryString { - error: Error, -} - -impl FailedToDeserializeQueryString { - #[doc(hidden)] - pub fn __private_new(error: E) -> Self - where - E: Into, - { - FailedToDeserializeQueryString { - error: Error::new(error), - } - } -} - -impl IntoResponse for FailedToDeserializeQueryString { - fn into_response(self) -> Response { - (http::StatusCode::BAD_REQUEST, self.to_string()).into_response() - } +define_rejection! { + #[status = BAD_REQUEST] + #[body = "Failed to deserialize form"] + /// Rejection type used if the [`Form`](super::Form) extractor is unable to + /// deserialize the form into the target type. + pub struct FailedToDeserializeForm(Error); } -impl std::fmt::Display for FailedToDeserializeQueryString { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Failed to deserialize query string: {}", self.error,) - } +define_rejection! { + #[status = BAD_REQUEST] + #[body = "Failed to deserialize query string"] + /// Rejection type used if the [`Query`](super::Query) extractor is unable to + /// deserialize the form into the target type. + pub struct FailedToDeserializeQueryString(Error); } -impl std::error::Error for FailedToDeserializeQueryString {} - composite_rejection! { /// Rejection used for [`Query`](super::Query). /// @@ -123,7 +103,7 @@ composite_rejection! { /// can fail. pub enum FormRejection { InvalidFormContentType, - FailedToDeserializeQueryString, + FailedToDeserializeForm, BytesRejection, } } diff --git a/axum/src/form.rs b/axum/src/form.rs index 3ead68a4b0..c30b2260f4 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -76,7 +76,7 @@ where match req.extract().await { Ok(RawForm(bytes)) => { let value = serde_urlencoded::from_bytes(&bytes) - .map_err(FailedToDeserializeQueryString::__private_new)?; + .map_err(FailedToDeserializeForm::from_err)?; Ok(Form(value)) } Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),