Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 39 additions & 93 deletions engine/packages/pegboard-gateway/src/shared_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

pub struct SharedStateInner {
ups: PubSub,
gateway_id: Uuid,
receiver_subject: String,
in_flight_requests: HashMap<RequestId, InFlightRequest>,
}
Expand All @@ -74,6 +75,7 @@

Self(Arc::new(SharedStateInner {
ups,
gateway_id,
receiver_subject,
in_flight_requests: HashMap::new(),
}))
Expand Down Expand Up @@ -160,6 +162,7 @@
};

let payload = protocol::ToClientTunnelMessage {
gateway_id: *self.gateway_id.as_bytes(),
request_id: request_id.clone(),
message_id,
// Only send reply to subject on the first message for this request. This reduces
Expand All @@ -179,8 +182,8 @@
});

// Send message
let message = protocol::ToClient::ToClientTunnelMessage(payload);
let message_serialized = versioned::ToClient::wrap_latest(message)
let message = protocol::ToRunner::ToClientTunnelMessage(payload);
let message_serialized = versioned::ToRunner::wrap_latest(message)
.serialize_with_embedded_version(PROTOCOL_VERSION)?;

if let (Some(hs), Some(ws_msg_index)) = (&mut req.hibernation_state, ws_msg_index) {
Expand Down Expand Up @@ -221,105 +224,48 @@
);

match versioned::ToGateway::deserialize_with_embedded_version(&msg.payload) {
Ok(protocol::ToGateway { message: msg }) => {
tracing::debug!(
request_id=?Uuid::from_bytes(msg.request_id),
message_id=?Uuid::from_bytes(msg.message_id),
"successfully deserialized message"
);

let Some(mut in_flight) =
self.in_flight_requests.get_async(&msg.request_id).await
Ok(protocol::ToGateway::ToGatewayKeepAlive) => {
// TODO:
// let prev_len = in_flight.pending_msgs.len();
//
// tracing::debug!(message_id=?Uuid::from_bytes(msg.message_id), "received tunnel ack");
//
// in_flight
// .pending_msgs
// .retain(|m| m.message_id != msg.message_id);
//
// if prev_len == in_flight.pending_msgs.len() {
// tracing::warn!(
// request_id=?Uuid::from_bytes(msg.request_id),
// message_id=?Uuid::from_bytes(msg.message_id),
// "pending message does not exist or ack received after message body"
// )
// } else {
// tracing::debug!(
// request_id=?Uuid::from_bytes(msg.request_id),
// message_id=?Uuid::from_bytes(msg.message_id),
// "received TunnelAck, removed from pending"
// );
// }
}
Ok(protocol::ToGateway::ToServerTunnelMessage(msg)) => {
let Some(in_flight) = self.in_flight_requests.get_async(&msg.request_id).await
else {
tracing::warn!(
request_id=?Uuid::from_bytes(msg.request_id),
message_id=?Uuid::from_bytes(msg.message_id),
"in flight has already been disconnected, cannot ack message"
"in flight has already been disconnected, dropping message"
);
continue;
};

if let protocol::ToServerTunnelMessageKind::TunnelAck = &msg.message_kind {
let prev_len = in_flight.pending_msgs.len();

tracing::debug!(message_id=?Uuid::from_bytes(msg.message_id), "received tunnel ack");

in_flight
.pending_msgs
.retain(|m| m.message_id != msg.message_id);

if prev_len == in_flight.pending_msgs.len() {
tracing::warn!(
request_id=?Uuid::from_bytes(msg.request_id),
message_id=?Uuid::from_bytes(msg.message_id),
"pending message does not exist or ack received after message body"
)
} else {
tracing::debug!(
request_id=?Uuid::from_bytes(msg.request_id),
message_id=?Uuid::from_bytes(msg.message_id),
"received TunnelAck, removed from pending"
);
}
} else {
// Send message to the request handler to emulate the real network action
tracing::debug!(
request_id=?Uuid::from_bytes(msg.request_id),
message_id=?Uuid::from_bytes(msg.message_id),
"forwarding message to request handler"
);
let _ = in_flight.msg_tx.send(msg.message_kind.clone()).await;

// Send ack back to runner
let ups_clone = self.ups.clone();
let receiver_subject = in_flight.receiver_subject.clone();
let request_id = msg.request_id;
let message_id = msg.message_id;
let ack_message = protocol::ToClient::ToClientTunnelMessage(
protocol::ToClientTunnelMessage {
request_id,
message_id,
gateway_reply_to: None,
message_kind: protocol::ToClientTunnelMessageKind::TunnelAck,
},
);
let ack_message_serialized =
match versioned::ToClient::wrap_latest(ack_message)
.serialize_with_embedded_version(PROTOCOL_VERSION)
{
Ok(x) => x,
Err(err) => {
tracing::error!(?err, "failed to serialize ack");
continue;
}
};
tokio::spawn(async move {
match ups_clone
.publish(
&receiver_subject,
&ack_message_serialized,
PublishOpts::one(),
)
.await
{
Ok(_) => {
tracing::debug!(
request_id=?Uuid::from_bytes(request_id),
message_id=?Uuid::from_bytes(message_id),
"sent TunnelAck to runner"
);
}
Err(err) => {
tracing::warn!(
?err,
request_id=?Uuid::from_bytes(request_id),
message_id=?Uuid::from_bytes(message_id),
"failed to send TunnelAck to runner"
);
}
}
});
}
// Send message to the request handler to emulate the real network action
tracing::debug!(
request_id=?Uuid::from_bytes(msg.request_id),
message_id=?Uuid::from_bytes(msg.message_id),
"forwarding message to request handler"
);
let _ = in_flight.msg_tx.send(msg.message_kind.clone()).await;
}
Err(err) => {
tracing::error!(?err, "failed to parse message");
Expand Down Expand Up @@ -467,9 +413,9 @@
/// Gateway channel is closed and there are no pending messages
GatewayClosed,
/// Any tunnel message not acked (TunnelAck)
MessageNotAcked { message_id: Uuid },

Check failure on line 416 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `message_id` is never read

Check failure on line 416 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `message_id` is never read

Check failure on line 416 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `message_id` is never read

Check failure on line 416 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `message_id` is never read
/// WebSocket pending messages (ToServerWebSocketMessageAck)
WebSocketMessageNotAcked { last_ws_msg_index: u16 },

Check failure on line 418 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `last_ws_msg_index` is never read

Check failure on line 418 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `last_ws_msg_index` is never read

Check failure on line 418 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `last_ws_msg_index` is never read

Check failure on line 418 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `last_ws_msg_index` is never read
}

let now = Instant::now();
Expand Down
13 changes: 0 additions & 13 deletions engine/packages/pegboard-runner/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,13 @@ use rivet_guard_core::WebSocketHandle;
use rivet_runner_protocol as protocol;
use rivet_runner_protocol::*;
use std::{
collections::HashMap,
sync::{Arc, atomic::AtomicU32},
time::Duration,
};
use tokio::sync::Mutex;
use vbare::OwnedVersionedData;

use crate::{errors::WsError, utils::UrlData};

pub struct TunnelActiveRequest {
/// Subject to send replies to.
pub gateway_reply_to: String,
pub is_ws: bool,
}

pub struct Conn {
pub namespace_id: Id,
pub runner_name: String,
Expand All @@ -32,10 +24,6 @@ pub struct Conn {
pub protocol_version: u16,
pub ws_handle: WebSocketHandle,
pub last_rtt: AtomicU32,

/// Active HTTP & WebSocket requests. They are separate but use the same mechanism to
/// maintain state.
pub tunnel_active_requests: Mutex<HashMap<RequestId, TunnelActiveRequest>>,
}

#[tracing::instrument(skip_all)]
Expand Down Expand Up @@ -191,6 +179,5 @@ pub async fn init_conn(
protocol_version,
ws_handle,
last_rtt: AtomicU32::new(0),
tunnel_active_requests: Mutex::new(HashMap::new()),
}))
}
37 changes: 0 additions & 37 deletions engine/packages/pegboard-runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@ use rivet_guard_core::{
WebSocketHandle, custom_serve::CustomServeTrait, proxy_service::ResponseBody,
request_context::RequestContext,
};
use rivet_runner_protocol as protocol;
use std::time::Duration;
use tokio::sync::watch;
use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
use universalpubsub::PublishOpts;
use vbare::OwnedVersionedData;

mod conn;
mod errors;
Expand Down Expand Up @@ -243,41 +241,6 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe {
);
}

// Send close messages to all remaining active requests
let active_requests = conn.tunnel_active_requests.lock().await;
for (request_id, req) in &*active_requests {
// Websockets are not ephemeral like requests. If the runner ws closes they are not informed;
// instead they wait for the actor itself to stop.
if req.is_ws {
continue;
}

let close_message = protocol::ToServerTunnelMessage {
request_id: request_id.clone(),
message_id: Uuid::new_v4().into_bytes(),
message_kind: protocol::ToServerTunnelMessageKind::ToServerResponseAbort,
};

let msg_serialized = protocol::versioned::ToGateway::wrap_latest(protocol::ToGateway {
message: close_message.clone(),
})
.serialize_with_embedded_version(protocol::PROTOCOL_VERSION)
.context("failed to serialize tunnel message for gateway")?;

// Publish message to UPS
let res = ups
.publish(&req.gateway_reply_to, &msg_serialized, PublishOpts::one())
.await;

if let Err(err) = res {
tracing::warn!(
?err,
%req.gateway_reply_to,
"error sending close message to remaining active requests"
);
}
}

// This will determine the close frame sent back to the runner websocket
lifecycle_res.map(|_| None)
}
Expand Down
Loading
Loading