Skip to content

Commit

Permalink
feat(workflows): add message and signal history
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterPtato committed Jul 11, 2024
1 parent 64b4054 commit c6c91e3
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 137 deletions.
39 changes: 5 additions & 34 deletions lib/chirp-workflow/core/src/ctx/activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@ use global_error::{GlobalError, GlobalResult};
use rivet_pools::prelude::*;
use uuid::Uuid;

use crate::{
ctx::{MessageCtx, OperationCtx},
error::{WorkflowError, WorkflowResult},
message::Message,
DatabaseHandle, Operation, OperationInput,
};
use crate::{ctx::OperationCtx, error::WorkflowError, DatabaseHandle, Operation, OperationInput};

#[derive(Clone)]
pub struct ActivityCtx {
Expand All @@ -19,21 +14,20 @@ pub struct ActivityCtx {
db: DatabaseHandle,

conn: rivet_connection::Connection,
msg_ctx: MessageCtx,

// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
}

impl ActivityCtx {
pub async fn new(
pub fn new(
workflow_id: Uuid,
db: DatabaseHandle,
conn: &rivet_connection::Connection,
activity_create_ts: i64,
ray_id: Uuid,
name: &'static str,
) -> WorkflowResult<Self> {
) -> Self {
let ts = rivet_util::timestamp::now();
let req_id = Uuid::new_v4();
let conn = conn.wrap(req_id, ray_id, name);
Expand All @@ -49,18 +43,15 @@ impl ActivityCtx {
);
op_ctx.from_workflow = true;

let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;

Ok(ActivityCtx {
ActivityCtx {
workflow_id,
ray_id,
name,
ts,
db,
conn,
op_ctx,
msg_ctx,
})
}
}
}

Expand Down Expand Up @@ -95,26 +86,6 @@ impl ActivityCtx {
.await
.map_err(GlobalError::raw)
}

pub async fn msg<M>(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
self.msg_ctx
.message(tags, body)
.await
.map_err(GlobalError::raw)
}

pub async fn msg_wait<M>(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
self.msg_ctx
.message_wait(tags, body)
.await
.map_err(GlobalError::raw)
}
}

impl ActivityCtx {
Expand Down
134 changes: 110 additions & 24 deletions lib/chirp-workflow/core/src/ctx/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use uuid::Uuid;

use crate::{
activity::ActivityId,
ctx::{ActivityCtx, MessageCtx},
event::Event,
util::{self, Location},
Activity, ActivityCtx, ActivityInput, DatabaseHandle, Executable, Listen, PulledWorkflow,
RegistryHandle, Signal, SignalRow, Workflow, WorkflowError, WorkflowInput, WorkflowResult,
message::Message,
util::Location,
Activity, ActivityInput, DatabaseHandle, Executable, Listen, PulledWorkflow, RegistryHandle,
Signal, SignalRow, Workflow, WorkflowError, WorkflowInput, WorkflowResult,
};

// Time to delay a worker from retrying after an error
Expand Down Expand Up @@ -55,15 +57,19 @@ pub struct WorkflowCtx {

root_location: Location,
location_idx: usize,

msg_ctx: MessageCtx,
}

impl WorkflowCtx {
pub fn new(
pub async fn new(
registry: RegistryHandle,
db: DatabaseHandle,
conn: rivet_connection::Connection,
workflow: PulledWorkflow,
) -> GlobalResult<Self> {
let msg_ctx = MessageCtx::new(&conn, workflow.workflow_id, workflow.ray_id).await?;

Ok(WorkflowCtx {
workflow_id: workflow.workflow_id,
name: workflow.workflow_name,
Expand All @@ -77,18 +83,13 @@ impl WorkflowCtx {

conn,

event_history: Arc::new(
util::combine_events(
workflow.activity_events,
workflow.signal_events,
workflow.sub_workflow_events,
)
.map_err(GlobalError::raw)?,
),
event_history: Arc::new(workflow.events),
input: Arc::new(workflow.input),

root_location: Box::new([]),
location_idx: 0,

msg_ctx,
})
}

Expand Down Expand Up @@ -117,6 +118,8 @@ impl WorkflowCtx {
.chain(std::iter::once(self.location_idx))
.collect(),
location_idx: 0,

msg_ctx: self.msg_ctx.clone(),
};

self.location_idx += 1;
Expand Down Expand Up @@ -258,8 +261,7 @@ impl WorkflowCtx {
self.create_ts,
self.ray_id,
A::NAME,
)
.await?;
);

let res = tokio::time::timeout(A::TIMEOUT, A::run(&ctx, input))
.await
Expand Down Expand Up @@ -568,6 +570,8 @@ impl WorkflowCtx {
.chain(std::iter::once(self.location_idx))
.collect(),
location_idx: 0,

msg_ctx: self.msg_ctx.clone(),
};

self.location_idx += 1;
Expand Down Expand Up @@ -669,19 +673,39 @@ impl WorkflowCtx {
workflow_id: Uuid,
body: T,
) -> GlobalResult<Uuid> {
let signal_id = Uuid::new_v4();
let event = { self.relevant_history().nth(self.location_idx) };

tracing::info!(name=%T::NAME, %workflow_id, %signal_id, "dispatching signal");
// Signal sent before
let signal_id = if let Some(event) = event {
// Validate history is consistent
let Event::SignalSend(signal) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

// Serialize input
let input_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;
tracing::debug!(id=%self.workflow_id, signal_name=%signal.name, signal_id=%signal.signal_id, "replaying signal dispatch");

self.db
.publish_signal(self.ray_id, workflow_id, signal_id, T::NAME, input_val)
.await
.map_err(GlobalError::raw)?;
signal.signal_id
}
// Send signal
else {
let signal_id = Uuid::new_v4();
tracing::info!(id=%self.workflow_id, signal_name=%T::NAME, to_workflow_id=%workflow_id, %signal_id, "dispatching signal");

// Serialize input
let input_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;

self.db
.publish_signal(self.ray_id, workflow_id, signal_id, T::NAME, input_val)
.await
.map_err(GlobalError::raw)?;

signal_id
};

// Move to next event
self.location_idx += 1;

Ok(signal_id)
}
Expand Down Expand Up @@ -785,6 +809,68 @@ impl WorkflowCtx {
Ok(signal)
}

pub async fn msg<M>(&mut self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
let event = { self.relevant_history().nth(self.location_idx) };

// Message sent before
if let Some(event) = event {
// Validate history is consistent
let Event::MessageSend(msg) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

tracing::debug!(id=%self.workflow_id, msg_name=%msg.name, "replaying message dispatch");
}
// Send message
else {
tracing::info!(id=%self.workflow_id, msg_name=%M::NAME, ?tags, "dispatching message");

self.msg_ctx
.message(tags, body)
.await
.map_err(GlobalError::raw)?
}

// Move to next event
self.location_idx += 1;

Ok(())
}

pub async fn msg_wait<M>(&mut self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
let event = { self.relevant_history().nth(self.location_idx) };

// Message sent before
if let Some(event) = event {
// Validate history is consistent
let Event::MessageSend(msg) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

tracing::debug!(id=%self.workflow_id, msg_name=%msg.name, "replaying message dispatch");
}
// Send message
else {
tracing::info!(id=%self.workflow_id, msg_name=%M::NAME, ?tags, "dispatching message");

self.msg_ctx
.message_wait(tags, body)
.await
.map_err(GlobalError::raw)?
}

// Move to next event
self.location_idx += 1;

Ok(())
}

// TODO: sleep_for, sleep_until
}

Expand Down
25 changes: 20 additions & 5 deletions lib/chirp-workflow/core/src/db/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use uuid::Uuid;

use crate::{activity::ActivityId, Workflow, WorkflowError, WorkflowResult};
use crate::{
activity::ActivityId, event::Event, util::Location, Workflow, WorkflowError, WorkflowResult,
};

mod postgres;
pub use postgres::DatabasePostgres;
Expand Down Expand Up @@ -136,9 +138,7 @@ pub struct PulledWorkflow {
pub input: serde_json::Value,
pub wake_deadline_ts: Option<i64>,

pub activity_events: Vec<ActivityEventRow>,
pub signal_events: Vec<SignalEventRow>,
pub sub_workflow_events: Vec<SubWorkflowEventRow>,
pub events: HashMap<Location, Vec<Event>>,
}

#[derive(sqlx::FromRow)]
Expand All @@ -160,6 +160,21 @@ pub struct SignalEventRow {
pub body: serde_json::Value,
}

#[derive(sqlx::FromRow)]
pub struct SignalSendEventRow {
pub workflow_id: Uuid,
pub location: Vec<i64>,
pub signal_id: Uuid,
pub signal_name: String,
}

#[derive(sqlx::FromRow)]
pub struct MessageSendEventRow {
pub workflow_id: Uuid,
pub location: Vec<i64>,
pub message_name: String,
}

#[derive(sqlx::FromRow)]
pub struct SubWorkflowEventRow {
pub workflow_id: Uuid,
Expand Down
Loading

0 comments on commit c6c91e3

Please sign in to comment.