diff --git a/Cargo.lock b/Cargo.lock index 252a4b0eef..a941174f1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3500,6 +3500,7 @@ dependencies = [ "tokio", "tokio-tungstenite", "tracing", + "universaldb", "universalpubsub", "vbare", ] diff --git a/engine/packages/actor-kv/src/entry.rs b/engine/packages/actor-kv/src/entry.rs index 81a57d0d96..adc43433ba 100644 --- a/engine/packages/actor-kv/src/entry.rs +++ b/engine/packages/actor-kv/src/entry.rs @@ -3,7 +3,7 @@ use std::result::Result::Ok; use anyhow::*; use universaldb::prelude::*; -use rivet_runner_protocol as rp; +use rivet_runner_protocol::mk2 as rp; use crate::key::KeyWrapper; diff --git a/engine/packages/actor-kv/src/key.rs b/engine/packages/actor-kv/src/key.rs index 71847d8ab5..455fb7a463 100644 --- a/engine/packages/actor-kv/src/key.rs +++ b/engine/packages/actor-kv/src/key.rs @@ -1,4 +1,4 @@ -use rivet_runner_protocol as rp; +use rivet_runner_protocol::mk2 as rp; use universaldb::tuple::{ Bytes, PackResult, TupleDepth, TuplePack, TupleUnpack, VersionstampOffset, }; diff --git a/engine/packages/actor-kv/src/lib.rs b/engine/packages/actor-kv/src/lib.rs index 7133362400..46e9735bfa 100644 --- a/engine/packages/actor-kv/src/lib.rs +++ b/engine/packages/actor-kv/src/lib.rs @@ -5,7 +5,7 @@ use entry::{EntryBaseKey, EntryBuilder, EntryMetadataKey, EntryValueChunkKey}; use futures_util::{StreamExt, TryStreamExt}; use gas::prelude::*; use key::{KeyWrapper, ListKeyWrapper}; -use rivet_runner_protocol as rp; +use rivet_runner_protocol::mk2 as rp; use universaldb::prelude::*; use universaldb::tuple::Subspace; use utils::{validate_entries, validate_keys}; diff --git a/engine/packages/actor-kv/src/utils.rs b/engine/packages/actor-kv/src/utils.rs index 50bf471ce3..c412771478 100644 --- a/engine/packages/actor-kv/src/utils.rs +++ b/engine/packages/actor-kv/src/utils.rs @@ -1,7 +1,7 @@ use std::result::Result::Ok; use anyhow::*; -use rivet_runner_protocol as rp; +use rivet_runner_protocol::mk2 as rp; use crate::{ MAX_KEY_SIZE, MAX_KEYS, MAX_PUT_PAYLOAD_SIZE, MAX_STORAGE_SIZE, MAX_VALUE_SIZE, key::KeyWrapper, diff --git a/engine/packages/pegboard-gateway/Cargo.toml b/engine/packages/pegboard-gateway/Cargo.toml index 424293844e..af832e5967 100644 --- a/engine/packages/pegboard-gateway/Cargo.toml +++ b/engine/packages/pegboard-gateway/Cargo.toml @@ -31,5 +31,6 @@ thiserror.workspace = true tokio-tungstenite.workspace = true tokio.workspace = true tracing.workspace = true +universaldb.workspace = true universalpubsub.workspace = true vbare.workspace = true diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index 9fb07566d0..2019bcb584 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -22,6 +22,7 @@ use tokio_tungstenite::tungstenite::{ Message, protocol::frame::{CloseFrame, coding::CloseCode}, }; +use universaldb::utils::IsolationLevel::*; use crate::shared_state::{InFlightRequestHandle, SharedState}; @@ -46,7 +47,7 @@ pub struct WebsocketPendingLimitReached; #[derive(Debug)] enum LifecycleResult { - ServerClose(protocol::ToServerWebSocketClose), + ServerClose(protocol::mk2::ToServerWebSocketClose), ClientClose(Option), Aborted, } @@ -153,10 +154,22 @@ impl CustomServeTrait for PegboardGateway { .context("failed to read body")? .to_bytes(); - let mut stopped_sub = self - .ctx - .subscribe::(("actor_id", self.actor_id)) - .await?; + let udb = self.ctx.udb()?; + let runner_id = self.runner_id; + let (mut stopped_sub, runner_protocol_version) = tokio::try_join!( + self.ctx + .subscribe::(("actor_id", self.actor_id)), + // Read runner protocol version + udb.run(|tx| async move { + tx.with_subspace(pegboard::keys::subspace()); + + tx.read( + &pegboard::keys::runner::ProtocolVersionKey::new(runner_id), + Serializable, + ) + .await + }) + )?; // Build subject to publish to let tunnel_subject = @@ -169,12 +182,12 @@ impl CustomServeTrait for PegboardGateway { .. } = self .shared_state - .start_in_flight_request(tunnel_subject, request_id) + .start_in_flight_request(tunnel_subject, runner_protocol_version, request_id) .await; // Start request - let message = protocol::ToClientTunnelMessageKind::ToClientRequestStart( - protocol::ToClientRequestStart { + let message = protocol::mk2::ToClientTunnelMessageKind::ToClientRequestStart( + protocol::mk2::ToClientRequestStart { actor_id: actor_id.clone(), method, path: self.path.clone(), @@ -197,12 +210,12 @@ impl CustomServeTrait for PegboardGateway { res = msg_rx.recv() => { if let Some(msg) = res { match msg { - protocol::ToServerTunnelMessageKind::ToServerResponseStart( + protocol::mk2::ToServerTunnelMessageKind::ToServerResponseStart( response_start, ) => { return anyhow::Ok(response_start); } - protocol::ToServerTunnelMessageKind::ToServerResponseAbort => { + protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort => { tracing::warn!("request aborted"); return Err(ServiceUnavailable.build()); } @@ -277,9 +290,6 @@ impl CustomServeTrait for PegboardGateway { request_id: protocol::RequestId, after_hibernation: bool, ) -> Result> { - // Use the actor ID from the gateway instance - let actor_id = self.actor_id.to_string(); - // Extract headers let mut request_headers = HashableMap::new(); for (name, value) in headers { @@ -288,10 +298,22 @@ impl CustomServeTrait for PegboardGateway { } } - let mut stopped_sub = self - .ctx - .subscribe::(("actor_id", self.actor_id)) - .await?; + let udb = self.ctx.udb()?; + let runner_id = self.runner_id; + let (mut stopped_sub, runner_protocol_version) = tokio::try_join!( + self.ctx + .subscribe::(("actor_id", self.actor_id)), + // Read runner protocol version + udb.run(|tx| async move { + tx.with_subspace(pegboard::keys::subspace()); + + tx.read( + &pegboard::keys::runner::ProtocolVersionKey::new(runner_id), + Serializable, + ) + .await + }) + )?; // Build subject to publish to let tunnel_subject = @@ -304,7 +326,7 @@ impl CustomServeTrait for PegboardGateway { new, } = self .shared_state - .start_in_flight_request(tunnel_subject.clone(), request_id) + .start_in_flight_request(tunnel_subject.clone(), runner_protocol_version, request_id) .await; ensure!( @@ -317,9 +339,9 @@ impl CustomServeTrait for PegboardGateway { true } else { // Send WebSocket open message - let open_message = protocol::ToClientTunnelMessageKind::ToClientWebSocketOpen( - protocol::ToClientWebSocketOpen { - actor_id: actor_id.clone(), + let open_message = protocol::mk2::ToClientTunnelMessageKind::ToClientWebSocketOpen( + protocol::mk2::ToClientWebSocketOpen { + actor_id: self.actor_id.to_string(), path: self.path.clone(), headers: request_headers, }, @@ -338,10 +360,10 @@ impl CustomServeTrait for PegboardGateway { res = msg_rx.recv() => { if let Some(msg) = res { match msg { - protocol::ToServerTunnelMessageKind::ToServerWebSocketOpen(msg) => { + protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketOpen(msg) => { return anyhow::Ok(msg); } - protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => { + protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => { tracing::warn!(?close, "websocket closed before opening"); return Err(WebSocketServiceUnavailable.build()); } @@ -538,8 +560,8 @@ impl CustomServeTrait for PegboardGateway { Ok(_) => (CloseCode::Normal.into(), None), Err(_) => (CloseCode::Error.into(), Some("ws.downstream_closed".into())), }; - let close_message = protocol::ToClientTunnelMessageKind::ToClientWebSocketClose( - protocol::ToClientWebSocketClose { + let close_message = protocol::mk2::ToClientTunnelMessageKind::ToClientWebSocketClose( + protocol::mk2::ToClientWebSocketClose { code: Some(close_code.into()), reason: close_reason.map(|x| x.as_str().to_string()), }, diff --git a/engine/packages/pegboard-gateway/src/shared_state.rs b/engine/packages/pegboard-gateway/src/shared_state.rs index 568bbf5f6e..535b3f907d 100644 --- a/engine/packages/pegboard-gateway/src/shared_state.rs +++ b/engine/packages/pegboard-gateway/src/shared_state.rs @@ -1,7 +1,9 @@ use anyhow::Result; use gas::prelude::*; use rivet_guard_core::errors::WebSocketServiceTimeout; -use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; +use rivet_runner_protocol::{ + self as protocol, PROTOCOL_MK1_VERSION, PROTOCOL_MK2_VERSION, versioned, +}; use scc::{HashMap, hash_map::Entry}; use std::{ ops::Deref, @@ -20,7 +22,7 @@ const HWS_MESSAGE_ACK_TIMEOUT: Duration = Duration::from_secs(30); const HWS_MAX_PENDING_MSGS_SIZE_PER_REQ: u64 = util::size::mebibytes(1); pub struct InFlightRequestHandle { - pub msg_rx: mpsc::Receiver, + pub msg_rx: mpsc::Receiver, /// Used to check if the request handler has been dropped. /// /// This is separate from `msg_rx` there may still be messages that need to be sent to the @@ -32,14 +34,15 @@ pub struct InFlightRequestHandle { struct InFlightRequest { /// UPS subject to send messages to for this request. receiver_subject: String, + protocol_version: u16, /// Sender for incoming messages to this request. - msg_tx: mpsc::Sender, + msg_tx: mpsc::Sender, /// Used to check if the request handler has been dropped. drop_tx: watch::Sender<()>, /// True once first message for this request has been sent (so runner learned reply_to). opened: bool, /// Message index counter for this request. - message_index: protocol::MessageIndex, + message_index: protocol::mk2::MessageIndex, hibernation_state: Option, stopping: bool, last_pong: i64, @@ -55,14 +58,14 @@ struct HibernationState { pub struct PendingWebsocketMessage { payload: Vec, send_instant: Instant, - message_index: protocol::MessageIndex, + message_index: protocol::mk2::MessageIndex, } pub struct SharedStateInner { ups: PubSub, - gateway_id: protocol::GatewayId, + gateway_id: protocol::mk2::GatewayId, receiver_subject: String, - in_flight_requests: HashMap, + in_flight_requests: HashMap, hibernation_timeout: i64, } @@ -85,7 +88,7 @@ impl SharedState { })) } - pub fn gateway_id(&self) -> protocol::GatewayId { + pub fn gateway_id(&self) -> protocol::mk2::GatewayId { self.gateway_id } @@ -106,7 +109,8 @@ impl SharedState { pub async fn start_in_flight_request( &self, receiver_subject: String, - request_id: protocol::RequestId, + protocol_version: u16, + request_id: protocol::mk2::RequestId, ) -> InFlightRequestHandle { let (msg_tx, msg_rx) = mpsc::channel(128); let (drop_tx, drop_rx) = watch::channel(()); @@ -115,6 +119,7 @@ impl SharedState { Entry::Vacant(entry) => { entry.insert_entry(InFlightRequest { receiver_subject, + protocol_version, msg_tx, drop_tx, opened: false, @@ -151,8 +156,8 @@ impl SharedState { #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id)))] pub async fn send_message( &self, - request_id: protocol::RequestId, - message_kind: protocol::ToClientTunnelMessageKind, + request_id: protocol::mk2::RequestId, + message_kind: protocol::mk2::ToClientTunnelMessageKind, ) -> Result<()> { let mut req = self .in_flight_requests @@ -161,7 +166,7 @@ impl SharedState { .context("request not in flight")?; // Generate message ID - let message_id = protocol::MessageId { + let message_id = protocol::mk2::MessageId { gateway_id: self.gateway_id, request_id, message_index: req.message_index, @@ -180,18 +185,25 @@ impl SharedState { // Check if this is a WebSocket message for hibernation tracking let is_ws_message = matches!( message_kind, - protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage(_) + protocol::mk2::ToClientTunnelMessageKind::ToClientWebSocketMessage(_) ); - let payload = protocol::ToClientTunnelMessage { + let payload = protocol::mk2::ToClientTunnelMessage { message_id, message_kind, }; - // Send message - let message = protocol::ToRunner::ToClientTunnelMessage(payload); - let message_serialized = versioned::ToRunner::wrap_latest(message) - .serialize_with_embedded_version(PROTOCOL_VERSION)?; + let message_serialized = if protocol::is_mk2(req.protocol_version) { + let message = protocol::mk2::ToRunner::ToClientTunnelMessage(payload); + versioned::ToRunnerMk2::wrap_latest(message) + .serialize_with_embedded_version(PROTOCOL_MK2_VERSION)? + } else { + let message = protocol::ToRunner::ToClientTunnelMessage( + versioned::to_client_tunnel_message_v4_to_v3(payload), + ); + versioned::ToRunner::wrap_latest(message) + .serialize_with_embedded_version(PROTOCOL_MK1_VERSION)? + }; if let (Some(hs), true) = (&mut req.hibernation_state, is_ws_message) { hs.total_pending_ws_msgs_size += message_serialized.len() as u64; @@ -228,7 +240,7 @@ impl SharedState { } #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id)))] - pub async fn send_and_check_ping(&self, request_id: protocol::RequestId) -> Result<()> { + pub async fn send_and_check_ping(&self, request_id: protocol::mk2::RequestId) -> Result<()> { let req = self .in_flight_requests .get_async(&request_id) @@ -243,14 +255,23 @@ impl SharedState { return Err(WebSocketServiceTimeout.build()); } - // Send message - let message = protocol::ToRunner::ToRunnerPing(protocol::ToRunnerPing { - gateway_id: self.gateway_id, - request_id, - ts: now, - }); - let message_serialized = versioned::ToRunner::wrap_latest(message) - .serialize_with_embedded_version(PROTOCOL_VERSION)?; + let message_serialized = if protocol::is_mk2(req.protocol_version) { + let message = protocol::mk2::ToRunner::ToRunnerPing(protocol::mk2::ToRunnerPing { + gateway_id: self.gateway_id, + request_id, + ts: now, + }); + versioned::ToRunnerMk2::wrap_latest(message) + .serialize_with_embedded_version(PROTOCOL_MK2_VERSION)? + } else { + let message = protocol::ToRunner::ToRunnerPing(protocol::ToRunnerPing { + gateway_id: self.gateway_id, + request_id, + ts: now, + }); + versioned::ToRunner::wrap_latest(message) + .serialize_with_embedded_version(PROTOCOL_MK1_VERSION)? + }; self.ups .publish( @@ -264,7 +285,7 @@ impl SharedState { } #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id)))] - pub async fn keepalive_hws(&self, request_id: protocol::RequestId) -> Result<()> { + pub async fn keepalive_hws(&self, request_id: protocol::mk2::RequestId) -> Result<()> { let mut req = self .in_flight_requests .get_async(&request_id) @@ -289,7 +310,7 @@ impl SharedState { ); match versioned::ToGateway::deserialize_with_embedded_version(&msg.payload) { - Ok(protocol::ToGateway::ToGatewayPong(pong)) => { + Ok(protocol::mk2::ToGateway::ToGatewayPong(pong)) => { let Some(mut in_flight) = self.in_flight_requests.get_async(&pong.request_id).await else { @@ -306,7 +327,7 @@ impl SharedState { let rtt = now.saturating_sub(pong.ts); metrics::TUNNEL_PING_DURATION.record(rtt as f64 * 0.001, &[]); } - Ok(protocol::ToGateway::ToServerTunnelMessage(msg)) => { + Ok(protocol::mk2::ToGateway::ToServerTunnelMessage(msg)) => { let message_id = msg.message_id; let Some(in_flight) = self @@ -342,7 +363,7 @@ impl SharedState { #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id), %enable))] pub async fn toggle_hibernation( &self, - request_id: protocol::RequestId, + request_id: protocol::mk2::RequestId, enable: bool, ) -> Result<()> { let mut req = self @@ -370,7 +391,7 @@ impl SharedState { #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id)))] pub async fn resend_pending_websocket_messages( &self, - request_id: protocol::RequestId, + request_id: protocol::mk2::RequestId, ) -> Result<()> { let Some(mut req) = self.in_flight_requests.get_async(&request_id).await else { bail!("request not in flight"); @@ -396,7 +417,7 @@ impl SharedState { #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id)))] pub async fn has_pending_websocket_messages( &self, - request_id: protocol::RequestId, + request_id: protocol::mk2::RequestId, ) -> Result { let Some(req) = self.in_flight_requests.get_async(&request_id).await else { bail!("request not in flight"); @@ -412,7 +433,7 @@ impl SharedState { #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id), %ack_index))] pub async fn ack_pending_websocket_messages( &self, - request_id: protocol::RequestId, + request_id: protocol::mk2::RequestId, ack_index: u16, ) -> Result<()> { let Some(mut req) = self.in_flight_requests.get_async(&request_id).await else { diff --git a/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs index 4d03b47287..90d580023c 100644 --- a/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs +++ b/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs @@ -16,7 +16,7 @@ pub async fn task( client_ws: WebSocketHandle, request_id: protocol::RequestId, mut stopped_sub: message::SubscriptionHandle, - mut msg_rx: mpsc::Receiver, + mut msg_rx: mpsc::Receiver, mut drop_rx: watch::Receiver<()>, can_hibernate: bool, mut tunnel_to_ws_abort_rx: watch::Receiver<()>, @@ -26,7 +26,7 @@ pub async fn task( res = msg_rx.recv() => { if let Some(msg) = res { match msg { - protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage(ws_msg) => { + protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage(ws_msg) => { let msg = if ws_msg.binary { Message::Binary(ws_msg.data.into()) } else { @@ -36,7 +36,7 @@ pub async fn task( }; client_ws.send(msg).await?; } - protocol::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => { + protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => { tracing::debug!( request_id=%protocol::util::id_to_string(&request_id), ack_index=?ack.index, @@ -46,7 +46,7 @@ pub async fn task( .ack_pending_websocket_messages(request_id, ack.index) .await?; } - protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => { + protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => { tracing::debug!(?close, "server closed websocket"); if can_hibernate && close.hibernate { diff --git a/engine/packages/pegboard-gateway/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-gateway/src/ws_to_tunnel_task.rs index f28ce389e6..378802e322 100644 --- a/engine/packages/pegboard-gateway/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-gateway/src/ws_to_tunnel_task.rs @@ -11,7 +11,7 @@ use crate::shared_state::SharedState; pub async fn task( shared_state: SharedState, - request_id: protocol::RequestId, + request_id: protocol::mk2::RequestId, ws_rx: Arc>, mut ws_to_tunnel_abort_rx: watch::Receiver<()>, ) -> Result { @@ -24,8 +24,8 @@ pub async fn task( match msg { Message::Binary(data) => { let ws_message = - protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage( - protocol::ToClientWebSocketMessage { + protocol::mk2::ToClientTunnelMessageKind::ToClientWebSocketMessage( + protocol::mk2::ToClientWebSocketMessage { data: data.into(), binary: true, }, @@ -36,8 +36,8 @@ pub async fn task( } Message::Text(text) => { let ws_message = - protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage( - protocol::ToClientWebSocketMessage { + protocol::mk2::ToClientTunnelMessageKind::ToClientWebSocketMessage( + protocol::mk2::ToClientWebSocketMessage { data: text.as_bytes().to_vec(), binary: false, }, diff --git a/engine/packages/pegboard-runner/src/actor_event_demuxer.rs b/engine/packages/pegboard-runner/src/actor_event_demuxer.rs index 22216d832e..8244d3fcb7 100644 --- a/engine/packages/pegboard-runner/src/actor_event_demuxer.rs +++ b/engine/packages/pegboard-runner/src/actor_event_demuxer.rs @@ -11,7 +11,7 @@ const GC_INTERVAL: Duration = Duration::from_secs(30); const MAX_LAST_SEEN: Duration = Duration::from_secs(30); struct Channel { - tx: mpsc::UnboundedSender, + tx: mpsc::UnboundedSender, handle: JoinHandle<()>, last_seen: Instant, } @@ -33,7 +33,7 @@ impl ActorEventDemuxer { /// Process an event by routing it to the appropriate actor's queue #[tracing::instrument(skip_all)] - pub fn ingest(&mut self, actor_id: Id, event: protocol::Event) { + pub fn ingest(&mut self, actor_id: Id, event: protocol::mk2::Event) { if let Some(channel) = self.channels.get(&actor_id) { let _ = channel.tx.send(event); } else { @@ -107,7 +107,7 @@ impl ActorEventDemuxer { async fn dispatch_events( ctx: &StandaloneCtx, actor_id: Id, - events: Vec, + events: Vec, ) -> Result<()> { let res = ctx .signal(pegboard::workflows::actor::Events { inner: events }) diff --git a/engine/packages/pegboard-runner/src/conn.rs b/engine/packages/pegboard-runner/src/conn.rs index 4bdd2aa38c..573ff52d44 100644 --- a/engine/packages/pegboard-runner/src/conn.rs +++ b/engine/packages/pegboard-runner/src/conn.rs @@ -127,7 +127,7 @@ pub async fn init_conn( }; // Spawn a new runner workflow if one doesn't already exist - let workflow_id = if protocol::is_new(protocol_version) { + let workflow_id = if protocol::is_mk2(protocol_version) { ctx.workflow(pegboard::workflows::runner2::Input { runner_id, namespace_id: namespace.namespace_id, @@ -174,8 +174,17 @@ pub async fn init_conn( return Err(WsError::InvalidInitialPacket("must be `ToServer::Init`").build()); }; - if protocol::is_new(protocol_version) { - ctx.signal(Init); + if protocol::is_mk2(protocol_version) { + ctx.signal(pegboard::workflows::runner2::Init {}) + .to_workflow_id(workflow_id) + .send() + .await + .with_context(|| { + format!( + "failed to forward initial packet to workflow: {}", + workflow_id + ) + })?; } else { // Forward to runner wf ctx.signal(pegboard::workflows::runner::Forward { inner: init_packet }) diff --git a/engine/packages/pegboard-runner/src/lib.rs b/engine/packages/pegboard-runner/src/lib.rs index 44b3d9b9e3..b2eeb5bd26 100644 --- a/engine/packages/pegboard-runner/src/lib.rs +++ b/engine/packages/pegboard-runner/src/lib.rs @@ -137,14 +137,12 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(()); let (ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch::channel(()); let (ping_abort_tx, ping_abort_rx) = watch::channel(()); - let (init_tx, init_rx) = watch::channel(()); let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task( self.ctx.clone(), conn.clone(), sub, eviction_sub, - init_rx, tunnel_to_ws_abort_rx, )); @@ -153,7 +151,6 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { conn.clone(), ws_handle.recv(), eviction_sub2, - init_tx, ws_to_tunnel_abort_rx, )); diff --git a/engine/packages/pegboard-runner/src/ping_task.rs b/engine/packages/pegboard-runner/src/ping_task.rs index e29c1ce8d6..7f8f25b9b0 100644 --- a/engine/packages/pegboard-runner/src/ping_task.rs +++ b/engine/packages/pegboard-runner/src/ping_task.rs @@ -1,7 +1,10 @@ use gas::prelude::*; +use hyper_tungstenite::tungstenite::Message; use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility}; +use rivet_runner_protocol::{self as protocol, versioned}; use std::sync::{Arc, atomic::Ordering}; use tokio::sync::watch; +use vbare::OwnedVersionedData; use crate::{LifecycleResult, UPDATE_PING_INTERVAL, conn::Conn}; @@ -23,15 +26,17 @@ pub async fn task( update_runner_ping(&ctx, &conn).await?; // Send ping to runner - let ping_msg = versioned::ToClient::wrap_latest(protocol::ToClient::ToClientPing( - protocol::ToClientPing { - ts: util::timestamp::now(), - }, - )); - let ping_msg_serialized = ping_msg.serialize(conn.protocol_version)?; - conn.ws_handle - .send(Message::Binary(ping_msg_serialized.into())) - .await?; + if protocol::is_mk2(conn.protocol_version) { + let ping_msg = versioned::ToClientMk2::wrap_latest( + protocol::mk2::ToClient::ToClientPing(protocol::mk2::ToClientPing { + ts: util::timestamp::now(), + }), + ); + let ping_msg_serialized = ping_msg.serialize(conn.protocol_version)?; + conn.ws_handle + .send(Message::Binary(ping_msg_serialized.into())) + .await?; + } } } diff --git a/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs index 3854f46de4..5920570ad4 100644 --- a/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs +++ b/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs @@ -1,10 +1,13 @@ use anyhow::Result; use gas::prelude::*; -use hyper_tungstenite::tungstenite::Message as WsMessage; +use hyper_tungstenite::tungstenite::Message; use pegboard::pubsub_subjects::GatewayReceiverSubject; -use rivet_runner_protocol::{self as protocol, versioned}; +use rivet_runner_protocol::{ + self as protocol, PROTOCOL_MK1_VERSION, PROTOCOL_MK2_VERSION, versioned, +}; use std::sync::Arc; use tokio::sync::watch; +use universalpubsub as ups; use universalpubsub::{NextOutput, PublishOpts, Subscriber}; use vbare::OwnedVersionedData; @@ -16,40 +19,33 @@ pub async fn task( conn: Arc, mut tunnel_sub: Subscriber, mut eviction_sub: Subscriber, - mut init_rx: watch::Receiver<()>, mut tunnel_to_ws_abort_rx: watch::Receiver<()>, ) -> Result { - if protocol::is_mk2(conn.protocol) { - // Must first receive init from adjacent task before processing messages - tokio::select! { - _ = init_rx.changed() => {} - _ = eviction_sub.next() => { - tracing::debug!("runner evicted"); - return Err(errors::WsError::Eviction.build()); - } - _ = tunnel_to_ws_abort_rx.changed() => { - tracing::debug!("task aborted"); - return Ok(LifecycleResult::Aborted); - } - }; - - loop { - match recv_msg().await? { - Ok(msg) => handle_message_mk2(msg).await?, - Err(lifecycle_res) => return Ok(lifecycle_res), - } - } - } else { - loop { - match recv_msg().await? { - Ok(msg) => handle_message_mk1(msg).await?, - Err(lifecycle_res) => return Ok(lifecycle_res), + loop { + match recv_msg( + &mut tunnel_sub, + &mut eviction_sub, + &mut tunnel_to_ws_abort_rx, + ) + .await? + { + Ok(msg) => { + if protocol::is_mk2(conn.protocol_version) { + handle_message_mk2(&ctx, &conn, msg).await?; + } else { + handle_message_mk1(&ctx, &conn, msg).await?; + } } + Err(lifecycle_res) => return Ok(lifecycle_res), } } } -async fn recv_msg() -> Result> { +async fn recv_msg( + tunnel_sub: &mut Subscriber, + eviction_sub: &mut Subscriber, + tunnel_to_ws_abort_rx: &mut watch::Receiver<()>, +) -> Result> { let tunnel_msg = tokio::select! { res = tunnel_sub.next() => { if let NextOutput::Message(tunnel_msg) = res.context("pubsub_to_client_task sub failed")? { @@ -77,9 +73,13 @@ async fn recv_msg() -> Result> { Ok(Ok(tunnel_msg)) } -async fn handle_message_mk2() -> Result<()> { +async fn handle_message_mk2( + ctx: &StandaloneCtx, + conn: &Conn, + tunnel_msg: ups::Message, +) -> Result<()> { // Parse message - let msg = match versioned::ToRunner2::deserialize_with_embedded_version(&tunnel_msg.payload) { + let msg = match versioned::ToRunnerMk2::deserialize_with_embedded_version(&tunnel_msg.payload) { Result::Ok(x) => x, Err(err) => { tracing::error!(?err, "failed to parse tunnel message"); @@ -93,12 +93,12 @@ async fn handle_message_mk2() -> Result<()> { // Publish pong to UPS let gateway_reply_to = GatewayReceiverSubject::new(ping.gateway_id).to_string(); let msg_serialized = versioned::ToGateway::wrap_latest( - protocol::ToGateway::ToGatewayPong(protocol::ToGatewayPong { + protocol::mk2::ToGateway::ToGatewayPong(protocol::mk2::ToGatewayPong { request_id: ping.request_id, ts: ping.ts, }), ) - .serialize_with_embedded_version(protocol::PROTOCOL_VERSION) + .serialize_with_embedded_version(PROTOCOL_MK2_VERSION) .context("failed to serialize pong message for gateway")?; ctx.ups() .context("failed to get UPS instance for tunnel message")? @@ -119,7 +119,6 @@ async fn handle_message_mk2() -> Result<()> { for command_wrapper in &mut command_wrappers { if let protocol::mk2::Command::CommandStartActor( protocol::mk2::CommandStartActor { - actor_id, hibernating_requests, .. }, @@ -127,7 +126,7 @@ async fn handle_message_mk2() -> Result<()> { { let ids = ctx .op(pegboard::ops::actor::hibernating_request::list::Input { - actor_id: Id::parse(actor_id)?, + actor_id: Id::parse(&command_wrapper.checkpoint.actor_id)?, }) .await?; @@ -146,6 +145,9 @@ async fn handle_message_mk2() -> Result<()> { // ToRunner data protocol::mk2::ToClient::ToClientCommands(command_wrappers) } + protocol::mk2::ToRunner::ToClientAckEvents(x) => { + protocol::mk2::ToClient::ToClientAckEvents(x) + } protocol::mk2::ToRunner::ToClientTunnelMessage(x) => { protocol::mk2::ToClient::ToClientTunnelMessage(x) } @@ -154,7 +156,7 @@ async fn handle_message_mk2() -> Result<()> { // Forward raw message to WebSocket let serialized_msg = versioned::ToClientMk2::wrap_latest(to_client_msg).serialize(conn.protocol_version)?; - let ws_msg = WsMessage::Binary(serialized_msg.into()); + let ws_msg = Message::Binary(serialized_msg.into()); conn.ws_handle .send(ws_msg) .await @@ -163,7 +165,11 @@ async fn handle_message_mk2() -> Result<()> { Ok(()) } -async fn handle_message_mk1() -> Result<()> { +async fn handle_message_mk1( + ctx: &StandaloneCtx, + conn: &Conn, + tunnel_msg: ups::Message, +) -> Result<()> { // Parse message let msg = match versioned::ToRunner::deserialize_with_embedded_version(&tunnel_msg.payload) { Result::Ok(x) => x, @@ -179,13 +185,14 @@ async fn handle_message_mk1() -> Result<()> { // Publish pong to UPS let gateway_reply_to = GatewayReceiverSubject::new(ping.gateway_id).to_string(); let msg_serialized = versioned::ToGateway::wrap_latest( - protocol::ToGateway::ToGatewayPong(protocol::ToGatewayPong { + protocol::mk2::ToGateway::ToGatewayPong(protocol::mk2::ToGatewayPong { request_id: ping.request_id, ts: ping.ts, }), ) - .serialize_with_embedded_version(protocol::PROTOCOL_VERSION) + .serialize_with_embedded_version(PROTOCOL_MK2_VERSION) .context("failed to serialize pong message for gateway")?; + ctx.ups() .context("failed to get UPS instance for tunnel message")? .publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) @@ -198,7 +205,7 @@ async fn handle_message_mk1() -> Result<()> { })?; // Not sent to client - continue; + return Ok(()); } protocol::ToRunner::ToClientInit(x) => protocol::ToClient::ToClientInit(x), protocol::ToRunner::ToClientClose => return Err(errors::WsError::Eviction.build()), @@ -242,7 +249,7 @@ async fn handle_message_mk1() -> Result<()> { tracing::debug!(?to_client_msg, "sending runner message to client"); let serialized_msg = versioned::ToClient::wrap_latest(to_client_msg).serialize(conn.protocol_version)?; - let ws_msg = WsMessage::Binary(serialized_msg.into()); + let ws_msg = Message::Binary(serialized_msg.into()); conn.ws_handle .send(ws_msg) .await diff --git a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs index d004b878a7..7bb98b9466 100644 --- a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs @@ -1,4 +1,5 @@ use anyhow::Context; +use bytes::Bytes; use futures_util::TryStreamExt; use gas::prelude::Id; use gas::prelude::*; @@ -8,9 +9,11 @@ use pegboard::pubsub_subjects::GatewayReceiverSubject; use pegboard::utils::event_actor_id; use pegboard_actor_kv as kv; use rivet_guard_core::websocket_handle::WebSocketReceiver; -use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; +use rivet_runner_protocol::{ + self as protocol, PROTOCOL_MK1_VERSION, PROTOCOL_MK2_VERSION, versioned, +}; use std::sync::{Arc, atomic::Ordering}; -use tokio::sync::{Mutex, watch}; +use tokio::sync::{Mutex, MutexGuard, watch}; use universalpubsub::PublishOpts; use universalpubsub::Subscriber; use vbare::OwnedVersionedData; @@ -23,7 +26,6 @@ pub async fn task( conn: Arc, ws_rx: Arc>, eviction_sub2: Subscriber, - init_tx: watch::Sender<()>, ws_to_tunnel_abort_rx: watch::Receiver<()>, ) -> Result { let mut event_demuxer = ActorEventDemuxer::new(ctx.clone()); @@ -33,7 +35,6 @@ pub async fn task( conn, ws_rx, eviction_sub2, - init_tx, ws_to_tunnel_abort_rx, &mut event_demuxer, ) @@ -51,32 +52,31 @@ pub async fn task_inner( conn: Arc, ws_rx: Arc>, mut eviction_sub2: Subscriber, - init_tx: watch::Sender<()>, mut ws_to_tunnel_abort_rx: watch::Receiver<()>, event_demuxer: &mut ActorEventDemuxer, ) -> Result { let mut ws_rx = ws_rx.lock().await; - if protocol::is_mk2(conn.protocol_version) { - loop { - match recv_msg().await? { - Ok(Some(msg)) => handle_message_mk2(msg).await?, - Ok(None) => {} - Err(lifecycle_res) => return Ok(lifecycle_res), - } - } - } else { - loop { - match recv_msg().await? { - Ok(Some(msg)) => handle_message_mk1(msg).await?, - Ok(None) => {} - Err(lifecycle_res) => return Ok(lifecycle_res), + loop { + match recv_msg(&mut ws_rx, &mut eviction_sub2, &mut ws_to_tunnel_abort_rx).await? { + Ok(Some(msg)) => { + if protocol::is_mk2(conn.protocol_version) { + handle_message_mk2(&ctx, &conn, event_demuxer, msg).await?; + } else { + handle_message_mk1(&ctx, &conn, msg).await?; + } } + Ok(None) => {} + Err(lifecycle_res) => return Ok(lifecycle_res), } } } -async fn recv_msg() -> Result, LifecycleResult>> { +async fn recv_msg( + ws_rx: &mut MutexGuard<'_, WebSocketReceiver>, + eviction_sub2: &mut Subscriber, + ws_to_tunnel_abort_rx: &mut watch::Receiver<()>, +) -> Result, LifecycleResult>> { let msg = tokio::select! { res = ws_rx.try_next() => { if let Some(msg) = res? { @@ -103,7 +103,7 @@ async fn recv_msg() -> Result, LifecycleResult>> "received binary message from WebSocket" ); - Ok(Some(data)) + Ok(Ok(Some(data))) } WsMessage::Close(_) => { tracing::debug!("websocket closed"); @@ -111,7 +111,7 @@ async fn recv_msg() -> Result, LifecycleResult>> } _ => { // Ignore other message types - Ok(None) + Ok(Ok(None)) } } } @@ -119,22 +119,21 @@ async fn recv_msg() -> Result, LifecycleResult>> #[tracing::instrument(skip_all)] async fn handle_message_mk2( ctx: &StandaloneCtx, - conn: &Arc, - init_tx: &watch::Sender<()>, + conn: &Conn, event_demuxer: &mut ActorEventDemuxer, - msg: (), + msg: Bytes, ) -> Result<()> { // Parse message - let msg = match versioned::ToServerMk2::deserialize(&data, conn.protocol_version) { + let msg = match versioned::ToServerMk2::deserialize(&msg, conn.protocol_version) { Ok(x) => x, Err(err) => { - tracing::warn!(?err, data_len = data.len(), "failed to deserialize message"); + tracing::warn!(?err, msg_len = msg.len(), "failed to deserialize message"); return Ok(()); } }; match msg { - protocol::ToServer::ToServerPong(pong) => { + protocol::mk2::ToServer::ToServerPong(pong) => { let now = util::timestamp::now(); let rtt = now.saturating_sub(pong.ts); @@ -148,19 +147,21 @@ async fn handle_message_mk2( conn.last_rtt.store(rtt, Ordering::Relaxed); } // Process KV request - protocol::ToServer::ToServerKvRequest(req) => { + protocol::mk2::ToServer::ToServerKvRequest(req) => { let actor_id = match Id::parse(&req.actor_id) { Ok(actor_id) => actor_id, Err(err) => { - let res_msg = versioned::ToClient::wrap_latest( - protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { - request_id: req.request_id, - data: protocol::KvResponseData::KvErrorResponse( - protocol::KvErrorResponse { - message: err.to_string(), - }, - ), - }), + let res_msg = versioned::ToClientMk2::wrap_latest( + protocol::mk2::ToClient::ToClientKvResponse( + protocol::mk2::ToClientKvResponse { + request_id: req.request_id, + data: protocol::mk2::KvResponseData::KvErrorResponse( + protocol::mk2::KvErrorResponse { + message: err.to_string(), + }, + ), + }, + ), ); let res_msg_serialized = res_msg @@ -189,15 +190,17 @@ async fn handle_message_mk2( // Verify actor belongs to this runner if !actor_belongs { - let res_msg = versioned::ToClient::wrap_latest( - protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { - request_id: req.request_id, - data: protocol::KvResponseData::KvErrorResponse( - protocol::KvErrorResponse { - message: "given actor does not belong to runner".to_string(), - }, - ), - }), + let res_msg = versioned::ToClientMk2::wrap_latest( + protocol::mk2::ToClient::ToClientKvResponse( + protocol::mk2::ToClientKvResponse { + request_id: req.request_id, + data: protocol::mk2::KvResponseData::KvErrorResponse( + protocol::mk2::KvErrorResponse { + message: "given actor does not belong to runner".to_string(), + }, + ), + }, + ), ); let res_msg_serialized = res_msg @@ -214,30 +217,32 @@ async fn handle_message_mk2( // TODO: Add queue and bg thread for processing kv ops // Run kv operation match req.data { - protocol::KvRequestData::KvGetRequest(body) => { + protocol::mk2::KvRequestData::KvGetRequest(body) => { let res = kv::get(&*ctx.udb()?, actor_id, body.keys).await; - let res_msg = versioned::ToClient::wrap_latest( - protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { - request_id: req.request_id, - data: match res { - Ok((keys, values, metadata)) => { - protocol::KvResponseData::KvGetResponse( - protocol::KvGetResponse { - keys, - values, - metadata, + let res_msg = versioned::ToClientMk2::wrap_latest( + protocol::mk2::ToClient::ToClientKvResponse( + protocol::mk2::ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok((keys, values, metadata)) => { + protocol::mk2::KvResponseData::KvGetResponse( + protocol::mk2::KvGetResponse { + keys, + values, + metadata, + }, + ) + } + Err(err) => protocol::mk2::KvResponseData::KvErrorResponse( + protocol::mk2::KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), }, - ) - } - Err(err) => protocol::KvResponseData::KvErrorResponse( - protocol::KvErrorResponse { - // TODO: Don't return actual error? - message: err.to_string(), - }, - ), + ), + }, }, - }), + ), ); let res_msg_serialized = res_msg @@ -248,7 +253,7 @@ async fn handle_message_mk2( .await .context("failed to send KV get response to client")?; } - protocol::KvRequestData::KvListRequest(body) => { + protocol::mk2::KvRequestData::KvListRequest(body) => { let res = kv::list( &*ctx.udb()?, actor_id, @@ -261,27 +266,29 @@ async fn handle_message_mk2( ) .await; - let res_msg = versioned::ToClient::wrap_latest( - protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { - request_id: req.request_id, - data: match res { - Ok((keys, values, metadata)) => { - protocol::KvResponseData::KvListResponse( - protocol::KvListResponse { - keys, - values, - metadata, + let res_msg = versioned::ToClientMk2::wrap_latest( + protocol::mk2::ToClient::ToClientKvResponse( + protocol::mk2::ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok((keys, values, metadata)) => { + protocol::mk2::KvResponseData::KvListResponse( + protocol::mk2::KvListResponse { + keys, + values, + metadata, + }, + ) + } + Err(err) => protocol::mk2::KvResponseData::KvErrorResponse( + protocol::mk2::KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), }, - ) - } - Err(err) => protocol::KvResponseData::KvErrorResponse( - protocol::KvErrorResponse { - // TODO: Don't return actual error? - message: err.to_string(), - }, - ), + ), + }, }, - }), + ), ); let res_msg_serialized = res_msg @@ -292,24 +299,26 @@ async fn handle_message_mk2( .await .context("failed to send KV list response to client")?; } - protocol::KvRequestData::KvPutRequest(body) => { + protocol::mk2::KvRequestData::KvPutRequest(body) => { let res = kv::put(&*ctx.udb()?, actor_id, body.keys, body.values).await; - let res_msg = versioned::ToClient::wrap_latest( - protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { - request_id: req.request_id, - data: match res { - Ok(()) => protocol::KvResponseData::KvPutResponse, - Err(err) => { - protocol::KvResponseData::KvErrorResponse( - protocol::KvErrorResponse { - // TODO: Don't return actual error? - message: err.to_string(), - }, - ) - } + let res_msg = versioned::ToClientMk2::wrap_latest( + protocol::mk2::ToClient::ToClientKvResponse( + protocol::mk2::ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok(()) => protocol::mk2::KvResponseData::KvPutResponse, + Err(err) => { + protocol::mk2::KvResponseData::KvErrorResponse( + protocol::mk2::KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }, + ) + } + }, }, - }), + ), ); let res_msg_serialized = res_msg @@ -320,22 +329,24 @@ async fn handle_message_mk2( .await .context("failed to send KV put response to client")?; } - protocol::KvRequestData::KvDeleteRequest(body) => { + protocol::mk2::KvRequestData::KvDeleteRequest(body) => { let res = kv::delete(&*ctx.udb()?, actor_id, body.keys).await; - let res_msg = versioned::ToClient::wrap_latest( - protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { - request_id: req.request_id, - data: match res { - Ok(()) => protocol::KvResponseData::KvDeleteResponse, - Err(err) => protocol::KvResponseData::KvErrorResponse( - protocol::KvErrorResponse { - // TODO: Don't return actual error? - message: err.to_string(), - }, - ), + let res_msg = versioned::ToClientMk2::wrap_latest( + protocol::mk2::ToClient::ToClientKvResponse( + protocol::mk2::ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok(()) => protocol::mk2::KvResponseData::KvDeleteResponse, + Err(err) => protocol::mk2::KvResponseData::KvErrorResponse( + protocol::mk2::KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }, + ), + }, }, - }), + ), ); let res_msg_serialized = res_msg @@ -346,22 +357,24 @@ async fn handle_message_mk2( .await .context("failed to send KV delete response to client")?; } - protocol::KvRequestData::KvDropRequest => { + protocol::mk2::KvRequestData::KvDropRequest => { let res = kv::delete_all(&*ctx.udb()?, actor_id).await; - let res_msg = versioned::ToClient::wrap_latest( - protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { - request_id: req.request_id, - data: match res { - Ok(()) => protocol::KvResponseData::KvDropResponse, - Err(err) => protocol::KvResponseData::KvErrorResponse( - protocol::KvErrorResponse { - // TODO: Don't return actual error? - message: err.to_string(), - }, - ), + let res_msg = versioned::ToClientMk2::wrap_latest( + protocol::mk2::ToClient::ToClientKvResponse( + protocol::mk2::ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok(()) => protocol::mk2::KvResponseData::KvDropResponse, + Err(err) => protocol::mk2::KvResponseData::KvErrorResponse( + protocol::mk2::KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }, + ), + }, }, - }), + ), ); let res_msg_serialized = res_msg @@ -374,12 +387,12 @@ async fn handle_message_mk2( } } } - protocol::ToServer::ToServerTunnelMessage(tunnel_msg) => { - handle_tunnel_message(&ctx, &conn, tunnel_msg) + protocol::mk2::ToServer::ToServerTunnelMessage(tunnel_msg) => { + handle_tunnel_message_mk2(&ctx, &conn, tunnel_msg) .await .context("failed to handle tunnel message")?; } - protocol::ToServer::ToServerInit(_) => { + protocol::mk2::ToServer::ToServerInit(init) => { // We send the signal first because we don't want to continue if this fails ctx.signal(pegboard::workflows::runner2::Init {}) .to_workflow_id(conn.workflow_id) @@ -394,60 +407,50 @@ async fn handle_message_mk2( let init_data = ctx .activity(ProcessInitInput { - runner_id: input.runner_id, - namespace_id: input.namespace_id, - last_command_idx: last_command_idx.unwrap_or(-1), - prepopulate_actor_names, - metadata, + runner_id: conn.runner_id, + namespace_id: conn.namespace_id, + last_command_idx: init.last_command_idx.unwrap_or(-1), + prepopulate_actor_names: init.prepopulate_actor_names, + metadata: init.metadata, }) .await?; // Send init packet - let msg = versioned::ToClient::wrap_latest(protocol::ToRunner::ToClientInit( - protocol::ToClientInit { - runner_id: input.runner_id.to_string(), + let init_msg = versioned::ToClientMk2::wrap_latest( + protocol::mk2::ToClient::ToClientInit(protocol::mk2::ToClientInit { + runner_id: conn.runner_id.to_string(), last_event_idx: init_data.last_event_idx, - metadata: protocol::ProtocolMetadata { - runner_lost_threshold: runner_lost_threshold, + metadata: protocol::mk2::ProtocolMetadata { + runner_lost_threshold: ctx.config().pegboard().runner_lost_threshold(), }, - }, - )); - let msg_serialized = res_msg - .serialize(conn.protocol_version) - .context("failed to serialize KV delete response")?; + }), + ); + let init_msg_serialized = init_msg.serialize(conn.protocol_version)?; conn.ws_handle - .send(Message::Binary(res_msg_serialized.into())) - .await - .context("failed to send KV delete response to client")?; + .send(Message::Binary(init_msg_serialized.into())) + .await?; // Send missed commands if !init_data.missed_commands.is_empty() { - let msg = versioned::ToClient::wrap_latest(protocol::ToRunner::ToClientCommands( - init_data.missed_commands, - )); - let msg_serialized = res_msg - .serialize(conn.protocol_version) - .context("failed to serialize KV delete response")?; + let msg = versioned::ToClientMk2::wrap_latest( + protocol::mk2::ToClient::ToClientCommands(init_data.missed_commands), + ); + let msg_serialized = msg.serialize(conn.protocol_version)?; conn.ws_handle - .send(Message::Binary(res_msg_serialized.into())) - .await - .context("failed to send KV delete response to client")?; + .send(Message::Binary(msg_serialized.into())) + .await?; } - - // Inform adjacent task that we have processed and sent the init packet. This will allow it to - // start accepting commands - let _ = init_tx.send(()); } // Forward to actor wf - protocol::ToServer::ToServerEvents(events) => { + protocol::mk2::ToServer::ToServerEvents(events) => { for event in events { event_demuxer.ingest(Id::parse(event_actor_id(&event.inner))?, event.inner); } } - protocol::ToServer::ToServerAckCommands(_) => { + protocol::mk2::ToServer::ToServerAckCommands(_) => { ack_commands(&ctx).await?; } - protocol::ToServer::ToServerStopping => { + protocol::mk2::ToServer::ToServerStopping => { ctx.signal(pegboard::workflows::runner2::Stop { reset_actor_rescheduling: false, }) @@ -464,10 +467,10 @@ async fn handle_message_mk2( } #[tracing::instrument(skip_all)] -async fn handle_message_mk1(ctx: &StandaloneCtx, conn: &Arc, msg: ()) -> Result<()> { +async fn handle_message_mk1(ctx: &StandaloneCtx, conn: &Conn, msg: Bytes) -> Result<()> { // HACK: Decode v2 to handle tunnel ack if rivet_runner_protocol::compat::version_needs_tunnel_ack(conn.protocol_version) { - match compat_ack_tunnel_message(&conn, &data[..]).await { + match compat_ack_tunnel_message(&conn, &msg[..]).await { Ok(_) => {} Err(err) => { tracing::error!(?err, "failed to send compat ack tunnel message") @@ -476,10 +479,10 @@ async fn handle_message_mk1(ctx: &StandaloneCtx, conn: &Arc, msg: ()) -> R } // Parse message - let msg = match versioned::ToServer::deserialize(&data, conn.protocol_version) { + let msg = match versioned::ToServer::deserialize(&msg, conn.protocol_version) { Ok(x) => x, Err(err) => { - tracing::warn!(?err, data_len = data.len(), "failed to deserialize message"); + tracing::warn!(?err, msg_len = msg.len(), "failed to deserialize message"); return Ok(()); } }; @@ -586,7 +589,13 @@ async fn handle_message_mk1(ctx: &StandaloneCtx, conn: &Arc, msg: ()) -> R protocol::KvGetResponse { keys, values, - metadata, + metadata: metadata + .into_iter() + .map(|x| protocol::KvMetadata { + version: x.version, + create_ts: x.create_ts, + }) + .collect(), }, ) } @@ -612,7 +621,25 @@ async fn handle_message_mk1(ctx: &StandaloneCtx, conn: &Arc, msg: ()) -> R let res = kv::list( &*ctx.udb()?, actor_id, - body.query, + match body.query { + protocol::KvListQuery::KvListAllQuery => { + protocol::mk2::KvListQuery::KvListAllQuery + } + protocol::KvListQuery::KvListRangeQuery(q) => { + protocol::mk2::KvListQuery::KvListRangeQuery( + protocol::mk2::KvListRangeQuery { + start: q.start, + end: q.end, + exclusive: q.exclusive, + }, + ) + } + protocol::KvListQuery::KvListPrefixQuery(q) => { + protocol::mk2::KvListQuery::KvListPrefixQuery( + protocol::mk2::KvListPrefixQuery { key: q.key }, + ) + } + }, body.reverse.unwrap_or_default(), body.limit .map(TryInto::try_into) @@ -630,7 +657,13 @@ async fn handle_message_mk1(ctx: &StandaloneCtx, conn: &Arc, msg: ()) -> R protocol::KvListResponse { keys, values, - metadata, + metadata: metadata + .into_iter() + .map(|x| protocol::KvMetadata { + version: x.version, + create_ts: x.create_ts, + }) + .collect(), }, ) } @@ -735,7 +768,7 @@ async fn handle_message_mk1(ctx: &StandaloneCtx, conn: &Arc, msg: ()) -> R } } protocol::ToServer::ToServerTunnelMessage(tunnel_msg) => { - handle_tunnel_message(&ctx, tunnel_msg) + handle_tunnel_message_mk1(&ctx, tunnel_msg) .await .context("failed to handle tunnel message")?; } @@ -756,6 +789,8 @@ async fn handle_message_mk1(ctx: &StandaloneCtx, conn: &Arc, msg: ()) -> R })?; } } + + Ok(()) } async fn ack_commands(ctx: &StandaloneCtx) -> Result<()> { @@ -765,52 +800,21 @@ async fn ack_commands(ctx: &StandaloneCtx) -> Result<()> { // limit: // }); // }).await?; + + todo!(); } #[tracing::instrument(skip_all)] async fn handle_tunnel_message_mk2( ctx: &StandaloneCtx, - conn: &Arc, + conn: &Conn, msg: protocol::mk2::ToServerTunnelMessage, ) -> Result<()> { - // Ignore DeprecatedTunnelAck messages (used only for backwards compatibility) - if matches!( - msg.message_kind, - protocol::mk2::ToServerTunnelMessageKind::DeprecatedTunnelAck - ) { - return Ok(()); - } - - // Send DeprecatedTunnelAck back to runner for older protocol versions - if protocol::mk2::compat::version_needs_tunnel_ack(conn.protocol_version) { - let ack_msg = versioned::ToClientMk2::wrap_latest( - protocol::mk2::ToClient::ToClientTunnelMessage(protocol::mk2::ToClientTunnelMessage { - message_id: msg.message_id, - message_kind: protocol::mk2::ToClientTunnelMessageKind::DeprecatedTunnelAck, - }), - ); - - let ack_serialized = ack_msg - .serialize(conn.protocol_version) - .context("failed to serialize DeprecatedTunnelAck response")?; - - conn.ws_handle - .send(hyper_tungstenite::tungstenite::Message::Binary( - ack_serialized.into(), - )) - .await - .context("failed to send DeprecatedTunnelAck to runner")?; - } - - // Parse message ID to extract gateway_id - let parts = - tunnel_id::parse_message_id(msg.message_id).context("failed to parse message id")?; - // Publish message to UPS - let gateway_reply_to = GatewayReceiverSubject::new(parts.gateway_id).to_string(); + let gateway_reply_to = GatewayReceiverSubject::new(msg.message_id.gateway_id).to_string(); let msg_serialized = versioned::ToGateway::wrap_latest(protocol::mk2::ToGateway::ToServerTunnelMessage(msg)) - .serialize_with_embedded_version(PROTOCOL_VERSION) + .serialize_with_embedded_version(PROTOCOL_MK2_VERSION) .context("failed to serialize tunnel message for gateway")?; ctx.ups() .context("failed to get UPS instance for tunnel message")? @@ -841,10 +845,11 @@ async fn handle_tunnel_message_mk1( // Publish message to UPS let gateway_reply_to = GatewayReceiverSubject::new(msg.message_id.gateway_id).to_string(); - let msg_serialized = - versioned::ToGateway::wrap_latest(protocol::ToGateway::ToServerTunnelMessage(msg)) - .serialize_with_embedded_version(PROTOCOL_VERSION) - .context("failed to serialize tunnel message for gateway")?; + let msg_serialized = versioned::ToGateway::v3_to_v4(versioned::ToGateway::V3( + protocol::ToGateway::ToServerTunnelMessage(msg), + ))? + .serialize_with_embedded_version(PROTOCOL_MK2_VERSION) + .context("failed to serialize tunnel message for gateway")?; ctx.ups() .context("failed to get UPS instance for tunnel message")? .publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) @@ -863,7 +868,7 @@ async fn handle_tunnel_message_mk1( /// /// We have to parse as specifically a v2 message since we need the exact request & message ID /// provided by the user and not available in v3. -async fn compat_ack_tunnel_message(conn: &Arc, payload: &[u8]) -> Result<()> { +async fn compat_ack_tunnel_message(conn: &Conn, payload: &[u8]) -> Result<()> { use rivet_runner_protocol::generated::v2 as protocol_v2; // Parse payload diff --git a/engine/packages/pegboard-serverless/src/lib.rs b/engine/packages/pegboard-serverless/src/lib.rs index 9faa65950f..720d43461d 100644 --- a/engine/packages/pegboard-serverless/src/lib.rs +++ b/engine/packages/pegboard-serverless/src/lib.rs @@ -360,7 +360,9 @@ async fn outbound_handler( let mut source = sse::EventSource::new(req).context("failed creating event source")?; let mut runner_id = None; + let mut runner_protocol_version = None; + let runner_protocol_version2 = &mut runner_protocol_version; let stream_handler = async { while let Some(event) = source.next().await { match event { @@ -375,9 +377,10 @@ async fn outbound_handler( .context("invalid payload")?; match payload { - protocol::ToServerlessServer::ToServerlessServerInit(init) => { + protocol::mk2::ToServerlessServer::ToServerlessServerInit(init) => { runner_id = Some(Id::parse(&init.runner_id).context("invalid runner id")?); + *runner_protocol_version2 = Some(init.runner_protocol_version); } } } @@ -422,6 +425,7 @@ async fn outbound_handler( } // Continue waiting on req while draining + let runner_protocol_version2 = &mut runner_protocol_version; let wait_for_shutdown_fut = async move { while let Some(event) = source.next().await { match event { @@ -440,10 +444,11 @@ async fn outbound_handler( .context("invalid payload")?; match payload { - protocol::ToServerlessServer::ToServerlessServerInit(init) => { + protocol::mk2::ToServerlessServer::ToServerlessServerInit(init) => { let runner_id_local = Id::parse(&init.runner_id).context("invalid runner id")?; runner_id = Some(runner_id_local); + *runner_protocol_version2 = Some(init.runner_protocol_version); drain_runner(ctx, runner_id_local).await?; } } @@ -469,8 +474,8 @@ async fn outbound_handler( // // This will force the runner to stop the request in order to avoid hitting the serverless // timeout threshold - if let Some(runner_id) = runner_id { - publish_to_client_stop(ctx, runner_id).await?; + if let (Some(runner_id), Some(runner_protocol_version)) = (runner_id, runner_protocol_version) { + publish_to_client_stop(ctx, runner_id, runner_protocol_version).await?; } tracing::debug!(?runner_id, "outbound req stopped"); @@ -515,13 +520,21 @@ async fn drain_runner(ctx: &StandaloneCtx, runner_id: Id) -> Result<()> { /// Send a stop message to the client. /// /// This will close the runner's WebSocket. -async fn publish_to_client_stop(ctx: &StandaloneCtx, runner_id: Id) -> Result<()> { +async fn publish_to_client_stop( + ctx: &StandaloneCtx, + runner_id: Id, + runner_protocol_version: u16, +) -> Result<()> { let receiver_subject = pegboard::pubsub_subjects::RunnerReceiverSubject::new(runner_id).to_string(); - let message_serialized = - protocol::versioned::ToRunner::wrap_latest(protocol::ToRunner::ToRunnerClose) - .serialize_with_embedded_version(protocol::PROTOCOL_VERSION)?; + let message_serialized = if protocol::is_mk2(runner_protocol_version) { + protocol::versioned::ToRunnerMk2::wrap_latest(protocol::mk2::ToRunner::ToRunnerClose) + .serialize_with_embedded_version(protocol::PROTOCOL_MK2_VERSION)? + } else { + protocol::versioned::ToRunner::wrap_latest(protocol::ToRunner::ToClientClose) + .serialize_with_embedded_version(protocol::PROTOCOL_MK1_VERSION)? + }; ctx.ups()? .publish(&receiver_subject, &message_serialized, PublishOpts::one()) diff --git a/engine/packages/pegboard/src/keys/runner.rs b/engine/packages/pegboard/src/keys/runner.rs index 9c038b4a73..e51555e4c4 100644 --- a/engine/packages/pegboard/src/keys/runner.rs +++ b/engine/packages/pegboard/src/keys/runner.rs @@ -183,6 +183,50 @@ impl<'de> TupleUnpack<'de> for TotalSlotsKey { } } +#[derive(Debug)] +pub struct ProtocolVersionKey { + runner_id: Id, +} + +impl ProtocolVersionKey { + pub fn new(runner_id: Id) -> Self { + ProtocolVersionKey { runner_id } + } +} + +impl FormalKey for ProtocolVersionKey { + type Value = u16; + + fn deserialize(&self, raw: &[u8]) -> Result { + Ok(u16::from_be_bytes(raw.try_into()?)) + } + + fn serialize(&self, value: Self::Value) -> Result> { + Ok(value.to_be_bytes().to_vec()) + } +} + +impl TuplePack for ProtocolVersionKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let t = (RUNNER, DATA, self.runner_id, PROTOCOL_VERSION); + t.pack(w, tuple_depth) + } +} + +impl<'de> TupleUnpack<'de> for ProtocolVersionKey { + fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> { + let (input, (_, _, runner_id, _)) = + <(usize, usize, Id, usize)>::unpack(input, tuple_depth)?; + let v = ProtocolVersionKey { runner_id }; + + Ok((input, v)) + } +} + #[derive(Debug)] pub struct ActorKey { runner_id: Id, diff --git a/engine/packages/pegboard/src/lib.rs b/engine/packages/pegboard/src/lib.rs index 50d941a7fc..5030e95d15 100644 --- a/engine/packages/pegboard/src/lib.rs +++ b/engine/packages/pegboard/src/lib.rs @@ -5,7 +5,6 @@ pub mod keys; mod metrics; pub mod ops; pub mod pubsub_subjects; -pub mod tunnel; pub mod utils; pub mod workflows; diff --git a/engine/packages/pegboard/src/ops/runner/update_alloc_idx.rs b/engine/packages/pegboard/src/ops/runner/update_alloc_idx.rs index be31d4c962..c720ad831f 100644 --- a/engine/packages/pegboard/src/ops/runner/update_alloc_idx.rs +++ b/engine/packages/pegboard/src/ops/runner/update_alloc_idx.rs @@ -65,6 +65,8 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) let remaining_slots_key = keys::runner::RemainingSlotsKey::new(runner.runner_id); let total_slots_key = keys::runner::TotalSlotsKey::new(runner.runner_id); + let protocol_version_key = + keys::runner::ProtocolVersionKey::new(runner.runner_id); let last_ping_ts_key = keys::runner::LastPingTsKey::new(runner.runner_id); let expired_ts_key = keys::runner::ExpiredTsKey::new(runner.runner_id); @@ -75,6 +77,7 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) version_entry, remaining_slots_entry, total_slots_entry, + protocol_version_entry, last_ping_ts_entry, expired_ts_entry, ) = tokio::try_join!( @@ -84,6 +87,7 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) tx.read_opt(&version_key, Serializable), tx.read_opt(&remaining_slots_key, Serializable), tx.read_opt(&total_slots_key, Serializable), + tx.read_opt(&protocol_version_key, Serializable), tx.read_opt(&last_ping_ts_key, Serializable), tx.read_opt(&expired_ts_key, Serializable), )?; @@ -95,6 +99,7 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) Some(version), Some(remaining_slots), Some(total_slots), + Some(protocol_version), Some(old_last_ping_ts), ) = ( workflow_id_entry, @@ -103,6 +108,7 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) version_entry, remaining_slots_entry, total_slots_entry, + protocol_version_entry, last_ping_ts_entry, ) else { @@ -151,6 +157,7 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) workflow_id, remaining_slots, total_slots, + protocol_version, }, )?; } @@ -181,6 +188,7 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) workflow_id, remaining_slots, total_slots, + protocol_version, }, )?; diff --git a/engine/packages/pegboard/src/utils.rs b/engine/packages/pegboard/src/utils.rs index 61bba809ae..0bf3e91be7 100644 --- a/engine/packages/pegboard/src/utils.rs +++ b/engine/packages/pegboard/src/utils.rs @@ -1,6 +1,23 @@ use rivet_runner_protocol as protocol; -pub fn event_actor_id(event: &protocol::Event) -> &str { +pub fn event_actor_id(event: &protocol::mk2::Event) -> &str { + match event { + protocol::mk2::Event::EventActorIntent(protocol::mk2::EventActorIntent { + actor_id, + .. + }) => actor_id, + protocol::mk2::Event::EventActorStateUpdate(protocol::mk2::EventActorStateUpdate { + actor_id, + .. + }) => actor_id, + protocol::mk2::Event::EventActorSetAlarm(protocol::mk2::EventActorSetAlarm { + actor_id, + .. + }) => actor_id, + } +} + +pub fn event_actor_id_mk1(event: &protocol::Event) -> &str { match event { protocol::Event::EventActorIntent(protocol::EventActorIntent { actor_id, .. }) => actor_id, protocol::Event::EventActorStateUpdate(protocol::EventActorStateUpdate { diff --git a/engine/packages/pegboard/src/workflows/actor/destroy.rs b/engine/packages/pegboard/src/workflows/actor/destroy.rs index f5d5afc7cf..e7391a6164 100644 --- a/engine/packages/pegboard/src/workflows/actor/destroy.rs +++ b/engine/packages/pegboard/src/workflows/actor/destroy.rs @@ -1,5 +1,6 @@ use gas::prelude::*; use rivet_data::converted::ActorByKeyKeyData; +use rivet_runner_protocol::PROTOCOL_MK1_VERSION; use universaldb::options::MutationType; use universaldb::utils::IsolationLevel::*; @@ -194,6 +195,7 @@ pub(crate) async fn clear_slot( let runner_remaining_slots_key = keys::runner::RemainingSlotsKey::new(runner_id); let runner_total_slots_key = keys::runner::TotalSlotsKey::new(runner_id); let runner_last_ping_ts_key = keys::runner::LastPingTsKey::new(runner_id); + let runner_protocol_version_key = keys::runner::ProtocolVersionKey::new(runner_id); let ( runner_workflow_id, diff --git a/engine/packages/pegboard/src/workflows/actor/mod.rs b/engine/packages/pegboard/src/workflows/actor/mod.rs index 2e5587fcae..79df340e42 100644 --- a/engine/packages/pegboard/src/workflows/actor/mod.rs +++ b/engine/packages/pegboard/src/workflows/actor/mod.rs @@ -218,9 +218,11 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> runtime::SpawnActorOutput::Allocated { runner_id, runner_workflow_id, + runner_protocol_version, } => runtime::LifecycleState::new( runner_id, runner_workflow_id, + runner_protocol_version, ctx.config().pegboard().actor_start_threshold(), ), runtime::SpawnActorOutput::Sleep => { @@ -291,8 +293,8 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> return Ok(Loop::Continue); } - let (Some(runner_id), Some(runner_workflow_id)) = - (state.runner_id, state.runner_workflow_id) + let (Some(runner_id), Some(runner_workflow_id), Some(runner_protocol_version)) = + (state.runner_id, state.runner_workflow_id, state.runner_protocol_version) else { tracing::warn!("actor not allocated, ignoring event"); return Ok(Loop::Continue); @@ -319,19 +321,22 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> }) .await?; - // TODO: Send message to tunnel - // // Send signal to stop actor now that we know it will be sleeping - // ctx.signal(crate::workflows::runner::Command { - // inner: protocol::Command::CommandStopActor( - // protocol::CommandStopActor { - // actor_id: input.actor_id.to_string(), - // generation: state.generation, - // }, - // ), - // }) - // .to_workflow_id(runner_workflow_id) - // .send() - // .await?; + if protocol::is_mk2(runner_protocol_version) { + // TODO: Send message to tunnel + } else { + // Send signal to stop actor now that we know it will be sleeping + ctx.signal(crate::workflows::runner::Command { + inner: protocol::Command::CommandStopActor( + protocol::CommandStopActor { + actor_id: input.actor_id.to_string(), + generation: state.generation, + }, + ), + }) + .to_workflow_id(runner_workflow_id) + .send() + .await?; + } } } protocol::ActorIntent::ActorIntentStop => { @@ -350,18 +355,21 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> }) .await?; - // TODO: Send message to tunnel - // ctx.signal(crate::workflows::runner::Command { - // inner: protocol::Command::CommandStopActor( - // protocol::CommandStopActor { - // actor_id: input.actor_id.to_string(), - // generation: state.generation, - // }, - // ), - // }) - // .to_workflow_id(runner_workflow_id) - // .send() - // .await?; + if protocol::is_mk2(runner_protocol_version) { + // TODO: Send message to tunnel + } else { + ctx.signal(crate::workflows::runner::Command { + inner: protocol::Command::CommandStopActor( + protocol::CommandStopActor { + actor_id: input.actor_id.to_string(), + generation: state.generation, + }, + ), + }) + .to_workflow_id(runner_workflow_id) + .send() + .await?; + } } } }, @@ -479,7 +487,7 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> } if !state.going_away { - let Some(runner_workflow_id) = state.runner_workflow_id else { + let (Some(runner_workflow_id), Some(runner_protocol_version)) = (state.runner_workflow_id, state.runner_protocol_version) else { return Ok(Loop::Continue); }; @@ -494,31 +502,37 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> }) .await?; - // TODO: Send message to tunnel - // ctx.signal(crate::workflows::runner::Command { - // inner: protocol::Command::CommandStopActor(protocol::CommandStopActor { - // actor_id: input.actor_id.to_string(), - // generation: state.generation, - // }), - // }) - // .to_workflow_id(runner_workflow_id) - // .send() - // .await?; + if protocol::is_mk2(runner_protocol_version) { + // TODO: Send message to tunnel + } else { + ctx.signal(crate::workflows::runner::Command { + inner: protocol::Command::CommandStopActor(protocol::CommandStopActor { + actor_id: input.actor_id.to_string(), + generation: state.generation, + }), + }) + .to_workflow_id(runner_workflow_id) + .send() + .await?; + } } } Main::Destroy(_) => { // If allocated, send stop actor command - if let Some(runner_workflow_id) = state.runner_workflow_id { - // TODO: Send message to tunnel - // ctx.signal(crate::workflows::runner::Command { - // inner: protocol::Command::CommandStopActor(protocol::CommandStopActor { - // actor_id: input.actor_id.to_string(), - // generation: state.generation, - // }), - // }) - // .to_workflow_id(runner_workflow_id) - // .send() - // .await?; + if let (Some(runner_workflow_id), Some(runner_protocol_version)) = (state.runner_workflow_id, state.runner_protocol_version) { + if protocol::is_mk2(runner_protocol_version) { + // TODO: Send message to tunnel + } else { + ctx.signal(crate::workflows::runner::Command { + inner: protocol::Command::CommandStopActor(protocol::CommandStopActor { + actor_id: input.actor_id.to_string(), + generation: state.generation, + }), + }) + .to_workflow_id(runner_workflow_id) + .send() + .await?; + } } return Ok(Loop::Break(runtime::LifecycleResult { @@ -586,6 +600,7 @@ async fn handle_stopped( state.stopping = false; state.runner_id = None; let old_runner_workflow_id = state.runner_workflow_id.take(); + let old_runner_protocol_version = state.runner_protocol_version.take(); let deallocate_res = ctx .activity(runtime::DeallocateInput { @@ -624,19 +639,28 @@ async fn handle_stopped( // We don't know the state of the previous generation of this actor actor if it becomes lost, send stop // command in case it ended up allocating - if let (StoppedVariant::Lost { .. }, Some(old_runner_workflow_id)) = - (&variant, old_runner_workflow_id) - { - // TODO: Send message to tunnel - // ctx.signal(crate::workflows::runner::Command { - // inner: protocol::Command::CommandStopActor(protocol::CommandStopActor { - // actor_id: input.actor_id.to_string(), - // generation: state.generation, - // }), - // }) - // .to_workflow_id(old_runner_workflow_id) - // .send() - // .await?; + if let ( + StoppedVariant::Lost { .. }, + Some(old_runner_workflow_id), + Some(old_runner_protocol_version), + ) = ( + &variant, + old_runner_workflow_id, + old_runner_protocol_version, + ) { + if protocol::is_mk2(old_runner_protocol_version) { + // TODO: Send message to tunnel + } else { + ctx.signal(crate::workflows::runner::Command { + inner: protocol::Command::CommandStopActor(protocol::CommandStopActor { + actor_id: input.actor_id.to_string(), + generation: state.generation, + }), + }) + .to_workflow_id(old_runner_workflow_id) + .send() + .await?; + } } // Reschedule no matter what @@ -752,6 +776,8 @@ pub struct Stopped {} pub struct Allocate { pub runner_id: Id, pub runner_workflow_id: Id, + #[serde(default)] + pub runner_protocol_version: Option, } #[signal("pegboard_actor_event")] diff --git a/engine/packages/pegboard/src/workflows/actor/runtime.rs b/engine/packages/pegboard/src/workflows/actor/runtime.rs index 78e1959984..299e8da410 100644 --- a/engine/packages/pegboard/src/workflows/actor/runtime.rs +++ b/engine/packages/pegboard/src/workflows/actor/runtime.rs @@ -4,7 +4,7 @@ use futures_util::StreamExt; use futures_util::TryStreamExt; use gas::prelude::*; use rivet_metrics::KeyValue; -use rivet_runner_protocol as protocol; +use rivet_runner_protocol::{self as protocol, PROTOCOL_MK1_VERSION}; use rivet_types::{ actors::CrashPolicy, keys::namespace::runner_config::RunnerConfigVariant, runner_configs::RunnerConfigKind, @@ -26,6 +26,7 @@ pub struct LifecycleState { // Set when currently running (not rescheduling or sleeping) pub runner_id: Option, pub runner_workflow_id: Option, + pub runner_protocol_version: Option, pub sleeping: bool, #[serde(default)] @@ -46,11 +47,17 @@ pub struct LifecycleState { } impl LifecycleState { - pub fn new(runner_id: Id, runner_workflow_id: Id, actor_start_threshold: i64) -> Self { + pub fn new( + runner_id: Id, + runner_workflow_id: Id, + runner_protocol_version: u16, + actor_start_threshold: i64, + ) -> Self { LifecycleState { generation: 0, runner_id: Some(runner_id), runner_workflow_id: Some(runner_workflow_id), + runner_protocol_version: Some(runner_protocol_version), sleeping: false, stopping: false, going_away: false, @@ -66,6 +73,7 @@ impl LifecycleState { generation: 0, runner_id: None, runner_workflow_id: None, + runner_protocol_version: None, sleeping: true, stopping: false, going_away: false, @@ -121,6 +129,7 @@ pub enum AllocateActorOutput { Allocated { runner_id: Id, runner_workflow_id: Id, + #[serde(default)] runner_protocol_version: Option, }, Pending { @@ -319,8 +328,9 @@ async fn allocate_actor( AllocateActorOutput::Allocated { runner_id: old_runner_alloc_key.runner_id, runner_workflow_id: old_runner_alloc_key_data.workflow_id, - runner_protocol_version: old_runner_alloc_key_data - .runner_protocol_version, + runner_protocol_version: Some( + old_runner_alloc_key_data.protocol_version, + ), }, )); } @@ -491,6 +501,7 @@ pub enum SpawnActorOutput { Allocated { runner_id: Id, runner_workflow_id: Id, + runner_protocol_version: u16, }, Sleep, Destroy, @@ -518,12 +529,14 @@ pub async fn spawn_actor( runner_workflow_id, runner_protocol_version, } => { + let runner_protocol_version = runner_protocol_version.unwrap_or(PROTOCOL_MK1_VERSION); + // Bump the autoscaler so it can scale up ctx.msg(rivet_types::msgs::pegboard::BumpServerlessAutoscaler {}) .send() .await?; - if protocol::is_new(runner_protocol_version) { + if protocol::is_mk2(runner_protocol_version) { // TODO: Send message to tunnel } else { ctx.signal(crate::workflows::runner::Command { @@ -580,6 +593,9 @@ pub async fn spawn_actor( // an `Allocate` signal match signal { Some(PendingAllocation::Allocate(sig)) => { + let runner_protocol_version = + sig.runner_protocol_version.unwrap_or(PROTOCOL_MK1_VERSION); + ctx.activity(UpdateRunnerInput { actor_id: input.actor_id, runner_id: sig.runner_id, @@ -587,7 +603,7 @@ pub async fn spawn_actor( }) .await?; - if protocol::is_new(sig.runner_protocol_version) { + if protocol::is_mk2(runner_protocol_version) { // TODO: Send message to tunnel } else { ctx.signal(crate::workflows::runner::Command { @@ -619,7 +635,7 @@ pub async fn spawn_actor( Ok(SpawnActorOutput::Allocated { runner_id: sig.runner_id, runner_workflow_id: sig.runner_workflow_id, - runner_protocol_version: sig.runner_protocol_version, + runner_protocol_version, }) } Some(PendingAllocation::Destroy(_)) => { @@ -665,6 +681,8 @@ pub async fn spawn_actor( // wait for the allocated signal to prevent a race condition. if !cleared { let sig = ctx.listen::().await?; + let runner_protocol_version = + sig.runner_protocol_version.unwrap_or(PROTOCOL_MK1_VERSION); ctx.activity(UpdateRunnerInput { actor_id: input.actor_id, @@ -673,34 +691,39 @@ pub async fn spawn_actor( }) .await?; - ctx.signal(crate::workflows::runner::Command { - inner: protocol::Command::CommandStartActor( - protocol::CommandStartActor { - actor_id: input.actor_id.to_string(), - generation, - config: protocol::ActorConfig { - name: input.name.clone(), - key: input.key.clone(), - create_ts: util::timestamp::now(), - input: input - .input - .as_ref() - .map(|x| BASE64_STANDARD.decode(x)) - .transpose()?, + if protocol::is_mk2(runner_protocol_version) { + // TODO: Send message to tunnel + } else { + ctx.signal(crate::workflows::runner::Command { + inner: protocol::Command::CommandStartActor( + protocol::CommandStartActor { + actor_id: input.actor_id.to_string(), + generation, + config: protocol::ActorConfig { + name: input.name.clone(), + key: input.key.clone(), + create_ts: util::timestamp::now(), + input: input + .input + .as_ref() + .map(|x| BASE64_STANDARD.decode(x)) + .transpose()?, + }, + // Empty because request ids are ephemeral. This is intercepted by guard and + // populated before it reaches the runner + hibernating_requests: Vec::new(), }, - // Empty because request ids are ephemeral. This is intercepted by guard and - // populated before it reaches the runner - hibernating_requests: Vec::new(), - }, - ), - }) - .to_workflow_id(sig.runner_workflow_id) - .send() - .await?; + ), + }) + .to_workflow_id(sig.runner_workflow_id) + .send() + .await?; + } Ok(SpawnActorOutput::Allocated { runner_id: sig.runner_id, runner_workflow_id: sig.runner_workflow_id, + runner_protocol_version, }) } else { Ok(SpawnActorOutput::Sleep) @@ -771,7 +794,7 @@ pub async fn reschedule_actor( state.generation = next_generation; state.runner_id = Some(*runner_id); state.runner_workflow_id = Some(*runner_workflow_id); - state.runner_protocol_version = runner_protocol_version; + state.runner_protocol_version = Some(*runner_protocol_version); // Reset gc timeout once allocated state.gc_timeout_ts = diff --git a/engine/packages/pegboard/src/workflows/runner.rs b/engine/packages/pegboard/src/workflows/runner.rs index b980f8b6b5..7d665b480f 100644 --- a/engine/packages/pegboard/src/workflows/runner.rs +++ b/engine/packages/pegboard/src/workflows/runner.rs @@ -2,7 +2,7 @@ use futures_util::{FutureExt, StreamExt, TryStreamExt}; use gas::prelude::*; use rivet_data::converted::{ActorNameKeyData, MetadataKeyData, RunnerByKeyKeyData}; use rivet_metrics::KeyValue; -use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; +use rivet_runner_protocol::{self as protocol, PROTOCOL_MK1_VERSION, versioned}; use universaldb::{ options::{ConflictRangeType, StreamingMode}, utils::{FormalChunkedKey, IsolationLevel::*}, @@ -157,7 +157,7 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> // Forward to actor workflows for event in new_events.clone() { let actor_id = - crate::utils::event_actor_id(&event.inner).to_string(); + crate::utils::event_actor_id_mk1(&event.inner).to_string(); let res = ctx .signal(crate::workflows::actor::Event { inner: event.inner.clone(), @@ -1096,6 +1096,9 @@ pub(crate) async fn allocate_pending_actors( signal: Allocate { runner_id: old_runner_alloc_key.runner_id, runner_workflow_id: old_runner_alloc_key_data.workflow_id, + runner_protocol_version: Some( + old_runner_alloc_key_data.protocol_version, + ), }, }); @@ -1132,7 +1135,7 @@ async fn send_message_to_runner(ctx: &ActivityCtx, input: &SendMessageToRunnerIn crate::pubsub_subjects::RunnerReceiverSubject::new(input.runner_id).to_string(); let message_serialized = versioned::ToRunner::wrap_latest(input.message.clone()) - .serialize_with_embedded_version(PROTOCOL_VERSION)?; + .serialize_with_embedded_version(PROTOCOL_MK1_VERSION)?; ctx.ups()? .publish(&receiver_subject, &message_serialized, PublishOpts::one()) diff --git a/engine/packages/pegboard/src/workflows/runner2.rs b/engine/packages/pegboard/src/workflows/runner2.rs index 375b1d1baf..c7ac61e7e7 100644 --- a/engine/packages/pegboard/src/workflows/runner2.rs +++ b/engine/packages/pegboard/src/workflows/runner2.rs @@ -1,20 +1,17 @@ use futures_util::{FutureExt, StreamExt, TryStreamExt}; use gas::prelude::*; -use rivet_data::converted::{ActorNameKeyData, MetadataKeyData, RunnerByKeyKeyData}; +use rivet_data::converted::RunnerByKeyKeyData; use rivet_metrics::KeyValue; -use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; +use rivet_runner_protocol::{self as protocol, PROTOCOL_MK2_VERSION, versioned}; use universaldb::{ options::{ConflictRangeType, StreamingMode}, - utils::{FormalChunkedKey, IsolationLevel::*}, + utils::IsolationLevel::*, }; use universalpubsub::PublishOpts; use vbare::OwnedVersionedData; use crate::{keys, metrics, workflows::actor::Allocate}; -/// Batch size of how many events to ack. -const EVENT_ACK_BATCH_SIZE: i64 = 500; - #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Input { pub runner_id: Id, @@ -41,13 +38,6 @@ impl State { } } -#[derive(Debug, Serialize, Deserialize)] -struct CommandRow { - index: i64, - command: protocol::Command, - create_ts: i64, -} - #[workflow] pub async fn pegboard_runner2(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> { ctx.activity(InitInput { @@ -78,6 +68,7 @@ pub async fn pegboard_runner2(ctx: &mut WorkflowCtx, input: &Input) -> Result<() key: input.key.clone(), version: input.version, total_slots: input.total_slots, + protocol_version: input.protocol_version, create_ts: ctx.create_ts(), }) .await?; @@ -182,7 +173,7 @@ pub async fn pegboard_runner2(ctx: &mut WorkflowCtx, input: &Input) -> Result<() // Close websocket connection (its unlikely to be open) ctx.activity(SendMessagesToRunnerInput { runner_id: input.runner_id, - messages: vec![protocol::ToRunner::ToRunnerClose], + messages: vec![protocol::mk2::ToRunner::ToRunnerClose], }) .await?; @@ -253,7 +244,7 @@ struct InitInput { create_ts: i64, } -#[activity(Init)] +#[activity(InitActivity)] async fn init(ctx: &ActivityCtx, input: &InitInput) -> Result<()> { let mut state = ctx.state::>()?; @@ -751,6 +742,9 @@ pub(crate) async fn allocate_pending_actors( signal: Allocate { runner_id: old_runner_alloc_key.runner_id, runner_workflow_id: old_runner_alloc_key_data.workflow_id, + runner_protocol_version: Some( + old_runner_alloc_key_data.protocol_version, + ), }, }); @@ -778,7 +772,7 @@ pub(crate) async fn allocate_pending_actors( #[derive(Debug, Serialize, Deserialize, Hash)] struct SendMessagesToRunnerInput { runner_id: Id, - messages: Vec, + messages: Vec, } #[activity(SendMessagesToRunner)] @@ -790,8 +784,8 @@ async fn send_messages_to_runner( crate::pubsub_subjects::RunnerReceiverSubject::new(input.runner_id).to_string(); for message in &input.messages { - let message_serialized = versioned::ToRunner::wrap_latest(message.clone()) - .serialize_with_embedded_version(PROTOCOL_VERSION)?; + let message_serialized = versioned::ToRunnerMk2::wrap_latest(message.clone()) + .serialize_with_embedded_version(PROTOCOL_MK2_VERSION)?; ctx.ups()? .publish(&receiver_subject, &message_serialized, PublishOpts::one()) @@ -801,6 +795,9 @@ async fn send_messages_to_runner( Ok(()) } +#[signal("pegboard_runner_init")] +pub struct Init {} + #[signal("pegboard_runner_check_queue")] pub struct CheckQueue {} @@ -809,14 +806,8 @@ pub struct Stop { pub reset_actor_rescheduling: bool, } -#[signal("pegboard_runner_forward")] -pub struct Forward { - pub inner: protocol::ToServer, -} - join_signal!(Main { - // Forwarded from the ws to this workflow - Forward(Forward), + Init, CheckQueue, Stop, }); diff --git a/engine/packages/universaldb/src/utils/keys.rs b/engine/packages/universaldb/src/utils/keys.rs index 727ad382f5..2f3ceee6ac 100644 --- a/engine/packages/universaldb/src/utils/keys.rs +++ b/engine/packages/universaldb/src/utils/keys.rs @@ -131,4 +131,5 @@ define_keys! { // 103 - RESERVED BY EE // 104 - RESERVED BY EE // 105 - RESERVED BY EE + (106, PROTOCOL_VERSION, "protocol_version"), } diff --git a/engine/sdks/rust/data/src/converted.rs b/engine/sdks/rust/data/src/converted.rs index bfed1e8100..1fcf655cdd 100644 --- a/engine/sdks/rust/data/src/converted.rs +++ b/engine/sdks/rust/data/src/converted.rs @@ -7,6 +7,7 @@ pub struct RunnerAllocIdxKeyData { pub workflow_id: Id, pub remaining_slots: u32, pub total_slots: u32, + pub protocol_version: u16, } impl TryFrom for RunnerAllocIdxKeyData { diff --git a/engine/sdks/rust/data/src/versioned/mod.rs b/engine/sdks/rust/data/src/versioned/mod.rs index b70d52f0f8..b10c0bcb6f 100644 --- a/engine/sdks/rust/data/src/versioned/mod.rs +++ b/engine/sdks/rust/data/src/versioned/mod.rs @@ -52,6 +52,38 @@ impl OwnedVersionedData for RunnerAllocIdxKeyData { } } +impl RunnerAllocIdxKeyData { + fn v1_to_v2(self) -> Result { + if let RunnerAllocIdxKeyData::V1(x) = self { + Ok(RunnerAllocIdxKeyData::V2( + pegboard_namespace_runner_alloc_idx_v2::Data { + workflow_id: x.workflow_id, + remaining_slots: x.remaining_slots, + total_slots: x.total_slots, + // Default to mk1 + protocol_version: rivet_runner_protocol::PROTOCOL_MK1_VERSION, + }, + )) + } else { + bail!("unexpected version"); + } + } + + fn v2_to_v1(self) -> Result { + if let RunnerAllocIdxKeyData::V2(x) = self { + Ok(RunnerAllocIdxKeyData::V1( + pegboard_namespace_runner_alloc_idx_v1::Data { + workflow_id: x.workflow_id, + remaining_slots: x.remaining_slots, + total_slots: x.total_slots, + }, + )) + } else { + bail!("unexpected version"); + } + } +} + pub enum MetadataKeyData { V1(pegboard_runner_metadata_v1::Data), } diff --git a/engine/sdks/rust/runner-protocol/src/versioned.rs b/engine/sdks/rust/runner-protocol/src/versioned.rs index 54b9707090..ad1dcc9bbb 100644 --- a/engine/sdks/rust/runner-protocol/src/versioned.rs +++ b/engine/sdks/rust/runner-protocol/src/versioned.rs @@ -1,6 +1,7 @@ use anyhow::{Ok, Result, bail}; use vbare::OwnedVersionedData; +use crate::PROTOCOL_MK1_VERSION; use crate::generated::{v1, v2, v3, v4}; use crate::uuid_compat::{decode_bytes_from_uuid, encode_bytes_to_uuid}; @@ -155,73 +156,70 @@ impl OwnedVersionedData for ToClient { impl ToClient { fn v1_to_v2(self) -> Result { - match self { - ToClient::V1(x) => { - let inner = match x { - v1::ToClient::ToClientInit(init) => { - v2::ToClient::ToClientInit(v2::ToClientInit { - runner_id: init.runner_id, - last_event_idx: init.last_event_idx, - metadata: v2::ProtocolMetadata { - runner_lost_threshold: init.metadata.runner_lost_threshold, + if let ToClient::V1(x) = self { + let inner = match x { + v1::ToClient::ToClientInit(init) => v2::ToClient::ToClientInit(v2::ToClientInit { + runner_id: init.runner_id, + last_event_idx: init.last_event_idx, + metadata: v2::ProtocolMetadata { + runner_lost_threshold: init.metadata.runner_lost_threshold, + }, + }), + v1::ToClient::ToClientClose => v2::ToClient::ToClientClose, + v1::ToClient::ToClientCommands(commands) => v2::ToClient::ToClientCommands( + commands + .into_iter() + .map(|cmd| v2::CommandWrapper { + index: cmd.index, + inner: match cmd.inner { + v1::Command::CommandStartActor(start) => { + v2::Command::CommandStartActor(v2::CommandStartActor { + actor_id: start.actor_id, + generation: start.generation, + config: v2::ActorConfig { + name: start.config.name, + key: start.config.key, + create_ts: start.config.create_ts, + input: start.config.input, + }, + }) + } + v1::Command::CommandStopActor(stop) => { + v2::Command::CommandStopActor(v2::CommandStopActor { + actor_id: stop.actor_id, + generation: stop.generation, + }) + } }, }) - } - v1::ToClient::ToClientClose => v2::ToClient::ToClientClose, - v1::ToClient::ToClientCommands(commands) => v2::ToClient::ToClientCommands( - commands - .into_iter() - .map(|cmd| v2::CommandWrapper { - index: cmd.index, - inner: match cmd.inner { - v1::Command::CommandStartActor(start) => { - v2::Command::CommandStartActor(v2::CommandStartActor { - actor_id: start.actor_id, - generation: start.generation, - config: v2::ActorConfig { - name: start.config.name, - key: start.config.key, - create_ts: start.config.create_ts, - input: start.config.input, - }, - }) - } - v1::Command::CommandStopActor(stop) => { - v2::Command::CommandStopActor(v2::CommandStopActor { - actor_id: stop.actor_id, - generation: stop.generation, - }) - } - }, - }) - .collect(), - ), - v1::ToClient::ToClientAckEvents(ack) => { - v2::ToClient::ToClientAckEvents(v2::ToClientAckEvents { - last_event_idx: ack.last_event_idx, - }) - } - v1::ToClient::ToClientKvResponse(resp) => { - v2::ToClient::ToClientKvResponse(v2::ToClientKvResponse { - request_id: resp.request_id, - data: convert_kv_response_data_v1_to_v2(resp.data), - }) - } - v1::ToClient::ToClientTunnelMessage(msg) => { - v2::ToClient::ToClientTunnelMessage(v2::ToClientTunnelMessage { - request_id: msg.request_id, - message_id: msg.message_id, - message_kind: convert_to_client_tunnel_message_kind_v1_to_v2( - msg.message_kind, - ), - gateway_reply_to: msg.gateway_reply_to, - }) - } - }; + .collect(), + ), + v1::ToClient::ToClientAckEvents(ack) => { + v2::ToClient::ToClientAckEvents(v2::ToClientAckEvents { + last_event_idx: ack.last_event_idx, + }) + } + v1::ToClient::ToClientKvResponse(resp) => { + v2::ToClient::ToClientKvResponse(v2::ToClientKvResponse { + request_id: resp.request_id, + data: convert_kv_response_data_v1_to_v2(resp.data), + }) + } + v1::ToClient::ToClientTunnelMessage(msg) => { + v2::ToClient::ToClientTunnelMessage(v2::ToClientTunnelMessage { + request_id: msg.request_id, + message_id: msg.message_id, + message_kind: convert_to_client_tunnel_message_kind_v1_to_v2( + msg.message_kind, + ), + gateway_reply_to: msg.gateway_reply_to, + }) + } + }; - Ok(ToClient::V2(inner)) - } - _ => bail!("unexpected version"), + Ok(ToClient::V2(inner)) + } else { + bail!("unexpected version"); } } @@ -827,18 +825,18 @@ impl OwnedVersionedData for ToRunner { } fn deserialize_converters() -> Vec Result> { - // No changes between v1 and v3 - vec![Ok, Ok] + // No changes between v1 and v4 + vec![Ok, Ok, Ok] } fn serialize_converters() -> Vec Result> { - // No changes between v1 and v3 - vec![Ok, Ok] + // No changes between v1 and v4 + vec![Ok, Ok, Ok] } } pub enum ToGateway { - // No change between v1 and v4 + V3(v3::ToGateway), V4(v4::ToGateway), } @@ -860,30 +858,92 @@ impl OwnedVersionedData for ToGateway { fn deserialize_version(payload: &[u8], version: u16) -> Result { match version { - 1 | 2 | 4 => Ok(ToGateway::V4(serde_bare::from_slice(payload)?)), + 1 | 2 | 3 => Ok(ToGateway::V3(serde_bare::from_slice(payload)?)), + 4 => Ok(ToGateway::V4(serde_bare::from_slice(payload)?)), _ => bail!("invalid version: {version}"), } } fn serialize_version(self, _version: u16) -> Result> { match self { + ToGateway::V3(data) => serde_bare::to_vec(&data).map_err(Into::into), ToGateway::V4(data) => serde_bare::to_vec(&data).map_err(Into::into), } } fn deserialize_converters() -> Vec Result> { - // No changes between v1 and v3 - vec![Ok, Ok] + // No changes between v1 and v4 but we need a converter to bridge mk1 to mk2 + vec![Ok, Ok, Self::v3_to_v4] } fn serialize_converters() -> Vec Result> { - // No changes between v1 and v3 - vec![Ok, Ok] + // No changes between v1 and v4 but we need a converter to bridge mk2 to mk1 + vec![Self::v4_to_v3, Ok, Ok] + } +} + +impl ToGateway { + pub fn v3_to_v4(self) -> Result { + if let ToGateway::V3(x) = self { + let inner = match x { + v3::ToGateway::ToGatewayPong(pong) => { + v4::ToGateway::ToGatewayPong(v4::ToGatewayPong { + request_id: pong.request_id, + ts: pong.ts, + }) + } + v3::ToGateway::ToServerTunnelMessage(msg) => { + v4::ToGateway::ToServerTunnelMessage(v4::ToServerTunnelMessage { + message_id: v4::MessageId { + gateway_id: msg.message_id.gateway_id, + request_id: msg.message_id.request_id, + message_index: msg.message_id.message_index, + }, + message_kind: convert_to_server_tunnel_message_kind_v3_to_v4( + msg.message_kind, + ), + }) + } + }; + + Ok(ToGateway::V4(inner)) + } else { + bail!("unexpected version"); + } + } + + fn v4_to_v3(self) -> Result { + if let ToGateway::V4(x) = self { + let inner = match x { + v4::ToGateway::ToGatewayPong(pong) => { + v3::ToGateway::ToGatewayPong(v3::ToGatewayPong { + request_id: pong.request_id, + ts: pong.ts, + }) + } + v4::ToGateway::ToServerTunnelMessage(msg) => { + v3::ToGateway::ToServerTunnelMessage(v3::ToServerTunnelMessage { + message_id: v3::MessageId { + gateway_id: msg.message_id.gateway_id, + request_id: msg.message_id.request_id, + message_index: msg.message_id.message_index, + }, + message_kind: convert_to_server_tunnel_message_kind_v4_to_v3( + msg.message_kind, + )?, + }) + } + }; + + Ok(ToGateway::V3(inner)) + } else { + bail!("unexpected version"); + } } } pub enum ToServerlessServer { - // No change between v1 and v4 + V3(v3::ToServerlessServer), V4(v4::ToServerlessServer), } @@ -905,25 +965,62 @@ impl OwnedVersionedData for ToServerlessServer { fn deserialize_version(payload: &[u8], version: u16) -> Result { match version { - 1 | 2 | 3 | 4 => Ok(ToServerlessServer::V4(serde_bare::from_slice(payload)?)), + 1 | 2 | 3 => Ok(ToServerlessServer::V3(serde_bare::from_slice(payload)?)), + 4 => Ok(ToServerlessServer::V4(serde_bare::from_slice(payload)?)), _ => bail!("invalid version: {version}"), } } fn serialize_version(self, _version: u16) -> Result> { match self { + ToServerlessServer::V3(data) => serde_bare::to_vec(&data).map_err(Into::into), ToServerlessServer::V4(data) => serde_bare::to_vec(&data).map_err(Into::into), } } fn deserialize_converters() -> Vec Result> { // No changes between v1 and v3 - vec![Ok, Ok] + vec![Ok, Ok, Ok, Self::v3_to_v4] } fn serialize_converters() -> Vec Result> { // No changes between v1 and v3 - vec![Ok, Ok] + vec![Self::v4_to_v3, Ok, Ok, Ok] + } +} + +impl ToServerlessServer { + fn v3_to_v4(self) -> Result { + if let ToServerlessServer::V3(x) = self { + let inner = match x { + v3::ToServerlessServer::ToServerlessServerInit(init) => { + v4::ToServerlessServer::ToServerlessServerInit(v4::ToServerlessServerInit { + runner_id: init.runner_id, + runner_protocol_version: PROTOCOL_MK1_VERSION, + }) + } + }; + + Ok(ToServerlessServer::V4(inner)) + } else { + bail!("unexpected version"); + } + } + + fn v4_to_v3(self) -> Result { + if let ToServerlessServer::V4(x) = self { + let inner = match x { + v4::ToServerlessServer::ToServerlessServerInit(init) => { + v3::ToServerlessServer::ToServerlessServerInit(v3::ToServerlessServerInit { + runner_id: init.runner_id, + }) + } + }; + + Ok(ToServerlessServer::V3(inner)) + } else { + bail!("unexpected version"); + } } } @@ -1834,3 +1931,160 @@ fn convert_kv_metadata_v3_to_v2(metadata: v3::KvMetadata) -> v2::KvMetadata { create_ts: metadata.create_ts, } } + +fn convert_to_server_tunnel_message_kind_v3_to_v4( + kind: v3::ToServerTunnelMessageKind, +) -> v4::ToServerTunnelMessageKind { + match kind { + v3::ToServerTunnelMessageKind::ToServerResponseStart(resp) => { + v4::ToServerTunnelMessageKind::ToServerResponseStart(v4::ToServerResponseStart { + status: resp.status, + headers: resp.headers, + body: resp.body, + stream: resp.stream, + }) + } + v3::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => { + v4::ToServerTunnelMessageKind::ToServerResponseChunk(v4::ToServerResponseChunk { + body: chunk.body, + finish: chunk.finish, + }) + } + v3::ToServerTunnelMessageKind::ToServerResponseAbort => { + v4::ToServerTunnelMessageKind::ToServerResponseAbort + } + v3::ToServerTunnelMessageKind::ToServerWebSocketOpen(open) => { + v4::ToServerTunnelMessageKind::ToServerWebSocketOpen(v4::ToServerWebSocketOpen { + can_hibernate: open.can_hibernate, + }) + } + v3::ToServerTunnelMessageKind::ToServerWebSocketMessage(msg) => { + v4::ToServerTunnelMessageKind::ToServerWebSocketMessage(v4::ToServerWebSocketMessage { + data: msg.data, + binary: msg.binary, + }) + } + v3::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => { + v4::ToServerTunnelMessageKind::ToServerWebSocketMessageAck( + v4::ToServerWebSocketMessageAck { index: ack.index }, + ) + } + v3::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => { + v4::ToServerTunnelMessageKind::ToServerWebSocketClose(v4::ToServerWebSocketClose { + code: close.code, + reason: close.reason, + hibernate: close.hibernate, + }) + } + v3::ToServerTunnelMessageKind::DeprecatedTunnelAck => { + // v4 removed DeprecatedTunnelAck, this should not occur in practice + // but if it does, we'll convert it to a response abort as a safe fallback + v4::ToServerTunnelMessageKind::ToServerResponseAbort + } + } +} + +fn convert_to_server_tunnel_message_kind_v4_to_v3( + kind: v4::ToServerTunnelMessageKind, +) -> Result { + Ok(match kind { + v4::ToServerTunnelMessageKind::ToServerResponseStart(resp) => { + v3::ToServerTunnelMessageKind::ToServerResponseStart(v3::ToServerResponseStart { + status: resp.status, + headers: resp.headers, + body: resp.body, + stream: resp.stream, + }) + } + v4::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => { + v3::ToServerTunnelMessageKind::ToServerResponseChunk(v3::ToServerResponseChunk { + body: chunk.body, + finish: chunk.finish, + }) + } + v4::ToServerTunnelMessageKind::ToServerResponseAbort => { + v3::ToServerTunnelMessageKind::ToServerResponseAbort + } + v4::ToServerTunnelMessageKind::ToServerWebSocketOpen(open) => { + v3::ToServerTunnelMessageKind::ToServerWebSocketOpen(v3::ToServerWebSocketOpen { + can_hibernate: open.can_hibernate, + }) + } + v4::ToServerTunnelMessageKind::ToServerWebSocketMessage(msg) => { + v3::ToServerTunnelMessageKind::ToServerWebSocketMessage(v3::ToServerWebSocketMessage { + data: msg.data, + binary: msg.binary, + }) + } + v4::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => { + v3::ToServerTunnelMessageKind::ToServerWebSocketMessageAck( + v3::ToServerWebSocketMessageAck { index: ack.index }, + ) + } + v4::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => { + v3::ToServerTunnelMessageKind::ToServerWebSocketClose(v3::ToServerWebSocketClose { + code: close.code, + reason: close.reason, + hibernate: close.hibernate, + }) + } + }) +} + +pub fn to_client_tunnel_message_v4_to_v3( + msg: v4::ToClientTunnelMessage, +) -> v3::ToClientTunnelMessage { + v3::ToClientTunnelMessage { + message_id: v3::MessageId { + gateway_id: msg.message_id.gateway_id, + request_id: msg.message_id.request_id, + message_index: msg.message_id.message_index, + }, + message_kind: convert_to_client_tunnel_message_kind_v4_to_v3(msg.message_kind), + } +} + +fn convert_to_client_tunnel_message_kind_v4_to_v3( + kind: v4::ToClientTunnelMessageKind, +) -> v3::ToClientTunnelMessageKind { + match kind { + v4::ToClientTunnelMessageKind::ToClientRequestStart(req) => { + v3::ToClientTunnelMessageKind::ToClientRequestStart(v3::ToClientRequestStart { + actor_id: req.actor_id, + method: req.method, + path: req.path, + headers: req.headers, + body: req.body, + stream: req.stream, + }) + } + v4::ToClientTunnelMessageKind::ToClientRequestChunk(chunk) => { + v3::ToClientTunnelMessageKind::ToClientRequestChunk(v3::ToClientRequestChunk { + body: chunk.body, + finish: chunk.finish, + }) + } + v4::ToClientTunnelMessageKind::ToClientRequestAbort => { + v3::ToClientTunnelMessageKind::ToClientRequestAbort + } + v4::ToClientTunnelMessageKind::ToClientWebSocketOpen(ws) => { + v3::ToClientTunnelMessageKind::ToClientWebSocketOpen(v3::ToClientWebSocketOpen { + actor_id: ws.actor_id, + path: ws.path, + headers: ws.headers, + }) + } + v4::ToClientTunnelMessageKind::ToClientWebSocketMessage(msg) => { + v3::ToClientTunnelMessageKind::ToClientWebSocketMessage(v3::ToClientWebSocketMessage { + data: msg.data, + binary: msg.binary, + }) + } + v4::ToClientTunnelMessageKind::ToClientWebSocketClose(close) => { + v3::ToClientTunnelMessageKind::ToClientWebSocketClose(v3::ToClientWebSocketClose { + code: close.code, + reason: close.reason, + }) + } + } +} diff --git a/engine/sdks/schemas/runner-protocol/v4.bare b/engine/sdks/schemas/runner-protocol/v4.bare index e363bcbe1f..ff1a4bba50 100644 --- a/engine/sdks/schemas/runner-protocol/v4.bare +++ b/engine/sdks/schemas/runner-protocol/v4.bare @@ -207,7 +207,7 @@ type CommandWrapper struct { # Message ID -type MessageIdParts struct { +type MessageId struct { # Globally unique ID gatewayId: GatewayId # Unique ID to the gateway @@ -216,7 +216,6 @@ type MessageIdParts struct { messageIndex: MessageIndex } -type MessageId data[12] # HTTP type ToClientRequestStart struct { @@ -322,6 +321,10 @@ type ToClientTunnelMessage struct { messageKind: ToClientTunnelMessageKind } +type ToClientPing struct { + ts: i64 +} + # MARK: To Server type ToServerInit struct { name: str @@ -424,6 +427,7 @@ type ToGateway union { # MARK: Serverless type ToServerlessServerInit struct { runnerId: Id + runnerProtocolVersion: u16 } type ToServerlessServer union { diff --git a/engine/sdks/typescript/runner-protocol/src/index.ts b/engine/sdks/typescript/runner-protocol/src/index.ts index 88b43e602d..7e619fa402 100644 --- a/engine/sdks/typescript/runner-protocol/src/index.ts +++ b/engine/sdks/typescript/runner-protocol/src/index.ts @@ -1007,8 +1007,6 @@ export function writeMessageId(bc: bare.ByteCursor, x: MessageId): void { writeMessageIndex(bc, x.messageIndex) } -export type DeprecatedTunnelAck = null - function read9(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() @@ -1448,6 +1446,20 @@ export function writeToClientTunnelMessage(bc: bare.ByteCursor, x: ToClientTunne writeToClientTunnelMessageKind(bc, x.messageKind) } +export type ToClientPing = { + readonly ts: i64 +} + +export function readToClientPing(bc: bare.ByteCursor): ToClientPing { + return { + ts: bare.readI64(bc), + } +} + +export function writeToClientPing(bc: bare.ByteCursor, x: ToClientPing): void { + bare.writeI64(bc, x.ts) +} + function read11(bc: bare.ByteCursor): readonly ActorCheckpoint[] { const len = bare.readUintSafe(bc) if (len === 0) { @@ -1581,17 +1593,17 @@ export function writeToServerAckCommands(bc: bare.ByteCursor, x: ToServerAckComm export type ToServerStopping = null -export type ToServerPing = { +export type ToServerPong = { readonly ts: i64 } -export function readToServerPing(bc: bare.ByteCursor): ToServerPing { +export function readToServerPong(bc: bare.ByteCursor): ToServerPong { return { ts: bare.readI64(bc), } } -export function writeToServerPing(bc: bare.ByteCursor, x: ToServerPing): void { +export function writeToServerPong(bc: bare.ByteCursor, x: ToServerPong): void { bare.writeI64(bc, x.ts) } @@ -1620,7 +1632,7 @@ export type ToServer = | { readonly tag: "ToServerEvents"; readonly val: ToServerEvents } | { readonly tag: "ToServerAckCommands"; readonly val: ToServerAckCommands } | { readonly tag: "ToServerStopping"; readonly val: ToServerStopping } - | { readonly tag: "ToServerPing"; readonly val: ToServerPing } + | { readonly tag: "ToServerPong"; readonly val: ToServerPong } | { readonly tag: "ToServerKvRequest"; readonly val: ToServerKvRequest } | { readonly tag: "ToServerTunnelMessage"; readonly val: ToServerTunnelMessage } @@ -1637,7 +1649,7 @@ export function readToServer(bc: bare.ByteCursor): ToServer { case 3: return { tag: "ToServerStopping", val: null } case 4: - return { tag: "ToServerPing", val: readToServerPing(bc) } + return { tag: "ToServerPong", val: readToServerPong(bc) } case 5: return { tag: "ToServerKvRequest", val: readToServerKvRequest(bc) } case 6: @@ -1670,9 +1682,9 @@ export function writeToServer(bc: bare.ByteCursor, x: ToServer): void { bare.writeU8(bc, 3) break } - case "ToServerPing": { + case "ToServerPong": { bare.writeU8(bc, 4) - writeToServerPing(bc, x.val) + writeToServerPong(bc, x.val) break } case "ToServerKvRequest": { @@ -1802,6 +1814,7 @@ export type ToClient = | { readonly tag: "ToClientAckEvents"; readonly val: ToClientAckEvents } | { readonly tag: "ToClientKvResponse"; readonly val: ToClientKvResponse } | { readonly tag: "ToClientTunnelMessage"; readonly val: ToClientTunnelMessage } + | { readonly tag: "ToClientPing"; readonly val: ToClientPing } export function readToClient(bc: bare.ByteCursor): ToClient { const offset = bc.offset @@ -1817,6 +1830,8 @@ export function readToClient(bc: bare.ByteCursor): ToClient { return { tag: "ToClientKvResponse", val: readToClientKvResponse(bc) } case 4: return { tag: "ToClientTunnelMessage", val: readToClientTunnelMessage(bc) } + case 5: + return { tag: "ToClientPing", val: readToClientPing(bc) } default: { bc.offset = offset throw new bare.BareError(offset, "invalid tag") @@ -1851,6 +1866,11 @@ export function writeToClient(bc: bare.ByteCursor, x: ToClient): void { writeToClientTunnelMessage(bc, x.val) break } + case "ToClientPing": { + bare.writeU8(bc, 5) + writeToClientPing(bc, x.val) + break + } } } @@ -2056,16 +2076,19 @@ export function decodeToGateway(bytes: Uint8Array): ToGateway { */ export type ToServerlessServerInit = { readonly runnerId: Id + readonly runnerProtocolVersion: u16 } export function readToServerlessServerInit(bc: bare.ByteCursor): ToServerlessServerInit { return { runnerId: readId(bc), + runnerProtocolVersion: bare.readU16(bc), } } export function writeToServerlessServerInit(bc: bare.ByteCursor, x: ToServerlessServerInit): void { writeId(bc, x.runnerId) + bare.writeU16(bc, x.runnerProtocolVersion) } export type ToServerlessServer = diff --git a/engine/sdks/typescript/runner/src/mod.ts b/engine/sdks/typescript/runner/src/mod.ts index 9f41be6ad8..926bd93c77 100644 --- a/engine/sdks/typescript/runner/src/mod.ts +++ b/engine/sdks/typescript/runner/src/mod.ts @@ -17,7 +17,7 @@ export { RunnerActor, type ActorConfig }; export { idToStr } from "./utils"; const KV_EXPIRE: number = 30_000; -const PROTOCOL_VERSION: number = 3; +const PROTOCOL_VERSION: number = 4; const RUNNER_PING_INTERVAL = 3_000; /** Warn once the backlog significantly exceeds the server's ack batch size. */