Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions packages/core/api-public/src/runner_configs/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
Ok(response) => Json(response).into_response(),
Err(err) => ApiError::from(err).into_response(),
}
}

Check warning on line 35 in packages/core/api-public/src/runner_configs/delete.rs

View workflow job for this annotation

GitHub Actions / Rustfmt

Diff in /home/runner/work/engine/engine/packages/core/api-public/src/runner_configs/delete.rs

#[tracing::instrument(skip_all)]
async fn delete_inner(
Expand Down Expand Up @@ -67,5 +67,23 @@
}
}

// Resolve namespace
let namespace = ctx
.op(namespace::ops::resolve_for_name_global::Input {
name: query.namespace.clone(),
})
.await?
.ok_or_else(|| namespace::errors::Namespace::NotFound.build())?;

// Purge cache
ctx.cache()
.clone()
.request()
.purge(
"namespace.runner_config.get",
vec![(namespace.namespace_id, path.runner_name.clone())],
)
.await?;

Ok(DeleteResponse {})
}
26 changes: 18 additions & 8 deletions packages/core/api-public/src/runner_configs/upsert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,18 @@ async fn upsert_inner(
}
}

// Resolve namespace
let namespace = ctx
.op(namespace::ops::resolve_for_name_global::Input {
name: query.namespace.clone(),
})
.await?
.ok_or_else(|| namespace::errors::Namespace::NotFound.build())?;

// Update runner metadata
//
// This allows us to populate the actor names immediately upon configuring a serverless runner
if let Some((url, metadata_headers)) = serverless_config {
// Resolve namespace
let namespace = ctx
.op(namespace::ops::resolve_for_name_global::Input {
name: query.namespace.clone(),
})
.await?
.ok_or_else(|| namespace::errors::Namespace::NotFound.build())?;

if let Err(err) = utils::refresh_runner_config_metadata(
ctx.clone(),
namespace.namespace_id,
Expand All @@ -150,5 +150,15 @@ async fn upsert_inner(
}
}

// Purge cache
ctx.cache()
.clone()
.request()
.purge(
"namespace.runner_config.get",
vec![(namespace.namespace_id, path.runner_name.clone())],
)
.await?;

Ok(UpsertResponse {})
}
35 changes: 34 additions & 1 deletion packages/core/guard/core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for
pub const X_RIVET_ERROR: HeaderName = HeaderName::from_static("x-rivet-error");
const ROUTE_CACHE_TTL: Duration = Duration::from_secs(60 * 10); // 10 minutes
const PROXY_STATE_CACHE_TTL: Duration = Duration::from_secs(60 * 60); // 1 hour
const WEBSOCKET_CLOSE_LINGER: Duration = Duration::from_millis(100); // Keep TCP connection open briefly after WebSocket close

/// Response body type that can handle both streaming and buffered responses
#[derive(Debug)]
Expand Down Expand Up @@ -944,7 +945,7 @@ impl ProxyService {
if !err.is_connect() || attempts >= max_attempts {
tracing::error!(?err, "Request error after {} attempts", attempts);
return Err(errors::UpstreamError(
"failed to connect to runner. Make sure your runners are healthy and the provided runner address is reachable by Rivet."
"Failed to connect to runner. Make sure your runners are healthy and do not have any crash logs."
.to_string(),
)
.build());
Expand Down Expand Up @@ -1799,6 +1800,12 @@ impl ProxyService {
})))
.await?;

// Flush to ensure close frame is sent
ws_handle.flush().await?;

// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;

break;
}
Err(err) => {
Expand All @@ -1811,6 +1818,12 @@ impl ProxyService {
)))
.await?;

// Flush to ensure close frame is sent
ws_handle.flush().await?;

// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;

break;
} else {
let backoff = ProxyService::calculate_backoff(
Expand Down Expand Up @@ -1841,6 +1854,12 @@ impl ProxyService {
),
)))
.await?;

// Flush to ensure close frame is sent
ws_handle.flush().await?;

// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
}
Ok(ResolveRouteOutput::Target(_)) => {
ws_handle
Expand All @@ -1850,6 +1869,13 @@ impl ProxyService {
),
)))
.await?;

// Flush to ensure close frame is sent
ws_handle.flush().await?;

// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;

break;
}
Err(err) => {
Expand All @@ -1858,6 +1884,13 @@ impl ProxyService {
err_to_close_frame(err),
)))
.await?;

// Flush to ensure close frame is sent
ws_handle.flush().await?;

// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;

break;
}
}
Expand Down
13 changes: 13 additions & 0 deletions packages/core/guard/core/src/websocket_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ impl WebSocketHandleInner {
}
}

pub async fn flush(&self) -> Result<()> {
let mut state = self.state.lock().await;
match &mut *state {
WebSocketState::Unaccepted { .. } | WebSocketState::Accepting => {
bail!("websocket has not been accepted");
}
WebSocketState::Split { ws_tx } => {
ws_tx.flush().await?;
Ok(())
}
}
}

async fn accept_inner(state: &mut WebSocketState) -> Result<WebSocketReceiver> {
if !matches!(*state, WebSocketState::Unaccepted { .. }) {
bail!("websocket already accepted")
Expand Down
1 change: 1 addition & 0 deletions packages/core/pegboard-serverless/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ rivet-types.workspace = true
rivet-util.workspace = true
tracing.workspace = true
universaldb.workspace = true
universalpubsub.workspace = true
vbare.workspace = true

namespace.workspace = true
Expand Down
84 changes: 65 additions & 19 deletions packages/core/pegboard-serverless/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use rivet_types::runner_configs::RunnerConfigKind;
use tokio::{sync::oneshot, task::JoinHandle, time::Duration};
use universaldb::options::StreamingMode;
use universaldb::utils::IsolationLevel::*;
use universalpubsub::PublishOpts;
use vbare::OwnedVersionedData;

const X_RIVET_ENDPOINT: HeaderName = HeaderName::from_static("x-rivet-endpoint");
Expand All @@ -27,6 +28,8 @@ const X_RIVET_TOTAL_SLOTS: HeaderName = HeaderName::from_static("x-rivet-total-s
const X_RIVET_RUNNER_NAME: HeaderName = HeaderName::from_static("x-rivet-runner-name");
const X_RIVET_NAMESPACE_ID: HeaderName = HeaderName::from_static("x-rivet-namespace-id");

const DRAIN_GRACE_PERIOD: Duration = Duration::from_secs(10);

struct OutboundConnection {
handle: JoinHandle<()>,
shutdown_tx: oneshot::Sender<()>,
Expand Down Expand Up @@ -377,12 +380,14 @@ async fn outbound_handler(
anyhow::Ok(())
};

let sleep_until_drop = request_lifespan.saturating_sub(DRAIN_GRACE_PERIOD);
tokio::select! {
res = stream_handler => return res.map_err(Into::into),
_ = tokio::time::sleep(request_lifespan) => {}
_ = tokio::time::sleep(sleep_until_drop) => {}
_ = shutdown_rx => {}
}

// Stop runner
draining.store(true, Ordering::SeqCst);

ctx.msg(rivet_types::msgs::pegboard::BumpServerlessAutoscaler {})
Expand All @@ -394,34 +399,56 @@ async fn outbound_handler(
}

// Continue waiting on req while draining
while let Some(event) = source.next().await {
match event {
Ok(sse::Event::Open) => {}
Ok(sse::Event::Message(msg)) => {
tracing::debug!(%msg.data, "received outbound req message");

// If runner_id is none at this point it means we did not send the stopping signal yet, so
// send it now
if runner_id.is_none() {
let data = BASE64.decode(msg.data).context("invalid base64 message")?;
let payload =
let wait_for_shutdown_fut = async move {
while let Some(event) = source.next().await {
match event {
Ok(sse::Event::Open) => {}
Ok(sse::Event::Message(msg)) => {
tracing::debug!(%msg.data, "received outbound req message");

// If runner_id is none at this point it means we did not send the stopping signal yet, so
// send it now
if runner_id.is_none() {
let data = BASE64.decode(msg.data).context("invalid base64 message")?;
let payload =
protocol::versioned::ToServerlessServer::deserialize_with_embedded_version(
&data,
)
.context("invalid payload")?;

match payload {
protocol::ToServerlessServer::ToServerlessServerInit(init) => {
let runner_id =
Id::parse(&init.runner_id).context("invalid runner id")?;
stop_runner(ctx, runner_id).await?;
match payload {
protocol::ToServerlessServer::ToServerlessServerInit(init) => {
let runner_id_local =
Id::parse(&init.runner_id).context("invalid runner id")?;
runner_id = Some(runner_id_local);
stop_runner(ctx, runner_id_local).await?;
}
}
}
}
Err(sse::Error::StreamEnded) => break,
Err(err) => return Err(err.into()),
}
Err(sse::Error::StreamEnded) => break,
Err(err) => return Err(err.into()),
}

Result::<()>::Ok(())
};

// Wait for runner to shut down
tokio::select! {
res = wait_for_shutdown_fut => return res.map_err(Into::into),
_ = tokio::time::sleep(DRAIN_GRACE_PERIOD) => {
tracing::debug!("reached drain grace period before runner shut down")
}

}

// Close connection
//
// This will force the runner to stop the request in order to avoid hitting the serverless
// timeout threshold
if let Some(runner_id) = runner_id {
publish_to_client_stop(ctx, runner_id).await?;
}

tracing::debug!("outbound req stopped");
Expand Down Expand Up @@ -454,3 +481,22 @@ async fn stop_runner(ctx: &StandaloneCtx, runner_id: Id) -> Result<()> {

Ok(())
}

/// Send a stop message to the client.
///
/// This will close the runner's WebSocket..
async fn publish_to_client_stop(ctx: &StandaloneCtx, runner_id: Id) -> Result<()> {
let receiver_subject =
pegboard::pubsub_subjects::RunnerReceiverSubject::new(runner_id).to_string();

let message_serialized = rivet_runner_protocol::versioned::ToClient::latest(
rivet_runner_protocol::ToClient::ToClientClose,
)
.serialize_with_embedded_version(rivet_runner_protocol::PROTOCOL_VERSION)?;

ctx.ups()?
.publish(&receiver_subject, &message_serialized, PublishOpts::one())
.await?;

Ok(())
}
43 changes: 23 additions & 20 deletions packages/infra/engine/src/commands/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct Opts {

/// Exclude the specified services instead of including them
#[arg(long)]
exclude_services: bool,
except_services: Vec<ServiceKind>,
}

#[derive(clap::ValueEnum, Clone, PartialEq)]
Expand Down Expand Up @@ -55,34 +55,37 @@ impl Opts {
}

// Select services to run
let services = if self.services.is_empty() {
let services = if self.services.is_empty() && self.except_services.is_empty() {
// Run all services
run_config.services.clone()
} else if !self.except_services.is_empty() {
// Exclude specified services
let except_service_kinds = self
.except_services
.iter()
.map(|x| x.clone().into())
.collect::<Vec<rivet_service_manager::ServiceKind>>();

run_config
.services
.iter()
.filter(|x| !except_service_kinds.iter().any(|y| y.eq(&x.kind)))
.cloned()
.collect::<Vec<_>>()
} else {
// Filter services
// Include only specified services
let service_kinds = self
.services
.iter()
.map(|x| x.clone().into())
.collect::<Vec<rivet_service_manager::ServiceKind>>();

if self.exclude_services {
// Exclude specified services
run_config
.services
.iter()
.filter(|x| !service_kinds.iter().any(|y| y.eq(&x.kind)))
.cloned()
.collect::<Vec<_>>()
} else {
// Include only specified services
run_config
.services
.iter()
.filter(|x| service_kinds.iter().any(|y| y.eq(&x.kind)))
.cloned()
.collect::<Vec<_>>()
}
run_config
.services
.iter()
.filter(|x| service_kinds.iter().any(|y| y.eq(&x.kind)))
.cloned()
.collect::<Vec<_>>()
};

// Start server
Expand Down
Loading
Loading