Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 63 additions & 62 deletions src/controllers/user/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@
use axum::Json;
use axum_extra::json;
use axum_extra::response::ErasedJson;
use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
use diesel::prelude::*;
use diesel_async::{AsyncPgConnection, RunQueryDsl};
use http::request::Parts;
use oauth2::reqwest::http_client;
use oauth2::reqwest::async_http_client;
use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse};
use tokio::runtime::Handle;

use crate::app::AppState;
use crate::email::Emails;
use crate::middleware::log_request::RequestLogExt;
use crate::middleware::session::SessionExtension;
use crate::models::{NewUser, User};
use crate::schema::users;
use crate::tasks::spawn_blocking;
use crate::util::diesel::{is_read_only_error, Conn};
use crate::util::diesel::is_read_only_error;
use crate::util::errors::{bad_request, server_error, AppResult};
use crate::views::EncodableMe;
use crates_io_github::GithubUser;
Expand Down Expand Up @@ -89,76 +88,74 @@
session: SessionExtension,
req: Parts,
) -> AppResult<Json<EncodableMe>> {
let app_clone = app.clone();
let request_log = req.request_log().clone();

let conn = app.db_write().await?;
spawn_blocking(move || {
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

// Make sure that the state we just got matches the session state that we
// should have issued earlier.
let session_state = session.remove("github_oauth_state").map(CsrfToken::new);
if !session_state.is_some_and(|state| query.state.secret() == state.secret()) {
return Err(bad_request("invalid state parameter"));
}
// Make sure that the state we just got matches the session state that we
// should have issued earlier.
let session_state = session.remove("github_oauth_state").map(CsrfToken::new);
if !session_state.is_some_and(|state| query.state.secret() == state.secret()) {
return Err(bad_request("invalid state parameter"));
}

Check warning on line 96 in src/controllers/user/session.rs

View check run for this annotation

Codecov / codecov/patch

src/controllers/user/session.rs#L91-L96

Added lines #L91 - L96 were not covered by tests

// Fetch the access token from GitHub using the code we just got
let token = app
.github_oauth
.exchange_code(query.code)
.request(http_client)
.map_err(|err| {
request_log.add("cause", err);
server_error("Error obtaining token")
})?;
// Fetch the access token from GitHub using the code we just got
let token = app
.github_oauth
.exchange_code(query.code)
.request_async(async_http_client)
.await
.map_err(|err| {
req.request_log().add("cause", err);
server_error("Error obtaining token")
})?;

Check warning on line 107 in src/controllers/user/session.rs

View check run for this annotation

Codecov / codecov/patch

src/controllers/user/session.rs#L99-L107

Added lines #L99 - L107 were not covered by tests

let token = token.access_token();
let token = token.access_token();

Check warning on line 109 in src/controllers/user/session.rs

View check run for this annotation

Codecov / codecov/patch

src/controllers/user/session.rs#L109

Added line #L109 was not covered by tests

// Fetch the user info from GitHub using the access token we just got and create a user record
let ghuser = Handle::current().block_on(app.github.current_user(token))?;
let user = save_user_to_database(&ghuser, token.secret(), &app.emails, conn)?;
// Fetch the user info from GitHub using the access token we just got and create a user record
let ghuser = app.github.current_user(token).await?;

Check warning on line 112 in src/controllers/user/session.rs

View check run for this annotation

Codecov / codecov/patch

src/controllers/user/session.rs#L112

Added line #L112 was not covered by tests

// Log in by setting a cookie and the middleware authentication
session.insert("user_id".to_string(), user.id.to_string());
let mut conn = app.db_write().await?;
let user = save_user_to_database(&ghuser, token.secret(), &app.emails, &mut conn).await?;

Check warning on line 115 in src/controllers/user/session.rs

View check run for this annotation

Codecov / codecov/patch

src/controllers/user/session.rs#L114-L115

Added lines #L114 - L115 were not covered by tests

Ok(())
})
.await?;
// Log in by setting a cookie and the middleware authentication
session.insert("user_id".to_string(), user.id.to_string());

Check warning on line 118 in src/controllers/user/session.rs

View check run for this annotation

Codecov / codecov/patch

src/controllers/user/session.rs#L118

Added line #L118 was not covered by tests

super::me::me(app_clone, req).await
super::me::me(app, req).await

Check warning on line 120 in src/controllers/user/session.rs

View check run for this annotation

Codecov / codecov/patch

src/controllers/user/session.rs#L120

Added line #L120 was not covered by tests
}

fn save_user_to_database(
async fn save_user_to_database(
user: &GithubUser,
access_token: &str,
emails: &Emails,
conn: &mut impl Conn,
conn: &mut AsyncPgConnection,
) -> AppResult<User> {
use diesel::prelude::*;

NewUser::new(
let new_user = NewUser::new(
user.id,
&user.login,
user.name.as_deref(),
user.avatar_url.as_deref(),
access_token,
)
.create_or_update(user.email.as_deref(), emails, conn)
.or_else(|e| {
// If we're in read only mode, we can't update their details
// just look for an existing user
if is_read_only_error(&e) {
users::table
.filter(users::gh_id.eq(user.id))
.first(conn)
.optional()?
.ok_or(e)
} else {
Err(e)
);

match new_user
.create_or_update(user.email.as_deref(), emails, conn)
.await
{
Ok(user) => Ok(user),
Err(error) if is_read_only_error(&error) => {
// If we're in read only mode, we can't update their details
// just look for an existing user
find_user_by_gh_id(conn, user.id)
.await?
.ok_or_else(|| error.into())

Check warning on line 147 in src/controllers/user/session.rs

View check run for this annotation

Codecov / codecov/patch

src/controllers/user/session.rs#L142-L147

Added lines #L142 - L147 were not covered by tests
}
})
.map_err(Into::into)
Err(error) => Err(error.into()),

Check warning on line 149 in src/controllers/user/session.rs

View check run for this annotation

Codecov / codecov/patch

src/controllers/user/session.rs#L149

Added line #L149 was not covered by tests
}
}

async fn find_user_by_gh_id(conn: &mut AsyncPgConnection, gh_id: i32) -> QueryResult<Option<User>> {
users::table
.filter(users::gh_id.eq(gh_id))
.first(conn)
.await
.optional()

Check warning on line 158 in src/controllers/user/session.rs

View check run for this annotation

Codecov / codecov/patch

src/controllers/user/session.rs#L153-L158

Added lines #L153 - L158 were not covered by tests
}

/// Handles the `DELETE /api/private/session` route.
Expand All @@ -170,20 +167,24 @@
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::test_db_connection;
use crates_io_test_db::TestDatabase;
use diesel_async::AsyncConnection;

#[test]
fn gh_user_with_invalid_email_doesnt_fail() {
#[tokio::test]
async fn gh_user_with_invalid_email_doesnt_fail() {
let emails = Emails::new_in_memory();
let (_test_db, conn) = &mut test_db_connection();

let test_db = TestDatabase::new();
let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap();

let gh_user = GithubUser {
email: Some("String.Format(\"{0}.{1}@live.com\", FirstName, LastName)".into()),
name: Some("My Name".into()),
login: "github_user".into(),
id: -1,
avatar_url: None,
};
let result = save_user_to_database(&gh_user, "arbitrary_token", &emails, conn);
let result = save_user_to_database(&gh_user, "arbitrary_token", &emails, &mut conn).await;

assert!(
result.is_ok(),
Expand Down
103 changes: 55 additions & 48 deletions src/models/user.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use chrono::NaiveDateTime;
use diesel_async::AsyncPgConnection;
use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::{AsyncConnection, AsyncPgConnection};
use secrecy::SecretString;

use crate::app::App;
Expand Down Expand Up @@ -171,66 +172,72 @@ impl<'a> NewUser<'a> {
}

/// Inserts the user into the database, or updates an existing one.
pub fn create_or_update(
pub async fn create_or_update(
&self,
email: Option<&'a str>,
emails: &Emails,
conn: &mut impl Conn,
conn: &mut AsyncPgConnection,
) -> QueryResult<User> {
use diesel::dsl::sql;
use diesel::insert_into;
use diesel::pg::upsert::excluded;
use diesel::sql_types::Integer;
use diesel::RunQueryDsl;
use diesel_async::RunQueryDsl;

conn.transaction(|conn| {
let user: User = insert_into(users::table)
.values(self)
// We need the `WHERE gh_id > 0` condition here because `gh_id` set
// to `-1` indicates that we were unable to find a GitHub ID for
// the associated GitHub login at the time that we backfilled
// GitHub IDs. Therefore, there are multiple records in production
// that have a `gh_id` of `-1` so we need to exclude those when
// considering uniqueness of `gh_id` values. The `> 0` condition isn't
// necessary for most fields in the database to be used as a conflict
// target :)
.on_conflict(sql::<Integer>("(gh_id) WHERE gh_id > 0"))
.do_update()
.set((
users::gh_login.eq(excluded(users::gh_login)),
users::name.eq(excluded(users::name)),
users::gh_avatar.eq(excluded(users::gh_avatar)),
users::gh_access_token.eq(excluded(users::gh_access_token)),
))
.get_result(conn)?;

// To send the user an account verification email
if let Some(user_email) = email {
let new_email = NewEmail {
user_id: user.id,
email: user_email,
};

let token = insert_into(emails::table)
.values(&new_email)
.on_conflict_do_nothing()
.returning(emails::token)
.get_result::<String>(conn)
.optional()?
.map(SecretString::from);

if let Some(token) = token {
// Swallows any error. Some users might insert an invalid email address here.
let email = UserConfirmEmail {
user_name: &user.gh_login,
domain: &emails.domain,
token,
async move {
let user: User = insert_into(users::table)
.values(self)
// We need the `WHERE gh_id > 0` condition here because `gh_id` set
// to `-1` indicates that we were unable to find a GitHub ID for
// the associated GitHub login at the time that we backfilled
// GitHub IDs. Therefore, there are multiple records in production
// that have a `gh_id` of `-1` so we need to exclude those when
// considering uniqueness of `gh_id` values. The `> 0` condition isn't
// necessary for most fields in the database to be used as a conflict
// target :)
.on_conflict(sql::<Integer>("(gh_id) WHERE gh_id > 0"))
.do_update()
.set((
users::gh_login.eq(excluded(users::gh_login)),
users::name.eq(excluded(users::name)),
users::gh_avatar.eq(excluded(users::gh_avatar)),
users::gh_access_token.eq(excluded(users::gh_access_token)),
))
.get_result(conn)
.await?;

// To send the user an account verification email
if let Some(user_email) = email {
let new_email = NewEmail {
user_id: user.id,
email: user_email,
};
let _ = emails.send(user_email, email);

let token = insert_into(emails::table)
.values(&new_email)
.on_conflict_do_nothing()
.returning(emails::token)
.get_result::<String>(conn)
.await
.optional()?
.map(SecretString::from);

if let Some(token) = token {
// Swallows any error. Some users might insert an invalid email address here.
let email = UserConfirmEmail {
user_name: &user.gh_login,
domain: &emails.domain,
token,
};
let _ = emails.async_send(user_email, email).await;
}
}
}

Ok(user)
Ok(user)
}
.scope_boxed()
})
.await
}
}
Loading