Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RequestExt::{with_limited_body, into_limited_body} #1420

Merged
merged 4 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions axum/src/ext_traits/mod.rs → axum-core/src/ext_traits/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
pub(crate) mod request;
pub(crate) mod request_parts;
pub(crate) mod service;

#[cfg(test)]
mod tests {
use std::convert::Infallible;

use crate::extract::{FromRef, FromRequestParts};
use async_trait::async_trait;
use axum_core::extract::{FromRef, FromRequestParts};
use http::request::Parts;

#[derive(Debug, Default, Clone, Copy)]
pub(crate) struct State<S>(pub(crate) S);

#[async_trait]
impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
where
InnerState: FromRef<OuterState>,
OuterState: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(
_parts: &mut Parts,
state: &OuterState,
) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(state);
Ok(Self(inner_state))
}
}

// some extractor that requires the state, such as `SignedCookieJar`
pub(crate) struct RequiresState(pub(crate) String);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use axum_core::extract::{FromRequest, FromRequestParts};
use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts};
use futures_util::future::BoxFuture;
use http::Request;
use http_body::Limited;

mod sealed {
pub trait Sealed<B> {}
Expand Down Expand Up @@ -48,6 +49,16 @@ pub trait RequestExt<B>: sealed::Sealed<B> + Sized {
where
E: FromRequestParts<S> + 'static,
S: Send + Sync;

/// Apply the [default body limit](crate::extract::DefaultBodyLimit).
///
/// If it is disabled, return the request as-is in `Err`.
fn with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>>;

/// Consumes the request, returning the body wrapped in [`Limited`] if a
/// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the
/// default limit is disabled.
fn into_limited_body(self) -> Result<Limited<B>, B>;
}

impl<B> RequestExt<B> for Request<B>
Expand Down Expand Up @@ -105,14 +116,36 @@ where
result
})
}

fn with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>> {
// update docs in `axum-core/src/extract/default_body_limit.rs` and
// `axum/src/docs/extract.md` if this changes
const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb

match self.extensions().get::<DefaultBodyLimitKind>().copied() {
Some(DefaultBodyLimitKind::Disable) => Err(self),
Some(DefaultBodyLimitKind::Limit(limit)) => {
Ok(self.map(|b| http_body::Limited::new(b, limit)))
}
None => Ok(self.map(|b| http_body::Limited::new(b, DEFAULT_LIMIT))),
}
}

fn into_limited_body(self) -> Result<Limited<B>, B> {
self.with_limited_body()
.map(Request::into_body)
.map_err(Request::into_body)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{ext_traits::tests::RequiresState, extract::State};
use crate::{
ext_traits::tests::{RequiresState, State},
extract::FromRef,
};
use async_trait::async_trait;
use axum_core::extract::FromRef;
use http::Method;
use hyper::Body;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use axum_core::extract::FromRequestParts;
use crate::extract::FromRequestParts;
use futures_util::future::BoxFuture;
use http::request::Parts;

Expand Down Expand Up @@ -53,9 +53,11 @@ mod tests {
use std::convert::Infallible;

use super::*;
use crate::{ext_traits::tests::RequiresState, extract::State};
use crate::{
ext_traits::tests::{RequiresState, State},
extract::FromRef,
};
use async_trait::async_trait;
use axum_core::extract::FromRef;
use http::{Method, Request};

#[tokio::test]
Expand All @@ -73,7 +75,10 @@ mod tests {

let state = "state".to_owned();

let State(extracted_state): State<String> = parts.extract_with_state(&state).await.unwrap();
let State(extracted_state): State<String> = parts
.extract_with_state::<State<String>, String>(&state)
.await
.unwrap();

assert_eq!(extracted_state, state);
}
Expand Down
1 change: 1 addition & 0 deletions axum-core/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod from_ref;
mod request_parts;
mod tuple;

pub(crate) use self::default_body_limit::DefaultBodyLimitKind;
pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef};

mod private {
Expand Down
30 changes: 7 additions & 23 deletions axum-core/src/extract/request_parts.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use super::{
default_body_limit::DefaultBodyLimitKind, rejection::*, FromRequest, FromRequestParts,
};
use crate::BoxError;
use super::{rejection::*, FromRequest, FromRequestParts};
use crate::{BoxError, RequestExt};
use async_trait::async_trait;
use bytes::Bytes;
use http::{request::Parts, HeaderMap, Method, Request, Uri, Version};
Expand Down Expand Up @@ -84,27 +82,13 @@ where
type Rejection = BytesRejection;

async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
// update docs in `axum-core/src/extract/default_body_limit.rs` and
// `axum/src/docs/extract.md` if this changes
const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb

let limit_kind = req.extensions().get::<DefaultBodyLimitKind>().copied();
let bytes = match limit_kind {
Some(DefaultBodyLimitKind::Disable) => crate::body::to_bytes(req.into_body())
let bytes = match req.into_limited_body() {
Ok(limited_body) => crate::body::to_bytes(limited_body)
.await
.map_err(FailedToBufferBody::from_err)?,
Err(unlimited_body) => crate::body::to_bytes(unlimited_body)
.await
.map_err(FailedToBufferBody::from_err)?,
Some(DefaultBodyLimitKind::Limit(limit)) => {
let body = http_body::Limited::new(req.into_body(), limit);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
}
None => {
let body = http_body::Limited::new(req.into_body(), DEFAULT_LIMIT);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
}
};

Ok(bytes)
Expand Down
3 changes: 3 additions & 0 deletions axum-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
pub(crate) mod macros;

mod error;
mod ext_traits;
pub use self::error::Error;

pub mod body;
Expand All @@ -60,3 +61,5 @@ pub mod response;

/// Alias for a type-erased error type.
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;

pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt};
2 changes: 2 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
you likely need to re-enable the `tokio` feature ([#1382])
- **breaking:** `handler::{WithState, IntoService}` are merged into one type,
named `HandlerService` ([#1418])
- **changed:** The default body limit now applies to the `Multipart` extractor ([#1420])
- **added:** String and binary `From` impls have been added to `extract::ws::Message`
to be more inline with `tungstenite` ([#1421])

Expand All @@ -54,6 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1408]: https://github.com/tokio-rs/axum/pull/1408
[#1414]: https://github.com/tokio-rs/axum/pull/1414
[#1418]: https://github.com/tokio-rs/axum/pull/1418
[#1420]: https://github.com/tokio-rs/axum/pull/1420
[#1421]: https://github.com/tokio-rs/axum/pull/1421

# 0.6.0-rc.2 (10. September, 2022)
Expand Down
12 changes: 5 additions & 7 deletions axum/src/extract/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::{BodyStream, FromRequest};
use crate::body::{Bytes, HttpBody};
use crate::BoxError;
use async_trait::async_trait;
use axum_core::RequestExt;
use futures_util::stream::Stream;
use http::header::{HeaderMap, CONTENT_TYPE};
use http::Request;
Expand Down Expand Up @@ -47,10 +48,6 @@ use std::{
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// For security reasons it's recommended to combine this with
/// [`RequestBodyLimitLayer`](tower_http::limit::RequestBodyLimitLayer)
/// to limit the size of the request payload.
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
#[derive(Debug)]
pub struct Multipart {
Expand All @@ -69,10 +66,11 @@ where

async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
let stream = match BodyStream::from_request(req, state).await {
Ok(stream) => stream,
Err(err) => match err {},
let stream_result = match req.with_limited_body() {
Ok(limited) => BodyStream::from_request(limited, state).await,
Err(unlimited) => BodyStream::from_request(unlimited, state).await,
};
let stream = stream_result.unwrap_or_else(|err| match err {});
let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart })
}
Expand Down
8 changes: 3 additions & 5 deletions axum/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,12 @@
#[macro_use]
pub(crate) mod macros;

mod ext_traits;
mod extension;
#[cfg(feature = "form")]
mod form;
#[cfg(feature = "json")]
mod json;
mod service_ext;
#[cfg(feature = "headers")]
mod typed_header;
mod util;
Expand Down Expand Up @@ -483,11 +483,9 @@ pub use self::typed_header::TypedHeader;
pub use self::form::Form;

#[doc(inline)]
pub use axum_core::{BoxError, Error};
pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt};

#[cfg(feature = "macros")]
pub use axum_macros::debug_handler;

pub use self::ext_traits::{
request::RequestExt, request_parts::RequestPartsExt, service::ServiceExt,
};
pub use self::service_ext::ServiceExt;
File renamed without changes.