diff --git a/Cargo.lock b/Cargo.lock index 04848a80b89..b0a2f476584 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1200,11 +1200,13 @@ name = "crates_io_worker" version = "0.0.0" dependencies = [ "anyhow", + "claims", "crates_io_test_db", "deadpool-diesel", "diesel", "diesel-async", "futures-util", + "insta", "sentry-core", "serde", "serde_json", diff --git a/crates/crates_io_worker/Cargo.toml b/crates/crates_io_worker/Cargo.toml index b3995ce6b31..ba5df876289 100644 --- a/crates/crates_io_worker/Cargo.toml +++ b/crates/crates_io_worker/Cargo.toml @@ -21,5 +21,7 @@ tokio = { version = "=1.40.0", features = ["rt", "time"]} tracing = "=0.1.40" [dev-dependencies] +claims = "=0.7.1" crates_io_test_db = { path = "../crates_io_test_db" } +insta = { version = "=1.40.0", features = ["json"] } tokio = { version = "=1.40.0", features = ["macros", "rt", "rt-multi-thread", "sync"]} diff --git a/crates/crates_io_worker/src/background_job.rs b/crates/crates_io_worker/src/background_job.rs index e977e15ae8f..f121bc1fb58 100644 --- a/crates/crates_io_worker/src/background_job.rs +++ b/crates/crates_io_worker/src/background_job.rs @@ -1,8 +1,10 @@ use crate::errors::EnqueueError; use crate::schema::background_jobs; use diesel::connection::LoadConnection; +use diesel::dsl::{exists, not}; use diesel::pg::Pg; use diesel::prelude::*; +use diesel::sql_types::{Int2, Jsonb, Text}; use serde::de::DeserializeOwned; use serde::Serialize; use std::future::Future; @@ -21,6 +23,12 @@ pub trait BackgroundJob: Serialize + DeserializeOwned + Send + Sync + 'static { /// [Self::enqueue_with_priority] can be used to override the priority value. const PRIORITY: i16 = 0; + /// Whether the job should be deduplicated. + /// + /// If true, the job will not be enqueued if there is already an unstarted + /// job with the same data. + const DEDUPLICATED: bool = false; + /// Job queue where this job will be executed. const QUEUE: &'static str = DEFAULT_QUEUE; @@ -30,7 +38,10 @@ pub trait BackgroundJob: Serialize + DeserializeOwned + Send + Sync + 'static { /// Execute the task. This method should define its logic. fn run(&self, ctx: Self::Context) -> impl Future> + Send; - fn enqueue(&self, conn: &mut impl LoadConnection) -> Result { + fn enqueue( + &self, + conn: &mut impl LoadConnection, + ) -> Result, EnqueueError> { self.enqueue_with_priority(conn, Self::PRIORITY) } @@ -39,16 +50,48 @@ pub trait BackgroundJob: Serialize + DeserializeOwned + Send + Sync + 'static { &self, conn: &mut impl LoadConnection, job_priority: i16, - ) -> Result { + ) -> Result, EnqueueError> { let job_data = serde_json::to_value(self)?; - let id = diesel::insert_into(background_jobs::table) - .values(( - background_jobs::job_type.eq(Self::JOB_NAME), - background_jobs::data.eq(job_data), - background_jobs::priority.eq(job_priority), + + if Self::DEDUPLICATED { + let similar_jobs = background_jobs::table + .select(background_jobs::id) + .filter(background_jobs::job_type.eq(Self::JOB_NAME)) + .filter(background_jobs::data.eq(&job_data)) + .filter(background_jobs::priority.eq(job_priority)) + .for_update() + .skip_locked(); + + let deduplicated_select = diesel::select(( + Self::JOB_NAME.into_sql::(), + (&job_data).into_sql::(), + job_priority.into_sql::(), )) - .returning(background_jobs::id) - .get_result(conn)?; - Ok(id) + .filter(not(exists(similar_jobs))); + + let id = diesel::insert_into(background_jobs::table) + .values(deduplicated_select) + .into_columns(( + background_jobs::job_type, + background_jobs::data, + background_jobs::priority, + )) + .returning(background_jobs::id) + .get_result::(conn) + .optional()?; + + Ok(id) + } else { + let id = diesel::insert_into(background_jobs::table) + .values(( + background_jobs::job_type.eq(Self::JOB_NAME), + background_jobs::data.eq(job_data), + background_jobs::priority.eq(job_priority), + )) + .returning(background_jobs::id) + .get_result(conn)?; + + Ok(Some(id)) + } } } diff --git a/crates/crates_io_worker/tests/runner.rs b/crates/crates_io_worker/tests/runner.rs index f40db340bb7..89d4913cc4f 100644 --- a/crates/crates_io_worker/tests/runner.rs +++ b/crates/crates_io_worker/tests/runner.rs @@ -1,3 +1,4 @@ +use claims::{assert_none, assert_some}; use crates_io_test_db::TestDatabase; use crates_io_worker::schema::background_jobs; use crates_io_worker::{BackgroundJob, Runner}; @@ -5,10 +6,20 @@ use diesel::prelude::*; use diesel_async::pooled_connection::deadpool::Pool; use diesel_async::pooled_connection::AsyncDieselConnectionManager; use diesel_async::AsyncPgConnection; +use insta::assert_compact_json_snapshot; use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use tokio::sync::Barrier; +fn all_jobs(conn: &mut PgConnection) -> Vec<(String, Value)> { + background_jobs::table + .select((background_jobs::job_type, background_jobs::data)) + .get_results(conn) + .unwrap() +} + fn job_exists(id: i64, conn: &mut PgConnection) -> bool { background_jobs::table .find(id) @@ -63,7 +74,7 @@ async fn jobs_are_locked_when_fetched() { let runner = runner(test_database.url(), test_context.clone()).register_job_type::(); let mut conn = test_database.connect(); - let job_id = TestJob.enqueue(&mut conn).unwrap(); + let job_id = TestJob.enqueue(&mut conn).unwrap().unwrap(); assert!(job_exists(job_id, &mut conn)); assert!(!job_is_locked(job_id, &mut conn)); @@ -193,7 +204,7 @@ async fn panicking_in_jobs_updates_retry_counter() { let mut conn = test_database.connect(); - let job_id = TestJob.enqueue(&mut conn).unwrap(); + let job_id = TestJob.enqueue(&mut conn).unwrap().unwrap(); let runner = runner.start(); runner.wait_for_shutdown().await; @@ -207,6 +218,85 @@ async fn panicking_in_jobs_updates_retry_counter() { assert_eq!(tries, 1); } +#[tokio::test(flavor = "multi_thread")] +async fn jobs_can_be_deduplicated() { + #[derive(Clone)] + struct TestContext { + runs: Arc, + job_started_barrier: Arc, + assertions_finished_barrier: Arc, + } + + #[derive(Serialize, Deserialize)] + struct TestJob { + value: String, + } + + impl TestJob { + fn new(value: impl Into) -> Self { + let value = value.into(); + Self { value } + } + } + + impl BackgroundJob for TestJob { + const JOB_NAME: &'static str = "test"; + const DEDUPLICATED: bool = true; + type Context = TestContext; + + async fn run(&self, ctx: Self::Context) -> anyhow::Result<()> { + let runs = ctx.runs.fetch_add(1, Ordering::SeqCst); + if runs == 0 { + ctx.job_started_barrier.wait().await; + ctx.assertions_finished_barrier.wait().await; + } + Ok(()) + } + } + + let test_database = TestDatabase::new(); + + let test_context = TestContext { + runs: Arc::new(AtomicU8::new(0)), + job_started_barrier: Arc::new(Barrier::new(2)), + assertions_finished_barrier: Arc::new(Barrier::new(2)), + }; + + let runner = runner(test_database.url(), test_context.clone()).register_job_type::(); + + let mut conn = test_database.connect(); + + // Enqueue first job + assert_some!(TestJob::new("foo").enqueue(&mut conn).unwrap()); + assert_compact_json_snapshot!(all_jobs(&mut conn), @r#"[["test", {"value": "foo"}]]"#); + + // Try to enqueue the same job again, which should be deduplicated + assert_none!(TestJob::new("foo").enqueue(&mut conn).unwrap()); + assert_compact_json_snapshot!(all_jobs(&mut conn), @r#"[["test", {"value": "foo"}]]"#); + + // Start processing the first job + let runner = runner.start(); + test_context.job_started_barrier.wait().await; + + // Enqueue the same job again, which should NOT be deduplicated, + // since the first job already still running + assert_some!(TestJob::new("foo").enqueue(&mut conn).unwrap()); + assert_compact_json_snapshot!(all_jobs(&mut conn), @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}]]"#); + + // Try to enqueue the same job again, which should be deduplicated again + assert_none!(TestJob::new("foo").enqueue(&mut conn).unwrap()); + assert_compact_json_snapshot!(all_jobs(&mut conn), @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}]]"#); + + // Enqueue the same job but with different data, which should + // NOT be deduplicated + assert_some!(TestJob::new("bar").enqueue(&mut conn).unwrap()); + assert_compact_json_snapshot!(all_jobs(&mut conn), @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}], ["test", {"value": "bar"}]]"#); + + // Resolve the final barrier to finish the test + test_context.assertions_finished_barrier.wait().await; + runner.wait_for_shutdown().await; +} + fn runner( database_url: &str, context: Context,