diff --git a/packages/rivetkit/src/manager/gateway.ts b/packages/rivetkit/src/manager/gateway.ts index 239ce2d91..2027b6476 100644 --- a/packages/rivetkit/src/manager/gateway.ts +++ b/packages/rivetkit/src/manager/gateway.ts @@ -24,7 +24,9 @@ import { logger } from "./log"; * * Routes requests based on the Upgrade header: * - WebSocket requests: Uses sec-websocket-protocol for routing (target.actor, actor.{id}) + * with fallback to query parameters (x_rivet_target, x_rivet_actor, etc.) * - HTTP requests: Uses x-rivet-target and x-rivet-actor headers for routing + * with fallback to query parameters (x_rivet_target, x_rivet_actor) */ export async function actorGateway( runConfig: RunnerConfig, @@ -63,6 +65,7 @@ export async function actorGateway( /** * Handle WebSocket requests using sec-websocket-protocol for routing + * with fallback to query parameters */ async function handleWebSocketGateway( runConfig: RunnerConfig, @@ -75,6 +78,9 @@ async function handleWebSocketGateway( throw new WebSocketsNotEnabled(); } + // Parse query parameters for fallback + const queryParams = parseQueryParams(c.req.url); + // Parse configuration from Sec-WebSocket-Protocol header const protocols = c.req.header("sec-websocket-protocol"); let target: string | undefined; @@ -105,8 +111,19 @@ async function handleWebSocketGateway( } } + // Fallback to query parameters if not provided via protocols + target = target || queryParams.get("x_rivet_target"); + actorId = actorId || queryParams.get("x_rivet_actor"); + encodingRaw = encodingRaw || queryParams.get("x_rivet_encoding"); + connParamsRaw = connParamsRaw || queryParams.get("x_rivet_conn_params"); + connIdRaw = connIdRaw || queryParams.get("x_rivet_conn_id"); + connTokenRaw = connTokenRaw || queryParams.get("x_rivet_conn_token"); + if (target !== "actor") { - return c.text("WebSocket upgrade requires target.actor protocol", 400); + return c.text( + "WebSocket upgrade requires target.actor protocol or x_rivet_target=actor query parameter", + 400, + ); } if (!actorId) { @@ -141,6 +158,7 @@ async function handleWebSocketGateway( /** * Handle HTTP requests using x-rivet headers for routing + * with fallback to query parameters */ async function handleHttpGateway( managerDriver: ManagerDriver, @@ -148,8 +166,14 @@ async function handleHttpGateway( next: Next, strippedPath: string, ) { - const target = c.req.header(HEADER_RIVET_TARGET); - const actorId = c.req.header(HEADER_RIVET_ACTOR); + // Parse query parameters for fallback + const queryParams = parseQueryParams(c.req.url); + + // Try headers first, then fallback to query parameters + const target = + c.req.header(HEADER_RIVET_TARGET) || queryParams.get("x_rivet_target"); + const actorId = + c.req.header(HEADER_RIVET_ACTOR) || queryParams.get("x_rivet_actor"); if (target !== "actor") { return next(); @@ -186,6 +210,22 @@ async function handleHttpGateway( return await managerDriver.proxyRequest(c, proxyRequest, actorId); } +/** + * Parse query parameters from URL + */ +function parseQueryParams(url: string): Map { + const params = new Map(); + try { + const urlObj = new URL(url, "http://dummy"); // Use dummy base for relative URLs + urlObj.searchParams.forEach((value, key) => { + params.set(key, value); + }); + } catch { + // If URL parsing fails, return empty params + } + return params; +} + /** * Creates a WebSocket proxy for test endpoints that forwards messages between server and client WebSockets */