Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion lib/api-helper/build/src/anchor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,18 @@ impl WatchIndexQuery {
/// Converts the `WatchIndexQuery` into a `TailAnchor` for use with the Chirp client.
pub fn to_consumer(self) -> Result<Option<TailAnchor>, ClientError> {
if let Some(watch_index) = self.watch_index {
Ok(Some(chirp_client::TailAnchor {
Ok(Some(TailAnchor {
start_time: watch_index.parse()?,
}))
} else {
Ok(None)
}
}

/// Converts the `WatchIndexQuery` into a `TailAnchor` for use with Chirp workflows.
pub fn to_workflow(self) -> Result<Option<chirp_workflow::ctx::message::TailAnchor>, ClientError> {
if let Some(watch_index) = self.watch_index {
Ok(Some(chirp_workflow::ctx::message::TailAnchor {
start_time: watch_index.parse()?,
}))
} else {
Expand Down
2 changes: 1 addition & 1 deletion lib/api-helper/build/src/macro_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ pub async fn __with_ctx<A: auth::ApiAuth + Send>(
);
let conn = rivet_connection::Connection::new(client, pools.clone(), cache.clone());
let db = chirp_workflow::compat::db_from_pools(&pools).await?;
let internal_ctx = ApiCtx::new(db, conn, req_id, ray_id, ts, svc_name);
let internal_ctx = ApiCtx::new(db, conn, req_id, ray_id, ts, svc_name).await?;

// Create auth
let rate_limit_ctx = AuthRateLimitCtx {
Expand Down
2 changes: 2 additions & 0 deletions lib/chirp-workflow/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ license = "Apache-2.0"
async-trait = "0.1.80"
chirp-client = { path = "../../chirp/client" }
chirp-workflow-macros = { path = "../macros" }
cjson = "0.1"
formatted-error = { path = "../../formatted-error" }
futures-util = "0.3"
global-error = { path = "../../global-error" }
Expand All @@ -27,6 +28,7 @@ serde = { version = "1.0.198", features = ["derive"] }
serde_json = "1.0.116"
thiserror = "1.0.59"
tokio = { version = "1.37.0", features = ["full"] }
tokio-util = "0.7"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
uuid = { version = "1.8.0", features = ["v4", "serde"] }
Expand Down
39 changes: 34 additions & 5 deletions lib/chirp-workflow/core/src/ctx/activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ use global_error::{GlobalError, GlobalResult};
use rivet_pools::prelude::*;
use uuid::Uuid;

use crate::{ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, WorkflowError};
use crate::{
ctx::{MessageCtx, OperationCtx},
error::{WorkflowError, WorkflowResult},
message::Message,
DatabaseHandle, Operation, OperationInput,
};

#[derive(Clone)]
pub struct ActivityCtx {
Expand All @@ -14,20 +19,21 @@ pub struct ActivityCtx {
db: DatabaseHandle,

conn: rivet_connection::Connection,
msg_ctx: MessageCtx,

// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
}

impl ActivityCtx {
pub fn new(
pub async fn new(
workflow_id: Uuid,
db: DatabaseHandle,
conn: &rivet_connection::Connection,
activity_create_ts: i64,
ray_id: Uuid,
name: &'static str,
) -> Self {
) -> WorkflowResult<Self> {
let ts = rivet_util::timestamp::now();
let req_id = Uuid::new_v4();
let conn = conn.wrap(req_id, ray_id, name);
Expand All @@ -43,15 +49,18 @@ impl ActivityCtx {
);
op_ctx.from_workflow = true;

ActivityCtx {
let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;

Ok(ActivityCtx {
workflow_id,
ray_id,
name,
ts,
db,
conn,
op_ctx,
}
msg_ctx,
})
}
}

Expand Down Expand Up @@ -86,6 +95,26 @@ impl ActivityCtx {
.await
.map_err(GlobalError::raw)
}

pub async fn msg<M>(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
self.msg_ctx
.message(tags, body)
.await
.map_err(GlobalError::raw)
}

pub async fn msg_wait<M>(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
where
M: Message,
{
self.msg_ctx
.message_wait(tags, body)
.await
.map_err(GlobalError::raw)
}
}

impl ActivityCtx {
Expand Down
61 changes: 55 additions & 6 deletions lib/chirp-workflow/core/src/ctx/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ use serde::Serialize;
use uuid::Uuid;

use crate::{
ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, Signal, Workflow, WorkflowError,
WorkflowInput,
ctx::{
message::{SubscriptionHandle, TailAnchor, TailAnchorResponse},
MessageCtx, OperationCtx,
},
error::WorkflowResult,
message::{Message, ReceivedMessage},
DatabaseHandle, Operation, OperationInput, Signal, Workflow, WorkflowError, WorkflowInput,
};

pub const WORKFLOW_TIMEOUT: Duration = Duration::from_secs(60);
Expand All @@ -20,20 +25,21 @@ pub struct ApiCtx {
db: DatabaseHandle,

conn: rivet_connection::Connection,
msg_ctx: MessageCtx,

// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
}

impl ApiCtx {
pub fn new(
pub async fn new(
db: DatabaseHandle,
conn: rivet_connection::Connection,
req_id: Uuid,
ray_id: Uuid,
ts: i64,
name: &'static str,
) -> Self {
) -> WorkflowResult<Self> {
let op_ctx = rivet_operation::OperationContext::new(
name.to_string(),
std::time::Duration::from_secs(60),
Expand All @@ -45,14 +51,17 @@ impl ApiCtx {
(),
);

ApiCtx {
let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;

Ok(ApiCtx {
ray_id,
name,
ts,
db,
conn,
op_ctx,
}
msg_ctx,
})
}
}

Expand Down Expand Up @@ -243,6 +252,46 @@ impl ApiCtx {
.map_err(WorkflowError::OperationFailure)
.map_err(GlobalError::raw)
}

pub async fn subscribe<M>(
&self,
tags: &serde_json::Value,
) -> GlobalResult<SubscriptionHandle<M>>
where
M: Message,
{
self.msg_ctx
.subscribe::<M>(tags)
.await
.map_err(GlobalError::raw)
}

pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> GlobalResult<Option<ReceivedMessage<M>>>
where
M: Message,
{
self.msg_ctx
.tail_read::<M>(tags)
.await
.map_err(GlobalError::raw)
}

pub async fn tail_anchor<M>(
&self,
tags: serde_json::Value,
anchor: &TailAnchor,
) -> GlobalResult<TailAnchorResponse<M>>
where
M: Message,
{
self.msg_ctx
.tail_anchor::<M>(tags, anchor)
.await
.map_err(GlobalError::raw)
}
}

impl ApiCtx {
Expand Down
Loading