Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor storing URL params in extensions #833

Merged
merged 3 commits into from
Mar 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions axum/src/extract/path/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ mod de;

use crate::{
extract::{rejection::*, FromRequest, RequestParts},
routing::{InvalidUtf8InPathParam, UrlParams},
routing::url_params::UrlParams,
};
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
use http::StatusCode;
use serde::de::DeserializeOwned;
use std::{
borrow::Cow,
fmt,
ops::{Deref, DerefMut},
};
Expand Down Expand Up @@ -162,9 +161,9 @@ where
type Rejection = PathRejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let params = match req.extensions_mut().get::<Option<UrlParams>>() {
Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params),
Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => {
let params = match req.extensions_mut().get::<UrlParams>() {
Some(UrlParams::Params(params)) => params,
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
let err = PathDeserializationError {
kind: ErrorKind::InvalidUtf8InPathParam {
key: key.as_str().to_owned(),
Expand All @@ -173,7 +172,6 @@ where
let err = FailedToDeserializePathParams(err);
return Err(err.into());
}
Some(None) => Cow::Owned(Vec::new()),
None => {
return Err(MissingPathParams.into());
}
Expand Down
55 changes: 3 additions & 52 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
response::{IntoResponse, Redirect, Response},
routing::strip_prefix::StripPrefix,
util::{try_downcast, ByteStr, PercentDecodedByteStr},
util::try_downcast,
BoxError,
};
use http::{Request, Uri};
Expand All @@ -31,6 +31,7 @@ mod method_routing;
mod not_found;
mod route;
mod strip_prefix;
pub(crate) mod url_params;

#[cfg(test)]
mod tests;
Expand Down Expand Up @@ -446,14 +447,7 @@ where
panic!("should always have a matched path for a route id");
}

let params = match_
.params
.iter()
.filter(|(key, _)| !key.starts_with(NEST_TAIL_PARAM))
.map(|(key, value)| (key.to_owned(), value.to_owned()))
.collect::<Vec<_>>();

insert_url_params(&mut req, params);
url_params::insert_url_params(req.extensions_mut(), match_.params);

let mut route = self
.routes
Expand Down Expand Up @@ -557,49 +551,6 @@ fn with_path(uri: &Uri, new_path: &str) -> Uri {
Uri::from_parts(parts).unwrap()
}

// we store the potential error here such that users can handle invalid path
// params using `Result<Path<T>, _>`. That wouldn't be possible if we
// returned an error immediately when decoding the param
pub(crate) struct UrlParams(
pub(crate) Result<Vec<(ByteStr, PercentDecodedByteStr)>, InvalidUtf8InPathParam>,
);

fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
let params = params
.into_iter()
.map(|(k, v)| {
if let Some(decoded) = PercentDecodedByteStr::new(v) {
Ok((ByteStr::new(k), decoded))
} else {
Err(InvalidUtf8InPathParam {
key: ByteStr::new(k),
})
}
})
.collect::<Result<Vec<_>, _>>();

if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
match params {
Ok(params) => {
let mut current = current.take().unwrap();
if let Ok(current) = &mut current.0 {
current.extend(params);
}
req.extensions_mut().insert(Some(current));
}
Err(err) => {
req.extensions_mut().insert(Some(UrlParams(Err(err))));
}
}
} else {
req.extensions_mut().insert(Some(UrlParams(params)));
}
}

pub(crate) struct InvalidUtf8InPathParam {
pub(crate) key: ByteStr,
}

/// Wrapper around `matchit::Node` that supports merging two `Node`s.
#[derive(Clone, Default)]
struct Node {
Expand Down
45 changes: 45 additions & 0 deletions axum/src/routing/url_params.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::util::{ByteStr, PercentDecodedByteStr};
use http::Extensions;
use matchit::Params;

pub(crate) enum UrlParams {
Params(Vec<(ByteStr, PercentDecodedByteStr)>),
InvalidUtf8InPathParam { key: ByteStr },
}

pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) {
let current_params = extensions.get_mut();

if let Some(UrlParams::InvalidUtf8InPathParam { .. }) = current_params {
// nothing to do here since an error was stored earlier
return;
}

let params = params
.iter()
.filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit weird, will this special-cased parameter name go away with the next PR that normalizes path params?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to do with nested opaque services. Its implemented as a wildcard route like /path/*NEST_TAIL_PARAM. So NEST_TAIL_PARAM is reserved by axum and shouldn't pollute the path params the user might see if using Path<HashMap<String, String>>.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I think it could use being more obviously special in that case, something like __private__axum_nest_tail_param.

.map(|(key, value)| (key.to_owned(), value.to_owned()))
.map(|(k, v)| {
if let Some(decoded) = PercentDecodedByteStr::new(v) {
Ok((ByteStr::new(k), decoded))
} else {
Err(ByteStr::new(k))
}
})
.collect::<Result<Vec<_>, _>>();

match (current_params, params) {
(Some(UrlParams::InvalidUtf8InPathParam { .. }), _) => {
unreachable!("we check for this state earlier in this method")
}
(_, Err(invalid_key)) => {
extensions.insert(UrlParams::InvalidUtf8InPathParam { key: invalid_key });
}
(Some(UrlParams::Params(current)), Ok(params)) => {
current.extend(params);
}
(None, Ok(params)) => {
extensions.insert(UrlParams::Params(params));
}
}
}