Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/libraries/workflow/GOTCHAS.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,7 @@ the internal location.
> **\*** Even if they did know about each other via atomics, there is no guarantee of consistency from
> `buffer_unordered`. Preemptively incrementing the location ensures consistency regardless of the order or
> completion time of the futures.

## Loops

TODO
9 changes: 9 additions & 0 deletions docs/libraries/workflow/LOOPS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Loops

TODO

## Differences between "Continue As New"

https://docs.temporal.io/develop/go/continue-as-new

TODO
2 changes: 1 addition & 1 deletion lib/chirp-workflow/core/src/activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub trait Activity {
type Output: Serialize + DeserializeOwned + Debug + Send;

const NAME: &'static str;
const MAX_RETRIES: u32;
const MAX_RETRIES: usize;
const TIMEOUT: std::time::Duration;

async fn run(ctx: &ActivityCtx, input: &Self::Input) -> GlobalResult<Self::Output>;
Expand Down
1 change: 1 addition & 0 deletions lib/chirp-workflow/core/src/ctx/listen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ impl<'a> ListenCtx<'a> {
self.ctx.workflow_id(),
signal_names,
self.ctx.full_location().as_ref(),
self.ctx.loop_location(),
)
.await?;

Expand Down
116 changes: 112 additions & 4 deletions lib/chirp-workflow/core/src/ctx/workflow.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};

use global_error::{GlobalError, GlobalResult};
use serde::Serialize;
use serde::{de::DeserializeOwned, Serialize};
use tokio::time::Duration;
use uuid::Uuid;

Expand Down Expand Up @@ -65,6 +65,9 @@ pub struct WorkflowCtx {
root_location: Location,
location_idx: usize,

/// If this context is currently in a loop, this is the location of the where the loop started.
loop_location: Option<Box<[usize]>>,

msg_ctx: MessageCtx,
}

Expand Down Expand Up @@ -95,6 +98,7 @@ impl WorkflowCtx {

root_location: Box::new([]),
location_idx: 0,
loop_location: None,

msg_ctx,
})
Expand Down Expand Up @@ -125,6 +129,7 @@ impl WorkflowCtx {
.chain(std::iter::once(self.location_idx))
.collect(),
location_idx: 0,
loop_location: self.loop_location.clone(),

msg_ctx: self.msg_ctx.clone(),
};
Expand Down Expand Up @@ -161,6 +166,10 @@ impl WorkflowCtx {
.collect()
}

pub(crate) fn loop_location(&self) -> Option<&[usize]> {
self.loop_location.as_deref()
}

// Purposefully infallible
pub(crate) async fn run(mut self) {
if let Err(err) = Self::run_inner(&mut self).await {
Expand Down Expand Up @@ -216,7 +225,8 @@ impl WorkflowCtx {
// finish. This workflow will be retried when the sub workflow completes
let wake_sub_workflow = err.sub_workflow();

if deadline_ts.is_some() || !wake_signals.is_empty() || wake_sub_workflow.is_some() {
if deadline_ts.is_some() || !wake_signals.is_empty() || wake_sub_workflow.is_some()
{
tracing::info!(name=%self.name, id=%self.workflow_id, ?err, "workflow sleeping");
} else {
tracing::error!(name=%self.name, id=%self.workflow_id, ?err, "workflow error");
Expand Down Expand Up @@ -299,6 +309,7 @@ impl WorkflowCtx {
create_ts,
input_val,
Ok(output_val),
self.loop_location(),
)
.await?;

Expand All @@ -318,6 +329,7 @@ impl WorkflowCtx {
create_ts,
input_val,
Err(&err.to_string()),
self.loop_location(),
)
.await?;

Expand All @@ -336,6 +348,7 @@ impl WorkflowCtx {
create_ts,
input_val,
Err(&err.to_string()),
self.loop_location(),
)
.await?;

Expand Down Expand Up @@ -437,6 +450,7 @@ impl WorkflowCtx {
&sub_workflow_name,
tags,
input_val,
self.loop_location(),
)
.await
.map_err(GlobalError::raw)?;
Expand Down Expand Up @@ -579,6 +593,7 @@ impl WorkflowCtx {
.chain(std::iter::once(self.location_idx))
.collect(),
location_idx: 0,
loop_location: self.loop_location.clone(),

msg_ctx: self.msg_ctx.clone(),
};
Expand Down Expand Up @@ -746,6 +761,7 @@ impl WorkflowCtx {
signal_id,
T::NAME,
input_val,
self.loop_location(),
)
.await
.map_err(GlobalError::raw)?;
Expand Down Expand Up @@ -810,6 +826,7 @@ impl WorkflowCtx {
signal_id,
T::NAME,
input_val,
self.loop_location(),
)
.await
.map_err(GlobalError::raw)?;
Expand Down Expand Up @@ -1005,7 +1022,8 @@ impl WorkflowCtx {
location.as_ref(),
&tags,
M::NAME,
body_val
body_val,
self.loop_location(),
),
self.msg_ctx.message(tags.clone(), body),
);
Expand Down Expand Up @@ -1063,7 +1081,8 @@ impl WorkflowCtx {
location.as_ref(),
&tags,
M::NAME,
body_val
body_val,
self.loop_location(),
),
self.msg_ctx.message_wait(tags.clone(), body),
);
Expand All @@ -1078,6 +1097,90 @@ impl WorkflowCtx {
Ok(())
}

/// Runs workflow steps in a loop. **Ensure that there are no side effects caused by the code in this
/// callback**. If you need side causes or side effects, use a native rust loop.
pub async fn repeat<F, T>(&mut self, mut cb: F) -> GlobalResult<T>
where
F: for<'a> FnMut(&'a mut WorkflowCtx) -> AsyncResult<'a, Loop<T>>,
T: Serialize + DeserializeOwned,
{
let loop_location = self.full_location();
let mut loop_branch = self.branch();

let event = { self.relevant_history().nth(self.location_idx) };

// Loop existed before
let output = if let Some(event) = event {
// Validate history is consistent
let Event::Loop(loop_event) = event else {
return Err(WorkflowError::HistoryDiverged(format!(
"expected {event}, found loop"
)))
.map_err(GlobalError::raw);
};

let output = loop_event.parse_output().map_err(GlobalError::raw)?;

// Shift by iteration count
loop_branch.location_idx = loop_event.iteration;

output
} else {
None
};

// Loop complete
let output = if let Some(output) = output {
tracing::debug!(name=%self.name, id=%self.workflow_id, "replaying loop");

output
}
// Run loop
else {
tracing::info!(name=%self.name, id=%self.workflow_id, "running loop");

loop {
let iteration_idx = loop_branch.location_idx;

let mut iteration_branch = loop_branch.branch();
iteration_branch.loop_location = Some(loop_location.clone());

match cb(&mut iteration_branch).await? {
Loop::Continue => {
self.db
.update_loop(
self.workflow_id,
loop_location.as_ref(),
iteration_idx,
None,
self.loop_location(),
)
.await?;
}
Loop::Break(res) => {
let output_val = serde_json::to_value(&res)
.map_err(WorkflowError::SerializeLoopOutput)
.map_err(GlobalError::raw)?;

self.db
.update_loop(
self.workflow_id,
loop_location.as_ref(),
iteration_idx,
Some(output_val),
self.loop_location(),
)
.await?;

break res;
}
}
}
};

Ok(output)
}

// TODO: sleep_for, sleep_until
}

Expand Down Expand Up @@ -1105,3 +1208,8 @@ impl WorkflowCtx {
self.ts.saturating_sub(self.create_ts)
}
}

pub enum Loop<T> {
Continue,
Break(T),
}
23 changes: 23 additions & 0 deletions lib/chirp-workflow/core/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ pub trait Database: Send {
create_ts: i64,
input: serde_json::Value,
output: Result<serde_json::Value, &str>,
loop_location: Option<&[usize]>,
) -> WorkflowResult<()>;

async fn pull_next_signal(
&self,
workflow_id: Uuid,
filter: &[&str],
location: &[usize],
loop_location: Option<&[usize]>,
) -> WorkflowResult<Option<SignalRow>>;
async fn publish_signal(
&self,
Expand All @@ -95,6 +97,7 @@ pub trait Database: Send {
signal_id: Uuid,
signal_name: &str,
body: serde_json::Value,
loop_location: Option<&[usize]>,
) -> WorkflowResult<()>;
async fn publish_tagged_signal_from_workflow(
&self,
Expand All @@ -105,6 +108,7 @@ pub trait Database: Send {
signal_id: Uuid,
signal_name: &str,
body: serde_json::Value,
loop_location: Option<&[usize]>,
) -> WorkflowResult<()>;

async fn dispatch_sub_workflow(
Expand All @@ -116,6 +120,7 @@ pub trait Database: Send {
sub_workflow_name: &str,
tags: Option<&serde_json::Value>,
input: serde_json::Value,
loop_location: Option<&[usize]>,
) -> WorkflowResult<()>;

/// Fetches a workflow that has the given json as a subset of its input after the given ts.
Expand All @@ -133,6 +138,16 @@ pub trait Database: Send {
tags: &serde_json::Value,
message_name: &str,
body: serde_json::Value,
loop_location: Option<&[usize]>,
) -> WorkflowResult<()>;

async fn update_loop(
&self,
workflow_id: Uuid,
location: &[usize],
iteration: usize,
output: Option<serde_json::Value>,
loop_location: Option<&[usize]>,
) -> WorkflowResult<()>;
}

Expand Down Expand Up @@ -222,3 +237,11 @@ pub struct SignalRow {
pub signal_name: String,
pub body: serde_json::Value,
}

#[derive(sqlx::FromRow)]
pub struct LoopEventRow {
pub workflow_id: Uuid,
pub location: Vec<i64>,
pub output: Option<serde_json::Value>,
pub iteration: i64,
}
Loading