diff --git a/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare b/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare index ff0ab26b32..836c67cf06 100644 --- a/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare +++ b/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare @@ -12,6 +12,7 @@ type PersistedConnection struct { state: data subscriptions: list lastSeen: i64 + hibernatableRequestId: optional } # MARK: Schedule Event diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts index c0a4c276e4..ff7fae5016 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts @@ -68,15 +68,6 @@ export interface ConnDriver { conn: AnyConn, state: State, ): ConnReadyState | undefined; - - /** - * If the underlying connection can hibernate. - */ - isHibernatable( - actor: AnyActorInstance, - conn: AnyConn, - state: State, - ): boolean; } // MARK: WebSocket @@ -159,22 +150,6 @@ const WEBSOCKET_DRIVER: ConnDriver = { ): ConnReadyState | undefined => { return state.websocket.readyState; }, - - isHibernatable( - _actor: AnyActorInstance, - _conn: AnyConn, - state: ConnDriverWebSocketState, - ): boolean { - // Extract isHibernatable from the HonoWebSocketAdapter - if (state.websocket.raw) { - const raw = state.websocket.raw as HonoWebSocketAdapter; - if (typeof raw.isHibernatable === "boolean") { - return raw.isHibernatable; - } - } - - return false; - }, }; // MARK: SSE @@ -210,10 +185,6 @@ const SSE_DRIVER: ConnDriver = { return ConnReadyState.OPEN; }, - - isHibernatable(): boolean { - return false; - }, }; // MARK: HTTP @@ -226,9 +197,6 @@ const HTTP_DRIVER: ConnDriver = { // Noop // TODO: Abort the request }, - isHibernatable(): boolean { - return false; - }, }; /** List of all connection drivers. */ diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn-socket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn-socket.ts index c41c3ca197..c4157d216b 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn-socket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn-socket.ts @@ -2,5 +2,7 @@ import type { ConnDriverState } from "./conn-drivers"; export interface ConnSocket { requestId: string; + requestIdBuf?: ArrayBuffer; + hibernatable: boolean; driverState: ConnDriverState; } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts index 0e7556be19..c771e14289 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts @@ -3,7 +3,7 @@ import invariant from "invariant"; import { PersistedHibernatableWebSocket } from "@/schemas/actor-persist/mod"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; -import { bufferToArrayBuffer } from "@/utils"; +import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils"; import { CONN_DRIVERS, ConnDriverKind, @@ -14,7 +14,7 @@ import { import type { ConnSocket } from "./conn-socket"; import type { AnyDatabaseProvider } from "./database"; import * as errors from "./errors"; -import type { ActorInstance } from "./instance"; +import { type ActorInstance, PERSIST_SYMBOL } from "./instance"; import type { PersistedConn } from "./persisted"; import { CachedSerializer } from "./protocol/serde"; import { generateSecureToken } from "./utils"; @@ -69,7 +69,8 @@ export class Conn { __socket?: ConnSocket; get __status(): ConnectionStatus { - if (this.__socket) { + // TODO: isHibernatible might be true while the actual hibernatable websocket has disconnected + if (this.__socket || this.isHibernatable) { return "connected"; } else { return "reconnecting"; @@ -132,17 +133,17 @@ export class Conn { * If the underlying connection can hibernate. */ public get isHibernatable(): boolean { - if (this.__driverState) { - const driverKind = getConnDriverKindFromState(this.__driverState); - const driver = CONN_DRIVERS[driverKind]; - return driver.isHibernatable( - this.#actor, - this, - (this.__driverState as any)[driverKind], - ); - } else { + if (!this.__persist.hibernatableRequestId) { return false; } + return ( + this.#actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex((x) => + arrayBuffersEqual( + x.requestId, + this.__persist.hibernatableRequestId!, + ), + ) > -1 + ); } /** diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts index 5b0dec8d30..e84786de7f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts @@ -19,6 +19,7 @@ import { bufferToArrayBuffer, EXTRA_ERROR_LOG, getEnvUniversal, + idToStr, promiseWithResolvers, SinglePromiseQueue, } from "@/utils"; @@ -244,7 +245,10 @@ export class ActorInstance { lastSeen: conn.lastSeen, stateEnabled: conn.__stateEnabled, isHibernatable: conn.isHibernatable, - requestId: conn.__socket?.requestId, + hibernatableRequestId: conn.__persist + .hibernatableRequestId + ? idToStr(conn.__persist.hibernatableRequestId) + : undefined, driver: conn.__driverState ? getConnDriverKindFromState(conn.__driverState) : undefined, @@ -267,6 +271,7 @@ export class ActorInstance { const conn = await this.createConn( { requestId: requestId, + hibernatable: false, driverState: { [ConnDriverKind.HTTP]: {} }, }, undefined, @@ -1016,6 +1021,74 @@ export class ActorInstance { ): Promise> { this.#assertReady(); + // Check for hibernatable websocket reconnection + if (socket.requestIdBuf && socket.hibernatable) { + this.rLog.debug({ + msg: "checking for hibernatable websocket connection", + requestId: socket.requestId, + existingConnectionsCount: this.#connections.size, + }); + + // Find existing connection with matching hibernatableRequestId + const existingConn = Array.from(this.#connections.values()).find( + (conn) => + conn.__persist.hibernatableRequestId && + arrayBuffersEqual( + conn.__persist.hibernatableRequestId, + socket.requestIdBuf!, + ), + ); + + if (existingConn) { + this.rLog.debug({ + msg: "reconnecting hibernatable websocket connection", + connectionId: existingConn.id, + requestId: socket.requestId, + }); + + // If there's an existing driver state, clean it up without marking as clean disconnect + if (existingConn.__driverState) { + this.#rLog.warn({ + msg: "found existing driver state on hibernatable websocket", + connectionId: existingConn.id, + requestId: socket.requestId, + }); + const driverKind = getConnDriverKindFromState( + existingConn.__driverState, + ); + const driver = CONN_DRIVERS[driverKind]; + if (driver.disconnect) { + // Call driver disconnect to clean up directly. Don't use Conn.disconnect since that will remove the connection entirely. + driver.disconnect( + this, + existingConn, + (existingConn.__driverState as any)[driverKind], + "Reconnecting hibernatable websocket with new driver state", + ); + } + } + + // Update with new driver state + existingConn.__socket = socket; + existingConn.__persist.lastSeen = Date.now(); + + // Update sleep timer since connection is now active + this.#resetSleepTimer(); + + this.inspector.emitter.emit("connectionUpdated"); + + // We don't need to send a new init message since this is a + // hibernated request that has already been initialized + + return existingConn; + } else { + this.rLog.debug({ + msg: "no existing hibernatable connection found, creating new connection", + requestId: socket.requestId, + }); + } + } + // If connection ID and token are provided, try to reconnect if (connectionId && connectionToken) { this.rLog.debug({ @@ -1074,14 +1147,12 @@ export class ActorInstance { ); return existingConn; + } else { + this.rLog.debug({ + msg: "connection not found or token mismatch, creating new connection", + connectionId, + }); } - - // If we get here, either connection doesn't exist or token doesn't match - // Fall through to create new connection with new IDs - this.rLog.debug({ - msg: "connection not found or token mismatch, creating new connection", - connectionId, - }); } // Generate new connection ID and token if not provided or if reconnection failed @@ -1147,6 +1218,19 @@ export class ActorInstance { lastSeen: Date.now(), subscriptions: [], }; + + // Check if this connection is for a hibernatable websocket + if (socket.requestIdBuf) { + const isHibernatable = + this.#persist.hibernatableWebSocket.findIndex((ws) => + arrayBuffersEqual(ws.requestId, socket.requestIdBuf!), + ) !== -1; + + if (isHibernatable) { + persist.hibernatableRequestId = socket.requestIdBuf; + } + } + const conn = new Conn(this, persist); conn.__socket = socket; this.#connections.set(conn.id, conn); @@ -2094,6 +2178,10 @@ export class ActorInstance { // Disconnect existing non-hibernatable connections for (const connection of this.#connections.values()) { if (!connection.isHibernatable) { + this.#rLog.debug({ + msg: "disconnecting non-hibernatable connection on actor stop", + connId: connection.id, + }); promises.push(connection.disconnect()); } @@ -2187,6 +2275,7 @@ export class ActorInstance { eventName: sub.eventName, })), lastSeen: BigInt(conn.lastSeen), + hibernatableRequestId: conn.hibernatableRequestId ?? null, })), scheduledEvents: persist.scheduledEvents.map((event) => ({ eventId: event.eventId, @@ -2225,6 +2314,7 @@ export class ActorInstance { eventName: sub.eventName, })), lastSeen: Number(conn.lastSeen), + hibernatableRequestId: conn.hibernatableRequestId ?? undefined, })), scheduledEvents: bareData.scheduledEvents.map((event) => ({ eventId: event.eventId, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts b/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts index fb2203e8a1..762739b7f5 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts @@ -16,8 +16,11 @@ export interface PersistedConn { state: CS; subscriptions: PersistedSubscription[]; - /** Last time the socket was seen. This is set when disconencted so we can determine when we need to clean this up. */ + /** Last time the socket was seen. This is set when disconnected so we can determine when we need to clean this up. */ lastSeen: number; + + /** Request ID of the hibernatable WebSocket. See PersistedActor.hibernatableWebSocket */ + hibernatableRequestId?: ArrayBuffer; } export interface PersistedSubscription { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index 6ea33604b8..38fc64cd89 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -116,6 +116,7 @@ export async function handleWebSocketConnect( encoding: Encoding, parameters: unknown, requestId: string, + requestIdBuf: ArrayBuffer | undefined, connId: string | undefined, connToken: string | undefined, ): Promise { @@ -184,9 +185,19 @@ export async function handleWebSocketConnect( actorId, }); + // Check if this is a hibernatable websocket + const isHibernatable = + !!requestIdBuf && + actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex( + (ws) => + arrayBuffersEqual(ws.requestId, requestIdBuf), + ) !== -1; + conn = await actor.createConn( { requestId: requestId, + requestIdBuf: requestIdBuf, + hibernatable: isHibernatable, driverState: { [ConnDriverKind.WEBSOCKET]: { encoding, @@ -365,6 +376,7 @@ export async function handleSseConnect( conn = await actor.createConn( { requestId: requestId, + hibernatable: false, driverState: { [ConnDriverKind.SSE]: { encoding, @@ -479,6 +491,7 @@ export async function handleAction( conn = await actor.createConn( { requestId: requestId, + hibernatable: false, driverState: { [ConnDriverKind.HTTP]: {} }, }, parameters, @@ -593,6 +606,7 @@ export async function handleRawWebSocketHandler( path: string, actorDriver: ActorDriver, actorId: string, + requestIdBuf: ArrayBuffer | undefined, ): Promise { const actor = await actorDriver.loadActor(actorId); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts index 1a3c863e6e..4cb69f177e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts @@ -187,6 +187,7 @@ export function createActorRouter( encoding, connParams, generateConnRequestId(), + undefined, connIdRaw, connTokenRaw, ); @@ -303,6 +304,7 @@ export function createActorRouter( pathWithQuery, actorDriver, c.env.actorId, + undefined, ); })(c, noopNext()); } else { diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 9da4cfd9bd..c577b58915 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -562,6 +562,7 @@ export class EngineActorDriver implements ActorDriver { encoding, connParams, requestId, + requestIdBuf, // Extract connId and connToken from protocols if needed undefined, undefined, @@ -572,6 +573,7 @@ export class EngineActorDriver implements ActorDriver { url.pathname + url.search, this, actorId, + requestIdBuf, ); } else { throw new Error(`Unreachable path: ${url.pathname}`); diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts index 0891cb3138..3a5534be47 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts @@ -175,6 +175,7 @@ export class FileSystemManagerDriver implements ManagerDriver { encoding, params, generateConnRequestId(), + undefined, connId, connToken, ); @@ -190,6 +191,7 @@ export class FileSystemManagerDriver implements ManagerDriver { path, this.#actorDriver, actorId, + undefined, ); return new InlineWebSocketAdapter2(wsHandler); } else { @@ -234,6 +236,7 @@ export class FileSystemManagerDriver implements ManagerDriver { encoding, connParams, generateConnRequestId(), + undefined, connId, connToken, ); @@ -249,6 +252,7 @@ export class FileSystemManagerDriver implements ManagerDriver { path, this.#actorDriver, actorId, + undefined, ); return upgradeWebSocket(() => wsHandler)(c, noopNext()); } else { diff --git a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts index b6eaff8b3c..5a4a63a0af 100644 --- a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts @@ -24,6 +24,10 @@ migrations.set( 1, (v1Data: v1.PersistedActor): v2.PersistedActor => ({ ...v1Data, + connections: v1Data.connections.map((conn) => ({ + ...conn, + hibernatableRequestId: null, + })), hibernatableWebSocket: [], }), );