diff --git a/lib/api-helper/build/src/anchor.rs b/lib/api-helper/build/src/anchor.rs
index 3222290ecc..582a212a0e 100644
--- a/lib/api-helper/build/src/anchor.rs
+++ b/lib/api-helper/build/src/anchor.rs
@@ -25,7 +25,18 @@ impl WatchIndexQuery {
/// Converts the `WatchIndexQuery` into a `TailAnchor` for use with the Chirp client.
pub fn to_consumer(self) -> Result, ClientError> {
if let Some(watch_index) = self.watch_index {
- Ok(Some(chirp_client::TailAnchor {
+ Ok(Some(TailAnchor {
+ start_time: watch_index.parse()?,
+ }))
+ } else {
+ Ok(None)
+ }
+ }
+
+ /// Converts the `WatchIndexQuery` into a `TailAnchor` for use with Chirp workflows.
+ pub fn to_workflow(self) -> Result , ClientError> {
+ if let Some(watch_index) = self.watch_index {
+ Ok(Some(chirp_workflow::ctx::message::TailAnchor {
start_time: watch_index.parse()?,
}))
} else {
diff --git a/lib/api-helper/build/src/macro_util.rs b/lib/api-helper/build/src/macro_util.rs
index 4e000213e8..83f760741a 100644
--- a/lib/api-helper/build/src/macro_util.rs
+++ b/lib/api-helper/build/src/macro_util.rs
@@ -325,7 +325,7 @@ pub async fn __with_ctx(
);
let conn = rivet_connection::Connection::new(client, pools.clone(), cache.clone());
let db = chirp_workflow::compat::db_from_pools(&pools).await?;
- let internal_ctx = ApiCtx::new(db, conn, req_id, ray_id, ts, svc_name);
+ let internal_ctx = ApiCtx::new(db, conn, req_id, ray_id, ts, svc_name).await?;
// Create auth
let rate_limit_ctx = AuthRateLimitCtx {
diff --git a/lib/chirp-workflow/core/Cargo.toml b/lib/chirp-workflow/core/Cargo.toml
index 28ba4c9d53..8f5bc697d5 100644
--- a/lib/chirp-workflow/core/Cargo.toml
+++ b/lib/chirp-workflow/core/Cargo.toml
@@ -9,6 +9,7 @@ license = "Apache-2.0"
async-trait = "0.1.80"
chirp-client = { path = "../../chirp/client" }
chirp-workflow-macros = { path = "../macros" }
+cjson = "0.1"
formatted-error = { path = "../../formatted-error" }
futures-util = "0.3"
global-error = { path = "../../global-error" }
@@ -27,6 +28,7 @@ serde = { version = "1.0.198", features = ["derive"] }
serde_json = "1.0.116"
thiserror = "1.0.59"
tokio = { version = "1.37.0", features = ["full"] }
+tokio-util = "0.7"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
uuid = { version = "1.8.0", features = ["v4", "serde"] }
diff --git a/lib/chirp-workflow/core/src/ctx/activity.rs b/lib/chirp-workflow/core/src/ctx/activity.rs
index 391b6a59b7..858f854141 100644
--- a/lib/chirp-workflow/core/src/ctx/activity.rs
+++ b/lib/chirp-workflow/core/src/ctx/activity.rs
@@ -2,7 +2,12 @@ use global_error::{GlobalError, GlobalResult};
use rivet_pools::prelude::*;
use uuid::Uuid;
-use crate::{ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, WorkflowError};
+use crate::{
+ ctx::{MessageCtx, OperationCtx},
+ error::{WorkflowError, WorkflowResult},
+ message::Message,
+ DatabaseHandle, Operation, OperationInput,
+};
#[derive(Clone)]
pub struct ActivityCtx {
@@ -14,20 +19,21 @@ pub struct ActivityCtx {
db: DatabaseHandle,
conn: rivet_connection::Connection,
+ msg_ctx: MessageCtx,
// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
}
impl ActivityCtx {
- pub fn new(
+ pub async fn new(
workflow_id: Uuid,
db: DatabaseHandle,
conn: &rivet_connection::Connection,
activity_create_ts: i64,
ray_id: Uuid,
name: &'static str,
- ) -> Self {
+ ) -> WorkflowResult {
let ts = rivet_util::timestamp::now();
let req_id = Uuid::new_v4();
let conn = conn.wrap(req_id, ray_id, name);
@@ -43,7 +49,9 @@ impl ActivityCtx {
);
op_ctx.from_workflow = true;
- ActivityCtx {
+ let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;
+
+ Ok(ActivityCtx {
workflow_id,
ray_id,
name,
@@ -51,7 +59,8 @@ impl ActivityCtx {
db,
conn,
op_ctx,
- }
+ msg_ctx,
+ })
}
}
@@ -86,6 +95,26 @@ impl ActivityCtx {
.await
.map_err(GlobalError::raw)
}
+
+ pub async fn msg(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
+ where
+ M: Message,
+ {
+ self.msg_ctx
+ .message(tags, body)
+ .await
+ .map_err(GlobalError::raw)
+ }
+
+ pub async fn msg_wait(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
+ where
+ M: Message,
+ {
+ self.msg_ctx
+ .message_wait(tags, body)
+ .await
+ .map_err(GlobalError::raw)
+ }
}
impl ActivityCtx {
diff --git a/lib/chirp-workflow/core/src/ctx/api.rs b/lib/chirp-workflow/core/src/ctx/api.rs
index 8c11c17826..2228724330 100644
--- a/lib/chirp-workflow/core/src/ctx/api.rs
+++ b/lib/chirp-workflow/core/src/ctx/api.rs
@@ -6,8 +6,13 @@ use serde::Serialize;
use uuid::Uuid;
use crate::{
- ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, Signal, Workflow, WorkflowError,
- WorkflowInput,
+ ctx::{
+ message::{SubscriptionHandle, TailAnchor, TailAnchorResponse},
+ MessageCtx, OperationCtx,
+ },
+ error::WorkflowResult,
+ message::{Message, ReceivedMessage},
+ DatabaseHandle, Operation, OperationInput, Signal, Workflow, WorkflowError, WorkflowInput,
};
pub const WORKFLOW_TIMEOUT: Duration = Duration::from_secs(60);
@@ -20,20 +25,21 @@ pub struct ApiCtx {
db: DatabaseHandle,
conn: rivet_connection::Connection,
+ msg_ctx: MessageCtx,
// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
}
impl ApiCtx {
- pub fn new(
+ pub async fn new(
db: DatabaseHandle,
conn: rivet_connection::Connection,
req_id: Uuid,
ray_id: Uuid,
ts: i64,
name: &'static str,
- ) -> Self {
+ ) -> WorkflowResult {
let op_ctx = rivet_operation::OperationContext::new(
name.to_string(),
std::time::Duration::from_secs(60),
@@ -45,14 +51,17 @@ impl ApiCtx {
(),
);
- ApiCtx {
+ let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;
+
+ Ok(ApiCtx {
ray_id,
name,
ts,
db,
conn,
op_ctx,
- }
+ msg_ctx,
+ })
}
}
@@ -243,6 +252,46 @@ impl ApiCtx {
.map_err(WorkflowError::OperationFailure)
.map_err(GlobalError::raw)
}
+
+ pub async fn subscribe(
+ &self,
+ tags: &serde_json::Value,
+ ) -> GlobalResult>
+ where
+ M: Message,
+ {
+ self.msg_ctx
+ .subscribe::(tags)
+ .await
+ .map_err(GlobalError::raw)
+ }
+
+ pub async fn tail_read(
+ &self,
+ tags: serde_json::Value,
+ ) -> GlobalResult>>
+ where
+ M: Message,
+ {
+ self.msg_ctx
+ .tail_read::(tags)
+ .await
+ .map_err(GlobalError::raw)
+ }
+
+ pub async fn tail_anchor(
+ &self,
+ tags: serde_json::Value,
+ anchor: &TailAnchor,
+ ) -> GlobalResult>
+ where
+ M: Message,
+ {
+ self.msg_ctx
+ .tail_anchor::(tags, anchor)
+ .await
+ .map_err(GlobalError::raw)
+ }
}
impl ApiCtx {
diff --git a/lib/chirp-workflow/core/src/ctx/message.rs b/lib/chirp-workflow/core/src/ctx/message.rs
new file mode 100644
index 0000000000..3201adf3d0
--- /dev/null
+++ b/lib/chirp-workflow/core/src/ctx/message.rs
@@ -0,0 +1,607 @@
+use std::{
+ fmt::{self, Debug},
+ marker::PhantomData,
+ sync::Arc,
+};
+
+use futures_util::StreamExt;
+use rivet_pools::prelude::redis::AsyncCommands;
+use rivet_pools::prelude::*;
+use tokio_util::sync::{CancellationToken, DropGuard};
+use tracing::Instrument;
+use uuid::Uuid;
+
+use crate::{
+ error::{WorkflowError, WorkflowResult},
+ message::{self, Message, MessageWrapper, ReceivedMessage, TraceEntry},
+};
+
+/// Time (in ms) that we subtract from the anchor grace period in order to
+/// validate that there is not a race condition between the anchor validity and
+/// writing to Redis.
+const TAIL_ANCHOR_VALID_GRACE: i64 = 250;
+
+#[derive(Clone)]
+pub struct MessageCtx {
+ /// The connection used to communicate with NATS.
+ nats: NatsPool,
+
+ /// Used for writing to message tails. This cache is ephemeral.
+ redis_chirp_ephemeral: RedisPool,
+
+ req_id: Uuid,
+ ray_id: Uuid,
+ trace: Vec,
+}
+
+impl MessageCtx {
+ pub async fn new(
+ conn: &rivet_connection::Connection,
+ req_id: Uuid,
+ ray_id: Uuid,
+ ) -> WorkflowResult {
+ Ok(MessageCtx {
+ nats: conn.nats().await?,
+ redis_chirp_ephemeral: conn.redis_chirp_ephemeral().await?,
+ req_id,
+ ray_id,
+ trace: conn
+ .chirp()
+ .trace()
+ .iter()
+ .cloned()
+ .map(TryInto::try_into)
+ .collect::>>()?,
+ })
+ }
+}
+
+// MARK: Publishing messages
+impl MessageCtx {
+ /// Publishes a message to NATS and to a durable message stream if a topic is
+ /// set.
+ ///
+ /// Use `subscribe` to consume these messages ephemerally and `tail` to read
+ /// the most recently sent message.
+ ///
+ /// This spawns a background task that calls `message_wait` internally and does not wait for the message to
+ /// finish publishing. This is done since there are very few cases where a
+ /// service should need to wait or fail if a message does not publish
+ /// successfully.
+ #[tracing::instrument(err, skip_all, fields(message = M::NAME))]
+ pub async fn message(&self, tags: serde_json::Value, message_body: M) -> WorkflowResult<()>
+ where
+ M: Message,
+ {
+ let client = self.clone();
+ let spawn_res = tokio::task::Builder::new()
+ .name("chirp_workflow::message_async")
+ .spawn(
+ async move {
+ match client.message_wait::(tags, message_body).await {
+ Ok(_) => {}
+ Err(err) => {
+ tracing::error!(?err, "failed to publish message");
+ }
+ }
+ }
+ .in_current_span(),
+ );
+ if let Err(err) = spawn_res {
+ tracing::error!(?err, "failed to spawn message_async task");
+ }
+
+ Ok(())
+ }
+
+ /// Same as `message` but waits for the message to successfully publish.
+ ///
+ /// This is useful in scenarios where we need to publish a large amount of
+ /// messages at once so we put the messages in a queue instead of submitting
+ /// a large number of tasks to Tokio at once.
+ #[tracing::instrument(err, skip_all, fields(message = M::NAME))]
+ pub async fn message_wait(
+ &self,
+ tags: serde_json::Value,
+ message_body: M,
+ ) -> WorkflowResult<()>
+ where
+ M: Message,
+ {
+ let tags_str = cjson::to_string(&tags).map_err(WorkflowError::SerializeMessageTags)?;
+ let nats_subject = message::serialize_message_nats_subject::(&tags_str);
+ let duration_since_epoch = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap_or_else(|err| unreachable!("time is broken: {}", err));
+ let ts = duration_since_epoch.as_millis() as i64;
+
+ // Serialize the body
+ let body_buf =
+ serde_json::to_string(&message_body).map_err(WorkflowError::SerializeMessage)?;
+ let body_buf_len = body_buf.len();
+ let body_buf = serde_json::value::RawValue::from_string(body_buf)
+ .map_err(WorkflowError::SerializeMessage)?;
+
+ // Serialize message
+ let req_id = Uuid::new_v4();
+ let message = MessageWrapper {
+ req_id: req_id,
+ ray_id: self.ray_id,
+ tags: tags.clone(),
+ ts,
+ trace: self.trace.clone(),
+ allow_recursive: false, // TODO:
+ body: &body_buf,
+ };
+ let message_buf = serde_json::to_vec(&message).map_err(WorkflowError::SerializeMessage)?;
+
+ // TODO: opts.dont_log_body
+ if true {
+ tracing::info!(
+ %nats_subject,
+ body_bytes = ?body_buf_len,
+ message_bytes = ?message_buf.len(),
+ "publish message"
+ );
+ } else {
+ tracing::info!(
+ %nats_subject,
+ ?message_body,
+ body_bytes = ?body_buf_len,
+ message_bytes = ?message_buf.len(),
+ "publish message"
+ );
+ }
+
+ // Write to Redis and NATS.
+ //
+ // It's important to write to the stream as fast as possible in order to
+ // ensure messages are handled quickly.
+ let message_buf = Arc::new(message_buf);
+ self.message_write_redis::(&tags_str, message_buf.clone(), req_id, ts)
+ .await;
+ self.message_publish_nats::(&nats_subject, message_buf)
+ .await;
+
+ Ok(())
+ }
+
+ /// Writes a message to a Redis durable stream and tails.
+ #[tracing::instrument(level = "debug", skip_all)]
+ async fn message_write_redis(
+ &self,
+ tags_str: &str,
+ message_buf: Arc>,
+ req_id: Uuid,
+ ts: i64,
+ ) where
+ M: Message,
+ {
+ // Write tail
+ let tail_key = redis_keys::message_tail::(tags_str);
+
+ let mut pipe = redis::pipe();
+
+ // Save message
+ pipe.hset(
+ &tail_key,
+ redis_keys::message_tail::REQUEST_ID,
+ req_id.to_string(),
+ )
+ .ignore();
+ pipe.hset(&tail_key, redis_keys::message_tail::TS, ts)
+ .ignore();
+ pipe.hset(
+ &tail_key,
+ redis_keys::message_tail::BODY,
+ message_buf.as_slice(),
+ )
+ .ignore();
+
+ let mut conn = self.redis_chirp_ephemeral.clone();
+ match pipe.query_async::<_, ()>(&mut conn).await {
+ Ok(_) => {
+ tracing::debug!("write to redis tail succeeded");
+ }
+ Err(err) => {
+ tracing::error!(?err, "failed to write to redis tail");
+ }
+ }
+
+ // Automatically expire
+ pipe.expire(&tail_key, M::TAIL_TTL.as_millis() as usize)
+ .ignore();
+ }
+
+ /// Publishes the message to NATS.
+ #[tracing::instrument(level = "debug", skip_all)]
+ async fn message_publish_nats(&self, nats_subject: &str, message_buf: Arc>)
+ where
+ M: Message,
+ {
+ // Publish message to NATS. Do this after a successful write to
+ // Redis in order to verify that tailing messages doesn't end up in a
+ // race condition that misses a message from the database.
+ //
+ // Infinite backoff since we want to wait until the service reboots.
+ let mut backoff = rivet_util::Backoff::default_infinite();
+ loop {
+ // Ignore for infinite backoff
+ backoff.tick().await;
+
+ let nats_subject = nats_subject.to_owned();
+
+ tracing::trace!(
+ %nats_subject,
+ message_len = message_buf.len(),
+ "publishing message to nats"
+ );
+ if let Err(err) = self
+ .nats
+ .publish(nats_subject.clone(), (*message_buf).clone().into())
+ .await
+ {
+ tracing::warn!(?err, "publish message failed, trying again");
+ continue;
+ }
+
+ tracing::debug!("publish nats message succeeded");
+ break;
+ }
+ }
+}
+
+// MARK: Subscriptions
+impl MessageCtx {
+ /// Listens for Chirp workflow messages globally on NATS.
+ #[tracing::instrument(level = "debug", err, skip_all)]
+ pub async fn subscribe(
+ &self,
+ tags: &serde_json::Value,
+ ) -> WorkflowResult>
+ where
+ M: Message,
+ {
+ self.subscribe_opt::(SubscribeOpts {
+ tags,
+ flush_nats: true,
+ })
+ .await
+ }
+
+ /// Listens for Chirp workflow messages globally on NATS.
+ #[tracing::instrument(err, skip_all, fields(message = M::NAME))]
+ pub async fn subscribe_opt(
+ &self,
+ opts: SubscribeOpts<'_>,
+ ) -> WorkflowResult>
+ where
+ M: Message,
+ {
+ let tags_str = cjson::to_string(opts.tags).map_err(WorkflowError::SerializeMessageTags)?;
+ let nats_subject = message::serialize_message_nats_subject::(&tags_str);
+
+ // Create subscription and flush immediately.
+ tracing::info!(%nats_subject, tags = ?opts.tags, "creating subscription");
+ let subscription = self
+ .nats
+ .subscribe(nats_subject.clone())
+ .await
+ .map_err(|x| WorkflowError::CreateSubscription(x.into()))?;
+ if opts.flush_nats {
+ self.nats
+ .flush()
+ .await
+ .map_err(|x| WorkflowError::FlushNats(x.into()))?;
+ }
+
+ // Return handle
+ let subscription = SubscriptionHandle::new(nats_subject, subscription, self.req_id);
+ Ok(subscription)
+ }
+
+ /// Reads the tail message of a stream without waiting for a message.
+ #[tracing::instrument(err, skip_all, fields(message = M::NAME))]
+ pub async fn tail_read(
+ &self,
+ tags: serde_json::Value,
+ ) -> WorkflowResult>>
+ where
+ M: Message,
+ {
+ let mut conn = self.redis_chirp_ephemeral.clone();
+
+ // Fetch message
+ let tags_str = cjson::to_string(&tags).map_err(WorkflowError::SerializeMessageTags)?;
+ let tail_key = redis_keys::message_tail::(&tags_str);
+ let message_buf = conn
+ .hget::<_, _, Option>>(&tail_key, redis_keys::message_tail::BODY)
+ .await?;
+
+ // Deserialize message
+ let message = if let Some(message_buf) = message_buf {
+ let message = ReceivedMessage::::deserialize(message_buf.as_slice())?;
+ tracing::info!(?message, "immediate read tail message");
+ Some(message)
+ } else {
+ tracing::info!("no tail message to read");
+ None
+ };
+
+ Ok(message)
+ }
+
+ /// Used by API services to tail an message (by start time) after a given timestamp.
+ ///
+ /// Because this waits indefinitely until next message, it is recommended to use this inside
+ /// of a `rivet_util::macros::select_with_timeout!` block:
+ /// ```rust
+ /// use rivet_util as util;
+ ///
+ /// let message_sub = tail_anchor!([ctx, anchor] message_test());
+ ///
+ /// // Consumes anchor or times out after 1 minute
+ /// util::macros::select_with_timeout!(
+ /// message = message_sub => {
+ /// let _message = message?;
+ /// }
+ /// );
+ /// ```
+ #[tracing::instrument(err, skip_all, fields(message = M::NAME))]
+ pub async fn tail_anchor(
+ &self,
+ tags: serde_json::Value,
+ anchor: &TailAnchor,
+ ) -> WorkflowResult>
+ where
+ M: Message,
+ {
+ // Validate anchor is valid
+ if !anchor.is_valid(M::TAIL_TTL.as_millis() as i64) {
+ return Ok(TailAnchorResponse::AnchorExpired);
+ }
+
+ // Create subscription. Do this before reading from the log in order to
+ // ensure consistency.
+ //
+ // Leave flush enabled in order to ensure that subscription is
+ // registered with NATS before continuing.
+ let mut sub = self.subscribe(&tags).await?;
+
+ // Read the tail log
+ let tail_read = self.tail_read(tags).await?;
+
+ // Check if valid or wait for subscription
+ let (message, source) = match tail_read {
+ Some(message) if message.ts > anchor.start_time => (message, "tail_read"),
+ _ => {
+ // Wait for next message if tail not present
+ let message = sub.next().await?;
+ (message, "subscription")
+ }
+ };
+
+ tracing::info!(?message, %source, ?anchor, "read tail message");
+
+ Ok(TailAnchorResponse::Message(message))
+ }
+}
+
+#[derive(Debug)]
+pub struct SubscribeOpts<'a> {
+ pub tags: &'a serde_json::Value,
+ pub flush_nats: bool,
+}
+
+/// Used to receive messages from other contexts.
+///
+/// This subscription will automatically close when dropped.
+pub struct SubscriptionHandle
+where
+ M: Message,
+{
+ _message: PhantomData,
+ _guard: DropGuard,
+ subject: String,
+ subscription: nats::Subscriber,
+ req_id: Uuid,
+}
+
+impl Debug for SubscriptionHandle
+where
+ M: Message,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SubscriptionHandle")
+ .field("subject", &self.subject)
+ .finish()
+ }
+}
+
+impl SubscriptionHandle
+where
+ M: Message,
+{
+ #[tracing::instrument(level = "debug", skip_all)]
+ fn new(subject: String, subscription: nats::Subscriber, req_id: Uuid) -> Self {
+ let token = CancellationToken::new();
+
+ {
+ let token = token.clone();
+ let spawn_res = tokio::task::Builder::new()
+ .name("chirp_workflow::message_wait_drop")
+ .spawn(
+ async move {
+ token.cancelled().await;
+
+ tracing::trace!("closing subscription");
+
+ // We don't worry about calling `subscription.drain()` since the
+ // entire subscription wrapper is dropped anyways, so we can't
+ // call `.recv()`.
+ }
+ .instrument(tracing::trace_span!("subscription_wait_drop")),
+ );
+ if let Err(err) = spawn_res {
+ tracing::error!(?err, "failed to spawn message_wait_drop task");
+ }
+ }
+
+ SubscriptionHandle {
+ _message: Default::default(),
+ _guard: token.drop_guard(),
+ subject,
+ subscription,
+ req_id,
+ }
+ }
+
+ /// Waits for the next message in the subscription.
+ ///
+ /// This future can be safely dropped.
+ #[tracing::instrument]
+ pub async fn next(&mut self) -> WorkflowResult> {
+ self.next_inner(false).await
+ }
+
+ // TODO: Add a full config struct to pass to `next` that impl's `Default`
+ /// Waits for the next message in the subscription that originates from the
+ /// parent request ID via trace.
+ ///
+ /// This future can be safely dropped.
+ #[tracing::instrument]
+ pub async fn next_with_trace(
+ &mut self,
+ filter_trace: bool,
+ ) -> WorkflowResult> {
+ self.next_inner(filter_trace).await
+ }
+
+ /// This future can be safely dropped.
+ #[tracing::instrument(level = "trace")]
+ async fn next_inner(&mut self, filter_trace: bool) -> WorkflowResult> {
+ tracing::info!("waiting for message");
+
+ loop {
+ // Poll the subscription.
+ //
+ // Use blocking threads instead of `try_next`, since I'm not sure
+ // try_next works as intended.
+ let nats_message = match self.subscription.next().await {
+ Some(x) => x,
+ None => {
+ tracing::debug!("unsubscribed");
+ return Err(WorkflowError::SubscriptionUnsubscribed);
+ }
+ };
+
+ if filter_trace {
+ let message_wrapper =
+ ReceivedMessage::::deserialize_wrapper(&nats_message.payload[..])?;
+
+ // Check if the message trace stack originates from this client
+ //
+ // We intentionally use the request ID instead of just checking the ray ID because
+ // there may be multiple calls to `message_with_subscribe` within the same ray.
+ // Explicitly checking the parent request ensures the response is unique to this
+ // message.
+ if message_wrapper
+ .trace
+ .iter()
+ .rev()
+ .any(|trace_entry| trace_entry.req_id == self.req_id)
+ {
+ let message = ReceivedMessage::::deserialize(&nats_message.payload[..])?;
+ tracing::info!(?message, "received message");
+
+ return Ok(message);
+ }
+ } else {
+ let message = ReceivedMessage::::deserialize(&nats_message.payload[..])?;
+ tracing::info!(?message, "received message");
+
+ return Ok(message);
+ }
+
+ // Message not from parent, continue with loop
+ }
+ }
+
+ /// Converts the subscription in to a stream.
+ pub fn into_stream(
+ self,
+ ) -> impl futures_util::Stream- >> {
+ futures_util::stream::try_unfold(self, |mut sub| async move {
+ let message = sub.next().await?;
+ Ok(Some((message, sub)))
+ })
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct TailAnchor {
+ pub start_time: i64,
+}
+
+impl TailAnchor {
+ pub fn new(start_time: i64) -> Self {
+ TailAnchor { start_time }
+ }
+
+ pub fn is_valid(&self, ttl: i64) -> bool {
+ self.start_time > rivet_util::timestamp::now() - ttl * 1000 - TAIL_ANCHOR_VALID_GRACE
+ }
+}
+
+#[derive(Debug)]
+pub enum TailAnchorResponse
+where
+ M: Message + Debug,
+{
+ Message(ReceivedMessage),
+
+ /// Anchor was older than the TTL of the message.
+ AnchorExpired,
+}
+
+impl TailAnchorResponse
+where
+ M: Message + Debug,
+{
+ /// Returns the timestamp of the message if exists.
+ ///
+ /// Useful for endpoints that need to return a new anchor.
+ pub fn msg_ts(&self) -> Option {
+ match self {
+ Self::Message(msg) => Some(msg.msg_ts()),
+ Self::AnchorExpired => None,
+ }
+ }
+}
+
+mod redis_keys {
+ use std::{
+ collections::hash_map::DefaultHasher,
+ hash::{Hash, Hasher},
+ };
+
+ use crate::message::Message;
+
+ /// HASH
+ pub fn message_tail(tags_str: &str) -> String
+ where
+ M: Message,
+ {
+ // Get hash of the tags
+ let mut hasher = DefaultHasher::new();
+ tags_str.hash(&mut hasher);
+
+ format!("{{topic:{}:{:x}}}:tail", M::NAME, hasher.finish())
+ }
+
+ pub mod message_tail {
+ pub const REQUEST_ID: &str = "r";
+ pub const TS: &str = "t";
+ pub const BODY: &str = "b";
+ }
+}
diff --git a/lib/chirp-workflow/core/src/ctx/mod.rs b/lib/chirp-workflow/core/src/ctx/mod.rs
index 8d75c427df..00dc5bab1a 100644
--- a/lib/chirp-workflow/core/src/ctx/mod.rs
+++ b/lib/chirp-workflow/core/src/ctx/mod.rs
@@ -1,10 +1,12 @@
mod activity;
pub(crate) mod api;
+pub mod message;
mod operation;
mod test;
mod workflow;
pub use activity::ActivityCtx;
pub use api::ApiCtx;
+pub use message::MessageCtx;
pub use operation::OperationCtx;
pub use test::TestCtx;
pub use workflow::WorkflowCtx;
diff --git a/lib/chirp-workflow/core/src/ctx/test.rs b/lib/chirp-workflow/core/src/ctx/test.rs
index a018f407ec..1c991d6ece 100644
--- a/lib/chirp-workflow/core/src/ctx/test.rs
+++ b/lib/chirp-workflow/core/src/ctx/test.rs
@@ -5,8 +5,13 @@ use tokio::time::Duration;
use uuid::Uuid;
use crate::{
- util, DatabaseHandle, DatabasePostgres, Operation, OperationCtx, OperationInput, Signal,
- Workflow, WorkflowError, WorkflowInput,
+ ctx::{
+ message::{SubscriptionHandle, TailAnchor, TailAnchorResponse},
+ MessageCtx, OperationCtx,
+ },
+ message::{Message, ReceivedMessage},
+ util, DatabaseHandle, DatabasePostgres, Operation, OperationInput, Signal, Workflow,
+ WorkflowError, WorkflowInput,
};
pub struct TestCtx {
@@ -17,6 +22,7 @@ pub struct TestCtx {
db: DatabaseHandle,
conn: rivet_connection::Connection,
+ msg_ctx: MessageCtx,
// Backwards compatibility
op_ctx: rivet_operation::OperationContext<()>,
@@ -60,6 +66,7 @@ impl TestCtx {
);
let db = DatabasePostgres::from_pool(pools.crdb().unwrap());
+ let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await.unwrap();
TestCtx {
name: service_name,
@@ -68,6 +75,7 @@ impl TestCtx {
db,
conn,
op_ctx,
+ msg_ctx,
}
}
}
@@ -262,6 +270,66 @@ impl TestCtx {
.map_err(WorkflowError::OperationFailure)
.map_err(GlobalError::raw)
}
+
+ pub async fn msg(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
+ where
+ M: Message,
+ {
+ self.msg_ctx
+ .message(tags, body)
+ .await
+ .map_err(GlobalError::raw)
+ }
+
+ pub async fn msg_wait(&self, tags: serde_json::Value, body: M) -> GlobalResult<()>
+ where
+ M: Message,
+ {
+ self.msg_ctx
+ .message_wait(tags, body)
+ .await
+ .map_err(GlobalError::raw)
+ }
+
+ pub async fn subscribe(
+ &self,
+ tags: &serde_json::Value,
+ ) -> GlobalResult>
+ where
+ M: Message,
+ {
+ self.msg_ctx
+ .subscribe::(tags)
+ .await
+ .map_err(GlobalError::raw)
+ }
+
+ pub async fn tail_read(
+ &self,
+ tags: serde_json::Value,
+ ) -> GlobalResult>>
+ where
+ M: Message,
+ {
+ self.msg_ctx
+ .tail_read::(tags)
+ .await
+ .map_err(GlobalError::raw)
+ }
+
+ pub async fn tail_anchor(
+ &self,
+ tags: serde_json::Value,
+ anchor: &TailAnchor,
+ ) -> GlobalResult>
+ where
+ M: Message,
+ {
+ self.msg_ctx
+ .tail_anchor::(tags, anchor)
+ .await
+ .map_err(GlobalError::raw)
+ }
}
impl TestCtx {
@@ -352,6 +420,7 @@ impl TestCtx {
}
}
+/// Like a subscription handle for messages but for workflows. Should only be used in tests
pub struct ObserveHandle {
db: DatabaseHandle,
name: &'static str,
@@ -362,7 +431,6 @@ pub struct ObserveHandle {
impl ObserveHandle {
pub async fn next(&mut self) -> GlobalResult {
tracing::info!(name=%self.name, input=?self.input, "observing workflow");
- tracing::info!(ts=%self.ts);
let (workflow_id, create_ts) = loop {
if let Some((workflow_id, create_ts)) = self
diff --git a/lib/chirp-workflow/core/src/ctx/workflow.rs b/lib/chirp-workflow/core/src/ctx/workflow.rs
index 7dbf319253..a6dbd78f87 100644
--- a/lib/chirp-workflow/core/src/ctx/workflow.rs
+++ b/lib/chirp-workflow/core/src/ctx/workflow.rs
@@ -258,7 +258,8 @@ impl WorkflowCtx {
self.create_ts,
self.ray_id,
A::NAME,
- );
+ )
+ .await?;
let res = tokio::time::timeout(A::TIMEOUT, A::run(&ctx, input))
.await
diff --git a/lib/chirp-workflow/core/src/error.rs b/lib/chirp-workflow/core/src/error.rs
index c41fa25e24..92420a1094 100644
--- a/lib/chirp-workflow/core/src/error.rs
+++ b/lib/chirp-workflow/core/src/error.rs
@@ -42,12 +42,6 @@ pub enum WorkflowError {
#[error("deserialize workflow input: {0}")]
DeserializeWorkflowOutput(serde_json::Error),
- #[error("serialize workflow tags: {0}")]
- SerializeWorkflowTags(serde_json::Error),
-
- #[error("deserialize workflow tags: {0}")]
- DeserializeWorkflowTags(serde_json::Error),
-
#[error("serialize activity input: {0}")]
SerializeActivityInput(serde_json::Error),
@@ -63,6 +57,39 @@ pub enum WorkflowError {
#[error("deserialize signal body: {0}")]
DeserializeSignalBody(serde_json::Error),
+ #[error("serialize message body: {0}")]
+ SerializeMessageBody(serde_json::Error),
+
+ #[error("serialize message: {0}")]
+ SerializeMessage(serde_json::Error),
+
+ #[error("decode message body: {0}")]
+ DeserializeMessageBody(serde_json::Error),
+
+ #[error("decode message: {0}")]
+ DeserializeMessage(serde_json::Error),
+
+ #[error("serialize message tags: {0:?}")]
+ SerializeMessageTags(cjson::Error),
+
+ #[error("create subscription: {0}")]
+ CreateSubscription(rivet_pools::prelude::nats::Error),
+
+ #[error("flush nats: {0}")]
+ FlushNats(rivet_pools::prelude::nats::Error),
+
+ #[error("subscription unsubscribed")]
+ SubscriptionUnsubscribed,
+
+ #[error("missing message data")]
+ MissingMessageData,
+
+ #[error("redis: {source}")]
+ Redis {
+ #[from]
+ source: rivet_pools::prelude::redis::RedisError,
+ },
+
#[error("no signal found: {0:?}")]
NoSignalFound(Box<[&'static str]>),
@@ -78,6 +105,9 @@ pub enum WorkflowError {
#[error("sql: {0}")]
Sqlx(sqlx::Error),
+ #[error("pools: {0}")]
+ Pools(#[from] rivet_pools::Error),
+
#[error("activity timed out")]
ActivityTimeout,
diff --git a/lib/chirp-workflow/core/src/lib.rs b/lib/chirp-workflow/core/src/lib.rs
index e3f5b66c08..9726e0724d 100644
--- a/lib/chirp-workflow/core/src/lib.rs
+++ b/lib/chirp-workflow/core/src/lib.rs
@@ -5,6 +5,7 @@ pub mod db;
mod error;
mod event;
mod executable;
+pub mod message;
pub mod operation;
pub mod prelude;
pub mod registry;
diff --git a/lib/chirp-workflow/core/src/message.rs b/lib/chirp-workflow/core/src/message.rs
new file mode 100644
index 0000000000..d7aa9f2774
--- /dev/null
+++ b/lib/chirp-workflow/core/src/message.rs
@@ -0,0 +1,132 @@
+use std::fmt::Debug;
+
+use rivet_operation::prelude::proto::chirp;
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
+use uuid::Uuid;
+
+use crate::error::{WorkflowError, WorkflowResult};
+
+pub trait Message: Debug + Send + Sync + Serialize + DeserializeOwned + 'static {
+ const NAME: &'static str;
+ const TAIL_TTL: std::time::Duration;
+}
+
+pub fn serialize_message_nats_subject(tags_str: &str) -> String
+where
+ M: Message,
+{
+ format!("chirp.workflow.msg.{}.{}", M::NAME, tags_str,)
+}
+
+/// A message received from a Chirp subscription.
+#[derive(Debug)]
+pub struct ReceivedMessage
+where
+ M: Message,
+{
+ pub(crate) ray_id: Uuid,
+ pub(crate) req_id: Uuid,
+ pub(crate) ts: i64,
+ pub(crate) trace: Vec,
+ pub(crate) body: M,
+}
+
+impl ReceivedMessage
+where
+ M: Message,
+{
+ #[tracing::instrument(skip(buf))]
+ pub(crate) fn deserialize(buf: &[u8]) -> WorkflowResult {
+ // Deserialize the wrapper
+ let message_wrapper = Self::deserialize_wrapper(buf)?;
+
+ // Deserialize the body
+ let body = serde_json::from_str::(message_wrapper.body.get())
+ .map_err(WorkflowError::DeserializeMessageBody)?;
+
+ Ok(ReceivedMessage {
+ ray_id: message_wrapper.ray_id,
+ req_id: message_wrapper.req_id,
+ ts: message_wrapper.ts,
+ trace: message_wrapper.trace,
+ body,
+ })
+ }
+
+ // Only returns the message wrapper
+ #[tracing::instrument(skip(buf))]
+ pub(crate) fn deserialize_wrapper<'a>(buf: &'a [u8]) -> WorkflowResult> {
+ serde_json::from_slice(buf).map_err(WorkflowError::DeserializeMessage)
+ }
+}
+
+impl std::ops::Deref for ReceivedMessage
+where
+ M: Message,
+{
+ type Target = M;
+
+ fn deref(&self) -> &Self::Target {
+ &self.body
+ }
+}
+
+impl ReceivedMessage
+where
+ M: Message,
+{
+ pub fn ray_id(&self) -> Uuid {
+ self.ray_id
+ }
+
+ pub fn req_id(&self) -> Uuid {
+ self.req_id
+ }
+
+ /// Timestamp at which the message was created.
+ pub fn msg_ts(&self) -> i64 {
+ self.ts
+ }
+
+ pub fn body(&self) -> &M {
+ &self.body
+ }
+
+ pub fn trace(&self) -> &[TraceEntry] {
+ &self.trace
+ }
+}
+
+#[derive(Serialize, Deserialize)]
+pub(crate) struct MessageWrapper<'a> {
+ pub(crate) ray_id: Uuid,
+ pub(crate) req_id: Uuid,
+ pub(crate) tags: serde_json::Value,
+ pub(crate) ts: i64,
+ pub(crate) trace: Vec,
+ #[serde(borrow)]
+ pub(crate) body: &'a serde_json::value::RawValue,
+ pub(crate) allow_recursive: bool,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct TraceEntry {
+ context_name: String,
+ pub(crate) req_id: Uuid,
+ ts: i64,
+}
+
+impl TryFrom for TraceEntry {
+ type Error = WorkflowError;
+
+ fn try_from(value: chirp::TraceEntry) -> WorkflowResult {
+ Ok(TraceEntry {
+ context_name: value.context_name.clone(),
+ req_id: value
+ .req_id
+ .map(|id| id.as_uuid())
+ .ok_or(WorkflowError::MissingMessageData)?,
+ ts: value.ts,
+ })
+ }
+}
diff --git a/lib/chirp-workflow/core/src/prelude.rs b/lib/chirp-workflow/core/src/prelude.rs
index a73942c6ea..3832e43628 100644
--- a/lib/chirp-workflow/core/src/prelude.rs
+++ b/lib/chirp-workflow/core/src/prelude.rs
@@ -1,5 +1,5 @@
// Internal types
-pub use chirp_client::prelude::*;
+pub use chirp_client::prelude::{msg, op, rpc, subscribe, tail_all, tail_anchor, tail_read};
pub use formatted_error;
pub use global_error::{ext::*, prelude::*};
#[doc(hidden)]
@@ -20,6 +20,7 @@ pub use crate::{
error::{WorkflowError, WorkflowResult},
executable::closure,
executable::Executable,
+ message::Message,
operation::Operation,
registry::Registry,
signal::{join_signal, Listen, Signal},
@@ -36,6 +37,7 @@ pub use async_trait;
pub use futures_util;
#[doc(hidden)]
pub use indoc::*;
+pub use uuid::Uuid;
// #[doc(hidden)]
// pub use redis;
#[doc(hidden)]
diff --git a/lib/chirp-workflow/macros/src/lib.rs b/lib/chirp-workflow/macros/src/lib.rs
index 316ea44cf0..51500abe50 100644
--- a/lib/chirp-workflow/macros/src/lib.rs
+++ b/lib/chirp-workflow/macros/src/lib.rs
@@ -21,6 +21,16 @@ impl Default for Config {
}
}
+struct MessageConfig {
+ tail_ttl: u64,
+}
+
+impl Default for MessageConfig {
+ fn default() -> Self {
+ MessageConfig { tail_ttl: 90 }
+ }
+}
+
#[proc_macro_attribute]
pub fn workflow(attr: TokenStream, item: TokenStream) -> TokenStream {
let name = parse_macro_input!(attr as OptionalIdent)
@@ -298,6 +308,44 @@ pub fn signal(attr: TokenStream, item: TokenStream) -> TokenStream {
TokenStream::from(expanded)
}
+#[proc_macro_attribute]
+pub fn message(attr: TokenStream, item: TokenStream) -> TokenStream {
+ let name = parse_macro_input!(attr as LitStr);
+ let item_struct = parse_macro_input!(item as ItemStruct);
+
+ let config = match parse_msg_config(&item_struct.attrs) {
+ Ok(x) => x,
+ Err(err) => return err.into_compile_error().into(),
+ };
+
+ let struct_ident = &item_struct.ident;
+ let tail_ttl = config.tail_ttl;
+
+ let expanded = quote! {
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ #item_struct
+
+ impl Message for #struct_ident {
+ const NAME: &'static str = #name;
+ const TAIL_TTL: std::time::Duration = std::time::Duration::from_secs(#tail_ttl);
+ }
+
+ #[async_trait::async_trait]
+ impl Listen for #struct_ident {
+ async fn listen(ctx: &mut chirp_workflow::prelude::WorkflowCtx) -> chirp_workflow::prelude::WorkflowResult {
+ let row = ctx.listen_any(&[Self::NAME]).await?;
+ Self::parse(&row.signal_name, row.body)
+ }
+
+ fn parse(_name: &str, body: serde_json::Value) -> chirp_workflow::prelude::WorkflowResult {
+ serde_json::from_value(body).map_err(WorkflowError::DeserializeActivityOutput)
+ }
+ }
+ };
+
+ TokenStream::from(expanded)
+}
+
#[proc_macro_attribute]
pub fn workflow_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(item as syn::ItemFn);
@@ -393,6 +441,31 @@ fn parse_config(attrs: &[syn::Attribute]) -> syn::Result {
Ok(config)
}
+fn parse_msg_config(attrs: &[syn::Attribute]) -> syn::Result {
+ let mut config = MessageConfig::default();
+
+ for attr in attrs {
+ let syn::Meta::NameValue(name_value) = &attr.meta else {
+ continue;
+ };
+
+ let ident = name_value.path.require_ident()?;
+
+ // Verify config property
+ if ident == "tail_ttl" {
+ config.tail_ttl = syn::parse::(name_value.value.to_token_stream().into())?
+ .base10_parse()?;
+ } else {
+ return Err(syn::Error::new(
+ ident.span(),
+ format!("Unknown config property `{ident}`"),
+ ));
+ }
+ }
+
+ Ok(config)
+}
+
fn parse_empty_config(attrs: &[syn::Attribute]) -> syn::Result<()> {
for attr in attrs {
let syn::Meta::NameValue(name_value) = &attr.meta else {
diff --git a/lib/connection/src/lib.rs b/lib/connection/src/lib.rs
index 9cfa32a5ad..d7198e8381 100644
--- a/lib/connection/src/lib.rs
+++ b/lib/connection/src/lib.rs
@@ -71,6 +71,10 @@ impl Connection {
self.cache.clone()
}
+ pub async fn nats(&self) -> Result {
+ self.pools.nats()
+ }
+
pub async fn crdb(&self) -> Result {
self.pools.crdb()
}
@@ -95,6 +99,10 @@ impl Connection {
self.pools.redis("ephemeral")
}
+ pub async fn redis_chirp_ephemeral(&self) -> Result {
+ self.pools.redis("ephemeral")
+ }
+
pub fn perf(&self) -> &chirp_perf::PerfCtx {
self.client.perf()
}
diff --git a/svc/Cargo.lock b/svc/Cargo.lock
index a09bb2fc4b..d1f36e21cb 100644
--- a/svc/Cargo.lock
+++ b/svc/Cargo.lock
@@ -2164,6 +2164,7 @@ dependencies = [
"async-trait",
"chirp-client",
"chirp-workflow-macros",
+ "cjson",
"formatted-error",
"futures-util",
"global-error",
@@ -2183,6 +2184,7 @@ dependencies = [
"sqlx 0.7.4 (git+https://github.com/rivet-gg/sqlx?rev=08d6e61aa0572e7ec557abbedb72cebb96e1ac5b)",
"thiserror",
"tokio",
+ "tokio-util 0.7.10",
"tracing",
"tracing-subscriber",
"uuid",