diff --git a/src/controllers/user/session.rs b/src/controllers/user/session.rs index 170af3c9139..4c694ec4b86 100644 --- a/src/controllers/user/session.rs +++ b/src/controllers/user/session.rs @@ -2,11 +2,11 @@ use axum::extract::{FromRequestParts, Query}; use axum::Json; use axum_extra::json; use axum_extra::response::ErasedJson; -use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; +use diesel::prelude::*; +use diesel_async::{AsyncPgConnection, RunQueryDsl}; use http::request::Parts; -use oauth2::reqwest::http_client; +use oauth2::reqwest::async_http_client; use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse}; -use tokio::runtime::Handle; use crate::app::AppState; use crate::email::Emails; @@ -14,8 +14,7 @@ use crate::middleware::log_request::RequestLogExt; use crate::middleware::session::SessionExtension; use crate::models::{NewUser, User}; use crate::schema::users; -use crate::tasks::spawn_blocking; -use crate::util::diesel::{is_read_only_error, Conn}; +use crate::util::diesel::is_read_only_error; use crate::util::errors::{bad_request, server_error, AppResult}; use crate::views::EncodableMe; use crates_io_github::GithubUser; @@ -89,76 +88,74 @@ pub async fn authorize( session: SessionExtension, req: Parts, ) -> AppResult> { - let app_clone = app.clone(); - let request_log = req.request_log().clone(); - - let conn = app.db_write().await?; - spawn_blocking(move || { - let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into(); - - // Make sure that the state we just got matches the session state that we - // should have issued earlier. - let session_state = session.remove("github_oauth_state").map(CsrfToken::new); - if !session_state.is_some_and(|state| query.state.secret() == state.secret()) { - return Err(bad_request("invalid state parameter")); - } + // Make sure that the state we just got matches the session state that we + // should have issued earlier. + let session_state = session.remove("github_oauth_state").map(CsrfToken::new); + if !session_state.is_some_and(|state| query.state.secret() == state.secret()) { + return Err(bad_request("invalid state parameter")); + } - // Fetch the access token from GitHub using the code we just got - let token = app - .github_oauth - .exchange_code(query.code) - .request(http_client) - .map_err(|err| { - request_log.add("cause", err); - server_error("Error obtaining token") - })?; + // Fetch the access token from GitHub using the code we just got + let token = app + .github_oauth + .exchange_code(query.code) + .request_async(async_http_client) + .await + .map_err(|err| { + req.request_log().add("cause", err); + server_error("Error obtaining token") + })?; - let token = token.access_token(); + let token = token.access_token(); - // Fetch the user info from GitHub using the access token we just got and create a user record - let ghuser = Handle::current().block_on(app.github.current_user(token))?; - let user = save_user_to_database(&ghuser, token.secret(), &app.emails, conn)?; + // Fetch the user info from GitHub using the access token we just got and create a user record + let ghuser = app.github.current_user(token).await?; - // Log in by setting a cookie and the middleware authentication - session.insert("user_id".to_string(), user.id.to_string()); + let mut conn = app.db_write().await?; + let user = save_user_to_database(&ghuser, token.secret(), &app.emails, &mut conn).await?; - Ok(()) - }) - .await?; + // Log in by setting a cookie and the middleware authentication + session.insert("user_id".to_string(), user.id.to_string()); - super::me::me(app_clone, req).await + super::me::me(app, req).await } -fn save_user_to_database( +async fn save_user_to_database( user: &GithubUser, access_token: &str, emails: &Emails, - conn: &mut impl Conn, + conn: &mut AsyncPgConnection, ) -> AppResult { - use diesel::prelude::*; - - NewUser::new( + let new_user = NewUser::new( user.id, &user.login, user.name.as_deref(), user.avatar_url.as_deref(), access_token, - ) - .create_or_update(user.email.as_deref(), emails, conn) - .or_else(|e| { - // If we're in read only mode, we can't update their details - // just look for an existing user - if is_read_only_error(&e) { - users::table - .filter(users::gh_id.eq(user.id)) - .first(conn) - .optional()? - .ok_or(e) - } else { - Err(e) + ); + + match new_user + .create_or_update(user.email.as_deref(), emails, conn) + .await + { + Ok(user) => Ok(user), + Err(error) if is_read_only_error(&error) => { + // If we're in read only mode, we can't update their details + // just look for an existing user + find_user_by_gh_id(conn, user.id) + .await? + .ok_or_else(|| error.into()) } - }) - .map_err(Into::into) + Err(error) => Err(error.into()), + } +} + +async fn find_user_by_gh_id(conn: &mut AsyncPgConnection, gh_id: i32) -> QueryResult> { + users::table + .filter(users::gh_id.eq(gh_id)) + .first(conn) + .await + .optional() } /// Handles the `DELETE /api/private/session` route. @@ -170,12 +167,16 @@ pub async fn logout(session: SessionExtension) -> Json { #[cfg(test)] mod tests { use super::*; - use crate::test_util::test_db_connection; + use crates_io_test_db::TestDatabase; + use diesel_async::AsyncConnection; - #[test] - fn gh_user_with_invalid_email_doesnt_fail() { + #[tokio::test] + async fn gh_user_with_invalid_email_doesnt_fail() { let emails = Emails::new_in_memory(); - let (_test_db, conn) = &mut test_db_connection(); + + let test_db = TestDatabase::new(); + let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let gh_user = GithubUser { email: Some("String.Format(\"{0}.{1}@live.com\", FirstName, LastName)".into()), name: Some("My Name".into()), @@ -183,7 +184,7 @@ mod tests { id: -1, avatar_url: None, }; - let result = save_user_to_database(&gh_user, "arbitrary_token", &emails, conn); + let result = save_user_to_database(&gh_user, "arbitrary_token", &emails, &mut conn).await; assert!( result.is_ok(), diff --git a/src/models/user.rs b/src/models/user.rs index 561dddcaa3e..1dfe85bed28 100644 --- a/src/models/user.rs +++ b/src/models/user.rs @@ -1,5 +1,6 @@ use chrono::NaiveDateTime; -use diesel_async::AsyncPgConnection; +use diesel_async::scoped_futures::ScopedFutureExt; +use diesel_async::{AsyncConnection, AsyncPgConnection}; use secrecy::SecretString; use crate::app::App; @@ -171,66 +172,72 @@ impl<'a> NewUser<'a> { } /// Inserts the user into the database, or updates an existing one. - pub fn create_or_update( + pub async fn create_or_update( &self, email: Option<&'a str>, emails: &Emails, - conn: &mut impl Conn, + conn: &mut AsyncPgConnection, ) -> QueryResult { use diesel::dsl::sql; use diesel::insert_into; use diesel::pg::upsert::excluded; use diesel::sql_types::Integer; - use diesel::RunQueryDsl; + use diesel_async::RunQueryDsl; conn.transaction(|conn| { - let user: User = insert_into(users::table) - .values(self) - // We need the `WHERE gh_id > 0` condition here because `gh_id` set - // to `-1` indicates that we were unable to find a GitHub ID for - // the associated GitHub login at the time that we backfilled - // GitHub IDs. Therefore, there are multiple records in production - // that have a `gh_id` of `-1` so we need to exclude those when - // considering uniqueness of `gh_id` values. The `> 0` condition isn't - // necessary for most fields in the database to be used as a conflict - // target :) - .on_conflict(sql::("(gh_id) WHERE gh_id > 0")) - .do_update() - .set(( - users::gh_login.eq(excluded(users::gh_login)), - users::name.eq(excluded(users::name)), - users::gh_avatar.eq(excluded(users::gh_avatar)), - users::gh_access_token.eq(excluded(users::gh_access_token)), - )) - .get_result(conn)?; - - // To send the user an account verification email - if let Some(user_email) = email { - let new_email = NewEmail { - user_id: user.id, - email: user_email, - }; - - let token = insert_into(emails::table) - .values(&new_email) - .on_conflict_do_nothing() - .returning(emails::token) - .get_result::(conn) - .optional()? - .map(SecretString::from); - - if let Some(token) = token { - // Swallows any error. Some users might insert an invalid email address here. - let email = UserConfirmEmail { - user_name: &user.gh_login, - domain: &emails.domain, - token, + async move { + let user: User = insert_into(users::table) + .values(self) + // We need the `WHERE gh_id > 0` condition here because `gh_id` set + // to `-1` indicates that we were unable to find a GitHub ID for + // the associated GitHub login at the time that we backfilled + // GitHub IDs. Therefore, there are multiple records in production + // that have a `gh_id` of `-1` so we need to exclude those when + // considering uniqueness of `gh_id` values. The `> 0` condition isn't + // necessary for most fields in the database to be used as a conflict + // target :) + .on_conflict(sql::("(gh_id) WHERE gh_id > 0")) + .do_update() + .set(( + users::gh_login.eq(excluded(users::gh_login)), + users::name.eq(excluded(users::name)), + users::gh_avatar.eq(excluded(users::gh_avatar)), + users::gh_access_token.eq(excluded(users::gh_access_token)), + )) + .get_result(conn) + .await?; + + // To send the user an account verification email + if let Some(user_email) = email { + let new_email = NewEmail { + user_id: user.id, + email: user_email, }; - let _ = emails.send(user_email, email); + + let token = insert_into(emails::table) + .values(&new_email) + .on_conflict_do_nothing() + .returning(emails::token) + .get_result::(conn) + .await + .optional()? + .map(SecretString::from); + + if let Some(token) = token { + // Swallows any error. Some users might insert an invalid email address here. + let email = UserConfirmEmail { + user_name: &user.gh_login, + domain: &emails.domain, + token, + }; + let _ = emails.async_send(user_email, email).await; + } } - } - Ok(user) + Ok(user) + } + .scope_boxed() }) + .await } } diff --git a/src/tests/user.rs b/src/tests/user.rs index c0594b210ff..e4f10f1cd9c 100644 --- a/src/tests/user.rs +++ b/src/tests/user.rs @@ -1,5 +1,4 @@ use crate::models::{ApiToken, Email, NewUser, User}; -use crate::tasks::spawn_blocking; use crate::tests::{ new_user, util::{MockCookieUser, RequestHelper}, @@ -7,6 +6,7 @@ use crate::tests::{ }; use crate::util::token::HashedToken; use diesel::prelude::*; +use diesel_async::RunQueryDsl; use http::StatusCode; use secrecy::ExposeSecret; use serde_json::json; @@ -23,21 +23,20 @@ impl crate::tests::util::MockCookieUser { #[tokio::test(flavor = "multi_thread")] async fn updating_existing_user_doesnt_change_api_token() { let (app, _, user, token) = TestApp::init().with_token(); - let mut conn = app.db_conn(); + let mut conn = app.async_db_conn().await; let gh_id = user.as_model().gh_id; let token = token.plaintext(); // Reuse gh_id but use new gh_login and gh_access_token assert_ok!( - NewUser::new(gh_id, "bar", None, None, "bar_token").create_or_update( - None, - &app.as_inner().emails, - &mut conn - ) + NewUser::new(gh_id, "bar", None, None, "bar_token") + .create_or_update(None, &app.as_inner().emails, &mut conn) + .await ); // Use the original API token to find the now updated user let hashed_token = assert_ok!(HashedToken::parse(token.expose_secret())); + let mut conn = app.db_conn(); let api_token = assert_ok!(ApiToken::find_by_api_token(&mut conn, &hashed_token)); let user = assert_ok!(User::find(&mut conn, api_token.user_id)); @@ -57,15 +56,16 @@ async fn updating_existing_user_doesnt_change_api_token() { #[tokio::test(flavor = "multi_thread")] async fn github_without_email_does_not_overwrite_email() { let (app, _) = TestApp::init().empty(); - let mut conn = app.db_conn(); + let mut conn = app.async_db_conn().await; // Simulate logging in via GitHub with an account that has no email. // Because faking GitHub is terrible, call what GithubUser::save_to_database does directly. // Don't use app.db_new_user because it adds a verified email. - let u = new_user("arbitrary_username"); - let u = u + let u = new_user("arbitrary_username") .create_or_update(None, &app.as_inner().emails, &mut conn) + .await .unwrap(); + let user_without_github_email = MockCookieUser::new(&app, u); let user_without_github_email_model = user_without_github_email.as_model(); @@ -87,6 +87,7 @@ async fn github_without_email_does_not_overwrite_email() { }; let u = u .create_or_update(None, &app.as_inner().emails, &mut conn) + .await .unwrap(); let again_user_without_github_email = MockCookieUser::new(&app, u); @@ -101,12 +102,14 @@ async fn github_with_email_does_not_overwrite_email() { use crate::schema::emails; let (app, _, user) = TestApp::init().with_user(); - let mut conn = app.db_conn(); + let mut conn = app.async_db_conn().await; + let model = user.as_model(); let original_email: String = Email::belonging_to(model) .select(emails::email) .first(&mut conn) + .await .unwrap(); let new_github_email = "new-email-in-github@example.com"; @@ -121,12 +124,10 @@ async fn github_with_email_does_not_overwrite_email() { // the rest of the fields are arbitrary ..new_user("arbitrary_username") }; - let u = spawn_blocking(move || { - let u = u.create_or_update(Some(new_github_email), &emails, &mut conn)?; - Ok::<_, anyhow::Error>(u) - }) - .await - .unwrap(); + let u = u + .create_or_update(Some(new_github_email), &emails, &mut conn) + .await + .unwrap(); let user_with_different_email_in_github = MockCookieUser::new(&app, u); @@ -162,19 +163,17 @@ async fn test_confirm_user_email() { use crate::schema::emails; let (app, _) = TestApp::init().empty(); - let mut conn = app.db_conn(); + let mut conn = app.async_db_conn().await; // Simulate logging in via GitHub. Don't use app.db_new_user because it inserts a verified // email directly into the database and we want to test the verification flow here. let email = "potato2@example.com"; let emails = app.as_inner().emails.clone(); - let (u, mut conn) = spawn_blocking(move || { - let u = new_user("arbitrary_username").create_or_update(Some(email), &emails, &mut conn)?; - Ok::<_, anyhow::Error>((u, conn)) - }) - .await - .unwrap(); + let u = new_user("arbitrary_username") + .create_or_update(Some(email), &emails, &mut conn) + .await + .unwrap(); let user = MockCookieUser::new(&app, u); let user_model = user.as_model(); @@ -182,6 +181,7 @@ async fn test_confirm_user_email() { let email_token: String = Email::belonging_to(user_model) .select(emails::token) .first(&mut conn) + .await .unwrap(); user.confirm_email(&email_token).await; @@ -202,25 +202,24 @@ async fn test_existing_user_email() { use diesel::update; let (app, _) = TestApp::init().empty(); - let mut conn = app.db_conn(); + let mut conn = app.async_db_conn().await; // Simulate logging in via GitHub. Don't use app.db_new_user because it inserts a verified // email directly into the database and we want to test the verification flow here. let email = "potahto@example.com"; let emails = app.as_inner().emails.clone(); - let (u, mut conn) = spawn_blocking(move || { - let u = new_user("arbitrary_username").create_or_update(Some(email), &emails, &mut conn)?; - Ok::<_, anyhow::Error>((u, conn)) - }) - .await - .unwrap(); + let u = new_user("arbitrary_username") + .create_or_update(Some(email), &emails, &mut conn) + .await + .unwrap(); update(Email::belonging_to(&u)) // Users created before we added verification will have // `NULL` in the `token_generated_at` column. .set(emails::token_generated_at.eq(None::)) .execute(&mut conn) + .await .unwrap(); let user = MockCookieUser::new(&app, u); diff --git a/src/worker/jobs/expiry_notification.rs b/src/worker/jobs/expiry_notification.rs index 8f0828c3d4c..e7c31bbefb2 100644 --- a/src/worker/jobs/expiry_notification.rs +++ b/src/worker/jobs/expiry_notification.rs @@ -167,7 +167,6 @@ The crates.io team"#, mod tests { use super::*; use crate::models::NewUser; - use crate::tasks::spawn_blocking; use crate::{models::token::ApiToken, schema::api_tokens, util::token::PlainToken}; use crates_io_test_db::TestDatabase; use diesel::dsl::IntervalDsl; @@ -176,20 +175,15 @@ mod tests { #[tokio::test] async fn test_expiry_notification() -> anyhow::Result<()> { - let emails = Emails::new_in_memory(); - let test_db = TestDatabase::new(); let mut conn = AsyncPgConnection::establish(test_db.url()).await?; // Set up a user and a token that is about to expire. - let mut sync_conn = test_db.connect(); - let user = spawn_blocking(move || { - let user = NewUser::new(0, "a", None, None, "token"); - let emails = Emails::new_in_memory(); - let user = user.create_or_update(Some("testuser@test.com"), &emails, &mut sync_conn)?; - Ok::<_, anyhow::Error>(user) - }) - .await?; + let user = NewUser::new(0, "a", None, None, "token"); + let emails = Emails::new_in_memory(); + let user = user + .create_or_update(Some("testuser@test.com"), &emails, &mut conn) + .await?; let token = PlainToken::generate(); @@ -220,6 +214,8 @@ mod tests { .await?; } + let emails = Emails::new_in_memory(); + // Check that the token is about to expire. check(&emails, &mut conn).await?;