From 64b4054d47d6346cd5eb99968847d7d726bc4837 Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Wed, 3 Jul 2024 00:28:47 +0000 Subject: [PATCH] feat(workflows): add messages --- lib/api-helper/build/src/anchor.rs | 13 +- lib/api-helper/build/src/macro_util.rs | 2 +- lib/chirp-workflow/core/Cargo.toml | 2 + lib/chirp-workflow/core/src/ctx/activity.rs | 39 +- lib/chirp-workflow/core/src/ctx/api.rs | 61 +- lib/chirp-workflow/core/src/ctx/message.rs | 607 ++++++++++++++++++++ lib/chirp-workflow/core/src/ctx/mod.rs | 2 + lib/chirp-workflow/core/src/ctx/test.rs | 74 ++- lib/chirp-workflow/core/src/ctx/workflow.rs | 3 +- lib/chirp-workflow/core/src/error.rs | 42 +- lib/chirp-workflow/core/src/lib.rs | 1 + lib/chirp-workflow/core/src/message.rs | 132 +++++ lib/chirp-workflow/core/src/prelude.rs | 4 +- lib/chirp-workflow/macros/src/lib.rs | 73 +++ lib/connection/src/lib.rs | 8 + svc/Cargo.lock | 2 + 16 files changed, 1041 insertions(+), 24 deletions(-) create mode 100644 lib/chirp-workflow/core/src/ctx/message.rs create mode 100644 lib/chirp-workflow/core/src/message.rs diff --git a/lib/api-helper/build/src/anchor.rs b/lib/api-helper/build/src/anchor.rs index 3222290ec..582a212a0 100644 --- a/lib/api-helper/build/src/anchor.rs +++ b/lib/api-helper/build/src/anchor.rs @@ -25,7 +25,18 @@ impl WatchIndexQuery { /// Converts the `WatchIndexQuery` into a `TailAnchor` for use with the Chirp client. pub fn to_consumer(self) -> Result, ClientError> { if let Some(watch_index) = self.watch_index { - Ok(Some(chirp_client::TailAnchor { + Ok(Some(TailAnchor { + start_time: watch_index.parse()?, + })) + } else { + Ok(None) + } + } + + /// Converts the `WatchIndexQuery` into a `TailAnchor` for use with Chirp workflows. + pub fn to_workflow(self) -> Result, ClientError> { + if let Some(watch_index) = self.watch_index { + Ok(Some(chirp_workflow::ctx::message::TailAnchor { start_time: watch_index.parse()?, })) } else { diff --git a/lib/api-helper/build/src/macro_util.rs b/lib/api-helper/build/src/macro_util.rs index 4e000213e..83f760741 100644 --- a/lib/api-helper/build/src/macro_util.rs +++ b/lib/api-helper/build/src/macro_util.rs @@ -325,7 +325,7 @@ pub async fn __with_ctx( ); let conn = rivet_connection::Connection::new(client, pools.clone(), cache.clone()); let db = chirp_workflow::compat::db_from_pools(&pools).await?; - let internal_ctx = ApiCtx::new(db, conn, req_id, ray_id, ts, svc_name); + let internal_ctx = ApiCtx::new(db, conn, req_id, ray_id, ts, svc_name).await?; // Create auth let rate_limit_ctx = AuthRateLimitCtx { diff --git a/lib/chirp-workflow/core/Cargo.toml b/lib/chirp-workflow/core/Cargo.toml index 28ba4c9d5..8f5bc697d 100644 --- a/lib/chirp-workflow/core/Cargo.toml +++ b/lib/chirp-workflow/core/Cargo.toml @@ -9,6 +9,7 @@ license = "Apache-2.0" async-trait = "0.1.80" chirp-client = { path = "../../chirp/client" } chirp-workflow-macros = { path = "../macros" } +cjson = "0.1" formatted-error = { path = "../../formatted-error" } futures-util = "0.3" global-error = { path = "../../global-error" } @@ -27,6 +28,7 @@ serde = { version = "1.0.198", features = ["derive"] } serde_json = "1.0.116" thiserror = "1.0.59" tokio = { version = "1.37.0", features = ["full"] } +tokio-util = "0.7" tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } uuid = { version = "1.8.0", features = ["v4", "serde"] } diff --git a/lib/chirp-workflow/core/src/ctx/activity.rs b/lib/chirp-workflow/core/src/ctx/activity.rs index 391b6a59b..858f85414 100644 --- a/lib/chirp-workflow/core/src/ctx/activity.rs +++ b/lib/chirp-workflow/core/src/ctx/activity.rs @@ -2,7 +2,12 @@ use global_error::{GlobalError, GlobalResult}; use rivet_pools::prelude::*; use uuid::Uuid; -use crate::{ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, WorkflowError}; +use crate::{ + ctx::{MessageCtx, OperationCtx}, + error::{WorkflowError, WorkflowResult}, + message::Message, + DatabaseHandle, Operation, OperationInput, +}; #[derive(Clone)] pub struct ActivityCtx { @@ -14,20 +19,21 @@ pub struct ActivityCtx { db: DatabaseHandle, conn: rivet_connection::Connection, + msg_ctx: MessageCtx, // Backwards compatibility op_ctx: rivet_operation::OperationContext<()>, } impl ActivityCtx { - pub fn new( + pub async fn new( workflow_id: Uuid, db: DatabaseHandle, conn: &rivet_connection::Connection, activity_create_ts: i64, ray_id: Uuid, name: &'static str, - ) -> Self { + ) -> WorkflowResult { let ts = rivet_util::timestamp::now(); let req_id = Uuid::new_v4(); let conn = conn.wrap(req_id, ray_id, name); @@ -43,7 +49,9 @@ impl ActivityCtx { ); op_ctx.from_workflow = true; - ActivityCtx { + let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?; + + Ok(ActivityCtx { workflow_id, ray_id, name, @@ -51,7 +59,8 @@ impl ActivityCtx { db, conn, op_ctx, - } + msg_ctx, + }) } } @@ -86,6 +95,26 @@ impl ActivityCtx { .await .map_err(GlobalError::raw) } + + pub async fn msg(&self, tags: serde_json::Value, body: M) -> GlobalResult<()> + where + M: Message, + { + self.msg_ctx + .message(tags, body) + .await + .map_err(GlobalError::raw) + } + + pub async fn msg_wait(&self, tags: serde_json::Value, body: M) -> GlobalResult<()> + where + M: Message, + { + self.msg_ctx + .message_wait(tags, body) + .await + .map_err(GlobalError::raw) + } } impl ActivityCtx { diff --git a/lib/chirp-workflow/core/src/ctx/api.rs b/lib/chirp-workflow/core/src/ctx/api.rs index 8c11c1782..222872433 100644 --- a/lib/chirp-workflow/core/src/ctx/api.rs +++ b/lib/chirp-workflow/core/src/ctx/api.rs @@ -6,8 +6,13 @@ use serde::Serialize; use uuid::Uuid; use crate::{ - ctx::OperationCtx, DatabaseHandle, Operation, OperationInput, Signal, Workflow, WorkflowError, - WorkflowInput, + ctx::{ + message::{SubscriptionHandle, TailAnchor, TailAnchorResponse}, + MessageCtx, OperationCtx, + }, + error::WorkflowResult, + message::{Message, ReceivedMessage}, + DatabaseHandle, Operation, OperationInput, Signal, Workflow, WorkflowError, WorkflowInput, }; pub const WORKFLOW_TIMEOUT: Duration = Duration::from_secs(60); @@ -20,20 +25,21 @@ pub struct ApiCtx { db: DatabaseHandle, conn: rivet_connection::Connection, + msg_ctx: MessageCtx, // Backwards compatibility op_ctx: rivet_operation::OperationContext<()>, } impl ApiCtx { - pub fn new( + pub async fn new( db: DatabaseHandle, conn: rivet_connection::Connection, req_id: Uuid, ray_id: Uuid, ts: i64, name: &'static str, - ) -> Self { + ) -> WorkflowResult { let op_ctx = rivet_operation::OperationContext::new( name.to_string(), std::time::Duration::from_secs(60), @@ -45,14 +51,17 @@ impl ApiCtx { (), ); - ApiCtx { + let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?; + + Ok(ApiCtx { ray_id, name, ts, db, conn, op_ctx, - } + msg_ctx, + }) } } @@ -243,6 +252,46 @@ impl ApiCtx { .map_err(WorkflowError::OperationFailure) .map_err(GlobalError::raw) } + + pub async fn subscribe( + &self, + tags: &serde_json::Value, + ) -> GlobalResult> + where + M: Message, + { + self.msg_ctx + .subscribe::(tags) + .await + .map_err(GlobalError::raw) + } + + pub async fn tail_read( + &self, + tags: serde_json::Value, + ) -> GlobalResult>> + where + M: Message, + { + self.msg_ctx + .tail_read::(tags) + .await + .map_err(GlobalError::raw) + } + + pub async fn tail_anchor( + &self, + tags: serde_json::Value, + anchor: &TailAnchor, + ) -> GlobalResult> + where + M: Message, + { + self.msg_ctx + .tail_anchor::(tags, anchor) + .await + .map_err(GlobalError::raw) + } } impl ApiCtx { diff --git a/lib/chirp-workflow/core/src/ctx/message.rs b/lib/chirp-workflow/core/src/ctx/message.rs new file mode 100644 index 000000000..3201adf3d --- /dev/null +++ b/lib/chirp-workflow/core/src/ctx/message.rs @@ -0,0 +1,607 @@ +use std::{ + fmt::{self, Debug}, + marker::PhantomData, + sync::Arc, +}; + +use futures_util::StreamExt; +use rivet_pools::prelude::redis::AsyncCommands; +use rivet_pools::prelude::*; +use tokio_util::sync::{CancellationToken, DropGuard}; +use tracing::Instrument; +use uuid::Uuid; + +use crate::{ + error::{WorkflowError, WorkflowResult}, + message::{self, Message, MessageWrapper, ReceivedMessage, TraceEntry}, +}; + +/// Time (in ms) that we subtract from the anchor grace period in order to +/// validate that there is not a race condition between the anchor validity and +/// writing to Redis. +const TAIL_ANCHOR_VALID_GRACE: i64 = 250; + +#[derive(Clone)] +pub struct MessageCtx { + /// The connection used to communicate with NATS. + nats: NatsPool, + + /// Used for writing to message tails. This cache is ephemeral. + redis_chirp_ephemeral: RedisPool, + + req_id: Uuid, + ray_id: Uuid, + trace: Vec, +} + +impl MessageCtx { + pub async fn new( + conn: &rivet_connection::Connection, + req_id: Uuid, + ray_id: Uuid, + ) -> WorkflowResult { + Ok(MessageCtx { + nats: conn.nats().await?, + redis_chirp_ephemeral: conn.redis_chirp_ephemeral().await?, + req_id, + ray_id, + trace: conn + .chirp() + .trace() + .iter() + .cloned() + .map(TryInto::try_into) + .collect::>>()?, + }) + } +} + +// MARK: Publishing messages +impl MessageCtx { + /// Publishes a message to NATS and to a durable message stream if a topic is + /// set. + /// + /// Use `subscribe` to consume these messages ephemerally and `tail` to read + /// the most recently sent message. + /// + /// This spawns a background task that calls `message_wait` internally and does not wait for the message to + /// finish publishing. This is done since there are very few cases where a + /// service should need to wait or fail if a message does not publish + /// successfully. + #[tracing::instrument(err, skip_all, fields(message = M::NAME))] + pub async fn message(&self, tags: serde_json::Value, message_body: M) -> WorkflowResult<()> + where + M: Message, + { + let client = self.clone(); + let spawn_res = tokio::task::Builder::new() + .name("chirp_workflow::message_async") + .spawn( + async move { + match client.message_wait::(tags, message_body).await { + Ok(_) => {} + Err(err) => { + tracing::error!(?err, "failed to publish message"); + } + } + } + .in_current_span(), + ); + if let Err(err) = spawn_res { + tracing::error!(?err, "failed to spawn message_async task"); + } + + Ok(()) + } + + /// Same as `message` but waits for the message to successfully publish. + /// + /// This is useful in scenarios where we need to publish a large amount of + /// messages at once so we put the messages in a queue instead of submitting + /// a large number of tasks to Tokio at once. + #[tracing::instrument(err, skip_all, fields(message = M::NAME))] + pub async fn message_wait( + &self, + tags: serde_json::Value, + message_body: M, + ) -> WorkflowResult<()> + where + M: Message, + { + let tags_str = cjson::to_string(&tags).map_err(WorkflowError::SerializeMessageTags)?; + let nats_subject = message::serialize_message_nats_subject::(&tags_str); + let duration_since_epoch = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|err| unreachable!("time is broken: {}", err)); + let ts = duration_since_epoch.as_millis() as i64; + + // Serialize the body + let body_buf = + serde_json::to_string(&message_body).map_err(WorkflowError::SerializeMessage)?; + let body_buf_len = body_buf.len(); + let body_buf = serde_json::value::RawValue::from_string(body_buf) + .map_err(WorkflowError::SerializeMessage)?; + + // Serialize message + let req_id = Uuid::new_v4(); + let message = MessageWrapper { + req_id: req_id, + ray_id: self.ray_id, + tags: tags.clone(), + ts, + trace: self.trace.clone(), + allow_recursive: false, // TODO: + body: &body_buf, + }; + let message_buf = serde_json::to_vec(&message).map_err(WorkflowError::SerializeMessage)?; + + // TODO: opts.dont_log_body + if true { + tracing::info!( + %nats_subject, + body_bytes = ?body_buf_len, + message_bytes = ?message_buf.len(), + "publish message" + ); + } else { + tracing::info!( + %nats_subject, + ?message_body, + body_bytes = ?body_buf_len, + message_bytes = ?message_buf.len(), + "publish message" + ); + } + + // Write to Redis and NATS. + // + // It's important to write to the stream as fast as possible in order to + // ensure messages are handled quickly. + let message_buf = Arc::new(message_buf); + self.message_write_redis::(&tags_str, message_buf.clone(), req_id, ts) + .await; + self.message_publish_nats::(&nats_subject, message_buf) + .await; + + Ok(()) + } + + /// Writes a message to a Redis durable stream and tails. + #[tracing::instrument(level = "debug", skip_all)] + async fn message_write_redis( + &self, + tags_str: &str, + message_buf: Arc>, + req_id: Uuid, + ts: i64, + ) where + M: Message, + { + // Write tail + let tail_key = redis_keys::message_tail::(tags_str); + + let mut pipe = redis::pipe(); + + // Save message + pipe.hset( + &tail_key, + redis_keys::message_tail::REQUEST_ID, + req_id.to_string(), + ) + .ignore(); + pipe.hset(&tail_key, redis_keys::message_tail::TS, ts) + .ignore(); + pipe.hset( + &tail_key, + redis_keys::message_tail::BODY, + message_buf.as_slice(), + ) + .ignore(); + + let mut conn = self.redis_chirp_ephemeral.clone(); + match pipe.query_async::<_, ()>(&mut conn).await { + Ok(_) => { + tracing::debug!("write to redis tail succeeded"); + } + Err(err) => { + tracing::error!(?err, "failed to write to redis tail"); + } + } + + // Automatically expire + pipe.expire(&tail_key, M::TAIL_TTL.as_millis() as usize) + .ignore(); + } + + /// Publishes the message to NATS. + #[tracing::instrument(level = "debug", skip_all)] + async fn message_publish_nats(&self, nats_subject: &str, message_buf: Arc>) + where + M: Message, + { + // Publish message to NATS. Do this after a successful write to + // Redis in order to verify that tailing messages doesn't end up in a + // race condition that misses a message from the database. + // + // Infinite backoff since we want to wait until the service reboots. + let mut backoff = rivet_util::Backoff::default_infinite(); + loop { + // Ignore for infinite backoff + backoff.tick().await; + + let nats_subject = nats_subject.to_owned(); + + tracing::trace!( + %nats_subject, + message_len = message_buf.len(), + "publishing message to nats" + ); + if let Err(err) = self + .nats + .publish(nats_subject.clone(), (*message_buf).clone().into()) + .await + { + tracing::warn!(?err, "publish message failed, trying again"); + continue; + } + + tracing::debug!("publish nats message succeeded"); + break; + } + } +} + +// MARK: Subscriptions +impl MessageCtx { + /// Listens for Chirp workflow messages globally on NATS. + #[tracing::instrument(level = "debug", err, skip_all)] + pub async fn subscribe( + &self, + tags: &serde_json::Value, + ) -> WorkflowResult> + where + M: Message, + { + self.subscribe_opt::(SubscribeOpts { + tags, + flush_nats: true, + }) + .await + } + + /// Listens for Chirp workflow messages globally on NATS. + #[tracing::instrument(err, skip_all, fields(message = M::NAME))] + pub async fn subscribe_opt( + &self, + opts: SubscribeOpts<'_>, + ) -> WorkflowResult> + where + M: Message, + { + let tags_str = cjson::to_string(opts.tags).map_err(WorkflowError::SerializeMessageTags)?; + let nats_subject = message::serialize_message_nats_subject::(&tags_str); + + // Create subscription and flush immediately. + tracing::info!(%nats_subject, tags = ?opts.tags, "creating subscription"); + let subscription = self + .nats + .subscribe(nats_subject.clone()) + .await + .map_err(|x| WorkflowError::CreateSubscription(x.into()))?; + if opts.flush_nats { + self.nats + .flush() + .await + .map_err(|x| WorkflowError::FlushNats(x.into()))?; + } + + // Return handle + let subscription = SubscriptionHandle::new(nats_subject, subscription, self.req_id); + Ok(subscription) + } + + /// Reads the tail message of a stream without waiting for a message. + #[tracing::instrument(err, skip_all, fields(message = M::NAME))] + pub async fn tail_read( + &self, + tags: serde_json::Value, + ) -> WorkflowResult>> + where + M: Message, + { + let mut conn = self.redis_chirp_ephemeral.clone(); + + // Fetch message + let tags_str = cjson::to_string(&tags).map_err(WorkflowError::SerializeMessageTags)?; + let tail_key = redis_keys::message_tail::(&tags_str); + let message_buf = conn + .hget::<_, _, Option>>(&tail_key, redis_keys::message_tail::BODY) + .await?; + + // Deserialize message + let message = if let Some(message_buf) = message_buf { + let message = ReceivedMessage::::deserialize(message_buf.as_slice())?; + tracing::info!(?message, "immediate read tail message"); + Some(message) + } else { + tracing::info!("no tail message to read"); + None + }; + + Ok(message) + } + + /// Used by API services to tail an message (by start time) after a given timestamp. + /// + /// Because this waits indefinitely until next message, it is recommended to use this inside + /// of a `rivet_util::macros::select_with_timeout!` block: + /// ```rust + /// use rivet_util as util; + /// + /// let message_sub = tail_anchor!([ctx, anchor] message_test()); + /// + /// // Consumes anchor or times out after 1 minute + /// util::macros::select_with_timeout!( + /// message = message_sub => { + /// let _message = message?; + /// } + /// ); + /// ``` + #[tracing::instrument(err, skip_all, fields(message = M::NAME))] + pub async fn tail_anchor( + &self, + tags: serde_json::Value, + anchor: &TailAnchor, + ) -> WorkflowResult> + where + M: Message, + { + // Validate anchor is valid + if !anchor.is_valid(M::TAIL_TTL.as_millis() as i64) { + return Ok(TailAnchorResponse::AnchorExpired); + } + + // Create subscription. Do this before reading from the log in order to + // ensure consistency. + // + // Leave flush enabled in order to ensure that subscription is + // registered with NATS before continuing. + let mut sub = self.subscribe(&tags).await?; + + // Read the tail log + let tail_read = self.tail_read(tags).await?; + + // Check if valid or wait for subscription + let (message, source) = match tail_read { + Some(message) if message.ts > anchor.start_time => (message, "tail_read"), + _ => { + // Wait for next message if tail not present + let message = sub.next().await?; + (message, "subscription") + } + }; + + tracing::info!(?message, %source, ?anchor, "read tail message"); + + Ok(TailAnchorResponse::Message(message)) + } +} + +#[derive(Debug)] +pub struct SubscribeOpts<'a> { + pub tags: &'a serde_json::Value, + pub flush_nats: bool, +} + +/// Used to receive messages from other contexts. +/// +/// This subscription will automatically close when dropped. +pub struct SubscriptionHandle +where + M: Message, +{ + _message: PhantomData, + _guard: DropGuard, + subject: String, + subscription: nats::Subscriber, + req_id: Uuid, +} + +impl Debug for SubscriptionHandle +where + M: Message, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SubscriptionHandle") + .field("subject", &self.subject) + .finish() + } +} + +impl SubscriptionHandle +where + M: Message, +{ + #[tracing::instrument(level = "debug", skip_all)] + fn new(subject: String, subscription: nats::Subscriber, req_id: Uuid) -> Self { + let token = CancellationToken::new(); + + { + let token = token.clone(); + let spawn_res = tokio::task::Builder::new() + .name("chirp_workflow::message_wait_drop") + .spawn( + async move { + token.cancelled().await; + + tracing::trace!("closing subscription"); + + // We don't worry about calling `subscription.drain()` since the + // entire subscription wrapper is dropped anyways, so we can't + // call `.recv()`. + } + .instrument(tracing::trace_span!("subscription_wait_drop")), + ); + if let Err(err) = spawn_res { + tracing::error!(?err, "failed to spawn message_wait_drop task"); + } + } + + SubscriptionHandle { + _message: Default::default(), + _guard: token.drop_guard(), + subject, + subscription, + req_id, + } + } + + /// Waits for the next message in the subscription. + /// + /// This future can be safely dropped. + #[tracing::instrument] + pub async fn next(&mut self) -> WorkflowResult> { + self.next_inner(false).await + } + + // TODO: Add a full config struct to pass to `next` that impl's `Default` + /// Waits for the next message in the subscription that originates from the + /// parent request ID via trace. + /// + /// This future can be safely dropped. + #[tracing::instrument] + pub async fn next_with_trace( + &mut self, + filter_trace: bool, + ) -> WorkflowResult> { + self.next_inner(filter_trace).await + } + + /// This future can be safely dropped. + #[tracing::instrument(level = "trace")] + async fn next_inner(&mut self, filter_trace: bool) -> WorkflowResult> { + tracing::info!("waiting for message"); + + loop { + // Poll the subscription. + // + // Use blocking threads instead of `try_next`, since I'm not sure + // try_next works as intended. + let nats_message = match self.subscription.next().await { + Some(x) => x, + None => { + tracing::debug!("unsubscribed"); + return Err(WorkflowError::SubscriptionUnsubscribed); + } + }; + + if filter_trace { + let message_wrapper = + ReceivedMessage::::deserialize_wrapper(&nats_message.payload[..])?; + + // Check if the message trace stack originates from this client + // + // We intentionally use the request ID instead of just checking the ray ID because + // there may be multiple calls to `message_with_subscribe` within the same ray. + // Explicitly checking the parent request ensures the response is unique to this + // message. + if message_wrapper + .trace + .iter() + .rev() + .any(|trace_entry| trace_entry.req_id == self.req_id) + { + let message = ReceivedMessage::::deserialize(&nats_message.payload[..])?; + tracing::info!(?message, "received message"); + + return Ok(message); + } + } else { + let message = ReceivedMessage::::deserialize(&nats_message.payload[..])?; + tracing::info!(?message, "received message"); + + return Ok(message); + } + + // Message not from parent, continue with loop + } + } + + /// Converts the subscription in to a stream. + pub fn into_stream( + self, + ) -> impl futures_util::Stream>> { + futures_util::stream::try_unfold(self, |mut sub| async move { + let message = sub.next().await?; + Ok(Some((message, sub))) + }) + } +} + +#[derive(Debug, Clone)] +pub struct TailAnchor { + pub start_time: i64, +} + +impl TailAnchor { + pub fn new(start_time: i64) -> Self { + TailAnchor { start_time } + } + + pub fn is_valid(&self, ttl: i64) -> bool { + self.start_time > rivet_util::timestamp::now() - ttl * 1000 - TAIL_ANCHOR_VALID_GRACE + } +} + +#[derive(Debug)] +pub enum TailAnchorResponse +where + M: Message + Debug, +{ + Message(ReceivedMessage), + + /// Anchor was older than the TTL of the message. + AnchorExpired, +} + +impl TailAnchorResponse +where + M: Message + Debug, +{ + /// Returns the timestamp of the message if exists. + /// + /// Useful for endpoints that need to return a new anchor. + pub fn msg_ts(&self) -> Option { + match self { + Self::Message(msg) => Some(msg.msg_ts()), + Self::AnchorExpired => None, + } + } +} + +mod redis_keys { + use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + }; + + use crate::message::Message; + + /// HASH + pub fn message_tail(tags_str: &str) -> String + where + M: Message, + { + // Get hash of the tags + let mut hasher = DefaultHasher::new(); + tags_str.hash(&mut hasher); + + format!("{{topic:{}:{:x}}}:tail", M::NAME, hasher.finish()) + } + + pub mod message_tail { + pub const REQUEST_ID: &str = "r"; + pub const TS: &str = "t"; + pub const BODY: &str = "b"; + } +} diff --git a/lib/chirp-workflow/core/src/ctx/mod.rs b/lib/chirp-workflow/core/src/ctx/mod.rs index 8d75c427d..00dc5bab1 100644 --- a/lib/chirp-workflow/core/src/ctx/mod.rs +++ b/lib/chirp-workflow/core/src/ctx/mod.rs @@ -1,10 +1,12 @@ mod activity; pub(crate) mod api; +pub mod message; mod operation; mod test; mod workflow; pub use activity::ActivityCtx; pub use api::ApiCtx; +pub use message::MessageCtx; pub use operation::OperationCtx; pub use test::TestCtx; pub use workflow::WorkflowCtx; diff --git a/lib/chirp-workflow/core/src/ctx/test.rs b/lib/chirp-workflow/core/src/ctx/test.rs index a018f407e..1c991d6ec 100644 --- a/lib/chirp-workflow/core/src/ctx/test.rs +++ b/lib/chirp-workflow/core/src/ctx/test.rs @@ -5,8 +5,13 @@ use tokio::time::Duration; use uuid::Uuid; use crate::{ - util, DatabaseHandle, DatabasePostgres, Operation, OperationCtx, OperationInput, Signal, - Workflow, WorkflowError, WorkflowInput, + ctx::{ + message::{SubscriptionHandle, TailAnchor, TailAnchorResponse}, + MessageCtx, OperationCtx, + }, + message::{Message, ReceivedMessage}, + util, DatabaseHandle, DatabasePostgres, Operation, OperationInput, Signal, Workflow, + WorkflowError, WorkflowInput, }; pub struct TestCtx { @@ -17,6 +22,7 @@ pub struct TestCtx { db: DatabaseHandle, conn: rivet_connection::Connection, + msg_ctx: MessageCtx, // Backwards compatibility op_ctx: rivet_operation::OperationContext<()>, @@ -60,6 +66,7 @@ impl TestCtx { ); let db = DatabasePostgres::from_pool(pools.crdb().unwrap()); + let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await.unwrap(); TestCtx { name: service_name, @@ -68,6 +75,7 @@ impl TestCtx { db, conn, op_ctx, + msg_ctx, } } } @@ -262,6 +270,66 @@ impl TestCtx { .map_err(WorkflowError::OperationFailure) .map_err(GlobalError::raw) } + + pub async fn msg(&self, tags: serde_json::Value, body: M) -> GlobalResult<()> + where + M: Message, + { + self.msg_ctx + .message(tags, body) + .await + .map_err(GlobalError::raw) + } + + pub async fn msg_wait(&self, tags: serde_json::Value, body: M) -> GlobalResult<()> + where + M: Message, + { + self.msg_ctx + .message_wait(tags, body) + .await + .map_err(GlobalError::raw) + } + + pub async fn subscribe( + &self, + tags: &serde_json::Value, + ) -> GlobalResult> + where + M: Message, + { + self.msg_ctx + .subscribe::(tags) + .await + .map_err(GlobalError::raw) + } + + pub async fn tail_read( + &self, + tags: serde_json::Value, + ) -> GlobalResult>> + where + M: Message, + { + self.msg_ctx + .tail_read::(tags) + .await + .map_err(GlobalError::raw) + } + + pub async fn tail_anchor( + &self, + tags: serde_json::Value, + anchor: &TailAnchor, + ) -> GlobalResult> + where + M: Message, + { + self.msg_ctx + .tail_anchor::(tags, anchor) + .await + .map_err(GlobalError::raw) + } } impl TestCtx { @@ -352,6 +420,7 @@ impl TestCtx { } } +/// Like a subscription handle for messages but for workflows. Should only be used in tests pub struct ObserveHandle { db: DatabaseHandle, name: &'static str, @@ -362,7 +431,6 @@ pub struct ObserveHandle { impl ObserveHandle { pub async fn next(&mut self) -> GlobalResult { tracing::info!(name=%self.name, input=?self.input, "observing workflow"); - tracing::info!(ts=%self.ts); let (workflow_id, create_ts) = loop { if let Some((workflow_id, create_ts)) = self diff --git a/lib/chirp-workflow/core/src/ctx/workflow.rs b/lib/chirp-workflow/core/src/ctx/workflow.rs index 7dbf31925..a6dbd78f8 100644 --- a/lib/chirp-workflow/core/src/ctx/workflow.rs +++ b/lib/chirp-workflow/core/src/ctx/workflow.rs @@ -258,7 +258,8 @@ impl WorkflowCtx { self.create_ts, self.ray_id, A::NAME, - ); + ) + .await?; let res = tokio::time::timeout(A::TIMEOUT, A::run(&ctx, input)) .await diff --git a/lib/chirp-workflow/core/src/error.rs b/lib/chirp-workflow/core/src/error.rs index c41fa25e2..92420a109 100644 --- a/lib/chirp-workflow/core/src/error.rs +++ b/lib/chirp-workflow/core/src/error.rs @@ -42,12 +42,6 @@ pub enum WorkflowError { #[error("deserialize workflow input: {0}")] DeserializeWorkflowOutput(serde_json::Error), - #[error("serialize workflow tags: {0}")] - SerializeWorkflowTags(serde_json::Error), - - #[error("deserialize workflow tags: {0}")] - DeserializeWorkflowTags(serde_json::Error), - #[error("serialize activity input: {0}")] SerializeActivityInput(serde_json::Error), @@ -63,6 +57,39 @@ pub enum WorkflowError { #[error("deserialize signal body: {0}")] DeserializeSignalBody(serde_json::Error), + #[error("serialize message body: {0}")] + SerializeMessageBody(serde_json::Error), + + #[error("serialize message: {0}")] + SerializeMessage(serde_json::Error), + + #[error("decode message body: {0}")] + DeserializeMessageBody(serde_json::Error), + + #[error("decode message: {0}")] + DeserializeMessage(serde_json::Error), + + #[error("serialize message tags: {0:?}")] + SerializeMessageTags(cjson::Error), + + #[error("create subscription: {0}")] + CreateSubscription(rivet_pools::prelude::nats::Error), + + #[error("flush nats: {0}")] + FlushNats(rivet_pools::prelude::nats::Error), + + #[error("subscription unsubscribed")] + SubscriptionUnsubscribed, + + #[error("missing message data")] + MissingMessageData, + + #[error("redis: {source}")] + Redis { + #[from] + source: rivet_pools::prelude::redis::RedisError, + }, + #[error("no signal found: {0:?}")] NoSignalFound(Box<[&'static str]>), @@ -78,6 +105,9 @@ pub enum WorkflowError { #[error("sql: {0}")] Sqlx(sqlx::Error), + #[error("pools: {0}")] + Pools(#[from] rivet_pools::Error), + #[error("activity timed out")] ActivityTimeout, diff --git a/lib/chirp-workflow/core/src/lib.rs b/lib/chirp-workflow/core/src/lib.rs index e3f5b66c0..9726e0724 100644 --- a/lib/chirp-workflow/core/src/lib.rs +++ b/lib/chirp-workflow/core/src/lib.rs @@ -5,6 +5,7 @@ pub mod db; mod error; mod event; mod executable; +pub mod message; pub mod operation; pub mod prelude; pub mod registry; diff --git a/lib/chirp-workflow/core/src/message.rs b/lib/chirp-workflow/core/src/message.rs new file mode 100644 index 000000000..d7aa9f277 --- /dev/null +++ b/lib/chirp-workflow/core/src/message.rs @@ -0,0 +1,132 @@ +use std::fmt::Debug; + +use rivet_operation::prelude::proto::chirp; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use uuid::Uuid; + +use crate::error::{WorkflowError, WorkflowResult}; + +pub trait Message: Debug + Send + Sync + Serialize + DeserializeOwned + 'static { + const NAME: &'static str; + const TAIL_TTL: std::time::Duration; +} + +pub fn serialize_message_nats_subject(tags_str: &str) -> String +where + M: Message, +{ + format!("chirp.workflow.msg.{}.{}", M::NAME, tags_str,) +} + +/// A message received from a Chirp subscription. +#[derive(Debug)] +pub struct ReceivedMessage +where + M: Message, +{ + pub(crate) ray_id: Uuid, + pub(crate) req_id: Uuid, + pub(crate) ts: i64, + pub(crate) trace: Vec, + pub(crate) body: M, +} + +impl ReceivedMessage +where + M: Message, +{ + #[tracing::instrument(skip(buf))] + pub(crate) fn deserialize(buf: &[u8]) -> WorkflowResult { + // Deserialize the wrapper + let message_wrapper = Self::deserialize_wrapper(buf)?; + + // Deserialize the body + let body = serde_json::from_str::(message_wrapper.body.get()) + .map_err(WorkflowError::DeserializeMessageBody)?; + + Ok(ReceivedMessage { + ray_id: message_wrapper.ray_id, + req_id: message_wrapper.req_id, + ts: message_wrapper.ts, + trace: message_wrapper.trace, + body, + }) + } + + // Only returns the message wrapper + #[tracing::instrument(skip(buf))] + pub(crate) fn deserialize_wrapper<'a>(buf: &'a [u8]) -> WorkflowResult> { + serde_json::from_slice(buf).map_err(WorkflowError::DeserializeMessage) + } +} + +impl std::ops::Deref for ReceivedMessage +where + M: Message, +{ + type Target = M; + + fn deref(&self) -> &Self::Target { + &self.body + } +} + +impl ReceivedMessage +where + M: Message, +{ + pub fn ray_id(&self) -> Uuid { + self.ray_id + } + + pub fn req_id(&self) -> Uuid { + self.req_id + } + + /// Timestamp at which the message was created. + pub fn msg_ts(&self) -> i64 { + self.ts + } + + pub fn body(&self) -> &M { + &self.body + } + + pub fn trace(&self) -> &[TraceEntry] { + &self.trace + } +} + +#[derive(Serialize, Deserialize)] +pub(crate) struct MessageWrapper<'a> { + pub(crate) ray_id: Uuid, + pub(crate) req_id: Uuid, + pub(crate) tags: serde_json::Value, + pub(crate) ts: i64, + pub(crate) trace: Vec, + #[serde(borrow)] + pub(crate) body: &'a serde_json::value::RawValue, + pub(crate) allow_recursive: bool, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TraceEntry { + context_name: String, + pub(crate) req_id: Uuid, + ts: i64, +} + +impl TryFrom for TraceEntry { + type Error = WorkflowError; + + fn try_from(value: chirp::TraceEntry) -> WorkflowResult { + Ok(TraceEntry { + context_name: value.context_name.clone(), + req_id: value + .req_id + .map(|id| id.as_uuid()) + .ok_or(WorkflowError::MissingMessageData)?, + ts: value.ts, + }) + } +} diff --git a/lib/chirp-workflow/core/src/prelude.rs b/lib/chirp-workflow/core/src/prelude.rs index a73942c6e..3832e4362 100644 --- a/lib/chirp-workflow/core/src/prelude.rs +++ b/lib/chirp-workflow/core/src/prelude.rs @@ -1,5 +1,5 @@ // Internal types -pub use chirp_client::prelude::*; +pub use chirp_client::prelude::{msg, op, rpc, subscribe, tail_all, tail_anchor, tail_read}; pub use formatted_error; pub use global_error::{ext::*, prelude::*}; #[doc(hidden)] @@ -20,6 +20,7 @@ pub use crate::{ error::{WorkflowError, WorkflowResult}, executable::closure, executable::Executable, + message::Message, operation::Operation, registry::Registry, signal::{join_signal, Listen, Signal}, @@ -36,6 +37,7 @@ pub use async_trait; pub use futures_util; #[doc(hidden)] pub use indoc::*; +pub use uuid::Uuid; // #[doc(hidden)] // pub use redis; #[doc(hidden)] diff --git a/lib/chirp-workflow/macros/src/lib.rs b/lib/chirp-workflow/macros/src/lib.rs index 316ea44cf..51500abe5 100644 --- a/lib/chirp-workflow/macros/src/lib.rs +++ b/lib/chirp-workflow/macros/src/lib.rs @@ -21,6 +21,16 @@ impl Default for Config { } } +struct MessageConfig { + tail_ttl: u64, +} + +impl Default for MessageConfig { + fn default() -> Self { + MessageConfig { tail_ttl: 90 } + } +} + #[proc_macro_attribute] pub fn workflow(attr: TokenStream, item: TokenStream) -> TokenStream { let name = parse_macro_input!(attr as OptionalIdent) @@ -298,6 +308,44 @@ pub fn signal(attr: TokenStream, item: TokenStream) -> TokenStream { TokenStream::from(expanded) } +#[proc_macro_attribute] +pub fn message(attr: TokenStream, item: TokenStream) -> TokenStream { + let name = parse_macro_input!(attr as LitStr); + let item_struct = parse_macro_input!(item as ItemStruct); + + let config = match parse_msg_config(&item_struct.attrs) { + Ok(x) => x, + Err(err) => return err.into_compile_error().into(), + }; + + let struct_ident = &item_struct.ident; + let tail_ttl = config.tail_ttl; + + let expanded = quote! { + #[derive(Debug, serde::Serialize, serde::Deserialize)] + #item_struct + + impl Message for #struct_ident { + const NAME: &'static str = #name; + const TAIL_TTL: std::time::Duration = std::time::Duration::from_secs(#tail_ttl); + } + + #[async_trait::async_trait] + impl Listen for #struct_ident { + async fn listen(ctx: &mut chirp_workflow::prelude::WorkflowCtx) -> chirp_workflow::prelude::WorkflowResult { + let row = ctx.listen_any(&[Self::NAME]).await?; + Self::parse(&row.signal_name, row.body) + } + + fn parse(_name: &str, body: serde_json::Value) -> chirp_workflow::prelude::WorkflowResult { + serde_json::from_value(body).map_err(WorkflowError::DeserializeActivityOutput) + } + } + }; + + TokenStream::from(expanded) +} + #[proc_macro_attribute] pub fn workflow_test(_attr: TokenStream, item: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(item as syn::ItemFn); @@ -393,6 +441,31 @@ fn parse_config(attrs: &[syn::Attribute]) -> syn::Result { Ok(config) } +fn parse_msg_config(attrs: &[syn::Attribute]) -> syn::Result { + let mut config = MessageConfig::default(); + + for attr in attrs { + let syn::Meta::NameValue(name_value) = &attr.meta else { + continue; + }; + + let ident = name_value.path.require_ident()?; + + // Verify config property + if ident == "tail_ttl" { + config.tail_ttl = syn::parse::(name_value.value.to_token_stream().into())? + .base10_parse()?; + } else { + return Err(syn::Error::new( + ident.span(), + format!("Unknown config property `{ident}`"), + )); + } + } + + Ok(config) +} + fn parse_empty_config(attrs: &[syn::Attribute]) -> syn::Result<()> { for attr in attrs { let syn::Meta::NameValue(name_value) = &attr.meta else { diff --git a/lib/connection/src/lib.rs b/lib/connection/src/lib.rs index 9cfa32a5a..d7198e838 100644 --- a/lib/connection/src/lib.rs +++ b/lib/connection/src/lib.rs @@ -71,6 +71,10 @@ impl Connection { self.cache.clone() } + pub async fn nats(&self) -> Result { + self.pools.nats() + } + pub async fn crdb(&self) -> Result { self.pools.crdb() } @@ -95,6 +99,10 @@ impl Connection { self.pools.redis("ephemeral") } + pub async fn redis_chirp_ephemeral(&self) -> Result { + self.pools.redis("ephemeral") + } + pub fn perf(&self) -> &chirp_perf::PerfCtx { self.client.perf() } diff --git a/svc/Cargo.lock b/svc/Cargo.lock index 4686de5b8..275cac882 100644 --- a/svc/Cargo.lock +++ b/svc/Cargo.lock @@ -2085,6 +2085,7 @@ dependencies = [ "async-trait", "chirp-client", "chirp-workflow-macros", + "cjson", "formatted-error", "futures-util", "global-error", @@ -2104,6 +2105,7 @@ dependencies = [ "sqlx", "thiserror", "tokio", + "tokio-util 0.7.10", "tracing", "tracing-subscriber", "uuid",