diff --git a/crates/crates_io_worker/Cargo.toml b/crates/crates_io_worker/Cargo.toml index 0620f0d7693..fb61851d1fc 100644 --- a/crates/crates_io_worker/Cargo.toml +++ b/crates/crates_io_worker/Cargo.toml @@ -24,4 +24,4 @@ tracing = "=0.1.40" claims = "=0.7.1" crates_io_test_db = { path = "../crates_io_test_db" } insta = { version = "=1.41.1", features = ["json"] } -tokio = { version = "=1.41.0", features = ["macros", "rt", "rt-multi-thread", "sync"]} +tokio = { version = "=1.41.0", features = ["macros", "sync"]} diff --git a/crates/crates_io_worker/tests/runner.rs b/crates/crates_io_worker/tests/runner.rs index 89d4913cc4f..97ae792dc5f 100644 --- a/crates/crates_io_worker/tests/runner.rs +++ b/crates/crates_io_worker/tests/runner.rs @@ -5,7 +5,7 @@ use crates_io_worker::{BackgroundJob, Runner}; use diesel::prelude::*; use diesel_async::pooled_connection::deadpool::Pool; use diesel_async::pooled_connection::AsyncDieselConnectionManager; -use diesel_async::AsyncPgConnection; +use diesel_async::{AsyncPgConnection, RunQueryDsl}; use insta::assert_compact_json_snapshot; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -13,37 +13,37 @@ use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use tokio::sync::Barrier; -fn all_jobs(conn: &mut PgConnection) -> Vec<(String, Value)> { +async fn all_jobs(conn: &mut AsyncPgConnection) -> QueryResult> { background_jobs::table .select((background_jobs::job_type, background_jobs::data)) .get_results(conn) - .unwrap() + .await } -fn job_exists(id: i64, conn: &mut PgConnection) -> bool { - background_jobs::table +async fn job_exists(id: i64, conn: &mut AsyncPgConnection) -> QueryResult { + Ok(background_jobs::table .find(id) .select(background_jobs::id) .get_result::(conn) - .optional() - .unwrap() - .is_some() + .await + .optional()? + .is_some()) } -fn job_is_locked(id: i64, conn: &mut PgConnection) -> bool { - background_jobs::table +async fn job_is_locked(id: i64, conn: &mut AsyncPgConnection) -> QueryResult { + Ok(background_jobs::table .find(id) .select(background_jobs::id) .for_update() .skip_locked() .get_result::(conn) - .optional() - .unwrap() - .is_none() + .await + .optional()? + .is_none()) } -#[tokio::test(flavor = "multi_thread")] -async fn jobs_are_locked_when_fetched() { +#[tokio::test] +async fn jobs_are_locked_when_fetched() -> anyhow::Result<()> { #[derive(Clone)] struct TestContext { job_started_barrier: Arc, @@ -71,28 +71,32 @@ async fn jobs_are_locked_when_fetched() { assertions_finished_barrier: Arc::new(Barrier::new(2)), }; - let runner = runner(test_database.url(), test_context.clone()).register_job_type::(); + let pool = pool(test_database.url())?; + let mut conn = pool.get().await?; - let mut conn = test_database.connect(); - let job_id = TestJob.enqueue(&mut conn).unwrap().unwrap(); + let runner = runner(pool, test_context.clone()).register_job_type::(); - assert!(job_exists(job_id, &mut conn)); - assert!(!job_is_locked(job_id, &mut conn)); + let job_id = assert_some!(TestJob.async_enqueue(&mut conn).await?); + + assert!(job_exists(job_id, &mut conn).await?); + assert!(!job_is_locked(job_id, &mut conn).await?); let runner = runner.start(); test_context.job_started_barrier.wait().await; - assert!(job_exists(job_id, &mut conn)); - assert!(job_is_locked(job_id, &mut conn)); + assert!(job_exists(job_id, &mut conn).await?); + assert!(job_is_locked(job_id, &mut conn).await?); test_context.assertions_finished_barrier.wait().await; runner.wait_for_shutdown().await; - assert!(!job_exists(job_id, &mut conn)); + assert!(!job_exists(job_id, &mut conn).await?); + + Ok(()) } -#[tokio::test(flavor = "multi_thread")] -async fn jobs_are_deleted_when_successfully_run() { +#[tokio::test] +async fn jobs_are_deleted_when_successfully_run() -> anyhow::Result<()> { #[derive(Serialize, Deserialize)] struct TestJob; @@ -105,30 +109,31 @@ async fn jobs_are_deleted_when_successfully_run() { } } - fn remaining_jobs(conn: &mut PgConnection) -> i64 { - background_jobs::table - .count() - .get_result(&mut *conn) - .unwrap() + async fn remaining_jobs(conn: &mut AsyncPgConnection) -> QueryResult { + background_jobs::table.count().get_result(conn).await } let test_database = TestDatabase::new(); - let runner = runner(test_database.url(), ()).register_job_type::(); + let pool = pool(test_database.url())?; + let mut conn = pool.get().await?; - let mut conn = test_database.connect(); - assert_eq!(remaining_jobs(&mut conn), 0); + let runner = runner(pool, ()).register_job_type::(); - TestJob.enqueue(&mut conn).unwrap(); - assert_eq!(remaining_jobs(&mut conn), 1); + assert_eq!(remaining_jobs(&mut conn).await?, 0); + + TestJob.async_enqueue(&mut conn).await?; + assert_eq!(remaining_jobs(&mut conn).await?, 1); let runner = runner.start(); runner.wait_for_shutdown().await; - assert_eq!(remaining_jobs(&mut conn), 0); + assert_eq!(remaining_jobs(&mut conn).await?, 0); + + Ok(()) } -#[tokio::test(flavor = "multi_thread")] -async fn failed_jobs_do_not_release_lock_before_updating_retry_time() { +#[tokio::test] +async fn failed_jobs_do_not_release_lock_before_updating_retry_time() -> anyhow::Result<()> { #[derive(Clone)] struct TestContext { job_started_barrier: Arc, @@ -153,10 +158,12 @@ async fn failed_jobs_do_not_release_lock_before_updating_retry_time() { job_started_barrier: Arc::new(Barrier::new(2)), }; - let runner = runner(test_database.url(), test_context.clone()).register_job_type::(); + let pool = pool(test_database.url())?; + let mut conn = pool.get().await?; - let mut conn = test_database.connect(); - TestJob.enqueue(&mut conn).unwrap(); + let runner = runner(pool, test_context.clone()).register_job_type::(); + + TestJob.async_enqueue(&mut conn).await?; let runner = runner.start(); test_context.job_started_barrier.wait().await; @@ -169,23 +176,25 @@ async fn failed_jobs_do_not_release_lock_before_updating_retry_time() { .select(background_jobs::id) .filter(background_jobs::retries.eq(0)) .for_update() - .load::(&mut *conn) - .unwrap(); + .load::(&mut conn) + .await?; assert_eq!(available_jobs.len(), 0); // Sanity check to make sure the job actually is there let total_jobs_including_failed = background_jobs::table .select(background_jobs::id) .for_update() - .load::(&mut *conn) - .unwrap(); + .load::(&mut conn) + .await?; assert_eq!(total_jobs_including_failed.len(), 1); runner.wait_for_shutdown().await; + + Ok(()) } -#[tokio::test(flavor = "multi_thread")] -async fn panicking_in_jobs_updates_retry_counter() { +#[tokio::test] +async fn panicking_in_jobs_updates_retry_counter() -> anyhow::Result<()> { #[derive(Serialize, Deserialize)] struct TestJob; @@ -200,11 +209,12 @@ async fn panicking_in_jobs_updates_retry_counter() { let test_database = TestDatabase::new(); - let runner = runner(test_database.url(), ()).register_job_type::(); + let pool = pool(test_database.url())?; + let mut conn = pool.get().await?; - let mut conn = test_database.connect(); + let runner = runner(pool, ()).register_job_type::(); - let job_id = TestJob.enqueue(&mut conn).unwrap().unwrap(); + let job_id = assert_some!(TestJob.async_enqueue(&mut conn).await?); let runner = runner.start(); runner.wait_for_shutdown().await; @@ -213,13 +223,15 @@ async fn panicking_in_jobs_updates_retry_counter() { .find(job_id) .select(background_jobs::retries) .for_update() - .first::(&mut *conn) - .unwrap(); + .first::(&mut conn) + .await?; assert_eq!(tries, 1); + + Ok(()) } -#[tokio::test(flavor = "multi_thread")] -async fn jobs_can_be_deduplicated() { +#[tokio::test] +async fn jobs_can_be_deduplicated() -> anyhow::Result<()> { #[derive(Clone)] struct TestContext { runs: Arc, @@ -262,17 +274,18 @@ async fn jobs_can_be_deduplicated() { assertions_finished_barrier: Arc::new(Barrier::new(2)), }; - let runner = runner(test_database.url(), test_context.clone()).register_job_type::(); + let pool = pool(test_database.url())?; + let mut conn = pool.get().await?; - let mut conn = test_database.connect(); + let runner = runner(pool, test_context.clone()).register_job_type::(); // Enqueue first job - assert_some!(TestJob::new("foo").enqueue(&mut conn).unwrap()); - assert_compact_json_snapshot!(all_jobs(&mut conn), @r#"[["test", {"value": "foo"}]]"#); + assert_some!(TestJob::new("foo").async_enqueue(&mut conn).await?); + assert_compact_json_snapshot!(all_jobs(&mut conn).await?, @r#"[["test", {"value": "foo"}]]"#); // Try to enqueue the same job again, which should be deduplicated - assert_none!(TestJob::new("foo").enqueue(&mut conn).unwrap()); - assert_compact_json_snapshot!(all_jobs(&mut conn), @r#"[["test", {"value": "foo"}]]"#); + assert_none!(TestJob::new("foo").async_enqueue(&mut conn).await?); + assert_compact_json_snapshot!(all_jobs(&mut conn).await?, @r#"[["test", {"value": "foo"}]]"#); // Start processing the first job let runner = runner.start(); @@ -280,30 +293,34 @@ async fn jobs_can_be_deduplicated() { // Enqueue the same job again, which should NOT be deduplicated, // since the first job already still running - assert_some!(TestJob::new("foo").enqueue(&mut conn).unwrap()); - assert_compact_json_snapshot!(all_jobs(&mut conn), @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}]]"#); + assert_some!(TestJob::new("foo").async_enqueue(&mut conn).await?); + assert_compact_json_snapshot!(all_jobs(&mut conn).await?, @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}]]"#); // Try to enqueue the same job again, which should be deduplicated again - assert_none!(TestJob::new("foo").enqueue(&mut conn).unwrap()); - assert_compact_json_snapshot!(all_jobs(&mut conn), @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}]]"#); + assert_none!(TestJob::new("foo").async_enqueue(&mut conn).await?); + assert_compact_json_snapshot!(all_jobs(&mut conn).await?, @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}]]"#); // Enqueue the same job but with different data, which should // NOT be deduplicated - assert_some!(TestJob::new("bar").enqueue(&mut conn).unwrap()); - assert_compact_json_snapshot!(all_jobs(&mut conn), @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}], ["test", {"value": "bar"}]]"#); + assert_some!(TestJob::new("bar").async_enqueue(&mut conn).await?); + assert_compact_json_snapshot!(all_jobs(&mut conn).await?, @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}], ["test", {"value": "bar"}]]"#); // Resolve the final barrier to finish the test test_context.assertions_finished_barrier.wait().await; runner.wait_for_shutdown().await; + + Ok(()) +} + +fn pool(database_url: &str) -> anyhow::Result> { + let manager = AsyncDieselConnectionManager::::new(database_url); + Ok(Pool::builder(manager).max_size(4).build()?) } fn runner( - database_url: &str, + deadpool: Pool, context: Context, ) -> Runner { - let manager = AsyncDieselConnectionManager::::new(database_url); - let deadpool = Pool::builder(manager).max_size(4).build().unwrap(); - Runner::new(deadpool, context) .configure_default_queue(|queue| queue.num_workers(2)) .shutdown_when_queue_empty()