diff --git a/docs/libraries/workflow/GOTCHAS.md b/docs/libraries/workflow/GOTCHAS.md index 4528dad44..1d81bca87 100644 --- a/docs/libraries/workflow/GOTCHAS.md +++ b/docs/libraries/workflow/GOTCHAS.md @@ -18,3 +18,93 @@ This will be the current timestamp on the first execution of the activity and wo ## Randomly generated content Randomly generated content like UUIDs should be placed in activities for consistent history. + +## Stale data + +When fetching data for use in a workflow, you will most often put it in an activity for retryability. However, +depending on how much later the data from the activity is used, it may become stale. Make sure to add another +activity where needed when you need more up-to-date info. + +## `WorkflowCtx::spawn` + +`WorkflowCtx::spawn` allows you to run workflow steps in a different thread and returns its join handle. Be +**very careful** when using it because it is the developers responsibility to make sure it's result is handled +correctly. If a spawn thread errors but its result is not handled, the main thread may continue as though no +error occurred. This will result in a corrupt workflow state and a divergent history. + +Also see [Consistency with concurrency](#consistency-with-concurrency). + +## Consistency with concurrency + +When you need to run multiple workflow events (like activities or signals) in parallel, be careful that you +ensure the state of the context is consistent between replays. + +Take this example trying to concurrently run multiple activities: + +```rust +let iter = actions.into_iter().map(|action| { + let ctx = ctx.clone(); + + async move { + ctx.activity(MyActivityInput { + action, + }).await?; + } + .boxed() +}); + +futures_util::stream::iter(iter) + .buffer_unordered(16) + .try_collect::>() + .await?; +``` + +This will error because of the `ctx.clone()`; each activity has the same internal location because none of the +ctx's know about each other\*. + +Instead, you can increment the location preemptively with `ctx.step()`: + +```rust +let iter = actions.into_iter().map(|action| { + let ctx = ctx.step(); + + async move { + ctx.activity(MyActivityInput { + action, + }).await?; + } + .boxed() +}); + +futures_util::stream::iter(iter) + .buffer_unordered(16) + .try_collect::>() + .await?; +``` + +If you plan on running more than one workflow step in each future, use a branch instead: + +```rust +let iter = actions.into_iter().map(|action| { + let ctx = ctx.branch(); + + async move { + ctx.activity(MyActivityInput { + action, + }).await?; + } + .boxed() +}); + +futures_util::stream::iter(iter) + .buffer_unordered(16) + .try_collect::>() + .await?; +``` + +Note that the first example would also work with a branch, but its a bit overkill as it creates a new layer in +the internal location. + +> **\*** Even if they did know about each other via atomics, there is no guarantee of consistency from +> `buffer_unordered`. Preemptively incrementing the location ensures consistency regardless of the order or +> completion time of the futures. diff --git a/lib/bolt/core/src/tasks/test.rs b/lib/bolt/core/src/tasks/test.rs index 68cb68844..156d587e2 100644 --- a/lib/bolt/core/src/tasks/test.rs +++ b/lib/bolt/core/src/tasks/test.rs @@ -574,7 +574,6 @@ struct SshKey { id: u64, } -// TODO: This only deletes linodes and firewalls, the ssh key still remains async fn cleanup_servers(ctx: &ProjectContext) -> Result<()> { if ctx.ns().rivet.provisioning.is_none() { return Ok(()); @@ -584,6 +583,7 @@ async fn cleanup_servers(ctx: &ProjectContext) -> Result<()> { rivet_term::status::progress("Cleaning up servers", ""); let ns = ctx.ns_id(); + let ns_full = format!("rivet-{ns}"); // Create client let api_token = ctx.read_secret(&["linode", "token"]).await?; @@ -642,7 +642,13 @@ async fn cleanup_servers(ctx: &ProjectContext) -> Result<()> { .data .into_iter() // Only delete test objects from this namespace - .filter(|object| object.data.tags.iter().any(|tag| tag == ns)) + .filter(|object| { + object + .data + .tags + .iter() + .any(|tag| tag == ns || tag == &ns_full) + }) .map(|object| { let client = client.clone(); let obj_type = object._type; diff --git a/lib/chirp-workflow/core/src/compat.rs b/lib/chirp-workflow/core/src/compat.rs index 8f24b69fd..2f4b5753f 100644 --- a/lib/chirp-workflow/core/src/compat.rs +++ b/lib/chirp-workflow/core/src/compat.rs @@ -31,11 +31,10 @@ where } let name = I::Workflow::NAME; - - tracing::info!(%name, ?input, "dispatching workflow"); - let id = Uuid::new_v4(); + tracing::info!(%name, %id, ?input, "dispatching workflow"); + // Serialize input let input_val = serde_json::to_value(input) .map_err(WorkflowError::SerializeWorkflowOutput) @@ -67,11 +66,10 @@ where } let name = I::Workflow::NAME; - - tracing::info!(%name, ?input, "dispatching workflow"); - let id = Uuid::new_v4(); + tracing::info!(%name, %id, ?input, "dispatching workflow"); + // Serialize input let input_val = serde_json::to_value(input) .map_err(WorkflowError::SerializeWorkflowOutput) @@ -202,6 +200,7 @@ pub async fn tagged_signal( Ok(signal_id) } +#[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] pub async fn op( ctx: &rivet_operation::OperationContext, input: I, @@ -211,6 +210,8 @@ where ::Operation: Operation, B: Debug + Clone, { + tracing::info!(?input, "operation call"); + let ctx = OperationCtx::new( db_from_ctx(ctx).await?, ctx.conn(), @@ -220,10 +221,14 @@ where I::Operation::NAME, ); - I::Operation::run(&ctx, &input) + let res = I::Operation::run(&ctx, &input) .await .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw) + .map_err(GlobalError::raw); + + tracing::info!(?res, "operation response"); + + res } pub async fn subscribe( diff --git a/lib/chirp-workflow/core/src/ctx/activity.rs b/lib/chirp-workflow/core/src/ctx/activity.rs index 6a4119616..21681fe92 100644 --- a/lib/chirp-workflow/core/src/ctx/activity.rs +++ b/lib/chirp-workflow/core/src/ctx/activity.rs @@ -56,6 +56,7 @@ impl ActivityCtx { } impl ActivityCtx { + #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] pub async fn op( &self, input: I, @@ -64,6 +65,8 @@ impl ActivityCtx { I: OperationInput, ::Operation: Operation, { + tracing::info!(?input, "operation call"); + let ctx = OperationCtx::new( self.db.clone(), &self.conn, @@ -73,11 +76,15 @@ impl ActivityCtx { I::Operation::NAME, ); - tokio::time::timeout(I::Operation::TIMEOUT, I::Operation::run(&ctx, &input)) + let res = tokio::time::timeout(I::Operation::TIMEOUT, I::Operation::run(&ctx, &input)) .await .map_err(|_| WorkflowError::OperationTimeout)? .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw) + .map_err(GlobalError::raw); + + tracing::info!(?res, "operation response"); + + res } pub async fn update_workflow_tags(&self, tags: &serde_json::Value) -> GlobalResult<()> { diff --git a/lib/chirp-workflow/core/src/ctx/api.rs b/lib/chirp-workflow/core/src/ctx/api.rs index 1c91b83f7..c2715478b 100644 --- a/lib/chirp-workflow/core/src/ctx/api.rs +++ b/lib/chirp-workflow/core/src/ctx/api.rs @@ -73,11 +73,10 @@ impl ApiCtx { ::Workflow: Workflow, { let name = I::Workflow::NAME; - - tracing::info!(%name, ?input, "dispatching workflow"); - let id = Uuid::new_v4(); + tracing::info!(%name, %id, ?input, "dispatching workflow"); + // Serialize input let input_val = serde_json::to_value(input) .map_err(WorkflowError::SerializeWorkflowOutput) @@ -103,11 +102,10 @@ impl ApiCtx { ::Workflow: Workflow, { let name = I::Workflow::NAME; - - tracing::info!(%name, ?tags, ?input, "dispatching tagged workflow"); - let id = Uuid::new_v4(); + tracing::info!(%name, %id, ?tags, ?input, "dispatching tagged workflow"); + // Serialize input let input_val = serde_json::to_value(input) .map_err(WorkflowError::SerializeWorkflowOutput) @@ -223,6 +221,7 @@ impl ApiCtx { Ok(signal_id) } + #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] pub async fn op( &self, input: I, @@ -231,6 +230,8 @@ impl ApiCtx { I: OperationInput, ::Operation: Operation, { + tracing::info!(?input, "operation call"); + let ctx = OperationCtx::new( self.db.clone(), &self.conn, @@ -240,10 +241,14 @@ impl ApiCtx { I::Operation::NAME, ); - I::Operation::run(&ctx, &input) + let res = I::Operation::run(&ctx, &input) .await .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw) + .map_err(GlobalError::raw); + + tracing::info!(?res, "operation response"); + + res } pub async fn subscribe( diff --git a/lib/chirp-workflow/core/src/ctx/listen.rs b/lib/chirp-workflow/core/src/ctx/listen.rs new file mode 100644 index 000000000..92232160c --- /dev/null +++ b/lib/chirp-workflow/core/src/ctx/listen.rs @@ -0,0 +1,44 @@ +use crate::{ + ctx::WorkflowCtx, + db::SignalRow, + error::{WorkflowError, WorkflowResult}, +}; + +/// Indirection struct to prevent invalid implementations of listen traits. +pub struct ListenCtx<'a> { + ctx: &'a mut WorkflowCtx, +} + +impl<'a> ListenCtx<'a> { + pub(crate) fn new(ctx: &'a mut WorkflowCtx) -> Self { + ListenCtx { ctx } + } + + /// Checks for a signal to this workflow with any of the given signal names. + pub async fn listen_any(&self, signal_names: &[&'static str]) -> WorkflowResult { + // Fetch new pending signal + let signal = self + .ctx + .db + .pull_next_signal( + self.ctx.workflow_id(), + signal_names, + self.ctx.full_location().as_ref(), + ) + .await?; + + let Some(signal) = signal else { + return Err(WorkflowError::NoSignalFound(Box::from(signal_names))); + }; + + tracing::info!( + workflow_name=%self.ctx.name(), + workflow_id=%self.ctx.workflow_id(), + signal_id=%signal.signal_id, + signal_name=%signal.signal_name, + "signal received", + ); + + Ok(signal) + } +} diff --git a/lib/chirp-workflow/core/src/ctx/mod.rs b/lib/chirp-workflow/core/src/ctx/mod.rs index 68c2f2848..7850baad9 100644 --- a/lib/chirp-workflow/core/src/ctx/mod.rs +++ b/lib/chirp-workflow/core/src/ctx/mod.rs @@ -1,5 +1,6 @@ mod activity; pub(crate) mod api; +mod listen; pub mod message; mod operation; mod standalone; @@ -7,6 +8,7 @@ mod test; pub(crate) mod workflow; pub use activity::ActivityCtx; pub use api::ApiCtx; +pub use listen::ListenCtx; pub use message::MessageCtx; pub use operation::OperationCtx; pub use standalone::StandaloneCtx; diff --git a/lib/chirp-workflow/core/src/ctx/operation.rs b/lib/chirp-workflow/core/src/ctx/operation.rs index 104c683cb..c0db0e258 100644 --- a/lib/chirp-workflow/core/src/ctx/operation.rs +++ b/lib/chirp-workflow/core/src/ctx/operation.rs @@ -55,6 +55,7 @@ impl OperationCtx { } impl OperationCtx { + #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] pub async fn op( &self, input: I, @@ -63,6 +64,8 @@ impl OperationCtx { I: OperationInput, ::Operation: Operation, { + tracing::info!(?input, "operation call"); + let ctx = OperationCtx::new( self.db.clone(), &self.conn, @@ -72,10 +75,14 @@ impl OperationCtx { I::Operation::NAME, ); - I::Operation::run(&ctx, &input) + let res = I::Operation::run(&ctx, &input) .await .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw) + .map_err(GlobalError::raw); + + tracing::info!(?res, "operation response"); + + res } pub async fn signal( diff --git a/lib/chirp-workflow/core/src/ctx/standalone.rs b/lib/chirp-workflow/core/src/ctx/standalone.rs index f3bd819f4..c1494d175 100644 --- a/lib/chirp-workflow/core/src/ctx/standalone.rs +++ b/lib/chirp-workflow/core/src/ctx/standalone.rs @@ -72,11 +72,10 @@ impl StandaloneCtx { ::Workflow: Workflow, { let name = I::Workflow::NAME; - - tracing::info!(%name, ?input, "dispatching workflow"); - let id = Uuid::new_v4(); + tracing::info!(%name, %id, ?input, "dispatching workflow"); + // Serialize input let input_val = serde_json::to_value(input) .map_err(WorkflowError::SerializeWorkflowOutput) @@ -102,11 +101,10 @@ impl StandaloneCtx { ::Workflow: Workflow, { let name = I::Workflow::NAME; - - tracing::info!(%name, ?tags, ?input, "dispatching tagged workflow"); - let id = Uuid::new_v4(); + tracing::info!(%name, %id, ?tags, ?input, "dispatching tagged workflow"); + // Serialize input let input_val = serde_json::to_value(input) .map_err(WorkflowError::SerializeWorkflowOutput) @@ -222,6 +220,7 @@ impl StandaloneCtx { Ok(signal_id) } + #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] pub async fn op( &self, input: I, @@ -230,6 +229,8 @@ impl StandaloneCtx { I: OperationInput, ::Operation: Operation, { + tracing::info!(?input, "operation call"); + let ctx = OperationCtx::new( self.db.clone(), &self.conn, @@ -239,10 +240,14 @@ impl StandaloneCtx { I::Operation::NAME, ); - I::Operation::run(&ctx, &input) + let res = I::Operation::run(&ctx, &input) .await .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw) + .map_err(GlobalError::raw); + + tracing::info!(?res, "operation response"); + + res } pub async fn subscribe( diff --git a/lib/chirp-workflow/core/src/ctx/test.rs b/lib/chirp-workflow/core/src/ctx/test.rs index dc42556a5..98c40d084 100644 --- a/lib/chirp-workflow/core/src/ctx/test.rs +++ b/lib/chirp-workflow/core/src/ctx/test.rs @@ -88,11 +88,10 @@ impl TestCtx { ::Workflow: Workflow, { let name = I::Workflow::NAME; - - tracing::info!(%name, ?input, "dispatching workflow"); - let id = Uuid::new_v4(); + tracing::info!(%name, %id, ?input, "dispatching workflow"); + // Serialize input let input_val = serde_json::to_value(input) .map_err(WorkflowError::SerializeWorkflowOutput) @@ -118,11 +117,10 @@ impl TestCtx { ::Workflow: Workflow, { let name = I::Workflow::NAME; - - tracing::info!(%name, ?tags, ?input, "dispatching tagged workflow"); - let id = Uuid::new_v4(); + tracing::info!(%name, %id, ?tags, ?input, "dispatching tagged workflow"); + // Serialize input let input_val = serde_json::to_value(input) .map_err(WorkflowError::SerializeWorkflowOutput) @@ -248,6 +246,7 @@ impl TestCtx { }) } + #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] pub async fn op( &self, input: I, @@ -256,6 +255,8 @@ impl TestCtx { I: OperationInput, ::Operation: Operation, { + tracing::info!(?input, "operation call"); + let ctx = OperationCtx::new( self.db.clone(), &self.conn, @@ -265,10 +266,14 @@ impl TestCtx { I::Operation::NAME, ); - I::Operation::run(&ctx, &input) + let res = I::Operation::run(&ctx, &input) .await .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw) + .map_err(GlobalError::raw); + + tracing::info!(?res, "operation response"); + + res } pub async fn msg(&self, tags: serde_json::Value, body: M) -> GlobalResult<()> diff --git a/lib/chirp-workflow/core/src/ctx/workflow.rs b/lib/chirp-workflow/core/src/ctx/workflow.rs index 2959d2794..fbd8bc333 100644 --- a/lib/chirp-workflow/core/src/ctx/workflow.rs +++ b/lib/chirp-workflow/core/src/ctx/workflow.rs @@ -7,13 +7,15 @@ use uuid::Uuid; use crate::{ activity::ActivityId, - ctx::{ActivityCtx, MessageCtx}, + ctx::{ActivityCtx, ListenCtx, MessageCtx}, event::Event, executable::{closure, AsyncResult, Executable}, + listen::{CustomListener, Listen}, message::Message, + signal::Signal, util::Location, - Activity, ActivityInput, DatabaseHandle, Listen, PulledWorkflow, RegistryHandle, Signal, - SignalRow, Workflow, WorkflowError, WorkflowInput, WorkflowResult, + Activity, ActivityInput, DatabaseHandle, PulledWorkflow, RegistryHandle, Workflow, + WorkflowError, WorkflowInput, WorkflowResult, }; // Time to delay a worker from retrying after an error @@ -43,7 +45,7 @@ pub struct WorkflowCtx { ray_id: Uuid, registry: RegistryHandle, - db: DatabaseHandle, + pub(crate) db: DatabaseHandle, conn: rivet_connection::Connection, @@ -148,7 +150,7 @@ impl WorkflowCtx { .flatten() } - fn full_location(&self) -> Location { + pub(crate) fn full_location(&self) -> Location { self.root_location .iter() .cloned() @@ -330,34 +332,6 @@ impl WorkflowCtx { } } } - - /// Checks for a signal to this workflow with any of the given signal names. Meant to be implemented and - /// not used directly in workflows. - pub async fn listen_any(&mut self, signal_names: &[&'static str]) -> WorkflowResult { - // Fetch new pending signal - let signal = self - .db - .pull_next_signal( - self.workflow_id, - signal_names, - self.full_location().as_ref(), - ) - .await?; - - let Some(signal) = signal else { - return Err(WorkflowError::NoSignalFound(Box::from(signal_names))); - }; - - tracing::info!( - workflow_name=%self.name, - workflow_id=%self.workflow_id, - signal_id=%signal.signal_id, - signal_name=%signal.signal_name, - "signal received", - ); - - Ok(signal) - } } impl WorkflowCtx { @@ -416,12 +390,19 @@ impl WorkflowCtx { } // Dispatch new workflow else { - let name = I::Workflow::NAME; - - tracing::info!(%name, ?tags, ?input, "dispatching workflow"); - + let sub_workflow_name = I::Workflow::NAME; let sub_workflow_id = Uuid::new_v4(); + tracing::info!( + name=%self.name, + id=%self.workflow_id, + %sub_workflow_name, + %sub_workflow_id, + ?tags, + ?input, + "dispatching sub workflow" + ); + // Serialize input let input_val = serde_json::to_value(input) .map_err(WorkflowError::SerializeWorkflowOutput) @@ -433,14 +414,20 @@ impl WorkflowCtx { self.workflow_id, self.full_location().as_ref(), sub_workflow_id, - &name, + &sub_workflow_name, tags, input_val, ) .await .map_err(GlobalError::raw)?; - tracing::info!(%name, ?sub_workflow_id, "workflow dispatched"); + tracing::info!( + name=%self.name, + id=%self.workflow_id, + %sub_workflow_name, + ?sub_workflow_id, + "workflow dispatched" + ); sub_workflow_id }; @@ -752,7 +739,7 @@ impl WorkflowCtx { else { let signal_id = Uuid::new_v4(); - tracing::debug!(name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal"); + tracing::info!(name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal"); // Serialize input let input_val = serde_json::to_value(&body) @@ -805,10 +792,62 @@ impl WorkflowCtx { let mut interval = tokio::time::interval(SIGNAL_RETRY); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + let ctx = ListenCtx::new(self); + + loop { + interval.tick().await; + + match T::listen(&ctx).await { + Ok(res) => break res, + Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => { + if retries > MAX_SIGNAL_RETRIES { + return Err(err).map_err(GlobalError::raw); + } + retries += 1; + } + err => return err.map_err(GlobalError::raw), + } + } + }; + + // Move to next event + self.location_idx += 1; + + Ok(signal) + } + + /// Execute a custom listener. + pub async fn custom_listener( + &mut self, + listener: &T, + ) -> GlobalResult<::Output> { + let event = { self.relevant_history().nth(self.location_idx) }; + + // Signal received before + let signal = if let Some(event) = event { + // Validate history is consistent + let Event::Signal(signal) = event else { + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); + }; + + tracing::debug!(name=%self.name, id=%self.workflow_id, signal_name=%signal.name, "replaying signal"); + + T::parse(&signal.name, signal.body.clone()).map_err(GlobalError::raw)? + } + // Listen for new messages + else { + tracing::info!(name=%self.name, id=%self.workflow_id, "listening for signal"); + + let mut retries = 0; + let mut interval = tokio::time::interval(SIGNAL_RETRY); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + let ctx = ListenCtx::new(self); + loop { interval.tick().await; - match T::listen(self).await { + match listener.listen(&ctx).await { Ok(res) => break res, Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => { if retries > MAX_SIGNAL_RETRIES { @@ -844,7 +883,9 @@ impl WorkflowCtx { } // Listen for new message else { - match T::listen(self).await { + let ctx = ListenCtx::new(self); + + match T::listen(&ctx).await { Ok(res) => Some(res), Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => None, Err(err) => return Err(err).map_err(GlobalError::raw), diff --git a/lib/chirp-workflow/core/src/db/postgres.rs b/lib/chirp-workflow/core/src/db/postgres.rs index e6882caf0..767ca7815 100644 --- a/lib/chirp-workflow/core/src/db/postgres.rs +++ b/lib/chirp-workflow/core/src/db/postgres.rs @@ -495,7 +495,7 @@ impl Database for DatabasePostgres { ) RETURNING 1 ), - -- After deleting the signal, add it to the events table (i.e. acknowledge it) + -- After acking the signal, add it to the events table insert_event AS ( INSERT INTO db_workflow.workflow_signal_events ( workflow_id, location, signal_id, signal_name, body, ack_ts diff --git a/lib/chirp-workflow/core/src/lib.rs b/lib/chirp-workflow/core/src/lib.rs index e83964b4c..5f260ec0a 100644 --- a/lib/chirp-workflow/core/src/lib.rs +++ b/lib/chirp-workflow/core/src/lib.rs @@ -5,6 +5,7 @@ pub mod db; mod error; mod event; mod executable; +mod listen; pub mod message; pub mod operation; pub mod prelude; diff --git a/lib/chirp-workflow/core/src/listen.rs b/lib/chirp-workflow/core/src/listen.rs new file mode 100644 index 000000000..f21328470 --- /dev/null +++ b/lib/chirp-workflow/core/src/listen.rs @@ -0,0 +1,22 @@ +use async_trait::async_trait; + +use crate::{ctx::ListenCtx, error::WorkflowResult}; + +/// A trait which allows listening for signals from the workflows database. This is used by +/// `WorkflowCtx::listen` and `WorkflowCtx::query_signal`. If you need a listener with state, use +/// `CustomListener`. +#[async_trait] +pub trait Listen: Sized { + async fn listen(ctx: &ListenCtx) -> WorkflowResult; + fn parse(name: &str, body: serde_json::Value) -> WorkflowResult; +} + +/// A trait which allows listening for signals with a custom state. This is used by +/// `WorkflowCtx::custom_listener`. +#[async_trait] +pub trait CustomListener: Sized { + type Output; + + async fn listen(&self, ctx: &ListenCtx) -> WorkflowResult; + fn parse(name: &str, body: serde_json::Value) -> WorkflowResult; +} diff --git a/lib/chirp-workflow/core/src/operation.rs b/lib/chirp-workflow/core/src/operation.rs index 17a447e8b..c33f5e261 100644 --- a/lib/chirp-workflow/core/src/operation.rs +++ b/lib/chirp-workflow/core/src/operation.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use async_trait::async_trait; use global_error::GlobalResult; @@ -6,7 +8,7 @@ use crate::OperationCtx; #[async_trait] pub trait Operation { type Input: OperationInput; - type Output: Send; + type Output: Debug + Send; const NAME: &'static str; const TIMEOUT: std::time::Duration; @@ -14,6 +16,6 @@ pub trait Operation { async fn run(ctx: &OperationCtx, input: &Self::Input) -> GlobalResult; } -pub trait OperationInput: Send { +pub trait OperationInput: Debug + Send { type Operation: Operation; } diff --git a/lib/chirp-workflow/core/src/prelude.rs b/lib/chirp-workflow/core/src/prelude.rs index 3832e4362..433d05503 100644 --- a/lib/chirp-workflow/core/src/prelude.rs +++ b/lib/chirp-workflow/core/src/prelude.rs @@ -20,10 +20,11 @@ pub use crate::{ error::{WorkflowError, WorkflowResult}, executable::closure, executable::Executable, + listen::{CustomListener, Listen}, message::Message, operation::Operation, registry::Registry, - signal::{join_signal, Listen, Signal}, + signal::{join_signal, Signal}, util::GlobalErrorExt, worker::Worker, workflow::Workflow, diff --git a/lib/chirp-workflow/core/src/signal.rs b/lib/chirp-workflow/core/src/signal.rs index f0e4c4cf2..0791f6fee 100644 --- a/lib/chirp-workflow/core/src/signal.rs +++ b/lib/chirp-workflow/core/src/signal.rs @@ -1,19 +1,7 @@ -use async_trait::async_trait; - -use crate::{WorkflowCtx, WorkflowResult}; - pub trait Signal { const NAME: &'static str; } -/// A trait which allows listening for signals from the workflows database. This is used by -/// `WorkflowCtx::listen` and `WorkflowCtx::query_signal`. -#[async_trait] -pub trait Listen: Sized { - async fn listen(ctx: &mut WorkflowCtx) -> WorkflowResult; - fn parse(name: &str, body: serde_json::Value) -> WorkflowResult; -} - /// Creates an enum that implements `Listen` and selects one of X signals. /// /// Example: @@ -68,7 +56,7 @@ macro_rules! join_signal { (@ $join:ident, [$($signals:ident),*]) => { #[async_trait::async_trait] impl Listen for $join { - async fn listen(ctx: &mut chirp_workflow::prelude::WorkflowCtx) -> chirp_workflow::prelude::WorkflowResult { + async fn listen(ctx: &chirp_workflow::prelude::ListenCtx) -> chirp_workflow::prelude::WorkflowResult { let row = ctx.listen_any(&[$($signals::NAME),*]).await?; Self::parse(&row.signal_name, row.body) } diff --git a/lib/chirp-workflow/macros/src/lib.rs b/lib/chirp-workflow/macros/src/lib.rs index 31a7bc139..f668eb108 100644 --- a/lib/chirp-workflow/macros/src/lib.rs +++ b/lib/chirp-workflow/macros/src/lib.rs @@ -306,7 +306,7 @@ pub fn signal(attr: TokenStream, item: TokenStream) -> TokenStream { #[async_trait::async_trait] impl Listen for #struct_ident { - async fn listen(ctx: &mut chirp_workflow::prelude::WorkflowCtx) -> chirp_workflow::prelude::WorkflowResult { + async fn listen(ctx: &chirp_workflow::prelude::ListenCtx) -> chirp_workflow::prelude::WorkflowResult { let row = ctx.listen_any(&[::NAME]).await?; Self::parse(&row.signal_name, row.body) } diff --git a/svc/pkg/cluster/src/workflows/server/mod.rs b/svc/pkg/cluster/src/workflows/server/mod.rs index 4475a9b34..92a0cb85d 100644 --- a/svc/pkg/cluster/src/workflows/server/mod.rs +++ b/svc/pkg/cluster/src/workflows/server/mod.rs @@ -199,7 +199,7 @@ pub(crate) async fn cluster_server(ctx: &mut WorkflowCtx, input: &Input) -> Glob let mut state = State::default(); loop { - match state.listen(ctx).await? { + match state.run(ctx).await? { Main::DnsCreate(_) => { ctx.workflow(dns_create::Input { server_id: input.server_id, @@ -266,10 +266,12 @@ pub(crate) async fn cluster_server(ctx: &mut WorkflowCtx, input: &Input) -> Glob } Main::Taint(_) => {} // Only for state Main::Destroy(_) => { - ctx.workflow(dns_delete::Input { - server_id: input.server_id, - }) - .await?; + if let PoolType::Gg = input.pool_type { + ctx.workflow(dns_delete::Input { + server_id: input.server_id, + }) + .await?; + } match input.provider { Provider::Linode => { @@ -671,6 +673,31 @@ struct State { } impl State { + async fn run(&mut self, ctx: &mut WorkflowCtx) -> GlobalResult
{ + let signal = ctx.custom_listener(self).await?; + + // Update state + self.transition(&signal); + + Ok(signal) + } + + fn transition(&mut self, signal: &Main) { + match signal { + Main::Drain(_) => self.draining = true, + Main::Undrain(_) => self.draining = false, + Main::Taint(_) => self.is_tainted = true, + Main::DnsCreate(_) => self.has_dns = true, + Main::DnsDelete(_) => self.has_dns = false, + _ => {} + } + } +} + +#[async_trait::async_trait] +impl CustomListener for State { + type Output = Main; + /* ==== BINARY CONDITION DECOMPOSITION ==== // state @@ -693,7 +720,7 @@ impl State { nomad registered // always nomad drain complete // if drain */ - async fn listen(&mut self, ctx: &mut WorkflowCtx) -> WorkflowResult
{ + async fn listen(&self, ctx: &ListenCtx) -> WorkflowResult { // Determine which signals to listen to let mut signals = vec![Destroy::NAME, NomadRegistered::NAME]; @@ -720,23 +747,11 @@ impl State { } let row = ctx.listen_any(&signals).await?; - let signal = Main::parse(&row.signal_name, row.body)?; - - // Update state - self.transition(&signal); - - Ok(signal) + Self::parse(&row.signal_name, row.body) } - fn transition(&mut self, signal: &Main) { - match signal { - Main::Drain(_) => self.draining = true, - Main::Undrain(_) => self.draining = false, - Main::Taint(_) => self.is_tainted = true, - Main::DnsCreate(_) => self.has_dns = true, - Main::DnsDelete(_) => self.has_dns = false, - _ => {} - } + fn parse(name: &str, body: serde_json::Value) -> WorkflowResult { + Main::parse(name, body) } }