Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion engine/packages/guard/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub async fn start(config: rivet_config::Config, pools: rivet_pools::Pools) -> R
}

// Share shared context
let shared_state = shared_state::SharedState::new(ctx.ups()?);
let shared_state = shared_state::SharedState::new(&config, ctx.ups()?);
shared_state.start().await?;

// Create handlers
Expand Down
4 changes: 2 additions & 2 deletions engine/packages/guard/src/shared_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use universalpubsub::PubSub;
pub struct SharedState(Arc<SharedStateInner>);

impl SharedState {
pub fn new(pubsub: PubSub) -> SharedState {
pub fn new(config: &rivet_config::Config, pubsub: PubSub) -> SharedState {
SharedState(Arc::new(SharedStateInner {
pegboard_gateway: pegboard_gateway::shared_state::SharedState::new(pubsub),
pegboard_gateway: pegboard_gateway::shared_state::SharedState::new(config, pubsub),
}))
}

Expand Down
1 change: 1 addition & 0 deletions engine/packages/pegboard-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ hyper-tungstenite.workspace = true
lazy_static.workspace = true
pegboard.workspace = true
rand.workspace = true
rivet-config.workspace = true
rivet-error.workspace = true
rivet-guard-core.workspace = true
rivet-metrics.workspace = true
Expand Down
20 changes: 13 additions & 7 deletions engine/packages/pegboard-gateway/src/keepalive_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ use std::time::Duration;
use tokio::sync::watch;

use super::LifecycleResult;
use crate::shared_state::SharedState;

/// Periodically pings writes keepalive in UDB. This is used to restore hibernating request IDs on
/// next actor start.
///
///Only ran for hibernating requests.
/// Only ran for hibernating requests.

pub async fn task(
shared_state: SharedState,
ctx: StandaloneCtx,
actor_id: Id,
gateway_id: GatewayId,
Expand Down Expand Up @@ -43,11 +46,14 @@ pub async fn task(
let jitter = { rand::thread_rng().gen_range(0..128) };
tokio::time::sleep(Duration::from_millis(jitter)).await;

ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
actor_id,
gateway_id,
request_id,
})
.await?;
tokio::try_join!(
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
actor_id,
gateway_id,
request_id,
}),
// Keep alive in flight req during hibernation
shared_state.keepalive_hws(request_id),
)?;
}
}
77 changes: 46 additions & 31 deletions engine/packages/pegboard-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ impl CustomServeTrait for PegboardGateway {
let InFlightRequestHandle {
mut msg_rx,
mut drop_rx,
..
} = self
.shared_state
.start_in_flight_request(tunnel_subject, request_id)
Expand Down Expand Up @@ -212,7 +213,7 @@ impl CustomServeTrait for PegboardGateway {
}
} else {
tracing::warn!(
request_id=?tunnel_id::request_id_to_string(&request_id),
request_id=%tunnel_id::request_id_to_string(&request_id),
"received no message response during request init",
);
break;
Expand Down Expand Up @@ -267,14 +268,14 @@ impl CustomServeTrait for PegboardGateway {
Ok(response)
}

#[tracing::instrument(skip_all, fields(actor_id=?self.actor_id, runner_id=?self.runner_id))]
#[tracing::instrument(skip_all, fields(actor_id=?self.actor_id, runner_id=?self.runner_id, request_id=%tunnel_id::request_id_to_string(&request_id)))]
async fn handle_websocket(
&self,
client_ws: WebSocketHandle,
headers: &hyper::HeaderMap,
_path: &str,
_request_context: &mut RequestContext,
unique_request_id: RequestId,
request_id: RequestId,
after_hibernation: bool,
) -> Result<Option<CloseFrame>> {
// Use the actor ID from the gateway instance
Expand All @@ -298,15 +299,20 @@ impl CustomServeTrait for PegboardGateway {
pegboard::pubsub_subjects::RunnerReceiverSubject::new(self.runner_id).to_string();

// Start listening for WebSocket messages
let request_id = unique_request_id;
let InFlightRequestHandle {
mut msg_rx,
mut drop_rx,
new,
} = self
.shared_state
.start_in_flight_request(tunnel_subject.clone(), request_id)
.await;

ensure!(
!after_hibernation || !new,
"should not be creating a new in flight entry after hibernation"
);

// If we are reconnecting after hibernation, don't send an open message
let can_hibernate = if after_hibernation {
true
Expand Down Expand Up @@ -348,7 +354,7 @@ impl CustomServeTrait for PegboardGateway {
}
} else {
tracing::warn!(
request_id=?tunnel_id::request_id_to_string(&request_id),
request_id=%tunnel_id::request_id_to_string(&request_id),
"received no message response during ws init",
);
break;
Expand Down Expand Up @@ -416,17 +422,23 @@ impl CustomServeTrait for PegboardGateway {
request_id,
ping_abort_rx,
));
let keepalive = if can_hibernate {
Some(tokio::spawn(keepalive_task::task(
self.shared_state.clone(),
self.ctx.clone(),
self.actor_id,
self.shared_state.gateway_id(),
request_id,
keepalive_abort_rx,
)))
} else {
None
};

let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx.clone();
let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone();
let ping_abort_tx2 = ping_abort_tx.clone();

// Clone variables needed for keepalive task
let ctx_clone = self.ctx.clone();
let actor_id_clone = self.actor_id;
let gateway_id_clone = self.shared_state.gateway_id();
let request_id_clone = request_id;

// Wait for all tasks to complete
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res) = tokio::join!(
async {
Expand Down Expand Up @@ -478,17 +490,9 @@ impl CustomServeTrait for PegboardGateway {
res
},
async {
if !can_hibernate {
let Some(keepalive) = keepalive else {
return Ok(LifecycleResult::Aborted);
}

let keepalive = tokio::spawn(keepalive_task::task(
ctx_clone,
actor_id_clone,
gateway_id_clone,
request_id_clone,
keepalive_abort_rx,
));
};

let res = keepalive.await?;

Expand Down Expand Up @@ -568,14 +572,12 @@ impl CustomServeTrait for PegboardGateway {
}
}

#[tracing::instrument(skip_all, fields(actor_id=?self.actor_id))]
#[tracing::instrument(skip_all, fields(actor_id=?self.actor_id, request_id=%tunnel_id::request_id_to_string(&request_id)))]
async fn handle_websocket_hibernation(
&self,
client_ws: WebSocketHandle,
unique_request_id: RequestId,
request_id: RequestId,
) -> Result<HibernationResult> {
let request_id = unique_request_id;

// Insert hibernating request entry before checking for pending messages
// This ensures the entry exists even if we immediately rewake the actor
self.ctx
Expand All @@ -592,21 +594,19 @@ impl CustomServeTrait for PegboardGateway {
.has_pending_websocket_messages(request_id)
.await?
{
tracing::debug!(
?unique_request_id,
"detected pending requests on websocket hibernation, rewaking actor"
);
tracing::debug!("exiting hibernating due to pending messages");

return Ok(HibernationResult::Continue);
}

// Start keepalive task
let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(());
let keepalive_handle = tokio::spawn(keepalive_task::task(
self.shared_state.clone(),
self.ctx.clone(),
self.actor_id,
self.shared_state.gateway_id(),
unique_request_id,
request_id,
keepalive_abort_rx,
));

Expand All @@ -623,7 +623,7 @@ impl CustomServeTrait for PegboardGateway {
.op(pegboard::ops::actor::hibernating_request::delete::Input {
actor_id: self.actor_id,
gateway_id: self.shared_state.gateway_id(),
request_id: unique_request_id,
request_id,
})
.await?;
}
Expand All @@ -643,6 +643,21 @@ impl PegboardGateway {
.subscribe::<pegboard::workflows::actor::Ready>(("actor_id", self.actor_id))
.await?;

// Fetch actor info after sub to prevent race condition
if let Some(actor) = self
.ctx
.op(pegboard::ops::actor::get_for_gateway::Input {
actor_id: self.actor_id,
})
.await?
{
if actor.runner_id.is_some() {
tracing::debug!("actor became ready during hibernation");

return Ok(HibernationResult::Continue);
}
}

let res = tokio::select! {
_ = ready_sub.next() => {
tracing::debug!("actor became ready during hibernation");
Expand Down
Loading
Loading