Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/pkg-pr-new.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
on:
pull_request:
push:
branches:
- main

jobs:
publish:
Expand All @@ -9,5 +12,5 @@ jobs:
- run: corepack enable
- uses: actions/setup-node@v4
- run: pnpm install
- run: pnpm build
- run: pnpm build -F '@rivetkit/*'
- run: pnpm dlx pkg-pr-new publish 'engine/sdks/typescript/runner/' 'engine/sdks/typescript/runner-protocol/' 'rivetkit-typescript/packages/*' --packageManager pnpm --template './examples/*'
13 changes: 7 additions & 6 deletions biome.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
"$schema": "https://biomejs.dev/schemas/2.1.1/schema.json",
"files": {
"includes": [
"**/*.js",
"**/*.json",
"**/*.ts",
"**/*.js",
"!examples/snippets",
"!rivetkit-openapi/openapi.json",
"!engine/artifacts",
"!website",
"!scripts",
"!frontend",
"!engine/sdks",
"!engine/sdks/typescript/api-full",
"!engine/sdks/typescript/runner-protocol"
"!examples/snippets",
"!frontend",
"!rivetkit-openapi/openapi.json",
"!scripts",
"!website"
],
"ignoreUnknown": true
},
Expand Down
3 changes: 0 additions & 3 deletions engine/docker/template/src/docker-compose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,11 @@ export function generateDockerCompose(context: TemplateContext) {
// If host networking is requested, set network_mode for all services
if (context.config.networkMode === "host") {
for (const svc of Object.values(dockerComposeConfig.services)) {
// @ts-expect-error - mutate dynamic service objects
svc.network_mode = "host";
// Remove networks field as it's incompatible with host networking
// @ts-expect-error
if (svc.networks) delete svc.networks;
// Remove ports since published ports are ignored with host networking
// and produce warnings in Docker Compose output.
// @ts-expect-error
if (svc.ports) delete svc.ports;
}
}
Expand Down
146 changes: 122 additions & 24 deletions engine/packages/guard/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ pub(crate) const SEC_WEBSOCKET_PROTOCOL: HeaderName =
HeaderName::from_static("sec-websocket-protocol");
pub(crate) const WS_PROTOCOL_TARGET: &str = "rivet_target.";

#[derive(Debug, Clone)]
pub struct ActorPathInfo {
pub actor_id: String,
pub token: Option<String>,
pub remaining_path: String,
}

/// Creates the main routing function that handles all incoming requests
#[tracing::instrument(skip_all)]
pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> RoutingFn {
Expand All @@ -35,17 +42,35 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->

tracing::debug!("Routing request for hostname: {host}, path: {path}");

// Parse query parameters
let query_params = parse_query_params(path);

// Check if this is a WebSocket upgrade request
let is_websocket = headers
.get("upgrade")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);

// Extract target from WebSocket protocol, HTTP header, or query param
// First, check if this is an actor path-based route
if let Some(actor_path_info) = parse_actor_path(path) {
tracing::debug!(?actor_path_info, "routing using path-based actor routing");

// Route to pegboard gateway with the extracted information
if let Some(routing_output) = pegboard_gateway::route_request_path_based(
&ctx,
&shared_state,
&actor_path_info.actor_id,
actor_path_info.token.as_deref(),
&actor_path_info.remaining_path,
headers,
is_websocket,
)
.await?
{
return Ok(routing_output);
}
}

// Fallback to header-based routing
// Extract target from WebSocket protocol or HTTP header
let target = if is_websocket {
// For WebSocket, parse the sec-websocket-protocol header
headers
Expand All @@ -58,21 +83,15 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
.map(|p| p.trim())
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET))
})
// Fallback to query parameter if protocol not provided
.or_else(|| query_params.get("x_rivet_target").map(|s| s.as_str()))
} else {
// For HTTP, use the x-rivet-target header, fallback to query param
headers
.get(X_RIVET_TARGET)
.and_then(|x| x.to_str().ok())
.or_else(|| query_params.get("x_rivet_target").map(|s| s.as_str()))
// For HTTP, use the x-rivet-target header
headers.get(X_RIVET_TARGET).and_then(|x| x.to_str().ok())
};

// Read target
if let Some(target) = target {
if let Some(routing_output) =
runner::route_request(&ctx, target, host, path, headers, &query_params)
.await?
runner::route_request(&ctx, target, host, path, headers).await?
{
return Ok(routing_output);
}
Expand All @@ -85,7 +104,6 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
path,
headers,
is_websocket,
&query_params,
)
.await?
{
Expand Down Expand Up @@ -120,18 +138,98 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
)
}

/// Parse query parameters from a path string
fn parse_query_params(path: &str) -> std::collections::HashMap<String, String> {
let mut params = std::collections::HashMap::new();
/// Parse actor routing information from path
/// Matches patterns:
/// - /gateway/actors/{actor_id}/tokens/{token}/route/{...path}
/// - /gateway/actors/{actor_id}/route/{...path}
pub fn parse_actor_path(path: &str) -> Option<ActorPathInfo> {
// Find query string position (everything from ? onwards, but before fragment)
let query_pos = path.find('?');
let fragment_pos = path.find('#');

// Extract query string (excluding fragment)
let query_string = match (query_pos, fragment_pos) {
(Some(q), Some(f)) if q < f => &path[q..f],
(Some(q), None) => &path[q..],
_ => "",
};

// Extract base path (before query and fragment)
let base_path = match query_pos {
Some(pos) => &path[..pos],
None => match fragment_pos {
Some(pos) => &path[..pos],
None => path,
},
};

// Check for double slashes (invalid path)
if base_path.contains("//") {
return None;
}

// Split the path into segments
let segments: Vec<&str> = base_path.split('/').filter(|s| !s.is_empty()).collect();

// Check minimum required segments: gateway, actors, {actor_id}, route
if segments.len() < 4 {
return None;
}

// Verify the fixed segments
if segments[0] != "gateway" || segments[1] != "actors" {
return None;
}

// Check for empty actor_id
if segments[2].is_empty() {
return None;
}

if let Some(query_start) = path.find('?') {
// Strip fragment if present
let query = &path[query_start + 1..].split('#').next().unwrap_or("");
// Use url::form_urlencoded to properly decode query parameters
for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
params.insert(key.into_owned(), value.into_owned());
let actor_id = segments[2].to_string();

// Check for token or direct route
let (token, remaining_path_start_idx) =
if segments.len() >= 6 && segments[3] == "tokens" && segments[5] == "route" {
// Pattern with token: /gateway/actors/{actor_id}/tokens/{token}/route/{...path}
// Check for empty token
if segments[4].is_empty() {
return None;
}
(Some(segments[4].to_string()), 6)
} else if segments.len() >= 4 && segments[3] == "route" {
// Pattern without token: /gateway/actors/{actor_id}/route/{...path}
(None, 4)
} else {
return None;
};

// Calculate the position in the original path where remaining path starts
let mut prefix_len = 0;
for (i, segment) in segments.iter().enumerate() {
if i >= remaining_path_start_idx {
break;
}
prefix_len += 1 + segment.len(); // +1 for the slash
}

params
// Extract the remaining path preserving trailing slashes
let remaining_base = if prefix_len < base_path.len() {
&base_path[prefix_len..]
} else {
"/"
};

// Ensure remaining path starts with /
let remaining_path = if remaining_base.is_empty() || !remaining_base.starts_with('/') {
format!("/{}{}", remaining_base, query_string)
} else {
format!("{}{}", remaining_base, query_string)
};

Some(ActorPathInfo {
actor_id,
token,
remaining_path,
})
}
48 changes: 37 additions & 11 deletions engine/packages/guard/src/routing/pegboard_gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,30 @@ use crate::{errors, shared_state::SharedState};

const ACTOR_READY_TIMEOUT: Duration = Duration::from_secs(10);
pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor");
pub const X_RIVET_AMESPACE: HeaderName = HeaderName::from_static("x-rivet-namespace");
const WS_PROTOCOL_ACTOR: &str = "rivet_actor.";
const WS_PROTOCOL_TOKEN: &str = "rivet_token.";

/// Route requests to actor services based on hostname and path
/// Route requests to actor services using path-based routing
#[tracing::instrument(skip_all)]
pub async fn route_request_path_based(
ctx: &StandaloneCtx,
shared_state: &SharedState,
actor_id_str: &str,
_token: Option<&str>,
path: &str,
_headers: &hyper::HeaderMap,
_is_websocket: bool,
) -> Result<Option<RoutingOutput>> {
// NOTE: Token validation implemented in EE

// Parse actor ID
let actor_id = Id::parse(actor_id_str).context("invalid actor id in path")?;

route_request_inner(ctx, shared_state, actor_id, path).await
}

/// Route requests to actor services based on headers
#[tracing::instrument(skip_all)]
pub async fn route_request(
ctx: &StandaloneCtx,
Expand All @@ -22,14 +43,13 @@ pub async fn route_request(
path: &str,
headers: &hyper::HeaderMap,
is_websocket: bool,
query_params: &std::collections::HashMap<String, String>,
) -> Result<Option<RoutingOutput>> {
// Check target
if target != "actor" {
return Ok(None);
}

// Extract actor ID from WebSocket protocol, HTTP header, or query param
// Extract actor ID from WebSocket protocol or HTTP header
let actor_id_str = if is_websocket {
// For WebSocket, parse the sec-websocket-protocol header
headers
Expand All @@ -42,26 +62,22 @@ pub async fn route_request(
.map(|p| p.trim())
.find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR))
})
// Fallback to query parameter if protocol not provided
.or_else(|| query_params.get("x_rivet_actor").map(|s| s.as_str()))
.ok_or_else(|| {
crate::errors::MissingHeader {
header: "`rivet_actor.*` protocol in sec-websocket-protocol or x_rivet_actor query parameter".to_string(),
header: "`rivet_actor.*` protocol in sec-websocket-protocol".to_string(),
}
.build()
})?
} else {
// For HTTP, use the x-rivet-actor header, fallback to query param
// For HTTP, use the x-rivet-actor header
headers
.get(X_RIVET_ACTOR)
.map(|x| x.to_str())
.transpose()
.context("invalid x-rivet-actor header")?
// Fallback to query parameter if header not provided
.or_else(|| query_params.get("x_rivet_actor").map(|s| s.as_str()))
.ok_or_else(|| {
crate::errors::MissingHeader {
header: format!("{} header or x_rivet_actor query parameter", X_RIVET_ACTOR),
header: X_RIVET_ACTOR.to_string(),
}
.build()
})?
Expand All @@ -70,6 +86,15 @@ pub async fn route_request(
// Find actor to route to
let actor_id = Id::parse(actor_id_str).context("invalid x-rivet-actor header")?;

route_request_inner(ctx, shared_state, actor_id, path).await
}

async fn route_request_inner(
ctx: &StandaloneCtx,
shared_state: &SharedState,
actor_id: Id,
path: &str,
) -> Result<Option<RoutingOutput>> {
// Route to peer dc where the actor lives
if actor_id.label() != ctx.config().dc_label() {
tracing::debug!(peer_dc_label=?actor_id.label(), "re-routing actor to peer dc");
Expand Down Expand Up @@ -189,11 +214,12 @@ pub async fn route_request(

tracing::debug!(?actor_id, ?runner_id, "actor ready");

// Return pegboard-gateway instance
// Return pegboard-gateway instance with path
let gateway = pegboard_gateway::PegboardGateway::new(
shared_state.pegboard_gateway.clone(),
runner_id,
actor_id,
path.to_string(),
);
Ok(Some(RoutingOutput::CustomServe(std::sync::Arc::new(
gateway,
Expand Down
Loading
Loading