Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type PersistedConnection struct {
state: data
subscriptions: list<PersistedSubscription>
lastSeen: i64
hibernatableRequestId: optional<data>
}

# MARK: Schedule Event
Expand Down
32 changes: 0 additions & 32 deletions rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,6 @@ export interface ConnDriver<State> {
conn: AnyConn,
state: State,
): ConnReadyState | undefined;

/**
* If the underlying connection can hibernate.
*/
isHibernatable(
actor: AnyActorInstance,
conn: AnyConn,
state: State,
): boolean;
}

// MARK: WebSocket
Expand Down Expand Up @@ -159,22 +150,6 @@ const WEBSOCKET_DRIVER: ConnDriver<ConnDriverWebSocketState> = {
): ConnReadyState | undefined => {
return state.websocket.readyState;
},

isHibernatable(
_actor: AnyActorInstance,
_conn: AnyConn,
state: ConnDriverWebSocketState,
): boolean {
// Extract isHibernatable from the HonoWebSocketAdapter
if (state.websocket.raw) {
const raw = state.websocket.raw as HonoWebSocketAdapter;
if (typeof raw.isHibernatable === "boolean") {
return raw.isHibernatable;
}
}

return false;
},
};

// MARK: SSE
Expand Down Expand Up @@ -210,10 +185,6 @@ const SSE_DRIVER: ConnDriver<ConnDriverSseState> = {

return ConnReadyState.OPEN;
},

isHibernatable(): boolean {
return false;
},
};

// MARK: HTTP
Expand All @@ -226,9 +197,6 @@ const HTTP_DRIVER: ConnDriver<ConnDriverHttpState> = {
// Noop
// TODO: Abort the request
},
isHibernatable(): boolean {
return false;
},
};

/** List of all connection drivers. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@ import type { ConnDriverState } from "./conn-drivers";

export interface ConnSocket {
requestId: string;
requestIdBuf?: ArrayBuffer;
hibernatable: boolean;
driverState: ConnDriverState;
}
25 changes: 13 additions & 12 deletions rivetkit-typescript/packages/rivetkit/src/actor/conn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import invariant from "invariant";
import { PersistedHibernatableWebSocket } from "@/schemas/actor-persist/mod";
import type * as protocol from "@/schemas/client-protocol/mod";
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
import { bufferToArrayBuffer } from "@/utils";
import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils";
import {
CONN_DRIVERS,
ConnDriverKind,
Expand All @@ -14,7 +14,7 @@ import {
import type { ConnSocket } from "./conn-socket";
import type { AnyDatabaseProvider } from "./database";
import * as errors from "./errors";
import type { ActorInstance } from "./instance";
import { type ActorInstance, PERSIST_SYMBOL } from "./instance";
import type { PersistedConn } from "./persisted";
import { CachedSerializer } from "./protocol/serde";
import { generateSecureToken } from "./utils";
Expand Down Expand Up @@ -69,7 +69,8 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
__socket?: ConnSocket;

get __status(): ConnectionStatus {
if (this.__socket) {
// TODO: isHibernatible might be true while the actual hibernatable websocket has disconnected
if (this.__socket || this.isHibernatable) {
return "connected";
} else {
return "reconnecting";
Expand Down Expand Up @@ -132,17 +133,17 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
* If the underlying connection can hibernate.
*/
public get isHibernatable(): boolean {
if (this.__driverState) {
const driverKind = getConnDriverKindFromState(this.__driverState);
const driver = CONN_DRIVERS[driverKind];
return driver.isHibernatable(
this.#actor,
this,
(this.__driverState as any)[driverKind],
);
} else {
if (!this.__persist.hibernatableRequestId) {
return false;
}
return (
this.#actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex((x) =>
arrayBuffersEqual(
x.requestId,
this.__persist.hibernatableRequestId!,
),
) > -1
);
}

/**
Expand Down
106 changes: 98 additions & 8 deletions rivetkit-typescript/packages/rivetkit/src/actor/instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
bufferToArrayBuffer,
EXTRA_ERROR_LOG,
getEnvUniversal,
idToStr,
promiseWithResolvers,
SinglePromiseQueue,
} from "@/utils";
Expand Down Expand Up @@ -244,7 +245,10 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
lastSeen: conn.lastSeen,
stateEnabled: conn.__stateEnabled,
isHibernatable: conn.isHibernatable,
requestId: conn.__socket?.requestId,
hibernatableRequestId: conn.__persist
.hibernatableRequestId
? idToStr(conn.__persist.hibernatableRequestId)
: undefined,
driver: conn.__driverState
? getConnDriverKindFromState(conn.__driverState)
: undefined,
Expand All @@ -267,6 +271,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
const conn = await this.createConn(
{
requestId: requestId,
hibernatable: false,
driverState: { [ConnDriverKind.HTTP]: {} },
},
undefined,
Expand Down Expand Up @@ -1016,6 +1021,74 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
): Promise<Conn<S, CP, CS, V, I, DB>> {
this.#assertReady();

// Check for hibernatable websocket reconnection
if (socket.requestIdBuf && socket.hibernatable) {
this.rLog.debug({
msg: "checking for hibernatable websocket connection",
requestId: socket.requestId,
existingConnectionsCount: this.#connections.size,
});

// Find existing connection with matching hibernatableRequestId
const existingConn = Array.from(this.#connections.values()).find(
(conn) =>
conn.__persist.hibernatableRequestId &&
arrayBuffersEqual(
conn.__persist.hibernatableRequestId,
socket.requestIdBuf!,
),
);

if (existingConn) {
this.rLog.debug({
msg: "reconnecting hibernatable websocket connection",
connectionId: existingConn.id,
requestId: socket.requestId,
});

// If there's an existing driver state, clean it up without marking as clean disconnect
if (existingConn.__driverState) {
this.#rLog.warn({
msg: "found existing driver state on hibernatable websocket",
connectionId: existingConn.id,
requestId: socket.requestId,
});
const driverKind = getConnDriverKindFromState(
existingConn.__driverState,
);
const driver = CONN_DRIVERS[driverKind];
if (driver.disconnect) {
// Call driver disconnect to clean up directly. Don't use Conn.disconnect since that will remove the connection entirely.
driver.disconnect(
this,
existingConn,
(existingConn.__driverState as any)[driverKind],
"Reconnecting hibernatable websocket with new driver state",
);
}
}

// Update with new driver state
existingConn.__socket = socket;
existingConn.__persist.lastSeen = Date.now();

// Update sleep timer since connection is now active
this.#resetSleepTimer();

this.inspector.emitter.emit("connectionUpdated");

// We don't need to send a new init message since this is a
// hibernated request that has already been initialized

return existingConn;
} else {
this.rLog.debug({
msg: "no existing hibernatable connection found, creating new connection",
requestId: socket.requestId,
});
}
}

// If connection ID and token are provided, try to reconnect
if (connectionId && connectionToken) {
this.rLog.debug({
Expand Down Expand Up @@ -1074,14 +1147,12 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
);

return existingConn;
} else {
this.rLog.debug({
msg: "connection not found or token mismatch, creating new connection",
connectionId,
});
}

// If we get here, either connection doesn't exist or token doesn't match
// Fall through to create new connection with new IDs
this.rLog.debug({
msg: "connection not found or token mismatch, creating new connection",
connectionId,
});
}

// Generate new connection ID and token if not provided or if reconnection failed
Expand Down Expand Up @@ -1147,6 +1218,19 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
lastSeen: Date.now(),
subscriptions: [],
};

// Check if this connection is for a hibernatable websocket
if (socket.requestIdBuf) {
const isHibernatable =
this.#persist.hibernatableWebSocket.findIndex((ws) =>
arrayBuffersEqual(ws.requestId, socket.requestIdBuf!),
) !== -1;

if (isHibernatable) {
persist.hibernatableRequestId = socket.requestIdBuf;
}
}

const conn = new Conn<S, CP, CS, V, I, DB>(this, persist);
conn.__socket = socket;
this.#connections.set(conn.id, conn);
Expand Down Expand Up @@ -2094,6 +2178,10 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
// Disconnect existing non-hibernatable connections
for (const connection of this.#connections.values()) {
if (!connection.isHibernatable) {
this.#rLog.debug({
msg: "disconnecting non-hibernatable connection on actor stop",
connId: connection.id,
});
promises.push(connection.disconnect());
}

Expand Down Expand Up @@ -2187,6 +2275,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
eventName: sub.eventName,
})),
lastSeen: BigInt(conn.lastSeen),
hibernatableRequestId: conn.hibernatableRequestId ?? null,
})),
scheduledEvents: persist.scheduledEvents.map((event) => ({
eventId: event.eventId,
Expand Down Expand Up @@ -2225,6 +2314,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
eventName: sub.eventName,
})),
lastSeen: Number(conn.lastSeen),
hibernatableRequestId: conn.hibernatableRequestId ?? undefined,
})),
scheduledEvents: bareData.scheduledEvents.map((event) => ({
eventId: event.eventId,
Expand Down
5 changes: 4 additions & 1 deletion rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ export interface PersistedConn<CP, CS> {
state: CS;
subscriptions: PersistedSubscription[];

/** Last time the socket was seen. This is set when disconencted so we can determine when we need to clean this up. */
/** Last time the socket was seen. This is set when disconnected so we can determine when we need to clean this up. */
lastSeen: number;

/** Request ID of the hibernatable WebSocket. See PersistedActor.hibernatableWebSocket */
hibernatableRequestId?: ArrayBuffer;
}

export interface PersistedSubscription {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export async function handleWebSocketConnect(
encoding: Encoding,
parameters: unknown,
requestId: string,
requestIdBuf: ArrayBuffer | undefined,
connId: string | undefined,
connToken: string | undefined,
): Promise<UpgradeWebSocketArgs> {
Expand Down Expand Up @@ -184,9 +185,19 @@ export async function handleWebSocketConnect(
actorId,
});

// Check if this is a hibernatable websocket
const isHibernatable =
!!requestIdBuf &&
actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex(
(ws) =>
arrayBuffersEqual(ws.requestId, requestIdBuf),
) !== -1;

conn = await actor.createConn(
{
requestId: requestId,
requestIdBuf: requestIdBuf,
hibernatable: isHibernatable,
driverState: {
[ConnDriverKind.WEBSOCKET]: {
encoding,
Expand Down Expand Up @@ -365,6 +376,7 @@ export async function handleSseConnect(
conn = await actor.createConn(
{
requestId: requestId,
hibernatable: false,
driverState: {
[ConnDriverKind.SSE]: {
encoding,
Expand Down Expand Up @@ -479,6 +491,7 @@ export async function handleAction(
conn = await actor.createConn(
{
requestId: requestId,
hibernatable: false,
driverState: { [ConnDriverKind.HTTP]: {} },
},
parameters,
Expand Down Expand Up @@ -593,6 +606,7 @@ export async function handleRawWebSocketHandler(
path: string,
actorDriver: ActorDriver,
actorId: string,
requestIdBuf: ArrayBuffer | undefined,
): Promise<UpgradeWebSocketArgs> {
const actor = await actorDriver.loadActor(actorId);

Expand Down
2 changes: 2 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/actor/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ export function createActorRouter(
encoding,
connParams,
generateConnRequestId(),
undefined,
connIdRaw,
connTokenRaw,
);
Expand Down Expand Up @@ -303,6 +304,7 @@ export function createActorRouter(
pathWithQuery,
actorDriver,
c.env.actorId,
undefined,
);
})(c, noopNext());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ export class EngineActorDriver implements ActorDriver {
encoding,
connParams,
requestId,
requestIdBuf,
// Extract connId and connToken from protocols if needed
undefined,
undefined,
Expand All @@ -572,6 +573,7 @@ export class EngineActorDriver implements ActorDriver {
url.pathname + url.search,
this,
actorId,
requestIdBuf,
);
} else {
throw new Error(`Unreachable path: ${url.pathname}`);
Expand Down
Loading
Loading