Skip to content

Commit

Permalink
fix(workflows): fix listening traits
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterPtato committed Jul 12, 2024
1 parent cdba2f3 commit 7d5d59d
Show file tree
Hide file tree
Showing 14 changed files with 202 additions and 79 deletions.
7 changes: 7 additions & 0 deletions docs/libraries/workflow/GOTCHAS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@ This will be the current timestamp on the first execution of the activity and wo
## Randomly generated content

Randomly generated content like UUIDs should be placed in activities for consistent history.

## `WorkflowCtx::spawn`

`WorkflowCtx::spawn` allows you to run workflow steps in a different thread and returns its join handle. Be
**very careful** when using it because it is the developers responsibility to make sure it's result is handled
correctly. If a spawn thread errors but its result is not handled, the main thread may continue as though no
error occurred.
10 changes: 8 additions & 2 deletions lib/bolt/core/src/tasks/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,6 @@ struct SshKey {
id: u64,
}

// TODO: This only deletes linodes and firewalls, the ssh key still remains
async fn cleanup_servers(ctx: &ProjectContext) -> Result<()> {
if ctx.ns().rivet.provisioning.is_none() {
return Ok(());
Expand All @@ -584,6 +583,7 @@ async fn cleanup_servers(ctx: &ProjectContext) -> Result<()> {
rivet_term::status::progress("Cleaning up servers", "");

let ns = ctx.ns_id();
let ns_full = format!("rivet-{ns}");

// Create client
let api_token = ctx.read_secret(&["linode", "token"]).await?;
Expand Down Expand Up @@ -642,7 +642,13 @@ async fn cleanup_servers(ctx: &ProjectContext) -> Result<()> {
.data
.into_iter()
// Only delete test objects from this namespace
.filter(|object| object.data.tags.iter().any(|tag| tag == ns))
.filter(|object| {
object
.data
.tags
.iter()
.any(|tag| tag == ns || tag == &ns_full)
})
.map(|object| {
let client = client.clone();
let obj_type = object._type;
Expand Down
11 changes: 9 additions & 2 deletions lib/chirp-workflow/core/src/ctx/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ impl ApiCtx {
Ok(signal_id)
}

#[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))]
pub async fn op<I>(
&self,
input: I,
Expand All @@ -231,6 +232,8 @@ impl ApiCtx {
I: OperationInput,
<I as OperationInput>::Operation: Operation<Input = I>,
{
tracing::info!(?input, "operation call");

let ctx = OperationCtx::new(
self.db.clone(),
&self.conn,
Expand All @@ -240,10 +243,14 @@ impl ApiCtx {
I::Operation::NAME,
);

I::Operation::run(&ctx, &input)
let res = I::Operation::run(&ctx, &input)
.await
.map_err(WorkflowError::OperationFailure)
.map_err(GlobalError::raw)
.map_err(GlobalError::raw);

tracing::info!(?res, "operation response");

res
}

pub async fn subscribe<M>(
Expand Down
44 changes: 44 additions & 0 deletions lib/chirp-workflow/core/src/ctx/listen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use crate::{
ctx::WorkflowCtx,
db::SignalRow,
error::{WorkflowError, WorkflowResult},
};

/// Indirection struct to prevent invalid implementations of listen traits.
pub struct ListenCtx<'a> {
ctx: &'a mut WorkflowCtx,
}

impl<'a> ListenCtx<'a> {
pub(crate) fn new(ctx: &'a mut WorkflowCtx) -> Self {
ListenCtx { ctx }
}

/// Checks for a signal to this workflow with any of the given signal names.
pub async fn listen_any(&self, signal_names: &[&'static str]) -> WorkflowResult<SignalRow> {
// Fetch new pending signal
let signal = self
.ctx
.db
.pull_next_signal(
self.ctx.workflow_id(),
signal_names,
self.ctx.full_location().as_ref(),
)
.await?;

let Some(signal) = signal else {
return Err(WorkflowError::NoSignalFound(Box::from(signal_names)));
};

tracing::info!(
workflow_name=%self.ctx.name(),
workflow_id=%self.ctx.workflow_id(),
signal_id=%signal.signal_id,
signal_name=%signal.signal_name,
"signal received",
);

Ok(signal)
}
}
2 changes: 2 additions & 0 deletions lib/chirp-workflow/core/src/ctx/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
mod activity;
pub(crate) mod api;
mod listen;
pub mod message;
mod operation;
mod standalone;
mod test;
pub(crate) mod workflow;
pub use activity::ActivityCtx;
pub use api::ApiCtx;
pub use listen::ListenCtx;
pub use message::MessageCtx;
pub use operation::OperationCtx;
pub use standalone::StandaloneCtx;
Expand Down
100 changes: 64 additions & 36 deletions lib/chirp-workflow/core/src/ctx/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ use uuid::Uuid;

use crate::{
activity::ActivityId,
ctx::{ActivityCtx, MessageCtx},
ctx::{ActivityCtx, ListenCtx, MessageCtx},
event::Event,
executable::{closure, AsyncResult, Executable},
listen::{CustomListener, Listen},
message::Message,
signal::Signal,
util::Location,
Activity, ActivityInput, DatabaseHandle, Listen, PulledWorkflow, RegistryHandle, Signal,
SignalRow, Workflow, WorkflowError, WorkflowInput, WorkflowResult,
Activity, ActivityInput, DatabaseHandle, PulledWorkflow, RegistryHandle, Workflow,
WorkflowError, WorkflowInput, WorkflowResult,
};

// Time to delay a worker from retrying after an error
Expand Down Expand Up @@ -43,7 +45,7 @@ pub struct WorkflowCtx {
ray_id: Uuid,

registry: RegistryHandle,
db: DatabaseHandle,
pub(crate) db: DatabaseHandle,

conn: rivet_connection::Connection,

Expand Down Expand Up @@ -148,7 +150,7 @@ impl WorkflowCtx {
.flatten()
}

fn full_location(&self) -> Location {
pub(crate) fn full_location(&self) -> Location {
self.root_location
.iter()
.cloned()
Expand Down Expand Up @@ -330,34 +332,6 @@ impl WorkflowCtx {
}
}
}

/// Checks for a signal to this workflow with any of the given signal names. Meant to be implemented and
/// not used directly in workflows.
pub async fn listen_any(&mut self, signal_names: &[&'static str]) -> WorkflowResult<SignalRow> {
// Fetch new pending signal
let signal = self
.db
.pull_next_signal(
self.workflow_id,
signal_names,
self.full_location().as_ref(),
)
.await?;

let Some(signal) = signal else {
return Err(WorkflowError::NoSignalFound(Box::from(signal_names)));
};

tracing::info!(
workflow_name=%self.name,
workflow_id=%self.workflow_id,
signal_id=%signal.signal_id,
signal_name=%signal.signal_name,
"signal received",
);

Ok(signal)
}
}

impl WorkflowCtx {
Expand Down Expand Up @@ -752,7 +726,7 @@ impl WorkflowCtx {
else {
let signal_id = Uuid::new_v4();

tracing::debug!(name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal");
tracing::info!(name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal");

// Serialize input
let input_val = serde_json::to_value(&body)
Expand Down Expand Up @@ -805,10 +779,62 @@ impl WorkflowCtx {
let mut interval = tokio::time::interval(SIGNAL_RETRY);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);

let ctx = ListenCtx::new(self);

loop {
interval.tick().await;

match T::listen(self).await {
match T::listen(&ctx).await {
Ok(res) => break res,
Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => {
if retries > MAX_SIGNAL_RETRIES {
return Err(err).map_err(GlobalError::raw);
}
retries += 1;
}
err => return err.map_err(GlobalError::raw),
}
}
};

// Move to next event
self.location_idx += 1;

Ok(signal)
}

/// Execute a custom listener.
pub async fn custom_listener<T: CustomListener>(
&mut self,
listener: &T,
) -> GlobalResult<<T as CustomListener>::Output> {
let event = { self.relevant_history().nth(self.location_idx) };

// Signal received before
let signal = if let Some(event) = event {
// Validate history is consistent
let Event::Signal(signal) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

tracing::debug!(name=%self.name, id=%self.workflow_id, signal_name=%signal.name, "replaying signal");

T::parse(&signal.name, signal.body.clone()).map_err(GlobalError::raw)?
}
// Listen for new messages
else {
tracing::info!(name=%self.name, id=%self.workflow_id, "listening for signal");

let mut retries = 0;
let mut interval = tokio::time::interval(SIGNAL_RETRY);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);

let ctx = ListenCtx::new(self);

loop {
interval.tick().await;

match listener.listen(&ctx).await {
Ok(res) => break res,
Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => {
if retries > MAX_SIGNAL_RETRIES {
Expand Down Expand Up @@ -844,7 +870,9 @@ impl WorkflowCtx {
}
// Listen for new message
else {
match T::listen(self).await {
let ctx = ListenCtx::new(self);

match T::listen(&ctx).await {
Ok(res) => Some(res),
Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => None,
Err(err) => return Err(err).map_err(GlobalError::raw),
Expand Down
2 changes: 1 addition & 1 deletion lib/chirp-workflow/core/src/db/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ impl Database for DatabasePostgres {
)
RETURNING 1
),
-- After deleting the signal, add it to the events table (i.e. acknowledge it)
-- After acking the signal, add it to the events table
insert_event AS (
INSERT INTO db_workflow.workflow_signal_events (
workflow_id, location, signal_id, signal_name, body, ack_ts
Expand Down
1 change: 1 addition & 0 deletions lib/chirp-workflow/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod db;
mod error;
mod event;
mod executable;
mod listen;
pub mod message;
pub mod operation;
pub mod prelude;
Expand Down
22 changes: 22 additions & 0 deletions lib/chirp-workflow/core/src/listen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use async_trait::async_trait;

use crate::{ctx::ListenCtx, error::WorkflowResult};

/// A trait which allows listening for signals from the workflows database. This is used by
/// `WorkflowCtx::listen` and `WorkflowCtx::query_signal`. If you need a listener with state, use
/// `CustomListener`.
#[async_trait]
pub trait Listen: Sized {
async fn listen(ctx: &ListenCtx) -> WorkflowResult<Self>;
fn parse(name: &str, body: serde_json::Value) -> WorkflowResult<Self>;
}

/// A trait which allows listening for signals with a custom state. This is used by
/// `WorkflowCtx::custom_listener`.
#[async_trait]
pub trait CustomListener: Sized {
type Output;

async fn listen(&self, ctx: &ListenCtx) -> WorkflowResult<Self::Output>;
fn parse(name: &str, body: serde_json::Value) -> WorkflowResult<Self::Output>;
}
6 changes: 4 additions & 2 deletions lib/chirp-workflow/core/src/operation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::Debug;

use async_trait::async_trait;
use global_error::GlobalResult;

Expand All @@ -6,14 +8,14 @@ use crate::OperationCtx;
#[async_trait]
pub trait Operation {
type Input: OperationInput;
type Output: Send;
type Output: Debug + Send;

const NAME: &'static str;
const TIMEOUT: std::time::Duration;

async fn run(ctx: &OperationCtx, input: &Self::Input) -> GlobalResult<Self::Output>;
}

pub trait OperationInput: Send {
pub trait OperationInput: Debug + Send {
type Operation: Operation;
}
3 changes: 2 additions & 1 deletion lib/chirp-workflow/core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ pub use crate::{
error::{WorkflowError, WorkflowResult},
executable::closure,
executable::Executable,
listen::{CustomListener, Listen},
message::Message,
operation::Operation,
registry::Registry,
signal::{join_signal, Listen, Signal},
signal::{join_signal, Signal},
util::GlobalErrorExt,
worker::Worker,
workflow::Workflow,
Expand Down
14 changes: 1 addition & 13 deletions lib/chirp-workflow/core/src/signal.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
use async_trait::async_trait;

use crate::{WorkflowCtx, WorkflowResult};

pub trait Signal {
const NAME: &'static str;
}

/// A trait which allows listening for signals from the workflows database. This is used by
/// `WorkflowCtx::listen` and `WorkflowCtx::query_signal`.
#[async_trait]
pub trait Listen: Sized {
async fn listen(ctx: &mut WorkflowCtx) -> WorkflowResult<Self>;
fn parse(name: &str, body: serde_json::Value) -> WorkflowResult<Self>;
}

/// Creates an enum that implements `Listen` and selects one of X signals.
///
/// Example:
Expand Down Expand Up @@ -68,7 +56,7 @@ macro_rules! join_signal {
(@ $join:ident, [$($signals:ident),*]) => {
#[async_trait::async_trait]
impl Listen for $join {
async fn listen(ctx: &mut chirp_workflow::prelude::WorkflowCtx) -> chirp_workflow::prelude::WorkflowResult<Self> {
async fn listen(ctx: &chirp_workflow::prelude::ListenCtx) -> chirp_workflow::prelude::WorkflowResult<Self> {
let row = ctx.listen_any(&[$($signals::NAME),*]).await?;
Self::parse(&row.signal_name, row.body)
}
Expand Down
Loading

0 comments on commit 7d5d59d

Please sign in to comment.