From b6dfb0d990848f040c4e64d6d1e162206f350fc4 Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Thu, 20 Nov 2025 14:25:48 -0800 Subject: [PATCH] fix(gateway): prevent gc from removing hibernating in flight req, check actor started after sub when hibernating --- Cargo.lock | 1 + engine/packages/guard/src/lib.rs | 2 +- engine/packages/guard/src/shared_state.rs | 4 +- engine/packages/pegboard-gateway/Cargo.toml | 1 + .../pegboard-gateway/src/keepalive_task.rs | 20 ++- engine/packages/pegboard-gateway/src/lib.rs | 77 ++++++----- .../pegboard-gateway/src/shared_state.rs | 124 +++++++++++------- .../pegboard-gateway/src/tunnel_to_ws_task.rs | 2 +- 8 files changed, 142 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1920300f8e..22dfbac133 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3488,6 +3488,7 @@ dependencies = [ "lazy_static", "pegboard", "rand 0.8.5", + "rivet-config", "rivet-error", "rivet-guard-core", "rivet-metrics", diff --git a/engine/packages/guard/src/lib.rs b/engine/packages/guard/src/lib.rs index a09e0fd912..e72c7c6edd 100644 --- a/engine/packages/guard/src/lib.rs +++ b/engine/packages/guard/src/lib.rs @@ -28,7 +28,7 @@ pub async fn start(config: rivet_config::Config, pools: rivet_pools::Pools) -> R } // Share shared context - let shared_state = shared_state::SharedState::new(ctx.ups()?); + let shared_state = shared_state::SharedState::new(&config, ctx.ups()?); shared_state.start().await?; // Create handlers diff --git a/engine/packages/guard/src/shared_state.rs b/engine/packages/guard/src/shared_state.rs index 0462fad5c8..cabfe8cd3c 100644 --- a/engine/packages/guard/src/shared_state.rs +++ b/engine/packages/guard/src/shared_state.rs @@ -6,9 +6,9 @@ use universalpubsub::PubSub; pub struct SharedState(Arc); impl SharedState { - pub fn new(pubsub: PubSub) -> SharedState { + pub fn new(config: &rivet_config::Config, pubsub: PubSub) -> SharedState { SharedState(Arc::new(SharedStateInner { - pegboard_gateway: pegboard_gateway::shared_state::SharedState::new(pubsub), + pegboard_gateway: pegboard_gateway::shared_state::SharedState::new(config, pubsub), })) } diff --git a/engine/packages/pegboard-gateway/Cargo.toml b/engine/packages/pegboard-gateway/Cargo.toml index 31f1d9162c..424293844e 100644 --- a/engine/packages/pegboard-gateway/Cargo.toml +++ b/engine/packages/pegboard-gateway/Cargo.toml @@ -18,6 +18,7 @@ hyper-tungstenite.workspace = true lazy_static.workspace = true pegboard.workspace = true rand.workspace = true +rivet-config.workspace = true rivet-error.workspace = true rivet-guard-core.workspace = true rivet-metrics.workspace = true diff --git a/engine/packages/pegboard-gateway/src/keepalive_task.rs b/engine/packages/pegboard-gateway/src/keepalive_task.rs index c726ba42ad..597892019a 100644 --- a/engine/packages/pegboard-gateway/src/keepalive_task.rs +++ b/engine/packages/pegboard-gateway/src/keepalive_task.rs @@ -6,12 +6,15 @@ use std::time::Duration; use tokio::sync::watch; use super::LifecycleResult; +use crate::shared_state::SharedState; /// Periodically pings writes keepalive in UDB. This is used to restore hibernating request IDs on /// next actor start. /// -///Only ran for hibernating requests. +/// Only ran for hibernating requests. + pub async fn task( + shared_state: SharedState, ctx: StandaloneCtx, actor_id: Id, gateway_id: GatewayId, @@ -43,11 +46,14 @@ pub async fn task( let jitter = { rand::thread_rng().gen_range(0..128) }; tokio::time::sleep(Duration::from_millis(jitter)).await; - ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input { - actor_id, - gateway_id, - request_id, - }) - .await?; + tokio::try_join!( + ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input { + actor_id, + gateway_id, + request_id, + }), + // Keep alive in flight req during hibernation + shared_state.keepalive_hws(request_id), + )?; } } diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index 1dd4bab1bb..cfcd1da971 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -167,6 +167,7 @@ impl CustomServeTrait for PegboardGateway { let InFlightRequestHandle { mut msg_rx, mut drop_rx, + .. } = self .shared_state .start_in_flight_request(tunnel_subject, request_id) @@ -212,7 +213,7 @@ impl CustomServeTrait for PegboardGateway { } } else { tracing::warn!( - request_id=?tunnel_id::request_id_to_string(&request_id), + request_id=%tunnel_id::request_id_to_string(&request_id), "received no message response during request init", ); break; @@ -267,14 +268,14 @@ impl CustomServeTrait for PegboardGateway { Ok(response) } - #[tracing::instrument(skip_all, fields(actor_id=?self.actor_id, runner_id=?self.runner_id))] + #[tracing::instrument(skip_all, fields(actor_id=?self.actor_id, runner_id=?self.runner_id, request_id=%tunnel_id::request_id_to_string(&request_id)))] async fn handle_websocket( &self, client_ws: WebSocketHandle, headers: &hyper::HeaderMap, _path: &str, _request_context: &mut RequestContext, - unique_request_id: RequestId, + request_id: RequestId, after_hibernation: bool, ) -> Result> { // Use the actor ID from the gateway instance @@ -298,15 +299,20 @@ impl CustomServeTrait for PegboardGateway { pegboard::pubsub_subjects::RunnerReceiverSubject::new(self.runner_id).to_string(); // Start listening for WebSocket messages - let request_id = unique_request_id; let InFlightRequestHandle { mut msg_rx, mut drop_rx, + new, } = self .shared_state .start_in_flight_request(tunnel_subject.clone(), request_id) .await; + ensure!( + !after_hibernation || !new, + "should not be creating a new in flight entry after hibernation" + ); + // If we are reconnecting after hibernation, don't send an open message let can_hibernate = if after_hibernation { true @@ -348,7 +354,7 @@ impl CustomServeTrait for PegboardGateway { } } else { tracing::warn!( - request_id=?tunnel_id::request_id_to_string(&request_id), + request_id=%tunnel_id::request_id_to_string(&request_id), "received no message response during ws init", ); break; @@ -416,17 +422,23 @@ impl CustomServeTrait for PegboardGateway { request_id, ping_abort_rx, )); + let keepalive = if can_hibernate { + Some(tokio::spawn(keepalive_task::task( + self.shared_state.clone(), + self.ctx.clone(), + self.actor_id, + self.shared_state.gateway_id(), + request_id, + keepalive_abort_rx, + ))) + } else { + None + }; let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx.clone(); let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone(); let ping_abort_tx2 = ping_abort_tx.clone(); - // Clone variables needed for keepalive task - let ctx_clone = self.ctx.clone(); - let actor_id_clone = self.actor_id; - let gateway_id_clone = self.shared_state.gateway_id(); - let request_id_clone = request_id; - // Wait for all tasks to complete let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res) = tokio::join!( async { @@ -478,17 +490,9 @@ impl CustomServeTrait for PegboardGateway { res }, async { - if !can_hibernate { + let Some(keepalive) = keepalive else { return Ok(LifecycleResult::Aborted); - } - - let keepalive = tokio::spawn(keepalive_task::task( - ctx_clone, - actor_id_clone, - gateway_id_clone, - request_id_clone, - keepalive_abort_rx, - )); + }; let res = keepalive.await?; @@ -568,14 +572,12 @@ impl CustomServeTrait for PegboardGateway { } } - #[tracing::instrument(skip_all, fields(actor_id=?self.actor_id))] + #[tracing::instrument(skip_all, fields(actor_id=?self.actor_id, request_id=%tunnel_id::request_id_to_string(&request_id)))] async fn handle_websocket_hibernation( &self, client_ws: WebSocketHandle, - unique_request_id: RequestId, + request_id: RequestId, ) -> Result { - let request_id = unique_request_id; - // Insert hibernating request entry before checking for pending messages // This ensures the entry exists even if we immediately rewake the actor self.ctx @@ -592,10 +594,7 @@ impl CustomServeTrait for PegboardGateway { .has_pending_websocket_messages(request_id) .await? { - tracing::debug!( - ?unique_request_id, - "detected pending requests on websocket hibernation, rewaking actor" - ); + tracing::debug!("exiting hibernating due to pending messages"); return Ok(HibernationResult::Continue); } @@ -603,10 +602,11 @@ impl CustomServeTrait for PegboardGateway { // Start keepalive task let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(()); let keepalive_handle = tokio::spawn(keepalive_task::task( + self.shared_state.clone(), self.ctx.clone(), self.actor_id, self.shared_state.gateway_id(), - unique_request_id, + request_id, keepalive_abort_rx, )); @@ -623,7 +623,7 @@ impl CustomServeTrait for PegboardGateway { .op(pegboard::ops::actor::hibernating_request::delete::Input { actor_id: self.actor_id, gateway_id: self.shared_state.gateway_id(), - request_id: unique_request_id, + request_id, }) .await?; } @@ -643,6 +643,21 @@ impl PegboardGateway { .subscribe::(("actor_id", self.actor_id)) .await?; + // Fetch actor info after sub to prevent race condition + if let Some(actor) = self + .ctx + .op(pegboard::ops::actor::get_for_gateway::Input { + actor_id: self.actor_id, + }) + .await? + { + if actor.runner_id.is_some() { + tracing::debug!("actor became ready during hibernation"); + + return Ok(HibernationResult::Continue); + } + } + let res = tokio::select! { _ = ready_sub.next() => { tracing::debug!("actor became ready during hibernation"); diff --git a/engine/packages/pegboard-gateway/src/shared_state.rs b/engine/packages/pegboard-gateway/src/shared_state.rs index 32b0dfe4a5..57e1fd4cea 100644 --- a/engine/packages/pegboard-gateway/src/shared_state.rs +++ b/engine/packages/pegboard-gateway/src/shared_state.rs @@ -27,6 +27,7 @@ pub struct InFlightRequestHandle { /// This is separate from `msg_rx` there may still be messages that need to be sent to the /// request after `msg_rx` has dropped. pub drop_rx: watch::Receiver<()>, + pub new: bool, } struct InFlightRequest { @@ -48,6 +49,8 @@ struct InFlightRequest { struct HibernationState { total_pending_ws_msgs_size: u64, pending_ws_msgs: Vec, + // Used to keep hibernating websockets from being GC'd + last_ping: Instant, } pub struct PendingWebsocketMessage { @@ -61,13 +64,14 @@ pub struct SharedStateInner { gateway_id: GatewayId, receiver_subject: String, in_flight_requests: HashMap, + hibernation_timeout: i64, } #[derive(Clone)] pub struct SharedState(Arc); impl SharedState { - pub fn new(ups: PubSub) -> Self { + pub fn new(config: &rivet_config::Config, ups: PubSub) -> Self { let gateway_id = tunnel_id::generate_gateway_id(); let receiver_subject = pegboard::pubsub_subjects::GatewayReceiverSubject::new(gateway_id).to_string(); @@ -77,6 +81,7 @@ impl SharedState { gateway_id, receiver_subject, in_flight_requests: HashMap::new(), + hibernation_timeout: config.pegboard().hibernating_request_eligible_threshold(), })) } @@ -97,7 +102,7 @@ impl SharedState { Ok(()) } - #[tracing::instrument(skip_all, fields(%receiver_subject, request_id=?tunnel_id::request_id_to_string(&request_id)))] + #[tracing::instrument(skip_all, fields(%receiver_subject, request_id=%tunnel_id::request_id_to_string(&request_id)))] pub async fn start_in_flight_request( &self, receiver_subject: String, @@ -106,7 +111,7 @@ impl SharedState { let (msg_tx, msg_rx) = mpsc::channel(128); let (drop_tx, drop_rx) = watch::channel(()); - match self.in_flight_requests.entry_async(request_id).await { + let new = match self.in_flight_requests.entry_async(request_id).await { Entry::Vacant(entry) => { entry.insert_entry(InFlightRequest { receiver_subject, @@ -118,6 +123,8 @@ impl SharedState { stopping: false, last_pong: util::timestamp::now(), }); + + true } Entry::Occupied(mut entry) => { entry.receiver_subject = receiver_subject; @@ -129,13 +136,19 @@ impl SharedState { entry.hibernation_state = None; entry.stopping = false; } + + false } - } + }; - InFlightRequestHandle { msg_rx, drop_rx } + InFlightRequestHandle { + msg_rx, + drop_rx, + new, + } } - #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id)))] + #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id)))] pub async fn send_message( &self, request_id: RequestId, @@ -172,6 +185,8 @@ impl SharedState { message_kind, }; + tracing::debug!(?message_id, ?payload, "shared state send message"); + // Send message let message = protocol::ToRunner::ToClientTunnelMessage(payload); let message_serialized = versioned::ToRunner::wrap_latest(message) @@ -211,7 +226,7 @@ impl SharedState { Ok(()) } - #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id)))] + #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id)))] pub async fn send_and_check_ping(&self, request_id: RequestId) -> Result<()> { let req = self .in_flight_requests @@ -247,6 +262,23 @@ impl SharedState { Ok(()) } + #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id)))] + pub async fn keepalive_hws(&self, request_id: RequestId) -> Result<()> { + let mut req = self + .in_flight_requests + .get_async(&request_id) + .await + .context("request not in flight")?; + + if let Some(hs) = &mut req.hibernation_state { + hs.last_ping = Instant::now(); + } else { + tracing::warn!("should not call keepalive_hws for non-hibernating ws"); + } + + Ok(()) + } + #[tracing::instrument(skip_all)] async fn receiver(&self, mut sub: Subscriber) { while let Ok(NextOutput::Message(msg)) = sub.next().await { @@ -261,7 +293,7 @@ impl SharedState { self.in_flight_requests.get_async(&pong.request_id).await else { tracing::debug!( - request_id=?tunnel_id::request_id_to_string(&pong.request_id), + request_id=%tunnel_id::request_id_to_string(&pong.request_id), "in flight has already been disconnected, dropping ping" ); continue; @@ -287,8 +319,8 @@ impl SharedState { self.in_flight_requests.get_async(&parts.request_id).await else { tracing::warn!( - gateway_id=?tunnel_id::gateway_id_to_string(&parts.gateway_id), - request_id=?tunnel_id::request_id_to_string(&parts.request_id), + gateway_id=%tunnel_id::gateway_id_to_string(&parts.gateway_id), + request_id=%tunnel_id::request_id_to_string(&parts.request_id), message_index=parts.message_index, "in flight has already been disconnected, dropping message" ); @@ -297,8 +329,8 @@ impl SharedState { // Send message to the request handler to emulate the real network action tracing::debug!( - gateway_id=?tunnel_id::gateway_id_to_string(&parts.gateway_id), - request_id=?tunnel_id::request_id_to_string(&parts.request_id), + gateway_id=%tunnel_id::gateway_id_to_string(&parts.gateway_id), + request_id=%tunnel_id::request_id_to_string(&parts.request_id), message_index=parts.message_index, "forwarding message to request handler" ); @@ -311,7 +343,7 @@ impl SharedState { } } - #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id), %enable))] + #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id), %enable))] pub async fn toggle_hibernation(&self, request_id: RequestId, enable: bool) -> Result<()> { let mut req = self .in_flight_requests @@ -326,6 +358,7 @@ impl SharedState { req.hibernation_state = Some(HibernationState { total_pending_ws_msgs_size: 0, pending_ws_msgs: Vec::new(), + last_ping: Instant::now(), }); } (false, false) => {} @@ -334,7 +367,7 @@ impl SharedState { Ok(()) } - #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id)))] + #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id)))] pub async fn resend_pending_websocket_messages(&self, request_id: RequestId) -> Result<()> { let Some(mut req) = self.in_flight_requests.get_async(&request_id).await else { bail!("request not in flight"); @@ -344,9 +377,11 @@ impl SharedState { if let Some(hs) = &mut req.hibernation_state { if !hs.pending_ws_msgs.is_empty() { - tracing::debug!(request_id=?tunnel_id::request_id_to_string(&request_id), len=?hs.pending_ws_msgs.len(), "resending pending messages"); + tracing::debug!(len=?hs.pending_ws_msgs.len(), "resending pending messages"); for pending_msg in &hs.pending_ws_msgs { + tracing::info!(?pending_msg.payload, ?pending_msg.message_index, "------2---------"); + self.ups .publish(&receiver_subject, &pending_msg.payload, PublishOpts::one()) .await?; @@ -357,7 +392,7 @@ impl SharedState { Ok(()) } - #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id)))] + #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id)))] pub async fn has_pending_websocket_messages(&self, request_id: RequestId) -> Result { let Some(req) = self.in_flight_requests.get_async(&request_id).await else { bail!("request not in flight"); @@ -370,7 +405,7 @@ impl SharedState { } } - #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id), %ack_index))] + #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id), %ack_index))] pub async fn ack_pending_websocket_messages( &self, request_id: RequestId, @@ -381,10 +416,7 @@ impl SharedState { }; let Some(hs) = &mut req.hibernation_state else { - tracing::warn!( - request_id=?tunnel_id::request_id_to_string(&request_id), - "cannot ack ws messages, hibernation is not enabled" - ); + tracing::warn!("cannot ack ws messages, hibernation is not enabled"); return Ok(()); }; @@ -395,7 +427,6 @@ impl SharedState { let len_after = hs.pending_ws_msgs.len(); tracing::debug!( - ack_index, removed_count = len_before - len_after, remaining_count = len_after, "acked pending websocket messages" @@ -440,7 +471,7 @@ impl SharedState { async fn gc_in_flight_requests(&self) { #[derive(Debug)] enum MsgGcReason { - /// Gateway channel is closed and there are no pending messages + /// Gateway channel is closed and is not hibernating GatewayClosed, /// WebSocket pending messages (ToServerWebSocketMessageAck) WebSocketMessageNotAcked { @@ -449,9 +480,14 @@ impl SharedState { #[allow(dead_code)] last_msg_index: u16, }, + /// The gateway has not kept alive the in flight request during hibernation for the given timeout + /// duration. + HibernationTimeout, } let now = Instant::now(); + let hibernation_timeout = + Duration::from_millis(self.hibernation_timeout.try_into().unwrap_or(90_000)); // First, check if an in flight req is beyond the timeout for tunnel message ack and websocket // message ack @@ -464,29 +500,23 @@ impl SharedState { } let reason = 'reason: { - // If we have no pending messages of any kind and the channel is closed, remove the - // in flight req - if req.msg_tx.is_closed() - && req - .hibernation_state - .as_ref() - .map(|hs| hs.pending_ws_msgs.is_empty()) - .unwrap_or(true) - { - break 'reason Some(MsgGcReason::GatewayClosed); - } + if let Some(hs) = &req.hibernation_state { + if let Some(earliest_pending_ws_msg) = hs.pending_ws_msgs.first() { + if now.duration_since(earliest_pending_ws_msg.send_instant) + > HWS_MESSAGE_ACK_TIMEOUT + { + break 'reason Some(MsgGcReason::WebSocketMessageNotAcked { + first_msg_index: earliest_pending_ws_msg.message_index, + last_msg_index: req.message_index + }); + } + } - if let Some(hs) = &req.hibernation_state - && let Some(earliest_pending_ws_msg) = hs.pending_ws_msgs.first() - { - if now.duration_since(earliest_pending_ws_msg.send_instant) - > HWS_MESSAGE_ACK_TIMEOUT - { - break 'reason Some(MsgGcReason::WebSocketMessageNotAcked { - first_msg_index: earliest_pending_ws_msg.message_index, - last_msg_index: req.message_index - }); + if hs.last_ping.elapsed() > hibernation_timeout { + break 'reason Some(MsgGcReason::HibernationTimeout); } + } else if req.msg_tx.is_closed() { + break 'reason Some(MsgGcReason::GatewayClosed); } None @@ -494,13 +524,13 @@ impl SharedState { if let Some(reason) = &reason { tracing::debug!( - request_id=?tunnel_id::request_id_to_string(request_id), + request_id=%tunnel_id::request_id_to_string(request_id), ?reason, "gc stopping in flight request" ); if req.drop_tx.send(()).is_err() { - tracing::debug!(request_id=?tunnel_id::request_id_to_string(request_id), "failed to send timeout msg to tunnel"); + tracing::debug!(request_id=%tunnel_id::request_id_to_string(request_id), "failed to send timeout msg to tunnel"); } // Mark req as stopping to skip this loop next time the gc is run @@ -518,7 +548,7 @@ impl SharedState { // When the websocket reconnects a new channel will be created if req.stopping && req.drop_tx.is_closed() { tracing::debug!( - request_id=?tunnel_id::request_id_to_string(request_id), + request_id=%tunnel_id::request_id_to_string(request_id), "gc removing in flight request" ); diff --git a/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs index 8926e4d9f6..571a2ea527 100644 --- a/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs +++ b/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs @@ -39,7 +39,7 @@ pub async fn task( } protocol::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => { tracing::debug!( - request_id=?tunnel_id::request_id_to_string(&request_id), + request_id=%tunnel_id::request_id_to_string(&request_id), ack_index=?ack.index, "received WebSocketMessageAck from runner" );