diff --git a/lib/chirp-workflow/core/src/builder/common/message.rs b/lib/chirp-workflow/core/src/builder/common/message.rs new file mode 100644 index 0000000000..3134908464 --- /dev/null +++ b/lib/chirp-workflow/core/src/builder/common/message.rs @@ -0,0 +1,84 @@ +use std::fmt::Display; + +use global_error::{GlobalError, GlobalResult}; +use serde::Serialize; + +use crate::{builder::BuilderError, ctx::MessageCtx, message::Message}; + +pub struct MessageBuilder<'a, M: Message> { + msg_ctx: &'a MessageCtx, + body: M, + tags: serde_json::Map, + wait: bool, + error: Option, +} + +impl<'a, M: Message> MessageBuilder<'a, M> { + pub(crate) fn new(msg_ctx: &'a MessageCtx, body: M) -> Self { + MessageBuilder { + msg_ctx, + body, + tags: serde_json::Map::new(), + wait: false, + error: None, + } + } + + pub fn tags(mut self, tags: serde_json::Value) -> Self { + if self.error.is_some() { + return self; + } + + match tags { + serde_json::Value::Object(map) => { + self.tags.extend(map); + } + _ => self.error = Some(BuilderError::TagsNotMap.into()), + } + + self + } + + pub fn tag(mut self, k: impl Display, v: impl Serialize) -> Self { + if self.error.is_some() { + return self; + } + + match serde_json::to_value(&v) { + Ok(v) => { + self.tags.insert(k.to_string(), v); + } + Err(err) => self.error = Some(err.into()), + } + + self + } + + pub async fn wait(mut self) -> Self { + if self.error.is_some() { + return self; + } + + self.wait = true; + + self + } + + pub async fn send(self) -> GlobalResult<()> { + if let Some(err) = self.error { + return Err(err); + } + + tracing::info!(msg_name=%M::NAME, tags=?self.tags, "dispatching message"); + + let tags = serde_json::Value::Object(self.tags); + + if self.wait { + self.msg_ctx.message_wait(tags, self.body).await?; + } else { + self.msg_ctx.message(tags, self.body).await?; + } + + Ok(()) + } +} diff --git a/lib/chirp-workflow/core/src/builder/common/mod.rs b/lib/chirp-workflow/core/src/builder/common/mod.rs new file mode 100644 index 0000000000..24ece4bc11 --- /dev/null +++ b/lib/chirp-workflow/core/src/builder/common/mod.rs @@ -0,0 +1,5 @@ +//! This module contains builders used by all ctx's besides the workflow ctx. + +pub mod message; +pub mod signal; +pub mod workflow; diff --git a/lib/chirp-workflow/core/src/builder/common/signal.rs b/lib/chirp-workflow/core/src/builder/common/signal.rs new file mode 100644 index 0000000000..f112754dca --- /dev/null +++ b/lib/chirp-workflow/core/src/builder/common/signal.rs @@ -0,0 +1,111 @@ +use std::fmt::Display; + +use global_error::{GlobalError, GlobalResult}; +use serde::Serialize; +use uuid::Uuid; + +use crate::{builder::BuilderError, db::DatabaseHandle, error::WorkflowError, signal::Signal}; + +pub struct SignalBuilder { + db: DatabaseHandle, + ray_id: Uuid, + body: T, + to_workflow_id: Option, + tags: serde_json::Map, + error: Option, +} + +impl SignalBuilder { + pub(crate) fn new(db: DatabaseHandle, ray_id: Uuid, body: T) -> Self { + SignalBuilder { + db, + ray_id, + body, + to_workflow_id: None, + tags: serde_json::Map::new(), + error: None, + } + } + + pub fn to_workflow(mut self, workflow_id: Uuid) -> Self { + if self.error.is_some() { + return self; + } + + self.to_workflow_id = Some(workflow_id); + + self + } + + pub fn tags(mut self, tags: serde_json::Value) -> Self { + if self.error.is_some() { + return self; + } + + match tags { + serde_json::Value::Object(map) => { + self.tags.extend(map); + } + _ => self.error = Some(BuilderError::TagsNotMap.into()), + } + + self + } + + pub fn tag(mut self, k: impl Display, v: impl Serialize) -> Self { + if self.error.is_some() { + return self; + } + + match serde_json::to_value(&v) { + Ok(v) => { + self.tags.insert(k.to_string(), v); + } + Err(err) => self.error = Some(err.into()), + } + + self + } + + pub async fn send(self) -> GlobalResult { + if let Some(err) = self.error { + return Err(err); + } + + let signal_id = Uuid::new_v4(); + + // Serialize input + let input_val = serde_json::to_value(&self.body) + .map_err(WorkflowError::SerializeSignalBody) + .map_err(GlobalError::raw)?; + + match (self.to_workflow_id, self.tags.is_empty()) { + (Some(workflow_id), true) => { + tracing::info!(signal_name=%T::NAME, to_workflow_id=%workflow_id, %signal_id, "dispatching signal"); + + self.db + .publish_signal(self.ray_id, workflow_id, signal_id, T::NAME, input_val) + .await + .map_err(GlobalError::raw)?; + } + (None, false) => { + tracing::info!(signal_name=%T::NAME, tags=?self.tags, %signal_id, "dispatching tagged signal"); + + self.db + .publish_tagged_signal( + self.ray_id, + &serde_json::Value::Object(self.tags), + signal_id, + T::NAME, + input_val, + ) + .await + .map_err(GlobalError::raw)?; + } + (Some(_), false) => return Err(BuilderError::WorkflowIdAndTags.into()), + (None, true) => return Err(BuilderError::NoWorkflowIdOrTags.into()), + } + + Ok(signal_id) + } +} diff --git a/lib/chirp-workflow/core/src/builder/common/workflow.rs b/lib/chirp-workflow/core/src/builder/common/workflow.rs new file mode 100644 index 0000000000..ffb55655d2 --- /dev/null +++ b/lib/chirp-workflow/core/src/builder/common/workflow.rs @@ -0,0 +1,108 @@ +use std::fmt::Display; + +use global_error::{GlobalError, GlobalResult}; +use serde::Serialize; +use uuid::Uuid; + +use crate::{ + builder::BuilderError, + ctx::common, + db::DatabaseHandle, + error::WorkflowError, + workflow::{Workflow, WorkflowInput}, +}; + +pub struct WorkflowBuilder { + db: DatabaseHandle, + ray_id: Uuid, + input: I, + tags: serde_json::Map, + error: Option, +} + +impl WorkflowBuilder +where + ::Workflow: Workflow, +{ + pub(crate) fn new(db: DatabaseHandle, ray_id: Uuid, input: I) -> Self { + WorkflowBuilder { + db, + ray_id, + input, + tags: serde_json::Map::new(), + error: None, + } + } + + pub fn tags(mut self, tags: serde_json::Value) -> Self { + if self.error.is_some() { + return self; + } + + match tags { + serde_json::Value::Object(map) => { + self.tags.extend(map); + } + _ => self.error = Some(BuilderError::TagsNotMap.into()), + } + + self + } + + pub fn tag(mut self, k: impl Display, v: impl Serialize) -> Self { + if self.error.is_some() { + return self; + } + + match serde_json::to_value(&v) { + Ok(v) => { + self.tags.insert(k.to_string(), v); + } + Err(err) => self.error = Some(err.into()), + } + + self + } + + pub async fn dispatch(self) -> GlobalResult { + if let Some(err) = self.error { + return Err(err); + } + + let workflow_name = I::Workflow::NAME; + let workflow_id = Uuid::new_v4(); + + let no_tags = self.tags.is_empty(); + let tags = serde_json::Value::Object(self.tags); + let tags = if no_tags { None } else { Some(&tags) }; + + tracing::info!( + %workflow_name, + %workflow_id, + ?tags, + input=?self.input, + "dispatching workflow" + ); + + // Serialize input + let input_val = serde_json::to_value(&self.input) + .map_err(WorkflowError::SerializeWorkflowOutput) + .map_err(GlobalError::raw)?; + + self.db + .dispatch_workflow(self.ray_id, workflow_id, &workflow_name, tags, input_val) + .await + .map_err(GlobalError::raw)?; + + Ok(workflow_id) + } + + pub async fn output( + self, + ) -> GlobalResult<<::Workflow as Workflow>::Output> { + let db = self.db.clone(); + + let workflow_id = self.dispatch().await?; + common::wait_for_workflow::(&db, workflow_id).await + } +} diff --git a/lib/chirp-workflow/core/src/builder/mod.rs b/lib/chirp-workflow/core/src/builder/mod.rs new file mode 100644 index 0000000000..0509add1c7 --- /dev/null +++ b/lib/chirp-workflow/core/src/builder/mod.rs @@ -0,0 +1,14 @@ +pub mod common; +pub mod workflow; + +#[derive(thiserror::Error, Debug)] +pub(crate) enum BuilderError { + #[error("tags must be a JSON map")] + TagsNotMap, + #[error("cannot call `to_workflow` and set tags on the same signal")] + WorkflowIdAndTags, + #[error("must call `to_workflow` or set tags on signal")] + NoWorkflowIdOrTags, + #[error("cannot dispatch a workflow/signal from an operation within a workflow execution. trigger it from the workflow's body")] + CannotDispatchFromOpInWorkflow, +} diff --git a/lib/chirp-workflow/core/src/builder/workflow/message.rs b/lib/chirp-workflow/core/src/builder/workflow/message.rs new file mode 100644 index 0000000000..5351723138 --- /dev/null +++ b/lib/chirp-workflow/core/src/builder/workflow/message.rs @@ -0,0 +1,143 @@ +use std::fmt::Display; + +use global_error::{GlobalError, GlobalResult}; +use serde::Serialize; + +use crate::{ + builder::BuilderError, ctx::WorkflowCtx, error::WorkflowError, event::Event, message::Message, +}; + +pub struct MessageBuilder<'a, M: Message> { + ctx: &'a mut WorkflowCtx, + body: M, + tags: serde_json::Map, + wait: bool, + error: Option, +} + +impl<'a, M: Message> MessageBuilder<'a, M> { + pub(crate) fn new(ctx: &'a mut WorkflowCtx, body: M) -> Self { + MessageBuilder { + ctx, + body, + tags: serde_json::Map::new(), + wait: false, + error: None, + } + } + + pub fn tags(mut self, tags: serde_json::Value) -> Self { + if self.error.is_some() { + return self; + } + + match tags { + serde_json::Value::Object(map) => { + self.tags.extend(map); + } + _ => self.error = Some(BuilderError::TagsNotMap.into()), + } + + self + } + + pub fn tag(mut self, k: impl Display, v: impl Serialize) -> Self { + if self.error.is_some() { + return self; + } + + match serde_json::to_value(&v) { + Ok(v) => { + self.tags.insert(k.to_string(), v); + } + Err(err) => self.error = Some(err.into()), + } + + self + } + + pub async fn wait(mut self) -> Self { + if self.error.is_some() { + return self; + } + + self.wait = true; + + self + } + + pub async fn send(self) -> GlobalResult<()> { + if let Some(err) = self.error { + return Err(err); + } + + let event = self.ctx.current_history_event(); + + // Message sent before + if let Some(event) = event { + // Validate history is consistent + let Event::MessageSend(msg) = event else { + return Err(WorkflowError::HistoryDiverged(format!( + "expected {event} at {}, found message send {}", + self.ctx.loc(), + M::NAME, + ))) + .map_err(GlobalError::raw); + }; + + if msg.name != M::NAME { + return Err(WorkflowError::HistoryDiverged(format!( + "expected {event} at {}, found message send {}", + self.ctx.loc(), + M::NAME, + ))) + .map_err(GlobalError::raw); + } + + tracing::debug!(name=%self.ctx.name(), id=%self.ctx.workflow_id(), msg_name=%msg.name, "replaying message dispatch"); + } + // Send message + else { + tracing::info!(name=%self.ctx.name(), id=%self.ctx.workflow_id(), msg_name=%M::NAME, tags=?self.tags, "dispatching message"); + + // Serialize body + let body_val = serde_json::to_value(&self.body) + .map_err(WorkflowError::SerializeMessageBody) + .map_err(GlobalError::raw)?; + let location = self.ctx.full_location(); + let tags = serde_json::Value::Object(self.tags); + let tags2 = tags.clone(); + + let (msg, write) = tokio::join!( + async { + self.ctx + .db() + .commit_workflow_message_send_event( + self.ctx.workflow_id(), + location.as_ref(), + &tags, + M::NAME, + body_val, + self.ctx.loop_location(), + ) + .await + }, + async { + if self.wait { + self.ctx.msg_ctx().message_wait(tags2, self.body).await + } else { + self.ctx.msg_ctx().message(tags2, self.body).await + } + }, + ); + + msg.map_err(GlobalError::raw)?; + write.map_err(GlobalError::raw)?; + } + + // Move to next event + self.ctx.inc_location(); + + Ok(()) + } +} diff --git a/lib/chirp-workflow/core/src/builder/workflow/mod.rs b/lib/chirp-workflow/core/src/builder/workflow/mod.rs new file mode 100644 index 0000000000..797702456b --- /dev/null +++ b/lib/chirp-workflow/core/src/builder/workflow/mod.rs @@ -0,0 +1,5 @@ +//! This module contains builders used specifically by the workflow ctx. + +pub mod message; +pub mod signal; +pub mod sub_workflow; diff --git a/lib/chirp-workflow/core/src/builder/workflow/signal.rs b/lib/chirp-workflow/core/src/builder/workflow/signal.rs new file mode 100644 index 0000000000..c06489e262 --- /dev/null +++ b/lib/chirp-workflow/core/src/builder/workflow/signal.rs @@ -0,0 +1,158 @@ +use std::fmt::Display; + +use global_error::{GlobalError, GlobalResult}; +use serde::Serialize; +use uuid::Uuid; + +use crate::{ + builder::BuilderError, ctx::WorkflowCtx, error::WorkflowError, event::Event, signal::Signal, +}; + +pub struct SignalBuilder<'a, T: Signal + Serialize> { + ctx: &'a mut WorkflowCtx, + body: T, + to_workflow_id: Option, + tags: serde_json::Map, + error: Option, +} + +impl<'a, T: Signal + Serialize> SignalBuilder<'a, T> { + pub(crate) fn new(ctx: &'a mut WorkflowCtx, body: T) -> Self { + SignalBuilder { + ctx, + body, + to_workflow_id: None, + tags: serde_json::Map::new(), + error: None, + } + } + + pub fn to_workflow(mut self, workflow_id: Uuid) -> Self { + if self.error.is_some() { + return self; + } + + self.to_workflow_id = Some(workflow_id); + + self + } + + pub fn tags(mut self, tags: serde_json::Value) -> Self { + if self.error.is_some() { + return self; + } + + match tags { + serde_json::Value::Object(map) => { + self.tags.extend(map); + } + _ => self.error = Some(BuilderError::TagsNotMap.into()), + } + + self + } + + pub fn tag(mut self, k: impl Display, v: impl Serialize) -> Self { + if self.error.is_some() { + return self; + } + + match serde_json::to_value(&v) { + Ok(v) => { + self.tags.insert(k.to_string(), v); + } + Err(err) => self.error = Some(err.into()), + } + + self + } + + pub async fn send(self) -> GlobalResult { + if let Some(err) = self.error { + return Err(err); + } + + let event = self.ctx.current_history_event(); + + // Signal sent before + if let Some(event) = event { + // Validate history is consistent + let Event::SignalSend(signal) = event else { + return Err(WorkflowError::HistoryDiverged(format!( + "expected {event} at {}, found signal send {}", + self.ctx.loc(), + T::NAME + ))) + .map_err(GlobalError::raw); + }; + + if signal.name != T::NAME { + return Err(WorkflowError::HistoryDiverged(format!( + "expected {event} at {}, found signal send {}", + self.ctx.loc(), + T::NAME + ))) + .map_err(GlobalError::raw); + } + + tracing::debug!(name=%self.ctx.name(), id=%self.ctx.workflow_id(), signal_name=%signal.name, signal_id=%signal.signal_id, "replaying signal dispatch"); + + Ok(signal.signal_id) + } + // Send signal + else { + let signal_id = Uuid::new_v4(); + + // Serialize input + let input_val = serde_json::to_value(&self.body) + .map_err(WorkflowError::SerializeSignalBody) + .map_err(GlobalError::raw)?; + + match (self.to_workflow_id, self.tags.is_empty()) { + (Some(workflow_id), true) => { + tracing::info!(name=%self.ctx.name(), id=%self.ctx.workflow_id(), signal_name=%T::NAME, to_workflow_id=%workflow_id, %signal_id, "dispatching signal"); + + self.ctx + .db() + .publish_signal_from_workflow( + self.ctx.workflow_id(), + self.ctx.full_location().as_ref(), + self.ctx.ray_id(), + workflow_id, + signal_id, + T::NAME, + input_val, + self.ctx.loop_location(), + ) + .await + .map_err(GlobalError::raw)?; + } + (None, false) => { + tracing::info!(name=%self.ctx.name(), id=%self.ctx.workflow_id(), signal_name=%T::NAME, tags=?self.tags, %signal_id, "dispatching tagged signal"); + + self.ctx + .db() + .publish_tagged_signal_from_workflow( + self.ctx.workflow_id(), + self.ctx.full_location().as_ref(), + self.ctx.ray_id(), + &serde_json::Value::Object(self.tags), + signal_id, + T::NAME, + input_val, + self.ctx.loop_location(), + ) + .await + .map_err(GlobalError::raw)?; + } + (Some(_), false) => return Err(BuilderError::WorkflowIdAndTags.into()), + (None, true) => return Err(BuilderError::NoWorkflowIdOrTags.into()), + } + + // Move to next event + self.ctx.inc_location(); + + Ok(signal_id) + } + } +} diff --git a/lib/chirp-workflow/core/src/builder/workflow/sub_workflow.rs b/lib/chirp-workflow/core/src/builder/workflow/sub_workflow.rs new file mode 100644 index 0000000000..c0cfe46f3f --- /dev/null +++ b/lib/chirp-workflow/core/src/builder/workflow/sub_workflow.rs @@ -0,0 +1,171 @@ +use std::{fmt::Display, sync::Arc}; + +use global_error::{GlobalError, GlobalResult}; +use serde::Serialize; +use uuid::Uuid; + +use crate::{ + builder::BuilderError, + ctx::WorkflowCtx, + error::WorkflowError, + workflow::{Workflow, WorkflowInput}, +}; + +pub struct SubWorkflowBuilder<'a, I: WorkflowInput> { + ctx: &'a mut WorkflowCtx, + input: I, + tags: serde_json::Map, + error: Option, +} + +impl<'a, I: WorkflowInput> SubWorkflowBuilder<'a, I> +where + ::Workflow: Workflow, +{ + pub(crate) fn new(ctx: &'a mut WorkflowCtx, input: I) -> Self { + SubWorkflowBuilder { + ctx, + input, + tags: serde_json::Map::new(), + error: None, + } + } + + pub fn tags(mut self, tags: serde_json::Value) -> Self { + if self.error.is_some() { + return self; + } + + match tags { + serde_json::Value::Object(map) => { + self.tags.extend(map); + } + _ => self.error = Some(BuilderError::TagsNotMap.into()), + } + + self + } + + pub fn tag(mut self, k: impl Display, v: impl Serialize) -> Self { + if self.error.is_some() { + return self; + } + + match serde_json::to_value(&v) { + Ok(v) => { + self.tags.insert(k.to_string(), v); + } + Err(err) => self.error = Some(err.into()), + } + + self + } + + pub async fn dispatch(self) -> GlobalResult { + if let Some(err) = self.error { + return Err(err); + } + + let sub_workflow_name = I::Workflow::NAME; + let sub_workflow_id = Uuid::new_v4(); + + let no_tags = self.tags.is_empty(); + let tags = serde_json::Value::Object(self.tags); + let tags = if no_tags { None } else { Some(&tags) }; + + tracing::info!( + name=%self.ctx.name(), + id=%self.ctx.workflow_id(), + %sub_workflow_name, + %sub_workflow_id, + ?tags, + input=?self.input, + "dispatching sub workflow" + ); + + // Serialize input + let input_val = serde_json::to_value(&self.input) + .map_err(WorkflowError::SerializeWorkflowOutput) + .map_err(GlobalError::raw)?; + + self.ctx + .db() + .dispatch_sub_workflow( + self.ctx.ray_id(), + self.ctx.workflow_id(), + self.ctx.full_location().as_ref(), + sub_workflow_id, + &sub_workflow_name, + tags, + input_val, + self.ctx.loop_location(), + ) + .await + .map_err(GlobalError::raw)?; + + tracing::info!( + name=%self.ctx.name(), + id=%self.ctx.workflow_id(), + %sub_workflow_name, + ?sub_workflow_id, + "sub workflow dispatched" + ); + + Ok(sub_workflow_id) + } + + pub async fn output( + self, + ) -> GlobalResult<<::Workflow as Workflow>::Output> { + if let Some(err) = self.error { + return Err(err); + } + + let no_tags = self.tags.is_empty(); + let tags = serde_json::Value::Object(self.tags); + let tags = if no_tags { None } else { Some(&tags) }; + + // Lookup workflow + let Ok(workflow) = self.ctx.registry().get_workflow(I::Workflow::NAME) else { + tracing::warn!( + name=%self.ctx.name(), + id=%self.ctx.workflow_id(), + sub_workflow_name=%I::Workflow::NAME, + "sub workflow not found in current registry", + ); + + // TODO(RVT-3755): If a sub workflow is dispatched, then the worker is updated to include the sub + // worker in the registry, this will diverge in history because it will try to run the sub worker + // in-process during the replay + // If the workflow isn't in the current registry, dispatch the workflow instead + let sub_workflow_id = self.ctx.dispatch_workflow_inner(tags, self.input).await?; + let output = self + .ctx + .wait_for_workflow::(sub_workflow_id) + .await?; + + return Ok(output); + }; + + tracing::info!(name=%self.ctx.name(), id=%self.ctx.workflow_id(), sub_workflow_name=%I::Workflow::NAME, "running sub workflow"); + + // TODO(RVT-3756): This is redundant with the deserialization in `workflow.run` in the registry + // Create a new branched workflow context for the sub workflow + let mut ctx = self + .ctx + .with_input(Arc::new(serde_json::to_value(&self.input)?)); + + // Run workflow + let output = (workflow.run)(&mut ctx).await.map_err(GlobalError::raw)?; + + // TODO: RVT-3756 + // Deserialize output + let output = serde_json::from_value(output) + .map_err(WorkflowError::DeserializeWorkflowOutput) + .map_err(GlobalError::raw)?; + + self.ctx.inc_location(); + + Ok(output) + } +} diff --git a/lib/chirp-workflow/core/src/compat.rs b/lib/chirp-workflow/core/src/compat.rs index e086770684..ae84824ed5 100644 --- a/lib/chirp-workflow/core/src/compat.rs +++ b/lib/chirp-workflow/core/src/compat.rs @@ -1,5 +1,4 @@ -// Forwards compatibility from old operation ctx to new workflows - +/// Forwards compatibility from old operation ctx to new workflows. use std::fmt::Debug; use global_error::prelude::*; @@ -7,121 +6,32 @@ use serde::Serialize; use uuid::Uuid; use crate::{ + builder::common as builder, + builder::BuilderError, ctx::{ api::WORKFLOW_TIMEOUT, + common, message::{MessageCtx, SubscriptionHandle}, - workflow::SUB_WORKFLOW_RETRY, - OperationCtx, }, db::{DatabaseHandle, DatabasePgNats}, - error::WorkflowError, message::Message, operation::{Operation, OperationInput}, signal::Signal, workflow::{Workflow, WorkflowInput}, }; -pub async fn dispatch_workflow( - ctx: &rivet_operation::OperationContext, - input: I, -) -> GlobalResult -where - I: WorkflowInput, - ::Workflow: Workflow, - B: Debug + Clone, -{ - if ctx.from_workflow { - bail!("cannot dispatch a workflow from an operation within a workflow execution. trigger it from the workflow's body."); - } - - let workflow_name = I::Workflow::NAME; - let workflow_id = Uuid::new_v4(); - - tracing::info!(%workflow_name, %workflow_id, ?input, "dispatching workflow"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeWorkflowOutput) - .map_err(GlobalError::raw)?; - - db_from_ctx(ctx) - .await? - .dispatch_workflow(ctx.ray_id(), workflow_id, &workflow_name, None, input_val) - .await - .map_err(GlobalError::raw)?; - - tracing::info!(%workflow_name, ?workflow_id, "workflow dispatched"); - - Ok(workflow_id) -} - -pub async fn dispatch_tagged_workflow( - ctx: &rivet_operation::OperationContext, - tags: &serde_json::Value, - input: I, -) -> GlobalResult -where - I: WorkflowInput, - ::Workflow: Workflow, - B: Debug + Clone, -{ - if ctx.from_workflow { - bail!("cannot dispatch a workflow from an operation within a workflow execution. trigger it from the workflow's body."); - } - - let workflow_name = I::Workflow::NAME; - let workflow_id = Uuid::new_v4(); - - tracing::info!(%workflow_name, %workflow_id, ?input, "dispatching tagged workflow"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeWorkflowOutput) - .map_err(GlobalError::raw)?; - - db_from_ctx(ctx) - .await? - .dispatch_workflow( - ctx.ray_id(), - workflow_id, - &workflow_name, - Some(tags), - input_val, - ) - .await - .map_err(GlobalError::raw)?; - - tracing::info!(%workflow_name, ?workflow_id, "workflow tagged dispatched"); - - Ok(workflow_id) -} - /// Wait for a given workflow to complete. /// 60 second timeout. pub async fn wait_for_workflow( ctx: &rivet_operation::OperationContext, workflow_id: Uuid, ) -> GlobalResult { - tracing::info!(sub_workflow_name=W::NAME, sub_workflow_id=?workflow_id, "waiting for workflow"); - - tokio::time::timeout(WORKFLOW_TIMEOUT, async move { - let mut interval = tokio::time::interval(SUB_WORKFLOW_RETRY); - loop { - interval.tick().await; + let db = db_from_ctx(ctx).await?; - // Check if state finished - let workflow = db_from_ctx(ctx) - .await? - .get_workflow(workflow_id) - .await - .map_err(GlobalError::raw)? - .ok_or(WorkflowError::WorkflowNotFound) - .map_err(GlobalError::raw)?; - if let Some(output) = workflow.parse_output::().map_err(GlobalError::raw)? { - return Ok(output); - } - } - }) + tokio::time::timeout( + WORKFLOW_TIMEOUT, + common::wait_for_workflow::(&db, workflow_id), + ) .await? } @@ -129,85 +39,37 @@ pub async fn wait_for_workflow( pub async fn workflow( ctx: &rivet_operation::OperationContext, input: I, -) -> GlobalResult<<::Workflow as Workflow>::Output> +) -> GlobalResult> where I: WorkflowInput, ::Workflow: Workflow, B: Debug + Clone, { - let workflow_id = dispatch_workflow(ctx, input).await?; - - wait_for_workflow::(ctx, workflow_id).await -} - -/// Dispatch a new workflow and wait for it to complete. Has a 60s timeout. -pub async fn tagged_workflow( - ctx: &rivet_operation::OperationContext, - tags: &serde_json::Value, - input: I, -) -> GlobalResult<<::Workflow as Workflow>::Output> -where - I: WorkflowInput, - ::Workflow: Workflow, - B: Debug + Clone, -{ - let workflow_id = dispatch_tagged_workflow(ctx, tags, input).await?; - - wait_for_workflow::(ctx, workflow_id).await -} - -pub async fn signal( - ctx: &rivet_operation::OperationContext, - workflow_id: Uuid, - input: I, -) -> GlobalResult { if ctx.from_workflow { - bail!("cannot dispatch a signal from an operation within a workflow execution. trigger it from the workflow's body."); + return Err(BuilderError::CannotDispatchFromOpInWorkflow.into()); } - let signal_id = Uuid::new_v4(); - - tracing::info!(signal_name=%I::NAME, to_workflow_id=%workflow_id, %signal_id, "dispatching signal"); + let db = db_from_ctx(ctx).await?; - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; - - db_from_ctx(ctx) - .await? - .publish_signal(ctx.ray_id(), workflow_id, signal_id, I::NAME, input_val) - .await - .map_err(GlobalError::raw)?; - - Ok(signal_id) + Ok(builder::workflow::WorkflowBuilder::new( + db, + ctx.ray_id(), + input, + )) } -pub async fn tagged_signal( +/// Creates a signal builder. +pub async fn signal( ctx: &rivet_operation::OperationContext, - tags: &serde_json::Value, - input: I, -) -> GlobalResult { + body: T, +) -> GlobalResult> { if ctx.from_workflow { - bail!("cannot dispatch a signal from an operation within a workflow execution. trigger it from the workflow's body."); + return Err(BuilderError::CannotDispatchFromOpInWorkflow.into()); } - let signal_id = Uuid::new_v4(); - - tracing::info!(signal_name=%I::NAME, ?tags, %signal_id, "dispatching tagged signal"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; + let db = db_from_ctx(ctx).await?; - db_from_ctx(ctx) - .await? - .publish_tagged_signal(ctx.ray_id(), tags, signal_id, I::NAME, input_val) - .await - .map_err(GlobalError::raw)?; - - Ok(signal_id) + Ok(builder::signal::SignalBuilder::new(db, ctx.ray_id(), body)) } #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] @@ -220,25 +82,16 @@ where ::Operation: Operation, B: Debug + Clone, { - tracing::info!(?input, "operation call"); - - let ctx = OperationCtx::new( - db_from_ctx(ctx).await?, + let db = db_from_ctx(ctx).await?; + common::op( + &db, ctx.conn(), ctx.ray_id(), ctx.req_ts(), ctx.from_workflow(), - I::Operation::NAME, - ); - - let res = I::Operation::run(&ctx, &input) - .await - .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw); - - tracing::info!(?res, "operation response"); - - res + input, + ) + .await } pub async fn subscribe( @@ -256,7 +109,7 @@ where msg_ctx.subscribe::(tags).await.map_err(GlobalError::raw) } -// Get crdb pool as a trait object +// Get pool as a trait object async fn db_from_ctx( ctx: &rivet_operation::OperationContext, ) -> GlobalResult { diff --git a/lib/chirp-workflow/core/src/ctx/activity.rs b/lib/chirp-workflow/core/src/ctx/activity.rs index ddb8b8020c..07b82c987d 100644 --- a/lib/chirp-workflow/core/src/ctx/activity.rs +++ b/lib/chirp-workflow/core/src/ctx/activity.rs @@ -3,9 +3,8 @@ use rivet_pools::prelude::*; use uuid::Uuid; use crate::{ - ctx::OperationCtx, + ctx::common, db::DatabaseHandle, - error::WorkflowError, operation::{Operation, OperationInput}, }; @@ -70,26 +69,15 @@ impl ActivityCtx { I: OperationInput, ::Operation: Operation, { - tracing::info!(activity_name=%self.name, ?input, "operation call"); - - let ctx = OperationCtx::new( - self.db.clone(), + common::op( + &self.db, &self.conn, self.ray_id, self.op_ctx.req_ts(), true, - I::Operation::NAME, - ); - - let res = tokio::time::timeout(I::Operation::TIMEOUT, I::Operation::run(&ctx, &input)) - .await - .map_err(|_| WorkflowError::OperationTimeout)? - .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw); - - tracing::info!(activity_name=%self.name, ?res, "operation response"); - - res + input, + ) + .await } pub async fn update_workflow_tags(&self, tags: &serde_json::Value) -> GlobalResult<()> { diff --git a/lib/chirp-workflow/core/src/ctx/api.rs b/lib/chirp-workflow/core/src/ctx/api.rs index 33a1115b1a..699487b4f2 100644 --- a/lib/chirp-workflow/core/src/ctx/api.rs +++ b/lib/chirp-workflow/core/src/ctx/api.rs @@ -6,13 +6,12 @@ use serde::Serialize; use uuid::Uuid; use crate::{ + builder::common as builder, ctx::{ + common, message::{MessageCtx, SubscriptionHandle, TailAnchor, TailAnchorResponse}, - workflow::SUB_WORKFLOW_RETRY, - OperationCtx, }, db::DatabaseHandle, - error::WorkflowError, error::WorkflowResult, message::{Message, ReceivedMessage}, operation::{Operation, OperationInput}, @@ -71,158 +70,27 @@ impl ApiCtx { } impl ApiCtx { - pub async fn dispatch_workflow(&self, input: I) -> GlobalResult - where - I: WorkflowInput, - ::Workflow: Workflow, - { - let name = I::Workflow::NAME; - let id = Uuid::new_v4(); - - tracing::info!(workflow_name=%name, workflow_id=%id, ?input, "dispatching workflow"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeWorkflowOutput) - .map_err(GlobalError::raw)?; - - self.db - .dispatch_workflow(self.ray_id, id, &name, None, input_val) - .await - .map_err(GlobalError::raw)?; - - tracing::info!(workflow_name=%name, workflow_id=%id, "workflow dispatched"); - - Ok(id) - } - - pub async fn dispatch_tagged_workflow( - &self, - tags: &serde_json::Value, - input: I, - ) -> GlobalResult - where - I: WorkflowInput, - ::Workflow: Workflow, - { - let name = I::Workflow::NAME; - let id = Uuid::new_v4(); - - tracing::info!(workflow_name=%name, workflow_id=%id, ?tags, ?input, "dispatching tagged workflow"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeWorkflowOutput) - .map_err(GlobalError::raw)?; - - self.db - .dispatch_workflow(self.ray_id, id, &name, Some(tags), input_val) - .await - .map_err(GlobalError::raw)?; - - tracing::info!(workflow_name=%name, workflow_id=%id, "tagged workflow dispatched"); - - Ok(id) - } - /// Wait for a given workflow to complete. /// 60 second timeout. pub async fn wait_for_workflow( &self, workflow_id: Uuid, ) -> GlobalResult { - tracing::info!(workflow_name=%W::NAME, %workflow_id, "waiting for workflow"); - - tokio::time::timeout(WORKFLOW_TIMEOUT, async move { - let mut interval = tokio::time::interval(SUB_WORKFLOW_RETRY); - loop { - interval.tick().await; - - // Check if state finished - let workflow = self - .db - .get_workflow(workflow_id) - .await - .map_err(GlobalError::raw)? - .ok_or(WorkflowError::WorkflowNotFound) - .map_err(GlobalError::raw)?; - if let Some(output) = workflow.parse_output::().map_err(GlobalError::raw)? { - return Ok(output); - } - } - }) - .await? + common::wait_for_workflow::(&self.db, workflow_id).await } - /// Dispatch a new workflow and wait for it to complete. Has a 60s timeout. - pub async fn workflow( - &self, - input: I, - ) -> GlobalResult<<::Workflow as Workflow>::Output> + /// Creates a workflow builder. + pub fn workflow(&self, input: I) -> builder::workflow::WorkflowBuilder where I: WorkflowInput, ::Workflow: Workflow, { - let workflow_id = self.dispatch_workflow(input).await?; - self.wait_for_workflow::(workflow_id).await + builder::workflow::WorkflowBuilder::new(self.db.clone(), self.ray_id, input) } - /// Dispatch a new workflow with tags and wait for it to complete. Has a 60s timeout. - pub async fn tagged_workflow( - &self, - tags: &serde_json::Value, - input: I, - ) -> GlobalResult<<::Workflow as Workflow>::Output> - where - I: WorkflowInput, - ::Workflow: Workflow, - { - let workflow_id = self.dispatch_tagged_workflow(tags, input).await?; - self.wait_for_workflow::(workflow_id).await - } - - pub async fn signal( - &self, - workflow_id: Uuid, - input: T, - ) -> GlobalResult { - let signal_id = Uuid::new_v4(); - - tracing::info!(signal_name=%T::NAME, to_workflow_id=%workflow_id, %signal_id, "dispatching signal"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; - - self.db - .publish_signal(self.ray_id, workflow_id, signal_id, T::NAME, input_val) - .await - .map_err(GlobalError::raw)?; - - Ok(signal_id) - } - - pub async fn tagged_signal( - &self, - tags: &serde_json::Value, - input: T, - ) -> GlobalResult { - let signal_id = Uuid::new_v4(); - - tracing::info!(signal_name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; - - self.db - .publish_tagged_signal(self.ray_id, tags, signal_id, T::NAME, input_val) - .await - .map_err(GlobalError::raw)?; - - Ok(signal_id) + /// Creates a signal builder. + pub fn signal(&self, body: T) -> builder::signal::SignalBuilder { + builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body) } #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] @@ -234,25 +102,15 @@ impl ApiCtx { I: OperationInput, ::Operation: Operation, { - tracing::info!(?input, "operation call"); - - let ctx = OperationCtx::new( - self.db.clone(), + common::op( + &self.db, &self.conn, self.ray_id, self.op_ctx.req_ts(), false, - I::Operation::NAME, - ); - - let res = I::Operation::run(&ctx, &input) - .await - .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw); - - tracing::info!(?res, "operation response"); - - res + input, + ) + .await } pub async fn subscribe( diff --git a/lib/chirp-workflow/core/src/ctx/common.rs b/lib/chirp-workflow/core/src/ctx/common.rs new file mode 100644 index 0000000000..6cd2726b37 --- /dev/null +++ b/lib/chirp-workflow/core/src/ctx/common.rs @@ -0,0 +1,74 @@ +use std::time::Duration; + +use global_error::{GlobalError, GlobalResult}; +use uuid::Uuid; + +/// Poll interval when polling for a sub workflow in-process +pub const SUB_WORKFLOW_RETRY: Duration = Duration::from_millis(150); +/// Time to delay a workflow from retrying after an error +pub const RETRY_TIMEOUT_MS: usize = 2000; + +use crate::{ + ctx::OperationCtx, + db::DatabaseHandle, + error::WorkflowError, + operation::{Operation, OperationInput}, + workflow::Workflow, +}; + +/// Polls the database for the workflow +pub async fn wait_for_workflow( + db: &DatabaseHandle, + workflow_id: Uuid, +) -> GlobalResult { + tracing::info!(workflow_name=%W::NAME, %workflow_id, "waiting for workflow"); + + let mut interval = tokio::time::interval(SUB_WORKFLOW_RETRY); + loop { + interval.tick().await; + + // Check if state finished + let workflow = db + .get_workflow(workflow_id) + .await + .map_err(GlobalError::raw)? + .ok_or(WorkflowError::WorkflowNotFound) + .map_err(GlobalError::raw)?; + if let Some(output) = workflow.parse_output::().map_err(GlobalError::raw)? { + return Ok(output); + } + } +} + +pub async fn op( + db: &DatabaseHandle, + conn: &rivet_connection::Connection, + ray_id: Uuid, + req_ts: i64, + from_workflow: bool, + input: I, +) -> GlobalResult<<::Operation as Operation>::Output> +where + I: OperationInput, + ::Operation: Operation, +{ + tracing::info!(?input, "operation call"); + + let ctx = OperationCtx::new( + db.clone(), + conn, + ray_id, + req_ts, + from_workflow, + I::Operation::NAME, + ); + + let res = I::Operation::run(&ctx, &input) + .await + .map_err(WorkflowError::OperationFailure) + .map_err(GlobalError::raw); + + tracing::info!(?res, "operation response"); + + res +} diff --git a/lib/chirp-workflow/core/src/ctx/listen.rs b/lib/chirp-workflow/core/src/ctx/listen.rs index d38f319333..b7832825e1 100644 --- a/lib/chirp-workflow/core/src/ctx/listen.rs +++ b/lib/chirp-workflow/core/src/ctx/listen.rs @@ -19,7 +19,7 @@ impl<'a> ListenCtx<'a> { // Fetch new pending signal let signal = self .ctx - .db + .db() .pull_next_signal( self.ctx.workflow_id(), signal_names, diff --git a/lib/chirp-workflow/core/src/ctx/mod.rs b/lib/chirp-workflow/core/src/ctx/mod.rs index 97d6695cc8..fc5774cb80 100644 --- a/lib/chirp-workflow/core/src/ctx/mod.rs +++ b/lib/chirp-workflow/core/src/ctx/mod.rs @@ -1,6 +1,7 @@ mod activity; pub(crate) mod api; mod backfill; +pub(crate) mod common; mod listen; pub mod message; mod operation; diff --git a/lib/chirp-workflow/core/src/ctx/operation.rs b/lib/chirp-workflow/core/src/ctx/operation.rs index 55041efac0..923bc455c6 100644 --- a/lib/chirp-workflow/core/src/ctx/operation.rs +++ b/lib/chirp-workflow/core/src/ctx/operation.rs @@ -1,11 +1,12 @@ -use global_error::{GlobalError, GlobalResult}; +use global_error::GlobalResult; use rivet_pools::prelude::*; use serde::Serialize; use uuid::Uuid; use crate::{ + builder::common as builder, + ctx::common, db::DatabaseHandle, - error::WorkflowError, operation::{Operation, OperationInput}, signal::Signal, }; @@ -69,69 +70,20 @@ impl OperationCtx { I: OperationInput, ::Operation: Operation, { - tracing::info!(?input, "operation call"); - - let ctx = OperationCtx::new( - self.db.clone(), + common::op( + &self.db, &self.conn, self.ray_id, self.op_ctx.req_ts(), self.op_ctx.from_workflow(), - I::Operation::NAME, - ); - - let res = I::Operation::run(&ctx, &input) - .await - .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw); - - tracing::info!(?res, "operation response"); - - res + input, + ) + .await } - pub async fn signal( - &self, - workflow_id: Uuid, - input: T, - ) -> GlobalResult { - let signal_id = Uuid::new_v4(); - - tracing::info!(signal_name=%T::NAME, %workflow_id, %signal_id, "dispatching signal"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; - - self.db - .publish_signal(self.ray_id, workflow_id, signal_id, T::NAME, input_val) - .await - .map_err(GlobalError::raw)?; - - Ok(signal_id) - } - - pub async fn tagged_signal( - &self, - tags: &serde_json::Value, - input: T, - ) -> GlobalResult { - let signal_id = Uuid::new_v4(); - - tracing::info!(signal_name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; - - self.db - .publish_tagged_signal(self.ray_id, tags, signal_id, T::NAME, input_val) - .await - .map_err(GlobalError::raw)?; - - Ok(signal_id) + /// Creates a signal builder. + pub fn signal(&self, body: T) -> builder::signal::SignalBuilder { + builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body) } } diff --git a/lib/chirp-workflow/core/src/ctx/standalone.rs b/lib/chirp-workflow/core/src/ctx/standalone.rs index dee5c021a3..705176a56e 100644 --- a/lib/chirp-workflow/core/src/ctx/standalone.rs +++ b/lib/chirp-workflow/core/src/ctx/standalone.rs @@ -4,14 +4,13 @@ use serde::Serialize; use uuid::Uuid; use crate::{ + builder::common as builder, ctx::{ - api::WORKFLOW_TIMEOUT, + common, message::{SubscriptionHandle, TailAnchor, TailAnchorResponse}, - workflow::SUB_WORKFLOW_RETRY, - MessageCtx, OperationCtx, + MessageCtx, }, db::DatabaseHandle, - error::WorkflowError, error::WorkflowResult, message::{Message, ReceivedMessage}, operation::{Operation, OperationInput}, @@ -70,158 +69,27 @@ impl StandaloneCtx { } impl StandaloneCtx { - pub async fn dispatch_workflow(&self, input: I) -> GlobalResult - where - I: WorkflowInput, - ::Workflow: Workflow, - { - let name = I::Workflow::NAME; - let id = Uuid::new_v4(); - - tracing::info!(workflow_name=%name, workflow_id=%id, ?input, "dispatching workflow"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeWorkflowOutput) - .map_err(GlobalError::raw)?; - - self.db - .dispatch_workflow(self.ray_id, id, &name, None, input_val) - .await - .map_err(GlobalError::raw)?; - - tracing::info!(workflow_name=%name, workflow_id=%id, "workflow dispatched"); - - Ok(id) - } - - pub async fn dispatch_tagged_workflow( - &self, - tags: &serde_json::Value, - input: I, - ) -> GlobalResult - where - I: WorkflowInput, - ::Workflow: Workflow, - { - let name = I::Workflow::NAME; - let id = Uuid::new_v4(); - - tracing::info!(workflow_name=%name, workflow_id=%id, ?tags, ?input, "dispatching tagged workflow"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeWorkflowOutput) - .map_err(GlobalError::raw)?; - - self.db - .dispatch_workflow(self.ray_id, id, &name, Some(tags), input_val) - .await - .map_err(GlobalError::raw)?; - - tracing::info!(workflow_name=%name, workflow_id=%id, "workflow dispatched"); - - Ok(id) - } - /// Wait for a given workflow to complete. /// 60 second timeout. pub async fn wait_for_workflow( &self, workflow_id: Uuid, ) -> GlobalResult { - tracing::info!(workflow_name=%W::NAME, id=?workflow_id, "waiting for workflow"); - - tokio::time::timeout(WORKFLOW_TIMEOUT, async move { - let mut interval = tokio::time::interval(SUB_WORKFLOW_RETRY); - loop { - interval.tick().await; - - // Check if state finished - let workflow = self - .db - .get_workflow(workflow_id) - .await - .map_err(GlobalError::raw)? - .ok_or(WorkflowError::WorkflowNotFound) - .map_err(GlobalError::raw)?; - if let Some(output) = workflow.parse_output::().map_err(GlobalError::raw)? { - return Ok(output); - } - } - }) - .await? + common::wait_for_workflow::(&self.db, workflow_id).await } - /// Dispatch a new workflow and wait for it to complete. Has a 60s timeout. - pub async fn workflow( - &self, - input: I, - ) -> GlobalResult<<::Workflow as Workflow>::Output> + /// Creates a workflow builder. + pub fn workflow(&self, input: I) -> builder::workflow::WorkflowBuilder where I: WorkflowInput, ::Workflow: Workflow, { - let workflow_id = self.dispatch_workflow(input).await?; - self.wait_for_workflow::(workflow_id).await + builder::workflow::WorkflowBuilder::new(self.db.clone(), self.ray_id, input) } - /// Dispatch a new workflow with tags and wait for it to complete. Has a 60s timeout. - pub async fn tagged_workflow( - &self, - tags: &serde_json::Value, - input: I, - ) -> GlobalResult<<::Workflow as Workflow>::Output> - where - I: WorkflowInput, - ::Workflow: Workflow, - { - let workflow_id = self.dispatch_tagged_workflow(tags, input).await?; - self.wait_for_workflow::(workflow_id).await - } - - pub async fn signal( - &self, - workflow_id: Uuid, - input: T, - ) -> GlobalResult { - let signal_id = Uuid::new_v4(); - - tracing::info!(signal_name=%T::NAME, %workflow_id, %signal_id, "dispatching signal"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; - - self.db - .publish_signal(self.ray_id, workflow_id, signal_id, T::NAME, input_val) - .await - .map_err(GlobalError::raw)?; - - Ok(signal_id) - } - - pub async fn tagged_signal( - &self, - tags: &serde_json::Value, - input: T, - ) -> GlobalResult { - let signal_id = Uuid::new_v4(); - - tracing::info!(signal_name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; - - self.db - .publish_tagged_signal(self.ray_id, tags, signal_id, T::NAME, input_val) - .await - .map_err(GlobalError::raw)?; - - Ok(signal_id) + /// Creates a signal builder. + pub fn signal(&self, body: T) -> builder::signal::SignalBuilder { + builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body) } #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] @@ -233,25 +101,15 @@ impl StandaloneCtx { I: OperationInput, ::Operation: Operation, { - tracing::info!(?input, "operation call"); - - let ctx = OperationCtx::new( - self.db.clone(), + common::op( + &self.db, &self.conn, self.ray_id, self.op_ctx.req_ts(), false, - I::Operation::NAME, - ); - - let res = I::Operation::run(&ctx, &input) - .await - .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw); - - tracing::info!(?res, "operation response"); - - res + input, + ) + .await } pub async fn subscribe( diff --git a/lib/chirp-workflow/core/src/ctx/test.rs b/lib/chirp-workflow/core/src/ctx/test.rs index aa161fda9e..7585824c3a 100644 --- a/lib/chirp-workflow/core/src/ctx/test.rs +++ b/lib/chirp-workflow/core/src/ctx/test.rs @@ -5,10 +5,11 @@ use tokio::time::Duration; use uuid::Uuid; use crate::{ + builder::common as builder, ctx::{ + common::{self, SUB_WORKFLOW_RETRY}, message::{SubscriptionHandle, TailAnchor, TailAnchorResponse}, - workflow::SUB_WORKFLOW_RETRY, - MessageCtx, OperationCtx, + MessageCtx, }, db::{DatabaseHandle, DatabasePgNats}, error::WorkflowError, @@ -87,60 +88,6 @@ impl TestCtx { } impl TestCtx { - pub async fn dispatch_workflow(&self, input: I) -> GlobalResult - where - I: WorkflowInput, - ::Workflow: Workflow, - { - let name = I::Workflow::NAME; - let id = Uuid::new_v4(); - - tracing::info!(workflow_name=%name, workflow_id=%id, ?input, "dispatching workflow"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeWorkflowOutput) - .map_err(GlobalError::raw)?; - - self.db - .dispatch_workflow(self.ray_id, id, &name, None, input_val) - .await - .map_err(GlobalError::raw)?; - - tracing::info!(workflow_name=%name, workflow_id=%id, "workflow dispatched"); - - Ok(id) - } - - pub async fn dispatch_tagged_workflow( - &self, - tags: &serde_json::Value, - input: I, - ) -> GlobalResult - where - I: WorkflowInput, - ::Workflow: Workflow, - { - let name = I::Workflow::NAME; - let id = Uuid::new_v4(); - - tracing::info!(workflow_name=%name, workflow_id=%id, ?tags, ?input, "dispatching tagged workflow"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeWorkflowOutput) - .map_err(GlobalError::raw)?; - - self.db - .dispatch_workflow(self.ray_id, id, &name, Some(tags), input_val) - .await - .map_err(GlobalError::raw)?; - - tracing::info!(workflow_name=%name, workflow_id=%id, "workflow dispatched"); - - Ok(id) - } - pub async fn wait_for_workflow( &self, workflow_id: Uuid, @@ -165,75 +112,30 @@ impl TestCtx { } } - pub async fn workflow( - &self, - input: I, - ) -> GlobalResult<<::Workflow as Workflow>::Output> + /// Creates a workflow builder. + pub fn workflow(&self, input: I) -> builder::workflow::WorkflowBuilder where I: WorkflowInput, ::Workflow: Workflow, { - let workflow_id = self.dispatch_workflow(input).await?; - let output = self.wait_for_workflow::(workflow_id).await?; - Ok(output) + builder::workflow::WorkflowBuilder::new(self.db.clone(), self.ray_id, input) } - pub async fn tagged_workflow( + /// Creates a signal builder. + pub fn signal(&self, body: T) -> builder::signal::SignalBuilder { + builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body) + } + + #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] + pub async fn op( &self, - tags: &serde_json::Value, input: I, - ) -> GlobalResult<<::Workflow as Workflow>::Output> + ) -> GlobalResult<<::Operation as Operation>::Output> where - I: WorkflowInput, - ::Workflow: Workflow, + I: OperationInput, + ::Operation: Operation, { - let workflow_id = self.dispatch_tagged_workflow(tags, input).await?; - let output = self.wait_for_workflow::(workflow_id).await?; - Ok(output) - } - - pub async fn signal( - &self, - workflow_id: Uuid, - input: T, - ) -> GlobalResult { - let signal_id = Uuid::new_v4(); - - tracing::info!(signal_name=%T::NAME, %workflow_id, %signal_id, "dispatching signal"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; - - self.db - .publish_signal(self.ray_id, workflow_id, signal_id, T::NAME, input_val) - .await - .map_err(GlobalError::raw)?; - - Ok(signal_id) - } - - pub async fn tagged_signal( - &self, - tags: &serde_json::Value, - input: T, - ) -> GlobalResult { - let signal_id = Uuid::new_v4(); - - tracing::info!(signal_name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal"); - - // Serialize input - let input_val = serde_json::to_value(input) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; - - self.db - .publish_tagged_signal(self.ray_id, tags, signal_id, T::NAME, input_val) - .await - .map_err(GlobalError::raw)?; - - Ok(signal_id) + common::op(&self.db, &self.conn, self.ray_id, self.ts, false, input).await } /// Waits for a workflow to be triggered with a superset of given input. Strictly for tests. @@ -251,54 +153,11 @@ impl TestCtx { }) } - #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] - pub async fn op( - &self, - input: I, - ) -> GlobalResult<<::Operation as Operation>::Output> - where - I: OperationInput, - ::Operation: Operation, - { - tracing::info!(?input, "operation call"); - - let ctx = OperationCtx::new( - self.db.clone(), - &self.conn, - self.ray_id, - self.ts, - false, - I::Operation::NAME, - ); - - let res = I::Operation::run(&ctx, &input) - .await - .map_err(WorkflowError::OperationFailure) - .map_err(GlobalError::raw); - - tracing::info!(?res, "operation response"); - - res - } - - pub async fn msg(&self, tags: serde_json::Value, body: M) -> GlobalResult<()> - where - M: Message, - { - self.msg_ctx - .message(tags, body) - .await - .map_err(GlobalError::raw) - } - - pub async fn msg_wait(&self, tags: serde_json::Value, body: M) -> GlobalResult<()> + pub async fn msg(&self, body: M) -> builder::message::MessageBuilder where M: Message, { - self.msg_ctx - .message_wait(tags, body) - .await - .map_err(GlobalError::raw) + builder::message::MessageBuilder::new(&self.msg_ctx, body) } pub async fn subscribe( diff --git a/lib/chirp-workflow/core/src/ctx/workflow.rs b/lib/chirp-workflow/core/src/ctx/workflow.rs index e4c4296032..639d5e41d3 100644 --- a/lib/chirp-workflow/core/src/ctx/workflow.rs +++ b/lib/chirp-workflow/core/src/ctx/workflow.rs @@ -8,7 +8,11 @@ use uuid::Uuid; use crate::{ activity::ActivityId, activity::{Activity, ActivityInput}, - ctx::{ActivityCtx, ListenCtx, MessageCtx}, + builder::workflow as builder, + ctx::{ + common::{RETRY_TIMEOUT_MS, SUB_WORKFLOW_RETRY}, + ActivityCtx, ListenCtx, MessageCtx, + }, db::{DatabaseHandle, PulledWorkflow}, error::{WorkflowError, WorkflowResult}, event::Event, @@ -27,19 +31,15 @@ use crate::{ workflow::{Workflow, WorkflowInput}, }; -// Time to delay a workflow from retrying after an error -pub const RETRY_TIMEOUT_MS: usize = 2000; -// Poll interval when polling for signals in-process +/// Poll interval when polling for signals in-process const SIGNAL_RETRY: Duration = Duration::from_millis(100); -// Most in-process signal poll tries +/// Most in-process signal poll tries const MAX_SIGNAL_RETRIES: usize = 16; -// Poll interval when polling for a sub workflow in-process -pub const SUB_WORKFLOW_RETRY: Duration = Duration::from_millis(150); -// Most in-process sub workflow poll tries +/// Most in-process sub workflow poll tries const MAX_SUB_WORKFLOW_RETRIES: usize = 4; -// Retry interval for failed db actions +/// Retry interval for failed db actions const DB_ACTION_RETRY: Duration = Duration::from_millis(150); -// Most db action retries +/// Most db action retries const MAX_DB_ACTION_RETRIES: usize = 5; // TODO: Use generics to store input instead of a json value @@ -54,7 +54,7 @@ pub struct WorkflowCtx { ray_id: Uuid, registry: RegistryHandle, - pub(crate) db: DatabaseHandle, + db: DatabaseHandle, conn: rivet_connection::Connection, @@ -62,7 +62,7 @@ pub struct WorkflowCtx { /// The reason this type is a hashmap is to allow querying by location. event_history: Arc>>, /// Input data passed to this workflow. - pub(crate) input: Arc, + input: Arc, root_location: Location, location_idx: usize, @@ -109,7 +109,16 @@ impl WorkflowCtx { /// Creates a new workflow run with one more depth in the location. Meant to be implemented and not used /// directly in workflows. pub fn branch(&mut self) -> Self { - let branch = WorkflowCtx { + let branch = self.with_input(self.input.clone()); + + self.inc_location(); + + branch + } + + /// Clones the current ctx but with a different input. + pub(crate) fn with_input(&self, input: Arc) -> Self { + WorkflowCtx { workflow_id: self.workflow_id, name: self.name.clone(), create_ts: self.create_ts, @@ -122,7 +131,7 @@ impl WorkflowCtx { conn: self.conn.clone(), event_history: self.event_history.clone(), - input: self.input.clone(), + input, root_location: self .root_location @@ -134,11 +143,7 @@ impl WorkflowCtx { loop_location: self.loop_location.clone(), msg_ctx: self.msg_ctx.clone(), - }; - - self.inc_location(); - - branch + } } /// Like `branch` but it does not add another layer of depth. Meant to be implemented and not used @@ -160,6 +165,11 @@ impl WorkflowCtx { .flatten() } + /// Returns the event at the current location index. + pub(crate) fn current_history_event(&self) -> Option<&Event> { + self.relevant_history().nth(self.location_idx) + } + pub(crate) fn full_location(&self) -> Location { self.root_location .iter() @@ -172,15 +182,6 @@ impl WorkflowCtx { self.location_idx += 1; } - pub(crate) fn loop_location(&self) -> Option<&[usize]> { - self.loop_location.as_deref() - } - - /// For debugging, pretty prints the current location - fn loc(&self) -> String { - util::format_location(&self.full_location()) - } - pub(crate) async fn run(mut self) -> WorkflowResult<()> { tracing::info!(name=%self.name, id=%self.workflow_id, "running workflow"); @@ -190,7 +191,7 @@ impl WorkflowCtx { // Run workflow match (workflow.run)(&mut self).await { Ok(output) => { - tracing::info!(name=%self.name, id=%self.workflow_id, "workflow success"); + tracing::info!(name=%self.name, id=%self.workflow_id, "workflow completed"); let mut retries = 0; let mut interval = tokio::time::interval(DB_ACTION_RETRY); @@ -394,7 +395,8 @@ impl WorkflowCtx { } impl WorkflowCtx { - async fn dispatch_workflow_inner( + /// Used internally to dispatch a workflow. Use `WorkflowCtx::workflow` instead. + pub(crate) async fn dispatch_workflow_inner( &mut self, tags: Option<&serde_json::Value>, input: I, @@ -403,7 +405,7 @@ impl WorkflowCtx { I: WorkflowInput, ::Workflow: Workflow, { - let event = self.relevant_history().nth(self.location_idx); + let event = self.current_history_event(); // Signal received before let id = if let Some(event) = event { @@ -523,14 +525,13 @@ impl WorkflowCtx { } } - /// Runs a sub workflow in the same process as the current workflow (if possible) and returns its - /// response. - pub fn workflow(&mut self, input: I) -> builder::SubWorkflowBuilder + /// Creates a sub workflow builder. + pub fn workflow(&mut self, input: I) -> builder::sub_workflow::SubWorkflowBuilder where I: WorkflowInput, ::Workflow: Workflow, { - builder::SubWorkflowBuilder::new(self, input) + builder::sub_workflow::SubWorkflowBuilder::new(self, input) } /// Run activity. Will replay on failure. @@ -544,7 +545,7 @@ impl WorkflowCtx { { let activity_id = ActivityId::new::(&input); - let event = self.relevant_history().nth(self.location_idx); + let event = self.current_history_event(); // Activity was ran before let output = if let Some(event) = event { @@ -672,15 +673,15 @@ impl WorkflowCtx { } } - /// Starts building a signal. - pub fn signal(&mut self, body: T) -> builder::SignalBuilder { - builder::SignalBuilder::new(self, body) + /// Creates a signal builder. + pub fn signal(&mut self, body: T) -> builder::signal::SignalBuilder { + builder::signal::SignalBuilder::new(self, body) } /// Listens for a signal for a short time before setting the workflow to sleep. Once the signal is /// received, the workflow will be woken up and continue. pub async fn listen(&mut self) -> GlobalResult { - let event = self.relevant_history().nth(self.location_idx); + let event = self.current_history_event(); // Signal received before let signal = if let Some(event) = event { @@ -734,7 +735,7 @@ impl WorkflowCtx { &mut self, listener: &T, ) -> GlobalResult<::Output> { - let event = self.relevant_history().nth(self.location_idx); + let event = self.current_history_event(); // Signal received before let signal = if let Some(event) = event { @@ -787,7 +788,7 @@ impl WorkflowCtx { // database so that upon replay it again receives no signal // /// Checks if the given signal exists in the database. // pub async fn query_signal(&mut self) -> GlobalResult> { - // let event = self.relevant_history().nth(self.location_idx); + // let event = self.current_history_event(); // // Signal received before // let signal = if let Some(event) = event { @@ -821,11 +822,11 @@ impl WorkflowCtx { // Ok(signal) // } - pub fn msg(&mut self, body: M) -> builder::MessageBuilder + pub fn msg(&mut self, body: M) -> builder::message::MessageBuilder where M: Message, { - builder::MessageBuilder::new(self, body) + builder::message::MessageBuilder::new(self, body) } /// Runs workflow steps in a loop. **Ensure that there are no side effects caused by the code in this @@ -920,7 +921,7 @@ impl WorkflowCtx { } pub async fn sleep_until(&mut self, time: T) -> GlobalResult<()> { - let event = self.relevant_history().nth(self.location_idx); + let event = self.current_history_event(); // Slept before let (deadline_ts, replay) = if let Some(event) = event { @@ -958,7 +959,7 @@ impl WorkflowCtx { // No-op if duration < 0 { if !replay { - tracing::warn!("tried to sleep for a negative duration"); + tracing::warn!(name=%self.name, id=%self.workflow_id, %duration, "tried to sleep for a negative duration"); } } else if duration < worker::TICK_INTERVAL.as_millis() as i64 + 1 { tracing::info!(name=%self.name, id=%self.workflow_id, %deadline_ts, "sleeping in memory"); @@ -979,6 +980,31 @@ impl WorkflowCtx { } impl WorkflowCtx { + pub(crate) fn registry(&self) -> &RegistryHandle { + &self.registry + } + + pub(crate) fn input(&self) -> &Arc { + &self.input + } + + pub(crate) fn loop_location(&self) -> Option<&[usize]> { + self.loop_location.as_deref() + } + + pub(crate) fn db(&self) -> &DatabaseHandle { + &self.db + } + + pub(crate) fn msg_ctx(&self) -> &MessageCtx { + &self.msg_ctx + } + + /// For debugging, pretty prints the current location + pub(crate) fn loc(&self) -> String { + util::format_location(&self.full_location()) + } + pub fn name(&self) -> &str { &self.name } @@ -987,6 +1013,10 @@ impl WorkflowCtx { self.workflow_id } + pub fn ray_id(&self) -> Uuid { + self.ray_id + } + /// Timestamp at which this workflow run started. pub fn ts(&self) -> i64 { self.ts @@ -1007,502 +1037,3 @@ pub enum Loop { Continue, Break(T), } - -// Tightly ingrained with the workflow ctx -pub mod builder { - use std::{fmt::Display, sync::Arc}; - - use global_error::{GlobalError, GlobalResult}; - use serde::Serialize; - use uuid::Uuid; - - use crate::{ - ctx::WorkflowCtx, - error::WorkflowError, - event::Event, - message::Message, - signal::Signal, - workflow::{Workflow, WorkflowInput}, - }; - - #[derive(thiserror::Error, Debug)] - enum BuilderError { - #[error("tags must be a JSON map")] - TagsNotMap, - #[error("cannot call `to_workflow` and set tags on the same signal")] - WorkflowIdAndTags, - #[error("must call `to_workflow` or set tags on signal")] - NoWorkflowIdOrTags, - } - - pub struct SignalBuilder<'a, T: Signal + Serialize> { - ctx: &'a mut WorkflowCtx, - body: T, - to_workflow_id: Option, - tags: serde_json::Map, - error: Option, - } - - impl<'a, T: Signal + Serialize> SignalBuilder<'a, T> { - pub(crate) fn new(ctx: &'a mut WorkflowCtx, body: T) -> Self { - SignalBuilder { - ctx, - body, - to_workflow_id: None, - tags: serde_json::Map::new(), - error: None, - } - } - - pub fn to_workflow(mut self, workflow_id: Uuid) -> Self { - if self.error.is_some() { - return self; - } - - self.to_workflow_id = Some(workflow_id); - - self - } - - pub fn tags(mut self, tags: serde_json::Value) -> Self { - if self.error.is_some() { - return self; - } - - match tags { - serde_json::Value::Object(map) => { - self.tags.extend(map); - } - _ => self.error = Some(BuilderError::TagsNotMap.into()), - } - - self - } - - pub fn tag(mut self, k: impl Display, v: impl Serialize) -> Self { - if self.error.is_some() { - return self; - } - - match serde_json::to_value(&v) { - Ok(v) => { - self.tags.insert(k.to_string(), v); - } - Err(err) => self.error = Some(err.into()), - } - - self - } - - pub async fn send(self) -> GlobalResult { - if let Some(err) = self.error { - return Err(err); - } - - let event = self.ctx.relevant_history().nth(self.ctx.location_idx); - - // Signal sent before - if let Some(event) = event { - // Validate history is consistent - let Event::SignalSend(signal) = event else { - return Err(WorkflowError::HistoryDiverged(format!( - "expected {event} at {}, found signal send {}", - self.ctx.loc(), - T::NAME - ))) - .map_err(GlobalError::raw); - }; - - if signal.name != T::NAME { - return Err(WorkflowError::HistoryDiverged(format!( - "expected {event} at {}, found signal send {}", - self.ctx.loc(), - T::NAME - ))) - .map_err(GlobalError::raw); - } - - tracing::debug!(name=%self.ctx.name, id=%self.ctx.workflow_id, signal_name=%signal.name, signal_id=%signal.signal_id, "replaying signal dispatch"); - - Ok(signal.signal_id) - } - // Send signal - else { - let signal_id = Uuid::new_v4(); - - // Serialize input - let input_val = serde_json::to_value(&self.body) - .map_err(WorkflowError::SerializeSignalBody) - .map_err(GlobalError::raw)?; - - match (self.to_workflow_id, self.tags.is_empty()) { - (Some(workflow_id), true) => { - tracing::info!(name=%self.ctx.name, id=%self.ctx.workflow_id, signal_name=%T::NAME, to_workflow_id=%workflow_id, %signal_id, "dispatching signal"); - - self.ctx - .db - .publish_signal_from_workflow( - self.ctx.workflow_id, - self.ctx.full_location().as_ref(), - self.ctx.ray_id, - workflow_id, - signal_id, - T::NAME, - input_val, - self.ctx.loop_location(), - ) - .await - .map_err(GlobalError::raw)?; - } - (None, false) => { - tracing::info!(name=%self.ctx.name, id=%self.ctx.workflow_id, signal_name=%T::NAME, tags=?self.tags, %signal_id, "dispatching tagged signal"); - - self.ctx - .db - .publish_tagged_signal_from_workflow( - self.ctx.workflow_id, - self.ctx.full_location().as_ref(), - self.ctx.ray_id, - &serde_json::Value::Object(self.tags), - signal_id, - T::NAME, - input_val, - self.ctx.loop_location(), - ) - .await - .map_err(GlobalError::raw)?; - } - (Some(_), false) => return Err(BuilderError::WorkflowIdAndTags.into()), - (None, true) => return Err(BuilderError::NoWorkflowIdOrTags.into()), - } - - // Move to next event - self.ctx.inc_location(); - - Ok(signal_id) - } - } - } - - pub struct MessageBuilder<'a, M: Message> { - ctx: &'a mut WorkflowCtx, - body: M, - tags: serde_json::Map, - wait: bool, - error: Option, - } - - impl<'a, M: Message> MessageBuilder<'a, M> { - pub(crate) fn new(ctx: &'a mut WorkflowCtx, body: M) -> Self { - MessageBuilder { - ctx, - body, - tags: serde_json::Map::new(), - wait: false, - error: None, - } - } - - pub fn tags(mut self, tags: serde_json::Value) -> Self { - if self.error.is_some() { - return self; - } - - match tags { - serde_json::Value::Object(map) => { - self.tags.extend(map); - } - _ => self.error = Some(BuilderError::TagsNotMap.into()), - } - - self - } - - pub fn tag(mut self, k: impl Display, v: impl Serialize) -> Self { - if self.error.is_some() { - return self; - } - - match serde_json::to_value(&v) { - Ok(v) => { - self.tags.insert(k.to_string(), v); - } - Err(err) => self.error = Some(err.into()), - } - - self - } - - pub async fn wait(mut self) -> Self { - if self.error.is_some() { - return self; - } - - self.wait = true; - - self - } - - pub async fn send(self) -> GlobalResult<()> { - if let Some(err) = self.error { - return Err(err); - } - - let event = self.ctx.relevant_history().nth(self.ctx.location_idx); - - // Message sent before - if let Some(event) = event { - // Validate history is consistent - let Event::MessageSend(msg) = event else { - return Err(WorkflowError::HistoryDiverged(format!( - "expected {event} at {}, found message send {}", - self.ctx.loc(), - M::NAME, - ))) - .map_err(GlobalError::raw); - }; - - if msg.name != M::NAME { - return Err(WorkflowError::HistoryDiverged(format!( - "expected {event} at {}, found message send {}", - self.ctx.loc(), - M::NAME, - ))) - .map_err(GlobalError::raw); - } - - tracing::debug!(name=%self.ctx.name, id=%self.ctx.workflow_id, msg_name=%msg.name, "replaying message dispatch"); - } - // Send message - else { - tracing::info!(name=%self.ctx.name, id=%self.ctx.workflow_id, msg_name=%M::NAME, tags=?self.tags, "dispatching message"); - - // Serialize body - let body_val = serde_json::to_value(&self.body) - .map_err(WorkflowError::SerializeMessageBody) - .map_err(GlobalError::raw)?; - let location = self.ctx.full_location(); - let tags = serde_json::Value::Object(self.tags); - let tags2 = tags.clone(); - - let (msg, write) = tokio::join!( - async { - self.ctx - .db - .commit_workflow_message_send_event( - self.ctx.workflow_id, - location.as_ref(), - &tags, - M::NAME, - body_val, - self.ctx.loop_location(), - ) - .await - }, - async { - if self.wait { - self.ctx.msg_ctx.message_wait(tags2, self.body).await - } else { - self.ctx.msg_ctx.message(tags2, self.body).await - } - }, - ); - - msg.map_err(GlobalError::raw)?; - write.map_err(GlobalError::raw)?; - } - - // Move to next event - self.ctx.inc_location(); - - Ok(()) - } - } - - pub struct SubWorkflowBuilder<'a, I: WorkflowInput> { - ctx: &'a mut WorkflowCtx, - input: I, - tags: serde_json::Map, - error: Option, - } - - impl<'a, I: WorkflowInput> SubWorkflowBuilder<'a, I> - where - ::Workflow: Workflow, - { - pub(crate) fn new(ctx: &'a mut WorkflowCtx, input: I) -> Self { - SubWorkflowBuilder { - ctx, - input, - tags: serde_json::Map::new(), - error: None, - } - } - - pub fn tags(mut self, tags: serde_json::Value) -> Self { - if self.error.is_some() { - return self; - } - - match tags { - serde_json::Value::Object(map) => { - self.tags.extend(map); - } - _ => self.error = Some(BuilderError::TagsNotMap.into()), - } - - self - } - - pub fn tag(mut self, k: impl Display, v: impl Serialize) -> Self { - if self.error.is_some() { - return self; - } - - match serde_json::to_value(&v) { - Ok(v) => { - self.tags.insert(k.to_string(), v); - } - Err(err) => self.error = Some(err.into()), - } - - self - } - - pub async fn dispatch(self) -> GlobalResult { - if let Some(err) = self.error { - return Err(err); - } - - let sub_workflow_name = I::Workflow::NAME; - let sub_workflow_id = Uuid::new_v4(); - - let no_tags = self.tags.is_empty(); - let tags = serde_json::Value::Object(self.tags); - let tags = if no_tags { None } else { Some(&tags) }; - - tracing::info!( - name=%self.ctx.name, - id=%self.ctx.workflow_id, - %sub_workflow_name, - %sub_workflow_id, - ?tags, - input=?self.input, - "dispatching sub workflow" - ); - - // Serialize input - let input_val = serde_json::to_value(&self.input) - .map_err(WorkflowError::SerializeWorkflowOutput) - .map_err(GlobalError::raw)?; - - self.ctx - .db - .dispatch_sub_workflow( - self.ctx.ray_id, - self.ctx.workflow_id, - self.ctx.full_location().as_ref(), - sub_workflow_id, - &sub_workflow_name, - tags, - input_val, - self.ctx.loop_location(), - ) - .await - .map_err(GlobalError::raw)?; - - tracing::info!( - name=%self.ctx.name, - id=%self.ctx.workflow_id, - %sub_workflow_name, - ?sub_workflow_id, - "sub workflow dispatched" - ); - - Ok(sub_workflow_id) - } - - pub async fn run( - self, - ) -> GlobalResult<<::Workflow as Workflow>::Output> { - if let Some(err) = self.error { - return Err(err); - } - - let no_tags = self.tags.is_empty(); - let tags = serde_json::Value::Object(self.tags); - let tags = if no_tags { None } else { Some(&tags) }; - - // Lookup workflow - let Ok(workflow) = self.ctx.registry.get_workflow(I::Workflow::NAME) else { - tracing::warn!( - name=%self.ctx.name, - id=%self.ctx.workflow_id, - sub_workflow_name=%I::Workflow::NAME, - "sub workflow not found in current registry", - ); - - // TODO(RVT-3755): If a sub workflow is dispatched, then the worker is updated to include the sub - // worker in the registry, this will diverge in history because it will try to run the sub worker - // in-process during the replay - // If the workflow isn't in the current registry, dispatch the workflow instead - let sub_workflow_id = self.ctx.dispatch_workflow_inner(tags, self.input).await?; - let output = self - .ctx - .wait_for_workflow::(sub_workflow_id) - .await?; - - return Ok(output); - }; - - tracing::info!(name=%self.ctx.name, id=%self.ctx.workflow_id, sub_workflow_name=%I::Workflow::NAME, "running sub workflow"); - - // Create a new branched workflow context for the sub workflow - let mut ctx = WorkflowCtx { - workflow_id: self.ctx.workflow_id, - name: I::Workflow::NAME.to_string(), - create_ts: rivet_util::timestamp::now(), - ts: rivet_util::timestamp::now(), - ray_id: self.ctx.ray_id, - - registry: self.ctx.registry.clone(), - db: self.ctx.db.clone(), - - conn: self - .ctx - .conn - .wrap(Uuid::new_v4(), self.ctx.ray_id, I::Workflow::NAME), - - event_history: self.ctx.event_history.clone(), - - // TODO(RVT-3756): This is redundant with the deserialization in `workflow.run` in the registry - input: Arc::new(serde_json::to_value(&self.input)?), - - root_location: self - .ctx - .root_location - .iter() - .cloned() - .chain(std::iter::once(self.ctx.location_idx)) - .collect(), - location_idx: 0, - loop_location: self.ctx.loop_location.clone(), - - msg_ctx: self.ctx.msg_ctx.clone(), - }; - - // Run workflow - let output = (workflow.run)(&mut ctx).await.map_err(GlobalError::raw)?; - - // TODO: RVT-3756 - // Deserialize output - let output = serde_json::from_value(output) - .map_err(WorkflowError::DeserializeWorkflowOutput) - .map_err(GlobalError::raw)?; - - self.ctx.inc_location(); - - Ok(output) - } - } -} diff --git a/lib/chirp-workflow/core/src/error.rs b/lib/chirp-workflow/core/src/error.rs index f3113e0f67..65205d31a5 100644 --- a/lib/chirp-workflow/core/src/error.rs +++ b/lib/chirp-workflow/core/src/error.rs @@ -4,7 +4,7 @@ use global_error::GlobalError; use tokio::time::Instant; use uuid::Uuid; -use crate::ctx::workflow::RETRY_TIMEOUT_MS; +use crate::ctx::common::RETRY_TIMEOUT_MS; pub type WorkflowResult = Result; diff --git a/lib/chirp-workflow/core/src/lib.rs b/lib/chirp-workflow/core/src/lib.rs index 36871b71f1..e8e16f5eb3 100644 --- a/lib/chirp-workflow/core/src/lib.rs +++ b/lib/chirp-workflow/core/src/lib.rs @@ -1,4 +1,5 @@ pub mod activity; +pub mod builder; pub mod compat; pub mod ctx; pub mod db; diff --git a/lib/chirp-workflow/core/src/registry.rs b/lib/chirp-workflow/core/src/registry.rs index 6a1343ddbc..f0bf1eae95 100644 --- a/lib/chirp-workflow/core/src/registry.rs +++ b/lib/chirp-workflow/core/src/registry.rs @@ -56,7 +56,7 @@ impl Registry { run: |ctx| { async move { // Deserialize input - let input = serde_json::from_value(ctx.input.as_ref().clone()) + let input = serde_json::from_value(ctx.input().as_ref().clone()) .map_err(WorkflowError::DeserializeWorkflowInput)?; // Run workflow