Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 5 additions & 34 deletions lib/chirp-workflow/core/src/ctx/activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@ use global_error::{GlobalError, GlobalResult};
use rivet_pools::prelude::*;
use uuid::Uuid;

use crate::{
ctx::{MessageCtx, OperationCtx},
error::{WorkflowError, WorkflowResult},
message::Message,
DatabaseHandle, Operation, OperationInput,
};
use crate::{ctx::OperationCtx, error::WorkflowError, DatabaseHandle, Operation, OperationInput};

#[derive(Clone)]
pub struct ActivityCtx {
Expand All @@ -19,21 +14,20 @@ pub struct ActivityCtx {
db: DatabaseHandle,

conn: rivet_connection::Connection,
msg_ctx: MessageCtx,

// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
}

impl ActivityCtx {
pub async fn new(
pub fn new(
workflow_id: Uuid,
db: DatabaseHandle,
conn: &rivet_connection::Connection,
activity_create_ts: i64,
ray_id: Uuid,
name: &'static str,
) -> WorkflowResult<Self> {
) -> Self {
let ts = rivet_util::timestamp::now();
let req_id = Uuid::new_v4();
let conn = conn.wrap(req_id, ray_id, name);
Expand All @@ -49,18 +43,15 @@ impl ActivityCtx {
);
op_ctx.from_workflow = true;

let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;

Ok(ActivityCtx {
ActivityCtx {
workflow_id,
ray_id,
name,
ts,
db,
conn,
op_ctx,
msg_ctx,
})
}
}
}

Expand Down Expand Up @@ -95,26 +86,6 @@ impl ActivityCtx {
.await
.map_err(GlobalError::raw)
}

pub async fn msg<M>(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
self.msg_ctx
.message(tags, body)
.await
.map_err(GlobalError::raw)
}

pub async fn msg_wait<M>(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
self.msg_ctx
.message_wait(tags, body)
.await
.map_err(GlobalError::raw)
}
}

impl ActivityCtx {
Expand Down
221 changes: 187 additions & 34 deletions lib/chirp-workflow/core/src/ctx/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use uuid::Uuid;

use crate::{
activity::ActivityId,
ctx::{ActivityCtx, MessageCtx},
event::Event,
util::{self, Location},
Activity, ActivityCtx, ActivityInput, DatabaseHandle, Executable, Listen, PulledWorkflow,
RegistryHandle, Signal, SignalRow, Workflow, WorkflowError, WorkflowInput, WorkflowResult,
message::Message,
util::Location,
Activity, ActivityInput, DatabaseHandle, Executable, Listen, PulledWorkflow, RegistryHandle,
Signal, SignalRow, Workflow, WorkflowError, WorkflowInput, WorkflowResult,
};

// Time to delay a worker from retrying after an error
Expand Down Expand Up @@ -55,15 +57,19 @@ pub struct WorkflowCtx {

root_location: Location,
location_idx: usize,

msg_ctx: MessageCtx,
}

impl WorkflowCtx {
pub fn new(
pub async fn new(
registry: RegistryHandle,
db: DatabaseHandle,
conn: rivet_connection::Connection,
workflow: PulledWorkflow,
) -> GlobalResult<Self> {
let msg_ctx = MessageCtx::new(&conn, workflow.workflow_id, workflow.ray_id).await?;

Ok(WorkflowCtx {
workflow_id: workflow.workflow_id,
name: workflow.workflow_name,
Expand All @@ -77,18 +83,13 @@ impl WorkflowCtx {

conn,

event_history: Arc::new(
util::combine_events(
workflow.activity_events,
workflow.signal_events,
workflow.sub_workflow_events,
)
.map_err(GlobalError::raw)?,
),
event_history: Arc::new(workflow.events),
input: Arc::new(workflow.input),

root_location: Box::new([]),
location_idx: 0,

msg_ctx,
})
}

Expand Down Expand Up @@ -117,6 +118,8 @@ impl WorkflowCtx {
.chain(std::iter::once(self.location_idx))
.collect(),
location_idx: 0,

msg_ctx: self.msg_ctx.clone(),
};

self.location_idx += 1;
Expand Down Expand Up @@ -258,8 +261,7 @@ impl WorkflowCtx {
self.create_ts,
self.ray_id,
A::NAME,
)
.await?;
);

let res = tokio::time::timeout(A::TIMEOUT, A::run(&ctx, input))
.await
Expand Down Expand Up @@ -568,6 +570,8 @@ impl WorkflowCtx {
.chain(std::iter::once(self.location_idx))
.collect(),
location_idx: 0,

msg_ctx: self.msg_ctx.clone(),
};

self.location_idx += 1;
Expand Down Expand Up @@ -669,19 +673,47 @@ impl WorkflowCtx {
workflow_id: Uuid,
body: T,
) -> GlobalResult<Uuid> {
let signal_id = Uuid::new_v4();
let event = { self.relevant_history().nth(self.location_idx) };

tracing::info!(name=%T::NAME, %workflow_id, %signal_id, "dispatching signal");
// Signal sent before
let signal_id = if let Some(event) = event {
// Validate history is consistent
let Event::SignalSend(signal) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

// Serialize input
let input_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;
tracing::debug!(id=%self.workflow_id, signal_name=%signal.name, signal_id=%signal.signal_id, "replaying signal dispatch");

self.db
.publish_signal(self.ray_id, workflow_id, signal_id, T::NAME, input_val)
.await
.map_err(GlobalError::raw)?;
signal.signal_id
}
// Send signal
else {
let signal_id = Uuid::new_v4();
tracing::info!(id=%self.workflow_id, signal_name=%T::NAME, to_workflow_id=%workflow_id, %signal_id, "dispatching signal");

// Serialize input
let input_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;

self.db
.publish_signal_from_workflow(
self.workflow_id,
self.full_location().as_ref(),
self.ray_id,
workflow_id,
signal_id,
T::NAME,
input_val,
)
.await
.map_err(GlobalError::raw)?;

signal_id
};

// Move to next event
self.location_idx += 1;

Ok(signal_id)
}
Expand All @@ -692,19 +724,48 @@ impl WorkflowCtx {
tags: &serde_json::Value,
body: T,
) -> GlobalResult<Uuid> {
let signal_id = Uuid::new_v4();
let event = { self.relevant_history().nth(self.location_idx) };

tracing::debug!(name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal");
// Signal sent before
let signal_id = if let Some(event) = event {
// Validate history is consistent
let Event::SignalSend(signal) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

// Serialize input
let input_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;
tracing::debug!(id=%self.workflow_id, signal_name=%signal.name, signal_id=%signal.signal_id, "replaying tagged signal dispatch");

self.db
.publish_tagged_signal(self.ray_id, tags, signal_id, T::NAME, input_val)
.await
.map_err(GlobalError::raw)?;
signal.signal_id
}
// Send signal
else {
let signal_id = Uuid::new_v4();

tracing::debug!(name=%T::NAME, ?tags, %signal_id, "dispatching tagged signal");

// Serialize input
let input_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeSignalBody)
.map_err(GlobalError::raw)?;

self.db
.publish_tagged_signal_from_workflow(
self.workflow_id,
self.full_location().as_ref(),
self.ray_id,
tags,
signal_id,
T::NAME,
input_val,
)
.await
.map_err(GlobalError::raw)?;

signal_id
};

// Move to next event
self.location_idx += 1;

Ok(signal_id)
}
Expand Down Expand Up @@ -785,6 +846,98 @@ impl WorkflowCtx {
Ok(signal)
}

pub async fn msg<M>(&mut self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
let event = { self.relevant_history().nth(self.location_idx) };

// Message sent before
if let Some(event) = event {
// Validate history is consistent
let Event::MessageSend(msg) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

tracing::debug!(id=%self.workflow_id, msg_name=%msg.name, "replaying message dispatch");
}
// Send message
else {
tracing::info!(id=%self.workflow_id, msg_name=%M::NAME, ?tags, "dispatching message");

// Serialize body
let body_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeWorkflowOutput)
.map_err(GlobalError::raw)?;
let location = self.full_location();

let (msg, write) = tokio::join!(
self.db.publish_message_from_workflow(
self.workflow_id,
location.as_ref(),
&tags,
M::NAME,
body_val
),
self.msg_ctx.message(tags.clone(), body),
);

msg.map_err(GlobalError::raw)?;
write.map_err(GlobalError::raw)?;
}

// Move to next event
self.location_idx += 1;

Ok(())
}

pub async fn msg_wait<M>(&mut self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
let event = { self.relevant_history().nth(self.location_idx) };

// Message sent before
if let Some(event) = event {
// Validate history is consistent
let Event::MessageSend(msg) = event else {
return Err(WorkflowError::HistoryDiverged).map_err(GlobalError::raw);
};

tracing::debug!(id=%self.workflow_id, msg_name=%msg.name, "replaying message dispatch");
}
// Send message
else {
tracing::info!(id=%self.workflow_id, msg_name=%M::NAME, ?tags, "dispatching message");

// Serialize body
let body_val = serde_json::to_value(&body)
.map_err(WorkflowError::SerializeWorkflowOutput)
.map_err(GlobalError::raw)?;
let location = self.full_location();

let (msg, write) = tokio::join!(
self.db.publish_message_from_workflow(
self.workflow_id,
location.as_ref(),
&tags,
M::NAME,
body_val
),
self.msg_ctx.message_wait(tags.clone(), body),
);

msg.map_err(GlobalError::raw)?;
write.map_err(GlobalError::raw)?;
}

// Move to next event
self.location_idx += 1;

Ok(())
}

// TODO: sleep_for, sleep_until
}

Expand Down
Loading