Skip to content

Commit

Permalink
fix: cast workflow errors to raw global errors
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterPtato committed May 9, 2024
1 parent f6211c3 commit b6b8a23
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 100 deletions.
7 changes: 4 additions & 3 deletions lib/chirp-workflow/core/src/ctx/activity.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use global_error::GlobalResult;
use global_error::{GlobalError, GlobalResult};
use rivet_pools::prelude::*;
use uuid::Uuid;

use crate::{
ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, WorkflowError, WorkflowResult,
ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, WorkflowError,
};

pub struct ActivityCtx {
Expand Down Expand Up @@ -50,7 +50,7 @@ impl ActivityCtx {
pub async fn op<I>(
&mut self,
input: I,
) -> WorkflowResult<<<I as OperationInput>::Operation as Operation>::Output>
) -> GlobalResult<<<I as OperationInput>::Operation as Operation>::Output>
where
I: OperationInput,
<I as OperationInput>::Operation: Operation<Input = I>,
Expand All @@ -60,6 +60,7 @@ impl ActivityCtx {
I::Operation::run(&mut ctx, &input)
.await
.map_err(WorkflowError::OperationFailure)
.map_err(GlobalError::raw)
}

pub fn name(&self) -> &str {
Expand Down
43 changes: 25 additions & 18 deletions lib/chirp-workflow/core/src/ctx/test.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::sync::Arc;

use global_error::{GlobalError, GlobalResult};
use serde::Serialize;
use tokio::time::Duration;
use uuid::Uuid;

use crate::{
DatabaseHandle, DatabasePostgres, Signal, Workflow, WorkflowError, WorkflowInput,
WorkflowResult,
};

pub type TestCtxHandle = Arc<TestCtx>;
Expand Down Expand Up @@ -49,7 +49,7 @@ impl TestCtx {
}

impl TestCtx {
pub async fn dispatch_workflow<I>(&self, input: I) -> WorkflowResult<Uuid>
pub async fn dispatch_workflow<I>(&self, input: I) -> GlobalResult<Uuid>
where
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
Expand All @@ -61,21 +61,25 @@ impl TestCtx {
let id = Uuid::new_v4();

// Serialize input
let input_str =
serde_json::to_string(&input).map_err(WorkflowError::SerializeWorkflowOutput)?;
let input_str = serde_json::to_string(&input)
.map_err(WorkflowError::SerializeWorkflowOutput)
.map_err(GlobalError::raw)?;

self.db.dispatch_workflow(id, &name, &input_str).await?;
self.db
.dispatch_workflow(id, &name, &input_str)
.await
.map_err(GlobalError::raw)?;

tracing::info!(%name, ?id, "workflow dispatched");

WorkflowResult::Ok(id)
Ok(id)
}

pub async fn wait_for_workflow<W: Workflow>(
&self,
workflow_id: Uuid,
) -> WorkflowResult<W::Output> {
tracing::info!(name = W::name(), id = ?workflow_id, "waiting for workflow");
) -> GlobalResult<W::Output> {
tracing::info!(name=W::name(), id=?workflow_id, "waiting for workflow");

let period = Duration::from_millis(50);
let mut interval = tokio::time::interval(period);
Expand All @@ -86,44 +90,47 @@ impl TestCtx {
let workflow = self
.db
.get_workflow(workflow_id)
.await?
.ok_or(WorkflowError::WorkflowNotFound)?;
if let Some(output) = workflow.parse_output::<W>()? {
return WorkflowResult::Ok(output);
.await
.map_err(GlobalError::raw)?
.ok_or(WorkflowError::WorkflowNotFound)
.map_err(GlobalError::raw)?;
if let Some(output) = workflow.parse_output::<W>().map_err(GlobalError::raw)? {
return Ok(output);
}
}
}

pub async fn workflow<I>(
&self,
input: I,
) -> WorkflowResult<<<I as WorkflowInput>::Workflow as Workflow>::Output>
) -> GlobalResult<<<I as WorkflowInput>::Workflow as Workflow>::Output>
where
I: WorkflowInput,
<I as WorkflowInput>::Workflow: Workflow<Input = I>,
{
let workflow_id = self.dispatch_workflow(input).await?;
let output = self.wait_for_workflow::<I::Workflow>(workflow_id).await?;
WorkflowResult::Ok(output)
Ok(output)
}

pub async fn signal<I: Signal + Serialize>(
&self,
workflow_id: Uuid,
input: I,
) -> WorkflowResult<Uuid> {
) -> GlobalResult<Uuid> {
tracing::debug!(name=%I::name(), %workflow_id, "dispatching signal");

let signal_id = Uuid::new_v4();

// Serialize input
let input_str =
serde_json::to_string(&input).map_err(WorkflowError::SerializeSignalBody)?;
serde_json::to_string(&input).map_err(WorkflowError::SerializeSignalBody).map_err(GlobalError::raw)?;

self.db
.publish_signal(workflow_id, signal_id, I::name(), &input_str)
.await?;
.await
.map_err(GlobalError::raw)?;

WorkflowResult::Ok(signal_id)
Ok(signal_id)
}
}
Loading

0 comments on commit b6b8a23

Please sign in to comment.