diff --git a/engine/packages/pegboard-gateway/src/shared_state.rs b/engine/packages/pegboard-gateway/src/shared_state.rs index 2a01f10b44..49c610f0a7 100644 --- a/engine/packages/pegboard-gateway/src/shared_state.rs +++ b/engine/packages/pegboard-gateway/src/shared_state.rs @@ -59,6 +59,7 @@ pub struct PendingWebsocketMessage { pub struct SharedStateInner { ups: PubSub, + gateway_id: Uuid, receiver_subject: String, in_flight_requests: HashMap, } @@ -74,6 +75,7 @@ impl SharedState { Self(Arc::new(SharedStateInner { ups, + gateway_id, receiver_subject, in_flight_requests: HashMap::new(), })) @@ -160,6 +162,7 @@ impl SharedState { }; let payload = protocol::ToClientTunnelMessage { + gateway_id: *self.gateway_id.as_bytes(), request_id: request_id.clone(), message_id, // Only send reply to subject on the first message for this request. This reduces @@ -179,8 +182,8 @@ impl SharedState { }); // Send message - let message = protocol::ToClient::ToClientTunnelMessage(payload); - let message_serialized = versioned::ToClient::wrap_latest(message) + let message = protocol::ToRunner::ToClientTunnelMessage(payload); + let message_serialized = versioned::ToRunner::wrap_latest(message) .serialize_with_embedded_version(PROTOCOL_VERSION)?; if let (Some(hs), Some(ws_msg_index)) = (&mut req.hibernation_state, ws_msg_index) { @@ -221,105 +224,48 @@ impl SharedState { ); match versioned::ToGateway::deserialize_with_embedded_version(&msg.payload) { - Ok(protocol::ToGateway { message: msg }) => { - tracing::debug!( - request_id=?Uuid::from_bytes(msg.request_id), - message_id=?Uuid::from_bytes(msg.message_id), - "successfully deserialized message" - ); - - let Some(mut in_flight) = - self.in_flight_requests.get_async(&msg.request_id).await + Ok(protocol::ToGateway::ToGatewayKeepAlive) => { + // TODO: + // let prev_len = in_flight.pending_msgs.len(); + // + // tracing::debug!(message_id=?Uuid::from_bytes(msg.message_id), "received tunnel ack"); + // + // in_flight + // .pending_msgs + // .retain(|m| m.message_id != msg.message_id); + // + // if prev_len == in_flight.pending_msgs.len() { + // tracing::warn!( + // request_id=?Uuid::from_bytes(msg.request_id), + // message_id=?Uuid::from_bytes(msg.message_id), + // "pending message does not exist or ack received after message body" + // ) + // } else { + // tracing::debug!( + // request_id=?Uuid::from_bytes(msg.request_id), + // message_id=?Uuid::from_bytes(msg.message_id), + // "received TunnelAck, removed from pending" + // ); + // } + } + Ok(protocol::ToGateway::ToServerTunnelMessage(msg)) => { + let Some(in_flight) = self.in_flight_requests.get_async(&msg.request_id).await else { tracing::warn!( request_id=?Uuid::from_bytes(msg.request_id), message_id=?Uuid::from_bytes(msg.message_id), - "in flight has already been disconnected, cannot ack message" + "in flight has already been disconnected, dropping message" ); continue; }; - if let protocol::ToServerTunnelMessageKind::TunnelAck = &msg.message_kind { - let prev_len = in_flight.pending_msgs.len(); - - tracing::debug!(message_id=?Uuid::from_bytes(msg.message_id), "received tunnel ack"); - - in_flight - .pending_msgs - .retain(|m| m.message_id != msg.message_id); - - if prev_len == in_flight.pending_msgs.len() { - tracing::warn!( - request_id=?Uuid::from_bytes(msg.request_id), - message_id=?Uuid::from_bytes(msg.message_id), - "pending message does not exist or ack received after message body" - ) - } else { - tracing::debug!( - request_id=?Uuid::from_bytes(msg.request_id), - message_id=?Uuid::from_bytes(msg.message_id), - "received TunnelAck, removed from pending" - ); - } - } else { - // Send message to the request handler to emulate the real network action - tracing::debug!( - request_id=?Uuid::from_bytes(msg.request_id), - message_id=?Uuid::from_bytes(msg.message_id), - "forwarding message to request handler" - ); - let _ = in_flight.msg_tx.send(msg.message_kind.clone()).await; - - // Send ack back to runner - let ups_clone = self.ups.clone(); - let receiver_subject = in_flight.receiver_subject.clone(); - let request_id = msg.request_id; - let message_id = msg.message_id; - let ack_message = protocol::ToClient::ToClientTunnelMessage( - protocol::ToClientTunnelMessage { - request_id, - message_id, - gateway_reply_to: None, - message_kind: protocol::ToClientTunnelMessageKind::TunnelAck, - }, - ); - let ack_message_serialized = - match versioned::ToClient::wrap_latest(ack_message) - .serialize_with_embedded_version(PROTOCOL_VERSION) - { - Ok(x) => x, - Err(err) => { - tracing::error!(?err, "failed to serialize ack"); - continue; - } - }; - tokio::spawn(async move { - match ups_clone - .publish( - &receiver_subject, - &ack_message_serialized, - PublishOpts::one(), - ) - .await - { - Ok(_) => { - tracing::debug!( - request_id=?Uuid::from_bytes(request_id), - message_id=?Uuid::from_bytes(message_id), - "sent TunnelAck to runner" - ); - } - Err(err) => { - tracing::warn!( - ?err, - request_id=?Uuid::from_bytes(request_id), - message_id=?Uuid::from_bytes(message_id), - "failed to send TunnelAck to runner" - ); - } - } - }); - } + // Send message to the request handler to emulate the real network action + tracing::debug!( + request_id=?Uuid::from_bytes(msg.request_id), + message_id=?Uuid::from_bytes(msg.message_id), + "forwarding message to request handler" + ); + let _ = in_flight.msg_tx.send(msg.message_kind.clone()).await; } Err(err) => { tracing::error!(?err, "failed to parse message"); diff --git a/engine/packages/pegboard-runner/src/conn.rs b/engine/packages/pegboard-runner/src/conn.rs index 8a47743f2e..ba5e3ebb24 100644 --- a/engine/packages/pegboard-runner/src/conn.rs +++ b/engine/packages/pegboard-runner/src/conn.rs @@ -8,21 +8,13 @@ use rivet_guard_core::WebSocketHandle; use rivet_runner_protocol as protocol; use rivet_runner_protocol::*; use std::{ - collections::HashMap, sync::{Arc, atomic::AtomicU32}, time::Duration, }; -use tokio::sync::Mutex; use vbare::OwnedVersionedData; use crate::{errors::WsError, utils::UrlData}; -pub struct TunnelActiveRequest { - /// Subject to send replies to. - pub gateway_reply_to: String, - pub is_ws: bool, -} - pub struct Conn { pub namespace_id: Id, pub runner_name: String, @@ -32,10 +24,6 @@ pub struct Conn { pub protocol_version: u16, pub ws_handle: WebSocketHandle, pub last_rtt: AtomicU32, - - /// Active HTTP & WebSocket requests. They are separate but use the same mechanism to - /// maintain state. - pub tunnel_active_requests: Mutex>, } #[tracing::instrument(skip_all)] @@ -191,6 +179,5 @@ pub async fn init_conn( protocol_version, ws_handle, last_rtt: AtomicU32::new(0), - tunnel_active_requests: Mutex::new(HashMap::new()), })) } diff --git a/engine/packages/pegboard-runner/src/lib.rs b/engine/packages/pegboard-runner/src/lib.rs index 05e72fe8f0..99fed0928e 100644 --- a/engine/packages/pegboard-runner/src/lib.rs +++ b/engine/packages/pegboard-runner/src/lib.rs @@ -9,12 +9,10 @@ use rivet_guard_core::{ WebSocketHandle, custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, }; -use rivet_runner_protocol as protocol; use std::time::Duration; use tokio::sync::watch; use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; use universalpubsub::PublishOpts; -use vbare::OwnedVersionedData; mod conn; mod errors; @@ -243,41 +241,6 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { ); } - // Send close messages to all remaining active requests - let active_requests = conn.tunnel_active_requests.lock().await; - for (request_id, req) in &*active_requests { - // Websockets are not ephemeral like requests. If the runner ws closes they are not informed; - // instead they wait for the actor itself to stop. - if req.is_ws { - continue; - } - - let close_message = protocol::ToServerTunnelMessage { - request_id: request_id.clone(), - message_id: Uuid::new_v4().into_bytes(), - message_kind: protocol::ToServerTunnelMessageKind::ToServerResponseAbort, - }; - - let msg_serialized = protocol::versioned::ToGateway::wrap_latest(protocol::ToGateway { - message: close_message.clone(), - }) - .serialize_with_embedded_version(protocol::PROTOCOL_VERSION) - .context("failed to serialize tunnel message for gateway")?; - - // Publish message to UPS - let res = ups - .publish(&req.gateway_reply_to, &msg_serialized, PublishOpts::one()) - .await; - - if let Err(err) = res { - tracing::warn!( - ?err, - %req.gateway_reply_to, - "error sending close message to remaining active requests" - ); - } - } - // This will determine the close frame sent back to the runner websocket lifecycle_res.map(|_| None) } diff --git a/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs index bf92171beb..842bf99e46 100644 --- a/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs +++ b/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs @@ -7,11 +7,7 @@ use tokio::sync::watch; use universalpubsub::{NextOutput, Subscriber}; use vbare::OwnedVersionedData; -use crate::{ - LifecycleResult, - conn::{Conn, TunnelActiveRequest}, - errors, -}; +use crate::{LifecycleResult, conn::Conn, errors}; #[tracing::instrument(skip_all, fields(runner_id=?conn.runner_id, workflow_id=?conn.workflow_id, protocol_version=%conn.protocol_version))] pub async fn task( @@ -47,8 +43,7 @@ pub async fn task( ); // Parse message - let mut msg = match versioned::ToClient::deserialize_with_embedded_version(&ups_msg.payload) - { + let msg = match versioned::ToRunner::deserialize_with_embedded_version(&ups_msg.payload) { Result::Ok(x) => x, Err(err) => { tracing::error!(?err, "failed to parse tunnel message"); @@ -56,11 +51,17 @@ pub async fn task( } }; - match &mut msg { - protocol::ToClient::ToClientClose => return Err(errors::WsError::Eviction.build()), + // Convert to ToClient types + let to_client_msg = match msg { + protocol::ToRunner::ToRunnerKeepAlive(_) => { + // TODO: + continue; + } + protocol::ToRunner::ToClientInit(x) => protocol::ToClient::ToClientInit(x), + protocol::ToRunner::ToClientClose => return Err(errors::WsError::Eviction.build()), // Dynamically populate hibernating request ids - protocol::ToClient::ToClientCommands(command_wrappers) => { - for command_wrapper in command_wrappers { + protocol::ToRunner::ToClientCommands(mut command_wrappers) => { + for command_wrapper in &mut command_wrappers { if let protocol::Command::CommandStartActor(protocol::CommandStartActor { actor_id, hibernating_request_ids, @@ -77,71 +78,28 @@ pub async fn task( ids.into_iter().map(|x| x.into_bytes().to_vec()).collect(); } } + + // NOTE: `command_wrappers` is mutated in this match arm, it is not the same as the + // ToRunner data + protocol::ToClient::ToClientCommands(command_wrappers) } - // Handle tunnel messages - protocol::ToClient::ToClientTunnelMessage(tunnel_msg) => { - match tunnel_msg.message_kind { - protocol::ToClientTunnelMessageKind::ToClientRequestStart(_) => { - // Save active request - // - // This will remove gateway_reply_to from the message since it does not need to be sent to the - // client - if let Some(reply_to) = tunnel_msg.gateway_reply_to.take() { - tracing::debug!(request_id=?Uuid::from_bytes(tunnel_msg.request_id), ?reply_to, "creating active request"); - let mut active_requests = conn.tunnel_active_requests.lock().await; - active_requests.insert( - tunnel_msg.request_id, - TunnelActiveRequest { - gateway_reply_to: reply_to, - is_ws: false, - }, - ); - } - } - // If terminal, remove active request tracking - protocol::ToClientTunnelMessageKind::ToClientRequestAbort => { - tracing::debug!(request_id=?Uuid::from_bytes(tunnel_msg.request_id), "removing active conn due to close message"); - let mut active_requests = conn.tunnel_active_requests.lock().await; - active_requests.remove(&tunnel_msg.request_id); - } - protocol::ToClientTunnelMessageKind::ToClientWebSocketOpen(_) => { - // Save active request - // - // This will remove gateway_reply_to from the message since it does not need to be sent to the - // client - if let Some(reply_to) = tunnel_msg.gateway_reply_to.take() { - tracing::debug!(request_id=?Uuid::from_bytes(tunnel_msg.request_id), ?reply_to, "creating active request"); - let mut active_requests = conn.tunnel_active_requests.lock().await; - active_requests.insert( - tunnel_msg.request_id, - TunnelActiveRequest { - gateway_reply_to: reply_to, - is_ws: true, - }, - ); - } - } - // If terminal, remove active request tracking - protocol::ToClientTunnelMessageKind::ToClientWebSocketClose(_) => { - tracing::debug!(request_id=?Uuid::from_bytes(tunnel_msg.request_id), "removing active conn due to close message"); - let mut active_requests = conn.tunnel_active_requests.lock().await; - active_requests.remove(&tunnel_msg.request_id); - } - _ => {} - } + protocol::ToRunner::ToClientAckEvents(x) => protocol::ToClient::ToClientAckEvents(x), + protocol::ToRunner::ToClientKvResponse(x) => protocol::ToClient::ToClientKvResponse(x), + protocol::ToRunner::ToClientTunnelMessage(x) => { + protocol::ToClient::ToClientTunnelMessage(x) } - _ => {} - } + }; // Forward raw message to WebSocket - let serialized_msg = - match versioned::ToClient::wrap_latest(msg).serialize(conn.protocol_version) { - Result::Ok(x) => x, - Err(err) => { - tracing::error!(?err, "failed to serialize tunnel message"); - continue; - } - }; + let serialized_msg = match versioned::ToClient::wrap_latest(to_client_msg) + .serialize(conn.protocol_version) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to serialize tunnel message"); + continue; + } + }; let ws_msg = WsMessage::Binary(serialized_msg.into()); conn.ws_handle .send(ws_msg) diff --git a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs index 21411cefdd..3c5511823f 100644 --- a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs @@ -4,6 +4,7 @@ use gas::prelude::Id; use gas::prelude::*; use hyper_tungstenite::tungstenite::Message as WsMessage; use hyper_tungstenite::tungstenite::Message; +use pegboard::pubsub_subjects::GatewayReceiverSubject; use pegboard_actor_kv as kv; use rivet_guard_core::websocket_handle::WebSocketReceiver; use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; @@ -333,7 +334,7 @@ async fn handle_message( } } protocol::ToServer::ToServerTunnelMessage(tunnel_msg) => { - handle_tunnel_message(&ctx, &conn, tunnel_msg) + handle_tunnel_message(&ctx, tunnel_msg) .await .context("failed to handle tunnel message")?; } @@ -361,31 +362,15 @@ async fn handle_message( #[tracing::instrument(skip_all)] async fn handle_tunnel_message( ctx: &StandaloneCtx, - conn: &Arc, msg: protocol::ToServerTunnelMessage, ) -> Result<()> { - // Determine reply to subject - let request_id = msg.request_id; - let gateway_reply_to = { - let active_requests = conn.tunnel_active_requests.lock().await; - if let Some(req) = active_requests.get(&request_id) { - req.gateway_reply_to.clone() - } else { - tracing::warn!(request_id=?Uuid::from_bytes(msg.request_id), message_id=?Uuid::from_bytes(msg.message_id), "no active request for tunnel message, may have timed out"); - return Ok(()); - } - }; - - // Remove active request entries when terminal - if is_to_server_tunnel_message_kind_request_close(&msg.message_kind) { - let mut active_requests = conn.tunnel_active_requests.lock().await; - active_requests.remove(&request_id); - } - // Publish message to UPS - let msg_serialized = versioned::ToGateway::wrap_latest(protocol::ToGateway { message: msg }) - .serialize_with_embedded_version(PROTOCOL_VERSION) - .context("failed to serialize tunnel message for gateway")?; + let gateway_reply_to = + GatewayReceiverSubject::new(Uuid::from_bytes(msg.gateway_id)).to_string(); + let msg_serialized = + versioned::ToGateway::wrap_latest(protocol::ToGateway::ToServerTunnelMessage(msg)) + .serialize_with_embedded_version(PROTOCOL_VERSION) + .context("failed to serialize tunnel message for gateway")?; ctx.ups() .context("failed to get UPS instance for tunnel message")? .publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) @@ -399,18 +384,3 @@ async fn handle_tunnel_message( Ok(()) } - -/// Determines if a given message kind will terminate the request. -fn is_to_server_tunnel_message_kind_request_close( - kind: &protocol::ToServerTunnelMessageKind, -) -> bool { - match kind { - // HTTP terminal states - protocol::ToServerTunnelMessageKind::ToServerResponseStart(resp) => !resp.stream, - protocol::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.finish, - protocol::ToServerTunnelMessageKind::ToServerResponseAbort => true, - // WebSocket terminal states (either side closes) - protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(_) => true, - _ => false, - } -} diff --git a/engine/sdks/rust/runner-protocol/src/versioned.rs b/engine/sdks/rust/runner-protocol/src/versioned.rs index 6254a336c8..db2a24e300 100644 --- a/engine/sdks/rust/runner-protocol/src/versioned.rs +++ b/engine/sdks/rust/runner-protocol/src/versioned.rs @@ -175,6 +175,7 @@ impl ToClient { } v2::ToClient::ToClientTunnelMessage(msg) => { v3::ToClient::ToClientTunnelMessage(v3::ToClientTunnelMessage { + gateway_id: [0; 16], request_id: msg.request_id, message_id: msg.message_id, message_kind: convert_to_client_tunnel_message_kind_v2_to_v3( @@ -489,6 +490,7 @@ impl ToServer { } v2::ToServer::ToServerTunnelMessage(msg) => { v3::ToServer::ToServerTunnelMessage(v3::ToServerTunnelMessage { + gateway_id: [0; 16], request_id: msg.request_id, message_id: msg.message_id, message_kind: convert_to_server_tunnel_message_kind_v2_to_v3( @@ -633,6 +635,41 @@ impl ToServer { } } +pub enum ToRunner { + // Only in v3 + V3(v3::ToRunner), +} + +impl OwnedVersionedData for ToRunner { + type Latest = v3::ToRunner; + + fn wrap_latest(latest: v3::ToRunner) -> Self { + ToRunner::V3(latest) + } + + fn unwrap_latest(self) -> Result { + #[allow(irrefutable_let_patterns)] + if let ToRunner::V3(data) = self { + Ok(data) + } else { + bail!("version not latest"); + } + } + + fn deserialize_version(payload: &[u8], version: u16) -> Result { + match version { + 1 | 2 | 3 => Ok(ToRunner::V3(serde_bare::from_slice(payload)?)), + _ => bail!("invalid version: {version}"), + } + } + + fn serialize_version(self, _version: u16) -> Result> { + match self { + ToRunner::V3(data) => serde_bare::to_vec(&data).map_err(Into::into), + } + } +} + pub enum ToGateway { // No change between v1 and v3 V3(v3::ToGateway), @@ -1152,7 +1189,6 @@ fn convert_to_client_tunnel_message_kind_v2_to_v3( kind: v2::ToClientTunnelMessageKind, ) -> v3::ToClientTunnelMessageKind { match kind { - v2::ToClientTunnelMessageKind::TunnelAck => v3::ToClientTunnelMessageKind::TunnelAck, v2::ToClientTunnelMessageKind::ToClientRequestStart(req) => { v3::ToClientTunnelMessageKind::ToClientRequestStart(v3::ToClientRequestStart { actor_id: req.actor_id, @@ -1181,8 +1217,7 @@ fn convert_to_client_tunnel_message_kind_v2_to_v3( } v2::ToClientTunnelMessageKind::ToClientWebSocketMessage(msg) => { v3::ToClientTunnelMessageKind::ToClientWebSocketMessage(v3::ToClientWebSocketMessage { - // Default to 0 for v2 messages (hibernation disabled by default) - index: 0, + index: msg.index, data: msg.data, binary: msg.binary, }) @@ -1193,6 +1228,13 @@ fn convert_to_client_tunnel_message_kind_v2_to_v3( reason: close.reason, }) } + // TunnelAck was removed in v3 + v2::ToClientTunnelMessageKind::TunnelAck => { + // TunnelAck is deprecated and should not be used + // For backwards compatibility, we skip it + // This shouldn't happen in practice as TunnelAck was removed + v3::ToClientTunnelMessageKind::ToClientRequestAbort + } } } @@ -1200,7 +1242,6 @@ fn convert_to_client_tunnel_message_kind_v3_to_v2( kind: v3::ToClientTunnelMessageKind, ) -> Result { Ok(match kind { - v3::ToClientTunnelMessageKind::TunnelAck => v2::ToClientTunnelMessageKind::TunnelAck, v3::ToClientTunnelMessageKind::ToClientRequestStart(req) => { v2::ToClientTunnelMessageKind::ToClientRequestStart(v2::ToClientRequestStart { actor_id: req.actor_id, @@ -1247,7 +1288,6 @@ fn convert_to_server_tunnel_message_kind_v2_to_v3( kind: v2::ToServerTunnelMessageKind, ) -> v3::ToServerTunnelMessageKind { match kind { - v2::ToServerTunnelMessageKind::TunnelAck => v3::ToServerTunnelMessageKind::TunnelAck, v2::ToServerTunnelMessageKind::ToServerResponseStart(resp) => { v3::ToServerTunnelMessageKind::ToServerResponseStart(v3::ToServerResponseStart { status: resp.status, @@ -1288,6 +1328,13 @@ fn convert_to_server_tunnel_message_kind_v2_to_v3( hibernate: close.retry, }) } + // TunnelAck was removed in v3 + v2::ToServerTunnelMessageKind::TunnelAck => { + // TunnelAck is deprecated and should not be used + // For backwards compatibility, we skip it + // This shouldn't happen in practice as TunnelAck was removed + v3::ToServerTunnelMessageKind::ToServerResponseAbort + } } } @@ -1295,7 +1342,6 @@ fn convert_to_server_tunnel_message_kind_v3_to_v2( kind: v3::ToServerTunnelMessageKind, ) -> Result { Ok(match kind { - v3::ToServerTunnelMessageKind::TunnelAck => v2::ToServerTunnelMessageKind::TunnelAck, v3::ToServerTunnelMessageKind::ToServerResponseStart(resp) => { v2::ToServerTunnelMessageKind::ToServerResponseStart(v2::ToServerResponseStart { status: resp.status, diff --git a/engine/sdks/schemas/runner-protocol/v3.bare b/engine/sdks/schemas/runner-protocol/v3.bare index 3c6040b698..d2085e59eb 100644 --- a/engine/sdks/schemas/runner-protocol/v3.bare +++ b/engine/sdks/schemas/runner-protocol/v3.bare @@ -193,13 +193,10 @@ type CommandWrapper struct { # MARK: Tunnel +type GatewayId data[16] # UUIDv4 type RequestId data[16] # UUIDv4 type MessageId data[16] # UUIDv4 - -# Ack -type TunnelAck void - # HTTP type ToClientRequestStart struct { actorId: Id @@ -270,8 +267,6 @@ type ToServerWebSocketClose struct { # To Server type ToServerTunnelMessageKind union { - TunnelAck | - # HTTP ToServerResponseStart | ToServerResponseChunk | @@ -285,6 +280,7 @@ type ToServerTunnelMessageKind union { } type ToServerTunnelMessage struct { + gatewayId: GatewayId requestId: RequestId messageId: MessageId messageKind: ToServerTunnelMessageKind @@ -292,8 +288,6 @@ type ToServerTunnelMessage struct { # To Client type ToClientTunnelMessageKind union { - TunnelAck | - # HTTP ToClientRequestStart | ToClientRequestChunk | @@ -306,6 +300,7 @@ type ToClientTunnelMessageKind union { } type ToClientTunnelMessage struct { + gatewayId: GatewayId requestId: RequestId messageId: MessageId messageKind: ToClientTunnelMessageKind @@ -389,9 +384,29 @@ type ToClient union { ToClientTunnelMessage } +# MARK: To Runner +type ToRunnerKeepAlive struct { + requestId: RequestId +} + +# We have to re-declare the entire union since BARE will not generate the +# ser/de for ToClient if it's not a top-level type +type ToRunner union { + ToRunnerKeepAlive | + ToClientInit | + ToClientClose | + ToClientCommands | + ToClientAckEvents | + ToClientKvResponse | + ToClientTunnelMessage +} + # MARK: To Gateway -type ToGateway struct { - message: ToServerTunnelMessage +type ToGatewayKeepAlive void + +type ToGateway union { + ToGatewayKeepAlive | + ToServerTunnelMessage } # MARK: Serverless diff --git a/engine/sdks/typescript/runner-protocol/src/index.ts b/engine/sdks/typescript/runner-protocol/src/index.ts index aa3065f134..0531f2c169 100644 --- a/engine/sdks/typescript/runner-protocol/src/index.ts +++ b/engine/sdks/typescript/runner-protocol/src/index.ts @@ -918,6 +918,20 @@ export function writeCommandWrapper(bc: bare.ByteCursor, x: CommandWrapper): voi writeCommand(bc, x.inner) } +export type GatewayId = ArrayBuffer + +export function readGatewayId(bc: bare.ByteCursor): GatewayId { + return bare.readFixedData(bc, 16) +} + +export function writeGatewayId(bc: bare.ByteCursor, x: GatewayId): void { + assert(x.byteLength === 16) + bare.writeFixedData(bc, x) +} + +/** + * UUIDv4 + */ export type RequestId = ArrayBuffer export function readRequestId(bc: bare.ByteCursor): RequestId { @@ -943,11 +957,6 @@ export function writeMessageId(bc: bare.ByteCursor, x: MessageId): void { bare.writeFixedData(bc, x) } -/** - * Ack - */ -export type TunnelAck = null - function read9(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() @@ -1204,7 +1213,6 @@ export function writeToServerWebSocketClose(bc: bare.ByteCursor, x: ToServerWebS * To Server */ export type ToServerTunnelMessageKind = - | { readonly tag: "TunnelAck"; readonly val: TunnelAck } /** * HTTP */ @@ -1224,20 +1232,18 @@ export function readToServerTunnelMessageKind(bc: bare.ByteCursor): ToServerTunn const tag = bare.readU8(bc) switch (tag) { case 0: - return { tag: "TunnelAck", val: null } - case 1: return { tag: "ToServerResponseStart", val: readToServerResponseStart(bc) } - case 2: + case 1: return { tag: "ToServerResponseChunk", val: readToServerResponseChunk(bc) } - case 3: + case 2: return { tag: "ToServerResponseAbort", val: null } - case 4: + case 3: return { tag: "ToServerWebSocketOpen", val: readToServerWebSocketOpen(bc) } - case 5: + case 4: return { tag: "ToServerWebSocketMessage", val: readToServerWebSocketMessage(bc) } - case 6: + case 5: return { tag: "ToServerWebSocketMessageAck", val: readToServerWebSocketMessageAck(bc) } - case 7: + case 6: return { tag: "ToServerWebSocketClose", val: readToServerWebSocketClose(bc) } default: { bc.offset = offset @@ -1248,41 +1254,37 @@ export function readToServerTunnelMessageKind(bc: bare.ByteCursor): ToServerTunn export function writeToServerTunnelMessageKind(bc: bare.ByteCursor, x: ToServerTunnelMessageKind): void { switch (x.tag) { - case "TunnelAck": { - bare.writeU8(bc, 0) - break - } case "ToServerResponseStart": { - bare.writeU8(bc, 1) + bare.writeU8(bc, 0) writeToServerResponseStart(bc, x.val) break } case "ToServerResponseChunk": { - bare.writeU8(bc, 2) + bare.writeU8(bc, 1) writeToServerResponseChunk(bc, x.val) break } case "ToServerResponseAbort": { - bare.writeU8(bc, 3) + bare.writeU8(bc, 2) break } case "ToServerWebSocketOpen": { - bare.writeU8(bc, 4) + bare.writeU8(bc, 3) writeToServerWebSocketOpen(bc, x.val) break } case "ToServerWebSocketMessage": { - bare.writeU8(bc, 5) + bare.writeU8(bc, 4) writeToServerWebSocketMessage(bc, x.val) break } case "ToServerWebSocketMessageAck": { - bare.writeU8(bc, 6) + bare.writeU8(bc, 5) writeToServerWebSocketMessageAck(bc, x.val) break } case "ToServerWebSocketClose": { - bare.writeU8(bc, 7) + bare.writeU8(bc, 6) writeToServerWebSocketClose(bc, x.val) break } @@ -1290,6 +1292,7 @@ export function writeToServerTunnelMessageKind(bc: bare.ByteCursor, x: ToServerT } export type ToServerTunnelMessage = { + readonly gatewayId: GatewayId readonly requestId: RequestId readonly messageId: MessageId readonly messageKind: ToServerTunnelMessageKind @@ -1297,6 +1300,7 @@ export type ToServerTunnelMessage = { export function readToServerTunnelMessage(bc: bare.ByteCursor): ToServerTunnelMessage { return { + gatewayId: readGatewayId(bc), requestId: readRequestId(bc), messageId: readMessageId(bc), messageKind: readToServerTunnelMessageKind(bc), @@ -1304,6 +1308,7 @@ export function readToServerTunnelMessage(bc: bare.ByteCursor): ToServerTunnelMe } export function writeToServerTunnelMessage(bc: bare.ByteCursor, x: ToServerTunnelMessage): void { + writeGatewayId(bc, x.gatewayId) writeRequestId(bc, x.requestId) writeMessageId(bc, x.messageId) writeToServerTunnelMessageKind(bc, x.messageKind) @@ -1313,7 +1318,6 @@ export function writeToServerTunnelMessage(bc: bare.ByteCursor, x: ToServerTunne * To Client */ export type ToClientTunnelMessageKind = - | { readonly tag: "TunnelAck"; readonly val: TunnelAck } /** * HTTP */ @@ -1332,18 +1336,16 @@ export function readToClientTunnelMessageKind(bc: bare.ByteCursor): ToClientTunn const tag = bare.readU8(bc) switch (tag) { case 0: - return { tag: "TunnelAck", val: null } - case 1: return { tag: "ToClientRequestStart", val: readToClientRequestStart(bc) } - case 2: + case 1: return { tag: "ToClientRequestChunk", val: readToClientRequestChunk(bc) } - case 3: + case 2: return { tag: "ToClientRequestAbort", val: null } - case 4: + case 3: return { tag: "ToClientWebSocketOpen", val: readToClientWebSocketOpen(bc) } - case 5: + case 4: return { tag: "ToClientWebSocketMessage", val: readToClientWebSocketMessage(bc) } - case 6: + case 5: return { tag: "ToClientWebSocketClose", val: readToClientWebSocketClose(bc) } default: { bc.offset = offset @@ -1354,36 +1356,32 @@ export function readToClientTunnelMessageKind(bc: bare.ByteCursor): ToClientTunn export function writeToClientTunnelMessageKind(bc: bare.ByteCursor, x: ToClientTunnelMessageKind): void { switch (x.tag) { - case "TunnelAck": { - bare.writeU8(bc, 0) - break - } case "ToClientRequestStart": { - bare.writeU8(bc, 1) + bare.writeU8(bc, 0) writeToClientRequestStart(bc, x.val) break } case "ToClientRequestChunk": { - bare.writeU8(bc, 2) + bare.writeU8(bc, 1) writeToClientRequestChunk(bc, x.val) break } case "ToClientRequestAbort": { - bare.writeU8(bc, 3) + bare.writeU8(bc, 2) break } case "ToClientWebSocketOpen": { - bare.writeU8(bc, 4) + bare.writeU8(bc, 3) writeToClientWebSocketOpen(bc, x.val) break } case "ToClientWebSocketMessage": { - bare.writeU8(bc, 5) + bare.writeU8(bc, 4) writeToClientWebSocketMessage(bc, x.val) break } case "ToClientWebSocketClose": { - bare.writeU8(bc, 6) + bare.writeU8(bc, 5) writeToClientWebSocketClose(bc, x.val) break } @@ -1391,6 +1389,7 @@ export function writeToClientTunnelMessageKind(bc: bare.ByteCursor, x: ToClientT } export type ToClientTunnelMessage = { + readonly gatewayId: GatewayId readonly requestId: RequestId readonly messageId: MessageId readonly messageKind: ToClientTunnelMessageKind @@ -1402,6 +1401,7 @@ export type ToClientTunnelMessage = { export function readToClientTunnelMessage(bc: bare.ByteCursor): ToClientTunnelMessage { return { + gatewayId: readGatewayId(bc), requestId: readRequestId(bc), messageId: readMessageId(bc), messageKind: readToClientTunnelMessageKind(bc), @@ -1410,6 +1410,7 @@ export function readToClientTunnelMessage(bc: bare.ByteCursor): ToClientTunnelMe } export function writeToClientTunnelMessage(bc: bare.ByteCursor, x: ToClientTunnelMessage): void { + writeGatewayId(bc, x.gatewayId) writeRequestId(bc, x.requestId) writeMessageId(bc, x.messageId) writeToClientTunnelMessageKind(bc, x.messageKind) @@ -1832,20 +1833,154 @@ export function decodeToClient(bytes: Uint8Array): ToClient { } /** - * MARK: To Gateway + * MARK: To Runner */ -export type ToGateway = { - readonly message: ToServerTunnelMessage +export type ToRunnerKeepAlive = { + readonly requestId: RequestId } -export function readToGateway(bc: bare.ByteCursor): ToGateway { +export function readToRunnerKeepAlive(bc: bare.ByteCursor): ToRunnerKeepAlive { return { - message: readToServerTunnelMessage(bc), + requestId: readRequestId(bc), + } +} + +export function writeToRunnerKeepAlive(bc: bare.ByteCursor, x: ToRunnerKeepAlive): void { + writeRequestId(bc, x.requestId) +} + +/** + * We have to re-declare the entire union since BARE will not generate the + * ser/de for ToClient if it's not a top-level type + */ +export type ToRunner = + | { readonly tag: "ToRunnerKeepAlive"; readonly val: ToRunnerKeepAlive } + | { readonly tag: "ToClientInit"; readonly val: ToClientInit } + | { readonly tag: "ToClientClose"; readonly val: ToClientClose } + | { readonly tag: "ToClientCommands"; readonly val: ToClientCommands } + | { readonly tag: "ToClientAckEvents"; readonly val: ToClientAckEvents } + | { readonly tag: "ToClientKvResponse"; readonly val: ToClientKvResponse } + | { readonly tag: "ToClientTunnelMessage"; readonly val: ToClientTunnelMessage } + +export function readToRunner(bc: bare.ByteCursor): ToRunner { + const offset = bc.offset + const tag = bare.readU8(bc) + switch (tag) { + case 0: + return { tag: "ToRunnerKeepAlive", val: readToRunnerKeepAlive(bc) } + case 1: + return { tag: "ToClientInit", val: readToClientInit(bc) } + case 2: + return { tag: "ToClientClose", val: null } + case 3: + return { tag: "ToClientCommands", val: readToClientCommands(bc) } + case 4: + return { tag: "ToClientAckEvents", val: readToClientAckEvents(bc) } + case 5: + return { tag: "ToClientKvResponse", val: readToClientKvResponse(bc) } + case 6: + return { tag: "ToClientTunnelMessage", val: readToClientTunnelMessage(bc) } + default: { + bc.offset = offset + throw new bare.BareError(offset, "invalid tag") + } + } +} + +export function writeToRunner(bc: bare.ByteCursor, x: ToRunner): void { + switch (x.tag) { + case "ToRunnerKeepAlive": { + bare.writeU8(bc, 0) + writeToRunnerKeepAlive(bc, x.val) + break + } + case "ToClientInit": { + bare.writeU8(bc, 1) + writeToClientInit(bc, x.val) + break + } + case "ToClientClose": { + bare.writeU8(bc, 2) + break + } + case "ToClientCommands": { + bare.writeU8(bc, 3) + writeToClientCommands(bc, x.val) + break + } + case "ToClientAckEvents": { + bare.writeU8(bc, 4) + writeToClientAckEvents(bc, x.val) + break + } + case "ToClientKvResponse": { + bare.writeU8(bc, 5) + writeToClientKvResponse(bc, x.val) + break + } + case "ToClientTunnelMessage": { + bare.writeU8(bc, 6) + writeToClientTunnelMessage(bc, x.val) + break + } + } +} + +export function encodeToRunner(x: ToRunner, config?: Partial): Uint8Array { + const fullConfig = config != null ? bare.Config(config) : DEFAULT_CONFIG + const bc = new bare.ByteCursor( + new Uint8Array(fullConfig.initialBufferLength), + fullConfig, + ) + writeToRunner(bc, x) + return new Uint8Array(bc.view.buffer, bc.view.byteOffset, bc.offset) +} + +export function decodeToRunner(bytes: Uint8Array): ToRunner { + const bc = new bare.ByteCursor(bytes, DEFAULT_CONFIG) + const result = readToRunner(bc) + if (bc.offset < bc.view.byteLength) { + throw new bare.BareError(bc.offset, "remaining bytes") + } + return result +} + +/** + * MARK: To Gateway + */ +export type ToGatewayKeepAlive = null + +export type ToGateway = + | { readonly tag: "ToGatewayKeepAlive"; readonly val: ToGatewayKeepAlive } + | { readonly tag: "ToServerTunnelMessage"; readonly val: ToServerTunnelMessage } + +export function readToGateway(bc: bare.ByteCursor): ToGateway { + const offset = bc.offset + const tag = bare.readU8(bc) + switch (tag) { + case 0: + return { tag: "ToGatewayKeepAlive", val: null } + case 1: + return { tag: "ToServerTunnelMessage", val: readToServerTunnelMessage(bc) } + default: { + bc.offset = offset + throw new bare.BareError(offset, "invalid tag") + } } } export function writeToGateway(bc: bare.ByteCursor, x: ToGateway): void { - writeToServerTunnelMessage(bc, x.message) + switch (x.tag) { + case "ToGatewayKeepAlive": { + bare.writeU8(bc, 0) + break + } + case "ToServerTunnelMessage": { + bare.writeU8(bc, 1) + writeToServerTunnelMessage(bc, x.val) + break + } + } } export function encodeToGateway(x: ToGateway, config?: Partial): Uint8Array { diff --git a/engine/sdks/typescript/runner/src/tunnel.ts b/engine/sdks/typescript/runner/src/tunnel.ts index f261ae5ffa..f71c9cb89e 100644 --- a/engine/sdks/typescript/runner/src/tunnel.ts +++ b/engine/sdks/typescript/runner/src/tunnel.ts @@ -438,29 +438,6 @@ export class Tunnel { this.#runner.__sendToServer(message); } - #sendAck(requestId: RequestId, messageId: MessageId) { - if (!this.#runner.__webSocketReady()) { - return; - } - - const message: protocol.ToServer = { - tag: "ToServerTunnelMessage", - val: { - requestId, - messageId, - messageKind: { tag: "TunnelAck", val: null }, - }, - }; - - this.log?.debug({ - msg: "ack tunnel msg", - requestId: idToStr(requestId), - messageId: idToStr(messageId), - }); - - this.#runner.__sendToServer(message); - } - #startGarbageCollector() { if (this.#gcInterval) { clearInterval(this.#gcInterval); @@ -609,71 +586,43 @@ export class Tunnel { message: stringifyToClientTunnelMessageKind(message.messageKind), }); - if (message.messageKind.tag === "TunnelAck") { - // Mark pending message as acknowledged and remove it - const actor = this.getRequestActor(requestIdStr); - if (actor) { - const didDelete = - actor.pendingTunnelMessages.delete(messageIdStr); - if (!didDelete) { - this.log?.warn({ - msg: "received tunnel ack for nonexistent message", - requestId: requestIdStr, - messageId: messageIdStr, - }); - } - } - } else { - switch (message.messageKind.tag) { - case "ToClientRequestStart": - this.#sendAck(message.requestId, message.messageId); - - await this.#handleRequestStart( - message.requestId, - message.messageKind.val, - ); - break; - case "ToClientRequestChunk": - this.#sendAck(message.requestId, message.messageId); - - await this.#handleRequestChunk( - message.requestId, - message.messageKind.val, - ); - break; - case "ToClientRequestAbort": - this.#sendAck(message.requestId, message.messageId); - - await this.#handleRequestAbort(message.requestId); - break; - case "ToClientWebSocketOpen": - this.#sendAck(message.requestId, message.messageId); - - await this.#handleWebSocketOpen( - message.requestId, - message.messageKind.val, - ); - break; - case "ToClientWebSocketMessage": { - this.#sendAck(message.requestId, message.messageId); - - this.#handleWebSocketMessage( - message.requestId, - message.messageKind.val, - ); - break; - } - case "ToClientWebSocketClose": - this.#sendAck(message.requestId, message.messageId); - - await this.#handleWebSocketClose( - message.requestId, - message.messageKind.val, - ); - break; - default: - unreachable(message.messageKind); + switch (message.messageKind.tag) { + case "ToClientRequestStart": + await this.#handleRequestStart( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientRequestChunk": + await this.#handleRequestChunk( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientRequestAbort": + await this.#handleRequestAbort(message.requestId); + break; + case "ToClientWebSocketOpen": + await this.#handleWebSocketOpen( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientWebSocketMessage": { + this.#handleWebSocketMessage( + message.requestId, + message.messageKind.val, + ); + break; } + case "ToClientWebSocketClose": + await this.#handleWebSocketClose( + message.requestId, + message.messageKind.val, + ); + break; + default: + unreachable(message.messageKind); } }