Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/chirp-workflow/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ edition = "2021"
license = "Apache-2.0"

[dependencies]
anyhow = "1.0.82"
async-trait = "0.1.80"
chirp-client = { path = "../../chirp/client" }
chirp-workflow-macros = { path = "../macros" }
Expand All @@ -32,3 +31,6 @@ tokio = { version = "1.37.0", features = ["full"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
uuid = { version = "1.8.0", features = ["v4", "serde"] }

[dev-dependencies]
anyhow = "1.0.82"
4 changes: 2 additions & 2 deletions lib/chirp-workflow/core/src/activity.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{fmt::Debug, hash::Hash};

use anyhow::*;
use async_trait::async_trait;
use global_error::GlobalResult;
use serde::{de::DeserializeOwned, Serialize};

use crate::ActivityCtx;
Expand All @@ -13,7 +13,7 @@ pub trait Activity {

fn name() -> &'static str;

async fn run(ctx: &mut ActivityCtx, input: &Self::Input) -> Result<Self::Output>;
async fn run(ctx: &mut ActivityCtx, input: &Self::Input) -> GlobalResult<Self::Output>;
}

pub trait ActivityInput: Serialize + DeserializeOwned + Debug + Hash + Send {
Expand Down
4 changes: 2 additions & 2 deletions lib/chirp-workflow/core/src/ctx/activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ impl ActivityCtx {
name.to_string(),
std::time::Duration::from_secs(60),
conn.clone(),
// TODO: req_id
Uuid::new_v4(),
workflow_id,
// TODO: ray_id
Uuid::new_v4(),
rivet_util::timestamp::now(),
// TODO: req_ts
rivet_util::timestamp::now(),
Expand Down
3 changes: 1 addition & 2 deletions lib/chirp-workflow/core/src/ctx/workflow.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{collections::HashMap, sync::Arc};

use anyhow::*;
use serde::Serialize;
use tokio::time::Duration;
use uuid::Uuid;
Expand Down Expand Up @@ -142,7 +141,7 @@ impl WorkflowCtx {
}
}

async fn run_workflow_inner(&mut self) -> Result<()> {
async fn run_workflow_inner(&mut self) -> WorkflowResult<()> {
tracing::info!(id=%self.workflow_id, "running workflow");

// Lookup workflow
Expand Down
8 changes: 4 additions & 4 deletions lib/chirp-workflow/core/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::*;
use global_error::GlobalError;
use uuid::Uuid;

pub type WorkflowResult<T> = Result<T, WorkflowError>;
Expand All @@ -10,13 +10,13 @@ pub type WorkflowResult<T> = Result<T, WorkflowError>;
#[derive(thiserror::Error, Debug)]
pub enum WorkflowError {
#[error("workflow failure: {0:?}")]
WorkflowFailure(Error),
WorkflowFailure(GlobalError),

#[error("activity failure: {0:?}")]
ActivityFailure(Error),
ActivityFailure(GlobalError),

#[error("operation failure: {0:?}")]
OperationFailure(Error),
OperationFailure(GlobalError),

#[error("workflow missing from registry: {0}")]
WorkflowMissingFromRegistry(String),
Expand Down
4 changes: 2 additions & 2 deletions lib/chirp-workflow/core/src/operation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::*;
use async_trait::async_trait;
use global_error::GlobalResult;

use crate::OperationCtx;

Expand All @@ -10,7 +10,7 @@ pub trait Operation {

fn name() -> &'static str;

async fn run(ctx: &mut OperationCtx, input: &Self::Input) -> Result<Self::Output>;
async fn run(ctx: &mut OperationCtx, input: &Self::Input) -> GlobalResult<Self::Output>;
}

pub trait OperationInput: Send {
Expand Down
2 changes: 0 additions & 2 deletions lib/chirp-workflow/core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ pub use chirp_workflow_macros::*;

// External libraries
#[doc(hidden)]
pub use anyhow::{self, Result};
#[doc(hidden)]
pub use async_trait;
#[doc(hidden)]
pub use futures_util;
Expand Down
18 changes: 14 additions & 4 deletions lib/chirp-workflow/core/src/registry.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};

use futures_util::FutureExt;
use global_error::GlobalError;

use crate::{Workflow, WorkflowCtx, WorkflowError, WorkflowResult};

Expand Down Expand Up @@ -41,11 +42,20 @@ impl Registry {
// Run workflow
let output = match W::run(ctx, &input).await {
Ok(x) => x,
// Differentiate between WorkflowError and user error
Err(err) => {
// Differentiate between WorkflowError and user error
match err.downcast::<WorkflowError>() {
Ok(err) => return Err(err),
Err(err) => return Err(WorkflowError::WorkflowFailure(err)),
match err {
GlobalError::Raw(inner_err) => {
match inner_err.downcast::<WorkflowError>() {
Ok(inner_err) => return Err(*inner_err),
Err(err) => {
return Err(WorkflowError::WorkflowFailure(
GlobalError::Raw(err),
))
}
}
}
_ => return Err(WorkflowError::WorkflowFailure(err)),
}
}
};
Expand Down
9 changes: 5 additions & 4 deletions lib/chirp-workflow/core/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
time::{SystemTime, UNIX_EPOCH},
};

use anyhow::*;
use global_error::{macros::*, GlobalResult};
use rand::Rng;
use tokio::time::{self, Duration};

Expand All @@ -16,7 +16,7 @@ const FAULT_RATE: usize = 80;

pub async fn sleep_until_ts(ts: i64) {
let target_time = UNIX_EPOCH + Duration::from_millis(ts as u64);
if let std::result::Result::Ok(sleep_duration) = target_time.duration_since(SystemTime::now()) {
if let Ok(sleep_duration) = target_time.duration_since(SystemTime::now()) {
time::sleep(sleep_duration).await;
}
}
Expand Down Expand Up @@ -100,12 +100,13 @@ pub fn combine_events(
.map(|(k, v)| (k, v.into_iter().map(|(_, v)| v).collect()))
.collect();

WorkflowResult::Ok(event_history)
Ok(event_history)
}

pub fn inject_fault() -> Result<()> {
pub fn inject_fault() -> GlobalResult<()> {
if rand::thread_rng().gen_range(0..100) < FAULT_RATE {
bail!("This is a random panic!");
}

Ok(())
}
4 changes: 2 additions & 2 deletions lib/chirp-workflow/core/src/workflow.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::*;
use async_trait::async_trait;
use global_error::GlobalResult;
use serde::{de::DeserializeOwned, Serialize};
use std::fmt::Debug;

Expand All @@ -13,7 +13,7 @@ pub trait Workflow {
fn name() -> &'static str;

// TODO: Is there any reason for input to be a reference?
async fn run(ctx: &mut WorkflowCtx, input: &Self::Input) -> Result<Self::Output>;
async fn run(ctx: &mut WorkflowCtx, input: &Self::Input) -> GlobalResult<Self::Output>;
}

pub trait WorkflowInput: Serialize + DeserializeOwned + Debug + Send {
Expand Down
6 changes: 3 additions & 3 deletions lib/chirp-workflow/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fn trait_fn(attr: TokenStream, item: TokenStream, opts: TraitFnOpts) -> TokenStr
ReturnType::Type(_, ty) => match ty.as_ref() {
Type::Path(path) => {
let segment = path.path.segments.last().unwrap();
if segment.ident == "Result" {
if segment.ident == "GlobalResult" {
match &segment.arguments {
PathArguments::AngleBracketed(args) => {
if let Some(GenericArgument::Type(Type::Path(path))) = args.args.first()
Expand All @@ -68,7 +68,7 @@ fn trait_fn(attr: TokenStream, item: TokenStream, opts: TraitFnOpts) -> TokenStr
}
} else {
panic!(
"{} function must return a Result type",
"{} function must return a GlobalResult type",
opts.trait_ty.to_token_stream().to_string()
);
}
Expand Down Expand Up @@ -124,7 +124,7 @@ fn trait_fn(attr: TokenStream, item: TokenStream, opts: TraitFnOpts) -> TokenStr
#fn_name
}

async fn run(#ctx_ident: #ctx_ty, #input_ident: &Self::Input) -> anyhow::Result<Self::Output> {
async fn run(#ctx_ident: #ctx_ty, #input_ident: &Self::Input) -> GlobalResult<Self::Output> {
#fn_body
}
}
Expand Down
8 changes: 4 additions & 4 deletions lib/global-error/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, fmt::Display, sync::Arc};
use std::{collections::HashMap, fmt::Display};

use http::StatusCode;
use serde::Serialize;
Expand All @@ -7,7 +7,7 @@ use types::rivet::chirp;

pub type GlobalResult<T> = Result<T, GlobalError>;

#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum GlobalError {
/// Errors thrown by any part of the code, such as from sql queries, api calls, etc.
Internal {
Expand Down Expand Up @@ -35,7 +35,7 @@ pub enum GlobalError {
},
/// Any kind of error, but stored dynamically. This is used to downcast the error back into its original
/// type if needed.
Raw(Arc<dyn std::error::Error + Send + Sync>),
Raw(Box<dyn std::error::Error + Send + Sync>),
}

impl Display for GlobalError {
Expand Down Expand Up @@ -91,7 +91,7 @@ impl GlobalError {
}

pub fn raw<T: std::error::Error + Send + Sync + 'static>(err: T) -> GlobalError {
GlobalError::Raw(Arc::new(err))
GlobalError::Raw(Box::new(err))
}

pub fn bad_request_builder(code: &'static str) -> BadRequestBuilder {
Expand Down
1 change: 0 additions & 1 deletion svc/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 12 additions & 12 deletions svc/pkg/foo/worker/src/workflows/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub struct TestOutput {
}

#[workflow(Test)]
async fn test(ctx: &mut WorkflowCtx, input: &TestInput) -> Result<TestOutput> {
async fn test(ctx: &mut WorkflowCtx, input: &TestInput) -> GlobalResult<TestOutput> {
tracing::info!("input {}", input.x);

let a = ctx.activity(FooInput {}).await?;
Expand All @@ -28,29 +28,29 @@ pub struct FooOutput {
}

#[activity(Foo)]
pub fn foo(ctx: &mut ActivityCtx, input: &FooInput) -> Result<FooOutput> {
pub fn foo(ctx: &mut ActivityCtx, input: &FooInput) -> GlobalResult<FooOutput> {
chirp_workflow::util::inject_fault()?;
let ids = sql_fetch_all!(
[ctx, (Uuid,)]
"
SELECT datacenter_id
FROM db_cluster.datacenters
",
)
.await
.unwrap()
.await?
.into_iter()
.map(|(id,)| id)
.collect();

let user_id = util::uuid::parse("000b3124-91d9-472e-8104-3dcc41e1a74d").unwrap();
let user_get_res = op!([ctx] user_get {
user_ids: vec![user_id.into()],
})
.await
.unwrap();
let user = user_get_res.users.first().unwrap();
// let user_id = util::uuid::parse("000b3124-91d9-472e-8104-3dcc41e1a74d").unwrap();
// let user_get_res = op!([ctx] user_get {
// user_ids: vec![user_id.into()],
// })
// .await
// .unwrap();
// let user = user_get_res.users.first().unwrap();

tracing::info!(?user, "-----------");
// tracing::info!(?user, "-----------");

Ok(FooOutput { ids })
}