diff --git a/lib/chirp-workflow/core/src/compat.rs b/lib/chirp-workflow/core/src/compat.rs index 002c0e8bb..e08677068 100644 --- a/lib/chirp-workflow/core/src/compat.rs +++ b/lib/chirp-workflow/core/src/compat.rs @@ -13,7 +13,7 @@ use crate::{ workflow::SUB_WORKFLOW_RETRY, OperationCtx, }, - db::{DatabaseHandle, DatabasePostgres}, + db::{DatabaseHandle, DatabasePgNats}, error::WorkflowError, message::Message, operation::{Operation, OperationInput}, @@ -261,13 +261,15 @@ async fn db_from_ctx( ctx: &rivet_operation::OperationContext, ) -> GlobalResult { let crdb = ctx.crdb().await?; + let nats = ctx.conn().nats().await?; - Ok(DatabasePostgres::from_pool(crdb)) + Ok(DatabasePgNats::from_pools(crdb, nats)) } // Get crdb pool as a trait object pub async fn db_from_pools(pools: &rivet_pools::Pools) -> GlobalResult { let crdb = pools.crdb()?; + let nats = pools.nats()?; - Ok(DatabasePostgres::from_pool(crdb)) + Ok(DatabasePgNats::from_pools(crdb, nats)) } diff --git a/lib/chirp-workflow/core/src/ctx/test.rs b/lib/chirp-workflow/core/src/ctx/test.rs index 7c4c027dc..aa161fda9 100644 --- a/lib/chirp-workflow/core/src/ctx/test.rs +++ b/lib/chirp-workflow/core/src/ctx/test.rs @@ -10,7 +10,7 @@ use crate::{ workflow::SUB_WORKFLOW_RETRY, MessageCtx, OperationCtx, }, - db::{DatabaseHandle, DatabasePostgres}, + db::{DatabaseHandle, DatabasePgNats}, error::WorkflowError, message::{Message, ReceivedMessage}, operation::{Operation, OperationInput}, @@ -70,7 +70,8 @@ impl TestCtx { (), ); - let db = DatabasePostgres::from_pool(pools.crdb().unwrap()); + let db = + DatabasePgNats::from_pools(pools.crdb().unwrap(), pools.nats_option().clone().unwrap()); let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await.unwrap(); TestCtx { diff --git a/lib/chirp-workflow/core/src/ctx/workflow.rs b/lib/chirp-workflow/core/src/ctx/workflow.rs index fbc133c92..ee6953001 100644 --- a/lib/chirp-workflow/core/src/ctx/workflow.rs +++ b/lib/chirp-workflow/core/src/ctx/workflow.rs @@ -753,7 +753,10 @@ impl WorkflowCtx { /// Tests if the given error is unrecoverable. If it is, allows the user to run recovery code safely. /// Should always be used when trying to handle activity errors manually. - pub fn catch_unrecoverable(&mut self, res: GlobalResult) -> GlobalResult> { + pub fn catch_unrecoverable( + &mut self, + res: GlobalResult, + ) -> GlobalResult> { match res { Err(err) if !err.is_workflow_recoverable() => { self.location_idx += 1; @@ -1082,7 +1085,7 @@ impl WorkflowCtx { let location = self.full_location(); let (msg, write) = tokio::join!( - self.db.publish_message_from_workflow( + self.db.commit_workflow_message_send_event( self.workflow_id, location.as_ref(), &tags, @@ -1143,7 +1146,7 @@ impl WorkflowCtx { let location = self.full_location(); let (msg, write) = tokio::join!( - self.db.publish_message_from_workflow( + self.db.commit_workflow_message_send_event( self.workflow_id, location.as_ref(), &tags, diff --git a/lib/chirp-workflow/core/src/db/mod.rs b/lib/chirp-workflow/core/src/db/mod.rs index 7e9307b56..fdbb77c8c 100644 --- a/lib/chirp-workflow/core/src/db/mod.rs +++ b/lib/chirp-workflow/core/src/db/mod.rs @@ -10,13 +10,14 @@ use crate::{ workflow::Workflow, }; -mod postgres; -pub use postgres::DatabasePostgres; +mod pg_nats; +pub use pg_nats::DatabasePgNats; pub type DatabaseHandle = Arc; #[async_trait::async_trait] pub trait Database: Send { + /// Writes a new workflow to the database. async fn dispatch_workflow( &self, ray_id: Uuid, @@ -26,19 +27,22 @@ pub trait Database: Send { input: serde_json::Value, ) -> WorkflowResult<()>; async fn get_workflow(&self, id: Uuid) -> WorkflowResult>; + + /// Pulls workflows for processing by the worker. Will only pull workflows with names matching the filter. async fn pull_workflows( &self, worker_instance_id: Uuid, filter: &[&str], ) -> WorkflowResult>; - // When a workflow is completed + /// Mark a workflow as completed. async fn commit_workflow( &self, workflow_id: Uuid, output: &serde_json::Value, ) -> WorkflowResult<()>; - // When a workflow fails + + /// Write a workflow failure to the database. async fn fail_workflow( &self, workflow_id: Uuid, @@ -48,12 +52,15 @@ pub trait Database: Send { wake_sub_workflow: Option, error: &str, ) -> WorkflowResult<()>; + + /// Updates workflow tags. async fn update_workflow_tags( &self, workflow_id: Uuid, tags: &serde_json::Value, ) -> WorkflowResult<()>; + /// Write a workflow activity event to history. async fn commit_workflow_activity_event( &self, workflow_id: Uuid, @@ -65,6 +72,7 @@ pub trait Database: Send { loop_location: Option<&[usize]>, ) -> WorkflowResult<()>; + /// Pulls the oldest signal with the given filter. Pulls from regular and tagged signals. async fn pull_next_signal( &self, workflow_id: Uuid, @@ -72,6 +80,8 @@ pub trait Database: Send { location: &[usize], loop_location: Option<&[usize]>, ) -> WorkflowResult>; + + /// Write a new signal to the database. async fn publish_signal( &self, ray_id: Uuid, @@ -80,6 +90,8 @@ pub trait Database: Send { signal_name: &str, body: serde_json::Value, ) -> WorkflowResult<()>; + + /// Write a new tagged signal to the database. async fn publish_tagged_signal( &self, ray_id: Uuid, @@ -88,6 +100,8 @@ pub trait Database: Send { signal_name: &str, body: serde_json::Value, ) -> WorkflowResult<()>; + + /// Write a new signal to the database. Contains extra info used to populate the history. async fn publish_signal_from_workflow( &self, from_workflow_id: Uuid, @@ -99,6 +113,8 @@ pub trait Database: Send { body: serde_json::Value, loop_location: Option<&[usize]>, ) -> WorkflowResult<()>; + + /// Write a new tagged signal to the database. Contains extra info used to populate the history. async fn publish_tagged_signal_from_workflow( &self, from_workflow_id: Uuid, @@ -111,6 +127,7 @@ pub trait Database: Send { loop_location: Option<&[usize]>, ) -> WorkflowResult<()>; + /// Publish a new workflow from an existing workflow. async fn dispatch_sub_workflow( &self, ray_id: Uuid, @@ -123,7 +140,8 @@ pub trait Database: Send { loop_location: Option<&[usize]>, ) -> WorkflowResult<()>; - /// Fetches a workflow that has the given json as a subset of its input after the given ts. + /// Fetches a workflow that has the given json as a subset of its input after the given ts. Used primarily + /// in tests. async fn poll_workflow( &self, name: &str, @@ -131,7 +149,8 @@ pub trait Database: Send { after_ts: i64, ) -> WorkflowResult>; - async fn publish_message_from_workflow( + /// Writes a message send event to history. + async fn commit_workflow_message_send_event( &self, from_workflow_id: Uuid, location: &[usize], @@ -141,6 +160,7 @@ pub trait Database: Send { loop_location: Option<&[usize]>, ) -> WorkflowResult<()>; + /// Updates a loop event in history and forgets all history items in the previous iteration. async fn update_loop( &self, workflow_id: Uuid, diff --git a/lib/chirp-workflow/core/src/db/postgres.rs b/lib/chirp-workflow/core/src/db/pg_nats.rs similarity index 94% rename from lib/chirp-workflow/core/src/db/postgres.rs rename to lib/chirp-workflow/core/src/db/pg_nats.rs index 89b4a19b5..4620bf315 100644 --- a/lib/chirp-workflow/core/src/db/postgres.rs +++ b/lib/chirp-workflow/core/src/db/pg_nats.rs @@ -1,7 +1,9 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use indoc::indoc; +use rivet_pools::prelude::NatsPool; use sqlx::{pool::PoolConnection, Acquire, PgPool, Postgres}; +use tracing::Instrument; use uuid::Uuid; use super::{ @@ -13,6 +15,7 @@ use crate::{ activity::ActivityId, error::{WorkflowError, WorkflowResult}, event::combine_events, + message, }; /// Max amount of workflows pulled from the database with each call to `pull_workflows`. @@ -20,37 +23,14 @@ const MAX_PULLED_WORKFLOWS: i64 = 10; /// Maximum times a query ran bu this database adapter is retried. const MAX_QUERY_RETRIES: usize = 16; -pub struct DatabasePostgres { +pub struct DatabasePgNats { pool: PgPool, + nats: NatsPool, } -impl DatabasePostgres { - pub async fn new(url: &str) -> WorkflowResult> { - let pool = sqlx::postgres::PgPoolOptions::new() - // The default connection timeout is too high - .acquire_timeout(Duration::from_secs(15)) - // Increase lifetime to mitigate: https://github.com/launchbadge/sqlx/issues/2854 - // - // See max lifetime https://www.cockroachlabs.com/docs/stable/connection-pooling#set-the-maximum-lifetime-of-connections - .max_lifetime(Duration::from_secs(30 * 60)) - // Remove connections after a while in order to reduce load - // on CRDB after bursts - .idle_timeout(Some(Duration::from_secs(3 * 60))) - // Open connections immediately on startup - .min_connections(1) - // Raise the cap, since this is effectively the amount of - // simultaneous requests we can handle. See - // https://www.cockroachlabs.com/docs/stable/connection-pooling.html - .max_connections(4096) - .connect(url) - .await - .map_err(WorkflowError::BuildSqlx)?; - - Ok(Arc::new(DatabasePostgres { pool })) - } - - pub fn from_pool(pool: PgPool) -> Arc { - Arc::new(DatabasePostgres { pool }) +impl DatabasePgNats { + pub fn from_pools(pool: PgPool, nats: NatsPool) -> Arc { + Arc::new(DatabasePgNats { pool, nats }) } async fn conn(&self) -> WorkflowResult> { @@ -63,6 +43,29 @@ impl DatabasePostgres { } } + /// Spawns a new thread and publishes a worker wake message to nats. + fn wake_worker(&self) { + let nats = self.nats.clone(); + + let spawn_res = tokio::task::Builder::new() + .name("chirp_workflow::DatabasePgNats::wake") + .spawn( + async move { + // Fail gracefully + if let Err(err) = nats + .publish(message::WORKER_WAKE_SUBJECT, Vec::new().into()) + .await + { + tracing::warn!(?err, "failed to publish wake message"); + } + } + .in_current_span(), + ); + if let Err(err) = spawn_res { + tracing::error!(?err, "failed to spawn wake task"); + } + } + /// Executes queries and explicitly handles retry errors. async fn query<'a, F, Fut, T>(&self, mut cb: F) -> WorkflowResult where @@ -94,7 +97,7 @@ impl DatabasePostgres { } #[async_trait::async_trait] -impl Database for DatabasePostgres { +impl Database for DatabasePgNats { async fn dispatch_workflow( &self, ray_id: Uuid, @@ -124,6 +127,8 @@ impl Database for DatabasePostgres { }) .await?; + self.wake_worker(); + Ok(()) } @@ -378,6 +383,8 @@ impl Database for DatabasePostgres { }) .await?; + self.wake_worker(); + Ok(()) } @@ -417,6 +424,8 @@ impl Database for DatabasePostgres { }) .await?; + self.wake_worker(); + Ok(()) } @@ -648,6 +657,8 @@ impl Database for DatabasePostgres { }) .await?; + self.wake_worker(); + Ok(()) } @@ -678,6 +689,8 @@ impl Database for DatabasePostgres { }) .await?; + self.wake_worker(); + Ok(()) } @@ -726,6 +739,8 @@ impl Database for DatabasePostgres { }) .await?; + self.wake_worker(); + Ok(()) } @@ -774,6 +789,8 @@ impl Database for DatabasePostgres { }) .await?; + self.wake_worker(); + Ok(()) } @@ -824,6 +841,8 @@ impl Database for DatabasePostgres { }) .await?; + self.wake_worker(); + Ok(()) } @@ -852,7 +871,7 @@ impl Database for DatabasePostgres { .map_err(WorkflowError::Sqlx) } - async fn publish_message_from_workflow( + async fn commit_workflow_message_send_event( &self, from_workflow_id: Uuid, location: &[usize], diff --git a/lib/chirp-workflow/core/src/message.rs b/lib/chirp-workflow/core/src/message.rs index d7aa9f277..7f08b1a74 100644 --- a/lib/chirp-workflow/core/src/message.rs +++ b/lib/chirp-workflow/core/src/message.rs @@ -6,6 +6,8 @@ use uuid::Uuid; use crate::error::{WorkflowError, WorkflowResult}; +pub const WORKER_WAKE_SUBJECT: &str = "chirp.workflow.worker.wake"; + pub trait Message: Debug + Send + Sync + Serialize + DeserializeOwned + 'static { const NAME: &'static str; const TAIL_TTL: std::time::Duration; diff --git a/lib/chirp-workflow/core/src/worker.rs b/lib/chirp-workflow/core/src/worker.rs index 311c6d0d2..5123aaff8 100644 --- a/lib/chirp-workflow/core/src/worker.rs +++ b/lib/chirp-workflow/core/src/worker.rs @@ -1,11 +1,15 @@ +use futures_util::StreamExt; use global_error::GlobalResult; use tokio::time::Duration; use tracing::Instrument; use uuid::Uuid; -use crate::{ctx::WorkflowCtx, db::DatabaseHandle, registry::RegistryHandle, util}; +use crate::{ + ctx::WorkflowCtx, db::DatabaseHandle, error::WorkflowError, message, registry::RegistryHandle, + util, +}; -const TICK_INTERVAL: Duration = Duration::from_millis(200); +const TICK_INTERVAL: Duration = Duration::from_secs(30); /// Used to spawn a new thread that indefinitely polls the database for new workflows. Only pulls workflows /// that are registered in its registry. After pulling, the workflows are ran and their state is written to @@ -32,14 +36,54 @@ impl Worker { self.registry.size(), ); + let shared_client = chirp_client::SharedClient::from_env(pools.clone())?; + let cache = rivet_cache::CacheInner::from_env(pools.clone())?; + + // Regular tick interval to poll the database let mut interval = tokio::time::interval(TICK_INTERVAL); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + loop { + interval.tick().await; + self.tick(&shared_client, &pools, &cache).await?; + } + } + + pub async fn start_with_nats(mut self, pools: rivet_pools::Pools) -> GlobalResult<()> { + tracing::info!( + worker_instance_id=?self.worker_instance_id, + "starting worker instance with {} registered workflows", + self.registry.size(), + ); + let shared_client = chirp_client::SharedClient::from_env(pools.clone())?; let cache = rivet_cache::CacheInner::from_env(pools.clone())?; + // Regular tick interval to poll the database + let mut interval = tokio::time::interval(TICK_INTERVAL); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + // Create a subscription to the wake subject which receives messages whenever the worker should be + // awoken + let mut sub = pools + .nats()? + .subscribe(message::WORKER_WAKE_SUBJECT) + .await + .map_err(|x| WorkflowError::CreateSubscription(x.into()))?; + loop { - interval.tick().await; + tokio::select! { + _ = interval.tick() => {}, + msg = sub.next() => { + match msg { + Some(_) => interval.reset(), + None => { + return Err(WorkflowError::SubscriptionUnsubscribed.into()); + } + } + } + } + self.tick(&shared_client, &pools, &cache).await?; } } @@ -53,7 +97,8 @@ impl Worker { ) -> GlobalResult<()> { tracing::trace!("tick"); - let registered_workflows = self + // Create filter from registered workflow names + let filter = self .registry .workflows .keys() @@ -63,7 +108,7 @@ impl Worker { // Query awake workflows let workflows = self .db - .pull_workflows(self.worker_instance_id, ®istered_workflows) + .pull_workflows(self.worker_instance_id, &filter) .await?; for workflow in workflows { let conn = util::new_conn( diff --git a/svc/Cargo.toml b/svc/Cargo.toml index bffcdade4..5c18538a1 100644 --- a/svc/Cargo.toml +++ b/svc/Cargo.toml @@ -61,6 +61,7 @@ members = [ "pkg/cloud/ops/version-get", "pkg/cloud/ops/version-publish", "pkg/cloud/worker", + "pkg/cluster", "pkg/cluster/standalone/datacenter-tls-renew", "pkg/cluster/standalone/default-update", "pkg/cluster/standalone/gc", @@ -145,6 +146,7 @@ members = [ "pkg/kv/ops/get", "pkg/kv/ops/list", "pkg/kv/worker", + "pkg/linode", "pkg/linode/standalone/gc", "pkg/load-test/standalone/api-cloud", "pkg/load-test/standalone/mm", diff --git a/svc/pkg/monolith/standalone/workflow-worker/src/lib.rs b/svc/pkg/monolith/standalone/workflow-worker/src/lib.rs index c3b90a2fe..fb9319f1c 100644 --- a/svc/pkg/monolith/standalone/workflow-worker/src/lib.rs +++ b/svc/pkg/monolith/standalone/workflow-worker/src/lib.rs @@ -4,10 +4,10 @@ use chirp_workflow::prelude::*; pub async fn run_from_env(pools: rivet_pools::Pools) -> GlobalResult<()> { let reg = cluster::registry().merge(linode::registry()); - let db = db::DatabasePostgres::from_pool(pools.crdb().unwrap()); + let db = db::DatabasePgNats::from_pools(pools.crdb()?, pools.nats()?); let worker = Worker::new(reg.handle(), db.clone()); // Start worker - worker.start(pools).await?; + worker.start_with_nats(pools).await?; bail!("worker exited unexpectedly"); }