Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/packages/pegboard-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@ impl CustomServeTrait for PegboardGateway {
// Determine single result from both tasks
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res) {
// Prefer error
(_, Err(err)) => Err(err),
(Err(err), _) => Err(err),
(_, Err(err)) => Err(err),
// Prefer non aborted result if both succeed
(Ok(res), Ok(LifecycleResult::Aborted)) => Ok(res),
(Ok(LifecycleResult::Aborted), Ok(res)) => Ok(res),
Expand Down
18 changes: 10 additions & 8 deletions engine/packages/pegboard-runner/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ pub struct TunnelActiveRequest {
}

pub struct Conn {
pub namespace_id: Id,
pub runner_name: String,
pub runner_key: String,
pub runner_id: Id,

pub workflow_id: Id,

pub protocol_version: u16,

pub ws_handle: WebSocketHandle,

pub last_rtt: AtomicU32,

/// Active HTTP & WebSocket requests. They are separate but use the same mechanism to
Expand Down Expand Up @@ -63,7 +62,7 @@ pub async fn init_conn(
let mut ws_rx = ws_rx.lock().await;

// Receive init packet
let (runner_id, workflow_id) = if let Some(msg) =
let (runner_name, runner_id, workflow_id) = if let Some(msg) =
tokio::time::timeout(Duration::from_secs(5), ws_rx.next())
.await
.map_err(|_| WsError::TimedOutWaitingForInit.build())?
Expand All @@ -81,7 +80,7 @@ pub async fn init_conn(
.map_err(|err| WsError::InvalidPacket(err.to_string()).build())
.context("failed to deserialize initial packet from client")?;

let (runner_id, workflow_id) =
let (runner_name, runner_id, workflow_id) =
if let protocol::ToServer::ToServerInit(protocol::ToServerInit {
name,
version,
Expand Down Expand Up @@ -160,7 +159,7 @@ pub async fn init_conn(
)
})?;

(runner_id, workflow_id)
(name.clone(), runner_id, workflow_id)
} else {
tracing::debug!(?packet, "invalid initial packet");
return Err(WsError::InvalidInitialPacket("must be `ToServer::Init`").build());
Expand All @@ -178,12 +177,15 @@ pub async fn init_conn(
)
})?;

(runner_id, workflow_id)
(runner_name, runner_id, workflow_id)
} else {
return Err(WsError::ConnectionClosed.build());
};

Ok(Arc::new(Conn {
namespace_id: namespace.namespace_id,
runner_name,
runner_key,
runner_id,
workflow_id,
protocol_version,
Expand Down
158 changes: 126 additions & 32 deletions engine/packages/pegboard-runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,26 @@ use rivet_guard_core::{
};
use rivet_runner_protocol as protocol;
use std::time::Duration;
use tokio::sync::watch;
use tokio_tungstenite::tungstenite::protocol::frame::{CloseFrame, coding::CloseCode};
use universalpubsub::PublishOpts;
use vbare::OwnedVersionedData;

mod client_to_pubsub_task;
mod conn;
mod errors;
mod ping_task;
mod pubsub_to_client_task;
mod tunnel_to_ws_task;
mod utils;
mod ws_to_tunnel_task;

const UPDATE_PING_INTERVAL: Duration = Duration::from_secs(3);

#[derive(Debug)]
enum LifecycleResult {
Closed,
Aborted,
}

pub struct PegboardRunnerWsCustomServe {
ctx: StandaloneCtx,
}
Expand Down Expand Up @@ -79,52 +86,142 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe {
.await
.context("failed to initialize runner connection")?;

// Subscribe to pubsub topic for this runner before accepting the client websocket so
// that failures can be retried by the proxy.
// Subscribe before accepting the client websocket so that failures can be retried by the proxy.
let topic =
pegboard::pubsub_subjects::RunnerReceiverSubject::new(conn.runner_id).to_string();
tracing::debug!(%topic, "subscribing to runner receiver topic");
let eviction_topic =
pegboard::pubsub_subjects::RunnerEvictionByIdSubject::new(conn.runner_id).to_string();
let eviction_topic2 = pegboard::pubsub_subjects::RunnerEvictionByNameSubject::new(
conn.namespace_id,
&conn.runner_name,
&conn.runner_key,
)
.to_string();

tracing::debug!(%topic, %eviction_topic, %eviction_topic2, "subscribing to runner topics");
let sub = ups
.subscribe(&topic)
.await
.with_context(|| format!("failed to subscribe to runner receiver topic: {}", topic))?;
let mut eviction_sub = ups.subscribe(&eviction_topic).await.with_context(|| {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

join sub futs

format!(
"failed to subscribe to runner eviction topic: {}",
eviction_topic
)
})?;
let mut eviction_sub2 = ups.subscribe(&eviction_topic2).await.with_context(|| {
format!(
"failed to subscribe to runner eviction topic: {}",
eviction_topic2
)
})?;

// Publish eviction message to evict any currently connected runners with the same id or ns id +
// runner name + runner key. This happens after subscribing to prevent race conditions.
tokio::try_join!(
async {
ups.publish(&eviction_topic, &[], PublishOpts::broadcast())
.await?;
// Because we will receive our own message, skip the first message in the sub
eviction_sub.next().await
},
async {
ups.publish(&eviction_topic2, &[], PublishOpts::broadcast())
.await?;
eviction_sub2.next().await
},
)?;

// Forward pubsub -> WebSocket
let mut pubsub_to_client = tokio::spawn(pubsub_to_client_task::task(conn.clone(), sub));
let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
let (ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch::channel(());
let (ping_abort_tx, ping_abort_rx) = watch::channel(());

let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo this naming is confusing. most of this file is tunnel logic, but it's also where all of the other ws<->pb forwarding logic happens.

conn.clone(),
sub,
eviction_sub,
tunnel_to_ws_abort_rx,
));

// Forward WebSocket -> pubsub
let mut client_to_pubsub = tokio::spawn(client_to_pubsub_task::task(
let ws_to_tunnel = tokio::spawn(ws_to_tunnel_task::task(
self.ctx.clone(),
conn.clone(),
ws_handle.recv(),
eviction_sub2,
ws_to_tunnel_abort_rx,
));

// Update pings
let mut ping = tokio::spawn(ping_task::task(self.ctx.clone(), conn.clone()));
let ping = tokio::spawn(ping_task::task(
self.ctx.clone(),
conn.clone(),
ping_abort_rx,
));
let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx.clone();
let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone();
let ping_abort_tx2 = ping_abort_tx.clone();

// Wait for all tasks to complete
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio::join!(
async {
let res = tunnel_to_ws.await?;

// Abort others if not aborted
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "tunnel to ws task completed, aborting others");

let _ = ping_abort_tx.send(());
let _ = ws_to_tunnel_abort_tx.send(());
} else {
tracing::debug!(?res, "tunnel to ws task completed");
}

// Wait for either task to complete
let lifecycle_res = tokio::select! {
res = &mut pubsub_to_client => {
let res = res?;
tracing::debug!(?res, "pubsub to WebSocket task completed");
res
}
res = &mut client_to_pubsub => {
let res = res?;
tracing::debug!(?res, "WebSocket to pubsub task completed");
},
async {
let res = ws_to_tunnel.await?;

// Abort others if not aborted
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "ws to tunnel task completed, aborting others");

let _ = ping_abort_tx2.send(());
let _ = tunnel_to_ws_abort_tx.send(());
} else {
tracing::debug!(?res, "ws to tunnel task completed");
}

res
}
res = &mut ping => {
let res = res?;
tracing::debug!(?res, "ping task completed");
},
async {
let res = ping.await?;

// Abort others if not aborted
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "ping task completed, aborting others");

let _ = ws_to_tunnel_abort_tx2.send(());
let _ = tunnel_to_ws_abort_tx2.send(());
} else {
tracing::debug!(?res, "ping task completed");
}

res
}
};
);

// Abort remaining tasks
pubsub_to_client.abort();
client_to_pubsub.abort();
ping.abort();
// Determine single result from both tasks
let lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) {
// Prefer error
(Err(err), _, _) => Err(err),
(_, Err(err), _) => Err(err),
(_, _, Err(err)) => Err(err),
// Prefer non aborted result if both succeed
(Ok(res), Ok(LifecycleResult::Aborted), _) => Ok(res),
(Ok(LifecycleResult::Aborted), Ok(res), _) => Ok(res),
// Unlikely case
(res, _, _) => res,
};

// Make runner immediately ineligible when it disconnects
let update_alloc_res = self
Expand Down Expand Up @@ -177,10 +274,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe {
.context("failed to serialize tunnel message for gateway")?;

// Publish message to UPS
let res = self
.ctx
.ups()
.context("failed to get UPS instance for tunnel message")?
let res = ups
.publish(&req.gateway_reply_to, &msg_serialized, PublishOpts::one())
.await;

Expand Down
16 changes: 13 additions & 3 deletions engine/packages/pegboard-runner/src/ping_task.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
use gas::prelude::*;
use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility};
use std::sync::{Arc, atomic::Ordering};
use tokio::sync::watch;

use crate::{UPDATE_PING_INTERVAL, conn::Conn};
use crate::{LifecycleResult, UPDATE_PING_INTERVAL, conn::Conn};

/// Updates the ping of all runners requesting a ping update at once.
#[tracing::instrument(skip_all)]
pub async fn task(ctx: StandaloneCtx, conn: Arc<Conn>) -> Result<()> {
pub async fn task(
ctx: StandaloneCtx,
conn: Arc<Conn>,
mut ping_abort_rx: watch::Receiver<()>,
) -> Result<LifecycleResult> {
loop {
tokio::time::sleep(UPDATE_PING_INTERVAL).await;
tokio::select! {
_ = tokio::time::sleep(UPDATE_PING_INTERVAL) => {}
_ = ping_abort_rx.changed() => {
return Ok(LifecycleResult::Aborted);
}
}

let Some(wf) = ctx
.workflow::<pegboard::workflows::runner::Input>(conn.workflow_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,43 @@ use gas::prelude::*;
use hyper_tungstenite::tungstenite::Message as WsMessage;
use rivet_runner_protocol::{self as protocol, versioned};
use std::sync::Arc;
use tokio::sync::watch;
use universalpubsub::{NextOutput, Subscriber};
use vbare::OwnedVersionedData;

use crate::{
LifecycleResult,
conn::{Conn, TunnelActiveRequest},
errors,
};

#[tracing::instrument(skip_all, fields(runner_id=?conn.runner_id, workflow_id=?conn.workflow_id, protocol_version=%conn.protocol_version))]
pub async fn task(conn: Arc<Conn>, mut sub: Subscriber) -> Result<()> {
while let NextOutput::Message(ups_msg) = sub
.next()
.await
.context("pubsub_to_client_task sub failed")?
{
pub async fn task(
conn: Arc<Conn>,
mut sub: Subscriber,
mut eviction_sub: Subscriber,
mut tunnel_to_ws_abort_rx: watch::Receiver<()>,
) -> Result<LifecycleResult> {
loop {
let ups_msg = tokio::select! {
res = sub.next() => {
if let NextOutput::Message(ups_msg) = res.context("pubsub_to_client_task sub failed")? {
ups_msg
} else {
tracing::debug!("tunnel sub closed");
bail!("tunnel sub closed");
}
}
_ = eviction_sub.next() => {
tracing::debug!("runner evicted");
return Err(errors::WsError::Eviction.build());
}
_ = tunnel_to_ws_abort_rx.changed() => {
tracing::debug!("task aborted");
return Ok(LifecycleResult::Aborted);
}
};

tracing::debug!(
payload_len = ups_msg.payload.len(),
"received message from pubsub, forwarding to WebSocket"
Expand Down Expand Up @@ -105,6 +127,4 @@ pub async fn task(conn: Arc<Conn>, mut sub: Subscriber) -> Result<()> {
.await
.context("failed to send message to WebSocket")?
}

Ok(())
}
Loading
Loading