Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support State with #[derive(FromRequest[Parts])] #1391

Merged
merged 6 commits into from Sep 23, 2022
Merged

Conversation

davidpdrsn
Copy link
Member

@davidpdrsn davidpdrsn commented Sep 18, 2022

Fixes #1314

This makes it possible to extract things via State in #[derive(FromRequet)]:

#[derive(FromRequet)]
#[from_request(state(AppState))]
struct Foo {
    state: State<AppState>,
}

The state can also be inferred in a lot of cases so you only need to write:

#[derive(FromRequet)]
struct Foo {
    // since we're using `State<AppState>` we know the state has to be
    // `AppState`
    state: State<AppState>,
}

Same for

#[derive(FromRequet)]
struct Foo {
    #[from_request(via(State))]
    state: AppState,
}

And

#[derive(FromRequet)]
#[from_request(via(State))]
struct AppState {}

I think I've covered all the edge cases but there are (unsurprisingly) a few.

Todo

  • Test that these extractors can be combined with others in the same handler, ie that we don't get errors because of M

Fixes #1314

This makes it possible to extract things via `State` in
`#[derive(FromRequet)]`:

```rust
struct Foo {
    state: State<AppState>,
}
```

The state can also be inferred in a lot of cases so you only need to
write:

```rust
struct Foo {
    // since we're using `State<AppState>` we know the state has to be
    // `AppState`
    state: State<AppState>,
}
```

Same for

```rust
struct Foo {
    #[from_request(via(State))]
    state: AppState,
}
```

And

```rust
struct AppState {}
```

I think I've covered all the edge cases but there are (unsurprisingly) a
few.
@davidpdrsn davidpdrsn added this to the 0.6 milestone Sep 18, 2022
@genusistimelord
Copy link
Contributor

I will give this a Test since I use a ton of different extractors in my full state struct.

@davidpdrsn
Copy link
Member Author

@genusistimelord Thanks!

Comment on lines +173 to +176
None => (|| {
let via = via.as_ref().map(|(_, via)| via)?;
state_from_via(&ident, via).map(State::Custom)
})()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
None => (|| {
let via = via.as_ref().map(|(_, via)| via)?;
state_from_via(&ident, via).map(State::Custom)
})()
None => via
.as_ref()
.and_then(|(_, via)| state_from_via(&ident, via))
.map(State::Custom)

Comment on lines +624 to +628
if let Type::Path(path) = ty {
Some(&path.path)
} else {
None
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be a bit shorter if converted to match.

Comment on lines +647 to +651
if let syn::GenericArgument::Type(ty) = generic_arg {
Some(ty)
} else {
None
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@jplatte
Copy link
Member

jplatte commented Sep 20, 2022

Will have to come back to this, but left a few first comments.

@genusistimelord
Copy link
Contributor

It worked very nicely.

So you know here is my State

use axum::extract::FromRef;

#[derive(Clone)]
pub struct SystemState {
    pub odbc: axum_odbc::ODBCConnectionManager,
    pub flash_config: axum_flash::Config,
    pub csrf: axum_csrf::CsrfConfig,
}

impl SystemState {
    pub fn new(
        odbc: axum_odbc::ODBCConnectionManager,
        flash_config: axum_flash::Config,
        csrf: axum_csrf::CsrfConfig,
    ) -> Self {
        Self {
            odbc,
            flash_config,
            csrf,
        }
    }
}

impl FromRef<SystemState> for axum_odbc::ODBCConnectionManager {
    fn from_ref(input: &SystemState) -> Self {
        input.odbc.clone()
    }
}

impl FromRef<SystemState> for axum_flash::Config {
    fn from_ref(input: &SystemState) -> Self {
        input.flash_config.clone()
    }
}

impl FromRef<SystemState> for axum_csrf::CsrfConfig {
    fn from_ref(input: &SystemState) -> Self {
        input.csrf.clone()
    }
}

and here is my FullState which uses the derive

#[derive(FromRequestParts, Clone)]
#[from_request(state(SystemState))]
pub struct FullState {
    pub auth: AuthSession<User, i64, AxumPgPool, PgPool>,
    pub flashes: Flashes,
    pub method: Method,
    pub odbc: ODBCConnectionManager,
    pub outgoingflash: Flash,
    #[from_request(via(Extension))]
    pub pool: PgPool,
    #[from_request(via(Extension))]
    pub state: ServerState,
    pub token: axum_csrf::CsrfToken,
}

and here is the Generated code from the macro

impl ::axum::extract::FromRequestParts<SystemState> for FullState {
    type Rejection = ::axum::response::Response;
    #[allow(
        clippy::let_unit_value,
        clippy::no_effect_underscore_binding,
        clippy::shadow_same,
        clippy::type_complexity,
        clippy::type_repetition_in_bounds,
        clippy::used_underscore_binding
    )]
    fn from_request_parts<'life0, 'life1, 'async_trait>(
        parts: &'life0 mut ::axum::http::request::Parts,
        state: &'life1 SystemState,
    ) -> ::core::pin::Pin<
        Box<
            dyn ::core::future::Future<
                Output = ::std::result::Result<Self, Self::Rejection>,
            > + ::core::marker::Send + 'async_trait,
        >,
    >
    where
        'life0: 'async_trait,
        'life1: 'async_trait,
        Self: 'async_trait,
    {
        Box::pin(async move {
            if let ::core::option::Option::Some(__ret)
                = ::core::option::Option::None::<
                    ::std::result::Result<Self, Self::Rejection>,
                > {
                return __ret;
            }
            let parts = parts;
            let state = state;
            let __ret: ::std::result::Result<Self, Self::Rejection> = {
                ::std::result::Result::Ok(Self {
                    auth: {
                        ::axum::extract::FromRequestParts::from_request_parts(
                                parts,
                                state,
                            )
                            .await
                            .map(::std::convert::identity)
                            .map_err(::axum::response::IntoResponse::into_response)?
                    },
                    flashes: {
                        ::axum::extract::FromRequestParts::from_request_parts(
                                parts,
                                state,
                            )
                            .await
                            .map(::std::convert::identity)
                            .map_err(::axum::response::IntoResponse::into_response)?
                    },
                    method: {
                        ::axum::extract::FromRequestParts::from_request_parts(
                                parts,
                                state,
                            )
                            .await
                            .map(::std::convert::identity)
                            .map_err(::axum::response::IntoResponse::into_response)?
                    },
                    odbc: {
                        ::axum::extract::FromRequestParts::from_request_parts(
                                parts,
                                state,
                            )
                            .await
                            .map(|Extension(inner)| inner)
                            .map_err(::axum::response::IntoResponse::into_response)?
                    },
                    outgoingflash: {
                        ::axum::extract::FromRequestParts::from_request_parts(
                                parts,
                                state,
                            )
                            .await
                            .map(::std::convert::identity)
                            .map_err(::axum::response::IntoResponse::into_response)?
                    },
                    pool: {
                        ::axum::extract::FromRequestParts::from_request_parts(
                                parts,
                                state,
                            )
                            .await
                            .map(|Extension(inner)| inner)
                            .map_err(::axum::response::IntoResponse::into_response)?
                    },
                    state: {
                        ::axum::extract::FromRequestParts::from_request_parts(
                                parts,
                                state,
                            )
                            .await
                            .map(|Extension(inner)| inner)
                            .map_err(::axum::response::IntoResponse::into_response)?
                    },
                    token: {
                        ::axum::extract::FromRequestParts::from_request_parts(
                                parts,
                                state,
                            )
                            .await
                            .map(::std::convert::identity)
                            .map_err(::axum::response::IntoResponse::into_response)?
                    },
                })
            };
            #[allow(unreachable_code)] __ret
        })
    }
}

Copy link
Member

@jplatte jplatte left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one more comment, have now reviewed everything.

Comment on lines +1 to +23
error[E0277]: the trait bound `AppState: FromRef<S>` is not satisfied
--> tests/from_request/fail/state_infer_multiple_different_types.rs:6:18
|
6 | inner_state: State<AppState>,
| ^^^^^ the trait `FromRef<S>` is not implemented for `AppState`
|
= note: required because of the requirements on the impl of `FromRequestParts<S>` for `State<AppState>`
help: consider extending the `where` clause, but there might be an alternative better way to express this requirement
|
4 | #[derive(FromRequest, AppState: FromRef<S>)]
| ++++++++++++++++++++++

error[E0277]: the trait bound `OtherState: FromRef<S>` is not satisfied
--> tests/from_request/fail/state_infer_multiple_different_types.rs:7:18
|
7 | other_state: State<OtherState>,
| ^^^^^ the trait `FromRef<S>` is not implemented for `OtherState`
|
= note: required because of the requirements on the impl of `FromRequestParts<S>` for `State<OtherState>`
help: consider extending the `where` clause, but there might be an alternative better way to express this requirement
|
4 | #[derive(FromRequest, OtherState: FromRef<S>)]
| ++++++++++++++++++++++++
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's room for improvement on this error message, but the important bit is that we have a test for it not working. I can work on the error message after this is merged.

@davidpdrsn davidpdrsn merged commit c3f3db7 into main Sep 23, 2022
@davidpdrsn davidpdrsn deleted the from-request-state branch September 23, 2022 21:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

#[derive(FromRequest[Parts])] doesn't work with State
3 participants