Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ This ensures imports resolve correctly across different build environments and p
- Write log messages in lowercase
- Use `logger()` to log messages
- Do not store `logger()` as a variable, always call it using `logger().info("...")`
- Use structured logging where it makes sense, for example: `logger().info("foo", { bar: 5, baz: 10 })`
- Use structured logging where it makes sense, for example: `logger().info({ msg: "foo", bar: 5, baz: 10 })`
- Supported logging methods are: trace, debug, info, warn, error, critical
- Instead of returning errors as raw HTTP responses with c.json, use or write an error in packages/rivetkit/src/actor/errors.ts and throw that instead. The middleware will automatically serialize the response for you.

Expand Down
10 changes: 10 additions & 0 deletions packages/cloudflare-workers/src/manager-driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import {
type ManagerDisplayInformation,
type ManagerDriver,
WS_PROTOCOL_ACTOR,
WS_PROTOCOL_CONN_ID,
WS_PROTOCOL_CONN_PARAMS,
WS_PROTOCOL_CONN_TOKEN,
WS_PROTOCOL_ENCODING,
WS_PROTOCOL_STANDARD,
WS_PROTOCOL_TARGET,
Expand Down Expand Up @@ -70,6 +72,8 @@ export class CloudflareActorsManagerDriver implements ManagerDriver {
actorId: string,
encoding: Encoding,
params: unknown,
connId?: string,
connToken?: string,
): Promise<UniversalWebSocket> {
const env = getCloudflareAmbientEnv();

Expand All @@ -93,6 +97,12 @@ export class CloudflareActorsManagerDriver implements ManagerDriver {
`${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`,
);
}
if (connId) {
protocols.push(`${WS_PROTOCOL_CONN_ID}${connId}`);
}
if (connToken) {
protocols.push(`${WS_PROTOCOL_CONN_TOKEN}${connToken}`);
}

const headers: Record<string, string> = {
Upgrade: "websocket",
Expand Down
33 changes: 33 additions & 0 deletions packages/rivetkit/fixtures/driver-test-suite/counter-conn.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import { actor } from "rivetkit";

export const counterConn = actor({
state: {
connectionCount: 0,
},
connState: { count: 0 },
onConnect: (c, conn) => {
c.state.connectionCount += 1;
},
onDisconnect: (c, conn) => {
// Note: We can't determine if disconnect was graceful from here
// For testing purposes, we'll decrement on all disconnects
// In real scenarios, you'd use connection tracking with timeouts
c.state.connectionCount -= 1;
},
actions: {
increment: (c, x: number) => {
c.conn.state.count += x;
c.broadcast("newCount", c.conn.state.count);
},
setCount: (c, x: number) => {
c.conn.state.count = x;
c.broadcast("newCount", x);
},
getCount: (c) => {
return c.conn.state.count;
},
getConnectionCount: (c) => {
return c.state.connectionCount;
},
},
});
3 changes: 3 additions & 0 deletions packages/rivetkit/fixtures/driver-test-suite/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { counterWithParams } from "./conn-params";
import { connStateActor } from "./conn-state";
// Import actors from individual files
import { counter } from "./counter";
import { counterConn } from "./counter-conn";
import { customTimeoutActor, errorHandlingActor } from "./error-handling";
import { inlineClientActor } from "./inline-client";
import { counterWithLifecycle } from "./lifecycle";
Expand Down Expand Up @@ -51,6 +52,8 @@ export const registry = setup({
use: {
// From counter.ts
counter,
// From counter-conn.ts
counterConn,
// From lifecycle.ts
counterWithLifecycle,
// From scheduled.ts
Expand Down
10 changes: 5 additions & 5 deletions packages/rivetkit/src/actor/conn-drivers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,11 @@ export const CONN_DRIVERS: Record<ConnDriverKind, ConnDriver<unknown>> = {
[ConnDriverKind.HTTP]: HTTP_DRIVER,
};

export function getConnDriverFromState(
export function getConnDriverKindFromState(
state: ConnDriverState,
): ConnDriver<unknown> {
if (ConnDriverKind.WEBSOCKET in state) return WEBSOCKET_DRIVER;
else if (ConnDriverKind.SSE in state) return SSE_DRIVER;
else if (ConnDriverKind.HTTP in state) return SSE_DRIVER;
): ConnDriverKind {
if (ConnDriverKind.WEBSOCKET in state) return ConnDriverKind.WEBSOCKET;
else if (ConnDriverKind.SSE in state) return ConnDriverKind.SSE;
else if (ConnDriverKind.HTTP in state) return ConnDriverKind.HTTP;
else assertUnreachable(state);
}
6 changes: 6 additions & 0 deletions packages/rivetkit/src/actor/conn-socket.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import type { ConnDriverState } from "./conn-drivers";

export interface ConnSocket {
socketId: string;
driverState: ConnDriverState;
}
108 changes: 44 additions & 64 deletions packages/rivetkit/src/actor/conn.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import * as cbor from "cbor-x";
import invariant from "invariant";
import type * as protocol from "@/schemas/client-protocol/mod";
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
import { bufferToArrayBuffer } from "@/utils";
Expand All @@ -7,8 +8,9 @@ import {
ConnDriverKind,
type ConnDriverState,
ConnReadyState,
getConnDriverFromState,
getConnDriverKindFromState,
} from "./conn-drivers";
import type { ConnSocket } from "./conn-socket";
import type { AnyDatabaseProvider } from "./database";
import * as errors from "./errors";
import type { ActorInstance } from "./instance";
Expand All @@ -24,14 +26,16 @@ export function generateConnToken(): string {
return generateSecureToken(32);
}

export function generateConnSocketId(): string {
return crypto.randomUUID();
}

export type ConnId = string;

export type AnyConn = Conn<any, any, any, any, any, any>;

export type ConnectionStatus = "connected" | "reconnecting";

export const CONNECTION_CHECK_LIVENESS_SYMBOL = Symbol("checkLiveness");

/**
* Represents a client connection to a actor.
*
Expand All @@ -45,19 +49,31 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
// TODO: Remove this cyclical reference
#actor: ActorInstance<S, CP, CS, V, I, DB>;

#status: ConnectionStatus = "connected";

/**
* The proxied state that notifies of changes automatically.
*
* Any data that should be stored indefinitely should be held within this object.
*/
__persist: PersistedConn<CP, CS>;

get __driverState(): ConnDriverState | undefined {
return this.__socket?.driverState;
}

/**
* Driver used to manage connection. If undefined, there is no connection connected.
* Socket connected to this connection.
*
* If undefined, then nothing is connected to this.
*/
__driverState?: ConnDriverState;
__socket?: ConnSocket;

get __status(): ConnectionStatus {
if (this.__socket) {
return "connected";
} else {
return "reconnecting";
}
}

public get params(): CP {
return this.__persist.params;
Expand Down Expand Up @@ -106,7 +122,7 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
* Status of the connection.
*/
public get status(): ConnectionStatus {
return this.#status;
return this.__status;
}

/**
Expand Down Expand Up @@ -146,9 +162,15 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
*/
public _sendMessage(message: CachedSerializer<protocol.ToClient>) {
if (this.__driverState) {
const driver = getConnDriverFromState(this.__driverState);
const driverKind = getConnDriverKindFromState(this.__driverState);
const driver = CONN_DRIVERS[driverKind];
if (driver.sendMessage) {
driver.sendMessage(this.#actor, this, this.__driverState, message);
driver.sendMessage(
this.#actor,
this,
(this.__driverState as any)[driverKind],
message,
);
} else {
this.#actor.rLog.debug({
msg: "conn driver does not support sending messages",
Expand Down Expand Up @@ -199,73 +221,31 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
* @param reason - The reason for disconnection.
*/
public async disconnect(reason?: string) {
this.#status = "reconnecting";

if (this.__driverState) {
const driver = getConnDriverFromState(this.__driverState);
if (this.__socket && this.__driverState) {
const driverKind = getConnDriverKindFromState(this.__driverState);
const driver = CONN_DRIVERS[driverKind];
if (driver.disconnect) {
driver.disconnect(this.#actor, this, this.__driverState, reason);
driver.disconnect(
this.#actor,
this,
(this.__driverState as any)[driverKind],
reason,
);
} else {
this.#actor.rLog.debug({
msg: "no disconnect handler for conn driver",
conn: this.id,
});
}

this.#actor.__connDisconnected(this, true, this.__socket.socketId);
} else {
this.#actor.rLog.warn({
msg: "missing connection driver state for disconnect",
conn: this.id,
});
}
}

/**
* This method checks the connection's liveness by querying the driver for its ready state.
* If the connection is not closed, it updates the last liveness timestamp and returns `true`.
* Otherwise, it returns `false`.
* @internal
*/
[CONNECTION_CHECK_LIVENESS_SYMBOL]() {
let readyState: ConnReadyState | undefined;

if (this.__driverState) {
const driver = getConnDriverFromState(this.__driverState);
readyState = driver.getConnectionReadyState(
this.#actor,
this,
this.__driverState,
);
}

const isConnectionClosed =
readyState === ConnReadyState.CLOSED ||
readyState === ConnReadyState.CLOSING ||
readyState === undefined;

const newLastSeen = Date.now();
const newStatus = isConnectionClosed ? "reconnecting" : "connected";

this.#actor.rLog.debug({
msg: "liveness probe for connection",
connId: this.id,
actorId: this.#actor.id,
readyState,

status: this.#status,
newStatus,

lastSeen: this.__persist.lastSeen,
currentTs: newLastSeen,
});

if (!isConnectionClosed) {
this.__persist.lastSeen = newLastSeen;
}

this.#status = newStatus;
return {
status: this.#status,
lastSeen: this.__persist.lastSeen,
};
this.__socket = undefined;
}
}
Loading
Loading