From dafcccf79f138e4e94c44e3483d74135c3d38764 Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Fri, 24 Oct 2025 18:59:50 -0700 Subject: [PATCH] fix: improve reconnection with sleeping and tunnel --- Cargo.lock | 24 ++ Cargo.toml | 1 + ...guard.websocket_pending_limit_reached.json | 5 + engine/packages/epoxy/src/ops/kv/get_local.rs | 15 +- .../epoxy/src/ops/kv/get_optimistic.rs | 49 ++-- .../packages/gasoline/src/ctx/standalone.rs | 6 +- .../packages/guard-core/src/custom_serve.rs | 3 + .../packages/guard-core/src/proxy_service.rs | 106 +++++--- .../guard-core/src/websocket_handle.rs | 105 ++------ .../packages/guard/src/routing/api_public.rs | 1 + engine/packages/pegboard-gateway/Cargo.toml | 3 + engine/packages/pegboard-gateway/src/lib.rs | 34 ++- .../pegboard-gateway/src/shared_state.rs | 228 ++++++++++++------ .../src/client_to_pubsub_task.rs | 10 +- engine/packages/pegboard-runner/src/conn.rs | 6 +- engine/packages/pegboard-runner/src/lib.rs | 11 +- .../src/pubsub_to_client_task.rs | 4 +- .../packages/pegboard-serverless/src/lib.rs | 8 +- .../pegboard/src/workflows/actor/runtime.rs | 30 +-- .../src/driver/postgres/mod.rs | 6 +- engine/sdks/typescript/runner/src/mod.ts | 32 +-- engine/sdks/typescript/runner/src/tunnel.ts | 57 ++++- .../runner/src/websocket-tunnel-adapter.ts | 31 ++- scripts/tests/actor_sleep.ts | 24 +- 24 files changed, 466 insertions(+), 333 deletions(-) create mode 100644 engine/artifacts/errors/guard.websocket_pending_limit_reached.json diff --git a/Cargo.lock b/Cargo.lock index c3fdd146ac..750c97cc2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3275,7 +3275,9 @@ dependencies = [ "rivet-guard-core", "rivet-runner-protocol", "rivet-util", + "scc", "serde", + "serde_json", "thiserror 1.0.69", "tokio", "tokio-tungstenite", @@ -4929,6 +4931,12 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "saa" +version = "5.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f895faf11c46e98547f4de603a113ca76708d4b6832dbbe3c26528b7b81aca3b" + [[package]] name = "safe_arch" version = "0.7.4" @@ -4938,6 +4946,16 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "scc" +version = "3.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd0b9e1890c5b17833a779c68a974f04170dfa36e3789395d17845418cc779ac" +dependencies = [ + "saa", + "sdd", +] + [[package]] name = "schannel" version = "0.1.27" @@ -5019,6 +5037,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sdd" +version = "4.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a8729f5224c38cb041e72fa9968dd4e379d3487b85359539d31d75ed95992d8" + [[package]] name = "sealed" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 4b9d52c8ce..62fff1c624 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,7 @@ regex = "1.4" rstest = "0.26.1" rustls-pemfile = "2.2.0" rustyline = "15.0.0" +scc = "3.3.2" serde_bare = "0.5.0" serde_html_form = "0.2.7" serde_yaml = "0.9.34" diff --git a/engine/artifacts/errors/guard.websocket_pending_limit_reached.json b/engine/artifacts/errors/guard.websocket_pending_limit_reached.json new file mode 100644 index 0000000000..53e830a61b --- /dev/null +++ b/engine/artifacts/errors/guard.websocket_pending_limit_reached.json @@ -0,0 +1,5 @@ +{ + "code": "websocket_pending_limit_reached", + "group": "guard", + "message": "Reached limit on pending websocket messages, aborting connection." +} diff --git a/engine/packages/epoxy/src/ops/kv/get_local.rs b/engine/packages/epoxy/src/ops/kv/get_local.rs index 64b3f95116..52987054c3 100644 --- a/engine/packages/epoxy/src/ops/kv/get_local.rs +++ b/engine/packages/epoxy/src/ops/kv/get_local.rs @@ -30,15 +30,12 @@ pub async fn epoxy_kv_get_local(ctx: &OperationCtx, input: &Input) -> Result Resul let kv_key = kv_key.clone(); let cache_key = cache_key.clone(); async move { - (async move { - let (value, cache_value) = tokio::try_join!( - async { - let v = tx.get(&packed_key, Serializable).await?; - if let Some(ref bytes) = v { - Ok(Some(kv_key.deserialize(bytes)?)) - } else { - Ok(None) - } - }, - async { - let v = tx.get(&packed_cache_key, Serializable).await?; - if let Some(ref bytes) = v { - Ok(Some(cache_key.deserialize(bytes)?)) - } else { - Ok(None) - } + let (value, cache_value) = tokio::try_join!( + async { + let v = tx.get(&packed_key, Serializable).await?; + if let Some(ref bytes) = v { + Ok(Some(kv_key.deserialize(bytes)?)) + } else { + Ok(None) } - )?; + }, + async { + let v = tx.get(&packed_cache_key, Serializable).await?; + if let Some(ref bytes) = v { + Ok(Some(cache_key.deserialize(bytes)?)) + } else { + Ok(None) + } + } + )?; - Ok(value.or(cache_value)) - }) - .await + Ok(value.or(cache_value)) } }) .custom_instrument(tracing::info_span!("get_optimistic_tx")) @@ -134,13 +131,11 @@ pub async fn epoxy_kv_get_optimistic(ctx: &OperationCtx, input: &Input) -> Resul let packed_cache_key = packed_cache_key.clone(); let cache_key = cache_key.clone(); let value_to_cache = value.clone(); + async move { - (async move { - let serialized = cache_key.serialize(value_to_cache)?; - tx.set(&packed_cache_key, &serialized); - Ok(()) - }) - .await + let serialized = cache_key.serialize(value_to_cache)?; + tx.set(&packed_cache_key, &serialized); + Ok(()) } }) .custom_instrument(tracing::info_span!("cache_value_tx")) diff --git a/engine/packages/gasoline/src/ctx/standalone.rs b/engine/packages/gasoline/src/ctx/standalone.rs index 25e08c796e..222e764618 100644 --- a/engine/packages/gasoline/src/ctx/standalone.rs +++ b/engine/packages/gasoline/src/ctx/standalone.rs @@ -46,9 +46,9 @@ impl StandaloneCtx { ) -> WorkflowResult { let ts = rivet_util::timestamp::now(); - let span = tracing::Span::current(); - span.record("req_id", req_id.to_string()); - span.record("ray_id", ray_id.to_string()); + tracing::Span::current() + .record("req_id", req_id.to_string()) + .record("ray_id", ray_id.to_string()); let msg_ctx = MessageCtx::new(&config, &pools, &cache, ray_id)?; diff --git a/engine/packages/guard-core/src/custom_serve.rs b/engine/packages/guard-core/src/custom_serve.rs index 3d54fdaeaa..f6343a98a5 100644 --- a/engine/packages/guard-core/src/custom_serve.rs +++ b/engine/packages/guard-core/src/custom_serve.rs @@ -3,6 +3,7 @@ use async_trait::async_trait; use bytes::Bytes; use http_body_util::Full; use hyper::{Request, Response}; +use uuid::Uuid; use crate::WebSocketHandle; use crate::proxy_service::ResponseBody; @@ -25,5 +26,7 @@ pub trait CustomServeTrait: Send + Sync { headers: &hyper::HeaderMap, path: &str, request_context: &mut RequestContext, + // Identifies the websocket across retries. + unique_request_id: Uuid, ) -> Result<()>; } diff --git a/engine/packages/guard-core/src/proxy_service.rs b/engine/packages/guard-core/src/proxy_service.rs index f4498099f3..296a804bd8 100644 --- a/engine/packages/guard-core/src/proxy_service.rs +++ b/engine/packages/guard-core/src/proxy_service.rs @@ -28,6 +28,7 @@ use tokio_tungstenite::tungstenite::{ }; use tracing::Instrument; use url::Url; +use uuid::Uuid; use crate::{ WebSocketHandle, custom_serve::CustomServeTrait, errors, metrics, @@ -1171,7 +1172,7 @@ impl ProxyService { } // Handle WebSocket upgrade properly with hyper_tungstenite - tracing::debug!("Upgrading client connection to WebSocket"); + tracing::debug!(%req_path, "Upgrading client connection to WebSocket"); let (client_response, client_ws) = match hyper_tungstenite::upgrade(req, None) { Ok(x) => { tracing::debug!("Client WebSocket upgrade successful"); @@ -1782,18 +1783,20 @@ impl ProxyService { } ResolveRouteOutput::Response(_) => unreachable!(), ResolveRouteOutput::CustomServe(mut handlers) => { - tracing::debug!("Spawning task to handle WebSocket communication"); + tracing::debug!(%req_path, "Spawning task to handle WebSocket communication"); let mut request_context = request_context.clone(); let req_headers = req_headers.clone(); let req_path = req_path.clone(); let req_host = req_host.clone(); - // TODO: Handle errors here, the error message is lost tokio::spawn( async move { + let request_id = Uuid::new_v4(); let mut attempts = 0u32; - let ws_handle = WebSocketHandle::new(client_ws); + let ws_handle = WebSocketHandle::new(client_ws) + .await + .context("failed initiating websocket handle")?; loop { match handlers @@ -1802,6 +1805,7 @@ impl ProxyService { &req_headers, &req_path, &mut request_context, + request_id, ) .await { @@ -1825,13 +1829,17 @@ impl ProxyService { break; } Err(err) => { + tracing::debug!(?err, "websocket handler error"); + attempts += 1; if attempts > max_attempts || !is_retryable_ws_error(&err) { + tracing::debug!(?attempts, "WebSocket failed to reconnect"); + // Close WebSocket with error ws_handle - .accept_and_send(to_hyper_close(Some( - err_to_close_frame(err, ray_id), - ))) + .send(to_hyper_close(Some(err_to_close_frame( + err, ray_id, + )))) .await?; // Flush to ensure close frame is sent @@ -1846,6 +1854,13 @@ impl ProxyService { attempts, initial_interval, ); + let backoff = Duration::from_millis(100); + + tracing::debug!( + ?backoff, + "WebSocket attempt {attempts} failed (service unavailable)" + ); + tokio::time::sleep(backoff).await; match state @@ -1864,11 +1879,9 @@ impl ProxyService { } Ok(ResolveRouteOutput::Response(response)) => { ws_handle - .accept_and_send(to_hyper_close(Some( - str_to_close_frame( - response.message.as_ref(), - ), - ))) + .send(to_hyper_close(Some(str_to_close_frame( + response.message.as_ref(), + )))) .await?; // Flush to ensure close frame is sent @@ -1879,12 +1892,10 @@ impl ProxyService { } Ok(ResolveRouteOutput::Target(_)) => { ws_handle - .accept_and_send(to_hyper_close(Some( - err_to_close_frame( - errors::WebSocketTargetChanged.build(), - ray_id, - ), - ))) + .send(to_hyper_close(Some(err_to_close_frame( + errors::WebSocketTargetChanged.build(), + ray_id, + )))) .await?; // Flush to ensure close frame is sent @@ -1897,9 +1908,9 @@ impl ProxyService { } Err(err) => { ws_handle - .accept_and_send(to_hyper_close(Some( - err_to_close_frame(err, ray_id), - ))) + .send(to_hyper_close(Some(err_to_close_frame( + err, ray_id, + )))) .await?; // Flush to ensure close frame is sent @@ -1947,13 +1958,17 @@ impl ProxyService { impl ProxyService { // Process an individual request - #[tracing::instrument(name = "guard_request", skip_all)] + #[tracing::instrument(name = "guard_request", skip_all, fields(ray_id, req_id))] pub async fn process(&self, mut req: Request) -> Result> { let start_time = Instant::now(); let request_ids = RequestIds::new(self.state.config.dc_label()); req.extensions_mut().insert(request_ids); + tracing::Span::current() + .record("req_id", request_ids.req_id.to_string()) + .record("ray_id", request_ids.ray_id.to_string()); + // Create request context for analytics tracking let mut request_context = RequestContext::new(self.state.clickhouse_inserter.clone(), request_ids); @@ -2063,35 +2078,50 @@ impl ProxyService { // If we receive an error during a websocket request, we attempt to open the websocket anyway // so we can send the error via websocket instead of http. Most websocket clients don't handle - // HTTP errors in a meaningful way for the user resulting in unhelpful errors + // HTTP errors in a meaningful way resulting in unhelpful errors for the user if is_websocket { tracing::debug!("Upgrading client connection to WebSocket for error proxy"); match hyper_tungstenite::upgrade(mock_req, None) { Ok((client_response, client_ws)) => { tracing::debug!("Client WebSocket upgrade for error proxy successful"); - tokio::spawn(async move { - let ws_handle = WebSocketHandle::new(client_ws); - let frame = err_to_close_frame(err, Some(request_ids.ray_id)); + tokio::spawn( + async move { + let ws_handle = match WebSocketHandle::new(client_ws).await { + Ok(ws_handle) => ws_handle, + Err(err) => { + tracing::debug!( + ?err, + "failed initiating websocket handle for error proxy" + ); + return; + } + }; + let frame = err_to_close_frame(err, Some(request_ids.ray_id)); - // Manual conversion to handle different tungstenite versions - let code_num: u16 = frame.code.into(); - let reason = frame.reason.clone(); + // Manual conversion to handle different tungstenite versions + let code_num: u16 = frame.code.into(); + let reason = frame.reason.clone(); - if let Err(err) = ws_handle - .accept_and_send( - tokio_tungstenite::tungstenite::Message::Close(Some( + if let Err(err) = ws_handle + .send(tokio_tungstenite::tungstenite::Message::Close(Some( tokio_tungstenite::tungstenite::protocol::CloseFrame { code: code_num.into(), reason, }, - )), - ) - .await - { - tracing::debug!(?err, "failed sending error proxy"); + ))) + .await + { + tracing::debug!( + ?err, + "failed sending websocket error proxy" + ); + } } - }); + .instrument( + tracing::info_span!("ws_error_proxy_task", ?request_ids.ray_id), + ), + ); // Return the response that will upgrade the client connection // For proper WebSocket handshaking, we need to preserve the original response diff --git a/engine/packages/guard-core/src/websocket_handle.rs b/engine/packages/guard-core/src/websocket_handle.rs index bb17d2df3b..763f337b20 100644 --- a/engine/packages/guard-core/src/websocket_handle.rs +++ b/engine/packages/guard-core/src/websocket_handle.rs @@ -4,7 +4,6 @@ use hyper::upgrade::Upgraded; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::tungstenite::Message as WsMessage; use hyper_util::rt::TokioIo; -use std::ops::Deref; use std::sync::Arc; use tokio::sync::Mutex; use tokio_tungstenite::WebSocketStream; @@ -14,104 +13,34 @@ pub type WebSocketReceiver = futures_util::stream::SplitStream>, WsMessage>; -enum WebSocketState { - Unaccepted { websocket: HyperWebsocket }, - Accepting, - Split { ws_tx: WebSocketSender }, -} - #[derive(Clone)] -pub struct WebSocketHandle(Arc); - -impl WebSocketHandle { - pub fn new(websocket: HyperWebsocket) -> Self { - Self(Arc::new(WebSocketHandleInner { - state: Mutex::new(WebSocketState::Unaccepted { websocket }), - })) - } +pub struct WebSocketHandle { + ws_tx: Arc>, + ws_rx: Arc>, } -impl Deref for WebSocketHandle { - type Target = WebSocketHandleInner; - - fn deref(&self) -> &Self::Target { - &*self.0 - } -} - -pub struct WebSocketHandleInner { - state: Mutex, -} +impl WebSocketHandle { + pub async fn new(websocket: HyperWebsocket) -> Result { + let ws_stream = websocket.await?; + let (ws_tx, ws_rx) = ws_stream.split(); -impl WebSocketHandleInner { - pub async fn accept(&self) -> Result { - let mut state = self.state.lock().await; - Self::accept_inner(&mut *state).await + Ok(Self { + ws_tx: Arc::new(Mutex::new(ws_tx)), + ws_rx: Arc::new(Mutex::new(ws_rx)), + }) } pub async fn send(&self, message: WsMessage) -> Result<()> { - let mut state = self.state.lock().await; - match &mut *state { - WebSocketState::Unaccepted { .. } | WebSocketState::Accepting => { - bail!("websocket has not been accepted"); - } - WebSocketState::Split { ws_tx } => { - ws_tx.send(message).await?; - Ok(()) - } - } - } - - pub async fn accept_and_send(&self, message: WsMessage) -> Result<()> { - let mut state = self.state.lock().await; - match &mut *state { - WebSocketState::Unaccepted { .. } => { - let _ = Self::accept_inner(&mut *state).await?; - let WebSocketState::Split { ws_tx } = &mut *state else { - bail!("websocket should be accepted"); - }; - ws_tx.send(message).await?; - Ok(()) - } - WebSocketState::Accepting => { - bail!("in accepting state") - } - WebSocketState::Split { ws_tx } => { - ws_tx.send(message).await?; - Ok(()) - } - } + self.ws_tx.lock().await.send(message).await?; + Ok(()) } pub async fn flush(&self) -> Result<()> { - let mut state = self.state.lock().await; - match &mut *state { - WebSocketState::Unaccepted { .. } | WebSocketState::Accepting => { - bail!("websocket has not been accepted"); - } - WebSocketState::Split { ws_tx } => { - ws_tx.flush().await?; - Ok(()) - } - } + self.ws_tx.lock().await.flush().await?; + Ok(()) } - async fn accept_inner(state: &mut WebSocketState) -> Result { - if !matches!(*state, WebSocketState::Unaccepted { .. }) { - bail!("websocket already accepted") - } - - // Accept websocket - let old_state = std::mem::replace(&mut *state, WebSocketState::Accepting); - let WebSocketState::Unaccepted { websocket } = old_state else { - bail!("should be in unaccepted state"); - }; - - // Accept WS - let ws_stream = websocket.await?; - let (ws_tx, ws_rx) = ws_stream.split(); - *state = WebSocketState::Split { ws_tx }; - - Ok(ws_rx) + pub fn recv(&self) -> Arc> { + self.ws_rx.clone() } } diff --git a/engine/packages/guard/src/routing/api_public.rs b/engine/packages/guard/src/routing/api_public.rs index 43415122da..18a79e2162 100644 --- a/engine/packages/guard/src/routing/api_public.rs +++ b/engine/packages/guard/src/routing/api_public.rs @@ -50,6 +50,7 @@ impl CustomServeTrait for ApiPublicService { _headers: &hyper::HeaderMap, _path: &str, _request_context: &mut RequestContext, + _unique_request_id: Uuid, ) -> Result<()> { bail!("api-public does not support WebSocket connections") } diff --git a/engine/packages/pegboard-gateway/Cargo.toml b/engine/packages/pegboard-gateway/Cargo.toml index ec5d7df480..693bf8de57 100644 --- a/engine/packages/pegboard-gateway/Cargo.toml +++ b/engine/packages/pegboard-gateway/Cargo.toml @@ -12,6 +12,7 @@ bytes.workspace = true futures-util.workspace = true gas.workspace = true http-body-util.workspace = true +# TODO: Doesn't match workspace version hyper = "1.6" hyper-tungstenite.workspace = true pegboard.workspace = true @@ -20,7 +21,9 @@ rivet-error.workspace = true rivet-guard-core.workspace = true rivet-runner-protocol.workspace = true rivet-util.workspace = true +scc.workspace = true serde.workspace = true +serde_json.workspace = true thiserror.workspace = true tokio-tungstenite.workspace = true tokio.workspace = true diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index 230afa357c..9a753f71ce 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -5,6 +5,7 @@ use futures_util::TryStreamExt; use gas::prelude::*; use http_body_util::{BodyExt, Full}; use hyper::{Request, Response, StatusCode, header::HeaderName}; +use rivet_error::*; use rivet_guard_core::{ WebSocketHandle, custom_serve::CustomServeTrait, @@ -25,6 +26,16 @@ const TUNNEL_ACK_TIMEOUT: Duration = Duration::from_secs(2); const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol"); const WS_PROTOCOL_ACTOR: &str = "rivet_actor."; +#[derive(RivetError, Serialize, Deserialize)] +#[error( + "guard", + "websocket_pending_limit_reached", + "Reached limit on pending websocket messages, aborting connection." +)] +pub struct WebsocketPendingLimitReached { + limit: usize, +} + pub struct PegboardGateway { shared_state: SharedState, runner_id: Id, @@ -78,9 +89,10 @@ impl CustomServeTrait for PegboardGateway { pegboard::pubsub_subjects::RunnerReceiverSubject::new(self.runner_id).to_string(); // Start listening for request responses - let (request_id, mut msg_rx) = self + let request_id = Uuid::new_v4().into_bytes(); + let mut msg_rx = self .shared_state - .start_in_flight_request(tunnel_subject) + .start_in_flight_request(tunnel_subject, request_id) .await; // Start request @@ -157,6 +169,7 @@ impl CustomServeTrait for PegboardGateway { headers: &hyper::HeaderMap, _path: &str, _request_context: &mut RequestContext, + unique_request_id: Uuid, ) -> Result<()> { // Use the actor ID from the gateway instance let actor_id = self.actor_id.to_string(); @@ -174,9 +187,10 @@ impl CustomServeTrait for PegboardGateway { pegboard::pubsub_subjects::RunnerReceiverSubject::new(self.runner_id).to_string(); // Start listening for WebSocket messages - let (request_id, mut msg_rx) = self + let request_id = unique_request_id.into_bytes(); + let mut msg_rx = self .shared_state - .start_in_flight_request(tunnel_subject.clone()) + .start_in_flight_request(tunnel_subject.clone(), request_id) .await; // Send WebSocket open message @@ -232,8 +246,12 @@ impl CustomServeTrait for PegboardGateway { WebSocketServiceUnavailable.build() })??; - // Accept the WebSocket - let mut ws_rx = client_ws.accept().await?; + // Send reclaimed messages + self.shared_state + .send_reclaimed_messages(request_id) + .await?; + + let ws_rx = client_ws.recv(); // Spawn task to forward messages from server to client let mut server_to_client = tokio::spawn( @@ -266,7 +284,7 @@ impl CustomServeTrait for PegboardGateway { } } - tracing::debug!("sub closed"); + tracing::debug!("tunnel sub closed"); Err(WebSocketServiceUnavailable.build()) } @@ -277,6 +295,8 @@ impl CustomServeTrait for PegboardGateway { let shared_state_clone = self.shared_state.clone(); let mut client_to_server = tokio::spawn( async move { + let mut ws_rx = ws_rx.lock().await; + while let Some(msg) = ws_rx.try_next().await? { match msg { Message::Binary(data) => { diff --git a/engine/packages/pegboard-gateway/src/shared_state.rs b/engine/packages/pegboard-gateway/src/shared_state.rs index 7d93e4e93d..781402d1e0 100644 --- a/engine/packages/pegboard-gateway/src/shared_state.rs +++ b/engine/packages/pegboard-gateway/src/shared_state.rs @@ -1,18 +1,21 @@ use anyhow::Result; use gas::prelude::*; -use rivet_runner_protocol::{self as protocol, MessageId, PROTOCOL_VERSION, RequestId, versioned}; +use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, RequestId, versioned}; +use scc::HashMap; use std::{ - collections::HashMap, ops::Deref, sync::Arc, time::{Duration, Instant}, }; -use tokio::sync::{Mutex, mpsc}; +use tokio::sync::mpsc; use universalpubsub::{NextOutput, PubSub, PublishOpts, Subscriber}; use vbare::OwnedVersionedData; +use crate::WebsocketPendingLimitReached; + const GC_INTERVAL: Duration = Duration::from_secs(60); -const MESSAGE_ACK_TIMEOUT: Duration = Duration::from_secs(5); +const MESSAGE_ACK_TIMEOUT: Duration = Duration::from_secs(30); +const MAX_PENDING_MSGS_PER_REQ: usize = 1024; struct InFlightRequest { /// UPS subject to send messages to for this request. @@ -23,9 +26,15 @@ struct InFlightRequest { opened: bool, } -struct PendingMessage { - request_id: RequestId, +pub struct PendingMessage { send_instant: Instant, + payload: protocol::ToClientTunnelMessage, +} + +impl Into for PendingMessage { + fn into(self) -> protocol::ToClientTunnelMessage { + self.payload + } } pub enum TunnelMessageData { @@ -36,8 +45,8 @@ pub enum TunnelMessageData { pub struct SharedStateInner { ups: PubSub, receiver_subject: String, - requests_in_flight: Mutex>, - pending_messages: Mutex>, + requests_in_flight: HashMap, + pending_messages: HashMap>, } #[derive(Clone)] @@ -52,8 +61,8 @@ impl SharedState { Self(Arc::new(SharedStateInner { ups, receiver_subject, - requests_in_flight: Mutex::new(HashMap::new()), - pending_messages: Mutex::new(HashMap::new()), + requests_in_flight: HashMap::new(), + pending_messages: HashMap::new(), })) } @@ -69,6 +78,27 @@ impl SharedState { Ok(()) } + pub async fn start_in_flight_request( + &self, + receiver_subject: String, + request_id: RequestId, + ) -> mpsc::Receiver { + let (msg_tx, msg_rx) = mpsc::channel(128); + + self.requests_in_flight + .upsert_async( + request_id, + InFlightRequest { + receiver_subject, + msg_tx, + opened: false, + }, + ) + .await; + + msg_rx + } + pub async fn send_message( &self, request_id: RequestId, @@ -78,35 +108,23 @@ impl SharedState { // Get subject and whether this is the first message for this request let (tunnel_receiver_subject, include_reply_to) = { - let mut requests_in_flight = self.requests_in_flight.lock().await; - if let Some(req) = requests_in_flight.get_mut(&request_id) { + if let Some(mut req) = self.requests_in_flight.get_async(&request_id).await { let receiver_subject = req.receiver_subject.clone(); + let include_reply_to = !req.opened; if include_reply_to { // Mark as opened so subsequent messages skip reply_to req.opened = true; } + (receiver_subject, include_reply_to) } else { - bail!("request not in flight") + bail!("request not in flight"); } }; - // Save pending message - { - let mut pending_messages = self.pending_messages.lock().await; - pending_messages.insert( - message_id, - PendingMessage { - request_id, - send_instant: Instant::now(), - }, - ); - } - - // Send message - let message = protocol::ToClient::ToClientTunnelMessage(protocol::ToClientTunnelMessage { - request_id, + let payload = protocol::ToClientTunnelMessage { + request_id: request_id.clone(), message_id, // Only send reply to subject on the first message for this request. This reduces // overhead of subsequent messages. @@ -116,7 +134,37 @@ impl SharedState { None }, message_kind, - }); + }; + + // Save pending message + { + let pending_msg = PendingMessage { + send_instant: Instant::now(), + payload: payload.clone(), + }; + let mut pending_msgs_by_req_id = self + .pending_messages + .entry_async(request_id) + .await + .or_insert_with(Vec::new); + let pending_msgs_by_req_id = pending_msgs_by_req_id.get_mut(); + + tracing::info!(l=?pending_msgs_by_req_id.len(), message_id=?Uuid::from_bytes(payload.message_id), request_id=?Uuid::from_bytes(payload.request_id), "new msg -----------"); + + if pending_msgs_by_req_id.len() >= MAX_PENDING_MSGS_PER_REQ { + self.pending_messages.remove_async(&request_id).await; + + return Err(WebsocketPendingLimitReached { + limit: MAX_PENDING_MSGS_PER_REQ, + } + .build()); + } + + pending_msgs_by_req_id.push(pending_msg); + } + + // Send message + let message = protocol::ToClient::ToClientTunnelMessage(payload); let message_serialized = versioned::ToClient::latest(message) .serialize_with_embedded_version(PROTOCOL_VERSION)?; self.ups @@ -130,23 +178,6 @@ impl SharedState { Ok(()) } - pub async fn start_in_flight_request( - &self, - receiver_subject: String, - ) -> (RequestId, mpsc::Receiver) { - let id = Uuid::new_v4().into_bytes(); - let (msg_tx, msg_rx) = mpsc::channel(128); - self.requests_in_flight.lock().await.insert( - id, - InFlightRequest { - receiver_subject, - msg_tx, - opened: false, - }, - ); - (id, msg_rx) - } - async fn receiver(&self, mut sub: Subscriber) { while let Ok(NextOutput::Message(msg)) = sub.next().await { tracing::trace!( @@ -157,23 +188,27 @@ impl SharedState { match versioned::ToGateway::deserialize_with_embedded_version(&msg.payload) { Ok(protocol::ToGateway { message: msg }) => { tracing::debug!( - ?msg.request_id, - ?msg.message_id, + request_id=?Uuid::from_bytes(msg.request_id), + message_id=?Uuid::from_bytes(msg.message_id), "successfully deserialized message" ); if let protocol::ToServerTunnelMessageKind::TunnelAck = &msg.message_kind { + tracing::info!(message_id=?Uuid::from_bytes(msg.message_id), request_id=?Uuid::from_bytes(msg.request_id), "ack -----------"); // Handle ack message - - let mut pending_messages = self.pending_messages.lock().await; - if pending_messages.remove(&msg.message_id).is_none() { + if let Some(mut pending_msgs) = + self.pending_messages.get_async(&msg.request_id).await + { + pending_msgs.retain(|m| m.payload.message_id != msg.message_id); + } else { tracing::warn!( "pending message does not exist or ack received after message body" - ); - } + ) + }; } else { // Send message to the request handler to emulate the real network action - let requests_in_flight = self.requests_in_flight.lock().await; - let Some(in_flight) = requests_in_flight.get(&msg.request_id) else { + let Some(in_flight) = + self.requests_in_flight.get_async(&msg.request_id).await + else { tracing::debug!( ?msg.request_id, "in flight has already been disconnected" @@ -230,6 +265,38 @@ impl SharedState { } } + pub async fn send_reclaimed_messages(&self, request_id: RequestId) -> Result<()> { + let receiver_subject = + if let Some(req) = self.requests_in_flight.get_async(&request_id).await { + req.receiver_subject.clone() + } else { + bail!("request not in flight"); + }; + + // When a request is started again, read all of its pending messages and send them to the new receiver + let Some(entry) = self.pending_messages.get_async(&request_id).await else { + return Ok(()); + }; + let reclaimed_pending_msgs = entry.get(); + + if !reclaimed_pending_msgs.is_empty() { + tracing::debug!(request_id=?Uuid::from_bytes(request_id.clone()), "resending pending messages"); + + for pending_msg in reclaimed_pending_msgs { + // Send message + let message = + protocol::ToClient::ToClientTunnelMessage(pending_msg.payload.clone()); + let message_serialized = versioned::ToClient::latest(message) + .serialize_with_embedded_version(PROTOCOL_VERSION)?; + self.ups + .publish(&receiver_subject, &message_serialized, PublishOpts::one()) + .await?; + } + } + + Ok(()) + } + async fn gc(&self) { let mut interval = tokio::time::interval(GC_INTERVAL); loop { @@ -238,38 +305,39 @@ impl SharedState { let now = Instant::now(); // Purge unacked messages - { - let mut pending_messages = self.pending_messages.lock().await; - let mut removed_req_ids = Vec::new(); - pending_messages.retain(|_k, v| { - if now.duration_since(v.send_instant) > MESSAGE_ACK_TIMEOUT { - // Expired - removed_req_ids.push(v.request_id.clone()); - false + let mut expired_req_ids = Vec::new(); + self.pending_messages + .retain_async(|request_id, pending_msgs| { + if let Some(pending_msg) = pending_msgs.first() { + if now.duration_since(pending_msg.send_instant) > MESSAGE_ACK_TIMEOUT { + // Expired + expired_req_ids.push(request_id.clone()); + false + } else { + true + } } else { - true + false } - }); + }) + .await; - // Close in-flight messages - let requests_in_flight = self.requests_in_flight.lock().await; - for req_id in removed_req_ids { - if let Some(x) = requests_in_flight.get(&req_id) { - let _ = x.msg_tx.send(TunnelMessageData::Timeout); - } else { - tracing::warn!( - ?req_id, - "message expired for in flight that does not exist" - ); - } + // Close in-flight requests for expired messages + for request_id in expired_req_ids { + if let Some(x) = self.requests_in_flight.get_async(&request_id).await { + let _ = x.msg_tx.send(TunnelMessageData::Timeout); + } else { + tracing::debug!( + request_id=?Uuid::from_bytes(request_id), + "message expired for in flight that does not exist" + ); } } // Purge no longer in flight - { - let mut requests_in_flight = self.requests_in_flight.lock().await; - requests_in_flight.retain(|_k, v| !v.msg_tx.is_closed()); - } + self.requests_in_flight + .retain_async(|_k, v| !v.msg_tx.is_closed()) + .await; } } } diff --git a/engine/packages/pegboard-runner/src/client_to_pubsub_task.rs b/engine/packages/pegboard-runner/src/client_to_pubsub_task.rs index 99b9ec6e1c..c830c5d06a 100644 --- a/engine/packages/pegboard-runner/src/client_to_pubsub_task.rs +++ b/engine/packages/pegboard-runner/src/client_to_pubsub_task.rs @@ -8,14 +8,22 @@ use pegboard_actor_kv as kv; use rivet_guard_core::websocket_handle::WebSocketReceiver; use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; use std::sync::{Arc, atomic::Ordering}; +use tokio::sync::Mutex; use universalpubsub::PublishOpts; use vbare::OwnedVersionedData; use crate::conn::Conn; #[tracing::instrument(skip_all, fields(runner_id=?conn.runner_id, workflow_id=?conn.workflow_id, protocol_version=%conn.protocol_version))] -pub async fn task(ctx: StandaloneCtx, conn: Arc, mut ws_rx: WebSocketReceiver) -> Result<()> { +pub async fn task( + ctx: StandaloneCtx, + conn: Arc, + ws_rx: Arc>, +) -> Result<()> { tracing::debug!("starting WebSocket to pubsub forwarding task"); + + let mut ws_rx = ws_rx.lock().await; + while let Some(msg) = ws_rx.try_next().await? { match msg { WsMessage::Binary(data) => { diff --git a/engine/packages/pegboard-runner/src/conn.rs b/engine/packages/pegboard-runner/src/conn.rs index 7649717cd5..5b0b0e99a9 100644 --- a/engine/packages/pegboard-runner/src/conn.rs +++ b/engine/packages/pegboard-runner/src/conn.rs @@ -4,7 +4,7 @@ use gas::prelude::Id; use gas::prelude::*; use hyper_tungstenite::tungstenite::Message; use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility}; -use rivet_guard_core::{WebSocketHandle, websocket_handle::WebSocketReceiver}; +use rivet_guard_core::WebSocketHandle; use rivet_runner_protocol as protocol; use rivet_runner_protocol::*; use std::{ @@ -42,7 +42,6 @@ pub struct Conn { pub async fn init_conn( ctx: &StandaloneCtx, ws_handle: WebSocketHandle, - ws_rx: &mut WebSocketReceiver, UrlData { protocol_version, namespace, @@ -59,6 +58,9 @@ pub async fn init_conn( tracing::debug!("new runner connection"); + let ws_rx = ws_handle.recv(); + let mut ws_rx = ws_rx.lock().await; + // Receive init packet let (runner_id, workflow_id) = if let Some(msg) = tokio::time::timeout(Duration::from_secs(5), ws_rx.next()) diff --git a/engine/packages/pegboard-runner/src/lib.rs b/engine/packages/pegboard-runner/src/lib.rs index 95b4a1591d..dd30703bba 100644 --- a/engine/packages/pegboard-runner/src/lib.rs +++ b/engine/packages/pegboard-runner/src/lib.rs @@ -61,6 +61,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { _headers: &hyper::HeaderMap, path: &str, _request_context: &mut RequestContext, + _unique_request_id: Uuid, ) -> Result<()> { // Get UPS let ups = self.ctx.ups().context("failed to get UPS instance")?; @@ -73,14 +74,8 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { tracing::debug!(?path, "tunnel ws connection established"); - // Accept WS - let mut ws_rx = ws_handle - .accept() - .await - .context("failed to accept WebSocket connection")?; - // Create connection - let conn = conn::init_conn(&self.ctx, ws_handle.clone(), &mut ws_rx, url_data) + let conn = conn::init_conn(&self.ctx, ws_handle.clone(), url_data) .await .context("failed to initialize runner connection")?; @@ -101,7 +96,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { let mut client_to_pubsub = tokio::spawn(client_to_pubsub_task::task( self.ctx.clone(), conn.clone(), - ws_rx, + ws_handle.recv(), )); // Update pings diff --git a/engine/packages/pegboard-runner/src/pubsub_to_client_task.rs b/engine/packages/pegboard-runner/src/pubsub_to_client_task.rs index 9dc4179a2a..6424454ed9 100644 --- a/engine/packages/pegboard-runner/src/pubsub_to_client_task.rs +++ b/engine/packages/pegboard-runner/src/pubsub_to_client_task.rs @@ -42,7 +42,7 @@ pub async fn task(conn: Arc, mut sub: Subscriber) -> Result<()> { // This will remove gateway_reply_to from the message since it does not need to be sent to the // client if let Some(reply_to) = tunnel_msg.gateway_reply_to.take() { - tracing::debug!(?tunnel_msg.request_id, ?reply_to, "creating active request"); + tracing::debug!(request_id=?Uuid::from_bytes(tunnel_msg.request_id), ?reply_to, "creating active request"); let mut active_requests = conn.tunnel_active_requests.lock().await; active_requests.insert( tunnel_msg.request_id, @@ -55,7 +55,7 @@ pub async fn task(conn: Arc, mut sub: Subscriber) -> Result<()> { match tunnel_msg.message_kind { // If terminal, remove active request tracking protocol::ToClientTunnelMessageKind::ToClientWebSocketClose(_) => { - tracing::debug!(?tunnel_msg.request_id, "removing active conn due to close message"); + tracing::debug!(request_id=?Uuid::from_bytes(tunnel_msg.request_id), "removing active conn due to close message"); let mut active_requests = conn.tunnel_active_requests.lock().await; active_requests.remove(&tunnel_msg.request_id); } diff --git a/engine/packages/pegboard-serverless/src/lib.rs b/engine/packages/pegboard-serverless/src/lib.rs index fb58597f64..ef1b94c822 100644 --- a/engine/packages/pegboard-serverless/src/lib.rs +++ b/engine/packages/pegboard-serverless/src/lib.rs @@ -383,7 +383,7 @@ async fn outbound_handler( } } Err(sse::Error::StreamEnded) => { - tracing::debug!("outbound req stopped early"); + tracing::debug!(?runner_id, "outbound req stopped early"); return Ok(()); } @@ -417,7 +417,7 @@ async fn outbound_handler( match event { Ok(sse::Event::Open) => {} Ok(sse::Event::Message(msg)) => { - tracing::debug!(%msg.data, "received outbound req message"); + tracing::debug!(%msg.data, ?runner_id, "received outbound req message"); // If runner_id is none at this point it means we did not send the stopping signal yet, so // send it now @@ -451,7 +451,7 @@ async fn outbound_handler( tokio::select! { res = wait_for_shutdown_fut => return res.map_err(Into::into), _ = tokio::time::sleep(DRAIN_GRACE_PERIOD) => { - tracing::debug!("reached drain grace period before runner shut down") + tracing::debug!(?runner_id, "reached drain grace period before runner shut down") } } @@ -463,7 +463,7 @@ async fn outbound_handler( publish_to_client_stop(ctx, runner_id).await?; } - tracing::debug!("outbound req stopped"); + tracing::debug!(?runner_id, "outbound req stopped"); Ok(()) } diff --git a/engine/packages/pegboard/src/workflows/actor/runtime.rs b/engine/packages/pegboard/src/workflows/actor/runtime.rs index e98eb5146a..c2985111a4 100644 --- a/engine/packages/pegboard/src/workflows/actor/runtime.rs +++ b/engine/packages/pegboard/src/workflows/actor/runtime.rs @@ -640,20 +640,20 @@ pub async fn reschedule_actor( }; state.reschedule_state.last_retry_ts = now; - // Don't sleep for first retry - if state.reschedule_state.retry_count > 0 { - let next = backoff.step().expect("should not have max retry"); - - // Sleep for backoff or destroy early - if let Some(_sig) = ctx - .listen_with_timeout::(Instant::from(next) - Instant::now()) - .await? - { - tracing::debug!("destroying before actor start"); - - return Ok(SpawnActorOutput::Destroy); - } - } + // // Don't sleep for first retry + // if state.reschedule_state.retry_count > 0 { + // let next = backoff.step().expect("should not have max retry"); + + // // Sleep for backoff or destroy early + // if let Some(_sig) = ctx + // .listen_with_timeout::(Instant::from(next) - Instant::now()) + // .await? + // { + // tracing::debug!("destroying before actor start"); + + // return Ok(SpawnActorOutput::Destroy); + // } + // } let next_generation = state.generation + 1; let spawn_res = spawn_actor( @@ -726,7 +726,7 @@ struct CompareRetryInput { async fn compare_retry(ctx: &ActivityCtx, input: &CompareRetryInput) -> Result<(i64, bool)> { let now = util::timestamp::now(); - // If the last retry ts is more than RETRY_RESET_DURATION_MS, reset retry count + // If the last retry ts is more than RETRY_RESET_DURATION_MS ago, reset retry count Ok((now, input.last_retry_ts < now - RETRY_RESET_DURATION_MS)) } diff --git a/engine/packages/universalpubsub/src/driver/postgres/mod.rs b/engine/packages/universalpubsub/src/driver/postgres/mod.rs index c2f20b68d0..016f6ad17b 100644 --- a/engine/packages/universalpubsub/src/driver/postgres/mod.rs +++ b/engine/packages/universalpubsub/src/driver/postgres/mod.rs @@ -292,10 +292,9 @@ impl PubSubDriver for PostgresDriver { // Try to LISTEN if client is available, but don't fail if disconnected // The reconnection logic will handle re-subscribing if let Some(client) = self.client.lock().await.clone() { - let span = tracing::trace_span!("pg_listen"); match client .execute(&format!("LISTEN \"{hashed}\""), &[]) - .instrument(span) + .instrument(tracing::trace_span!("pg_listen")) .await { Result::Ok(_) => { @@ -368,10 +367,9 @@ impl PubSubDriver for PostgresDriver { match conn.execute("SELECT 1", &[]).await { Result::Ok(_) => { // Connection is good, use it for NOTIFY - let span = tracing::trace_span!("pg_notify"); match conn .execute(&format!("NOTIFY \"{hashed}\", '{encoded}'"), &[]) - .instrument(span) + .instrument(tracing::trace_span!("pg_notify")) .await { Result::Ok(_) => return Ok(()), diff --git a/engine/sdks/typescript/runner/src/mod.ts b/engine/sdks/typescript/runner/src/mod.ts index 8f9d2ff2c8..5e8524d408 100644 --- a/engine/sdks/typescript/runner/src/mod.ts +++ b/engine/sdks/typescript/runner/src/mod.ts @@ -155,9 +155,6 @@ export class Runner { const actor = this.#removeActor(actorId, generation); if (!actor) return; - // Unregister actor from tunnel - this.#tunnel?.unregisterActor(actor); - // If onActorStop times out, Pegboard will handle this timeout with ACTOR_STOP_THRESHOLD_DURATION_MS try { await this.#config.onActorStop(actorId, actor.generation); @@ -246,23 +243,8 @@ export class Runner { this.#actors.delete(actorId); - // Close all WebSocket connections for this actor - const actorWebSockets = this.#actorWebSockets.get(actorId); - if (actorWebSockets) { - for (const ws of actorWebSockets) { - try { - ws.close(1000, "Actor stopped"); - } catch (err) { - logger()?.error({ - msg: "error closing websocket for actor", - runnerId: this.runnerId, - actorId, - err, - }); - } - } - this.#actorWebSockets.delete(actorId); - } + // Unregister actor from tunnel + this.#tunnel?.unregisterActor(actor); return actor; } @@ -1376,12 +1358,22 @@ export class Runner { return; } + logger()?.debug({ + msg: "------------ SEND", + }); + const encoded = protocol.encodeToServer(message); if ( this.#pegboardWebSocket && this.#pegboardWebSocket.readyState === 1 ) { + logger()?.debug({ + msg: "------------ SEND 2", + }); this.#pegboardWebSocket.send(encoded); + logger()?.debug({ + msg: "------------ SEND 3", + }); } else { logger()?.error({ msg: "WebSocket not available or not open for sending data", diff --git a/engine/sdks/typescript/runner/src/tunnel.ts b/engine/sdks/typescript/runner/src/tunnel.ts index 3e9dfc24e2..09e32a1c40 100644 --- a/engine/sdks/typescript/runner/src/tunnel.ts +++ b/engine/sdks/typescript/runner/src/tunnel.ts @@ -1,6 +1,6 @@ import type * as protocol from "@rivetkit/engine-runner-protocol"; import type { MessageId, RequestId } from "@rivetkit/engine-runner-protocol"; -import { v4 as uuidv4 } from "uuid"; +import { v4 as uuidv4, stringify as uuidstringify } from "uuid"; import { logger } from "./log"; import type { ActorInstance, Runner } from "./mod"; import { unreachable } from "./utils"; @@ -95,6 +95,13 @@ export class Tunnel { } #sendAck(requestId: RequestId, messageId: MessageId) { + logger()?.debug({ + msg: "------------ tunnel ws ready", + ready: this.#runner.__webSocketReady(), + requestId: uuidstringify(new Uint8Array(requestId)), + messageId: uuidstringify(new Uint8Array(messageId)), + }); + if (!this.#runner.__webSocketReady()) { return; } @@ -108,6 +115,12 @@ export class Tunnel { }, }; + logger()?.debug({ + msg: "ack tunnel msg", + requestId: uuidstringify(new Uint8Array(requestId)), + messageId: uuidstringify(new Uint8Array(messageId)), + }); + this.#runner.__sendToServer(message); } @@ -224,6 +237,13 @@ export class Tunnel { } async handleTunnelMessage(message: protocol.ToClientTunnelMessage) { + logger()?.debug({ + msg: "tunnel msg", + requestId: uuidstringify(new Uint8Array(message.requestId)), + messageId: uuidstringify(new Uint8Array(message.messageId)), + message: message.messageKind, + }); + if (message.messageKind.tag === "TunnelAck") { // Mark pending message as acknowledged and remove it const msgIdStr = bufferToString(message.messageId); @@ -232,36 +252,55 @@ export class Tunnel { this.#pendingTunnelMessages.delete(msgIdStr); } } else { - this.#sendAck(message.requestId, message.messageId); switch (message.messageKind.tag) { case "ToClientRequestStart": + this.#sendAck(message.requestId, message.messageId); + await this.#handleRequestStart( message.requestId, message.messageKind.val, ); break; case "ToClientRequestChunk": + this.#sendAck(message.requestId, message.messageId); + await this.#handleRequestChunk( message.requestId, message.messageKind.val, ); break; case "ToClientRequestAbort": + this.#sendAck(message.requestId, message.messageId); + await this.#handleRequestAbort(message.requestId); break; case "ToClientWebSocketOpen": + this.#sendAck(message.requestId, message.messageId); + await this.#handleWebSocketOpen( message.requestId, message.messageKind.val, ); break; case "ToClientWebSocketMessage": - await this.#handleWebSocketMessage( + let unhandled = await this.#handleWebSocketMessage( message.requestId, message.messageKind.val, ); + logger()?.debug({ + msg: "------------ unhandled", + unhandled, + requestId: uuidstringify(new Uint8Array(message.requestId)), + messageId: uuidstringify(new Uint8Array(message.messageId)), + }); + + if (!unhandled) { + this.#sendAck(message.requestId, message.messageId); + } break; case "ToClientWebSocketClose": + this.#sendAck(message.requestId, message.messageId); + await this.#handleWebSocketClose( message.requestId, message.messageKind.val, @@ -569,10 +608,16 @@ export class Tunnel { } } + /// Returns false if the message was sent off async #handleWebSocketMessage( requestId: ArrayBuffer, msg: protocol.ToServerWebSocketMessage, - ) { + ): Promise { + logger()?.debug({ + msg: "adapter handle msg", + requestId: uuidstringify(new Uint8Array(requestId)), + }); + const webSocketId = bufferToString(requestId); const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { @@ -580,7 +625,9 @@ export class Tunnel { ? new Uint8Array(msg.data) : new TextDecoder().decode(new Uint8Array(msg.data)); - adapter._handleMessage(data, msg.binary); + return adapter._handleMessage(data, msg.binary); + } else { + return true; } } diff --git a/engine/sdks/typescript/runner/src/websocket-tunnel-adapter.ts b/engine/sdks/typescript/runner/src/websocket-tunnel-adapter.ts index eb46758d94..3d5ca3af55 100644 --- a/engine/sdks/typescript/runner/src/websocket-tunnel-adapter.ts +++ b/engine/sdks/typescript/runner/src/websocket-tunnel-adapter.ts @@ -254,6 +254,11 @@ export class WebSocketTunnelAdapter { let hasListeners = false; if (listeners && listeners.size > 0) { + logger()?.debug({ + msg: "------------- listeners", + l: listeners?.size + }); + hasListeners = true; for (const listener of listeners) { try { @@ -311,6 +316,10 @@ export class WebSocketTunnelAdapter { break; case "message": if (this.#onmessage) { + logger()?.debug({ + msg: "---------------- msg", + }); + hasListeners = true; try { this.#onmessage.call(this, event); @@ -326,6 +335,11 @@ export class WebSocketTunnelAdapter { // Buffer the event if no listeners are registered if (!hasListeners) { + logger()?.debug({ + msg: "------------- no listeners", + type, + }); + this.#bufferedEvents.push({ type, event }); } } @@ -338,6 +352,12 @@ export class WebSocketTunnelAdapter { (buffered) => buffered.type !== type, ); + logger()?.debug({ + msg: "------------- flush", + type, + l: eventsToFlush.length + }); + for (const { event } of eventsToFlush) { // Re-fire the event, which will now have listeners const listeners = this.#eventListeners.get(type); @@ -426,10 +446,11 @@ export class WebSocketTunnelAdapter { this.#fireEvent("open", event); } - _handleMessage(data: string | Uint8Array, isBinary: boolean): void { + /// Returns false if the message was sent off. + _handleMessage(data: string | Uint8Array, isBinary: boolean): boolean { if (this.#readyState !== 1) { // OPEN - return; + return true; } let messageData: any; @@ -465,7 +486,13 @@ export class WebSocketTunnelAdapter { target: this, }; + logger()?.debug({ + msg: "------------ tunnel fire event", + }); + this.#fireEvent("message", event); + + return false; } _handleClose(code?: number, reason?: string): void { diff --git a/scripts/tests/actor_sleep.ts b/scripts/tests/actor_sleep.ts index e39765b47a..fc8d8d0102 100755 --- a/scripts/tests/actor_sleep.ts +++ b/scripts/tests/actor_sleep.ts @@ -17,6 +17,8 @@ async function main() { console.log("Actor created:", actorResponse.actor); for (let i = 0; i < 10; i++) { + await testWebSocket(actorResponse.actor.actor_id); + console.log("Sleeping actor..."); const actorSleepResponse = await fetch(`${RIVET_ENDPOINT}/sleep`, { method: "GET", @@ -38,7 +40,6 @@ async function main() { // await new Promise(resolve => setTimeout(resolve, 2000)); } - // Make a request to the actor console.log("Making request to actor..."); const actorPingResponse = await fetch(`${RIVET_ENDPOINT}/ping`, { @@ -59,8 +60,6 @@ async function main() { } console.log("Actor ping response:", pingResult); - - // await testWebSocket(actorResponse.actor.actor_id); } catch (error) { console.error(`Actor test failed:`, error); } @@ -89,14 +88,6 @@ function testWebSocket(actorId: string): Promise { let pingReceived = false; let echoReceived = false; - const timeout = setTimeout(() => { - console.log( - "No response received within timeout, but connection was established", - ); - // Connection was established, that's enough for the test - ws.close(); - resolve(); - }, 2000); ws.addEventListener("open", () => { console.log("WebSocket connected"); @@ -126,21 +117,18 @@ function testWebSocket(actorId: string): Promise { console.log("Echo test successful!"); // All tests passed - clearTimeout(timeout); ws.close(); resolve(); } }); - ws.addEventListener("error", (error) => { - clearTimeout(timeout); - reject(new Error(`WebSocket error: ${error.message}`)); + ws.addEventListener("error", (event) => { + reject(new Error(`WebSocket error: ${event}`)); }); - ws.addEventListener("close", () => { - clearTimeout(timeout); + ws.addEventListener("close", event => { if (!pingReceived || !echoReceived) { - reject(new Error("WebSocket closed before completing tests")); + reject(new Error(`WebSocket closed before completing tests: ${event.code} (${event.reason}) ${new Date().toISOString()}`)); } }); });