Skip to content

Commit

Permalink
feat(workflows): add messages
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterPtato committed Jul 5, 2024
1 parent 5cea2f5 commit 64b4054
Show file tree
Hide file tree
Showing 16 changed files with 1,041 additions and 24 deletions.
13 changes: 12 additions & 1 deletion lib/api-helper/build/src/anchor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,18 @@ impl WatchIndexQuery {
/// Converts the `WatchIndexQuery` into a `TailAnchor` for use with the Chirp client.
pub fn to_consumer(self) -> Result<Option<TailAnchor>, ClientError> {
if let Some(watch_index) = self.watch_index {
Ok(Some(chirp_client::TailAnchor {
Ok(Some(TailAnchor {
start_time: watch_index.parse()?,
}))
} else {
Ok(None)
}
}

/// Converts the `WatchIndexQuery` into a `TailAnchor` for use with Chirp workflows.
pub fn to_workflow(self) -> Result<Option<chirp_workflow::ctx::message::TailAnchor>, ClientError> {
if let Some(watch_index) = self.watch_index {
Ok(Some(chirp_workflow::ctx::message::TailAnchor {
start_time: watch_index.parse()?,
}))
} else {
Expand Down
2 changes: 1 addition & 1 deletion lib/api-helper/build/src/macro_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ pub async fn __with_ctx<A: auth::ApiAuth + Send>(
);
let conn = rivet_connection::Connection::new(client, pools.clone(), cache.clone());
let db = chirp_workflow::compat::db_from_pools(&pools).await?;
let internal_ctx = ApiCtx::new(db, conn, req_id, ray_id, ts, svc_name);
let internal_ctx = ApiCtx::new(db, conn, req_id, ray_id, ts, svc_name).await?;

// Create auth
let rate_limit_ctx = AuthRateLimitCtx {
Expand Down
2 changes: 2 additions & 0 deletions lib/chirp-workflow/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ license = "Apache-2.0"
async-trait = "0.1.80"
chirp-client = { path = "../../chirp/client" }
chirp-workflow-macros = { path = "../macros" }
cjson = "0.1"
formatted-error = { path = "../../formatted-error" }
futures-util = "0.3"
global-error = { path = "../../global-error" }
Expand All @@ -27,6 +28,7 @@ serde = { version = "1.0.198", features = ["derive"] }
serde_json = "1.0.116"
thiserror = "1.0.59"
tokio = { version = "1.37.0", features = ["full"] }
tokio-util = "0.7"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
uuid = { version = "1.8.0", features = ["v4", "serde"] }
Expand Down
39 changes: 34 additions & 5 deletions lib/chirp-workflow/core/src/ctx/activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ use global_error::{GlobalError, GlobalResult};
use rivet_pools::prelude::*;
use uuid::Uuid;

use crate::{ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, WorkflowError};
use crate::{
ctx::{MessageCtx, OperationCtx},
error::{WorkflowError, WorkflowResult},
message::Message,
DatabaseHandle, Operation, OperationInput,
};

#[derive(Clone)]
pub struct ActivityCtx {
Expand All @@ -14,20 +19,21 @@ pub struct ActivityCtx {
db: DatabaseHandle,

conn: rivet_connection::Connection,
msg_ctx: MessageCtx,

// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
}

impl ActivityCtx {
pub fn new(
pub async fn new(
workflow_id: Uuid,
db: DatabaseHandle,
conn: &rivet_connection::Connection,
activity_create_ts: i64,
ray_id: Uuid,
name: &'static str,
) -> Self {
) -> WorkflowResult<Self> {
let ts = rivet_util::timestamp::now();
let req_id = Uuid::new_v4();
let conn = conn.wrap(req_id, ray_id, name);
Expand All @@ -43,15 +49,18 @@ impl ActivityCtx {
);
op_ctx.from_workflow = true;

ActivityCtx {
let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;

Ok(ActivityCtx {
workflow_id,
ray_id,
name,
ts,
db,
conn,
op_ctx,
}
msg_ctx,
})
}
}

Expand Down Expand Up @@ -86,6 +95,26 @@ impl ActivityCtx {
.await
.map_err(GlobalError::raw)
}

pub async fn msg<M>(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
self.msg_ctx
.message(tags, body)
.await
.map_err(GlobalError::raw)
}

pub async fn msg_wait<M>(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
self.msg_ctx
.message_wait(tags, body)
.await
.map_err(GlobalError::raw)
}
}

impl ActivityCtx {
Expand Down
61 changes: 55 additions & 6 deletions lib/chirp-workflow/core/src/ctx/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ use serde::Serialize;
use uuid::Uuid;

use crate::{
ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, Signal, Workflow, WorkflowError,
WorkflowInput,
ctx::{
message::{SubscriptionHandle, TailAnchor, TailAnchorResponse},
MessageCtx, OperationCtx,
},
error::WorkflowResult,
message::{Message, ReceivedMessage},
DatabaseHandle, Operation, OperationInput, Signal, Workflow, WorkflowError, WorkflowInput,
};

pub const WORKFLOW_TIMEOUT: Duration = Duration::from_secs(60);
Expand All @@ -20,20 +25,21 @@ pub struct ApiCtx {
db: DatabaseHandle,

conn: rivet_connection::Connection,
msg_ctx: MessageCtx,

// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
}

impl ApiCtx {
pub fn new(
pub async fn new(
db: DatabaseHandle,
conn: rivet_connection::Connection,
req_id: Uuid,
ray_id: Uuid,
ts: i64,
name: &'static str,
) -> Self {
) -> WorkflowResult<Self> {
let op_ctx = rivet_operation::OperationContext::new(
name.to_string(),
std::time::Duration::from_secs(60),
Expand All @@ -45,14 +51,17 @@ impl ApiCtx {
(),
);

ApiCtx {
let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;

Ok(ApiCtx {
ray_id,
name,
ts,
db,
conn,
op_ctx,
}
msg_ctx,
})
}
}

Expand Down Expand Up @@ -243,6 +252,46 @@ impl ApiCtx {
.map_err(WorkflowError::OperationFailure)
.map_err(GlobalError::raw)
}

pub async fn subscribe<M>(
&self,
tags: &serde_json::Value,
) -> GlobalResult<SubscriptionHandle<M>>
where
M: Message,
{
self.msg_ctx
.subscribe::<M>(tags)
.await
.map_err(GlobalError::raw)
}

pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> GlobalResult<Option<ReceivedMessage<M>>>
where
M: Message,
{
self.msg_ctx
.tail_read::<M>(tags)
.await
.map_err(GlobalError::raw)
}

pub async fn tail_anchor<M>(
&self,
tags: serde_json::Value,
anchor: &TailAnchor,
) -> GlobalResult<TailAnchorResponse<M>>
where
M: Message,
{
self.msg_ctx
.tail_anchor::<M>(tags, anchor)
.await
.map_err(GlobalError::raw)
}
}

impl ApiCtx {
Expand Down
Loading

0 comments on commit 64b4054

Please sign in to comment.