From 54990fa11a1a407569754a12bf0dae5114f2a3e6 Mon Sep 17 00:00:00 2001 From: Tobias Bieniek Date: Sat, 16 Nov 2024 20:35:57 +0100 Subject: [PATCH 1/5] controllers/user/session: Extract `find_user_by_gh_id()` fn --- src/controllers/user/session.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/controllers/user/session.rs b/src/controllers/user/session.rs index 170af3c9139..adbeda5ef32 100644 --- a/src/controllers/user/session.rs +++ b/src/controllers/user/session.rs @@ -2,6 +2,7 @@ use axum::extract::{FromRequestParts, Query}; use axum::Json; use axum_extra::json; use axum_extra::response::ErasedJson; +use diesel::QueryResult; use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; use http::request::Parts; use oauth2::reqwest::http_client; @@ -149,11 +150,7 @@ fn save_user_to_database( // If we're in read only mode, we can't update their details // just look for an existing user if is_read_only_error(&e) { - users::table - .filter(users::gh_id.eq(user.id)) - .first(conn) - .optional()? - .ok_or(e) + find_user_by_gh_id(conn, user.id)?.ok_or(e) } else { Err(e) } @@ -161,6 +158,15 @@ fn save_user_to_database( .map_err(Into::into) } +fn find_user_by_gh_id(conn: &mut impl Conn, gh_id: i32) -> QueryResult> { + use diesel::prelude::*; + + users::table + .filter(users::gh_id.eq(gh_id)) + .first(conn) + .optional() +} + /// Handles the `DELETE /api/private/session` route. pub async fn logout(session: SessionExtension) -> Json { session.remove("user_id"); From dd5e3a1bf527157a4fd79718191b7ef21af63e2a Mon Sep 17 00:00:00 2001 From: Tobias Bieniek Date: Sat, 16 Nov 2024 20:38:55 +0100 Subject: [PATCH 2/5] controllers/user/session: Replace `or_else()` call with `match` statement --- src/controllers/user/session.rs | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/controllers/user/session.rs b/src/controllers/user/session.rs index adbeda5ef32..850885e2721 100644 --- a/src/controllers/user/session.rs +++ b/src/controllers/user/session.rs @@ -138,24 +138,23 @@ fn save_user_to_database( ) -> AppResult { use diesel::prelude::*; - NewUser::new( + let new_user = NewUser::new( user.id, &user.login, user.name.as_deref(), user.avatar_url.as_deref(), access_token, - ) - .create_or_update(user.email.as_deref(), emails, conn) - .or_else(|e| { - // If we're in read only mode, we can't update their details - // just look for an existing user - if is_read_only_error(&e) { - find_user_by_gh_id(conn, user.id)?.ok_or(e) - } else { - Err(e) + ); + + match new_user.create_or_update(user.email.as_deref(), emails, conn) { + Ok(user) => Ok(user), + Err(error) if is_read_only_error(&error) => { + // If we're in read only mode, we can't update their details + // just look for an existing user + find_user_by_gh_id(conn, user.id)?.ok_or_else(|| error.into()) } - }) - .map_err(Into::into) + Err(error) => Err(error.into()), + } } fn find_user_by_gh_id(conn: &mut impl Conn, gh_id: i32) -> QueryResult> { From 738be5882143ab452da70cec46cff6c47aef09c6 Mon Sep 17 00:00:00 2001 From: Tobias Bieniek Date: Sat, 16 Nov 2024 20:12:25 +0100 Subject: [PATCH 3/5] controllers/user/session: Migrate to `diesel-async` database queries --- src/controllers/user/session.rs | 84 ++++++++++---------- src/models/user.rs | 103 +++++++++++++------------ src/tests/user.rs | 61 +++++++-------- src/worker/jobs/expiry_notification.rs | 18 ++--- 4 files changed, 136 insertions(+), 130 deletions(-) diff --git a/src/controllers/user/session.rs b/src/controllers/user/session.rs index 850885e2721..c0fa02241e7 100644 --- a/src/controllers/user/session.rs +++ b/src/controllers/user/session.rs @@ -3,11 +3,10 @@ use axum::Json; use axum_extra::json; use axum_extra::response::ErasedJson; use diesel::QueryResult; -use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; +use diesel_async::AsyncPgConnection; use http::request::Parts; use oauth2::reqwest::http_client; use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse}; -use tokio::runtime::Handle; use crate::app::AppState; use crate::email::Emails; @@ -16,7 +15,7 @@ use crate::middleware::session::SessionExtension; use crate::models::{NewUser, User}; use crate::schema::users; use crate::tasks::spawn_blocking; -use crate::util::diesel::{is_read_only_error, Conn}; +use crate::util::diesel::is_read_only_error; use crate::util::errors::{bad_request, server_error, AppResult}; use crate::views::EncodableMe; use crates_io_github::GithubUser; @@ -90,54 +89,48 @@ pub async fn authorize( session: SessionExtension, req: Parts, ) -> AppResult> { + // Make sure that the state we just got matches the session state that we + // should have issued earlier. + let session_state = session.remove("github_oauth_state").map(CsrfToken::new); + if !session_state.is_some_and(|state| query.state.secret() == state.secret()) { + return Err(bad_request("invalid state parameter")); + } + + // Fetch the access token from GitHub using the code we just got let app_clone = app.clone(); let request_log = req.request_log().clone(); - - let conn = app.db_write().await?; - spawn_blocking(move || { - let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into(); - - // Make sure that the state we just got matches the session state that we - // should have issued earlier. - let session_state = session.remove("github_oauth_state").map(CsrfToken::new); - if !session_state.is_some_and(|state| query.state.secret() == state.secret()) { - return Err(bad_request("invalid state parameter")); - } - - // Fetch the access token from GitHub using the code we just got - let token = app + let token = spawn_blocking(move || { + app_clone .github_oauth .exchange_code(query.code) .request(http_client) .map_err(|err| { request_log.add("cause", err); server_error("Error obtaining token") - })?; + }) + }) + .await?; - let token = token.access_token(); + let token = token.access_token(); - // Fetch the user info from GitHub using the access token we just got and create a user record - let ghuser = Handle::current().block_on(app.github.current_user(token))?; - let user = save_user_to_database(&ghuser, token.secret(), &app.emails, conn)?; + // Fetch the user info from GitHub using the access token we just got and create a user record + let ghuser = app.github.current_user(token).await?; - // Log in by setting a cookie and the middleware authentication - session.insert("user_id".to_string(), user.id.to_string()); + let mut conn = app.db_write().await?; + let user = save_user_to_database(&ghuser, token.secret(), &app.emails, &mut conn).await?; - Ok(()) - }) - .await?; + // Log in by setting a cookie and the middleware authentication + session.insert("user_id".to_string(), user.id.to_string()); - super::me::me(app_clone, req).await + super::me::me(app, req).await } -fn save_user_to_database( +async fn save_user_to_database( user: &GithubUser, access_token: &str, emails: &Emails, - conn: &mut impl Conn, + conn: &mut AsyncPgConnection, ) -> AppResult { - use diesel::prelude::*; - let new_user = NewUser::new( user.id, &user.login, @@ -146,23 +139,30 @@ fn save_user_to_database( access_token, ); - match new_user.create_or_update(user.email.as_deref(), emails, conn) { + match new_user + .create_or_update(user.email.as_deref(), emails, conn) + .await + { Ok(user) => Ok(user), Err(error) if is_read_only_error(&error) => { // If we're in read only mode, we can't update their details // just look for an existing user - find_user_by_gh_id(conn, user.id)?.ok_or_else(|| error.into()) + find_user_by_gh_id(conn, user.id) + .await? + .ok_or_else(|| error.into()) } Err(error) => Err(error.into()), } } -fn find_user_by_gh_id(conn: &mut impl Conn, gh_id: i32) -> QueryResult> { +async fn find_user_by_gh_id(conn: &mut AsyncPgConnection, gh_id: i32) -> QueryResult> { use diesel::prelude::*; + use diesel_async::RunQueryDsl; users::table .filter(users::gh_id.eq(gh_id)) .first(conn) + .await .optional() } @@ -175,12 +175,16 @@ pub async fn logout(session: SessionExtension) -> Json { #[cfg(test)] mod tests { use super::*; - use crate::test_util::test_db_connection; + use crates_io_test_db::TestDatabase; + use diesel_async::AsyncConnection; - #[test] - fn gh_user_with_invalid_email_doesnt_fail() { + #[tokio::test] + async fn gh_user_with_invalid_email_doesnt_fail() { let emails = Emails::new_in_memory(); - let (_test_db, conn) = &mut test_db_connection(); + + let test_db = TestDatabase::new(); + let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let gh_user = GithubUser { email: Some("String.Format(\"{0}.{1}@live.com\", FirstName, LastName)".into()), name: Some("My Name".into()), @@ -188,7 +192,7 @@ mod tests { id: -1, avatar_url: None, }; - let result = save_user_to_database(&gh_user, "arbitrary_token", &emails, conn); + let result = save_user_to_database(&gh_user, "arbitrary_token", &emails, &mut conn).await; assert!( result.is_ok(), diff --git a/src/models/user.rs b/src/models/user.rs index 561dddcaa3e..1dfe85bed28 100644 --- a/src/models/user.rs +++ b/src/models/user.rs @@ -1,5 +1,6 @@ use chrono::NaiveDateTime; -use diesel_async::AsyncPgConnection; +use diesel_async::scoped_futures::ScopedFutureExt; +use diesel_async::{AsyncConnection, AsyncPgConnection}; use secrecy::SecretString; use crate::app::App; @@ -171,66 +172,72 @@ impl<'a> NewUser<'a> { } /// Inserts the user into the database, or updates an existing one. - pub fn create_or_update( + pub async fn create_or_update( &self, email: Option<&'a str>, emails: &Emails, - conn: &mut impl Conn, + conn: &mut AsyncPgConnection, ) -> QueryResult { use diesel::dsl::sql; use diesel::insert_into; use diesel::pg::upsert::excluded; use diesel::sql_types::Integer; - use diesel::RunQueryDsl; + use diesel_async::RunQueryDsl; conn.transaction(|conn| { - let user: User = insert_into(users::table) - .values(self) - // We need the `WHERE gh_id > 0` condition here because `gh_id` set - // to `-1` indicates that we were unable to find a GitHub ID for - // the associated GitHub login at the time that we backfilled - // GitHub IDs. Therefore, there are multiple records in production - // that have a `gh_id` of `-1` so we need to exclude those when - // considering uniqueness of `gh_id` values. The `> 0` condition isn't - // necessary for most fields in the database to be used as a conflict - // target :) - .on_conflict(sql::("(gh_id) WHERE gh_id > 0")) - .do_update() - .set(( - users::gh_login.eq(excluded(users::gh_login)), - users::name.eq(excluded(users::name)), - users::gh_avatar.eq(excluded(users::gh_avatar)), - users::gh_access_token.eq(excluded(users::gh_access_token)), - )) - .get_result(conn)?; - - // To send the user an account verification email - if let Some(user_email) = email { - let new_email = NewEmail { - user_id: user.id, - email: user_email, - }; - - let token = insert_into(emails::table) - .values(&new_email) - .on_conflict_do_nothing() - .returning(emails::token) - .get_result::(conn) - .optional()? - .map(SecretString::from); - - if let Some(token) = token { - // Swallows any error. Some users might insert an invalid email address here. - let email = UserConfirmEmail { - user_name: &user.gh_login, - domain: &emails.domain, - token, + async move { + let user: User = insert_into(users::table) + .values(self) + // We need the `WHERE gh_id > 0` condition here because `gh_id` set + // to `-1` indicates that we were unable to find a GitHub ID for + // the associated GitHub login at the time that we backfilled + // GitHub IDs. Therefore, there are multiple records in production + // that have a `gh_id` of `-1` so we need to exclude those when + // considering uniqueness of `gh_id` values. The `> 0` condition isn't + // necessary for most fields in the database to be used as a conflict + // target :) + .on_conflict(sql::("(gh_id) WHERE gh_id > 0")) + .do_update() + .set(( + users::gh_login.eq(excluded(users::gh_login)), + users::name.eq(excluded(users::name)), + users::gh_avatar.eq(excluded(users::gh_avatar)), + users::gh_access_token.eq(excluded(users::gh_access_token)), + )) + .get_result(conn) + .await?; + + // To send the user an account verification email + if let Some(user_email) = email { + let new_email = NewEmail { + user_id: user.id, + email: user_email, }; - let _ = emails.send(user_email, email); + + let token = insert_into(emails::table) + .values(&new_email) + .on_conflict_do_nothing() + .returning(emails::token) + .get_result::(conn) + .await + .optional()? + .map(SecretString::from); + + if let Some(token) = token { + // Swallows any error. Some users might insert an invalid email address here. + let email = UserConfirmEmail { + user_name: &user.gh_login, + domain: &emails.domain, + token, + }; + let _ = emails.async_send(user_email, email).await; + } } - } - Ok(user) + Ok(user) + } + .scope_boxed() }) + .await } } diff --git a/src/tests/user.rs b/src/tests/user.rs index c0594b210ff..e4f10f1cd9c 100644 --- a/src/tests/user.rs +++ b/src/tests/user.rs @@ -1,5 +1,4 @@ use crate::models::{ApiToken, Email, NewUser, User}; -use crate::tasks::spawn_blocking; use crate::tests::{ new_user, util::{MockCookieUser, RequestHelper}, @@ -7,6 +6,7 @@ use crate::tests::{ }; use crate::util::token::HashedToken; use diesel::prelude::*; +use diesel_async::RunQueryDsl; use http::StatusCode; use secrecy::ExposeSecret; use serde_json::json; @@ -23,21 +23,20 @@ impl crate::tests::util::MockCookieUser { #[tokio::test(flavor = "multi_thread")] async fn updating_existing_user_doesnt_change_api_token() { let (app, _, user, token) = TestApp::init().with_token(); - let mut conn = app.db_conn(); + let mut conn = app.async_db_conn().await; let gh_id = user.as_model().gh_id; let token = token.plaintext(); // Reuse gh_id but use new gh_login and gh_access_token assert_ok!( - NewUser::new(gh_id, "bar", None, None, "bar_token").create_or_update( - None, - &app.as_inner().emails, - &mut conn - ) + NewUser::new(gh_id, "bar", None, None, "bar_token") + .create_or_update(None, &app.as_inner().emails, &mut conn) + .await ); // Use the original API token to find the now updated user let hashed_token = assert_ok!(HashedToken::parse(token.expose_secret())); + let mut conn = app.db_conn(); let api_token = assert_ok!(ApiToken::find_by_api_token(&mut conn, &hashed_token)); let user = assert_ok!(User::find(&mut conn, api_token.user_id)); @@ -57,15 +56,16 @@ async fn updating_existing_user_doesnt_change_api_token() { #[tokio::test(flavor = "multi_thread")] async fn github_without_email_does_not_overwrite_email() { let (app, _) = TestApp::init().empty(); - let mut conn = app.db_conn(); + let mut conn = app.async_db_conn().await; // Simulate logging in via GitHub with an account that has no email. // Because faking GitHub is terrible, call what GithubUser::save_to_database does directly. // Don't use app.db_new_user because it adds a verified email. - let u = new_user("arbitrary_username"); - let u = u + let u = new_user("arbitrary_username") .create_or_update(None, &app.as_inner().emails, &mut conn) + .await .unwrap(); + let user_without_github_email = MockCookieUser::new(&app, u); let user_without_github_email_model = user_without_github_email.as_model(); @@ -87,6 +87,7 @@ async fn github_without_email_does_not_overwrite_email() { }; let u = u .create_or_update(None, &app.as_inner().emails, &mut conn) + .await .unwrap(); let again_user_without_github_email = MockCookieUser::new(&app, u); @@ -101,12 +102,14 @@ async fn github_with_email_does_not_overwrite_email() { use crate::schema::emails; let (app, _, user) = TestApp::init().with_user(); - let mut conn = app.db_conn(); + let mut conn = app.async_db_conn().await; + let model = user.as_model(); let original_email: String = Email::belonging_to(model) .select(emails::email) .first(&mut conn) + .await .unwrap(); let new_github_email = "new-email-in-github@example.com"; @@ -121,12 +124,10 @@ async fn github_with_email_does_not_overwrite_email() { // the rest of the fields are arbitrary ..new_user("arbitrary_username") }; - let u = spawn_blocking(move || { - let u = u.create_or_update(Some(new_github_email), &emails, &mut conn)?; - Ok::<_, anyhow::Error>(u) - }) - .await - .unwrap(); + let u = u + .create_or_update(Some(new_github_email), &emails, &mut conn) + .await + .unwrap(); let user_with_different_email_in_github = MockCookieUser::new(&app, u); @@ -162,19 +163,17 @@ async fn test_confirm_user_email() { use crate::schema::emails; let (app, _) = TestApp::init().empty(); - let mut conn = app.db_conn(); + let mut conn = app.async_db_conn().await; // Simulate logging in via GitHub. Don't use app.db_new_user because it inserts a verified // email directly into the database and we want to test the verification flow here. let email = "potato2@example.com"; let emails = app.as_inner().emails.clone(); - let (u, mut conn) = spawn_blocking(move || { - let u = new_user("arbitrary_username").create_or_update(Some(email), &emails, &mut conn)?; - Ok::<_, anyhow::Error>((u, conn)) - }) - .await - .unwrap(); + let u = new_user("arbitrary_username") + .create_or_update(Some(email), &emails, &mut conn) + .await + .unwrap(); let user = MockCookieUser::new(&app, u); let user_model = user.as_model(); @@ -182,6 +181,7 @@ async fn test_confirm_user_email() { let email_token: String = Email::belonging_to(user_model) .select(emails::token) .first(&mut conn) + .await .unwrap(); user.confirm_email(&email_token).await; @@ -202,25 +202,24 @@ async fn test_existing_user_email() { use diesel::update; let (app, _) = TestApp::init().empty(); - let mut conn = app.db_conn(); + let mut conn = app.async_db_conn().await; // Simulate logging in via GitHub. Don't use app.db_new_user because it inserts a verified // email directly into the database and we want to test the verification flow here. let email = "potahto@example.com"; let emails = app.as_inner().emails.clone(); - let (u, mut conn) = spawn_blocking(move || { - let u = new_user("arbitrary_username").create_or_update(Some(email), &emails, &mut conn)?; - Ok::<_, anyhow::Error>((u, conn)) - }) - .await - .unwrap(); + let u = new_user("arbitrary_username") + .create_or_update(Some(email), &emails, &mut conn) + .await + .unwrap(); update(Email::belonging_to(&u)) // Users created before we added verification will have // `NULL` in the `token_generated_at` column. .set(emails::token_generated_at.eq(None::)) .execute(&mut conn) + .await .unwrap(); let user = MockCookieUser::new(&app, u); diff --git a/src/worker/jobs/expiry_notification.rs b/src/worker/jobs/expiry_notification.rs index 8f0828c3d4c..e7c31bbefb2 100644 --- a/src/worker/jobs/expiry_notification.rs +++ b/src/worker/jobs/expiry_notification.rs @@ -167,7 +167,6 @@ The crates.io team"#, mod tests { use super::*; use crate::models::NewUser; - use crate::tasks::spawn_blocking; use crate::{models::token::ApiToken, schema::api_tokens, util::token::PlainToken}; use crates_io_test_db::TestDatabase; use diesel::dsl::IntervalDsl; @@ -176,20 +175,15 @@ mod tests { #[tokio::test] async fn test_expiry_notification() -> anyhow::Result<()> { - let emails = Emails::new_in_memory(); - let test_db = TestDatabase::new(); let mut conn = AsyncPgConnection::establish(test_db.url()).await?; // Set up a user and a token that is about to expire. - let mut sync_conn = test_db.connect(); - let user = spawn_blocking(move || { - let user = NewUser::new(0, "a", None, None, "token"); - let emails = Emails::new_in_memory(); - let user = user.create_or_update(Some("testuser@test.com"), &emails, &mut sync_conn)?; - Ok::<_, anyhow::Error>(user) - }) - .await?; + let user = NewUser::new(0, "a", None, None, "token"); + let emails = Emails::new_in_memory(); + let user = user + .create_or_update(Some("testuser@test.com"), &emails, &mut conn) + .await?; let token = PlainToken::generate(); @@ -220,6 +214,8 @@ mod tests { .await?; } + let emails = Emails::new_in_memory(); + // Check that the token is about to expire. check(&emails, &mut conn).await?; From edcf8abb34478681a27f5e0e54eda485ccf4ede9 Mon Sep 17 00:00:00 2001 From: Tobias Bieniek Date: Sat, 16 Nov 2024 22:08:26 +0100 Subject: [PATCH 4/5] controllers/user/session: Migrate to async `oauth2` client --- src/controllers/user/session.rs | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/controllers/user/session.rs b/src/controllers/user/session.rs index c0fa02241e7..0e2b2012f16 100644 --- a/src/controllers/user/session.rs +++ b/src/controllers/user/session.rs @@ -5,7 +5,7 @@ use axum_extra::response::ErasedJson; use diesel::QueryResult; use diesel_async::AsyncPgConnection; use http::request::Parts; -use oauth2::reqwest::http_client; +use oauth2::reqwest::async_http_client; use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse}; use crate::app::AppState; @@ -14,7 +14,6 @@ use crate::middleware::log_request::RequestLogExt; use crate::middleware::session::SessionExtension; use crate::models::{NewUser, User}; use crate::schema::users; -use crate::tasks::spawn_blocking; use crate::util::diesel::is_read_only_error; use crate::util::errors::{bad_request, server_error, AppResult}; use crate::views::EncodableMe; @@ -97,19 +96,15 @@ pub async fn authorize( } // Fetch the access token from GitHub using the code we just got - let app_clone = app.clone(); - let request_log = req.request_log().clone(); - let token = spawn_blocking(move || { - app_clone - .github_oauth - .exchange_code(query.code) - .request(http_client) - .map_err(|err| { - request_log.add("cause", err); - server_error("Error obtaining token") - }) - }) - .await?; + let token = app + .github_oauth + .exchange_code(query.code) + .request_async(async_http_client) + .await + .map_err(|err| { + req.request_log().add("cause", err); + server_error("Error obtaining token") + })?; let token = token.access_token(); From 09bf3031ad2831febff416b45981f281a385f79b Mon Sep 17 00:00:00 2001 From: Tobias Bieniek Date: Sat, 16 Nov 2024 22:43:14 +0100 Subject: [PATCH 5/5] controllers/user/session: Simplify imports --- src/controllers/user/session.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/controllers/user/session.rs b/src/controllers/user/session.rs index 0e2b2012f16..4c694ec4b86 100644 --- a/src/controllers/user/session.rs +++ b/src/controllers/user/session.rs @@ -2,8 +2,8 @@ use axum::extract::{FromRequestParts, Query}; use axum::Json; use axum_extra::json; use axum_extra::response::ErasedJson; -use diesel::QueryResult; -use diesel_async::AsyncPgConnection; +use diesel::prelude::*; +use diesel_async::{AsyncPgConnection, RunQueryDsl}; use http::request::Parts; use oauth2::reqwest::async_http_client; use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse}; @@ -151,9 +151,6 @@ async fn save_user_to_database( } async fn find_user_by_gh_id(conn: &mut AsyncPgConnection, gh_id: i32) -> QueryResult> { - use diesel::prelude::*; - use diesel_async::RunQueryDsl; - users::table .filter(users::gh_id.eq(gh_id)) .first(conn)