From 75233814344b86ea3484bf9e6f7b5cf049561c67 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Thu, 25 Sep 2025 20:10:08 -0700 Subject: [PATCH 1/3] chore(core): reimpl reconnect logic --- .../cloudflare-workers/src/manager-driver.ts | 10 ++ packages/rivetkit/src/actor/conn-socket.ts | 6 + packages/rivetkit/src/actor/conn.ts | 81 +++------ packages/rivetkit/src/actor/instance.ts | 91 +++++++++- packages/rivetkit/src/actor/persisted.ts | 2 + .../rivetkit/src/actor/router-endpoints.ts | 158 +++++++++++++----- packages/rivetkit/src/actor/router.ts | 10 ++ packages/rivetkit/src/client/actor-conn.ts | 25 ++- .../src/common/actor-router-consts.ts | 2 + packages/rivetkit/src/driver-helpers/mod.ts | 2 + .../test-inline-client-driver.ts | 2 + .../src/drivers/engine/actor-driver.ts | 3 + .../src/drivers/file-system/manager.ts | 7 + packages/rivetkit/src/manager/driver.ts | 2 + .../actor-websocket-client.ts | 21 ++- .../rivetkit/src/remote-manager-driver/mod.ts | 4 + 16 files changed, 311 insertions(+), 115 deletions(-) create mode 100644 packages/rivetkit/src/actor/conn-socket.ts diff --git a/packages/cloudflare-workers/src/manager-driver.ts b/packages/cloudflare-workers/src/manager-driver.ts index 1edc785ab..6948e3858 100644 --- a/packages/cloudflare-workers/src/manager-driver.ts +++ b/packages/cloudflare-workers/src/manager-driver.ts @@ -9,7 +9,9 @@ import { type ManagerDisplayInformation, type ManagerDriver, WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_ID, WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_CONN_TOKEN, WS_PROTOCOL_ENCODING, WS_PROTOCOL_STANDARD, WS_PROTOCOL_TARGET, @@ -70,6 +72,8 @@ export class CloudflareActorsManagerDriver implements ManagerDriver { actorId: string, encoding: Encoding, params: unknown, + connId?: string, + connToken?: string, ): Promise { const env = getCloudflareAmbientEnv(); @@ -93,6 +97,12 @@ export class CloudflareActorsManagerDriver implements ManagerDriver { `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, ); } + if (connId) { + protocols.push(`${WS_PROTOCOL_CONN_ID}${connId}`); + } + if (connToken) { + protocols.push(`${WS_PROTOCOL_CONN_TOKEN}${connToken}`); + } const headers: Record = { Upgrade: "websocket", diff --git a/packages/rivetkit/src/actor/conn-socket.ts b/packages/rivetkit/src/actor/conn-socket.ts new file mode 100644 index 000000000..ef5b1f095 --- /dev/null +++ b/packages/rivetkit/src/actor/conn-socket.ts @@ -0,0 +1,6 @@ +import type { ConnDriverState } from "./conn-drivers"; + +export interface ConnSocket { + socketId: string; + driverState: ConnDriverState; +} diff --git a/packages/rivetkit/src/actor/conn.ts b/packages/rivetkit/src/actor/conn.ts index 8c1c1a95b..c55a5c59e 100644 --- a/packages/rivetkit/src/actor/conn.ts +++ b/packages/rivetkit/src/actor/conn.ts @@ -9,6 +9,7 @@ import { ConnReadyState, getConnDriverFromState, } from "./conn-drivers"; +import type { ConnSocket } from "./conn-socket"; import type { AnyDatabaseProvider } from "./database"; import * as errors from "./errors"; import type { ActorInstance } from "./instance"; @@ -24,14 +25,16 @@ export function generateConnToken(): string { return generateSecureToken(32); } +export function generateConnSocketId(): string { + return crypto.randomUUID(); +} + export type ConnId = string; export type AnyConn = Conn; export type ConnectionStatus = "connected" | "reconnecting"; -export const CONNECTION_CHECK_LIVENESS_SYMBOL = Symbol("checkLiveness"); - /** * Represents a client connection to a actor. * @@ -45,8 +48,6 @@ export class Conn { // TODO: Remove this cyclical reference #actor: ActorInstance; - #status: ConnectionStatus = "connected"; - /** * The proxied state that notifies of changes automatically. * @@ -54,10 +55,24 @@ export class Conn { */ __persist: PersistedConn; + get __driverState(): ConnDriverState | undefined { + return this.__socket?.driverState; + } + /** - * Driver used to manage connection. If undefined, there is no connection connected. + * Socket connected to this connection. + * + * If undefined, then nothing is connected to this. */ - __driverState?: ConnDriverState; + __socket?: ConnSocket; + + get __status(): ConnectionStatus { + if (this.__socket) { + return "connected"; + } else { + return "reconnecting"; + } + } public get params(): CP { return this.__persist.params; @@ -106,7 +121,7 @@ export class Conn { * Status of the connection. */ public get status(): ConnectionStatus { - return this.#status; + return this.__status; } /** @@ -199,8 +214,6 @@ export class Conn { * @param reason - The reason for disconnection. */ public async disconnect(reason?: string) { - this.#status = "reconnecting"; - if (this.__driverState) { const driver = getConnDriverFromState(this.__driverState); if (driver.disconnect) { @@ -217,55 +230,7 @@ export class Conn { conn: this.id, }); } - } - - /** - * This method checks the connection's liveness by querying the driver for its ready state. - * If the connection is not closed, it updates the last liveness timestamp and returns `true`. - * Otherwise, it returns `false`. - * @internal - */ - [CONNECTION_CHECK_LIVENESS_SYMBOL]() { - let readyState: ConnReadyState | undefined; - - if (this.__driverState) { - const driver = getConnDriverFromState(this.__driverState); - readyState = driver.getConnectionReadyState( - this.#actor, - this, - this.__driverState, - ); - } - - const isConnectionClosed = - readyState === ConnReadyState.CLOSED || - readyState === ConnReadyState.CLOSING || - readyState === undefined; - - const newLastSeen = Date.now(); - const newStatus = isConnectionClosed ? "reconnecting" : "connected"; - - this.#actor.rLog.debug({ - msg: "liveness probe for connection", - connId: this.id, - actorId: this.#actor.id, - readyState, - - status: this.#status, - newStatus, - - lastSeen: this.__persist.lastSeen, - currentTs: newLastSeen, - }); - - if (!isConnectionClosed) { - this.__persist.lastSeen = newLastSeen; - } - this.#status = newStatus; - return { - status: this.#status, - lastSeen: this.__persist.lastSeen, - }; + this.__socket = undefined; } } diff --git a/packages/rivetkit/src/actor/instance.ts b/packages/rivetkit/src/actor/instance.ts index e4c10611b..05ac259fa 100644 --- a/packages/rivetkit/src/actor/instance.ts +++ b/packages/rivetkit/src/actor/instance.ts @@ -22,8 +22,9 @@ import { } from "@/utils"; import type { ActionContext } from "./action"; import type { ActorConfig, OnConnectOptions } from "./config"; -import { CONNECTION_CHECK_LIVENESS_SYMBOL, Conn, type ConnId } from "./conn"; +import { Conn, type ConnId } from "./conn"; import type { ConnDriver, ConnDriverState } from "./conn-drivers"; +import type { ConnSocket } from "./conn-socket"; import { ActorContext } from "./context"; import type { AnyDatabaseProvider, InferDatabaseClient } from "./database"; import type { ActorDriver } from "./driver"; @@ -786,7 +787,23 @@ export class ActorInstance { * * If not a clean disconnect, will keep the connection alive for a given interval to wait for reconnect. */ - __connDisconnected(conn: Conn, wasClean: boolean) { + __connDisconnected( + conn: Conn, + wasClean: boolean, + socketId: string, + ) { + // If socket ID is provided, check if it matches the current socket ID + // If it doesn't match, this is a stale disconnect event from an old socket + if (socketId && conn.__socket && socketId !== conn.__socket.socketId) { + this.rLog.debug({ + msg: "ignoring stale disconnect event", + connId: conn.id, + eventSocketId: socketId, + currentSocketId: conn.__socket.socketId, + }); + return; + } + if (wasClean) { // Disconnected cleanly, remove the conn @@ -798,7 +815,11 @@ export class ActorInstance { this.rLog.warn("called conn disconnected without driver state"); } - conn.__driverState = undefined; + // Update last seen so we know when to clean it up + conn.__persist.lastSeen = Date.now(); + + // Remove socket + conn.__socket = undefined; } } @@ -915,7 +936,7 @@ export class ActorInstance { connectionToken: string, params: CP, state: CS, - driverState: ConnDriverState, + socket: ConnSocket, ): Promise> { this.#assertReady(); @@ -933,7 +954,7 @@ export class ActorInstance { subscriptions: [], }; const conn = new Conn(this, persist); - conn.__driverState = driverState; + conn.__socket = socket; this.#connections.set(conn.id, conn); // Update sleep @@ -991,6 +1012,61 @@ export class ActorInstance { return conn; } + /** + * Reconnect an existing connection with a new driver state. + */ + async reconnectConn( + connectionId: string, + connectionToken: string, + socket: ConnSocket, + ): Promise> { + this.#assertReady(); + + // Find existing connection by ID + const existingConn = this.#connections.get(connectionId); + if (!existingConn) { + throw new errors.UserError(`Connection not found: ${connectionId}`); + } + + // Validate connection token + if (existingConn._token !== connectionToken) { + throw new errors.UserError("Invalid connection token"); + } + + // If there's an existing driver state, disconnect it first + if (existingConn.__driverState) { + await existingConn.disconnect("Reconnecting with new driver state"); + } + + // Update with new driver state + existingConn.__socket = socket; + existingConn.__persist.lastSeen = Date.now(); + + // Update sleep timer since connection is now active + this.#resetSleepTimer(); + + this.inspector.emitter.emit("connectionUpdated"); + + // Send init message for reconnection + existingConn._sendMessage( + new CachedSerializer( + { + body: { + tag: "Init", + val: { + actorId: this.id, + connectionId: existingConn.id, + connectionToken: existingConn._token, + }, + }, + }, + TO_CLIENT_VERSIONED, + ), + ); + + return existingConn; + } + // MARK: Messages async processMessage( message: protocol.ToServer, @@ -1119,11 +1195,10 @@ export class ActorInstance { this.#rLog.debug({ msg: "checking connections liveness" }); for (const conn of this.#connections.values()) { - const liveness = conn[CONNECTION_CHECK_LIVENESS_SYMBOL](); - if (liveness.status === "connected") { + if (conn.__status === "connected") { this.#rLog.debug({ msg: "connection is alive", connId: conn.id }); } else { - const lastSeen = liveness.lastSeen; + const lastSeen = conn.__persist.lastSeen; const sinceLastSeen = Date.now() - lastSeen; if (sinceLastSeen < this.#config.options.connectionLivenessTimeout) { this.#rLog.debug({ diff --git a/packages/rivetkit/src/actor/persisted.ts b/packages/rivetkit/src/actor/persisted.ts index 598a0ebed..40bfb66ff 100644 --- a/packages/rivetkit/src/actor/persisted.ts +++ b/packages/rivetkit/src/actor/persisted.ts @@ -14,6 +14,8 @@ export interface PersistedConn { params: CP; state: CS; subscriptions: PersistedSubscription[]; + + /** Last time the socket was seen. This is set when disconencted so we can determine when we need to clean this up. */ lastSeen: number; } diff --git a/packages/rivetkit/src/actor/router-endpoints.ts b/packages/rivetkit/src/actor/router-endpoints.ts index 98b782aa6..73687f211 100644 --- a/packages/rivetkit/src/actor/router-endpoints.ts +++ b/packages/rivetkit/src/actor/router-endpoints.ts @@ -5,14 +5,21 @@ import type { WSContext } from "hono/ws"; import invariant from "invariant"; import { ActionContext } from "@/actor/action"; import type { AnyConn } from "@/actor/conn"; -import { generateConnId, generateConnToken } from "@/actor/conn"; +import { + generateConnId, + generateConnSocketId, + generateConnToken, +} from "@/actor/conn"; +import { ConnDriverKind } from "@/actor/conn-drivers"; import * as errors from "@/actor/errors"; import type { AnyActorInstance } from "@/actor/instance"; import type { InputData } from "@/actor/protocol/serde"; import { type Encoding, EncodingSchema } from "@/actor/protocol/serde"; import { HEADER_ACTOR_QUERY, + HEADER_CONN_ID, HEADER_CONN_PARAMS, + HEADER_CONN_TOKEN, HEADER_ENCODING, } from "@/common/actor-router-consts"; import type { UpgradeWebSocketArgs } from "@/common/inline-websocket-adapter2"; @@ -32,7 +39,6 @@ import { serializeWithEncoding, } from "@/serde"; import { bufferToArrayBuffer, promiseWithResolvers } from "@/utils"; -import { ConnDriverKind } from "./conn-drivers"; import type { ActorDriver } from "./driver"; import { loggerWithoutContext } from "./log"; import { parseMessage } from "./protocol/old"; @@ -105,6 +111,8 @@ export async function handleWebSocketConnect( actorId: string, encoding: Encoding, parameters: unknown, + connId?: string, + connToken?: string, ): Promise { const exposeInternalError = req ? getRequestExposeInternalError(req) : false; @@ -147,6 +155,7 @@ export async function handleWebSocketConnect( // Promise used to wait for the websocket close in `disconnect` const closePromise = promiseWithResolvers(); + const socketId = generateConnSocketId(); return { onOpen: (_evt: any, ws: WSContext) => { @@ -155,33 +164,56 @@ export async function handleWebSocketConnect( // Run async operations in background (async () => { try { - const connId = generateConnId(); - const connToken = generateConnToken(); - const connState = await actor.prepareConn(parameters, req); - - // Save socket - actor.rLog.debug({ - msg: "registered websocket for conn", - actorId, - }); + let conn: AnyConn; + + // Check if this is a reconnection + if (connId && connToken) { + // This is a reconnection - use the existing connection + actor.rLog.debug({ msg: "websocket reconnection attempt", connId }); + + conn = await actor.reconnectConn(connId, connToken, { + socketId, + driverState: { + [ConnDriverKind.WEBSOCKET]: { + encoding, + websocket: ws, + closePromise, + }, + }, + }); + } else { + // This is a new connection + const newConnId = generateConnId(); + const newConnToken = generateConnToken(); + const connState = await actor.prepareConn(parameters, req); + + // Save socket + actor.rLog.debug({ + msg: "registered websocket for conn", + actorId, + }); - // Create connection - const conn = await actor.createConn( - connId, - connToken, - parameters, - connState, - { - [ConnDriverKind.WEBSOCKET]: { - encoding, - websocket: ws, - closePromise, + // Create connection + conn = await actor.createConn( + newConnId, + newConnToken, + parameters, + connState, + { + socketId, + driverState: { + [ConnDriverKind.WEBSOCKET]: { + encoding, + websocket: ws, + closePromise, + }, + }, }, - }, - ); + ); + } // Unblock other handlers - handlersResolve({ conn, actor, connId }); + handlersResolve({ conn, actor, connId: conn.id }); } catch (error) { handlersReject(error); @@ -280,7 +312,7 @@ export async function handleWebSocketConnect( // Handle cleanup asynchronously handlersPromise .then(({ conn, actor }) => { - actor.__connDisconnected(conn, event.wasClean); + actor.__connDisconnected(conn, event.wasClean, socketId); }) .catch((error) => { deconstructError( @@ -320,31 +352,60 @@ export async function handleSseConnect( const encoding = getRequestEncoding(c.req); const parameters = getRequestConnParams(c.req); + const socketId = generateConnSocketId(); + + // Check for reconnection parameters + const connId = c.req.header(HEADER_CONN_ID); + const connToken = c.req.header(HEADER_CONN_TOKEN); // Return the main handler with all async work inside return streamSSE(c, async (stream) => { let actor: AnyActorInstance | undefined; - let connId: string | undefined; - let connToken: string | undefined; - let connState: unknown; let conn: AnyConn | undefined; try { // Do all async work inside the handler actor = await actorDriver.loadActor(actorId); - connId = generateConnId(); - connToken = generateConnToken(); - connState = await actor.prepareConn(parameters, c.req.raw); - - actor.rLog.debug("sse open"); - - // Create connection - conn = await actor.createConn(connId, connToken, parameters, connState, { - [ConnDriverKind.SSE]: { - encoding, - stream: stream, - }, - }); + + // Check if this is a reconnection + if (connId && connToken) { + // This is a reconnection - use the existing connection + actor.rLog.debug({ msg: "sse reconnection attempt", connId }); + + conn = await actor.reconnectConn(connId, connToken, { + socketId, + driverState: { + [ConnDriverKind.SSE]: { + encoding, + stream: stream, + }, + }, + }); + } else { + // This is a new connection + const newConnId = generateConnId(); + const newConnToken = generateConnToken(); + const connState = await actor.prepareConn(parameters, c.req.raw); + + actor.rLog.debug("sse open"); + + // Create connection + conn = await actor.createConn( + newConnId, + newConnToken, + parameters, + connState, + { + socketId, + driverState: { + [ConnDriverKind.SSE]: { + encoding, + stream: stream, + }, + }, + }, + ); + } // Wait for close const abortResolver = promiseWithResolvers(); @@ -363,7 +424,7 @@ export async function handleSseConnect( // Cleanup if (conn) { - actor.__connDisconnected(conn, false); + actor.__connDisconnected(conn, false, socketId); } abortResolver.resolve(undefined); @@ -399,7 +460,7 @@ export async function handleSseConnect( // Cleanup on error if (conn && actor !== undefined) { - actor.__connDisconnected(conn, false); + actor.__connDisconnected(conn, false, socketId); } // Close the stream on error @@ -429,6 +490,7 @@ export async function handleAction( HTTP_ACTION_REQUEST_VERSIONED, ); const actionArgs = cbor.decode(new Uint8Array(request.args)); + const socketId = generateConnSocketId(); // Invoke the action let actor: AnyActorInstance | undefined; @@ -446,7 +508,10 @@ export async function handleAction( generateConnToken(), parameters, connState, - { [ConnDriverKind.HTTP]: {} }, + { + socketId, + driverState: { [ConnDriverKind.HTTP]: {} }, + }, ); // Call action @@ -454,7 +519,8 @@ export async function handleAction( output = await actor.executeAction(ctx, actionName, actionArgs); } finally { if (conn) { - actor?.__connDisconnected(conn, true); + // HTTP connections don't have persistent sockets, so no socket ID needed + actor?.__connDisconnected(conn, true, socketId); } } diff --git a/packages/rivetkit/src/actor/router.ts b/packages/rivetkit/src/actor/router.ts index b36c158bf..3079eddd6 100644 --- a/packages/rivetkit/src/actor/router.ts +++ b/packages/rivetkit/src/actor/router.ts @@ -23,7 +23,9 @@ import { HEADER_ENCODING, PATH_CONNECT_WEBSOCKET, PATH_RAW_WEBSOCKET_PREFIX, + WS_PROTOCOL_CONN_ID, WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_CONN_TOKEN, WS_PROTOCOL_ENCODING, WS_PROTOCOL_TOKEN, } from "@/common/actor-router-consts"; @@ -88,6 +90,8 @@ export function createActorRouter( const protocols = c.req.header("sec-websocket-protocol"); let encodingRaw: string | undefined; let connParamsRaw: string | undefined; + let connIdRaw: string | undefined; + let connTokenRaw: string | undefined; if (protocols) { const protocolList = protocols.split(",").map((p) => p.trim()); @@ -98,6 +102,10 @@ export function createActorRouter( connParamsRaw = decodeURIComponent( protocol.substring(WS_PROTOCOL_CONN_PARAMS.length), ); + } else if (protocol.startsWith(WS_PROTOCOL_CONN_ID)) { + connIdRaw = protocol.substring(WS_PROTOCOL_CONN_ID.length); + } else if (protocol.startsWith(WS_PROTOCOL_CONN_TOKEN)) { + connTokenRaw = protocol.substring(WS_PROTOCOL_CONN_TOKEN.length); } } } @@ -114,6 +122,8 @@ export function createActorRouter( c.env.actorId, encoding, connParams, + connIdRaw, + connTokenRaw, ); })(c, noopNext()); } else { diff --git a/packages/rivetkit/src/client/actor-conn.ts b/packages/rivetkit/src/client/actor-conn.ts index 9a3c21327..584959a46 100644 --- a/packages/rivetkit/src/client/actor-conn.ts +++ b/packages/rivetkit/src/client/actor-conn.ts @@ -99,7 +99,7 @@ export class ActorConnRaw { /** If attempting to connect. Helpful for knowing if in a retry loop when reconnecting. */ #connecting = false; - // These will only be set on SSE driver + // Connection info, used for reconnection and HTTP requests #actorId?: string; #connectionId?: string; #connectionToken?: string; @@ -282,11 +282,24 @@ enc this.#actorQuery, this.#driver, ); + + // Check if we have connection info for reconnection + const isReconnection = this.#connectionId && this.#connectionToken; + if (isReconnection) { + logger().debug({ + msg: "attempting websocket reconnection", + connectionId: this.#connectionId, + }); + } + const ws = await this.#driver.openWebSocket( PATH_CONNECT_WEBSOCKET, actorId, this.#encoding, this.#params, + // Pass connection ID and token for reconnection if available + isReconnection ? this.#connectionId : undefined, + isReconnection ? this.#connectionToken : undefined, ); this.#transport = { websocket: ws }; ws.addEventListener("open", () => { @@ -321,6 +334,8 @@ enc encoding: this.#encoding, }); + const isReconnection = this.#connectionId && this.#connectionToken; + const eventSource = new EventSource("http://actor/connect/sse", { fetch: (input, init) => { return this.#driver.sendRequest( @@ -334,6 +349,12 @@ enc ...(this.#params !== undefined ? { [HEADER_CONN_PARAMS]: JSON.stringify(this.#params) } : {}), + ...(isReconnection + ? { + [HEADER_CONN_ID]: this.#connectionId, + [HEADER_CONN_TOKEN]: this.#connectionToken, + } + : {}), }, }), ); @@ -403,7 +424,7 @@ enc ); if (response.body.tag === "Init") { - // This is only called for SSE + // Store connection info for reconnection this.#actorId = response.body.val.actorId; this.#connectionId = response.body.val.connectionId; this.#connectionToken = response.body.val.connectionToken; diff --git a/packages/rivetkit/src/common/actor-router-consts.ts b/packages/rivetkit/src/common/actor-router-consts.ts index 0abd47204..9f8be14a7 100644 --- a/packages/rivetkit/src/common/actor-router-consts.ts +++ b/packages/rivetkit/src/common/actor-router-consts.ts @@ -31,6 +31,8 @@ export const WS_PROTOCOL_TARGET = "rivet_target."; export const WS_PROTOCOL_ACTOR = "rivet_actor."; export const WS_PROTOCOL_ENCODING = "rivet_encoding."; export const WS_PROTOCOL_CONN_PARAMS = "rivet_conn_params."; +export const WS_PROTOCOL_CONN_ID = "rivet_conn."; +export const WS_PROTOCOL_CONN_TOKEN = "rivet_conn_token."; export const WS_PROTOCOL_TOKEN = "rivet_token."; // MARK: WebSocket Inline Test Protocol Prefixes diff --git a/packages/rivetkit/src/driver-helpers/mod.ts b/packages/rivetkit/src/driver-helpers/mod.ts index c1ac9f757..b37cb203c 100644 --- a/packages/rivetkit/src/driver-helpers/mod.ts +++ b/packages/rivetkit/src/driver-helpers/mod.ts @@ -13,7 +13,9 @@ export { PATH_CONNECT_WEBSOCKET, PATH_RAW_WEBSOCKET_PREFIX, WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_ID, WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_CONN_TOKEN, WS_PROTOCOL_ENCODING, WS_PROTOCOL_PATH, WS_PROTOCOL_STANDARD, diff --git a/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts b/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts index 95f786bba..efe5da24e 100644 --- a/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts +++ b/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts @@ -160,6 +160,8 @@ export function createTestInlineClientDriver( actorId: string, encoding: Encoding, params: unknown, + connId?: string, + connToken?: string, ): Promise { const WebSocket = await importWebSocket(); diff --git a/packages/rivetkit/src/drivers/engine/actor-driver.ts b/packages/rivetkit/src/drivers/engine/actor-driver.ts index 116d4af41..aa4923aa3 100644 --- a/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -322,6 +322,9 @@ export class EngineActorDriver implements ActorDriver { actorId, encoding, connParams, + // Extract connId and connToken from protocols if needed + undefined, + undefined, ); } else if (url.pathname.startsWith(PATH_RAW_WEBSOCKET_PREFIX)) { wsHandlerPromise = handleRawWebSocketHandler( diff --git a/packages/rivetkit/src/drivers/file-system/manager.ts b/packages/rivetkit/src/drivers/file-system/manager.ts index ac187c29a..360ae857a 100644 --- a/packages/rivetkit/src/drivers/file-system/manager.ts +++ b/packages/rivetkit/src/drivers/file-system/manager.ts @@ -137,6 +137,8 @@ export class FileSystemManagerDriver implements ManagerDriver { actorId: string, encoding: Encoding, params: unknown, + connId?: string, + connToken?: string, ): Promise { // Handle raw WebSocket paths const pathOnly = path.split("?")[0]; @@ -150,6 +152,8 @@ export class FileSystemManagerDriver implements ManagerDriver { actorId, encoding, params, + connId, + connToken, ); return new InlineWebSocketAdapter2(wsHandler); } else if ( @@ -202,6 +206,9 @@ export class FileSystemManagerDriver implements ManagerDriver { actorId, encoding, connParams, + // Extract connId and connToken from query parameters or headers if needed + undefined, + undefined, ); return upgradeWebSocket(() => wsHandler)(c, noopNext()); } else if ( diff --git a/packages/rivetkit/src/manager/driver.ts b/packages/rivetkit/src/manager/driver.ts index 802e3f210..9d0b32270 100644 --- a/packages/rivetkit/src/manager/driver.ts +++ b/packages/rivetkit/src/manager/driver.ts @@ -21,6 +21,8 @@ export interface ManagerDriver { actorId: string, encoding: Encoding, params: unknown, + connId?: string, + connToken?: string, ): Promise; proxyRequest( c: HonoContext, diff --git a/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts b/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts index 023db8930..fe5fdd933 100644 --- a/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts +++ b/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts @@ -3,7 +3,9 @@ import { HEADER_CONN_PARAMS, HEADER_ENCODING, WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_ID, WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_CONN_TOKEN, WS_PROTOCOL_ENCODING, WS_PROTOCOL_STANDARD as WS_PROTOCOL_RIVETKIT, WS_PROTOCOL_TARGET, @@ -21,6 +23,8 @@ export async function openWebSocketToActor( actorId: string, encoding: Encoding, params: unknown, + connId?: string, + connToken?: string, ): Promise { const WebSocket = await importWebSocket(); @@ -38,7 +42,14 @@ export async function openWebSocketToActor( // Create WebSocket connection const ws = new WebSocket( guardUrl, - buildWebSocketProtocols(runConfig, actorId, encoding, params), + buildWebSocketProtocols( + runConfig, + actorId, + encoding, + params, + connId, + connToken, + ), ); // Set binary type to arraybuffer for proper encoding support @@ -54,6 +65,8 @@ export function buildWebSocketProtocols( actorId: string, encoding: Encoding, params?: unknown, + connId?: string, + connToken?: string, ): string[] { const protocols: string[] = []; protocols.push(WS_PROTOCOL_RIVETKIT); @@ -68,5 +81,11 @@ export function buildWebSocketProtocols( `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, ); } + if (connId) { + protocols.push(`${WS_PROTOCOL_CONN_ID}${connId}`); + } + if (connToken) { + protocols.push(`${WS_PROTOCOL_CONN_TOKEN}${connToken}`); + } return protocols; } diff --git a/packages/rivetkit/src/remote-manager-driver/mod.ts b/packages/rivetkit/src/remote-manager-driver/mod.ts index ac63db79c..6694c191f 100644 --- a/packages/rivetkit/src/remote-manager-driver/mod.ts +++ b/packages/rivetkit/src/remote-manager-driver/mod.ts @@ -206,6 +206,8 @@ export class RemoteManagerDriver implements ManagerDriver { actorId: string, encoding: Encoding, params: unknown, + connId?: string, + connToken?: string, ): Promise { return await openWebSocketToActor( this.#config, @@ -213,6 +215,8 @@ export class RemoteManagerDriver implements ManagerDriver { actorId, encoding, params, + connId, + connToken, ); } From a8283f128e159e09621b5ee0a5ff3f75721684d0 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Thu, 25 Sep 2025 20:59:11 -0700 Subject: [PATCH 2/3] chore(core): update sse protocol to have http request to close stream gracefully --- AGENTS.md | 2 +- packages/rivetkit/src/actor/conn.ts | 4 +- .../rivetkit/src/actor/router-endpoints.ts | 37 +++++++++++++++++++ packages/rivetkit/src/actor/router.ts | 17 +++++++++ packages/rivetkit/src/client/actor-conn.ts | 24 ++++++++++++ 5 files changed, 82 insertions(+), 2 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index f56123434..6c785eb0e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -101,7 +101,7 @@ This ensures imports resolve correctly across different build environments and p - Write log messages in lowercase - Use `logger()` to log messages - Do not store `logger()` as a variable, always call it using `logger().info("...")` - - Use structured logging where it makes sense, for example: `logger().info("foo", { bar: 5, baz: 10 })` + - Use structured logging where it makes sense, for example: `logger().info({ msg: "foo", bar: 5, baz: 10 })` - Supported logging methods are: trace, debug, info, warn, error, critical - Instead of returning errors as raw HTTP responses with c.json, use or write an error in packages/rivetkit/src/actor/errors.ts and throw that instead. The middleware will automatically serialize the response for you. diff --git a/packages/rivetkit/src/actor/conn.ts b/packages/rivetkit/src/actor/conn.ts index c55a5c59e..1d0bd6a27 100644 --- a/packages/rivetkit/src/actor/conn.ts +++ b/packages/rivetkit/src/actor/conn.ts @@ -214,7 +214,7 @@ export class Conn { * @param reason - The reason for disconnection. */ public async disconnect(reason?: string) { - if (this.__driverState) { + if (this.__socket && this.__driverState) { const driver = getConnDriverFromState(this.__driverState); if (driver.disconnect) { driver.disconnect(this.#actor, this, this.__driverState, reason); @@ -224,6 +224,8 @@ export class Conn { conn: this.id, }); } + + this.#actor.__connDisconnected(this, true, this.__socket.socketId); } else { this.#actor.rLog.warn({ msg: "missing connection driver state for disconnect", diff --git a/packages/rivetkit/src/actor/router-endpoints.ts b/packages/rivetkit/src/actor/router-endpoints.ts index 73687f211..377927ce4 100644 --- a/packages/rivetkit/src/actor/router-endpoints.ts +++ b/packages/rivetkit/src/actor/router-endpoints.ts @@ -580,6 +580,43 @@ export async function handleConnectionMessage( return c.json({}); } +export async function handleConnectionClose( + c: HonoContext, + _runConfig: RunConfig, + actorDriver: ActorDriver, + connId: string, + connToken: string, + actorId: string, +) { + const actor = await actorDriver.loadActor(actorId); + + // Find connection + const conn = actor.conns.get(connId); + if (!conn) { + throw new errors.ConnNotFound(connId); + } + + // Authenticate connection + if (conn._token !== connToken) { + throw new errors.IncorrectConnToken(); + } + + // Check if this is an SSE connection + if ( + !conn.__socket?.driverState || + !(ConnDriverKind.SSE in conn.__socket.driverState) + ) { + throw new errors.UserError( + "Connection close is only supported for SSE connections", + ); + } + + // Close the SSE connection + await conn.disconnect("Connection closed by client request"); + + return c.json({}); +} + export async function handleRawWebSocketHandler( req: Request | undefined, path: string, diff --git a/packages/rivetkit/src/actor/router.ts b/packages/rivetkit/src/actor/router.ts index 3079eddd6..2678805ef 100644 --- a/packages/rivetkit/src/actor/router.ts +++ b/packages/rivetkit/src/actor/router.ts @@ -11,6 +11,7 @@ import { type ConnectWebSocketOutput, type ConnsMessageOpts, handleAction, + handleConnectionClose, handleConnectionMessage, handleRawWebSocketHandler, handleSseConnect, @@ -160,6 +161,22 @@ export function createActorRouter( ); }); + router.post("/connections/close", async (c) => { + const connId = c.req.header(HEADER_CONN_ID); + const connToken = c.req.header(HEADER_CONN_TOKEN); + if (!connId || !connToken) { + throw new Error("Missing required parameters"); + } + return handleConnectionClose( + c, + runConfig, + actorDriver, + connId, + connToken, + c.env.actorId, + ); + }); + // Raw HTTP endpoints - /http/* router.all("/raw/http/*", async (c) => { const actor = await actorDriver.loadActor(c.env.actorId); diff --git a/packages/rivetkit/src/client/actor-conn.ts b/packages/rivetkit/src/client/actor-conn.ts index 584959a46..3eeffc765 100644 --- a/packages/rivetkit/src/client/actor-conn.ts +++ b/packages/rivetkit/src/client/actor-conn.ts @@ -853,6 +853,30 @@ enc await promise; } } else if ("sse" in this.#transport) { + // Send close request to server for SSE connections + if (this.#connectionId && this.#connectionToken) { + try { + await sendHttpRequest({ + url: "http://actor/connections/close", + method: "POST", + headers: { + [HEADER_CONN_ID]: this.#connectionId, + [HEADER_CONN_TOKEN]: this.#connectionToken, + }, + encoding: this.#encoding, + skipParseResponse: true, + customFetch: this.#driver.sendRequest.bind( + this.#driver, + this.#actorId!, + ), + requestVersionedDataHandler: TO_SERVER_VERSIONED, + responseVersionedDataHandler: TO_CLIENT_VERSIONED, + }); + } catch (error) { + // Ignore errors when closing - connection may already be closed + logger().warn({ msg: "failed to send close request", error }); + } + } this.#transport.sse.close(); } else { assertUnreachable(this.#transport); From 53647086400ff7837206dd7db6c5ed45f3943e23 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Thu, 25 Sep 2025 21:43:39 -0700 Subject: [PATCH 3/3] chore(core): test actor reconnect logic --- .../driver-test-suite/counter-conn.ts | 33 ++++ .../fixtures/driver-test-suite/registry.ts | 3 + packages/rivetkit/src/actor/conn-drivers.ts | 10 +- packages/rivetkit/src/actor/conn.ts | 23 ++- packages/rivetkit/src/actor/instance.ts | 181 ++++++++++-------- .../rivetkit/src/actor/router-endpoints.ts | 103 +++------- packages/rivetkit/src/actor/router.ts | 33 ++++ packages/rivetkit/src/actor/utils.ts | 6 +- packages/rivetkit/src/client/actor-conn.ts | 110 +++++++++-- .../rivetkit/src/driver-test-suite/mod.ts | 3 + .../tests/actor-reconnect.ts | 160 ++++++++++++++++ .../src/drivers/file-system/manager.ts | 7 +- packages/rivetkit/src/manager/driver.ts | 2 + packages/rivetkit/src/manager/gateway.ts | 10 + packages/rivetkit/src/manager/router.ts | 55 +++++- .../rivetkit/src/remote-manager-driver/mod.ts | 4 + 16 files changed, 556 insertions(+), 187 deletions(-) create mode 100644 packages/rivetkit/fixtures/driver-test-suite/counter-conn.ts create mode 100644 packages/rivetkit/src/driver-test-suite/tests/actor-reconnect.ts diff --git a/packages/rivetkit/fixtures/driver-test-suite/counter-conn.ts b/packages/rivetkit/fixtures/driver-test-suite/counter-conn.ts new file mode 100644 index 000000000..93d4a6c0a --- /dev/null +++ b/packages/rivetkit/fixtures/driver-test-suite/counter-conn.ts @@ -0,0 +1,33 @@ +import { actor } from "rivetkit"; + +export const counterConn = actor({ + state: { + connectionCount: 0, + }, + connState: { count: 0 }, + onConnect: (c, conn) => { + c.state.connectionCount += 1; + }, + onDisconnect: (c, conn) => { + // Note: We can't determine if disconnect was graceful from here + // For testing purposes, we'll decrement on all disconnects + // In real scenarios, you'd use connection tracking with timeouts + c.state.connectionCount -= 1; + }, + actions: { + increment: (c, x: number) => { + c.conn.state.count += x; + c.broadcast("newCount", c.conn.state.count); + }, + setCount: (c, x: number) => { + c.conn.state.count = x; + c.broadcast("newCount", x); + }, + getCount: (c) => { + return c.conn.state.count; + }, + getConnectionCount: (c) => { + return c.state.connectionCount; + }, + }, +}); diff --git a/packages/rivetkit/fixtures/driver-test-suite/registry.ts b/packages/rivetkit/fixtures/driver-test-suite/registry.ts index 844dc2e73..154decc40 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/registry.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/registry.ts @@ -17,6 +17,7 @@ import { counterWithParams } from "./conn-params"; import { connStateActor } from "./conn-state"; // Import actors from individual files import { counter } from "./counter"; +import { counterConn } from "./counter-conn"; import { customTimeoutActor, errorHandlingActor } from "./error-handling"; import { inlineClientActor } from "./inline-client"; import { counterWithLifecycle } from "./lifecycle"; @@ -51,6 +52,8 @@ export const registry = setup({ use: { // From counter.ts counter, + // From counter-conn.ts + counterConn, // From lifecycle.ts counterWithLifecycle, // From scheduled.ts diff --git a/packages/rivetkit/src/actor/conn-drivers.ts b/packages/rivetkit/src/actor/conn-drivers.ts index c3a31a1e5..94d48ecaa 100644 --- a/packages/rivetkit/src/actor/conn-drivers.ts +++ b/packages/rivetkit/src/actor/conn-drivers.ts @@ -195,11 +195,11 @@ export const CONN_DRIVERS: Record> = { [ConnDriverKind.HTTP]: HTTP_DRIVER, }; -export function getConnDriverFromState( +export function getConnDriverKindFromState( state: ConnDriverState, -): ConnDriver { - if (ConnDriverKind.WEBSOCKET in state) return WEBSOCKET_DRIVER; - else if (ConnDriverKind.SSE in state) return SSE_DRIVER; - else if (ConnDriverKind.HTTP in state) return SSE_DRIVER; +): ConnDriverKind { + if (ConnDriverKind.WEBSOCKET in state) return ConnDriverKind.WEBSOCKET; + else if (ConnDriverKind.SSE in state) return ConnDriverKind.SSE; + else if (ConnDriverKind.HTTP in state) return ConnDriverKind.HTTP; else assertUnreachable(state); } diff --git a/packages/rivetkit/src/actor/conn.ts b/packages/rivetkit/src/actor/conn.ts index 1d0bd6a27..83bdd0a4f 100644 --- a/packages/rivetkit/src/actor/conn.ts +++ b/packages/rivetkit/src/actor/conn.ts @@ -1,4 +1,5 @@ import * as cbor from "cbor-x"; +import invariant from "invariant"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { bufferToArrayBuffer } from "@/utils"; @@ -7,7 +8,7 @@ import { ConnDriverKind, type ConnDriverState, ConnReadyState, - getConnDriverFromState, + getConnDriverKindFromState, } from "./conn-drivers"; import type { ConnSocket } from "./conn-socket"; import type { AnyDatabaseProvider } from "./database"; @@ -161,9 +162,15 @@ export class Conn { */ public _sendMessage(message: CachedSerializer) { if (this.__driverState) { - const driver = getConnDriverFromState(this.__driverState); + const driverKind = getConnDriverKindFromState(this.__driverState); + const driver = CONN_DRIVERS[driverKind]; if (driver.sendMessage) { - driver.sendMessage(this.#actor, this, this.__driverState, message); + driver.sendMessage( + this.#actor, + this, + (this.__driverState as any)[driverKind], + message, + ); } else { this.#actor.rLog.debug({ msg: "conn driver does not support sending messages", @@ -215,9 +222,15 @@ export class Conn { */ public async disconnect(reason?: string) { if (this.__socket && this.__driverState) { - const driver = getConnDriverFromState(this.__driverState); + const driverKind = getConnDriverKindFromState(this.__driverState); + const driver = CONN_DRIVERS[driverKind]; if (driver.disconnect) { - driver.disconnect(this.#actor, this, this.__driverState, reason); + driver.disconnect( + this.#actor, + this, + (this.__driverState as any)[driverKind], + reason, + ); } else { this.#actor.rLog.debug({ msg: "no disconnect handler for conn driver", diff --git a/packages/rivetkit/src/actor/instance.ts b/packages/rivetkit/src/actor/instance.ts index 05ac259fa..9005de913 100644 --- a/packages/rivetkit/src/actor/instance.ts +++ b/packages/rivetkit/src/actor/instance.ts @@ -22,8 +22,13 @@ import { } from "@/utils"; import type { ActionContext } from "./action"; import type { ActorConfig, OnConnectOptions } from "./config"; -import { Conn, type ConnId } from "./conn"; -import type { ConnDriver, ConnDriverState } from "./conn-drivers"; +import { Conn, type ConnId, generateConnId, generateConnToken } from "./conn"; +import { + CONN_DRIVERS, + type ConnDriver, + type ConnDriverState, + getConnDriverKindFromState, +} from "./conn-drivers"; import type { ConnSocket } from "./conn-socket"; import { ActorContext } from "./context"; import type { AnyDatabaseProvider, InferDatabaseClient } from "./database"; @@ -811,7 +816,7 @@ export class ActorInstance { } else { // Disconnected uncleanly, allow reconnection - if (conn.__driverState) { + if (!conn.__driverState) { this.rLog.warn("called conn disconnected without driver state"); } @@ -874,12 +879,96 @@ export class ActorInstance { this.#resetSleepTimer(); } - async prepareConn( + /** + * Called to create a new connection or reconnect an existing one. + */ + async createConn( + socket: ConnSocket, // biome-ignore lint/suspicious/noExplicitAny: TypeScript bug with ExtractActorConnParams, params: any, request?: Request, - ): Promise { - // Authenticate connection + connectionId?: string, + connectionToken?: string, + ): Promise> { + this.#assertReady(); + + // If connection ID and token are provided, try to reconnect + if (connectionId && connectionToken) { + this.rLog.debug({ + msg: "checking for existing connection", + connectionId, + }); + const existingConn = this.#connections.get(connectionId); + if (existingConn && existingConn._token === connectionToken) { + // This is a valid reconnection + this.rLog.debug({ + msg: "reconnecting existing connection", + connectionId, + }); + + // If there's an existing driver state, clean it up without marking as clean disconnect + if (existingConn.__driverState) { + const driverKind = getConnDriverKindFromState( + existingConn.__driverState, + ); + const driver = CONN_DRIVERS[driverKind]; + if (driver.disconnect) { + // Call driver disconnect to clean up directly. Don't use Conn.disconnect since that will remove the connection entirely. + driver.disconnect( + this, + existingConn, + (existingConn.__driverState as any)[driverKind], + "Reconnecting with new driver state", + ); + } + } + + // Update with new driver state + existingConn.__socket = socket; + existingConn.__persist.lastSeen = Date.now(); + + // Update sleep timer since connection is now active + this.#resetSleepTimer(); + + this.inspector.emitter.emit("connectionUpdated"); + + // Send init message for reconnection + existingConn._sendMessage( + new CachedSerializer( + { + body: { + tag: "Init", + val: { + actorId: this.id, + connectionId: existingConn.id, + connectionToken: existingConn._token, + }, + }, + }, + TO_CLIENT_VERSIONED, + ), + ); + + return existingConn; + } + + // If we get here, either connection doesn't exist or token doesn't match + // Fall through to create new connection with new IDs + this.rLog.debug({ + msg: "connection not found or token mismatch, creating new connection", + connectionId, + }); + } + + // Generate new connection ID and token if not provided or if reconnection failed + const newConnId = generateConnId(); + const newConnToken = generateConnToken(); + + if (this.#connections.has(newConnId)) { + throw new Error(`Connection already exists: ${newConnId}`); + } + + // Prepare connection state let connState: CS | undefined; const onBeforeConnectOpts = { @@ -925,31 +1014,12 @@ export class ActorInstance { } } - return connState as CS; - } - - /** - * Called after establishing a connection handshake. - */ - async createConn( - connectionId: string, - connectionToken: string, - params: CP, - state: CS, - socket: ConnSocket, - ): Promise> { - this.#assertReady(); - - if (this.#connections.has(connectionId)) { - throw new Error(`Connection already exists: ${connectionId}`); - } - // Create connection const persist: PersistedConn = { - connId: connectionId, - token: connectionToken, + connId: newConnId, + token: newConnToken, params: params, - state: state, + state: connState as CS, lastSeen: Date.now(), subscriptions: [], }; @@ -1012,61 +1082,6 @@ export class ActorInstance { return conn; } - /** - * Reconnect an existing connection with a new driver state. - */ - async reconnectConn( - connectionId: string, - connectionToken: string, - socket: ConnSocket, - ): Promise> { - this.#assertReady(); - - // Find existing connection by ID - const existingConn = this.#connections.get(connectionId); - if (!existingConn) { - throw new errors.UserError(`Connection not found: ${connectionId}`); - } - - // Validate connection token - if (existingConn._token !== connectionToken) { - throw new errors.UserError("Invalid connection token"); - } - - // If there's an existing driver state, disconnect it first - if (existingConn.__driverState) { - await existingConn.disconnect("Reconnecting with new driver state"); - } - - // Update with new driver state - existingConn.__socket = socket; - existingConn.__persist.lastSeen = Date.now(); - - // Update sleep timer since connection is now active - this.#resetSleepTimer(); - - this.inspector.emitter.emit("connectionUpdated"); - - // Send init message for reconnection - existingConn._sendMessage( - new CachedSerializer( - { - body: { - tag: "Init", - val: { - actorId: this.id, - connectionId: existingConn.id, - connectionToken: existingConn._token, - }, - }, - }, - TO_CLIENT_VERSIONED, - ), - ); - - return existingConn; - } - // MARK: Messages async processMessage( message: protocol.ToServer, diff --git a/packages/rivetkit/src/actor/router-endpoints.ts b/packages/rivetkit/src/actor/router-endpoints.ts index 377927ce4..8472c9631 100644 --- a/packages/rivetkit/src/actor/router-endpoints.ts +++ b/packages/rivetkit/src/actor/router-endpoints.ts @@ -166,12 +166,17 @@ export async function handleWebSocketConnect( try { let conn: AnyConn; - // Check if this is a reconnection - if (connId && connToken) { - // This is a reconnection - use the existing connection - actor.rLog.debug({ msg: "websocket reconnection attempt", connId }); + // Create or reconnect connection + actor.rLog.debug({ + msg: connId + ? "websocket reconnection attempt" + : "new websocket connection", + connId, + actorId, + }); - conn = await actor.reconnectConn(connId, connToken, { + conn = await actor.createConn( + { socketId, driverState: { [ConnDriverKind.WEBSOCKET]: { @@ -180,37 +185,12 @@ export async function handleWebSocketConnect( closePromise, }, }, - }); - } else { - // This is a new connection - const newConnId = generateConnId(); - const newConnToken = generateConnToken(); - const connState = await actor.prepareConn(parameters, req); - - // Save socket - actor.rLog.debug({ - msg: "registered websocket for conn", - actorId, - }); - - // Create connection - conn = await actor.createConn( - newConnId, - newConnToken, - parameters, - connState, - { - socketId, - driverState: { - [ConnDriverKind.WEBSOCKET]: { - encoding, - websocket: ws, - closePromise, - }, - }, - }, - ); - } + }, + parameters, + req, + connId, + connToken, + ); // Unblock other handlers handlersResolve({ conn, actor, connId: conn.id }); @@ -367,12 +347,14 @@ export async function handleSseConnect( // Do all async work inside the handler actor = await actorDriver.loadActor(actorId); - // Check if this is a reconnection - if (connId && connToken) { - // This is a reconnection - use the existing connection - actor.rLog.debug({ msg: "sse reconnection attempt", connId }); + // Create or reconnect connection + actor.rLog.debug({ + msg: connId ? "sse reconnection attempt" : "sse open", + connId, + }); - conn = await actor.reconnectConn(connId, connToken, { + conn = await actor.createConn( + { socketId, driverState: { [ConnDriverKind.SSE]: { @@ -380,32 +362,12 @@ export async function handleSseConnect( stream: stream, }, }, - }); - } else { - // This is a new connection - const newConnId = generateConnId(); - const newConnToken = generateConnToken(); - const connState = await actor.prepareConn(parameters, c.req.raw); - - actor.rLog.debug("sse open"); - - // Create connection - conn = await actor.createConn( - newConnId, - newConnToken, - parameters, - connState, - { - socketId, - driverState: { - [ConnDriverKind.SSE]: { - encoding, - stream: stream, - }, - }, - }, - ); - } + }, + parameters, + c.req.raw, + connId, + connToken, + ); // Wait for close const abortResolver = promiseWithResolvers(); @@ -502,16 +464,13 @@ export async function handleAction( actor.rLog.debug({ msg: "handling action", actionName, encoding }); // Create conn - const connState = await actor.prepareConn(parameters, c.req.raw); conn = await actor.createConn( - generateConnId(), - generateConnToken(), - parameters, - connState, { socketId, driverState: { [ConnDriverKind.HTTP]: {} }, }, + parameters, + c.req.raw, ); // Call action diff --git a/packages/rivetkit/src/actor/router.ts b/packages/rivetkit/src/actor/router.ts index 2678805ef..1b9752b98 100644 --- a/packages/rivetkit/src/actor/router.ts +++ b/packages/rivetkit/src/actor/router.ts @@ -42,6 +42,7 @@ import { } from "@/inspector/actor"; import { isInspectorEnabled, secureInspector } from "@/inspector/utils"; import type { RunConfig } from "@/registry/run-config"; +import { ConnDriverKind } from "./conn-drivers"; import type { ActorDriver } from "./driver"; import { InternalError } from "./errors"; import { loggerWithoutContext } from "./log"; @@ -83,6 +84,38 @@ export function createActorRouter( return c.text("ok"); }); + // Test endpoint to force disconnect a connection non-cleanly + router.post("/.test/force-disconnect", async (c) => { + const connId = c.req.query("conn"); + + if (!connId) { + return c.text("Missing conn query parameter", 400); + } + + const actor = await actorDriver.loadActor(c.env.actorId); + const conn = actor.__getConnForId(connId); + + if (!conn) { + return c.text(`Connection not found: ${connId}`, 404); + } + + // Force close the websocket/SSE connection without clean shutdown + const driverState = conn.__driverState; + if (driverState && ConnDriverKind.WEBSOCKET in driverState) { + const ws = driverState[ConnDriverKind.WEBSOCKET].websocket; + + // Force close without sending close frame + (ws.raw as any).terminate(); + } else if (driverState && ConnDriverKind.SSE in driverState) { + const stream = driverState[ConnDriverKind.SSE].stream; + + // Force close the SSE stream + stream.abort(); + } + + return c.json({ success: true }); + }); + router.get(PATH_CONNECT_WEBSOCKET, async (c) => { const upgradeWebSocket = runConfig.getUpgradeWebSocket?.(); if (upgradeWebSocket) { diff --git a/packages/rivetkit/src/actor/utils.ts b/packages/rivetkit/src/actor/utils.ts index 80d76fcd6..33c7559da 100644 --- a/packages/rivetkit/src/actor/utils.ts +++ b/packages/rivetkit/src/actor/utils.ts @@ -87,7 +87,11 @@ export class Lock { export function generateSecureToken(length = 32) { const array = new Uint8Array(length); crypto.getRandomValues(array); - return btoa(String.fromCharCode(...array)); + // Replace base64 chars that are not URL safe with URL-safe chars and strip padding + return btoa(String.fromCharCode(...array)) + .replace(/\+/g, "-") + .replace(/\//g, "_") + .replace(/=/g, ""); } export function generateRandomString(length = 32) { diff --git a/packages/rivetkit/src/client/actor-conn.ts b/packages/rivetkit/src/client/actor-conn.ts index 3eeffc765..e99c29fe8 100644 --- a/packages/rivetkit/src/client/actor-conn.ts +++ b/packages/rivetkit/src/client/actor-conn.ts @@ -301,9 +301,18 @@ enc isReconnection ? this.#connectionId : undefined, isReconnection ? this.#connectionToken : undefined, ); + logger().debug({ + msg: "transport set to new websocket", + connectionId: this.#connectionId, + readyState: ws.readyState, + messageQueueLength: this.#messageQueue.length, + }); this.#transport = { websocket: ws }; ws.addEventListener("open", () => { - logger().debug({ msg: "client websocket open" }); + logger().debug({ + msg: "client websocket open", + connectionId: this.#connectionId, + }); }); ws.addEventListener("message", async (ev) => { this.#handleOnMessage(ev.data); @@ -380,6 +389,7 @@ enc logger().debug({ msg: "socket open", messageQueueLength: this.#messageQueue.length, + connectionId: this.#connectionId, }); // Resolve open promise @@ -399,6 +409,10 @@ enc // If the message fails to send, the message will be re-queued const queue = this.#messageQueue; this.#messageQueue = []; + logger().debug({ + msg: "flushing message queue", + queueLength: queue.length, + }); for (const msg of queue) { this.#sendMessage(msg); } @@ -520,26 +534,46 @@ enc // // These properties will be undefined const closeEvent = event as CloseEvent; - if (closeEvent.wasClean) { - logger().info({ - msg: "socket closed", - code: closeEvent.code, - reason: closeEvent.reason, - wasClean: closeEvent.wasClean, - }); - } else { - logger().warn({ - msg: "socket closed", - code: closeEvent.code, - reason: closeEvent.reason, - wasClean: closeEvent.wasClean, + const wasClean = closeEvent.wasClean; + + logger().info({ + msg: "socket closed", + code: closeEvent.code, + reason: closeEvent.reason, + wasClean: wasClean, + connectionId: this.#connectionId, + messageQueueLength: this.#messageQueue.length, + actionsInFlight: this.#actionsInFlight.size, + }); + + // Reject all in-flight actions + if (this.#actionsInFlight.size > 0) { + logger().debug({ + msg: "rejecting in-flight actions after disconnect", + count: this.#actionsInFlight.size, + connectionId: this.#connectionId, + wasClean, }); + + const disconnectError = new Error( + wasClean ? "Connection closed" : "Connection lost", + ); + + for (const actionInfo of this.#actionsInFlight.values()) { + actionInfo.reject(disconnectError); + } + this.#actionsInFlight.clear(); } this.#transport = undefined; // Automatically reconnect. Skip if already attempting to connect. if (!this.#disposed && !this.#connecting) { + logger().debug({ + msg: "triggering reconnect", + connectionId: this.#connectionId, + messageQueueLength: this.#messageQueue.length, + }); // TODO: Fetch actor to check if it's destroyed // TODO: Add backoff for reconnect // TODO: Add a way of preserving connection ID for connection state @@ -689,9 +723,26 @@ enc let queueMessage = false; if (!this.#transport) { // No transport connected yet + logger().debug({ msg: "no transport, queueing message" }); queueMessage = true; } else if ("websocket" in this.#transport) { - if (this.#transport.websocket.readyState === 1) { + const readyState = this.#transport.websocket.readyState; + logger().debug({ + msg: "websocket send attempt", + readyState, + readyStateString: + readyState === 0 + ? "CONNECTING" + : readyState === 1 + ? "OPEN" + : readyState === 2 + ? "CLOSING" + : "CLOSED", + connectionId: this.#connectionId, + messageType: (message.body as any).tag, + actionName: (message.body as any).val?.name, + }); + if (readyState === 1) { try { const messageSerialized = serializeWithEncoding( this.#encoding, @@ -707,12 +758,17 @@ enc logger().warn({ msg: "failed to send message, added to queue", error, + connectionId: this.#connectionId, }); // Assuming the socket is disconnected and will be reconnected soon queueMessage = true; } } else { + logger().debug({ + msg: "websocket not open, queueing message", + readyState, + }); queueMessage = true; } } else if ("sse" in this.#transport) { @@ -728,7 +784,13 @@ enc if (!opts?.ephemeral && queueMessage) { this.#messageQueue.push(message); - logger().debug({ msg: "queued connection message" }); + logger().debug({ + msg: "queued connection message", + queueLength: this.#messageQueue.length, + connectionId: this.#connectionId, + messageType: (message.body as any).tag, + actionName: (message.body as any).val?.name, + }); } } @@ -807,6 +869,22 @@ enc return deserializeWithEncoding(this.#encoding, buffer, TO_CLIENT_VERSIONED); } + /** + * Get the actor ID (for testing purposes). + * @internal + */ + get actorId(): string | undefined { + return this.#actorId; + } + + /** + * Get the connection ID (for testing purposes). + * @internal + */ + get connectionId(): string | undefined { + return this.#connectionId; + } + /** * Disconnects from the actor. * diff --git a/packages/rivetkit/src/driver-test-suite/mod.ts b/packages/rivetkit/src/driver-test-suite/mod.ts index ff72193f5..5d3d940d1 100644 --- a/packages/rivetkit/src/driver-test-suite/mod.ts +++ b/packages/rivetkit/src/driver-test-suite/mod.ts @@ -23,6 +23,7 @@ import { runActorInlineClientTests } from "./tests/actor-inline-client"; import { runActorInspectorTests } from "./tests/actor-inspector"; import { runActorMetadataTests } from "./tests/actor-metadata"; import { runActorOnStateChangeTests } from "./tests/actor-onstatechange"; +import { runActorReconnectTests } from "./tests/actor-reconnect"; import { runActorVarsTests } from "./tests/actor-vars"; import { runManagerDriverTests } from "./tests/manager-driver"; import { runRawHttpTests } from "./tests/raw-http"; @@ -100,6 +101,8 @@ export function runDriverTests( runActorConnStateTests({ ...driverTestConfig, transport }); + runActorReconnectTests({ ...driverTestConfig, transport }); + runRequestAccessTests({ ...driverTestConfig, transport }); runActorDriverTestsWithTransport({ ...driverTestConfig, transport }); diff --git a/packages/rivetkit/src/driver-test-suite/tests/actor-reconnect.ts b/packages/rivetkit/src/driver-test-suite/tests/actor-reconnect.ts new file mode 100644 index 000000000..e93b7afee --- /dev/null +++ b/packages/rivetkit/src/driver-test-suite/tests/actor-reconnect.ts @@ -0,0 +1,160 @@ +import { describe, expect, test, vi } from "vitest"; +import type { ActorConnRaw } from "@/client/actor-conn"; +import type { DriverTestConfig } from "../mod"; +import { setupDriverTest } from "../utils"; + +export function runActorReconnectTests(driverTestConfig: DriverTestConfig) { + describe("Actor Reconnection Tests", () => { + test("should reconnect and preserve connection state after non-clean disconnect", async (c) => { + const { client, endpoint } = await setupDriverTest(c, driverTestConfig); + + // Create actor and connect + const handle = client.counterConn.getOrCreate(["test-reconnect"]); + const connection = handle.connect(); + + // Set an initial count on the connection + await connection.increment(5); + + // Verify connection count is 1 + const connCount1 = await connection.getConnectionCount(); + expect(connCount1).toBe(1); + + // Force disconnect (non-clean) - simulates network failure + const connRaw = connection as unknown as ActorConnRaw; + await forceUncleanDisconnect( + endpoint, + connRaw.actorId!, + connRaw.connectionId!, + ); + + // Wait a bit for the disconnection to be processed + await vi.waitFor( + async () => { + const countAfterReconnect = await connection.getCount(); + expect(countAfterReconnect).toBe(5); // Should preserve the count + }, + { timeout: 5000, interval: 100 }, + ); + + // Verify connection count is still 1 (same connection reconnected) + const connCount2 = await connection.getConnectionCount(); + expect(connCount2).toBe(1); + + // Verify we can still increment the counter + const newCount = await connection.getCount(); + expect(newCount).toBe(5); + + // Clean up + await connection.dispose(); + }); + + test("should not preserve connection state after clean disconnect", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + // Create actor and connect + const handle = client.counterConn.getOrCreate(["test-clean-disconnect"]); + const connection = handle.connect(); + + // Set an initial count on the connection + await connection.increment(10); + + // Verify connection count is 1 + const connCount1 = await connection.getConnectionCount(); + expect(connCount1).toBe(1); + + // Clean disconnect + await connection.dispose(); + + // Wait a bit to ensure disconnection is processed + await vi.waitFor( + async () => { + // Check that connection count is now 0 + const handle2 = client.counterConn.get(["test-clean-disconnect"]); + const connCount = await handle2.getConnectionCount(); + // This counts the current action caller + expect(connCount).toBe(1); + }, + { timeout: 5000 }, + ); + + // Create a new connection + const connection2 = handle.connect(); + + // The count should be reset since it's a new connection + const countNewConnection = await connection2.getCount(); + expect(countNewConnection).toBe(0); // Should be reset + + // Verify connection count is 1 again (new connection) + const connCount3 = await connection2.getConnectionCount(); + expect(connCount3).toBe(1); + + // Clean up + await connection2.dispose(); + }); + + test("should handle multiple non-clean disconnects and reconnects", async (c) => { + const { client, endpoint } = await setupDriverTest(c, driverTestConfig); + + // Create actor and connect + const handle = client.counterConn.getOrCreate([ + "test-multiple-reconnect", + ]); + const connection = handle.connect(); + + // Set an initial count + await connection.setCount(100); + + // Perform multiple disconnect-reconnect cycles + for (let i = 0; i < 3; i++) { + // Increment before disconnect + await connection.increment(1); + + // Force disconnect + const connRaw = connection as unknown as ActorConnRaw; + await forceUncleanDisconnect( + endpoint, + connRaw.actorId!, + connRaw.connectionId!, + ); + + // Wait for reconnection and verify state is preserved + await vi.waitFor( + async () => { + const countAfter = await connection.getCount(); + expect(countAfter).toBe(101 + i); + }, + { timeout: 5000 }, + ); + + // Verify connection count remains 1 + const connCount = await connection.getConnectionCount(); + expect(connCount).toBe(1); + } + + // Final verification + const finalCount = await connection.getCount(); + expect(finalCount).toBe(103); + + // Clean up + await connection.dispose(); + }); + }); +} + +async function forceUncleanDisconnect( + endpoint: string, + actorId: string, + connId: string, +): Promise { + const response = await fetch( + `${endpoint}/.test/force-disconnect?actor=${actorId}&conn=${connId}`, + { + method: "POST", + }, + ); + + if (!response.ok) { + const text = await response.text(); + throw new Error(`Failed to force disconnect: ${text}`); + } +} diff --git a/packages/rivetkit/src/drivers/file-system/manager.ts b/packages/rivetkit/src/drivers/file-system/manager.ts index 360ae857a..5d8237f65 100644 --- a/packages/rivetkit/src/drivers/file-system/manager.ts +++ b/packages/rivetkit/src/drivers/file-system/manager.ts @@ -190,6 +190,8 @@ export class FileSystemManagerDriver implements ManagerDriver { actorId: string, encoding: Encoding, connParams: unknown, + connId?: string, + connToken?: string, ): Promise { const upgradeWebSocket = this.#runConfig.getUpgradeWebSocket?.(); invariant(upgradeWebSocket, "missing getUpgradeWebSocket"); @@ -206,9 +208,8 @@ export class FileSystemManagerDriver implements ManagerDriver { actorId, encoding, connParams, - // Extract connId and connToken from query parameters or headers if needed - undefined, - undefined, + connId, + connToken, ); return upgradeWebSocket(() => wsHandler)(c, noopNext()); } else if ( diff --git a/packages/rivetkit/src/manager/driver.ts b/packages/rivetkit/src/manager/driver.ts index 9d0b32270..607fc570d 100644 --- a/packages/rivetkit/src/manager/driver.ts +++ b/packages/rivetkit/src/manager/driver.ts @@ -35,6 +35,8 @@ export interface ManagerDriver { actorId: string, encoding: Encoding, params: unknown, + connId?: string, + connToken?: string, ): Promise; displayInformation(): ManagerDisplayInformation; diff --git a/packages/rivetkit/src/manager/gateway.ts b/packages/rivetkit/src/manager/gateway.ts index bfa73c070..e88c4aa35 100644 --- a/packages/rivetkit/src/manager/gateway.ts +++ b/packages/rivetkit/src/manager/gateway.ts @@ -6,7 +6,9 @@ import { HEADER_RIVET_ACTOR, HEADER_RIVET_TARGET, WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_ID, WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_CONN_TOKEN, WS_PROTOCOL_ENCODING, WS_PROTOCOL_TARGET, } from "@/common/actor-router-consts"; @@ -63,6 +65,8 @@ async function handleWebSocketGateway( let actorId: string | undefined; let encodingRaw: string | undefined; let connParamsRaw: string | undefined; + let connIdRaw: string | undefined; + let connTokenRaw: string | undefined; if (protocols) { const protocolList = protocols.split(",").map((p) => p.trim()); @@ -77,6 +81,10 @@ async function handleWebSocketGateway( connParamsRaw = decodeURIComponent( protocol.substring(WS_PROTOCOL_CONN_PARAMS.length), ); + } else if (protocol.startsWith(WS_PROTOCOL_CONN_ID)) { + connIdRaw = protocol.substring(WS_PROTOCOL_CONN_ID.length); + } else if (protocol.startsWith(WS_PROTOCOL_CONN_TOKEN)) { + connTokenRaw = protocol.substring(WS_PROTOCOL_CONN_TOKEN.length); } } } @@ -110,6 +118,8 @@ async function handleWebSocketGateway( actorId, encoding as any, // Will be validated by driver connParams, + connIdRaw, + connTokenRaw, ); } diff --git a/packages/rivetkit/src/manager/router.ts b/packages/rivetkit/src/manager/router.ts index 902dd8556..ee9348e51 100644 --- a/packages/rivetkit/src/manager/router.ts +++ b/packages/rivetkit/src/manager/router.ts @@ -16,7 +16,9 @@ import { serializeActorKey } from "@/actor/keys"; import type { Encoding, Transport } from "@/client/mod"; import { WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_ID, WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_CONN_TOKEN, WS_PROTOCOL_ENCODING, WS_PROTOCOL_PATH, WS_PROTOCOL_TRANSPORT, @@ -26,7 +28,7 @@ import { handleRouteNotFound, loggerMiddleware, } from "@/common/router"; -import { deconstructError, noopNext } from "@/common/utils"; +import { deconstructError, noopNext, stringifyError } from "@/common/utils"; import { type ActorDriver, HEADER_ACTOR_ID } from "@/driver-helpers/mod"; import type { TestInlineDriverCallRequest, @@ -50,7 +52,6 @@ import { import { RivetIdSchema } from "@/manager-api/common"; import type { RegistryConfig } from "@/registry/config"; import type { RunConfig } from "@/registry/run-config"; -import { stringifyError } from "@/utils"; import type { ActorOutput, ManagerDriver } from "./driver"; import { actorGateway, createTestWebSocketProxy } from "./gateway"; import { logger } from "./log"; @@ -383,6 +384,8 @@ function addManagerRoutes( let transport: Transport = "websocket"; let path = ""; let params: unknown; + let connId: string | undefined; + let connToken: string | undefined; for (const protocol of protocols) { if (protocol.startsWith(WS_PROTOCOL_ACTOR)) { @@ -404,6 +407,10 @@ function addManagerRoutes( protocol.substring(WS_PROTOCOL_CONN_PARAMS.length), ); params = JSON.parse(paramsRaw); + } else if (protocol.startsWith(WS_PROTOCOL_CONN_ID)) { + connId = protocol.substring(WS_PROTOCOL_CONN_ID.length); + } else if (protocol.startsWith(WS_PROTOCOL_CONN_TOKEN)) { + connToken = protocol.substring(WS_PROTOCOL_CONN_TOKEN.length); } } @@ -422,6 +429,8 @@ function addManagerRoutes( actorId, encoding, params, + connId, + connToken, ); return await createTestWebSocketProxy(clientWsPromise); @@ -485,6 +494,48 @@ function addManagerRoutes( ); } }); + + // Test endpoint to force disconnect a connection non-cleanly + router.post("/.test/force-disconnect", async (c) => { + const actorId = c.req.query("actor"); + const connId = c.req.query("conn"); + + if (!actorId || !connId) { + return c.text("Missing actor or conn query parameters", 400); + } + + logger().debug({ + msg: "forcing unclean disconnect", + actorId, + connId, + }); + + try { + // Send a special request to the actor to force disconnect the connection + const response = await managerDriver.sendRequest( + actorId, + new Request(`http://actor/.test/force-disconnect?conn=${connId}`, { + method: "POST", + }), + ); + + if (!response.ok) { + const text = await response.text(); + return c.text( + `Failed to force disconnect: ${text}`, + response.status as any, + ); + } + + return c.json({ success: true }); + } catch (error) { + logger().error({ + msg: "error forcing disconnect", + error: stringifyError(error), + }); + return c.text(`Error: ${error}`, 500); + } + }); } router.get("/health", cors, (c) => { diff --git a/packages/rivetkit/src/remote-manager-driver/mod.ts b/packages/rivetkit/src/remote-manager-driver/mod.ts index 6694c191f..d876765a5 100644 --- a/packages/rivetkit/src/remote-manager-driver/mod.ts +++ b/packages/rivetkit/src/remote-manager-driver/mod.ts @@ -234,6 +234,8 @@ export class RemoteManagerDriver implements ManagerDriver { actorId: string, encoding: Encoding, params: unknown, + connId?: string, + connToken?: string, ): Promise { const upgradeWebSocket = this.#config.getUpgradeWebSocket?.(); invariant(upgradeWebSocket, "missing getUpgradeWebSocket"); @@ -255,6 +257,8 @@ export class RemoteManagerDriver implements ManagerDriver { actorId, encoding, params, + connId, + connToken, ); const args = await createWebSocketProxy(c, wsGuardUrl, protocols);