diff --git a/lib/chirp-workflow/core/src/ctx/activity.rs b/lib/chirp-workflow/core/src/ctx/activity.rs index fca683de7..a504afc6d 100644 --- a/lib/chirp-workflow/core/src/ctx/activity.rs +++ b/lib/chirp-workflow/core/src/ctx/activity.rs @@ -1,9 +1,9 @@ -use global_error::GlobalResult; +use global_error::{GlobalError, GlobalResult}; use rivet_pools::prelude::*; use uuid::Uuid; use crate::{ - ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, WorkflowError, WorkflowResult, + ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, WorkflowError, }; pub struct ActivityCtx { @@ -50,7 +50,7 @@ impl ActivityCtx { pub async fn op( &mut self, input: I, - ) -> WorkflowResult<<::Operation as Operation>::Output> + ) -> GlobalResult<<::Operation as Operation>::Output> where I: OperationInput, ::Operation: Operation, @@ -60,6 +60,7 @@ impl ActivityCtx { I::Operation::run(&mut ctx, &input) .await .map_err(WorkflowError::OperationFailure) + .map_err(GlobalError::raw) } pub fn name(&self) -> &str { diff --git a/lib/chirp-workflow/core/src/ctx/test.rs b/lib/chirp-workflow/core/src/ctx/test.rs index 52b5e6b40..a3ab6d06a 100644 --- a/lib/chirp-workflow/core/src/ctx/test.rs +++ b/lib/chirp-workflow/core/src/ctx/test.rs @@ -1,12 +1,12 @@ use std::sync::Arc; +use global_error::{GlobalError, GlobalResult}; use serde::Serialize; use tokio::time::Duration; use uuid::Uuid; use crate::{ DatabaseHandle, DatabasePostgres, Signal, Workflow, WorkflowError, WorkflowInput, - WorkflowResult, }; pub type TestCtxHandle = Arc; @@ -49,7 +49,7 @@ impl TestCtx { } impl TestCtx { - pub async fn dispatch_workflow(&self, input: I) -> WorkflowResult + pub async fn dispatch_workflow(&self, input: I) -> GlobalResult where I: WorkflowInput, ::Workflow: Workflow, @@ -61,21 +61,25 @@ impl TestCtx { let id = Uuid::new_v4(); // Serialize input - let input_str = - serde_json::to_string(&input).map_err(WorkflowError::SerializeWorkflowOutput)?; + let input_str = serde_json::to_string(&input) + .map_err(WorkflowError::SerializeWorkflowOutput) + .map_err(GlobalError::raw)?; - self.db.dispatch_workflow(id, &name, &input_str).await?; + self.db + .dispatch_workflow(id, &name, &input_str) + .await + .map_err(GlobalError::raw)?; tracing::info!(%name, ?id, "workflow dispatched"); - WorkflowResult::Ok(id) + Ok(id) } pub async fn wait_for_workflow( &self, workflow_id: Uuid, - ) -> WorkflowResult { - tracing::info!(name = W::name(), id = ?workflow_id, "waiting for workflow"); + ) -> GlobalResult { + tracing::info!(name=W::name(), id=?workflow_id, "waiting for workflow"); let period = Duration::from_millis(50); let mut interval = tokio::time::interval(period); @@ -86,10 +90,12 @@ impl TestCtx { let workflow = self .db .get_workflow(workflow_id) - .await? - .ok_or(WorkflowError::WorkflowNotFound)?; - if let Some(output) = workflow.parse_output::()? { - return WorkflowResult::Ok(output); + .await + .map_err(GlobalError::raw)? + .ok_or(WorkflowError::WorkflowNotFound) + .map_err(GlobalError::raw)?; + if let Some(output) = workflow.parse_output::().map_err(GlobalError::raw)? { + return Ok(output); } } } @@ -97,33 +103,34 @@ impl TestCtx { pub async fn workflow( &self, input: I, - ) -> WorkflowResult<<::Workflow as Workflow>::Output> + ) -> GlobalResult<<::Workflow as Workflow>::Output> where I: WorkflowInput, ::Workflow: Workflow, { let workflow_id = self.dispatch_workflow(input).await?; let output = self.wait_for_workflow::(workflow_id).await?; - WorkflowResult::Ok(output) + Ok(output) } pub async fn signal( &self, workflow_id: Uuid, input: I, - ) -> WorkflowResult { + ) -> GlobalResult { tracing::debug!(name=%I::name(), %workflow_id, "dispatching signal"); let signal_id = Uuid::new_v4(); // Serialize input let input_str = - serde_json::to_string(&input).map_err(WorkflowError::SerializeSignalBody)?; + serde_json::to_string(&input).map_err(WorkflowError::SerializeSignalBody).map_err(GlobalError::raw)?; self.db .publish_signal(workflow_id, signal_id, I::name(), &input_str) - .await?; + .await + .map_err(GlobalError::raw)?; - WorkflowResult::Ok(signal_id) + Ok(signal_id) } } diff --git a/lib/chirp-workflow/core/src/ctx/workflow.rs b/lib/chirp-workflow/core/src/ctx/workflow.rs index 7ff4562c5..93fa15b53 100644 --- a/lib/chirp-workflow/core/src/ctx/workflow.rs +++ b/lib/chirp-workflow/core/src/ctx/workflow.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, sync::Arc}; +use global_error::{GlobalError, GlobalResult}; use serde::Serialize; use tokio::time::Duration; use uuid::Uuid; @@ -58,8 +59,8 @@ impl WorkflowCtx { db: DatabaseHandle, conn: rivet_connection::Connection, workflow: PulledWorkflow, - ) -> WorkflowResult { - WorkflowResult::Ok(WorkflowCtx { + ) -> GlobalResult { + GlobalResult::Ok(WorkflowCtx { workflow_id: workflow.workflow_id, name: workflow.workflow_name, @@ -68,11 +69,14 @@ impl WorkflowCtx { conn, - event_history: Arc::new(util::combine_events( - workflow.activity_events, - workflow.signal_events, - workflow.sub_workflow_events, - )?), + event_history: Arc::new( + util::combine_events( + workflow.activity_events, + workflow.signal_events, + workflow.sub_workflow_events, + ) + .map_err(GlobalError::raw)?, + ), input: Arc::new(workflow.input), root_location: Box::new([]), @@ -305,13 +309,13 @@ impl WorkflowCtx { "signal received", ); - WorkflowResult::Ok(signal) + Ok(signal) } } impl WorkflowCtx { /// Dispatch another workflow. - pub async fn dispatch_workflow(&mut self, input: I) -> WorkflowResult + pub async fn dispatch_workflow(&mut self, input: I) -> GlobalResult where I: WorkflowInput, ::Workflow: Workflow, @@ -322,11 +326,11 @@ impl WorkflowCtx { let id = if let Some(event) = event { // Validate history is consistent let Event::SubWorkflow(sub_workflow) = event else { - return Err(WorkflowError::HistoryDiverged); + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); }; if sub_workflow.sub_workflow_name != I::Workflow::name() { - return Err(WorkflowError::HistoryDiverged); + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); } tracing::debug!( @@ -346,8 +350,9 @@ impl WorkflowCtx { let sub_workflow_id = Uuid::new_v4(); // Serialize input - let input_str = - serde_json::to_string(&input).map_err(WorkflowError::SerializeWorkflowOutput)?; + let input_str = serde_json::to_string(&input) + .map_err(WorkflowError::SerializeWorkflowOutput) + .map_err(GlobalError::raw)?; self.db .dispatch_sub_workflow( @@ -357,7 +362,8 @@ impl WorkflowCtx { &name, &input_str, ) - .await?; + .await + .map_err(GlobalError::raw)?; tracing::info!(%name, ?sub_workflow_id, "workflow dispatched"); @@ -367,14 +373,14 @@ impl WorkflowCtx { // Move to next event self.location_idx += 1; - WorkflowResult::Ok(id) + GlobalResult::Ok(id) } /// Wait for another workflow's response. pub async fn wait_for_workflow( &self, sub_workflow_id: Uuid, - ) -> WorkflowResult { + ) -> GlobalResult { tracing::info!(name = W::name(), ?sub_workflow_id, "waiting for workflow"); let mut retries = 0; @@ -388,14 +394,17 @@ impl WorkflowCtx { let workflow = self .db .get_workflow(sub_workflow_id) - .await? - .ok_or(WorkflowError::WorkflowNotFound)?; + .await + .map_err(GlobalError::raw)? + .ok_or(WorkflowError::WorkflowNotFound) + .map_err(GlobalError::raw)?; - if let Some(output) = workflow.parse_output::()? { - return WorkflowResult::Ok(output); + if let Some(output) = workflow.parse_output::().map_err(GlobalError::raw)? { + return Ok(output); } else { if retries > MAX_SUB_WORKFLOW_RETRIES { - return Err(WorkflowError::SubWorkflowIncomplete(sub_workflow_id)); + return Err(WorkflowError::SubWorkflowIncomplete(sub_workflow_id)) + .map_err(GlobalError::raw); } retries += 1; } @@ -407,7 +416,7 @@ impl WorkflowCtx { pub async fn workflow( &mut self, input: I, - ) -> WorkflowResult<<::Workflow as Workflow>::Output> + ) -> GlobalResult<<::Workflow as Workflow>::Output> where I: WorkflowInput, ::Workflow: Workflow, @@ -416,14 +425,14 @@ impl WorkflowCtx { let output = self .wait_for_workflow::(sub_workflow_id) .await?; - WorkflowResult::Ok(output) + Ok(output) } /// Run activity. Will replay on failure. pub async fn activity( &mut self, input: I, - ) -> WorkflowResult<<::Activity as Activity>::Output> + ) -> GlobalResult<<::Activity as Activity>::Output> where I: ActivityInput, ::Activity: Activity, @@ -436,36 +445,38 @@ impl WorkflowCtx { let output = if let Some(event) = event { // Validate history is consistent let Event::Activity(activity) = event else { - return Err(WorkflowError::HistoryDiverged); + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); }; if activity.activity_id != activity_id { - return Err(WorkflowError::HistoryDiverged); + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); } // Activity succeeded - if let Some(output) = activity.get_output()? { + if let Some(output) = activity.get_output().map_err(GlobalError::raw)? { output } else { // Activity failed, retry self.run_activity::(&input, &activity_id) - .await? + .await + .map_err(GlobalError::raw)? } } // This is a new activity else { self.run_activity::(&input, &activity_id) - .await? + .await + .map_err(GlobalError::raw)? }; // Move to next event self.location_idx += 1; - WorkflowResult::Ok(output) + Ok(output) } /// Joins multiple executable actions (activities, closures) and awaits them simultaneously. - pub async fn join(&mut self, exec: T) -> WorkflowResult { + pub async fn join(&mut self, exec: T) -> GlobalResult { exec.execute(self).await } @@ -474,7 +485,7 @@ impl WorkflowCtx { &mut self, workflow_id: Uuid, body: T, - ) -> WorkflowResult { + ) -> GlobalResult { let id = Uuid::new_v4(); self.db @@ -482,28 +493,31 @@ impl WorkflowCtx { workflow_id, id, T::name(), - &serde_json::to_string(&body).map_err(WorkflowError::SerializeSignalBody)?, + &serde_json::to_string(&body) + .map_err(WorkflowError::SerializeSignalBody) + .map_err(GlobalError::raw)?, ) - .await?; + .await + .map_err(GlobalError::raw)?; - WorkflowResult::Ok(id) + Ok(id) } /// Listens for a signal for a short time before setting the workflow to sleep. Once the signal is /// received, the workflow will be woken up and continue. - pub async fn listen(&mut self) -> WorkflowResult { + pub async fn listen(&mut self) -> GlobalResult { let event = { self.relevant_history().nth(self.location_idx) }; // Signal received before let signal = if let Some(event) = event { // Validate history is consistent let Event::Signal(signal) = event else { - return Err(WorkflowError::HistoryDiverged); + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); }; tracing::debug!(id=%self.workflow_id, name=%signal.name, "replaying signal"); - T::parse(&signal.name, &signal.body)? + T::parse(&signal.name, &signal.body).map_err(GlobalError::raw)? } // Listen for new messages else { @@ -517,14 +531,14 @@ impl WorkflowCtx { interval.tick().await; match T::listen(self).await { - WorkflowResult::Ok(res) => break res, + Ok(res) => break res, Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => { if retries > MAX_SIGNAL_RETRIES { - return Err(err); + return Err(err).map_err(GlobalError::raw); } retries += 1; } - err => return err, + err => return err.map_err(GlobalError::raw), } } }; @@ -532,11 +546,11 @@ impl WorkflowCtx { // Move to next event self.location_idx += 1; - WorkflowResult::Ok(signal) + Ok(signal) } /// Checks if the given signal exists in the database. - pub async fn query_signal(&mut self) -> WorkflowResult> { + pub async fn query_signal(&mut self) -> GlobalResult> { let event = { self.relevant_history().nth(self.location_idx) }; // Signal received before @@ -545,24 +559,24 @@ impl WorkflowCtx { // Validate history is consistent let Event::Signal(signal) = event else { - return Err(WorkflowError::HistoryDiverged); + return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw); }; - Some(T::parse(&signal.name, &signal.body)?) + Some(T::parse(&signal.name, &signal.body).map_err(GlobalError::raw)?) } // Listen for new message else { match T::listen(self).await { - WorkflowResult::Ok(res) => Some(res), + Ok(res) => Some(res), Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => None, - Err(err) => return Err(err), + Err(err) => return Err(err).map_err(GlobalError::raw), } }; // Move to next event self.location_idx += 1; - WorkflowResult::Ok(signal) + Ok(signal) } // TODO: sleep_for, sleep_until diff --git a/lib/chirp-workflow/core/src/executable.rs b/lib/chirp-workflow/core/src/executable.rs index daadb04fc..143b570a3 100644 --- a/lib/chirp-workflow/core/src/executable.rs +++ b/lib/chirp-workflow/core/src/executable.rs @@ -1,8 +1,9 @@ use std::{future::Future, pin::Pin}; use async_trait::async_trait; +use global_error::GlobalResult; -use crate::{WorkflowCtx, WorkflowResult}; +use crate::WorkflowCtx; /// Signifies a retryable executable entity in a workflow. For example: activity, tuple of activities (join), /// closure. @@ -10,10 +11,10 @@ use crate::{WorkflowCtx, WorkflowResult}; pub trait Executable: Send { type Output: Send; - async fn execute(self, ctx: &mut WorkflowCtx) -> WorkflowResult; + async fn execute(self, ctx: &mut WorkflowCtx) -> GlobalResult; } -type AsyncResult<'a, T> = Pin> + Send + 'a>>; +type AsyncResult<'a, T> = Pin> + Send + 'a>>; #[async_trait] impl Executable for F @@ -23,7 +24,7 @@ where { type Output = T; - async fn execute(self, ctx: &mut WorkflowCtx) -> WorkflowResult { + async fn execute(self, ctx: &mut WorkflowCtx) -> GlobalResult { let mut branch = ctx.branch(); (self)(&mut branch).await } @@ -36,7 +37,7 @@ macro_rules! impl_tuple { impl<$($args : Executable),*> Executable for ($($args),*) { type Output = ($(<$args as Executable>::Output),*); - async fn execute(self, ctx: &mut WorkflowCtx) -> WorkflowResult { + async fn execute(self, ctx: &mut WorkflowCtx) -> GlobalResult { #[allow(non_snake_case)] let ($($args),*) = self; diff --git a/lib/chirp-workflow/core/src/registry.rs b/lib/chirp-workflow/core/src/registry.rs index 100cb0372..2daf06dbc 100644 --- a/lib/chirp-workflow/core/src/registry.rs +++ b/lib/chirp-workflow/core/src/registry.rs @@ -43,21 +43,19 @@ impl Registry { let output = match W::run(ctx, &input).await { Ok(x) => x, // Differentiate between WorkflowError and user error - Err(err) => { - match err { - GlobalError::Raw(inner_err) => { - match inner_err.downcast::() { - Ok(inner_err) => return Err(*inner_err), - Err(err) => { - return Err(WorkflowError::WorkflowFailure( - GlobalError::Raw(err), - )) - } + Err(err) => match err { + GlobalError::Raw(inner_err) => { + match inner_err.downcast::() { + Ok(inner_err) => return Err(*inner_err), + Err(err) => { + return Err(WorkflowError::WorkflowFailure( + GlobalError::Raw(err), + )) } } - _ => return Err(WorkflowError::WorkflowFailure(err)), } - } + _ => return Err(WorkflowError::WorkflowFailure(err)), + }, }; // Serialize output diff --git a/lib/chirp-workflow/core/src/signal.rs b/lib/chirp-workflow/core/src/signal.rs index 25e8086ca..a4ffcbb78 100644 --- a/lib/chirp-workflow/core/src/signal.rs +++ b/lib/chirp-workflow/core/src/signal.rs @@ -59,7 +59,7 @@ macro_rules! join_signal { fn parse(name: &str, body: &str) -> ::wf::WorkflowResult { $( if name == $signals::name() { - WorkflowResult::Ok( + Ok( Self::$signals( serde_json::from_str(body) .map_err(WorkflowError::DeserializeActivityOutput)? diff --git a/lib/chirp-workflow/macros/src/lib.rs b/lib/chirp-workflow/macros/src/lib.rs index 6a096fec7..d643213ce 100644 --- a/lib/chirp-workflow/macros/src/lib.rs +++ b/lib/chirp-workflow/macros/src/lib.rs @@ -97,7 +97,7 @@ fn trait_fn(attr: TokenStream, item: TokenStream, opts: TraitFnOpts) -> TokenStr impl chirp_workflow::prelude::Executable for #input_type { type Output = <#struct_ident as #trait_ty>::Output; - async fn execute(self, ctx: &mut chirp_workflow::prelude::WorkflowCtx) -> chirp_workflow::prelude::WorkflowResult { + async fn execute(self, ctx: &mut chirp_workflow::prelude::WorkflowCtx) -> GlobalResult { ctx.activity(self).await } } diff --git a/svc/pkg/foo/worker/src/workflows/test.rs b/svc/pkg/foo/worker/src/workflows/test.rs index cddb56349..63b924ad3 100644 --- a/svc/pkg/foo/worker/src/workflows/test.rs +++ b/svc/pkg/foo/worker/src/workflows/test.rs @@ -12,8 +12,6 @@ pub struct TestOutput { #[workflow(Test)] async fn test(ctx: &mut WorkflowCtx, input: &TestInput) -> GlobalResult { - tracing::info!("input {}", input.x); - let a = ctx.activity(FooInput {}).await?; Ok(TestOutput { y: a.ids.len() }) @@ -29,7 +27,6 @@ pub struct FooOutput { #[activity(Foo)] pub fn foo(ctx: &mut ActivityCtx, input: &FooInput) -> GlobalResult { - chirp_workflow::util::inject_fault()?; let ids = sql_fetch_all!( [ctx, (Uuid,)] " @@ -42,15 +39,12 @@ pub fn foo(ctx: &mut ActivityCtx, input: &FooInput) -> GlobalResult { .map(|(id,)| id) .collect(); - // let user_id = util::uuid::parse("000b3124-91d9-472e-8104-3dcc41e1a74d").unwrap(); - // let user_get_res = op!([ctx] user_get { - // user_ids: vec![user_id.into()], - // }) - // .await - // .unwrap(); - // let user = user_get_res.users.first().unwrap(); - - // tracing::info!(?user, "-----------"); + let user_id = util::uuid::parse("000b3124-91d9-472e-8104-3dcc41e1a74d")?; + let user_get_res = op!([ctx] user_get { + user_ids: vec![user_id.into()], + }) + .await?; + let user = unwrap!(user_get_res.users.first()); Ok(FooOutput { ids }) }