Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
use crate::middleware::session::RequestSession;
use crate::models::token::{CrateScope, EndpointScope};
use crate::models::{ApiToken, User};
use crate::util::diesel::Conn;
use crate::util::errors::{
account_locked, forbidden, internal, AppResult, InsecurelyGeneratedTokenRevoked,
};
use crate::util::token::HashedToken;
use chrono::Utc;
use diesel_async::AsyncPgConnection;
use http::header;
use http::request::Parts;

Expand Down Expand Up @@ -58,8 +58,12 @@
}

#[instrument(name = "auth.check", skip_all)]
pub fn check(&self, parts: &Parts, conn: &mut impl Conn) -> AppResult<Authentication> {
let auth = authenticate(parts, conn)?;
pub async fn check(
&self,
parts: &Parts,
conn: &mut AsyncPgConnection,
) -> AppResult<Authentication> {
let auth = authenticate(parts, conn).await?;

if let Some(token) = auth.api_token() {
if !self.allow_token {
Expand Down Expand Up @@ -168,9 +172,9 @@
}

#[instrument(skip_all)]
fn authenticate_via_cookie(
async fn authenticate_via_cookie(
parts: &Parts,
conn: &mut impl Conn,
conn: &mut AsyncPgConnection,
) -> AppResult<Option<CookieAuthentication>> {
let user_id_from_session = parts
.session()
Expand All @@ -181,7 +185,7 @@
return Ok(None);
};

let user = User::find(conn, id).map_err(|err| {
let user = User::async_find(conn, id).await.map_err(|err| {
parts.request_log().add("cause", err);
internal("user_id from cookie not found in database")
})?;
Expand All @@ -194,9 +198,9 @@
}

#[instrument(skip_all)]
fn authenticate_via_token(
async fn authenticate_via_token(
parts: &Parts,
conn: &mut impl Conn,
conn: &mut AsyncPgConnection,
) -> AppResult<Option<TokenAuthentication>> {
let maybe_authorization = parts
.headers()
Expand All @@ -210,14 +214,16 @@
let token =
HashedToken::parse(header_value).map_err(|_| InsecurelyGeneratedTokenRevoked::boxed())?;

let token = ApiToken::find_by_api_token(conn, &token).map_err(|e| {
let cause = format!("invalid token caused by {e}");
parts.request_log().add("cause", cause);
let token = ApiToken::async_find_by_api_token(conn, &token)
.await
.map_err(|e| {
let cause = format!("invalid token caused by {e}");
parts.request_log().add("cause", cause);

forbidden("authentication failed")
})?;
forbidden("authentication failed")
})?;

let user = User::find(conn, token.user_id).map_err(|err| {
let user = User::async_find(conn, token.user_id).await.map_err(|err| {

Check warning on line 226 in src/auth.rs

View check run for this annotation

Codecov / codecov/patch

src/auth.rs#L226

Added line #L226 was not covered by tests
parts.request_log().add("cause", err);
internal("user_id from token not found in database")
})?;
Expand All @@ -231,16 +237,16 @@
}

#[instrument(skip_all)]
fn authenticate(parts: &Parts, conn: &mut impl Conn) -> AppResult<Authentication> {
async fn authenticate(parts: &Parts, conn: &mut AsyncPgConnection) -> AppResult<Authentication> {
controllers::util::verify_origin(parts)?;

match authenticate_via_cookie(parts, conn) {
match authenticate_via_cookie(parts, conn).await {
Ok(None) => {}
Ok(Some(auth)) => return Ok(Authentication::Cookie(auth)),
Err(err) => return Err(err),
}

match authenticate_via_token(parts, conn) {
match authenticate_via_token(parts, conn).await {
Ok(None) => {}
Ok(Some(auth)) => return Ok(Authentication::Token(auth)),
Err(err) => return Err(err),
Expand Down
13 changes: 6 additions & 7 deletions src/controllers/crate_owner_invitation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ use tokio::runtime::Handle;

/// Handles the `GET /api/v1/me/crate_owner_invitations` route.
pub async fn list(app: AppState, req: Parts) -> AppResult<Json<Value>> {
let conn = app.db_read().await?;
let mut conn = app.db_read().await?;
let auth = AuthCheck::only_cookie().check(&req, &mut conn).await?;
spawn_blocking(move || {
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let auth = AuthCheck::only_cookie().check(&req, conn)?;
let user_id = auth.user_id();

let PrivateListResponse {
Expand Down Expand Up @@ -69,12 +69,11 @@ pub async fn list(app: AppState, req: Parts) -> AppResult<Json<Value>> {

/// Handles the `GET /api/private/crate_owner_invitations` route.
pub async fn private_list(app: AppState, req: Parts) -> AppResult<Json<PrivateListResponse>> {
let conn = app.db_read().await?;
let mut conn = app.db_read().await?;
let auth = AuthCheck::only_cookie().check(&req, &mut conn).await?;
spawn_blocking(move || {
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let auth = AuthCheck::only_cookie().check(&req, conn)?;

let filter = if let Some(crate_name) = req.query().get("crate_name") {
ListFilter::CrateName(crate_name.clone())
} else if let Some(id) = req.query().get("invitee_id").and_then(|i| i.parse().ok()) {
Expand Down Expand Up @@ -284,11 +283,11 @@ pub async fn handle_invite(state: AppState, req: BytesRequest) -> AppResult<Json

let crate_invite = crate_invite.crate_owner_invite;

let conn = state.db_write().await?;
let mut conn = state.db_write().await?;
let auth = AuthCheck::default().check(&parts, &mut conn).await?;
spawn_blocking(move || {
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let auth = AuthCheck::default().check(&parts, conn)?;
let user_id = auth.user_id();

let config = &state.config;
Expand Down
15 changes: 9 additions & 6 deletions src/controllers/krate/follow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ pub async fn follow(
Path(crate_name): Path<String>,
req: Parts,
) -> AppResult<Response> {
let conn = app.db_write().await?;
let mut conn = app.db_write().await?;
let user_id = AuthCheck::default().check(&req, &mut conn).await?.user_id();
spawn_blocking(move || {
use diesel::RunQueryDsl;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let user_id = AuthCheck::default().check(&req, conn)?.user_id();
let follow = follow_target(&crate_name, conn, user_id)?;
diesel::insert_into(follows::table)
.values(&follow)
Expand All @@ -58,13 +58,13 @@ pub async fn unfollow(
Path(crate_name): Path<String>,
req: Parts,
) -> AppResult<Response> {
let conn = app.db_write().await?;
let mut conn = app.db_write().await?;
let user_id = AuthCheck::default().check(&req, &mut conn).await?.user_id();
spawn_blocking(move || {
use diesel::RunQueryDsl;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let user_id = AuthCheck::default().check(&req, conn)?.user_id();
let follow = follow_target(&crate_name, conn, user_id)?;
diesel::delete(&follow).execute(conn)?;

Expand All @@ -79,15 +79,18 @@ pub async fn following(
Path(crate_name): Path<String>,
req: Parts,
) -> AppResult<Json<Value>> {
let conn = app.db_read_prefer_primary().await?;
let mut conn = app.db_read_prefer_primary().await?;
let user_id = AuthCheck::only_cookie()
.check(&req, &mut conn)
.await?
.user_id();
spawn_blocking(move || {
use diesel::RunQueryDsl;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

use diesel::dsl::exists;

let user_id = AuthCheck::only_cookie().check(&req, conn)?.user_id();
let follow = follow_target(&crate_name, conn, user_id)?;
let following =
diesel::select(exists(follows::table.find(follow.id()))).get_result::<bool>(conn)?;
Expand Down
12 changes: 6 additions & 6 deletions src/controllers/krate/owners.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,17 @@ async fn modify_owners(
));
}

let conn = app.db_write().await?;
let mut conn = app.db_write().await?;
let auth = AuthCheck::default()
.with_endpoint_scope(EndpointScope::ChangeOwners)
.for_crate(&crate_name)
.check(&parts, &mut conn)
.await?;
spawn_blocking(move || {
use diesel::RunQueryDsl;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let auth = AuthCheck::default()
.with_endpoint_scope(EndpointScope::ChangeOwners)
.for_crate(&crate_name)
.check(&parts, conn)?;

let user = auth.user();

// The set of emails to send out after invite processing is complete and
Expand Down
20 changes: 14 additions & 6 deletions src/controllers/krate/publish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@ pub async fn publish(app: AppState, req: BytesRequest) -> AppResult<Json<GoodCra
request_log.add("crate_name", &*metadata.name);
request_log.add("crate_version", &version_string);

let conn = app.db_write().await?;
spawn_blocking(move || {
use diesel::RunQueryDsl;
let mut conn = app.db_write().await?;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
let (existing_crate, auth) = {
use diesel_async::RunQueryDsl;

// this query should only be used for the endpoint scope calculation
// since a race condition there would only cause `publish-new` instead of
// `publish-update` to be used.
let existing_crate: Option<Crate> = Crate::by_name(&metadata.name)
.first::<Crate>(conn)
.first::<Crate>(&mut conn)
.await
.optional()?;

let endpoint_scope = match existing_crate {
Expand All @@ -102,7 +102,15 @@ pub async fn publish(app: AppState, req: BytesRequest) -> AppResult<Json<GoodCra
let auth = AuthCheck::default()
.with_endpoint_scope(endpoint_scope)
.for_crate(&metadata.name)
.check(&req, conn)?;
.check(&req, &mut conn)
.await?;
(existing_crate, auth)
};

spawn_blocking(move || {
use diesel::RunQueryDsl;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let api_token_id = auth.api_token_id();
let user = auth.user();
Expand Down
11 changes: 7 additions & 4 deletions src/controllers/krate/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ use axum::Json;
use diesel::dsl::{exists, sql, InnerJoinQuerySource, LeftJoinQuerySource};
use diesel::sql_types::{Array, Bool, Text};
use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
use diesel_async::AsyncPgConnection;
use diesel_full_text_search::*;
use http::request::Parts;
use serde_json::Value;
use std::cell::OnceCell;
use tokio::runtime::Handle;

use crate::app::AppState;
use crate::controllers::helpers::Paginate;
Expand All @@ -22,7 +24,6 @@ use crate::controllers::helpers::pagination::{Page, Paginated, PaginationOptions
use crate::models::krate::ALL_COLUMNS;
use crate::sql::{array_agg, canon_crate_name, lower};
use crate::tasks::spawn_blocking;
use crate::util::diesel::Conn;
use crate::util::RequestUtils;

/// Handles the `GET /crates` route.
Expand Down Expand Up @@ -303,12 +304,14 @@ impl<'a> FilterParams<'a> {
.as_deref()
}

fn authed_user_id(&self, req: &Parts, conn: &mut impl Conn) -> AppResult<i32> {
fn authed_user_id(&self, req: &Parts, conn: &mut AsyncPgConnection) -> AppResult<i32> {
if let Some(val) = self._auth_user_id.get() {
return Ok(*val);
}

let user_id = AuthCheck::default().check(req, conn)?.user_id();
let user_id = Handle::current()
.block_on(AuthCheck::default().check(req, conn))?
.user_id();

// This should not fail, because of the `get()` check above
let _ = self._auth_user_id.set(user_id);
Expand All @@ -319,7 +322,7 @@ impl<'a> FilterParams<'a> {
fn make_query(
&'a self,
req: &Parts,
conn: &mut impl Conn,
conn: &mut AsyncPgConnection,
) -> AppResult<crates::BoxedQuery<'a, diesel::pg::Pg>> {
let mut query = crates::table.into_boxed();

Expand Down
20 changes: 10 additions & 10 deletions src/controllers/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ pub async fn list(
Query(params): Query<GetParams>,
req: Parts,
) -> AppResult<Json<Value>> {
let conn = app.db_read_prefer_primary().await?;
let mut conn = app.db_read_prefer_primary().await?;
let auth = AuthCheck::only_cookie().check(&req, &mut conn).await?;
spawn_blocking(move || {
use diesel::RunQueryDsl;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let auth = AuthCheck::only_cookie().check(&req, conn)?;
let user = auth.user();

let tokens: Vec<ApiToken> = ApiToken::belonging_to(user)
Expand Down Expand Up @@ -92,13 +92,13 @@ pub async fn new(
return Err(bad_request("name must have a value"));
}

let conn = app.db_write().await?;
let mut conn = app.db_write().await?;
let auth = AuthCheck::default().check(&parts, &mut conn).await?;
spawn_blocking(move || {
use diesel::RunQueryDsl;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let auth = AuthCheck::default().check(&parts, conn)?;
if auth.api_token_id().is_some() {
return Err(bad_request(
"cannot use an API token to create a new API token",
Expand Down Expand Up @@ -175,13 +175,13 @@ pub async fn new(

/// Handles the `GET /me/tokens/:id` route.
pub async fn show(app: AppState, Path(id): Path<i32>, req: Parts) -> AppResult<Json<Value>> {
let conn = app.db_write().await?;
let mut conn = app.db_write().await?;
let auth = AuthCheck::default().check(&req, &mut conn).await?;
spawn_blocking(move || {
use diesel::RunQueryDsl;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let auth = AuthCheck::default().check(&req, conn)?;
let user = auth.user();
let token = ApiToken::belonging_to(user)
.find(id)
Expand All @@ -195,13 +195,13 @@ pub async fn show(app: AppState, Path(id): Path<i32>, req: Parts) -> AppResult<J

/// Handles the `DELETE /me/tokens/:id` route.
pub async fn revoke(app: AppState, Path(id): Path<i32>, req: Parts) -> AppResult<Json<Value>> {
let conn = app.db_write().await?;
let mut conn = app.db_write().await?;
let auth = AuthCheck::default().check(&req, &mut conn).await?;
spawn_blocking(move || {
use diesel::RunQueryDsl;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let auth = AuthCheck::default().check(&req, conn)?;
let user = auth.user();
diesel::update(ApiToken::belonging_to(user).find(id))
.set(api_tokens::revoked.eq(true))
Expand All @@ -214,13 +214,13 @@ pub async fn revoke(app: AppState, Path(id): Path<i32>, req: Parts) -> AppResult

/// Handles the `DELETE /tokens/current` route.
pub async fn revoke_current(app: AppState, req: Parts) -> AppResult<Response> {
let conn = app.db_write().await?;
let mut conn = app.db_write().await?;
let auth = AuthCheck::default().check(&req, &mut conn).await?;
spawn_blocking(move || {
use diesel::RunQueryDsl;

let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();

let auth = AuthCheck::default().check(&req, conn)?;
let api_token_id = auth
.api_token_id()
.ok_or_else(|| bad_request("token not provided"))?;
Expand Down
Loading