From 45c785f63c727130efbe0cdee46d32add7b03ecf Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sun, 14 Sep 2025 16:52:05 -0700 Subject: [PATCH 1/3] chore(core): remove manager driver --- clients/openapi/openapi.json | 769 ++----- examples/ai-agent/src/backend/registry.ts | 1 - .../src/backend/registry.ts | 14 - examples/chat-room/src/backend/registry.ts | 1 - .../cloudflare-workers-hono/src/registry.ts | 3 - examples/cloudflare-workers/src/registry.ts | 3 - examples/counter/scripts/connect.ts | 29 +- examples/counter/src/registry.ts | 3 - examples/crdt/src/backend/registry.ts | 1 - examples/database/src/backend/registry.ts | 1 - examples/drizzle/src/registry.ts | 3 - examples/elysia/src/registry.ts | 3 - examples/express/src/registry.ts | 3 - examples/game/src/backend/registry.ts | 1 - examples/hono-react/src/backend/registry.ts | 3 - examples/hono/src/registry.ts | 4 +- examples/next-js/src/rivet/registry.ts | 3 - examples/rate/src/backend/registry.ts | 1 - .../raw-fetch-handler/src/backend/registry.ts | 4 - .../src/backend/registry.ts | 4 - examples/react/src/backend/registry.ts | 3 - examples/starter/src/registry.ts | 3 - examples/stream/src/backend/registry.ts | 1 - examples/sync/src/backend/registry.ts | 1 - examples/tenant/src/backend/registry.ts | 1 - examples/trpc/src/registry.ts | 3 - .../src/actor-handler-do.ts | 6 +- .../cloudflare-workers/src/manager-driver.ts | 7 +- .../driver-test-suite/action-inputs.ts | 1 - .../driver-test-suite/action-timeout.ts | 4 - .../driver-test-suite/action-types.ts | 3 - .../driver-test-suite/actor-onstatechange.ts | 1 - .../fixtures/driver-test-suite/auth.ts | 103 - .../driver-test-suite/conn-liveness.ts | 1 - .../fixtures/driver-test-suite/conn-params.ts | 1 - .../fixtures/driver-test-suite/conn-state.ts | 1 - .../fixtures/driver-test-suite/counter.ts | 1 - .../driver-test-suite/error-handling.ts | 1 - .../driver-test-suite/inline-client.ts | 1 - .../fixtures/driver-test-suite/lifecycle.ts | 1 - .../fixtures/driver-test-suite/metadata.ts | 1 - .../driver-test-suite/raw-http-auth.ts | 51 +- .../raw-http-request-properties.ts | 9 +- .../fixtures/driver-test-suite/raw-http.ts | 27 +- .../driver-test-suite/raw-websocket-auth.ts | 31 +- .../driver-test-suite/raw-websocket.ts | 8 - .../fixtures/driver-test-suite/registry.ts | 15 - .../driver-test-suite/request-access-auth.ts | 48 - .../driver-test-suite/request-access.ts | 1 - .../fixtures/driver-test-suite/scheduled.ts | 1 - .../fixtures/driver-test-suite/sleep.ts | 5 - .../fixtures/driver-test-suite/vars.ts | 5 - .../rivetkit/schemas/actor-persist/v1.bare | 2 - .../rivetkit/schemas/client-protocol/v1.bare | 2 + packages/rivetkit/scripts/dump-openapi.ts | 12 - packages/rivetkit/src/actor/action.ts | 6 +- packages/rivetkit/src/actor/config.ts | 322 +-- packages/rivetkit/src/actor/connection.ts | 12 +- packages/rivetkit/src/actor/context.ts | 5 +- packages/rivetkit/src/actor/definition.ts | 18 +- packages/rivetkit/src/actor/errors.ts | 106 +- packages/rivetkit/src/actor/instance.ts | 65 +- .../{drivers/engine => actor}/keys.test.ts | 0 .../src/{drivers/engine => actor}/keys.ts | 0 packages/rivetkit/src/actor/mod.ts | 5 - packages/rivetkit/src/actor/protocol/old.ts | 31 +- .../rivetkit/src/actor/router-endpoints.ts | 14 +- packages/rivetkit/src/actor/router.ts | 7 +- packages/rivetkit/src/client/actor-common.ts | 2 +- packages/rivetkit/src/client/actor-conn.ts | 188 +- packages/rivetkit/src/client/actor-handle.ts | 107 +- packages/rivetkit/src/client/actor-query.ts | 65 + packages/rivetkit/src/client/client.ts | 98 +- packages/rivetkit/src/client/config.ts | 44 + packages/rivetkit/src/client/errors.ts | 1 + .../rivetkit/src/client/http-client-driver.ts | 329 --- packages/rivetkit/src/client/mod.ts | 28 +- packages/rivetkit/src/client/raw-utils.ts | 95 +- packages/rivetkit/src/client/utils.ts | 5 +- packages/rivetkit/src/common/router.ts | 29 +- packages/rivetkit/src/common/utils.ts | 10 + packages/rivetkit/src/driver-helpers/mod.ts | 1 - .../rivetkit/src/driver-test-suite/mod.ts | 138 +- .../test-inline-client-driver.ts | 404 ---- .../src/driver-test-suite/tests/actor-auth.ts | 591 ----- .../driver-test-suite/tests/actor-handle.ts | 33 + .../driver-test-suite/tests/manager-driver.ts | 6 +- .../rivetkit/src/driver-test-suite/utils.ts | 23 +- packages/rivetkit/src/drivers/default.ts | 2 +- .../src/drivers/engine/actor-driver.ts | 2 +- .../src/drivers/engine/api-endpoints.ts | 128 -- .../rivetkit/src/drivers/engine/api-utils.ts | 71 - .../src/drivers/engine/manager-driver.ts | 405 ---- packages/rivetkit/src/drivers/engine/mod.ts | 5 +- .../src/drivers/file-system/manager.ts | 6 +- .../rivetkit/src/inline-client-driver/mod.ts | 389 ---- .../src/manager-api/routes/actors-create.ts | 16 + .../src/manager-api/routes/actors-delete.ts | 4 + .../manager-api/routes/actors-get-by-id.ts | 7 + .../routes/actors-get-or-create-by-id.ts | 29 + .../src/manager-api/routes/actors-get.ts | 7 + .../rivetkit/src/manager-api/routes/common.ts | 18 + packages/rivetkit/src/manager/auth.ts | 124 -- packages/rivetkit/src/manager/driver.ts | 4 +- packages/rivetkit/src/manager/router.ts | 1922 +++-------------- packages/rivetkit/src/mod.ts | 2 - packages/rivetkit/src/registry/config.ts | 2 +- packages/rivetkit/src/registry/mod.ts | 26 +- packages/rivetkit/src/registry/run-config.ts | 73 +- packages/rivetkit/src/registry/serve.ts | 7 +- .../actor-http-client.ts | 72 + .../actor-websocket-client.ts | 60 + .../remote-manager-driver/api-endpoints.ts | 79 + .../src/remote-manager-driver/api-utils.ts | 43 + .../log.ts | 2 +- .../rivetkit/src/remote-manager-driver/mod.ts | 274 +++ .../ws-proxy.ts | 4 +- packages/rivetkit/src/serde.ts | 10 +- packages/rivetkit/src/test/mod.ts | 30 +- packages/rivetkit/tests/actor-types.test.ts | 7 - packages/rivetkit/tests/driver-engine.test.ts | 12 +- vitest.base.ts | 2 +- 122 files changed, 1949 insertions(+), 5899 deletions(-) delete mode 100644 packages/rivetkit/fixtures/driver-test-suite/auth.ts delete mode 100644 packages/rivetkit/fixtures/driver-test-suite/request-access-auth.ts rename packages/rivetkit/src/{drivers/engine => actor}/keys.test.ts (100%) rename packages/rivetkit/src/{drivers/engine => actor}/keys.ts (100%) create mode 100644 packages/rivetkit/src/client/actor-query.ts create mode 100644 packages/rivetkit/src/client/config.ts delete mode 100644 packages/rivetkit/src/client/http-client-driver.ts delete mode 100644 packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts delete mode 100644 packages/rivetkit/src/driver-test-suite/tests/actor-auth.ts delete mode 100644 packages/rivetkit/src/drivers/engine/api-endpoints.ts delete mode 100644 packages/rivetkit/src/drivers/engine/api-utils.ts delete mode 100644 packages/rivetkit/src/drivers/engine/manager-driver.ts delete mode 100644 packages/rivetkit/src/inline-client-driver/mod.ts create mode 100644 packages/rivetkit/src/manager-api/routes/actors-create.ts create mode 100644 packages/rivetkit/src/manager-api/routes/actors-delete.ts create mode 100644 packages/rivetkit/src/manager-api/routes/actors-get-by-id.ts create mode 100644 packages/rivetkit/src/manager-api/routes/actors-get-or-create-by-id.ts create mode 100644 packages/rivetkit/src/manager-api/routes/actors-get.ts create mode 100644 packages/rivetkit/src/manager-api/routes/common.ts delete mode 100644 packages/rivetkit/src/manager/auth.ts create mode 100644 packages/rivetkit/src/remote-manager-driver/actor-http-client.ts create mode 100644 packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts create mode 100644 packages/rivetkit/src/remote-manager-driver/api-endpoints.ts create mode 100644 packages/rivetkit/src/remote-manager-driver/api-utils.ts rename packages/rivetkit/src/{inline-client-driver => remote-manager-driver}/log.ts (62%) create mode 100644 packages/rivetkit/src/remote-manager-driver/mod.ts rename packages/rivetkit/src/{drivers/engine => remote-manager-driver}/ws-proxy.ts (97%) diff --git a/clients/openapi/openapi.json b/clients/openapi/openapi.json index 025772556..5e993868b 100644 --- a/clients/openapi/openapi.json +++ b/clients/openapi/openapi.json @@ -5,219 +5,43 @@ "title": "RivetKit API" }, "components": { - "schemas": { - "ResolveResponse": { - "type": "object", - "properties": { - "i": { - "type": "string", - "example": "actor-123" - } - }, - "required": [ - "i" - ] - }, - "ResolveQuery": { - "type": "object", - "properties": { - "query": { - "nullable": true, - "example": { - "getForId": { - "actorId": "actor-123" - } - } - } - } - }, - "ActionResponse": { - "nullable": true - }, - "ActionRequest": { - "type": "object", - "properties": { - "query": { - "nullable": true, - "example": { - "getForId": { - "actorId": "actor-123" - } - } - }, - "body": { - "nullable": true, - "example": { - "param1": "value1", - "param2": 123 - } - } - } - }, - "ConnectionMessageResponse": { - "nullable": true - }, - "ConnectionMessageRequest": { - "type": "object", - "properties": { - "message": { - "nullable": true, - "example": { - "type": "message", - "content": "Hello, actor!" - } - } - } - } - }, + "schemas": {}, "parameters": {} }, "paths": { - "/actors/resolve": { - "post": { - "parameters": [ - { - "schema": { - "type": "string", - "description": "Actor query information" - }, - "required": true, - "name": "X-RivetKit-Query", - "in": "header" - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ResolveQuery" - } - } - } - }, - "responses": { - "200": { - "description": "Success", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ResolveResponse" - } - } - } - }, - "400": { - "description": "User error" - }, - "500": { - "description": "Internal error" - } - } - } - }, - "/actors/connect/websocket": { - "get": { - "responses": { - "101": { - "description": "WebSocket upgrade" - } - } - } - }, - "/actors/connect/sse": { + "/actors/by-id": { "get": { "parameters": [ { "schema": { - "type": "string", - "description": "The encoding format to use for the response (json, cbor)", - "example": "json" + "type": "string" }, "required": true, - "name": "X-RivetKit-Encoding", - "in": "header" + "name": "name", + "in": "query" }, { "schema": { - "type": "string", - "description": "Actor query information" + "type": "string" }, "required": true, - "name": "X-RivetKit-Query", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection parameters" - }, - "required": false, - "name": "X-RivetKit-Conn-Params", - "in": "header" + "name": "key", + "in": "query" } ], - "responses": { - "200": { - "description": "SSE stream", - "content": { - "text/event-stream": { - "schema": { - "nullable": true - } - } - } - } - } - } - }, - "/actors/actions/{action}": { - "post": { - "parameters": [ - { - "schema": { - "type": "string", - "example": "myAction" - }, - "required": true, - "name": "action", - "in": "path" - }, - { - "schema": { - "type": "string", - "description": "The encoding format to use for the response (json, cbor)", - "example": "json" - }, - "required": true, - "name": "X-RivetKit-Encoding", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection parameters" - }, - "required": false, - "name": "X-RivetKit-Conn-Params", - "in": "header" - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ActionRequest" - } - } - } - }, "responses": { "200": { "description": "Success", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ActionResponse" + "type": "object", + "properties": { + "actor_id": { + "type": "string", + "nullable": true + } + } } } } @@ -229,56 +53,37 @@ "description": "Internal error" } } - } - }, - "/actors/message": { - "post": { - "parameters": [ - { - "schema": { - "type": "string", - "description": "Actor ID (used in some endpoints)", - "example": "actor-123456" - }, - "required": true, - "name": "X-RivetKit-Actor", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection ID", - "example": "conn-123456" - }, - "required": true, - "name": "X-RivetKit-Conn", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "The encoding format to use for the response (json, cbor)", - "example": "json" - }, - "required": true, - "name": "X-RivetKit-Encoding", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection token" - }, - "required": true, - "name": "X-RivetKit-Conn-Token", - "in": "header" - } - ], + }, + "put": { "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ConnectionMessageRequest" + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "key": { + "type": "string" + }, + "runner_name_selector": { + "type": "string" + }, + "crash_policy": { + "type": "string" + }, + "input": { + "type": "string", + "nullable": true + } + }, + "required": [ + "name", + "key", + "runner_name_selector", + "crash_policy" + ] } } } @@ -289,7 +94,19 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ConnectionMessageResponse" + "type": "object", + "properties": { + "actor_id": { + "type": "string" + }, + "created": { + "type": "boolean" + } + }, + "required": [ + "actor_id", + "created" + ] } } } @@ -303,157 +120,86 @@ } } }, - "/actors/raw/http/*": { + "/actors/{actor_id}": { "get": { "parameters": [ { "schema": { - "type": "string", - "description": "Actor query information" - }, - "required": false, - "name": "X-RivetKit-Query", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection parameters" - }, - "required": false, - "name": "X-RivetKit-Conn-Params", - "in": "header" - } - ], - "requestBody": { - "content": { - "*/*": { - "schema": { - "nullable": true, - "description": "Raw request body (can be any content type)" - } - } - } - }, - "responses": { - "200": { - "description": "Success - response from actor's onFetch handler", - "content": { - "*/*": { - "schema": { - "nullable": true, - "description": "Raw response from actor's onFetch handler" - } - } - } - }, - "404": { - "description": "Actor does not have an onFetch handler" - }, - "500": { - "description": "Internal server error or invalid response from actor" - } - } - }, - "post": { - "parameters": [ - { - "schema": { - "type": "string", - "description": "Actor query information" - }, - "required": false, - "name": "X-RivetKit-Query", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection parameters" - }, - "required": false, - "name": "X-RivetKit-Conn-Params", - "in": "header" - } - ], - "requestBody": { - "content": { - "*/*": { - "schema": { - "nullable": true, - "description": "Raw request body (can be any content type)" - } - } - } - }, - "responses": { - "200": { - "description": "Success - response from actor's onFetch handler", - "content": { - "*/*": { - "schema": { - "nullable": true, - "description": "Raw response from actor's onFetch handler" - } - } - } - }, - "404": { - "description": "Actor does not have an onFetch handler" - }, - "500": { - "description": "Internal server error or invalid response from actor" - } - } - }, - "put": { - "parameters": [ - { - "schema": { - "type": "string", - "description": "Actor query information" - }, - "required": false, - "name": "X-RivetKit-Query", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection parameters" + "type": "string" }, - "required": false, - "name": "X-RivetKit-Conn-Params", - "in": "header" + "required": true, + "name": "actor_id", + "in": "path" } ], - "requestBody": { - "content": { - "*/*": { - "schema": { - "nullable": true, - "description": "Raw request body (can be any content type)" - } - } - } - }, "responses": { "200": { - "description": "Success - response from actor's onFetch handler", + "description": "Success", "content": { - "*/*": { + "application/json": { "schema": { - "nullable": true, - "description": "Raw response from actor's onFetch handler" + "type": "object", + "properties": { + "actor": { + "type": "object", + "properties": { + "actor_id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "key": { + "type": "string" + }, + "namespace_id": { + "type": "string" + }, + "runner_name_selector": { + "type": "string" + }, + "create_ts": { + "type": "number" + }, + "connectable_ts": { + "type": "number", + "nullable": true + }, + "destroy_ts": { + "type": "number", + "nullable": true + }, + "sleep_ts": { + "type": "number", + "nullable": true + }, + "start_ts": { + "type": "number", + "nullable": true + } + }, + "required": [ + "actor_id", + "name", + "key", + "namespace_id", + "runner_name_selector", + "create_ts" + ] + } + }, + "required": [ + "actor" + ] } } } }, - "404": { - "description": "Actor does not have an onFetch handler" + "400": { + "description": "User error" }, "500": { - "description": "Internal server error or invalid response from actor" + "description": "Internal error" } } }, @@ -461,218 +207,137 @@ "parameters": [ { "schema": { - "type": "string", - "description": "Actor query information" - }, - "required": false, - "name": "X-RivetKit-Query", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection parameters" + "type": "string" }, - "required": false, - "name": "X-RivetKit-Conn-Params", - "in": "header" - } - ], - "requestBody": { - "content": { - "*/*": { - "schema": { - "nullable": true, - "description": "Raw request body (can be any content type)" - } - } - } - }, - "responses": { - "200": { - "description": "Success - response from actor's onFetch handler", - "content": { - "*/*": { - "schema": { - "nullable": true, - "description": "Raw response from actor's onFetch handler" - } - } - } - }, - "404": { - "description": "Actor does not have an onFetch handler" - }, - "500": { - "description": "Internal server error or invalid response from actor" - } - } - }, - "patch": { - "parameters": [ - { - "schema": { - "type": "string", - "description": "Actor query information" - }, - "required": false, - "name": "X-RivetKit-Query", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection parameters" - }, - "required": false, - "name": "X-RivetKit-Conn-Params", - "in": "header" + "required": true, + "name": "actor_id", + "in": "path" } ], - "requestBody": { - "content": { - "*/*": { - "schema": { - "nullable": true, - "description": "Raw request body (can be any content type)" - } - } - } - }, "responses": { "200": { - "description": "Success - response from actor's onFetch handler", - "content": { - "*/*": { - "schema": { - "nullable": true, - "description": "Raw response from actor's onFetch handler" - } - } - } - }, - "404": { - "description": "Actor does not have an onFetch handler" - }, - "500": { - "description": "Internal server error or invalid response from actor" - } - } - }, - "head": { - "parameters": [ - { - "schema": { - "type": "string", - "description": "Actor query information" - }, - "required": false, - "name": "X-RivetKit-Query", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection parameters" - }, - "required": false, - "name": "X-RivetKit-Conn-Params", - "in": "header" - } - ], - "requestBody": { - "content": { - "*/*": { - "schema": { - "nullable": true, - "description": "Raw request body (can be any content type)" - } - } - } - }, - "responses": { - "200": { - "description": "Success - response from actor's onFetch handler", + "description": "Success", "content": { - "*/*": { + "application/json": { "schema": { - "nullable": true, - "description": "Raw response from actor's onFetch handler" + "type": "object", + "properties": {} } } } }, - "404": { - "description": "Actor does not have an onFetch handler" + "400": { + "description": "User error" }, "500": { - "description": "Internal server error or invalid response from actor" + "description": "Internal error" } } - }, - "options": { - "parameters": [ - { - "schema": { - "type": "string", - "description": "Actor query information" - }, - "required": false, - "name": "X-RivetKit-Query", - "in": "header" - }, - { - "schema": { - "type": "string", - "description": "Connection parameters" - }, - "required": false, - "name": "X-RivetKit-Conn-Params", - "in": "header" - } - ], + } + }, + "/actors": { + "post": { "requestBody": { "content": { - "*/*": { + "application/json": { "schema": { - "nullable": true, - "description": "Raw request body (can be any content type)" + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "runner_name_selector": { + "type": "string" + }, + "crash_policy": { + "type": "string" + }, + "key": { + "type": "string", + "nullable": true + }, + "input": { + "type": "string", + "nullable": true + } + }, + "required": [ + "name", + "runner_name_selector", + "crash_policy" + ] } } } }, "responses": { "200": { - "description": "Success - response from actor's onFetch handler", + "description": "Success", "content": { - "*/*": { + "application/json": { "schema": { - "nullable": true, - "description": "Raw response from actor's onFetch handler" + "type": "object", + "properties": { + "actor": { + "type": "object", + "properties": { + "actor_id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "key": { + "type": "string" + }, + "namespace_id": { + "type": "string" + }, + "runner_name_selector": { + "type": "string" + }, + "create_ts": { + "type": "number" + }, + "connectable_ts": { + "type": "number", + "nullable": true + }, + "destroy_ts": { + "type": "number", + "nullable": true + }, + "sleep_ts": { + "type": "number", + "nullable": true + }, + "start_ts": { + "type": "number", + "nullable": true + } + }, + "required": [ + "actor_id", + "name", + "key", + "namespace_id", + "runner_name_selector", + "create_ts" + ] + } + }, + "required": [ + "actor" + ] } } } }, - "404": { - "description": "Actor does not have an onFetch handler" - }, - "500": { - "description": "Internal server error or invalid response from actor" - } - } - } - }, - "/actors/raw/websocket/*": { - "get": { - "responses": { - "101": { - "description": "WebSocket upgrade successful" - }, "400": { - "description": "WebSockets not enabled or invalid request" + "description": "User error" }, - "404": { - "description": "Actor does not have an onWebSocket handler" + "500": { + "description": "Internal error" } } } diff --git a/examples/ai-agent/src/backend/registry.ts b/examples/ai-agent/src/backend/registry.ts index 8fc133614..baf8ce89a 100644 --- a/examples/ai-agent/src/backend/registry.ts +++ b/examples/ai-agent/src/backend/registry.ts @@ -11,7 +11,6 @@ export type Message = { }; export const aiAgent = actor({ - onAuth: () => {}, // Persistent state that survives restarts: https://rivet.gg/docs/actors/state state: { messages: [] as Message[], diff --git a/examples/better-auth-external-db/src/backend/registry.ts b/examples/better-auth-external-db/src/backend/registry.ts index 1eb8514c4..6fabaa335 100644 --- a/examples/better-auth-external-db/src/backend/registry.ts +++ b/examples/better-auth-external-db/src/backend/registry.ts @@ -15,20 +15,6 @@ interface Message { } export const chatRoom = actor({ - // onAuth runs on the server & before connecting to the actor - onAuth: async (opts: OnAuthOptions) => { - // Access Better Auth session - const authResult = await auth.api.getSession({ - headers: opts.request.headers, - }); - if (!authResult) throw new Unauthorized(); - - // Passes auth data to the actor (c.conn.auth) - return { - user: authResult.user, - session: authResult.session, - }; - }, state: { messages: [], } as State, diff --git a/examples/chat-room/src/backend/registry.ts b/examples/chat-room/src/backend/registry.ts index 88f678fed..635be8257 100644 --- a/examples/chat-room/src/backend/registry.ts +++ b/examples/chat-room/src/backend/registry.ts @@ -3,7 +3,6 @@ import { actor, setup } from "rivetkit"; export type Message = { sender: string; text: string; timestamp: number }; export const chatRoom = actor({ - onAuth: () => {}, // Persistent state that survives restarts: https://rivet.gg/docs/actors/state state: { messages: [] as Message[], diff --git a/examples/cloudflare-workers-hono/src/registry.ts b/examples/cloudflare-workers-hono/src/registry.ts index aab6f2e02..3e3d61c37 100644 --- a/examples/cloudflare-workers-hono/src/registry.ts +++ b/examples/cloudflare-workers-hono/src/registry.ts @@ -1,9 +1,6 @@ import { actor, setup } from "rivetkit"; export const counter = actor({ - onAuth: () => { - // Configure auth here - }, state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/examples/cloudflare-workers/src/registry.ts b/examples/cloudflare-workers/src/registry.ts index 0ddb65983..eb5c0be87 100644 --- a/examples/cloudflare-workers/src/registry.ts +++ b/examples/cloudflare-workers/src/registry.ts @@ -1,9 +1,6 @@ import { actor, setup } from "rivetkit"; export const counter = actor({ - onAuth: () => { - // Configure auth here - }, state: { count: 0, connectionCount: 0, messageCount: 0 }, actions: { increment: (c, x: number) => { diff --git a/examples/counter/scripts/connect.ts b/examples/counter/scripts/connect.ts index ef1a24d32..251eb0cb8 100644 --- a/examples/counter/scripts/connect.ts +++ b/examples/counter/scripts/connect.ts @@ -1,14 +1,28 @@ import { createClient } from "rivetkit/client"; import type { Registry } from "../src/registry"; -async function main() { - const client = createClient( - process.env.ENDPOINT ?? "http://127.0.0.1:8080", - ); +// async function main() { +// const client = createClient(); +// +// const counter = await client.counter.getOrCreate().connect(); +// +// counter.on("newCount", (count: number) => console.log("Event:", count)); +// +// for (let i = 0; i < 5; i++) { +// const out = await counter.increment(5); +// console.log("RPC:", out); +// +// await new Promise((resolve) => setTimeout(resolve, 1000)); +// } +// +// await new Promise((resolve) => setTimeout(resolve, 10000)); +// await counter.dispose(); +// } - const counter = await client.counter.getOrCreate().connect(); +async function main() { + const client = createClient(); - counter.on("newCount", (count: number) => console.log("Event:", count)); + const counter = await client.counter.getOrCreate(); for (let i = 0; i < 5; i++) { const out = await counter.increment(5); @@ -16,9 +30,6 @@ async function main() { await new Promise((resolve) => setTimeout(resolve, 1000)); } - - await new Promise((resolve) => setTimeout(resolve, 10000)); - await counter.dispose(); } main(); diff --git a/examples/counter/src/registry.ts b/examples/counter/src/registry.ts index c91b51244..44707e2fe 100644 --- a/examples/counter/src/registry.ts +++ b/examples/counter/src/registry.ts @@ -4,9 +4,6 @@ const counter = actor({ state: { count: 0, }, - onAuth: () => { - return true; - }, actions: { increment: (c, x: number) => { c.state.count += x; diff --git a/examples/crdt/src/backend/registry.ts b/examples/crdt/src/backend/registry.ts index 333293f08..80fa91843 100644 --- a/examples/crdt/src/backend/registry.ts +++ b/examples/crdt/src/backend/registry.ts @@ -3,7 +3,6 @@ import * as Y from "yjs"; import { applyUpdate, encodeStateAsUpdate } from "yjs"; export const yjsDocument = actor({ - onAuth: () => {}, // Persistent state that survives restarts: https://rivet.gg/docs/actors/state state: { docData: "", // Base64 encoded Yjs document diff --git a/examples/database/src/backend/registry.ts b/examples/database/src/backend/registry.ts index 54afa682b..eed1d304b 100644 --- a/examples/database/src/backend/registry.ts +++ b/examples/database/src/backend/registry.ts @@ -4,7 +4,6 @@ import { authenticate } from "./my-utils"; export type Note = { id: string; content: string; updatedAt: number }; export const notes = actor({ - onAuth: () => {}, // Persistent state that survives restarts: https://rivet.gg/docs/actors/state state: { notes: [] as Note[], diff --git a/examples/drizzle/src/registry.ts b/examples/drizzle/src/registry.ts index f4dce5845..cabe7f889 100644 --- a/examples/drizzle/src/registry.ts +++ b/examples/drizzle/src/registry.ts @@ -8,9 +8,6 @@ // state: { // count: 0, // }, -// onAuth: () => { -// // Configure auth here -// }, // actions: { // increment: (c, x: number) => { // // createState or state fix fix fix diff --git a/examples/elysia/src/registry.ts b/examples/elysia/src/registry.ts index aab6f2e02..3e3d61c37 100644 --- a/examples/elysia/src/registry.ts +++ b/examples/elysia/src/registry.ts @@ -1,9 +1,6 @@ import { actor, setup } from "rivetkit"; export const counter = actor({ - onAuth: () => { - // Configure auth here - }, state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/examples/express/src/registry.ts b/examples/express/src/registry.ts index aab6f2e02..3e3d61c37 100644 --- a/examples/express/src/registry.ts +++ b/examples/express/src/registry.ts @@ -1,9 +1,6 @@ import { actor, setup } from "rivetkit"; export const counter = actor({ - onAuth: () => { - // Configure auth here - }, state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/examples/game/src/backend/registry.ts b/examples/game/src/backend/registry.ts index 0c4c665bf..067300182 100644 --- a/examples/game/src/backend/registry.ts +++ b/examples/game/src/backend/registry.ts @@ -5,7 +5,6 @@ export type Input = { x: number; y: number }; export type Player = { id: string; position: Position; input: Input }; const gameRoom = actor({ - onAuth: () => {}, // Persistent state that survives restarts: https://rivet.gg/docs/actors/state state: { players: {} as Record, diff --git a/examples/hono-react/src/backend/registry.ts b/examples/hono-react/src/backend/registry.ts index 95f78dff4..4afe732a3 100644 --- a/examples/hono-react/src/backend/registry.ts +++ b/examples/hono-react/src/backend/registry.ts @@ -1,9 +1,6 @@ import { actor, setup } from "rivetkit"; export const counter = actor({ - onAuth: () => { - // Configure auth here - }, state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/examples/hono/src/registry.ts b/examples/hono/src/registry.ts index aab6f2e02..592707834 100644 --- a/examples/hono/src/registry.ts +++ b/examples/hono/src/registry.ts @@ -1,9 +1,7 @@ import { actor, setup } from "rivetkit"; export const counter = actor({ - onAuth: () => { - // Configure auth here - }, + , state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/examples/next-js/src/rivet/registry.ts b/examples/next-js/src/rivet/registry.ts index 95f78dff4..4afe732a3 100644 --- a/examples/next-js/src/rivet/registry.ts +++ b/examples/next-js/src/rivet/registry.ts @@ -1,9 +1,6 @@ import { actor, setup } from "rivetkit"; export const counter = actor({ - onAuth: () => { - // Configure auth here - }, state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/examples/rate/src/backend/registry.ts b/examples/rate/src/backend/registry.ts index 180850c42..6567f0b69 100644 --- a/examples/rate/src/backend/registry.ts +++ b/examples/rate/src/backend/registry.ts @@ -7,7 +7,6 @@ export type RateLimitResult = { }; export const rateLimiter = actor({ - onAuth: () => {}, // Persistent state that survives restarts: https://rivet.gg/docs/actors/state state: { count: 0, diff --git a/examples/raw-fetch-handler/src/backend/registry.ts b/examples/raw-fetch-handler/src/backend/registry.ts index 52469d754..761c1cc3d 100644 --- a/examples/raw-fetch-handler/src/backend/registry.ts +++ b/examples/raw-fetch-handler/src/backend/registry.ts @@ -5,10 +5,6 @@ export const counter = actor({ state: { count: 0, }, - onAuth: () => { - // Skip auth, make onFetch public - return {}; - }, createVars: () => { // Setup router return { router: createCounterRouter() }; diff --git a/examples/raw-websocket-handler-proxy/src/backend/registry.ts b/examples/raw-websocket-handler-proxy/src/backend/registry.ts index 8660241a5..e6741c395 100644 --- a/examples/raw-websocket-handler-proxy/src/backend/registry.ts +++ b/examples/raw-websocket-handler-proxy/src/backend/registry.ts @@ -8,10 +8,6 @@ export const chatRoom = actor({ timestamp: number; }>, }, - onAuth: () => { - // Skip auth, make WebSocket handler public - return {}; - }, createVars: () => { return { sockets: new Set(), diff --git a/examples/react/src/backend/registry.ts b/examples/react/src/backend/registry.ts index 95f78dff4..4afe732a3 100644 --- a/examples/react/src/backend/registry.ts +++ b/examples/react/src/backend/registry.ts @@ -1,9 +1,6 @@ import { actor, setup } from "rivetkit"; export const counter = actor({ - onAuth: () => { - // Configure auth here - }, state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/examples/starter/src/registry.ts b/examples/starter/src/registry.ts index aab6f2e02..3e3d61c37 100644 --- a/examples/starter/src/registry.ts +++ b/examples/starter/src/registry.ts @@ -1,9 +1,6 @@ import { actor, setup } from "rivetkit"; export const counter = actor({ - onAuth: () => { - // Configure auth here - }, state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/examples/stream/src/backend/registry.ts b/examples/stream/src/backend/registry.ts index 91ca7785f..2af941db7 100644 --- a/examples/stream/src/backend/registry.ts +++ b/examples/stream/src/backend/registry.ts @@ -5,7 +5,6 @@ export type StreamState = { }; const streamProcessor = actor({ - onAuth: () => {}, // Persistent state that survives restarts: https://rivet.gg/docs/actors/state state: { topValues: [] as number[], diff --git a/examples/sync/src/backend/registry.ts b/examples/sync/src/backend/registry.ts index 39f162913..d07410713 100644 --- a/examples/sync/src/backend/registry.ts +++ b/examples/sync/src/backend/registry.ts @@ -9,7 +9,6 @@ export type Contact = { }; const contacts = actor({ - onAuth: () => {}, // State is automatically persisted // Persistent state that survives restarts: https://rivet.gg/docs/actors/state state: { diff --git a/examples/tenant/src/backend/registry.ts b/examples/tenant/src/backend/registry.ts index 3d8f5165f..3a836537d 100644 --- a/examples/tenant/src/backend/registry.ts +++ b/examples/tenant/src/backend/registry.ts @@ -21,7 +21,6 @@ export type ConnState = { }; const tenant = actor({ - onAuth: () => {}, // Persistent state that survives restarts: https://rivet.gg/docs/actors/state state: { orgId: "org-1", diff --git a/examples/trpc/src/registry.ts b/examples/trpc/src/registry.ts index aab6f2e02..3e3d61c37 100644 --- a/examples/trpc/src/registry.ts +++ b/examples/trpc/src/registry.ts @@ -1,9 +1,6 @@ import { actor, setup } from "rivetkit"; export const counter = actor({ - onAuth: () => { - // Configure auth here - }, state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/packages/cloudflare-workers/src/actor-handler-do.ts b/packages/cloudflare-workers/src/actor-handler-do.ts index 892f10992..63ad249c6 100644 --- a/packages/cloudflare-workers/src/actor-handler-do.ts +++ b/packages/cloudflare-workers/src/actor-handler-do.ts @@ -2,11 +2,7 @@ import { DurableObject, env } from "cloudflare:workers"; import type { ExecutionContext } from "hono"; import invariant from "invariant"; import type { ActorKey, ActorRouter, Registry, RunConfig } from "rivetkit"; -import { - createActorRouter, - createClientWithDriver, - createInlineClientDriver, -} from "rivetkit"; +import { createActorRouter, createClientWithDriver } from "rivetkit"; import { serializeEmptyPersistData } from "rivetkit/driver-helpers"; import { CloudflareDurableObjectGlobalState, diff --git a/packages/cloudflare-workers/src/manager-driver.ts b/packages/cloudflare-workers/src/manager-driver.ts index 8be947e71..79e4cb9dd 100644 --- a/packages/cloudflare-workers/src/manager-driver.ts +++ b/packages/cloudflare-workers/src/manager-driver.ts @@ -1,5 +1,5 @@ import type { Context as HonoContext } from "hono"; -import type { Encoding } from "rivetkit"; +import type { Encoding, UniversalWebSocket } from "rivetkit"; import { type ActorOutput, type CreateInput, @@ -9,7 +9,6 @@ import { HEADER_AUTH_DATA, HEADER_CONN_PARAMS, HEADER_ENCODING, - HEADER_EXPOSE_INTERNAL_ERROR, type ManagerDisplayInformation, type ManagerDriver, } from "rivetkit/driver-helpers"; @@ -69,7 +68,7 @@ export class CloudflareActorsManagerDriver implements ManagerDriver { actorId: string, encoding: Encoding, params: unknown, - ): Promise { + ): Promise { const env = getCloudflareAmbientEnv(); logger().debug({ @@ -85,7 +84,6 @@ export class CloudflareActorsManagerDriver implements ManagerDriver { const headers: Record = { Upgrade: "websocket", Connection: "Upgrade", - [HEADER_EXPOSE_INTERNAL_ERROR]: "true", [HEADER_ENCODING]: encoding, }; if (params) { @@ -190,7 +188,6 @@ export class CloudflareActorsManagerDriver implements ManagerDriver { } // Add RivetKit headers - actorRequest.headers.set(HEADER_EXPOSE_INTERNAL_ERROR, "true"); actorRequest.headers.set(HEADER_ENCODING, encoding); if (params) { actorRequest.headers.set(HEADER_CONN_PARAMS, JSON.stringify(params)); diff --git a/packages/rivetkit/fixtures/driver-test-suite/action-inputs.ts b/packages/rivetkit/fixtures/driver-test-suite/action-inputs.ts index 98f085509..42e566457 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/action-inputs.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/action-inputs.ts @@ -7,7 +7,6 @@ export interface State { // Test actor that can capture input during creation export const inputActor = actor({ - onAuth: () => {}, createState: (c, input): State => { return { initialInput: input, diff --git a/packages/rivetkit/fixtures/driver-test-suite/action-timeout.ts b/packages/rivetkit/fixtures/driver-test-suite/action-timeout.ts index c66314f36..b1ee4b3cd 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/action-timeout.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/action-timeout.ts @@ -2,7 +2,6 @@ import { actor } from "rivetkit"; // Short timeout actor export const shortTimeoutActor = actor({ - onAuth: () => {}, state: { value: 0 }, options: { actionTimeout: 50, // 50ms timeout @@ -21,7 +20,6 @@ export const shortTimeoutActor = actor({ // Long timeout actor export const longTimeoutActor = actor({ - onAuth: () => {}, state: { value: 0 }, options: { actionTimeout: 200, // 200ms timeout @@ -37,7 +35,6 @@ export const longTimeoutActor = actor({ // Default timeout actor export const defaultTimeoutActor = actor({ - onAuth: () => {}, state: { value: 0 }, actions: { normalAction: async (c) => { @@ -49,7 +46,6 @@ export const defaultTimeoutActor = actor({ // Sync actor (timeout shouldn't apply) export const syncTimeoutActor = actor({ - onAuth: () => {}, state: { value: 0 }, options: { actionTimeout: 50, // 50ms timeout diff --git a/packages/rivetkit/fixtures/driver-test-suite/action-types.ts b/packages/rivetkit/fixtures/driver-test-suite/action-types.ts index d343e9995..ad0707971 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/action-types.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/action-types.ts @@ -2,7 +2,6 @@ import { actor, UserError } from "rivetkit"; // Actor with synchronous actions export const syncActionActor = actor({ - onAuth: () => {}, state: { value: 0 }, actions: { // Simple synchronous action that returns a value directly @@ -26,7 +25,6 @@ export const syncActionActor = actor({ // Actor with asynchronous actions export const asyncActionActor = actor({ - onAuth: () => {}, state: { value: 0, data: null as any }, actions: { // Async action with a delay @@ -59,7 +57,6 @@ export const asyncActionActor = actor({ // Actor with promise actions export const promiseActor = actor({ - onAuth: () => {}, state: { results: [] as string[] }, actions: { // Action that returns a resolved promise diff --git a/packages/rivetkit/fixtures/driver-test-suite/actor-onstatechange.ts b/packages/rivetkit/fixtures/driver-test-suite/actor-onstatechange.ts index f953270f3..1a51841a3 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/actor-onstatechange.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/actor-onstatechange.ts @@ -1,7 +1,6 @@ import { actor } from "rivetkit"; export const onStateChangeActor = actor({ - onAuth: () => {}, state: { value: 0, changeCount: 0, diff --git a/packages/rivetkit/fixtures/driver-test-suite/auth.ts b/packages/rivetkit/fixtures/driver-test-suite/auth.ts deleted file mode 100644 index aa02530ae..000000000 --- a/packages/rivetkit/fixtures/driver-test-suite/auth.ts +++ /dev/null @@ -1,103 +0,0 @@ -import { actor, UserError } from "rivetkit"; - -// Basic auth actor - requires API key -export const authActor = actor({ - state: { requests: 0 }, - onAuth: (opts, params: { apiKey?: string } | undefined) => { - const apiKey = params?.apiKey; - if (!apiKey) { - throw new UserError("API key required", { code: "missing_auth" }); - } - - if (apiKey !== "valid-api-key") { - throw new UserError("Invalid API key", { code: "invalid_auth" }); - } - - return { userId: "user123", token: apiKey }; - }, - actions: { - getRequests: (c) => { - c.state.requests++; - return c.state.requests; - }, - getUserAuth: (c) => c.conn.auth, - }, -}); - -// Intent-specific auth actor - checks different permissions for different intents -export const intentAuthActor = actor({ - state: { value: 0 }, - onAuth: ({ request, intents }, params: { role: string }) => { - console.log("intents", intents, params); - const role = params.role; - - if (intents.has("create") && role !== "admin") { - throw new UserError("Admin role required for create operations", { - code: "insufficient_permissions", - }); - } - - if (intents.has("action") && !["admin", "user"].includes(role || "")) { - throw new UserError("User or admin role required for actions", { - code: "insufficient_permissions", - }); - } - - return { role, timestamp: Date.now() }; - }, - actions: { - getValue: (c) => c.state.value, - setValue: (c, value: number) => { - c.state.value = value; - return value; - }, - getAuth: (c) => c.conn.auth, - }, -}); - -// Public actor - empty onAuth to allow public access -export const publicActor = actor({ - state: { visitors: 0 }, - onAuth: () => { - return null; // Allow public access - }, - actions: { - visit: (c) => { - c.state.visitors++; - return c.state.visitors; - }, - }, -}); - -// No auth actor - should fail when accessed publicly (no onAuth defined) -export const noAuthActor = actor({ - state: { value: 42 }, - actions: { - getValue: (c) => c.state.value, - }, -}); - -// Async auth actor - tests promise-based authentication -export const asyncAuthActor = actor({ - state: { count: 0 }, - onAuth: async (opts, params: { token?: string } | undefined) => { - const token = params?.token; - if (!token) { - throw new UserError("Token required", { code: "missing_token" }); - } - - // Simulate token validation - if (token === "invalid") { - throw new UserError("Token is invalid", { code: "invalid_token" }); - } - - return { userId: `user-${token}`, validated: true }; - }, - actions: { - increment: (c) => { - c.state.count++; - return c.state.count; - }, - getAuthData: (c) => c.conn.auth, - }, -}); diff --git a/packages/rivetkit/fixtures/driver-test-suite/conn-liveness.ts b/packages/rivetkit/fixtures/driver-test-suite/conn-liveness.ts index 9e730e468..6dc77ba02 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/conn-liveness.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/conn-liveness.ts @@ -1,7 +1,6 @@ import { actor, CONNECTION_DRIVER_WEBSOCKET } from "rivetkit"; export const connLivenessActor = actor({ - onAuth: () => {}, state: { counter: 0, acceptingConnections: true, diff --git a/packages/rivetkit/fixtures/driver-test-suite/conn-params.ts b/packages/rivetkit/fixtures/driver-test-suite/conn-params.ts index f89507534..422b50887 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/conn-params.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/conn-params.ts @@ -1,7 +1,6 @@ import { actor } from "rivetkit"; export const counterWithParams = actor({ - onAuth: () => {}, state: { count: 0, initializers: [] as string[] }, createConnState: (c, opts, params: { name?: string }) => { return { diff --git a/packages/rivetkit/fixtures/driver-test-suite/conn-state.ts b/packages/rivetkit/fixtures/driver-test-suite/conn-state.ts index 723739cc7..668d6f6a7 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/conn-state.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/conn-state.ts @@ -8,7 +8,6 @@ export type ConnState = { }; export const connStateActor = actor({ - onAuth: () => {}, state: { sharedCounter: 0, disconnectionCount: 0, diff --git a/packages/rivetkit/fixtures/driver-test-suite/counter.ts b/packages/rivetkit/fixtures/driver-test-suite/counter.ts index 3ee625ca3..fd653c007 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/counter.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/counter.ts @@ -1,7 +1,6 @@ import { actor } from "rivetkit"; export const counter = actor({ - onAuth: () => {}, state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/packages/rivetkit/fixtures/driver-test-suite/error-handling.ts b/packages/rivetkit/fixtures/driver-test-suite/error-handling.ts index 8f67b7ed3..e1501a0a0 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/error-handling.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/error-handling.ts @@ -1,7 +1,6 @@ import { actor, UserError } from "rivetkit"; export const errorHandlingActor = actor({ - onAuth: () => {}, state: { errorLog: [] as string[], }, diff --git a/packages/rivetkit/fixtures/driver-test-suite/inline-client.ts b/packages/rivetkit/fixtures/driver-test-suite/inline-client.ts index f6c96aabc..596eb735b 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/inline-client.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/inline-client.ts @@ -2,7 +2,6 @@ import { actor } from "rivetkit"; import type { registry } from "./registry"; export const inlineClientActor = actor({ - onAuth: () => {}, state: { messages: [] as string[] }, actions: { // Action that uses client to call another actor (stateless) diff --git a/packages/rivetkit/fixtures/driver-test-suite/lifecycle.ts b/packages/rivetkit/fixtures/driver-test-suite/lifecycle.ts index d146e683e..0cdcf059c 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/lifecycle.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/lifecycle.ts @@ -3,7 +3,6 @@ import { actor } from "rivetkit"; type ConnParams = { trackLifecycle?: boolean } | undefined; export const counterWithLifecycle = actor({ - onAuth: () => {}, state: { count: 0, events: [] as string[], diff --git a/packages/rivetkit/fixtures/driver-test-suite/metadata.ts b/packages/rivetkit/fixtures/driver-test-suite/metadata.ts index 552731572..0330f43f8 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/metadata.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/metadata.ts @@ -3,7 +3,6 @@ import { actor } from "rivetkit"; // Note: For testing only - metadata API will need to be mocked // in tests since this is implementation-specific export const metadataActor = actor({ - onAuth: () => {}, state: { lastMetadata: null as any, actorName: "", diff --git a/packages/rivetkit/fixtures/driver-test-suite/raw-http-auth.ts b/packages/rivetkit/fixtures/driver-test-suite/raw-http-auth.ts index 3c260b13b..f2327295d 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/raw-http-auth.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/raw-http-auth.ts @@ -5,22 +5,19 @@ export const rawHttpAuthActor = actor({ state: { requestCount: 0, }, - onAuth: (opts, params: { apiKey?: string }) => { - const apiKey = params.apiKey; - if (!apiKey) { - throw new UserError("API key required", { code: "missing_auth" }); - } - - if (apiKey !== "valid-api-key") { - throw new UserError("Invalid API key", { code: "invalid_auth" }); - } - - return { userId: "user123", token: apiKey }; - }, - onFetch( - ctx: ActorContext, - request: Request, - ) { + // onAuth: (opts, params: { apiKey?: string }) => { + // const apiKey = params.apiKey; + // if (!apiKey) { + // throw new UserError("API key required", { code: "missing_auth" }); + // } + // + // if (apiKey !== "valid-api-key") { + // throw new UserError("Invalid API key", { code: "invalid_auth" }); + // } + // + // return { userId: "user123", token: apiKey }; + // }, + onFetch(ctx: ActorContext, request: Request) { const url = new URL(request.url); ctx.state.requestCount++; @@ -67,10 +64,7 @@ export const rawHttpNoAuthActor = actor({ state: { value: 42, }, - onFetch( - ctx: ActorContext, - request: Request, - ) { + onFetch(ctx: ActorContext, request: Request) { return new Response( JSON.stringify({ value: ctx.state.value, @@ -92,13 +86,7 @@ export const rawHttpPublicActor = actor({ state: { visitors: 0, }, - onAuth: () => { - return null; // Allow public access - }, - onFetch( - ctx: ActorContext, - request: Request, - ) { + onFetch(ctx: ActorContext, request: Request) { ctx.state.visitors++; return new Response( JSON.stringify({ @@ -123,14 +111,7 @@ export const rawHttpCustomAuthActor = actor({ authorized: 0, unauthorized: 0, }, - onAuth: () => { - // Allow all connections - auth will be handled in onFetch - return {}; - }, - onFetch( - ctx: ActorContext, - request: Request, - ) { + onFetch(ctx: ActorContext, request: Request) { // Custom auth check in onFetch const authHeader = request.headers.get("Authorization"); diff --git a/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts b/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts index 068bd60e9..69895e32e 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts @@ -1,15 +1,8 @@ import { type ActorContext, actor } from "rivetkit"; export const rawHttpRequestPropertiesActor = actor({ - onAuth() { - // Allow public access - empty onAuth - return {}; - }, actions: {}, - onFetch( - ctx: ActorContext, - request: Request, - ) { + onFetch(ctx: ActorContext, request: Request) { // Extract all relevant Request properties const url = new URL(request.url); const method = request.method; diff --git a/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts b/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts index 51f2b0ccd..33b2bdf9a 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts @@ -5,14 +5,7 @@ export const rawHttpActor = actor({ state: { requestCount: 0, }, - onAuth() { - // Allow public access - empty onAuth - return {}; - }, - onFetch( - ctx: ActorContext, - request: Request, - ) { + onFetch(ctx: ActorContext, request: Request) { const url = new URL(request.url); const method = request.method; @@ -57,19 +50,10 @@ export const rawHttpActor = actor({ }); export const rawHttpNoHandlerActor = actor({ - // No onFetch handler - all requests should return 404 - onAuth() { - // Allow public access - empty onAuth - return {}; - }, actions: {}, }); export const rawHttpVoidReturnActor = actor({ - onAuth() { - // Allow public access - empty onAuth - return {}; - }, onFetch(ctx, request) { // Intentionally return void to test error handling return undefined as any; @@ -78,10 +62,6 @@ export const rawHttpVoidReturnActor = actor({ }); export const rawHttpHonoActor = actor({ - onAuth() { - // Allow public access - return {}; - }, createVars() { const router = new Hono(); @@ -119,10 +99,7 @@ export const rawHttpHonoActor = actor({ // Return the router as a var return { router }; }, - onFetch( - ctx: ActorContext, - request: Request, - ) { + onFetch(ctx: ActorContext, request: Request) { // Use the Hono router from vars return ctx.vars.router.fetch(request); }, diff --git a/packages/rivetkit/fixtures/driver-test-suite/raw-websocket-auth.ts b/packages/rivetkit/fixtures/driver-test-suite/raw-websocket-auth.ts index 92e62ba6b..613152851 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/raw-websocket-auth.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/raw-websocket-auth.ts @@ -11,18 +11,18 @@ export const rawWebSocketAuthActor = actor({ connectionCount: 0, messageCount: 0, }, - onAuth: (opts, params: { apiKey?: string }) => { - const apiKey = params.apiKey; - if (!apiKey) { - throw new UserError("API key required", { code: "missing_auth" }); - } - - if (apiKey !== "valid-api-key") { - throw new UserError("Invalid API key", { code: "invalid_auth" }); - } - - return { userId: "user123", token: apiKey }; - }, + // onAuth: (opts, params: { apiKey?: string }) => { + // const apiKey = params.apiKey; + // if (!apiKey) { + // throw new UserError("API key required", { code: "missing_auth" }); + // } + // + // if (apiKey !== "valid-api-key") { + // throw new UserError("Invalid API key", { code: "invalid_auth" }); + // } + // + // return { userId: "user123", token: apiKey }; + // }, onWebSocket(ctx, websocket) { ctx.state.connectionCount++; @@ -104,9 +104,6 @@ export const rawWebSocketPublicActor = actor({ state: { visitors: 0, }, - onAuth: () => { - return null; // Allow public access - }, onWebSocket(ctx, websocket) { ctx.state.visitors++; @@ -136,10 +133,6 @@ export const rawWebSocketCustomAuthActor = actor({ authorized: 0, unauthorized: 0, }, - onAuth: () => { - // Allow all connections - auth will be handled in onWebSocket - return {}; - }, onWebSocket(ctx, websocket, opts) { // Check for auth token in URL or headers const url = new URL(opts.request.url); diff --git a/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts b/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts index 7d4e799a5..7194c35a8 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts @@ -5,10 +5,6 @@ export const rawWebSocketActor = actor({ connectionCount: 0, messageCount: 0, }, - onAuth(params) { - // Allow all connections and pass through connection params - return { connParams: params }; - }, onWebSocket(ctx, websocket, opts) { ctx.state.connectionCount = ctx.state.connectionCount + 1; console.log(`[ACTOR] New connection, count: ${ctx.state.connectionCount}`); @@ -105,10 +101,6 @@ export const rawWebSocketActor = actor({ }); export const rawWebSocketBinaryActor = actor({ - onAuth() { - // Allow all connections - return {}; - }, onWebSocket(ctx, websocket, opts) { // Handle binary data websocket.addEventListener("message", (event: any) => { diff --git a/packages/rivetkit/fixtures/driver-test-suite/registry.ts b/packages/rivetkit/fixtures/driver-test-suite/registry.ts index bffd29b2f..a423f1e6a 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/registry.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/registry.ts @@ -13,13 +13,6 @@ import { syncActionActor, } from "./action-types"; import { onStateChangeActor } from "./actor-onstatechange"; -import { - asyncAuthActor, - authActor, - intentAuthActor, - noAuthActor, - publicActor, -} from "./auth"; import { connLivenessActor } from "./conn-liveness"; import { counterWithParams } from "./conn-params"; import { connStateActor } from "./conn-state"; @@ -50,7 +43,6 @@ import { rawWebSocketPublicActor, } from "./raw-websocket-auth"; import { requestAccessActor } from "./request-access"; -import { requestAccessAuthActor } from "./request-access-auth"; import { scheduled } from "./scheduled"; import { sleep, @@ -112,12 +104,6 @@ export const registry = setup({ dynamicVarActor, uniqueVarActor, driverCtxActor, - // From auth.ts - authActor, - intentAuthActor, - publicActor, - noAuthActor, - asyncAuthActor, // From raw-http.ts rawHttpActor, rawHttpNoHandlerActor, @@ -140,7 +126,6 @@ export const registry = setup({ rawWebSocketCustomAuthActor, // From request-access.ts requestAccessActor, - requestAccessAuthActor, // From actor-onstatechange.ts onStateChangeActor, }, diff --git a/packages/rivetkit/fixtures/driver-test-suite/request-access-auth.ts b/packages/rivetkit/fixtures/driver-test-suite/request-access-auth.ts deleted file mode 100644 index cc0ec91f2..000000000 --- a/packages/rivetkit/fixtures/driver-test-suite/request-access-auth.ts +++ /dev/null @@ -1,48 +0,0 @@ -import { actor } from "rivetkit"; - -/** - * Test fixture to verify request object access in onAuth hook - * onAuth runs on the HTTP server, not in the actor, so we test it separately - */ -export const requestAccessAuthActor = actor({ - onAuth: ({ request, intents }, params: { trackRequest?: boolean }) => { - if (params?.trackRequest) { - // Extract request info and return it as auth data - const headers: Record = {}; - request.headers.forEach((value, key) => { - headers[key] = value; - }); - - return { - hasRequest: true, - requestUrl: request.url, - requestMethod: request.method, - requestHeaders: headers, - intents: Array.from(intents), - }; - } - - // Return empty auth data when not tracking - return {}; - }, - state: { - authData: null as any, - }, - onConnect: (c, conn) => { - // Store auth data in state so we can retrieve it - c.state.authData = conn.auth; - }, - actions: { - getAuthRequestInfo: (c) => { - // Return the stored auth data or a default object - const authData = c.state.authData || { - hasRequest: false, - requestUrl: null, - requestMethod: null, - requestHeaders: {}, - intents: [], - }; - return authData; - }, - }, -}); diff --git a/packages/rivetkit/fixtures/driver-test-suite/request-access.ts b/packages/rivetkit/fixtures/driver-test-suite/request-access.ts index a43166ab3..24ff8356c 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/request-access.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/request-access.ts @@ -4,7 +4,6 @@ import { actor } from "rivetkit"; * Test fixture to verify request object access in all lifecycle hooks */ export const requestAccessActor = actor({ - onAuth: () => {}, // Allow unauthenticated connections state: { // Track request info from different hooks onBeforeConnectRequest: { diff --git a/packages/rivetkit/fixtures/driver-test-suite/scheduled.ts b/packages/rivetkit/fixtures/driver-test-suite/scheduled.ts index 7a12f551d..7bac35bac 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/scheduled.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/scheduled.ts @@ -1,7 +1,6 @@ import { actor } from "rivetkit"; export const scheduled = actor({ - onAuth: () => {}, state: { lastRun: 0, scheduledCount: 0, diff --git a/packages/rivetkit/fixtures/driver-test-suite/sleep.ts b/packages/rivetkit/fixtures/driver-test-suite/sleep.ts index ef458478f..29f85422b 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/sleep.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/sleep.ts @@ -3,7 +3,6 @@ import { actor, type UniversalWebSocket } from "rivetkit"; export const SLEEP_TIMEOUT = 500; export const sleep = actor({ - onAuth: () => {}, state: { startCount: 0, sleepCount: 0 }, onStart: (c) => { c.state.startCount += 1; @@ -31,7 +30,6 @@ export const sleep = actor({ }); export const sleepWithLongRpc = actor({ - onAuth: () => {}, state: { startCount: 0, sleepCount: 0 }, createVars: () => ({}) as { longRunningResolve: PromiseWithResolvers }, onStart: (c) => { @@ -59,7 +57,6 @@ export const sleepWithLongRpc = actor({ }); export const sleepWithRawHttp = actor({ - onAuth: () => {}, state: { startCount: 0, sleepCount: 0, requestCount: 0 }, onStart: (c) => { c.state.startCount += 1; @@ -98,7 +95,6 @@ export const sleepWithRawHttp = actor({ }); export const sleepWithRawWebSocket = actor({ - onAuth: () => {}, state: { startCount: 0, sleepCount: 0, connectionCount: 0 }, onStart: (c) => { c.state.startCount += 1; @@ -168,7 +164,6 @@ export const sleepWithRawWebSocket = actor({ }); export const sleepWithNoSleepOption = actor({ - onAuth: () => {}, state: { startCount: 0, sleepCount: 0 }, onStart: (c) => { c.state.startCount += 1; diff --git a/packages/rivetkit/fixtures/driver-test-suite/vars.ts b/packages/rivetkit/fixtures/driver-test-suite/vars.ts index dac3ec834..7a6231982 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/vars.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/vars.ts @@ -2,7 +2,6 @@ import { actor } from "rivetkit"; // Actor with static vars export const staticVarActor = actor({ - onAuth: () => {}, state: { value: 0 }, connState: { hello: "world" }, vars: { counter: 42, name: "test-actor" }, @@ -18,7 +17,6 @@ export const staticVarActor = actor({ // Actor with nested vars export const nestedVarActor = actor({ - onAuth: () => {}, state: { value: 0 }, connState: { hello: "world" }, vars: { @@ -45,7 +43,6 @@ export const nestedVarActor = actor({ // Actor with dynamic vars export const dynamicVarActor = actor({ - onAuth: () => {}, state: { value: 0 }, connState: { hello: "world" }, createVars: () => { @@ -63,7 +60,6 @@ export const dynamicVarActor = actor({ // Actor with unique vars per instance export const uniqueVarActor = actor({ - onAuth: () => {}, state: { value: 0 }, connState: { hello: "world" }, createVars: () => { @@ -80,7 +76,6 @@ export const uniqueVarActor = actor({ // Actor that uses driver context export const driverCtxActor = actor({ - onAuth: () => {}, state: { value: 0 }, connState: { hello: "world" }, createVars: (c, driverCtx: any) => { diff --git a/packages/rivetkit/schemas/actor-persist/v1.bare b/packages/rivetkit/schemas/actor-persist/v1.bare index 3c50f950a..5725bc135 100644 --- a/packages/rivetkit/schemas/actor-persist/v1.bare +++ b/packages/rivetkit/schemas/actor-persist/v1.bare @@ -19,8 +19,6 @@ type PersistedConnection struct { parameters: data # Connection state state: data - # Authentication data - auth: optional # Active subscriptions subscriptions: list # Last seen timestamp diff --git a/packages/rivetkit/schemas/client-protocol/v1.bare b/packages/rivetkit/schemas/client-protocol/v1.bare index c3f9a0171..35423ef87 100644 --- a/packages/rivetkit/schemas/client-protocol/v1.bare +++ b/packages/rivetkit/schemas/client-protocol/v1.bare @@ -6,6 +6,7 @@ type Init struct { } type Error struct { + group: str code: str message: str metadata: optional @@ -68,6 +69,7 @@ type HttpActionResponse struct { # MARK: HTTP Error type HttpResponseError struct { + group: str code: str message: str metadata: optional diff --git a/packages/rivetkit/scripts/dump-openapi.ts b/packages/rivetkit/scripts/dump-openapi.ts index 9962e818a..b582e50d6 100644 --- a/packages/rivetkit/scripts/dump-openapi.ts +++ b/packages/rivetkit/scripts/dump-openapi.ts @@ -1,6 +1,5 @@ import * as fs from "node:fs/promises"; import { resolve } from "node:path"; -import type { ClientDriver } from "@/client/client"; import { createFileSystemOrMemoryDriver } from "@/drivers/file-system/mod"; import type { ManagerDriver } from "@/manager/driver"; import { createManagerRouter } from "@/manager/router"; @@ -19,16 +18,6 @@ function main() { getUpgradeWebSocket: () => () => unimplemented(), }); - const inlineClientDriver: ClientDriver = { - action: unimplemented, - resolveActorId: unimplemented, - connectWebSocket: unimplemented, - connectSse: unimplemented, - sendHttpMessage: unimplemented, - rawHttpRequest: unimplemented, - rawWebSocket: unimplemented, - }; - const managerDriver: ManagerDriver = { getForId: unimplemented, getWithKey: unimplemented, @@ -44,7 +33,6 @@ function main() { const { openapi } = createManagerRouter( registryConfig, driverConfig, - inlineClientDriver, managerDriver, true, ); diff --git a/packages/rivetkit/src/actor/action.ts b/packages/rivetkit/src/actor/action.ts index 1104a0305..a3b50ecb7 100644 --- a/packages/rivetkit/src/actor/action.ts +++ b/packages/rivetkit/src/actor/action.ts @@ -19,7 +19,6 @@ export class ActionContext< TConnState, TVars, TInput, - TAuthData, TDatabase extends AnyDatabaseProvider, > { #actorContext: ActorContext< @@ -28,7 +27,6 @@ export class ActionContext< TConnState, TVars, TInput, - TAuthData, TDatabase >; @@ -45,7 +43,6 @@ export class ActionContext< TConnState, TVars, TInput, - TAuthData, TDatabase >, public readonly conn: Conn< @@ -54,7 +51,6 @@ export class ActionContext< TConnState, TVars, TInput, - TAuthData, TDatabase >, ) { @@ -129,7 +125,7 @@ export class ActionContext< */ get conns(): Map< ConnId, - Conn + Conn > { return this.#actorContext.conns; } diff --git a/packages/rivetkit/src/actor/config.ts b/packages/rivetkit/src/actor/config.ts index 8c91d09e8..ccdbdca4b 100644 --- a/packages/rivetkit/src/actor/config.ts +++ b/packages/rivetkit/src/actor/config.ts @@ -11,7 +11,6 @@ export type InitContext = ActorContext< undefined, undefined, undefined, - undefined, undefined >; @@ -21,7 +20,6 @@ export interface ActorTypes< TConnState, TVars, TInput, - TAuthData, TDatabase extends AnyDatabaseProvider, > { state?: TState; @@ -29,7 +27,6 @@ export interface ActorTypes< connState?: TConnState; vars?: TVars; input?: TInput; - authData?: TAuthData; database?: TDatabase; } @@ -40,7 +37,6 @@ export interface ActorTypes< // (b) it makes the type definitions incredibly difficult to read as opposed to vanilla TypeScript. export const ActorConfigSchema = z .object({ - onAuth: z.function().optional(), onCreate: z.function().optional(), onStart: z.function().optional(), onStop: z.function().optional(), @@ -116,15 +112,7 @@ export interface OnConnectOptions { // This must have only one or the other or else TState will not be able to be inferred // // Data returned from this handler will be available on `c.state`. -type CreateState< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase, -> = +type CreateState = | { state: TState } | { createState: (c: InitContext, input: TInput) => TState | Promise; @@ -142,7 +130,6 @@ type CreateConnState< TConnState, TVars, TInput, - TAuthData, TDatabase, > = | { connState: TConnState } @@ -161,15 +148,7 @@ type CreateConnState< /** * @experimental */ -type CreateVars< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase, -> = +type CreateVars = | { /** * @experimental @@ -184,62 +163,16 @@ type CreateVars< } | Record; -// Creates auth config -// -// This must have only one or the other or else TAuthData will not be able to be inferred -type OnAuth = - | { - /** - * Called on the HTTP server before clients can interact with the actor. - * - * Only called for public endpoints. Calls to actors from within the backend - * do not trigger this handler. - * - * Data returned from this handler will be available on `c.conn.auth`. - * - * This function is required for any public HTTP endpoint access. Use this hook - * to validate client credentials and return authentication data that will be - * available on connections. This runs on the HTTP server (not the actor) - * in order to reduce load on the actor & prevent denial of server attacks - * against individual actors. - * - * If you need access to actor state for authentication, use onBeforeConnect - * with an empty onAuth function instead. - * - * You can also provide your own authentication middleware on your router if you - * choose, then use onAuth to pass the authentication data (e.g. user ID) to the - * actor itself. - * - * @param opts Authentication options including request and intent - * @returns Authentication data to attach to connections (must be serializable) - * @throws Throw an error to deny access to the actor - */ - onAuth: ( - opts: OnAuthOptions, - params: TConnParams, - ) => TAuthData | Promise; - } - | Record; - export interface Actions< TState, TConnParams, TConnState, TVars, TInput, - TAuthData, TDatabase extends AnyDatabaseProvider, > { [Action: string]: ( - c: ActionContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActionContext, ...args: any[] ) => any; } @@ -254,21 +187,12 @@ export interface Actions< */ export type AuthIntent = "get" | "create" | "connect" | "action" | "message"; -export interface OnAuthOptions { - request: Request; - /** - * @experimental - */ - intents: Set; -} - interface BaseActorConfig< TState, TConnParams, TConnState, TVars, TInput, - TAuthData, TDatabase extends AnyDatabaseProvider, TActions extends Actions< TState, @@ -276,7 +200,6 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TAuthData, TDatabase >, > { @@ -287,15 +210,7 @@ interface BaseActorConfig< * This is called before any other lifecycle hooks. */ onCreate?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActorContext, input: TInput, ) => void | Promise; @@ -308,15 +223,7 @@ interface BaseActorConfig< * @returns Void or a Promise that resolves when startup is complete */ onStart?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActorContext, ) => void | Promise; /** @@ -330,15 +237,7 @@ interface BaseActorConfig< * @returns Void or a Promise that resolves when shutdown is complete */ onStop?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActorContext, ) => void | Promise; /** @@ -353,48 +252,22 @@ interface BaseActorConfig< * @param newState The updated state */ onStateChange?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActorContext, newState: TState, ) => void; /** * Called before a client connects to the actor. * - * Unlike onAuth, this handler is still called for both internal and - * public clients. - * * Use this hook to determine if a connection should be accepted - * and to initialize connection-specific state. Unlike onAuth, this runs - * on the actor and has access to actor state, but uses slightly - * more resources on the actor rather than authenticating with onAuth. - * - * For authentication without actor state access, prefer onAuth. - * - * For authentication with actor state, use onBeforeConnect with an empty - * onAuth handler. + * and to initialize connection-specific state. * * @param opts Connection parameters including client-provided data * @returns The initial connection state or a Promise that resolves to it * @throws Throw an error to reject the connection */ onBeforeConnect?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActorContext, opts: OnConnectOptions, params: TConnParams, ) => void | Promise; @@ -409,24 +282,8 @@ interface BaseActorConfig< * @returns Void or a Promise that resolves when connection handling is complete */ onConnect?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, - conn: Conn< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActorContext, + conn: Conn, ) => void | Promise; /** @@ -439,24 +296,8 @@ interface BaseActorConfig< * @returns Void or a Promise that resolves when disconnect handling is complete */ onDisconnect?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, - conn: Conn< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActorContext, + conn: Conn, ) => void | Promise; /** @@ -472,15 +313,7 @@ interface BaseActorConfig< * @returns The modified output to send to the client */ onBeforeActionResponse?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActorContext, name: string, args: unknown[], output: Out, @@ -496,17 +329,9 @@ interface BaseActorConfig< * @returns A Response object to send back, or void to continue with default routing */ onFetch?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActorContext, request: Request, - opts: { auth: TAuthData }, + opts: {}, ) => Response | Promise; /** @@ -519,17 +344,9 @@ interface BaseActorConfig< * @param request The original HTTP upgrade request */ onWebSocket?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >, + c: ActorContext, websocket: UniversalWebSocket, - opts: { request: Request; auth: TAuthData }, + opts: { request: Request }, ) => void | Promise; actions: TActions; @@ -553,12 +370,10 @@ export type ActorConfig< TConnState, TVars, TInput, - TAuthData, TDatabase extends AnyDatabaseProvider, > = Omit< z.infer, | "actions" - | "onAuth" | "onCreate" | "onStart" | "onStateChange" @@ -582,46 +397,12 @@ export type ActorConfig< TConnState, TVars, TInput, - TAuthData, TDatabase, - Actions< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - > - > & - OnAuth & - CreateState< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - > & - CreateConnState< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - > & - CreateVars< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase + Actions > & + CreateState & + CreateConnState & + CreateVars & ActorDatabaseConfig; // See description on `ActorConfig` @@ -631,7 +412,6 @@ export type ActorConfigInput< TConnState = undefined, TVars = undefined, TInput = undefined, - TAuthData = undefined, TDatabase extends AnyDatabaseProvider = undefined, TActions extends Actions< TState, @@ -639,23 +419,13 @@ export type ActorConfigInput< TConnState, TVars, TInput, - TAuthData, TDatabase > = Record, > = { - types?: ActorTypes< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - >; + types?: ActorTypes; } & Omit< z.input, | "actions" - | "onAuth" | "onCreate" | "onStart" | "onStop" @@ -680,38 +450,12 @@ export type ActorConfigInput< TConnState, TVars, TInput, - TAuthData, TDatabase, TActions > & - OnAuth & - CreateState< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - > & - CreateConnState< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - > & - CreateVars< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase - > & + CreateState & + CreateConnState & + CreateVars & ActorDatabaseConfig; // For testing type definitions: @@ -721,7 +465,6 @@ export function test< TConnState, TVars, TInput, - TAuthData, TDatabase extends AnyDatabaseProvider, TActions extends Actions< TState, @@ -729,7 +472,6 @@ export function test< TConnState, TVars, TInput, - TAuthData, TDatabase >, >( @@ -739,26 +481,16 @@ export function test< TConnState, TVars, TInput, - TAuthData, TDatabase, TActions >, -): ActorConfig< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TAuthData, - TDatabase -> { +): ActorConfig { const config = ActorConfigSchema.parse(input) as ActorConfig< TState, TConnParams, TConnState, TVars, TInput, - TAuthData, TDatabase >; return config; diff --git a/packages/rivetkit/src/actor/connection.ts b/packages/rivetkit/src/actor/connection.ts index 6edecc050..8eacce71d 100644 --- a/packages/rivetkit/src/actor/connection.ts +++ b/packages/rivetkit/src/actor/connection.ts @@ -21,7 +21,7 @@ export function generateConnToken(): string { export type ConnId = string; -export type AnyConn = Conn; +export type AnyConn = Conn; export const CONNECTION_DRIVER_WEBSOCKET = "webSocket"; export const CONNECTION_DRIVER_SSE = "sse"; @@ -43,13 +43,13 @@ export const CONNECTION_CHECK_LIVENESS_SYMBOL = Symbol("checkLiveness"); * * @see {@link https://rivet.gg/docs/connections|Connection Documentation} */ -export class Conn { +export class Conn { subscriptions: Set = new Set(); #stateEnabled: boolean; // TODO: Remove this cyclical reference - #actor: ActorInstance; + #actor: ActorInstance; #status: ConnectionStatus = "connected"; @@ -71,10 +71,6 @@ export class Conn { return this.__persist.params; } - public get auth(): AD { - return this.__persist.authData as AD; - } - public get driver(): ConnectionDriver { return this.__persist.connDriver as ConnectionDriver; } @@ -140,7 +136,7 @@ export class Conn { * @protected */ public constructor( - actor: ActorInstance, + actor: ActorInstance, persist: PersistedConn, driver: ConnDriver, stateEnabled: boolean, diff --git a/packages/rivetkit/src/actor/context.ts b/packages/rivetkit/src/actor/context.ts index 1c447c472..d4d415b84 100644 --- a/packages/rivetkit/src/actor/context.ts +++ b/packages/rivetkit/src/actor/context.ts @@ -16,7 +16,6 @@ export class ActorContext< TConnState, TVars, TInput, - TAuthData, TDatabase extends AnyDatabaseProvider, > { #actor: ActorInstance< @@ -25,7 +24,6 @@ export class ActorContext< TConnState, TVars, TInput, - TAuthData, TDatabase >; @@ -36,7 +34,6 @@ export class ActorContext< TConnState, TVars, TInput, - TAuthData, TDatabase >, ) { @@ -114,7 +111,7 @@ export class ActorContext< */ get conns(): Map< ConnId, - Conn + Conn > { return this.#actor.conns; } diff --git a/packages/rivetkit/src/actor/definition.ts b/packages/rivetkit/src/actor/definition.ts index 28670a952..7f54f376a 100644 --- a/packages/rivetkit/src/actor/definition.ts +++ b/packages/rivetkit/src/actor/definition.ts @@ -12,7 +12,6 @@ export type AnyActorDefinition = ActorDefinition< any, any, any, - any, any >; @@ -26,11 +25,10 @@ export type ActorContextOf = infer CS, infer V, infer I, - infer AD, infer DB, any > - ? ActorContext + ? ActorContext : never; /** @@ -43,11 +41,10 @@ export type ActionContextOf = infer CS, infer V, infer I, - infer AD, infer DB, any > - ? ActionContext + ? ActionContext : never; export class ActorDefinition< @@ -56,21 +53,20 @@ export class ActorDefinition< CS, V, I, - AD, DB extends AnyDatabaseProvider, - R extends Actions, + R extends Actions, > { - #config: ActorConfig; + #config: ActorConfig; - constructor(config: ActorConfig) { + constructor(config: ActorConfig) { this.#config = config; } - get config(): ActorConfig { + get config(): ActorConfig { return this.#config; } - instantiate(): ActorInstance { + instantiate(): ActorInstance { return new ActorInstance(this.#config); } } diff --git a/packages/rivetkit/src/actor/errors.ts b/packages/rivetkit/src/actor/errors.ts index a39f129ac..fd15ffaec 100644 --- a/packages/rivetkit/src/actor/errors.ts +++ b/packages/rivetkit/src/actor/errors.ts @@ -20,6 +20,8 @@ export class ActorError extends Error { public public: boolean; public metadata?: unknown; public statusCode = 500; + public readonly group: string; + public readonly code: string; public static isActorError( error: unknown, @@ -31,11 +33,14 @@ export class ActorError extends Error { } constructor( - public readonly code: string, + group: string, + code: string, message: string, opts?: ActorErrorOptions, ) { super(message, { cause: opts?.cause }); + this.group = group; + this.code = code; this.public = opts?.public ?? false; this.metadata = opts?.metadata; @@ -64,7 +69,7 @@ export class ActorError extends Error { export class InternalError extends ActorError { constructor(message: string) { - super(INTERNAL_ERROR_CODE, message); + super("actor", INTERNAL_ERROR_CODE, message); } } @@ -77,6 +82,7 @@ export class Unreachable extends InternalError { export class StateNotEnabled extends ActorError { constructor() { super( + "actor", "state_not_enabled", "State not enabled. Must implement `createState` or `state` to use state. (https://www.rivet.gg/docs/actors/state/#initializing-state)", ); @@ -86,6 +92,7 @@ export class StateNotEnabled extends ActorError { export class ConnStateNotEnabled extends ActorError { constructor() { super( + "actor", "conn_state_not_enabled", "Connection state not enabled. Must implement `createConnectionState` or `connectionState` to use connection state. (https://www.rivet.gg/docs/actors/connections/#connection-state)", ); @@ -95,6 +102,7 @@ export class ConnStateNotEnabled extends ActorError { export class VarsNotEnabled extends ActorError { constructor() { super( + "actor", "vars_not_enabled", "Variables not enabled. Must implement `createVars` or `vars` to use state. (https://www.rivet.gg/docs/actors/ephemeral-variables/#initializing-variables)", ); @@ -104,7 +112,8 @@ export class VarsNotEnabled extends ActorError { export class ActionTimedOut extends ActorError { constructor() { super( - "action_timed_out", + "action", + "timed_out", "Action timed out. This can be increased with: `actor({ options: { action: { timeout: ... } } })`", { public: true }, ); @@ -114,7 +123,8 @@ export class ActionTimedOut extends ActorError { export class ActionNotFound extends ActorError { constructor(name: string) { super( - "action_not_found", + "action", + "not_found", `Action '${name}' not found. Validate the action exists on your actor.`, { public: true }, ); @@ -124,7 +134,8 @@ export class ActionNotFound extends ActorError { export class InvalidEncoding extends ActorError { constructor(format?: string) { super( - "invalid_encoding", + "encoding", + "invalid", `Invalid encoding \`${format}\`. (https://www.rivet.gg/docs/actors/clients/#actor-client)`, { public: true, @@ -135,7 +146,7 @@ export class InvalidEncoding extends ActorError { export class ConnNotFound extends ActorError { constructor(id?: string) { - super("conn_not_found", `Connection not found for ID: ${id}`, { + super("connection", "not_found", `Connection not found for ID: ${id}`, { public: true, }); } @@ -143,7 +154,7 @@ export class ConnNotFound extends ActorError { export class IncorrectConnToken extends ActorError { constructor() { - super("incorrect_conn_token", "Incorrect connection token.", { + super("connection", "incorrect_token", "Incorrect connection token.", { public: true, }); } @@ -152,7 +163,8 @@ export class IncorrectConnToken extends ActorError { export class MessageTooLong extends ActorError { constructor() { super( - "message_too_long", + "message", + "too_long", "Message too long. This can be configured with: `registry.runServer({ maxIncomingMessageSize: ... })`", { public: true }, ); @@ -161,7 +173,7 @@ export class MessageTooLong extends ActorError { export class MalformedMessage extends ActorError { constructor(cause?: unknown) { - super("malformed_message", `Malformed message: ${cause}`, { + super("message", "malformed", `Malformed message: ${cause}`, { public: true, cause, }); @@ -182,13 +194,13 @@ export class InvalidStateType extends ActorError { } msg += " Valid types include: null, undefined, boolean, string, number, BigInt, Date, RegExp, Error, typed arrays (Uint8Array, Int8Array, Float32Array, etc.), Map, Set, Array, and plain objects. (https://www.rivet.gg/docs/actors/state/#limitations)"; - super("invalid_state_type", msg); + super("state", "invalid_type", msg); } } export class Unsupported extends ActorError { constructor(feature: string) { - super("unsupported", `Unsupported feature: ${feature}`); + super("feature", "unsupported", `Unsupported feature: ${feature}`); } } @@ -216,7 +228,7 @@ export class UserError extends ActorError { * @param opts - Optional parameters for the error, including a machine-readable code and additional metadata. */ constructor(message: string, opts?: UserErrorOptions) { - super(opts?.code ?? USER_ERROR_CODE, message, { + super("user", opts?.code ?? USER_ERROR_CODE, message, { public: true, metadata: opts?.metadata, }); @@ -225,7 +237,7 @@ export class UserError extends ActorError { export class InvalidQueryJSON extends ActorError { constructor(error?: unknown) { - super("invalid_query_json", `Invalid query JSON: ${error}`, { + super("request", "invalid_query_json", `Invalid query JSON: ${error}`, { public: true, cause: error, }); @@ -234,7 +246,7 @@ export class InvalidQueryJSON extends ActorError { export class InvalidRequest extends ActorError { constructor(error?: unknown) { - super("invalid_request", `Invalid request: ${error}`, { + super("request", "invalid", `Invalid request: ${error}`, { public: true, cause: error, }); @@ -244,7 +256,8 @@ export class InvalidRequest extends ActorError { export class ActorNotFound extends ActorError { constructor(identifier?: string) { super( - "actor_not_found", + "actor", + "not_found", identifier ? `Actor not found: ${identifier} (https://www.rivet.gg/docs/actors/clients/#actor-client)` : "Actor not found (https://www.rivet.gg/docs/actors/clients/#actor-client)", @@ -256,7 +269,8 @@ export class ActorNotFound extends ActorError { export class ActorAlreadyExists extends ActorError { constructor(name: string, key: string[]) { super( - "actor_already_exists", + "actor", + "already_exists", `Actor already exists with name '${name}' and key '${JSON.stringify(key)}' (https://www.rivet.gg/docs/actors/clients/#actor-client)`, { public: true }, ); @@ -266,7 +280,8 @@ export class ActorAlreadyExists extends ActorError { export class ProxyError extends ActorError { constructor(operation: string, error?: unknown) { super( - "proxy_error", + "proxy", + "error", `Error proxying ${operation}, this is likely an internal error: ${error}`, { public: true, @@ -278,19 +293,20 @@ export class ProxyError extends ActorError { export class InvalidActionRequest extends ActorError { constructor(message: string) { - super("invalid_action_request", message, { public: true }); + super("action", "invalid_request", message, { public: true }); } } export class InvalidParams extends ActorError { constructor(message: string) { - super("invalid_params", message, { public: true }); + super("params", "invalid", message, { public: true }); } } export class Unauthorized extends ActorError { constructor(message?: string) { super( + "auth", "unauthorized", message ?? "Unauthorized. Access denied. (https://www.rivet.gg/docs/actors/authentication/)", @@ -305,6 +321,7 @@ export class Unauthorized extends ActorError { export class Forbidden extends ActorError { constructor(message?: string, opts?: { metadata?: unknown }) { super( + "auth", "forbidden", message ?? "Forbidden. Access denied. (https://www.rivet.gg/docs/actors/authentication/)", @@ -320,7 +337,8 @@ export class Forbidden extends ActorError { export class DatabaseNotEnabled extends ActorError { constructor() { super( - "database_not_enabled", + "database", + "not_enabled", "Database not enabled. Must implement `database` to use database.", ); } @@ -329,7 +347,8 @@ export class DatabaseNotEnabled extends ActorError { export class FetchHandlerNotDefined extends ActorError { constructor() { super( - "fetch_handler_not_defined", + "handler", + "fetch_not_defined", "Raw HTTP handler not defined. Actor must implement `onFetch` to handle raw HTTP requests. (https://www.rivet.gg/docs/actors/fetch-and-websocket-handler/)", { public: true }, ); @@ -340,7 +359,8 @@ export class FetchHandlerNotDefined extends ActorError { export class WebSocketHandlerNotDefined extends ActorError { constructor() { super( - "websocket_handler_not_defined", + "handler", + "websocket_not_defined", "Raw WebSocket handler not defined. Actor must implement `onWebSocket` to handle raw WebSocket connections. (https://www.rivet.gg/docs/actors/fetch-and-websocket-handler/)", { public: true }, ); @@ -351,6 +371,7 @@ export class WebSocketHandlerNotDefined extends ActorError { export class InvalidFetchResponse extends ActorError { constructor() { super( + "handler", "invalid_fetch_response", "Actor's onFetch handler must return a Response object. Returning void/undefined is not allowed. (https://www.rivet.gg/docs/actors/fetch-and-websocket-handler/)", { public: true }, @@ -358,3 +379,44 @@ export class InvalidFetchResponse extends ActorError { this.statusCode = 500; } } + +// Manager-specific errors +export class MissingActorHeader extends ActorError { + constructor() { + super( + "request", + "missing_actor_header", + "Missing x-rivet-actor header when x-rivet-target=actor", + { public: true }, + ); + this.statusCode = 400; + } +} + +export class WebSocketsNotEnabled extends ActorError { + constructor() { + super( + "driver", + "websockets_not_enabled", + "WebSockets are not enabled for this driver", + { public: true }, + ); + this.statusCode = 400; + } +} + +export class FeatureNotImplemented extends ActorError { + constructor(feature: string) { + super("feature", "not_implemented", `${feature} is not implemented`, { + public: true, + }); + this.statusCode = 501; + } +} + +export class RouteNotFound extends ActorError { + constructor() { + super("route", "not_found", "Route not found", { public: true }); + this.statusCode = 404; + } +} diff --git a/packages/rivetkit/src/actor/instance.ts b/packages/rivetkit/src/actor/instance.ts index ce01fd968..ae90e88e9 100644 --- a/packages/rivetkit/src/actor/instance.ts +++ b/packages/rivetkit/src/actor/instance.ts @@ -6,7 +6,6 @@ import type { Client } from "@/client/client"; import { getBaseLogger, getIncludeTarget, type Logger } from "@/common/log"; import { isCborSerializable, stringifyError } from "@/common/utils"; import type { UniversalWebSocket } from "@/common/websocket-interface"; -import { serializeActorKey } from "@/drivers/engine/keys"; import { ActorInspector } from "@/inspector/actor"; import type { Registry } from "@/mod"; import type * as bareSchema from "@/schemas/actor-persist/mod"; @@ -30,6 +29,7 @@ import { ActorContext } from "./context"; import type { AnyDatabaseProvider, InferDatabaseClient } from "./database"; import type { ActorDriver, ConnDriver, ConnectionDriversMap } from "./driver"; import * as errors from "./errors"; +import { serializeActorKey } from "./keys"; import { loggerWithoutContext } from "./log"; import type { PersistedActor, @@ -66,8 +66,6 @@ export type AnyActorInstance = ActorInstance< // biome-ignore lint/suspicious/noExplicitAny: Needs to be used in `extends` any, // biome-ignore lint/suspicious/noExplicitAny: Needs to be used in `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Needs to be used in `extends` any >; @@ -83,8 +81,6 @@ export type ExtractActorState = // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any, // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any > ? State @@ -102,8 +98,6 @@ export type ExtractActorConnParams = // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any, // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any > ? ConnParams @@ -121,24 +115,14 @@ export type ExtractActorConnState = // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any, // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any > ? ConnState : never; -export class ActorInstance< - S, - CP, - CS, - V, - I, - AD, - DB extends AnyDatabaseProvider, -> { +export class ActorInstance { // Shared actor context for this instance - actorContext: ActorContext; + actorContext: ActorContext; /** Actor log, intended for the user to call */ #log!: Logger; @@ -176,7 +160,7 @@ export class ActorInstance< #backgroundPromises: Promise[] = []; #abortController = new AbortController(); - #config: ActorConfig; + #config: ActorConfig; #connectionDrivers!: ConnectionDriversMap; #actorDriver!: ActorDriver; #inlineClient!: Client>; @@ -186,8 +170,8 @@ export class ActorInstance< #region!: string; #ready = false; - #connections = new Map>(); - #subscriptionIndex = new Map>>(); + #connections = new Map>(); + #subscriptionIndex = new Map>>(); #checkConnLivenessInterval?: NodeJS.Timeout; #sleepTimeout?: NodeJS.Timeout; @@ -225,7 +209,6 @@ export class ActorInstance< stateEnabled: conn._stateEnabled, params: conn.params as {}, state: conn._stateEnabled ? conn.state : undefined, - auth: conn.auth as {}, })); }, setState: async (state: unknown) => { @@ -265,7 +248,7 @@ export class ActorInstance< * * @private */ - constructor(config: ActorConfig) { + constructor(config: ActorConfig) { this.#config = config; this.actorContext = new ActorContext(this); } @@ -320,7 +303,6 @@ export class ActorInstance< undefined, undefined, undefined, - undefined, any >, this.#actorDriver.getContext(this.#actorId), @@ -738,7 +720,7 @@ export class ActorInstance< for (const connPersist of this.#persist.connections) { // Create connections const driver = this.__getConnDriver(connPersist.connDriver); - const conn = new Conn( + const conn = new Conn( this, connPersist, driver, @@ -770,7 +752,6 @@ export class ActorInstance< undefined, undefined, undefined, - undefined, undefined >, persistData.input!, @@ -805,14 +786,14 @@ export class ActorInstance< } } - __getConnForId(id: string): Conn | undefined { + __getConnForId(id: string): Conn | undefined { return this.#connections.get(id); } /** * Removes a connection and cleans up its resources. */ - __removeConn(conn: Conn | undefined) { + __removeConn(conn: Conn | undefined) { if (!conn) { this.#rLog.warn({ msg: "`conn` does not exist" }); return; @@ -894,7 +875,6 @@ export class ActorInstance< undefined, undefined, undefined, - undefined, undefined >, onBeforeConnectOpts, @@ -938,7 +918,7 @@ export class ActorInstance< driverId: ConnectionDriver, driverState: unknown, authData: unknown, - ): Promise> { + ): Promise> { this.#assertReady(); if (this.#connections.has(connectionId)) { @@ -958,7 +938,7 @@ export class ActorInstance< lastSeen: Date.now(), subscriptions: [], }; - const conn = new Conn( + const conn = new Conn( this, persist, driver, @@ -1024,7 +1004,7 @@ export class ActorInstance< // MARK: Messages async processMessage( message: protocol.ToServer, - conn: Conn, + conn: Conn, ) { await processMessage(message, this, conn, { onExecuteAction: async (ctx, name, args) => { @@ -1058,7 +1038,7 @@ export class ActorInstance< // MARK: Events #addSubscription( eventName: string, - connection: Conn, + connection: Conn, fromPersist: boolean, ) { if (connection.subscriptions.has(eventName)) { @@ -1091,7 +1071,7 @@ export class ActorInstance< #removeSubscription( eventName: string, - connection: Conn, + connection: Conn, fromRemoveConn: boolean, ) { if (!connection.subscriptions.has(eventName)) { @@ -1205,7 +1185,7 @@ export class ActorInstance< * @internal */ async executeAction( - ctx: ActionContext, + ctx: ActionContext, actionName: string, args: unknown[], ): Promise { @@ -1329,7 +1309,7 @@ export class ActorInstance< /** * Handles raw HTTP requests to the actor. */ - async handleFetch(request: Request, opts: { auth: AD }): Promise { + async handleFetch(request: Request, opts: {}): Promise { this.#assertReady(); if (!this.#config.onFetch) { @@ -1366,7 +1346,7 @@ export class ActorInstance< */ async handleWebSocket( websocket: UniversalWebSocket, - opts: { request: Request; auth: AD }, + opts: { request: Request }, ): Promise { this.#assertReady(); @@ -1459,7 +1439,7 @@ export class ActorInstance< /** * Gets the map of connections. */ - get conns(): Map> { + get conns(): Map> { return this.#connections; } @@ -1809,10 +1789,6 @@ export class ActorInstance< ), parameters: bufferToArrayBuffer(cbor.encode(conn.params || {})), state: bufferToArrayBuffer(cbor.encode(conn.state || {})), - auth: - conn.authData !== undefined - ? bufferToArrayBuffer(cbor.encode(conn.authData)) - : null, subscriptions: conn.subscriptions.map((sub) => ({ eventName: sub.eventName, })), @@ -1848,9 +1824,6 @@ export class ActorInstance< connDriverState: cbor.decode(new Uint8Array(conn.driverState)), params: cbor.decode(new Uint8Array(conn.parameters)), state: cbor.decode(new Uint8Array(conn.state)), - authData: conn.auth - ? cbor.decode(new Uint8Array(conn.auth)) - : undefined, subscriptions: conn.subscriptions.map((sub) => ({ eventName: sub.eventName, })), diff --git a/packages/rivetkit/src/drivers/engine/keys.test.ts b/packages/rivetkit/src/actor/keys.test.ts similarity index 100% rename from packages/rivetkit/src/drivers/engine/keys.test.ts rename to packages/rivetkit/src/actor/keys.test.ts diff --git a/packages/rivetkit/src/drivers/engine/keys.ts b/packages/rivetkit/src/actor/keys.ts similarity index 100% rename from packages/rivetkit/src/drivers/engine/keys.ts rename to packages/rivetkit/src/actor/keys.ts diff --git a/packages/rivetkit/src/actor/mod.ts b/packages/rivetkit/src/actor/mod.ts index 0ee22aa1d..58c585070 100644 --- a/packages/rivetkit/src/actor/mod.ts +++ b/packages/rivetkit/src/actor/mod.ts @@ -14,7 +14,6 @@ export function actor< TConnState, TVars, TInput, - TAuthData, TDatabase extends AnyDatabaseProvider, TActions extends Actions< TState, @@ -22,7 +21,6 @@ export function actor< TConnState, TVars, TInput, - TAuthData, TDatabase >, >( @@ -32,7 +30,6 @@ export function actor< TConnState, TVars, TInput, - TAuthData, TDatabase, TActions >, @@ -42,7 +39,6 @@ export function actor< TConnState, TVars, TInput, - TAuthData, TDatabase, TActions > { @@ -52,7 +48,6 @@ export function actor< TConnState, TVars, TInput, - TAuthData, TDatabase >; return new ActorDefinition(config); diff --git a/packages/rivetkit/src/actor/protocol/old.ts b/packages/rivetkit/src/actor/protocol/old.ts index eb557f331..5659ce25a 100644 --- a/packages/rivetkit/src/actor/protocol/old.ts +++ b/packages/rivetkit/src/actor/protocol/old.ts @@ -94,21 +94,20 @@ export interface ProcessMessageHandler< CS, V, I, - AD, DB extends AnyDatabaseProvider, > { onExecuteAction?: ( - ctx: ActionContext, + ctx: ActionContext, name: string, args: unknown[], ) => Promise; onSubscribe?: ( eventName: string, - conn: Conn, + conn: Conn, ) => Promise; onUnsubscribe?: ( eventName: string, - conn: Conn, + conn: Conn, ) => Promise; } @@ -118,13 +117,12 @@ export async function processMessage< CS, V, I, - AD, DB extends AnyDatabaseProvider, >( message: protocol.ToServer, - actor: ActorInstance, - conn: Conn, - handler: ProcessMessageHandler, + actor: ActorInstance, + conn: Conn, + handler: ProcessMessageHandler, ) { let actionId: bigint | undefined; let actionName: string | undefined; @@ -148,7 +146,7 @@ export async function processMessage< actionName: name, }); - const ctx = new ActionContext( + const ctx = new ActionContext( actor.actorContext, conn, ); @@ -214,11 +212,15 @@ export async function processMessage< assertUnreachable(message.body); } } catch (error) { - const { code, message, metadata } = deconstructError(error, actor.rLog, { - connectionId: conn.id, - actionId, - actionName, - }); + const { group, code, message, metadata } = deconstructError( + error, + actor.rLog, + { + connectionId: conn.id, + actionId, + actionName, + }, + ); actor.rLog.debug({ msg: "sending error response", @@ -235,6 +237,7 @@ export async function processMessage< body: { tag: "Error", val: { + group, code, message, metadata: bufferToArrayBuffer(cbor.encode(metadata)), diff --git a/packages/rivetkit/src/actor/router-endpoints.ts b/packages/rivetkit/src/actor/router-endpoints.ts index c70ed55be..49eb0cf49 100644 --- a/packages/rivetkit/src/actor/router-endpoints.ts +++ b/packages/rivetkit/src/actor/router-endpoints.ts @@ -571,7 +571,6 @@ export async function handleRawWebSocketHandler( // Call the actor's onWebSocket handler with the adapted WebSocket actor.handleWebSocket(adapter, { request: newRequest, - auth: authData, }); }, onMessage: (event: any, ws: any) => { @@ -613,13 +612,9 @@ export function getRequestEncoding(req: HonoRequest): Encoding { return result.data; } -export function getRequestExposeInternalError(req: Request): boolean { - const param = req.headers.get(HEADER_EXPOSE_INTERNAL_ERROR); - if (!param) { - return false; - } - - return param === "true"; +export function getRequestExposeInternalError(_req: Request): boolean { + // Unipmlemented + return false; } export function getRequestQuery(c: HonoContext): unknown { @@ -644,9 +639,6 @@ export const HEADER_ACTOR_QUERY = "X-RivetKit-Query"; export const HEADER_ENCODING = "X-RivetKit-Encoding"; -// Internal header -export const HEADER_EXPOSE_INTERNAL_ERROR = "X-RivetKit-Expose-Internal-Error"; - // IMPORTANT: Params must be in headers or in an E2EE part of the request (i.e. NOT the URL or query string) in order to ensure that tokens can be securely passed in params. export const HEADER_CONN_PARAMS = "X-RivetKit-Conn-Params"; diff --git a/packages/rivetkit/src/actor/router.ts b/packages/rivetkit/src/actor/router.ts index 3be4d860d..3802215f8 100644 --- a/packages/rivetkit/src/actor/router.ts +++ b/packages/rivetkit/src/actor/router.ts @@ -254,12 +254,7 @@ export function createActorRouter( } router.notFound(handleRouteNotFound); - router.onError( - handleRouteError.bind(undefined, { - // All headers to this endpoint are considered secure, so we can enable the expose internal error header for requests from the internal client - enableExposeInternalError: true, - }), - ); + router.onError(handleRouteError); return router; } diff --git a/packages/rivetkit/src/client/actor-common.ts b/packages/rivetkit/src/client/actor-common.ts index efe23dc79..0d242e325 100644 --- a/packages/rivetkit/src/client/actor-common.ts +++ b/packages/rivetkit/src/client/actor-common.ts @@ -21,7 +21,7 @@ export type ActorActionFunction< */ export type ActorDefinitionActions = // biome-ignore lint/suspicious/noExplicitAny: safe to use any here - AD extends ActorDefinition + AD extends ActorDefinition ? { [K in keyof R]: R[K] extends (...args: infer Args) => infer Return ? ActorActionFunction diff --git a/packages/rivetkit/src/client/actor-conn.ts b/packages/rivetkit/src/client/actor-conn.ts index b3fa8477c..f5d4a524c 100644 --- a/packages/rivetkit/src/client/actor-conn.ts +++ b/packages/rivetkit/src/client/actor-conn.ts @@ -1,7 +1,7 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; import pRetry from "p-retry"; -import type { CloseEvent, WebSocket } from "ws"; +import type { CloseEvent } from "ws"; import type { AnyActorDefinition } from "@/actor/definition"; import { inputDataToBuffer } from "@/actor/protocol/old"; import { type Encoding, jsonStringifyCompat } from "@/actor/protocol/serde"; @@ -11,7 +11,14 @@ import type { UniversalMessageEvent, } from "@/common/eventsource-interface"; import { assertUnreachable, stringifyError } from "@/common/utils"; +import { + HEADER_CONN_ID, + HEADER_CONN_TOKEN, + HEADER_ENCODING, + type ManagerDriver, +} from "@/driver-helpers/mod"; import type { ActorQuery } from "@/manager/protocol/query"; +import { PATH_CONNECT_WEBSOCKET, type UniversalWebSocket } from "@/mod"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED, @@ -24,15 +31,15 @@ import { } from "@/serde"; import { bufferToArrayBuffer, getEnvUniversal } from "@/utils"; import type { ActorDefinitionActions } from "./actor-common"; -import { - ACTOR_CONNS_SYMBOL, - type ClientDriver, - type ClientRaw, - TRANSPORT_SYMBOL, -} from "./client"; +import { queryActor } from "./actor-query"; +import { ACTOR_CONNS_SYMBOL, type ClientRaw, TRANSPORT_SYMBOL } from "./client"; import * as errors from "./errors"; import { logger } from "./log"; -import { type WebSocketMessage as ConnMessage, messageLength } from "./utils"; +import { + type WebSocketMessage as ConnMessage, + messageLength, + sendHttpRequest, +} from "./utils"; interface ActionInFlight { name: string; @@ -65,7 +72,7 @@ export interface SendHttpMessageOpts { } export type ConnTransport = - | { websocket: WebSocket } + | { websocket: UniversalWebSocket } | { sse: UniversalEventSource }; export const CONNECT_SYMBOL = Symbol("connect"); @@ -112,9 +119,9 @@ export class ActorConnRaw { #onOpenPromise?: PromiseWithResolvers; #client: ClientRaw; - #driver: ClientDriver; + #driver: ManagerDriver; #params: unknown; - #encodingKind: Encoding; + #encoding: Encoding; #actorQuery: ActorQuery; // TODO: ws message queue @@ -128,15 +135,15 @@ export class ActorConnRaw { */ public constructor( client: ClientRaw, - driver: ClientDriver, + driver: ManagerDriver, params: unknown, - encodingKind: Encoding, + encoding: Encoding, actorQuery: ActorQuery, ) { this.#client = client; this.#driver = driver; this.#params = params; - this.#encodingKind = encodingKind; + this.#encoding = encoding; this.#actorQuery = actorQuery; this.#keepNodeAliveInterval = setInterval(() => 60_000); @@ -261,13 +268,17 @@ enc } } - async #connectWebSocket({ signal }: { signal?: AbortSignal } = {}) { - const ws = await this.#driver.connectWebSocket( + async #connectWebSocket() { + const { actorId } = await queryActor( undefined, this.#actorQuery, - this.#encodingKind, + this.#driver, + ); + const ws = await this.#driver.openWebSocket( + PATH_CONNECT_WEBSOCKET, + actorId, + this.#encoding, this.#params, - signal ? { signal } : undefined, ); this.#transport = { websocket: ws }; ws.addEventListener("open", () => { @@ -284,31 +295,66 @@ enc }); } - async #connectSse({ signal }: { signal?: AbortSignal } = {}) { - const eventSource = await this.#driver.connectSse( - undefined, - this.#actorQuery, - this.#encodingKind, - this.#params, - signal ? { signal } : undefined, - ); - this.#transport = { sse: eventSource }; - eventSource.onopen = () => { - logger().debug({ msg: "eventsource open" }); - // #handleOnOpen is called on "i" event - }; - eventSource.onmessage = (ev: UniversalMessageEvent) => { - this.#handleOnMessage(ev.data); - }; - eventSource.onerror = (_ev: UniversalErrorEvent) => { - if (eventSource.readyState === eventSource.CLOSED) { - // This error indicates a close event - this.#handleOnClose(new Event("error")); - } else { - // Log error since event source is still open - this.#handleOnError(); - } - }; + async #connectSse() { + throw "TODO"; + + // OLD: + // const eventSource = await this.#driver.connectSse( + // undefined, + // this.#actorQuery, + // this.#encodingKind, + // this.#params, + // signal ? { signal } : undefined, + // ); + // this.#transport = { sse: eventSource }; + // eventSource.onopen = () => { + // logger().debug({ msg: "eventsource open" }); + // // #handleOnOpen is called on "i" event + // }; + // eventSource.onmessage = (ev: UniversalMessageEvent) => { + // this.#handleOnMessage(ev.data); + // }; + // eventSource.onerror = (_ev: UniversalErrorEvent) => { + // if (eventSource.readyState === eventSource.CLOSED) { + // // This error indicates a close event + // this.#handleOnClose(new Event("error")); + // } else { + // // Log error since event source is still open + // this.#handleOnError(); + // } + // }; + + // NEW: + // const EventSource = await importEventSource(); + // + // // Get the actor ID + // const { actorId } = await managerDriver.queryActor(c, actorQuery); + // logger().debug({ msg: "found actor for sse connection", actorId }); + // invariant(actorId, "Missing actor ID"); + // + // logger().debug({ + // msg: "opening sse connection", + // actorId, + // encoding: encodingKind, + // }); + // + // const eventSource = new EventSource("http://actor/connect/sse", { + // fetch: (input, init) => { + // return fetch(input, { + // ...init, + // headers: { + // ...init?.headers, + // "User-Agent": httpUserAgent(), + // [HEADER_ENCODING]: encodingKind, + // ...(params !== undefined + // ? { [HEADER_CONN_PARAMS]: JSON.stringify(params) } + // : {}), + // }, + // }); + // }, + // }) as UniversalEventSource; + // + // return eventSource; } /** Called by the onopen event from drivers. */ @@ -372,7 +418,7 @@ enc this.#handleOnOpen(); } else if (response.body.tag === "Error") { // Connection error - const { code, message, metadata, actionId } = response.body.val; + const { group, code, message, metadata, actionId } = response.body.val; if (actionId) { const inFlight = this.#takeActionInFlight(Number(actionId)); @@ -381,22 +427,29 @@ enc msg: "action error", actionId: actionId, actionName: inFlight?.name, + group, code, message, metadata, }); - inFlight.reject(new errors.ActorError(code, message, metadata)); + inFlight.reject(new errors.ActorError(group, code, message, metadata)); } else { logger().warn({ msg: "connection error", + group, code, message, metadata, }); // Create a connection error - const actorError = new errors.ActorError(code, message, metadata); + const actorError = new errors.ActorError( + group, + code, + message, + metadata, + ); // If we have an onOpenPromise, reject it with the error if (this.#onOpenPromise) { @@ -623,7 +676,7 @@ enc if (this.#transport.websocket.readyState === 1) { try { const messageSerialized = serializeWithEncoding( - this.#encodingKind, + this.#encoding, message, TO_SERVER_VERSIONED, ); @@ -673,20 +726,33 @@ enc getEnvUniversal("_RIVETKIT_LOG_MESSAGE") ? { msg: "sent http message", - message: jsonStringifyCompat(message).substring(0, 100) + "...", + message: `${jsonStringifyCompat(message).substring(0, 100)}...`, } : { msg: "sent http message" }, ); - await this.#driver.sendHttpMessage( - undefined, - this.#actorId, - this.#encodingKind, - this.#connectionId, - this.#connectionToken, - message, - opts?.signal ? { signal: opts.signal } : undefined, - ); + logger().debug({ + msg: "sending http message", + actorId: this.#actorId, + connectionId: this.#connectionId, + }); + + // Send an HTTP request to the connections endpoint + await sendHttpRequest({ + url: "http://actor/connections/message", + method: "POST", + headers: { + [HEADER_ENCODING]: this.#encoding, + [HEADER_CONN_ID]: this.#connectionId, + [HEADER_CONN_TOKEN]: this.#connectionToken, + }, + body: message, + encoding: this.#encoding, + skipParseResponse: true, + customFetch: this.#driver.sendRequest.bind(this.#driver, this.#actorId), + requestVersionedDataHandler: TO_SERVER_VERSIONED, + responseVersionedDataHandler: TO_CLIENT_VERSIONED, + }); } catch (error) { // TODO: This will not automatically trigger a re-broadcast of HTTP events since SSE is separate from the HTTP action @@ -705,7 +771,7 @@ enc invariant(this.#transport, "transport must be defined"); // Decode base64 since SSE sends raw strings - if (encodingIsBinary(this.#encodingKind) && "sse" in this.#transport) { + if (encodingIsBinary(this.#encoding) && "sse" in this.#transport) { if (typeof data === "string") { const binaryString = atob(data); data = new Uint8Array( @@ -720,11 +786,7 @@ enc const buffer = await inputDataToBuffer(data); - return deserializeWithEncoding( - this.#encodingKind, - buffer, - TO_CLIENT_VERSIONED, - ); + return deserializeWithEncoding(this.#encoding, buffer, TO_CLIENT_VERSIONED); } /** diff --git a/packages/rivetkit/src/client/actor-handle.ts b/packages/rivetkit/src/client/actor-handle.ts index 6bdfdc07f..484b5ef4d 100644 --- a/packages/rivetkit/src/client/actor-handle.ts +++ b/packages/rivetkit/src/client/actor-handle.ts @@ -1,18 +1,30 @@ +import * as cbor from "cbor-x"; import invariant from "invariant"; import type { AnyActorDefinition } from "@/actor/definition"; import type { Encoding } from "@/actor/protocol/serde"; import { assertUnreachable } from "@/actor/utils"; +import { deconstructError } from "@/common/utils"; import { importWebSocket } from "@/common/websocket"; +import { + HEADER_CONN_PARAMS, + HEADER_ENCODING, + type ManagerDriver, +} from "@/driver-helpers/mod"; import type { ActorQuery } from "@/manager/protocol/query"; +import type * as protocol from "@/schemas/client-protocol/mod"; +import { + HTTP_ACTION_REQUEST_VERSIONED, + HTTP_ACTION_RESPONSE_VERSIONED, +} from "@/schemas/client-protocol/versioned"; +import { bufferToArrayBuffer } from "@/utils"; import type { ActorDefinitionActions } from "./actor-common"; import { type ActorConn, ActorConnRaw } from "./actor-conn"; -import { - type ClientDriver, - type ClientRaw, - CREATE_ACTOR_CONN_PROXY, -} from "./client"; +import { queryActor } from "./actor-query"; +import { type ClientRaw, CREATE_ACTOR_CONN_PROXY } from "./client"; +import { ActorError } from "./errors"; import { logger } from "./log"; import { rawHttpFetch, rawWebSocket } from "./raw-utils"; +import { sendHttpRequest } from "./utils"; /** * Provides underlying functions for stateless {@link ActorHandle} for action calls. @@ -22,8 +34,8 @@ import { rawHttpFetch, rawWebSocket } from "./raw-utils"; */ export class ActorHandleRaw { #client: ClientRaw; - #driver: ClientDriver; - #encodingKind: Encoding; + #driver: ManagerDriver; + #encoding: Encoding; #actorQuery: ActorQuery; #params: unknown; @@ -36,14 +48,14 @@ export class ActorHandleRaw { */ public constructor( client: any, - driver: ClientDriver, + driver: ManagerDriver, params: unknown, - encodingKind: Encoding, + encoding: Encoding, actorQuery: ActorQuery, ) { this.#client = client; this.#driver = driver; - this.#encodingKind = encodingKind; + this.#encoding = encoding; this.#actorQuery = actorQuery; this.#params = params; } @@ -63,15 +75,64 @@ export class ActorHandleRaw { args: Args; signal?: AbortSignal; }): Promise { - return await this.#driver.action( - undefined, - this.#actorQuery, - this.#encodingKind, - this.#params, - opts.name, - opts.args, - { signal: opts.signal }, - ); + // return await this.#driver.action( + // undefined, + // this.#actorQuery, + // this.#encodingKind, + // this.#params, + // opts.name, + // opts.args, + // { signal: opts.signal }, + // ); + try { + // Get the actor ID + const { actorId } = await queryActor( + undefined, + this.#actorQuery, + this.#driver, + ); + logger().debug({ msg: "found actor for action", actorId }); + invariant(actorId, "Missing actor ID"); + + // Invoke the action + logger().debug({ + msg: "handling action", + name: opts.name, + encoding: this.#encoding, + }); + const responseData = await sendHttpRequest< + protocol.HttpActionRequest, + protocol.HttpActionResponse + >({ + url: `http://actor/action/${encodeURIComponent(opts.name)}`, + method: "POST", + headers: { + [HEADER_ENCODING]: this.#encoding, + ...(this.#params !== undefined + ? { [HEADER_CONN_PARAMS]: JSON.stringify(this.#params) } + : {}), + }, + body: { + args: bufferToArrayBuffer(cbor.encode(opts.args)), + } satisfies protocol.HttpActionRequest, + encoding: this.#encoding, + customFetch: this.#driver.sendRequest.bind(this.#driver, actorId), + signal: opts?.signal, + requestVersionedDataHandler: HTTP_ACTION_REQUEST_VERSIONED, + responseVersionedDataHandler: HTTP_ACTION_RESPONSE_VERSIONED, + }); + + return cbor.decode(new Uint8Array(responseData.output)); + } catch (err) { + // Standardize to ClientActorError instead of the native backend error + const { group, code, message, metadata } = deconstructError( + err, + logger(), + {}, + true, + ); + throw new ActorError(group, code, message, metadata); + } } /** @@ -90,7 +151,7 @@ export class ActorHandleRaw { this.#client, this.#driver, this.#params, - this.#encodingKind, + this.#encoding, this.#actorQuery, ); @@ -159,12 +220,10 @@ export class ActorHandleRaw { assertUnreachable(this.#actorQuery); } - const actorId = await this.#driver.resolveActorId( + const { actorId } = await queryActor( undefined, this.#actorQuery, - this.#encodingKind, - this.#params, - signal ? { signal } : undefined, + this.#driver, ); this.#actorQuery = { getForId: { actorId, name } }; diff --git a/packages/rivetkit/src/client/actor-query.ts b/packages/rivetkit/src/client/actor-query.ts new file mode 100644 index 000000000..645d7e680 --- /dev/null +++ b/packages/rivetkit/src/client/actor-query.ts @@ -0,0 +1,65 @@ +import type { Context as HonoContext } from "hono"; +import * as errors from "@/actor/errors"; +import type { ManagerDriver } from "@/driver-helpers/mod"; +import type { ActorQuery } from "@/manager/protocol/query"; +import { logger } from "./log"; + +/** + * Query the manager driver to get or create a actor based on the provided query + */ +export async function queryActor( + c: HonoContext | undefined, + query: ActorQuery, + managerDriver: ManagerDriver, +): Promise<{ actorId: string }> { + logger().debug({ msg: "querying actor", query: JSON.stringify(query) }); + let actorOutput: { actorId: string }; + if ("getForId" in query) { + const output = await managerDriver.getForId({ + c, + name: query.getForId.name, + actorId: query.getForId.actorId, + }); + if (!output) throw new errors.ActorNotFound(query.getForId.actorId); + actorOutput = output; + } else if ("getForKey" in query) { + const existingActor = await managerDriver.getWithKey({ + c, + name: query.getForKey.name, + key: query.getForKey.key, + }); + if (!existingActor) { + throw new errors.ActorNotFound( + `${query.getForKey.name}:${JSON.stringify(query.getForKey.key)}`, + ); + } + actorOutput = existingActor; + } else if ("getOrCreateForKey" in query) { + const getOrCreateOutput = await managerDriver.getOrCreateWithKey({ + c, + name: query.getOrCreateForKey.name, + key: query.getOrCreateForKey.key, + input: query.getOrCreateForKey.input, + region: query.getOrCreateForKey.region, + }); + actorOutput = { + actorId: getOrCreateOutput.actorId, + }; + } else if ("create" in query) { + const createOutput = await managerDriver.createActor({ + c, + name: query.create.name, + key: query.create.key, + input: query.create.input, + region: query.create.region, + }); + actorOutput = { + actorId: createOutput.actorId, + }; + } else { + throw new errors.InvalidRequest("Invalid query format"); + } + + logger().debug({ msg: "actor query result", actorId: actorOutput.actorId }); + return { actorId: actorOutput.actorId }; +} diff --git a/packages/rivetkit/src/client/client.ts b/packages/rivetkit/src/client/client.ts index 097f6b622..546d7d2cf 100644 --- a/packages/rivetkit/src/client/client.ts +++ b/packages/rivetkit/src/client/client.ts @@ -1,12 +1,9 @@ -import type { Context as HonoContext } from "hono"; -import type { WebSocket } from "ws"; import type { AnyActorDefinition } from "@/actor/definition"; import type { Transport } from "@/actor/protocol/old"; import type { Encoding } from "@/actor/protocol/serde"; -import type { UniversalEventSource } from "@/common/eventsource-interface"; +import type { ManagerDriver } from "@/driver-helpers/mod"; import type { ActorQuery } from "@/manager/protocol/query"; import type { Registry } from "@/mod"; -import type { ToServer } from "@/schemas/client-protocol/mod"; import type { ActorActionFunction } from "./actor-common"; import { type ActorConn, @@ -14,8 +11,12 @@ import { CONNECT_SYMBOL, } from "./actor-conn"; import { type ActorHandle, ActorHandleRaw } from "./actor-handle"; +import { queryActor } from "./actor-query"; +import type { ClientConfig } from "./config"; import { logger } from "./log"; +export type { ClientConfig, ClientConfigInput } from "./config"; + /** Extract the actor registry from the registry definition. */ export type ExtractActorsFromRegistry> = A extends Registry ? Actors : never; @@ -78,15 +79,6 @@ export interface ActorAccessor { ): Promise>; } -/** - * Options for configuring the client. - * @typedef {Object} ClientOptions - */ -export interface ClientOptions { - encoding?: Encoding; - transport?: Transport; -} - /** * Options for querying actors. * @typedef {Object} QueryOptions @@ -159,66 +151,6 @@ export const ACTOR_CONNS_SYMBOL = Symbol("actorConns"); export const CREATE_ACTOR_CONN_PROXY = Symbol("createActorConnProxy"); export const TRANSPORT_SYMBOL = Symbol("transport"); -export interface ClientDriver { - action = unknown[], Response = unknown>( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - name: string, - args: Args, - opts: { signal?: AbortSignal } | undefined, - ): Promise; - resolveActorId( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params: unknown, - opts: { signal?: AbortSignal } | undefined, - ): Promise; - connectWebSocket( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params: unknown, - opts: { signal?: AbortSignal } | undefined, - ): Promise; - connectSse( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params: unknown, - opts: { signal?: AbortSignal } | undefined, - ): Promise; - sendHttpMessage( - c: HonoContext | undefined, - actorId: string, - encoding: Encoding, - connectionId: string, - connectionToken: string, - message: ToServer, - opts: { signal?: AbortSignal } | undefined, - ): Promise; - rawHttpRequest( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - path: string, - init: RequestInit, - opts: { signal?: AbortSignal } | undefined, - ): Promise; - rawWebSocket( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - path: string, - protocols: string | string[] | undefined, - opts: { signal?: AbortSignal } | undefined, - ): Promise; -} - /** * Client for managing & connecting to actors. * @@ -230,7 +162,7 @@ export class ClientRaw { [ACTOR_CONNS_SYMBOL] = new Set(); - #driver: ClientDriver; + #driver: ManagerDriver; #encodingKind: Encoding; [TRANSPORT_SYMBOL]: Transport; @@ -238,10 +170,10 @@ export class ClientRaw { * Creates an instance of Client. * * @param {string} managerEndpoint - The manager endpoint. See {@link https://rivet.gg/docs/setup|Initial Setup} for instructions on getting the manager endpoint. - * @param {ClientOptions} [opts] - Options for configuring the client. + * @param {ClientConfig} [opts] - Options for configuring the client. * @see {@link https://rivet.gg/docs/setup|Initial Setup} */ - public constructor(driver: ClientDriver, opts?: ClientOptions) { + public constructor(driver: ManagerDriver, opts?: ClientConfig) { this.#driver = driver; this.#encodingKind = opts?.encoding ?? "bare"; @@ -389,13 +321,7 @@ export class ClientRaw { }); // Create the actor - const actorId = await this.#driver.resolveActorId( - undefined, - createQuery, - this.#encodingKind, - opts?.params, - opts?.signal ? { signal: opts.signal } : undefined, - ); + const { actorId } = await queryActor(undefined, createQuery, this.#driver); logger().debug({ msg: "created actor with ID", name, @@ -479,10 +405,10 @@ export type Client> = ClientRaw & { export type AnyClient = Client>; export function createClientWithDriver>( - driver: ClientDriver, - opts?: ClientOptions, + driver: ManagerDriver, + config?: ClientConfig, ): Client { - const client = new ClientRaw(driver, opts); + const client = new ClientRaw(driver, config); // Create proxy for accessing actors by name return new Proxy(client, { diff --git a/packages/rivetkit/src/client/config.ts b/packages/rivetkit/src/client/config.ts new file mode 100644 index 000000000..bc4c15b9d --- /dev/null +++ b/packages/rivetkit/src/client/config.ts @@ -0,0 +1,44 @@ +import z from "zod"; +import { TransportSchema } from "@/actor/protocol/old"; +import { EncodingSchema } from "@/actor/protocol/serde"; +import { getEnvUniversal, type UpgradeWebSocket } from "@/utils"; + +export type GetUpgradeWebSocket = () => UpgradeWebSocket; + +export const ClientConfigSchema = z.object({ + /** Configure serving the API */ + api: z + .object({ + host: z.string().default("127.0.0.1"), + port: z.number().default(6420), + }) + .default({}), + + /** Endpoint to connect to the Rivet engine. Can be configured via RIVET_ENGINE env var. */ + endpoint: z + .string() + .nullable() + .default(() => getEnvUniversal("RIVET_ENGINE") ?? null), + + namespace: z + .string() + .default(() => getEnvUniversal("RIVET_NAMESPACE") ?? "default"), + + runnerName: z + .string() + .default(() => getEnvUniversal("RIVET_RUNNER") ?? "rivetkit"), + + encoding: EncodingSchema.default("bare"), + + transport: TransportSchema.default("websocket"), + + // This is a function to allow for lazy configuration of upgradeWebSocket on the + // fly. This is required since the dependencies that upgradeWebSocket + // (specifically Node.js) can sometimes only be specified after the router is + // created or must be imported async using `await import(...)` + getUpgradeWebSocket: z.custom().optional(), +}); + +export type ClientConfig = z.infer; + +export type ClientConfigInput = z.input; diff --git a/packages/rivetkit/src/client/errors.ts b/packages/rivetkit/src/client/errors.ts index 6244a43cf..25256f825 100644 --- a/packages/rivetkit/src/client/errors.ts +++ b/packages/rivetkit/src/client/errors.ts @@ -20,6 +20,7 @@ export class ActorError extends ActorClientError { __type = "ActorError"; constructor( + public readonly group: string, public readonly code: string, message: string, public readonly metadata?: unknown, diff --git a/packages/rivetkit/src/client/http-client-driver.ts b/packages/rivetkit/src/client/http-client-driver.ts deleted file mode 100644 index 28c66c452..000000000 --- a/packages/rivetkit/src/client/http-client-driver.ts +++ /dev/null @@ -1,329 +0,0 @@ -import * as cbor from "cbor-x"; -import type { Context as HonoContext } from "hono"; -import type { WebSocket } from "ws"; -import type { Encoding } from "@/actor/protocol/serde"; -import { - HEADER_ACTOR_ID, - HEADER_ACTOR_QUERY, - HEADER_CONN_ID, - HEADER_CONN_PARAMS, - HEADER_CONN_TOKEN, - HEADER_ENCODING, -} from "@/actor/router-endpoints"; -import { importEventSource } from "@/common/eventsource"; -import type { UniversalEventSource } from "@/common/eventsource-interface"; -import { importWebSocket } from "@/common/websocket"; -import type { ActorQuery } from "@/manager/protocol/query"; -import type * as protocol from "@/schemas/client-protocol/mod"; -import { - HTTP_ACTION_REQUEST_VERSIONED, - HTTP_ACTION_RESPONSE_VERSIONED, - HTTP_RESOLVE_REQUEST_VERSIONED, - HTTP_RESOLVE_RESPONSE_VERSIONED, - TO_SERVER_VERSIONED, -} from "@/schemas/client-protocol/versioned"; -import { serializeWithEncoding, wsBinaryTypeForEncoding } from "@/serde"; -import { assertUnreachable, bufferToArrayBuffer, httpUserAgent } from "@/utils"; -import type { ClientDriver } from "./client"; -import * as errors from "./errors"; -import { logger } from "./log"; -import { sendHttpRequest } from "./utils"; - -/** - * Client driver that communicates with the manager via HTTP. - */ -export function createHttpClientDriver(managerEndpoint: string): ClientDriver { - // Lazily import the dynamic imports so we don't have to turn `createClient` in to an async fn - const dynamicImports = (async () => { - // Import dynamic dependencies - const [WebSocket, EventSource] = await Promise.all([ - importWebSocket(), - importEventSource(), - ]); - return { - WebSocket, - EventSource, - }; - })(); - - const driver: ClientDriver = { - action: async = unknown[], Response = unknown>( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - name: string, - args: Args, - opts: { signal?: AbortSignal } | undefined, - ): Promise => { - logger().debug({ - msg: "actor handle action", - name, - args, - query: actorQuery, - }); - - const responseData = await sendHttpRequest< - protocol.HttpActionRequest, - protocol.HttpActionResponse - >({ - url: `${managerEndpoint}/registry/actors/actions/${encodeURIComponent(name)}`, - method: "POST", - headers: { - [HEADER_ENCODING]: encoding, - [HEADER_ACTOR_QUERY]: JSON.stringify(actorQuery), - ...(params !== undefined - ? { [HEADER_CONN_PARAMS]: JSON.stringify(params) } - : {}), - }, - body: { - args: bufferToArrayBuffer(cbor.encode(args)), - } satisfies protocol.HttpActionRequest, - encoding: encoding, - signal: opts?.signal, - requestVersionedDataHandler: HTTP_ACTION_REQUEST_VERSIONED, - responseVersionedDataHandler: HTTP_ACTION_RESPONSE_VERSIONED, - }); - - return cbor.decode(new Uint8Array(responseData.output)); - }, - - resolveActorId: async ( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params: unknown, - ): Promise => { - logger().debug({ msg: "resolving actor ID", query: actorQuery }); - - try { - const result = await sendHttpRequest< - null, - protocol.HttpResolveResponse - >({ - url: `${managerEndpoint}/registry/actors/resolve`, - method: "POST", - headers: { - [HEADER_ENCODING]: encodingKind, - [HEADER_ACTOR_QUERY]: JSON.stringify(actorQuery), - ...(params !== undefined - ? { [HEADER_CONN_PARAMS]: JSON.stringify(params) } - : {}), - }, - body: null, - encoding: encodingKind, - requestVersionedDataHandler: HTTP_RESOLVE_REQUEST_VERSIONED, - responseVersionedDataHandler: HTTP_RESOLVE_RESPONSE_VERSIONED, - }); - - logger().debug({ msg: "resolved actor ID", actorId: result.actorId }); - return result.actorId; - } catch (error) { - logger().error({ msg: "failed to resolve actor ID", error }); - if (error instanceof errors.ActorError) { - throw error; - } else { - throw new errors.InternalError( - `Failed to resolve actor ID: ${String(error)}`, - ); - } - } - }, - - connectWebSocket: async ( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params: unknown, - ): Promise => { - const { WebSocket } = await dynamicImports; - - const endpoint = managerEndpoint - .replace(/^http:/, "ws:") - .replace(/^https:/, "wss:"); - const url = `${endpoint}/registry/actors/connect/websocket`; - - // Pass sensitive data via protocol - const protocol = [ - `query.${encodeURIComponent(JSON.stringify(actorQuery))}`, - `encoding.${encodingKind}`, - ]; - if (params) - protocol.push( - `conn_params.${encodeURIComponent(JSON.stringify(params))}`, - ); - - // HACK: See packages/drivers/cloudflare-workers/src/websocket.ts - protocol.push("rivetkit"); - - logger().debug({ msg: "connecting to websocket", url }); - const ws = new WebSocket(url, protocol); - // HACK: Bun bug prevents changing binary type, so we ignore the error https://github.com/oven-sh/bun/issues/17005 - try { - ws.binaryType = wsBinaryTypeForEncoding(encodingKind); - } catch (error) {} - - // Node & web WebSocket types not compatible - return ws as any; - }, - - connectSse: async ( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params: unknown, - ): Promise => { - const { EventSource } = await dynamicImports; - - const url = `${managerEndpoint}/registry/actors/connect/sse`; - - logger().debug({ msg: "connecting to sse", url }); - const eventSource = new EventSource(url, { - fetch: (input, init) => { - return fetch(input, { - ...init, - headers: { - ...init?.headers, - "User-Agent": httpUserAgent(), - [HEADER_ENCODING]: encodingKind, - [HEADER_ACTOR_QUERY]: JSON.stringify(actorQuery), - ...(params !== undefined - ? { [HEADER_CONN_PARAMS]: JSON.stringify(params) } - : {}), - }, - credentials: "include", - }); - }, - }); - - return eventSource as UniversalEventSource; - }, - - sendHttpMessage: async ( - _c: HonoContext | undefined, - actorId: string, - encoding: Encoding, - connectionId: string, - connectionToken: string, - message: protocol.ToServer, - ): Promise => { - // TODO: Implement ordered messages, this is not guaranteed order. Needs to use an index in order to ensure we can pipeline requests efficiently. - // TODO: Validate that we're using HTTP/3 whenever possible for pipelining requests - const messageSerialized = serializeWithEncoding( - encoding, - message, - TO_SERVER_VERSIONED, - ); - const res = await fetch(`${managerEndpoint}/registry/actors/message`, { - method: "POST", - headers: { - "User-Agent": httpUserAgent(), - [HEADER_ENCODING]: encoding, - [HEADER_ACTOR_ID]: actorId, - [HEADER_CONN_ID]: connectionId, - [HEADER_CONN_TOKEN]: connectionToken, - }, - body: messageSerialized, - credentials: "include", - }); - if (!res.ok) { - throw new errors.InternalError( - `Publish message over HTTP error (${res.statusText}):\n${await res.text()}`, - ); - } - - // Discard response - await res.body?.cancel(); - }, - - rawHttpRequest: async ( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - path: string, - init: RequestInit, - ): Promise => { - // Build the full URL - // Remove leading slash from path to avoid double slashes - const normalizedPath = path.startsWith("/") ? path.slice(1) : path; - const url = `${managerEndpoint}/registry/actors/raw/http/${normalizedPath}`; - - logger().debug({ - msg: "rewriting http url", - from: path, - to: url, - }); - - // Merge headers properly - const headers = new Headers(init.headers); - headers.set("User-Agent", httpUserAgent()); - headers.set(HEADER_ACTOR_QUERY, JSON.stringify(actorQuery)); - headers.set(HEADER_ENCODING, encoding); - if (params !== undefined) { - headers.set(HEADER_CONN_PARAMS, JSON.stringify(params)); - } - - // Forward the request with query in headers - return await fetch(url, { - ...init, - headers, - }); - }, - - rawWebSocket: async ( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - path: string, - protocols: string | string[] | undefined, - ): Promise => { - const { WebSocket } = await dynamicImports; - - // Build the WebSocket URL with normalized path - const wsEndpoint = managerEndpoint - .replace(/^http:/, "ws:") - .replace(/^https:/, "wss:"); - // Normalize path to match raw HTTP behavior - const normalizedPath = path.startsWith("/") ? path.slice(1) : path; - const url = `${wsEndpoint}/registry/actors/raw/websocket/${normalizedPath}`; - - logger().debug({ - msg: "rewriting websocket url", - from: path, - to: url, - }); - - // Pass data via WebSocket protocol subprotocols - const protocolList: string[] = []; - protocolList.push( - `query.${encodeURIComponent(JSON.stringify(actorQuery))}`, - ); - protocolList.push(`encoding.${encoding}`); - if (params) { - protocolList.push( - `conn_params.${encodeURIComponent(JSON.stringify(params))}`, - ); - } - - // HACK: See packages/drivers/cloudflare-workers/src/websocket.ts - protocolList.push("rivetkit"); - - // Add user protocols - if (protocols) { - if (Array.isArray(protocols)) { - protocolList.push(...protocols); - } else { - protocolList.push(protocols); - } - } - - // Create WebSocket connection - logger().debug({ msg: "opening raw websocket", url }); - return new WebSocket(url, protocolList) as any; - }, - }; - - return driver; -} diff --git a/packages/rivetkit/src/client/mod.ts b/packages/rivetkit/src/client/mod.ts index 5fd4e4257..41b3215e2 100644 --- a/packages/rivetkit/src/client/mod.ts +++ b/packages/rivetkit/src/client/mod.ts @@ -1,10 +1,11 @@ import type { Registry } from "@/registry/mod"; +import { RemoteManagerDriver } from "@/remote-manager-driver/mod"; import { type Client, - type ClientOptions, + type ClientConfigInput, createClientWithDriver, } from "./client"; -import { createHttpClientDriver } from "./http-client-driver"; +import { ClientConfigSchema } from "./config"; export { ActorDefinition, @@ -28,7 +29,6 @@ export { ActorHandleRaw } from "./actor-handle"; export type { ActorAccessor, Client, - ClientOptions, ClientRaw, CreateOptions, ExtractActorsFromRegistry, @@ -41,16 +41,20 @@ export type { /** * Creates a client with the actor accessor proxy. - * - * @template A The actor application type. - * @param {string} managerEndpoint - The manager endpoint. - * @param {ClientOptions} [opts] - Options for configuring the client. - * @returns {Client} - A proxied client that supports the `client.myActor.connect()` syntax. */ export function createClient>( - endpoint: string, - opts?: ClientOptions, + endpointOrConfig?: string | ClientConfigInput, ): Client { - const driver = createHttpClientDriver(endpoint); - return createClientWithDriver(driver, opts); + // Parse config + const configInput = + endpointOrConfig === undefined + ? {} + : typeof endpointOrConfig === "string" + ? { engine: endpointOrConfig } + : endpointOrConfig; + const config = ClientConfigSchema.parse(configInput); + + // Create client + const driver = new RemoteManagerDriver(config); + return createClientWithDriver(driver, config); } diff --git a/packages/rivetkit/src/client/raw-utils.ts b/packages/rivetkit/src/client/raw-utils.ts index dad4458fd..08800a2b4 100644 --- a/packages/rivetkit/src/client/raw-utils.ts +++ b/packages/rivetkit/src/client/raw-utils.ts @@ -1,11 +1,17 @@ +import invariant from "invariant"; +import { PATH_RAW_WEBSOCKET_PREFIX } from "@/actor/router"; +import { deconstructError } from "@/common/utils"; +import { HEADER_CONN_PARAMS, type ManagerDriver } from "@/driver-helpers/mod"; import type { ActorQuery } from "@/manager/protocol/query"; -import type { ClientDriver } from "./client"; +import { queryActor } from "./actor-query"; +import { ActorError } from "./errors"; +import { logger } from "./log"; /** * Shared implementation for raw HTTP fetch requests */ export async function rawHttpFetch( - driver: ClientDriver, + driver: ManagerDriver, actorQuery: ActorQuery, params: unknown, input: string | URL | Request, @@ -55,38 +61,81 @@ export async function rawHttpFetch( throw new TypeError("Invalid input type for fetch"); } - // Use the driver's raw HTTP method - just pass the sub-path - return await driver.rawHttpRequest( - undefined, - actorQuery, - // Force JSON so it's readable by the user - "json", - params, - path, - mergedInit, - undefined, - ); + try { + // Get the actor ID + const { actorId } = await queryActor(undefined, actorQuery, driver); + logger().debug({ msg: "found actor for raw http", actorId }); + invariant(actorId, "Missing actor ID"); + + // Build the URL with normalized path + const normalizedPath = path.startsWith("/") ? path.slice(1) : path; + const url = new URL(`http://actor/raw/http/${normalizedPath}`); + + // Forward conn params if provided + const proxyRequestHeaders = new Headers(mergedInit.headers); + if (params) { + proxyRequestHeaders.set(HEADER_CONN_PARAMS, JSON.stringify(params)); + } + + // Forward the request to the actor + const proxyRequest = new Request(url, { + ...init, + headers: proxyRequestHeaders, + }); + + return driver.sendRequest(actorId, proxyRequest); + } catch (err) { + // Standardize to ClientActorError instead of the native backend error + const { group, code, message, metadata } = deconstructError( + err, + logger(), + {}, + true, + ); + throw new ActorError(group, code, message, metadata); + } } /** * Shared implementation for raw WebSocket connections */ export async function rawWebSocket( - driver: ClientDriver, + driver: ManagerDriver, actorQuery: ActorQuery, params: unknown, path?: string, + // TODO: Supportp rotocols protocols?: string | string[], ): Promise { - // Use the driver's raw WebSocket method - return await driver.rawWebSocket( - undefined, - actorQuery, - // Force JSON so it's readable by the user - "json", + // TODO: Do we need encoding in rawWebSocket? + const encoding = "bare"; + + // Get the actor ID + const { actorId } = await queryActor(undefined, actorQuery, driver); + logger().debug({ msg: "found actor for action", actorId }); + invariant(actorId, "Missing actor ID"); + + // Normalize path to match raw HTTP behavior + const normalizedPath = path + ? path.startsWith("/") + ? path.slice(1) + : path + : ""; + logger().debug({ + msg: "opening websocket", + actorId, + encoding, + path: normalizedPath, + }); + + // Open WebSocket + const ws = await driver.openWebSocket( + `${PATH_RAW_WEBSOCKET_PREFIX}${normalizedPath}`, + actorId, + encoding, params, - path || "", - protocols, - undefined, ); + + // Node & browser WebSocket types are incompatible + return ws as any; } diff --git a/packages/rivetkit/src/client/utils.ts b/packages/rivetkit/src/client/utils.ts index 75ed9ffd6..ca1e07aa5 100644 --- a/packages/rivetkit/src/client/utils.ts +++ b/packages/rivetkit/src/client/utils.ts @@ -41,8 +41,8 @@ export interface HttpRequestOpts { skipParseResponse?: boolean; signal?: AbortSignal; customFetch?: (req: Request) => Promise; - requestVersionedDataHandler: VersionedDataHandler; - responseVersionedDataHandler: VersionedDataHandler; + requestVersionedDataHandler: VersionedDataHandler | undefined; + responseVersionedDataHandler: VersionedDataHandler | undefined; } export async function sendHttpRequest< @@ -122,6 +122,7 @@ export async function sendHttpRequest< // Throw structured error throw new ActorError( + responseData.group, responseData.code, responseData.message, responseData.metadata diff --git a/packages/rivetkit/src/common/router.ts b/packages/rivetkit/src/common/router.ts index 7d3325d57..73684ab45 100644 --- a/packages/rivetkit/src/common/router.ts +++ b/packages/rivetkit/src/common/router.ts @@ -7,7 +7,7 @@ import { } from "@/actor/router-endpoints"; import { HttpResponseError } from "@/schemas/client-protocol/mod"; import { HTTP_RESPONSE_ERROR_VERSIONED } from "@/schemas/client-protocol/versioned"; -import { serializeWithEncoding } from "@/serde"; +import { encodingIsBinary, serializeWithEncoding } from "@/serde"; import { bufferToArrayBuffer } from "@/utils"; import { getLogger, type Logger } from "./log"; import { deconstructError, stringifyError } from "./utils"; @@ -42,19 +42,10 @@ export function handleRouteNotFound(c: HonoContext) { return c.text("Not Found (RivetKit)", 404); } -export interface HandleRouterErrorOpts { - enableExposeInternalError?: boolean; -} - -export function handleRouteError( - opts: HandleRouterErrorOpts, - error: unknown, - c: HonoContext, -) { - const exposeInternalError = - opts.enableExposeInternalError && getRequestExposeInternalError(c.req.raw); +export function handleRouteError(error: unknown, c: HonoContext) { + const exposeInternalError = getRequestExposeInternalError(c.req.raw); - const { statusCode, code, message, metadata } = deconstructError( + const { statusCode, group, code, message, metadata } = deconstructError( error, logger(), { @@ -67,20 +58,20 @@ export function handleRouteError( let encoding: Encoding; try { encoding = getRequestEncoding(c.req); - } catch (err) { - logger().debug({ - msg: "failed to extract encoding", - error: stringifyError(err), - }); + } catch (_) { encoding = "json"; } const output = serializeWithEncoding( encoding, { + group, code, message, - metadata: bufferToArrayBuffer(cbor.encode(metadata)), + // TODO: Cannot serialize non-binary meta since it requires ArrayBuffer atm + metadata: encodingIsBinary(encoding) + ? bufferToArrayBuffer(cbor.encode(metadata)) + : null, }, HTTP_RESPONSE_ERROR_VERSIONED, ); diff --git a/packages/rivetkit/src/common/utils.ts b/packages/rivetkit/src/common/utils.ts index 32ae8364b..e730cbd29 100644 --- a/packages/rivetkit/src/common/utils.ts +++ b/packages/rivetkit/src/common/utils.ts @@ -186,6 +186,7 @@ export interface DeconstructedError { __type: "ActorError"; statusCode: ContentfulStatusCode; public: boolean; + group: string; code: string; message: string; metadata?: unknown; @@ -203,6 +204,7 @@ export function deconstructError( // We log the error here instead of after generating the code & message because we need to log the original error, not the masked internal error. let statusCode: ContentfulStatusCode; let public_: boolean; + let group: string; let code: string; let message: string; let metadata: unknown; @@ -212,12 +214,14 @@ export function deconstructError( "statusCode" in error && error.statusCode ? error.statusCode : 400 ) as ContentfulStatusCode; public_ = true; + group = error.group; code = error.code; message = getErrorMessage(error); metadata = error.metadata; logger.info({ msg: "public error", + group, code, message, issues: "https://github.com/rivet-gg/rivetkit/issues", @@ -228,12 +232,14 @@ export function deconstructError( if (errors.ActorError.isActorError(error)) { statusCode = 500; public_ = false; + group = error.group; code = error.code; message = getErrorMessage(error); metadata = error.metadata; logger.info({ msg: "internal error", + group, code, message, issues: "https://github.com/rivet-gg/rivetkit/issues", @@ -243,11 +249,13 @@ export function deconstructError( } else { statusCode = 500; public_ = false; + group = "internal"; code = errors.INTERNAL_ERROR_CODE; message = getErrorMessage(error); logger.info({ msg: "internal error", + group, code, message, issues: "https://github.com/rivet-gg/rivetkit/issues", @@ -258,6 +266,7 @@ export function deconstructError( } else { statusCode = 500; public_ = false; + group = "internal"; code = errors.INTERNAL_ERROR_CODE; message = errors.INTERNAL_ERROR_DESCRIPTION; metadata = { @@ -278,6 +287,7 @@ export function deconstructError( __type: "ActorError", statusCode, public: public_, + group, code, message, metadata, diff --git a/packages/rivetkit/src/driver-helpers/mod.ts b/packages/rivetkit/src/driver-helpers/mod.ts index 0e00cc679..48b2c0930 100644 --- a/packages/rivetkit/src/driver-helpers/mod.ts +++ b/packages/rivetkit/src/driver-helpers/mod.ts @@ -8,7 +8,6 @@ export { HEADER_CONN_PARAMS, HEADER_CONN_TOKEN, HEADER_ENCODING, - HEADER_EXPOSE_INTERNAL_ERROR, } from "@/actor/router-endpoints"; export type { ActorOutput, diff --git a/packages/rivetkit/src/driver-test-suite/mod.ts b/packages/rivetkit/src/driver-test-suite/mod.ts index 73d074787..b18ab6860 100644 --- a/packages/rivetkit/src/driver-test-suite/mod.ts +++ b/packages/rivetkit/src/driver-test-suite/mod.ts @@ -4,13 +4,12 @@ import { bundleRequire } from "bundle-require"; import invariant from "invariant"; import { describe } from "vitest"; import type { Transport } from "@/client/mod"; -import { createInlineClientDriver } from "@/inline-client-driver/mod"; import { createManagerRouter } from "@/manager/router"; import type { DriverConfig, Registry, RunConfig } from "@/mod"; import { RunConfigSchema } from "@/registry/run-config"; import { getPort } from "@/test/mod"; +import { logger } from "./log"; import { runActionFeaturesTests } from "./tests/action-features"; -import { runActorAuthTests } from "./tests/actor-auth"; import { runActorConnTests } from "./tests/actor-conn"; import { runActorConnStateTests } from "./tests/actor-conn-state"; import { @@ -68,6 +67,8 @@ type ClientType = "http" | "inline"; export interface DriverDeployOutput { endpoint: string; + namespace: string; + runnerName: string; /** Cleans up the test. */ cleanup(): Promise; @@ -114,8 +115,6 @@ export function runDriverTests( runActorErrorHandlingTests(driverTestConfig); - runActorAuthTests(driverTestConfig); - runActorInlineClientTests(driverTestConfig); runRawHttpTests(driverTestConfig); @@ -141,6 +140,7 @@ export function runDriverTests( export async function createTestRuntime( registryPath: string, driverFactory: (registry: Registry) => Promise<{ + rivetEngine?: { endpoint: string; namespace: string; runnerName: string }; driver: DriverConfig; cleanup?: () => Promise; }>, @@ -156,58 +156,84 @@ export async function createTestRuntime( registry.config.test.enabled = true; // Build drivers - const { driver, cleanup: driverCleanup } = await driverFactory(registry); - - // Build driver config - let injectWebSocket: NodeWebSocket["injectWebSocket"] | undefined; - let upgradeWebSocket: any; - const config: RunConfig = RunConfigSchema.parse({ + const { driver, - getUpgradeWebSocket: () => upgradeWebSocket!, - inspector: { - enabled: true, - token: () => "token", - }, - }); + cleanup: driverCleanup, + rivetEngine, + } = await driverFactory(registry); - // Create router - const managerDriver = driver.manager(registry.config, config); - const inlineDriver = createInlineClientDriver(managerDriver); - const { router } = createManagerRouter( - registry.config, - config, - inlineDriver, - managerDriver, - false, - ); - - // Inject WebSocket - const nodeWebSocket = createNodeWebSocket({ app: router }); - upgradeWebSocket = nodeWebSocket.upgradeWebSocket; - injectWebSocket = nodeWebSocket.injectWebSocket; - - // Start server - const port = await getPort(); - const server = honoServe({ - fetch: router.fetch, - hostname: "127.0.0.1", - port, - }); - invariant(injectWebSocket !== undefined, "should have injectWebSocket"); - injectWebSocket(server); - const endpoint = `http://127.0.0.1:${port}`; - - // Cleanup - const cleanup = async () => { - // Stop server - await new Promise((resolve) => server.close(() => resolve(undefined))); - - // Extra cleanup - await driverCleanup?.(); - }; - - return { - endpoint, - cleanup, - }; + if (rivetEngine) { + // TODO: We don't need createTestRuntime fort his + // Using external Rivet engine + + const cleanup = async () => { + await driverCleanup?.(); + }; + + return { + endpoint: rivetEngine.endpoint, + namespace: rivetEngine.namespace, + runnerName: rivetEngine.runnerName, + cleanup, + }; + } else { + // Start server for Rivet engine + + // Build driver config + // biome-ignore lint/style/useConst: Assigned later + let upgradeWebSocket: any; + const config: RunConfig = RunConfigSchema.parse({ + driver, + getUpgradeWebSocket: () => upgradeWebSocket!, + inspector: { + enabled: true, + token: () => "token", + }, + }); + + // Create router + const managerDriver = driver.manager(registry.config, config); + const { router } = createManagerRouter( + registry.config, + config, + managerDriver, + false, + ); + + // Inject WebSocket + const nodeWebSocket = createNodeWebSocket({ app: router }); + upgradeWebSocket = nodeWebSocket.upgradeWebSocket; + + // Start server + const port = await getPort(); + const server = honoServe({ + fetch: router.fetch, + hostname: "127.0.0.1", + port, + }); + invariant( + nodeWebSocket.injectWebSocket !== undefined, + "should have injectWebSocket", + ); + nodeWebSocket.injectWebSocket(server); + const serverEndpoint = `http://127.0.0.1:${port}`; + + logger().info({ msg: "test serer listening", port }); + + // Cleanup + const cleanup = async () => { + // Stop server + await new Promise((resolve) => server.close(() => resolve(undefined))); + + // Extra cleanup + await driverCleanup?.(); + }; + + return { + endpoint: serverEndpoint, + namespace: "default", + runnerName: "rivetkit", + cleanup, + }; + } } diff --git a/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts b/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts deleted file mode 100644 index 6bb3c8aed..000000000 --- a/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts +++ /dev/null @@ -1,404 +0,0 @@ -import * as cbor from "cbor-x"; -import type { Context as HonoContext } from "hono"; -import type { WebSocket } from "ws"; -import type { Encoding } from "@/actor/protocol/serde"; -import { - HEADER_ACTOR_QUERY, - HEADER_CONN_PARAMS, - HEADER_ENCODING, -} from "@/actor/router-endpoints"; -import { assertUnreachable } from "@/actor/utils"; -import type { ClientDriver } from "@/client/client"; -import { ActorError as ClientActorError } from "@/client/errors"; -import type { Transport } from "@/client/mod"; -import type { UniversalEventSource } from "@/common/eventsource-interface"; -import { importWebSocket } from "@/common/websocket"; -import type { ActorQuery } from "@/manager/protocol/query"; -import type { - TestInlineDriverCallRequest, - TestInlineDriverCallResponse, -} from "@/manager/router"; -import type * as protocol from "@/schemas/client-protocol/mod"; -import { logger } from "./log"; - -/** - * Creates a client driver used for testing the inline client driver. This will send a request to the HTTP server which will then internally call the internal client and return the response. - */ -export function createTestInlineClientDriver( - endpoint: string, - transport: Transport, -): ClientDriver { - return { - action: async = unknown[], Response = unknown>( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - name: string, - args: Args, - ): Promise => { - return makeInlineRequest( - endpoint, - encoding, - transport, - "action", - [undefined, actorQuery, encoding, params, name, args], - ); - }, - - resolveActorId: async ( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params: unknown, - ): Promise => { - return makeInlineRequest( - endpoint, - encodingKind, - transport, - "resolveActorId", - [undefined, actorQuery, encodingKind, params], - ); - }, - - connectWebSocket: async ( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params: unknown, - ): Promise => { - const WebSocket = await importWebSocket(); - - logger().debug({ - msg: "creating websocket connection via test inline driver", - actorQuery, - encodingKind, - }); - - // Create WebSocket connection to the test endpoint - const wsUrl = new URL( - `${endpoint}/registry/.test/inline-driver/connect-websocket`, - ); - wsUrl.searchParams.set("actorQuery", JSON.stringify(actorQuery)); - if (params !== undefined) - wsUrl.searchParams.set("params", JSON.stringify(params)); - wsUrl.searchParams.set("encodingKind", encodingKind); - - // Convert http/https to ws/wss - const wsProtocol = wsUrl.protocol === "https:" ? "wss:" : "ws:"; - const finalWsUrl = `${wsProtocol}//${wsUrl.host}${wsUrl.pathname}${wsUrl.search}`; - - logger().debug({ msg: "connecting to websocket", url: finalWsUrl }); - - // Create and return the WebSocket - // Node & browser WebSocket types are incompatible - const ws = new WebSocket(finalWsUrl, [ - // HACK: See packages/drivers/cloudflare-workers/src/websocket.ts - "rivetkit", - ]) as any; - - return ws; - }, - - connectSse: async ( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params: unknown, - ): Promise => { - logger().debug({ - msg: "creating sse connection via test inline driver", - actorQuery, - encodingKind, - params, - }); - - // Dynamically import EventSource if needed - const EventSourceImport = await import("eventsource"); - // Handle both ES modules (default) and CommonJS export patterns - const EventSourceConstructor = - (EventSourceImport as any).default || EventSourceImport; - - // Encode parameters for the URL - const actorQueryParam = encodeURIComponent(JSON.stringify(actorQuery)); - const encodingParam = encodeURIComponent(encodingKind); - const paramsParam = params - ? encodeURIComponent(JSON.stringify(params)) - : null; - - // Create SSE connection URL - const sseUrl = new URL( - `${endpoint}/registry/.test/inline-driver/connect-sse`, - ); - sseUrl.searchParams.set("actorQueryRaw", actorQueryParam); - sseUrl.searchParams.set("encodingKind", encodingParam); - if (paramsParam) { - sseUrl.searchParams.set("params", paramsParam); - } - - logger().debug({ msg: "connecting to sse", url: sseUrl.toString() }); - - // Create and return the EventSource - const eventSource = new EventSourceConstructor(sseUrl.toString()); - - // Wait for the connection to be established before returning - await new Promise((resolve, reject) => { - eventSource.onopen = () => { - logger().debug({ msg: "sse connection established" }); - resolve(); - }; - - eventSource.onerror = (event: Event) => { - logger().error({ msg: "sse connection failed", event }); - reject(new Error("Failed to establish SSE connection")); - }; - - // Set a timeout in case the connection never establishes - setTimeout(() => { - if (eventSource.readyState !== EventSourceConstructor.OPEN) { - reject(new Error("SSE connection timed out")); - } - }, 10000); // 10 second timeout - }); - - return eventSource as UniversalEventSource; - }, - - sendHttpMessage: async ( - _c: HonoContext | undefined, - actorId: string, - encoding: Encoding, - connectionId: string, - connectionToken: string, - message: protocol.ToServer, - ): Promise => { - logger().debug({ - msg: "sending http message via test inline driver", - actorId, - encoding, - connectionId, - transport, - }); - - const result = await fetch( - `${endpoint}/registry/.test/inline-driver/call`, - { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - encoding, - transport, - method: "sendHttpMessage", - args: [ - undefined, - actorId, - encoding, - connectionId, - connectionToken, - message, - ], - } satisfies TestInlineDriverCallRequest), - }, - ); - - if (!result.ok) { - throw new Error(`Failed to send HTTP message: ${result.statusText}`); - } - - // Discard response - await result.body?.cancel(); - }, - - rawHttpRequest: async ( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - path: string, - init: RequestInit, - ): Promise => { - // Normalize path to match other drivers - const normalizedPath = path.startsWith("/") ? path.slice(1) : path; - - logger().debug({ - msg: "sending raw http request via test inline driver", - actorQuery, - encoding, - path: normalizedPath, - }); - - // Use the dedicated raw HTTP endpoint - const url = `${endpoint}/registry/.test/inline-driver/raw-http/${normalizedPath}`; - - logger().debug({ msg: "rewriting http url", from: path, to: url }); - - // Merge headers with our metadata - const headers = new Headers(init.headers); - headers.set(HEADER_ACTOR_QUERY, JSON.stringify(actorQuery)); - headers.set(HEADER_ENCODING, encoding); - if (params !== undefined) { - headers.set(HEADER_CONN_PARAMS, JSON.stringify(params)); - } - - // Forward the request directly - const response = await fetch(url, { - ...init, - headers, - }); - - // Check if it's an error response from our handler - if ( - !response.ok && - response.headers.get("content-type")?.includes("application/json") - ) { - try { - // Clone the response to avoid consuming the body - const clonedResponse = response.clone(); - const errorData = (await clonedResponse.json()) as any; - if (errorData.error) { - // Handle both error formats: - // 1. { error: { code, message, metadata } } - structured format - // 2. { error: "message" } - simple string format (from custom onFetch handlers) - if (typeof errorData.error === "object") { - throw new ClientActorError( - errorData.error.code, - errorData.error.message, - errorData.error.metadata, - ); - } - // For simple string errors, just return the response as-is - // This allows custom onFetch handlers to return their own error formats - } - } catch (e) { - // If it's not our error format, just return the response as-is - if (!(e instanceof ClientActorError)) { - return response; - } - throw e; - } - } - - return response; - }, - - rawWebSocket: async ( - _c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - path: string, - protocols: string | string[] | undefined, - ): Promise => { - logger().debug({ msg: "test inline driver rawWebSocket called" }); - const WebSocket = await importWebSocket(); - - // Normalize path to match other drivers - const normalizedPath = path.startsWith("/") ? path.slice(1) : path; - - logger().debug({ - msg: "creating raw websocket connection via test inline driver", - actorQuery, - encoding, - path: normalizedPath, - protocols, - }); - - // Create WebSocket connection to the test endpoint - const wsUrl = new URL( - `${endpoint}/registry/.test/inline-driver/raw-websocket`, - ); - wsUrl.searchParams.set("actorQuery", JSON.stringify(actorQuery)); - if (params !== undefined) - wsUrl.searchParams.set("params", JSON.stringify(params)); - wsUrl.searchParams.set("encodingKind", encoding); - wsUrl.searchParams.set("path", normalizedPath); - if (protocols !== undefined) - wsUrl.searchParams.set("protocols", JSON.stringify(protocols)); - - // Convert http/https to ws/wss - const wsProtocol = wsUrl.protocol === "https:" ? "wss:" : "ws:"; - const finalWsUrl = `${wsProtocol}//${wsUrl.host}${wsUrl.pathname}${wsUrl.search}`; - - logger().debug({ msg: "connecting to raw websocket", url: finalWsUrl }); - - logger().debug({ - msg: "rewriting websocket url", - from: path, - to: finalWsUrl, - }); - - // Create and return the WebSocket - // Node & browser WebSocket types are incompatible - const ws = new WebSocket(finalWsUrl, [ - // HACK: See packages/drivers/cloudflare-workers/src/websocket.ts - "rivetkit", - ]) as any; - - logger().debug({ - msg: "test inline driver created websocket", - readyState: ws.readyState, - url: ws.url, - }); - - return ws; - }, - }; -} - -async function makeInlineRequest( - endpoint: string, - encoding: Encoding, - transport: Transport, - method: string, - args: unknown[], -): Promise { - logger().debug({ - msg: "sending inline request", - encoding, - transport, - method, - args, - }); - - // Call driver - const response = await fetch( - `${endpoint}/registry/.test/inline-driver/call`, - { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: cbor.encode({ - encoding, - transport, - method, - args, - } satisfies TestInlineDriverCallRequest), - }, - ); - - if (!response.ok) { - throw new Error(`Failed to call inline ${method}: ${response.statusText}`); - } - - // Parse response - const buffer = await response.arrayBuffer(); - const callResponse: TestInlineDriverCallResponse = cbor.decode( - new Uint8Array(buffer), - ); - - // Throw or OK - if ("ok" in callResponse) { - return callResponse.ok; - } else if ("err" in callResponse) { - throw new ClientActorError( - callResponse.err.code, - callResponse.err.message, - callResponse.err.metadata, - ); - } else { - assertUnreachable(callResponse); - } -} diff --git a/packages/rivetkit/src/driver-test-suite/tests/actor-auth.ts b/packages/rivetkit/src/driver-test-suite/tests/actor-auth.ts deleted file mode 100644 index facd7f448..000000000 --- a/packages/rivetkit/src/driver-test-suite/tests/actor-auth.ts +++ /dev/null @@ -1,591 +0,0 @@ -import { describe, expect, test } from "vitest"; -import type { ActorError } from "@/client/errors"; -import type { DriverTestConfig } from "../mod"; -import { setupDriverTest } from "../utils"; - -export function runActorAuthTests(driverTestConfig: DriverTestConfig) { - describe("Actor Authentication Tests", () => { - describe("Basic Authentication", () => { - test("should allow access with valid auth", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - // Create client with valid auth params - const instance = client.authActor.getOrCreate(undefined, { - params: { apiKey: "valid-api-key" }, - }); - - // This should succeed with valid API key - const authData = await instance.getUserAuth(); - if (driverTestConfig.clientType === "inline") { - // Inline clients don't have auth data - expect(authData).toBeUndefined(); - } else { - // HTTP clients should have auth data - expect(authData).toEqual({ - userId: "user123", - token: "valid-api-key", - }); - } - - // Should be able to call actions - const requests = await instance.getRequests(); - expect(requests).toBe(1); - }); - - test("should deny access with invalid auth", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - // This should fail without proper authorization - const instance = client.authActor.getOrCreate(); - - if (driverTestConfig.clientType === "inline") { - // Inline clients bypass authentication - const requests = await instance.getRequests(); - expect(typeof requests).toBe("number"); - } else { - // HTTP clients should enforce authentication - try { - await instance.getRequests(); - expect.fail("Expected authentication error"); - } catch (error) { - expect((error as ActorError).code).toBe("missing_auth"); - } - } - }); - - test("should expose auth data on connection", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.authActor.getOrCreate(undefined, { - params: { apiKey: "valid-api-key" }, - }); - - // Auth data should be available via c.conn.auth - const authData = await instance.getUserAuth(); - if (driverTestConfig.clientType === "inline") { - // Inline clients don't have auth data - expect(authData).toBeUndefined(); - } else { - // HTTP clients should have auth data - expect(authData).toBeDefined(); - expect((authData as any).userId).toBe("user123"); - expect((authData as any).token).toBe("valid-api-key"); - } - }); - }); - - describe("Intent-Based Authentication", () => { - test("should allow get operations for any role", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const createdInstance = await client.intentAuthActor.create(["foo"], { - params: { role: "admin" }, - }); - const actorId = await createdInstance.resolve(); - - if (driverTestConfig.clientType === "inline") { - // Inline clients bypass authentication - const instance = client.intentAuthActor.getForId(actorId); - const value = await instance.getValue(); - expect(value).toBe(0); - } else { - // HTTP clients - actions require user or admin role - const instance = client.intentAuthActor.getForId(actorId, { - params: { role: "user" }, // Actions require user or admin role - }); - const value = await instance.getValue(); - expect(value).toBe(0); - } - }); - - test("should require admin role for create operations", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - if (driverTestConfig.clientType === "inline") { - // Inline clients bypass authentication - should succeed - const instance = client.intentAuthActor.getOrCreate(undefined, { - params: { role: "user" }, - }); - const value = await instance.getValue(); - expect(value).toBe(0); - } else { - // HTTP clients should enforce authentication - try { - const instance = client.intentAuthActor.getOrCreate(undefined, { - params: { role: "user" }, - }); - await instance.getValue(); - expect.fail("Expected permission error for create operation"); - } catch (error) { - expect((error as ActorError).code).toBe("insufficient_permissions"); - expect((error as ActorError).message).toContain( - "Admin role required", - ); - } - } - }); - - test("should allow actions for user and admin roles", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const createdInstance = await client.intentAuthActor.create(["foo"], { - params: { role: "admin" }, - }); - const actorId = await createdInstance.resolve(); - - // This should fail - actions require user or admin role - const instance = client.intentAuthActor.getForId(actorId, { - params: { role: "guest" }, - }); - - if (driverTestConfig.clientType === "inline") { - // Inline clients bypass authentication - should succeed - const result = await instance.setValue(42); - expect(result).toBe(42); - } else { - // HTTP clients should enforce authentication - try { - await instance.setValue(42); - expect.fail("Expected permission error for action"); - } catch (error) { - expect((error as ActorError).code).toBe("insufficient_permissions"); - expect((error as ActorError).message).toContain( - "User or admin role required", - ); - } - } - }); - }); - - describe("Public Access", () => { - test("should allow access with empty onAuth", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - // Public actor should allow access without authentication - const instance = client.publicActor.getOrCreate(); - - const visitors = await instance.visit(); - expect(visitors).toBe(1); - - // Should be able to call multiple times - const visitors2 = await instance.visit(); - expect(visitors2).toBe(2); - }); - - test("should deny access without onAuth defined", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - // Actor without onAuth should be blocked - const instance = client.noAuthActor.getOrCreate(); - - if (driverTestConfig.clientType === "inline") { - // Inline clients bypass authentication - should succeed - const value = await instance.getValue(); - expect(value).toBe(42); - } else { - // HTTP clients should enforce authentication - try { - await instance.getValue(); - expect.fail( - "Expected access to be denied for actor without onAuth", - ); - } catch (error) { - expect((error as ActorError).code).toBe("forbidden"); - } - } - }); - }); - - describe("Async Authentication", () => { - test("should handle promise-based auth", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.asyncAuthActor.getOrCreate(undefined, { - params: { token: "valid" }, - }); - - // Should succeed with valid token - const result = await instance.increment(); - expect(result).toBe(1); - - // Auth data should be available - const authData = await instance.getAuthData(); - if (driverTestConfig.clientType === "inline") { - // Inline clients don't have auth data - expect(authData).toBeUndefined(); - } else { - // HTTP clients should have auth data - expect(authData).toBeDefined(); - expect((authData as any).userId).toBe("user-valid"); - expect((authData as any).validated).toBe(true); - } - }); - - test("should handle async auth failures", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.asyncAuthActor.getOrCreate(); - - if (driverTestConfig.clientType === "inline") { - // Inline clients bypass authentication - should succeed - const result = await instance.increment(); - expect(result).toBe(1); - } else { - // HTTP clients should enforce authentication - try { - await instance.increment(); - expect.fail("Expected async auth failure"); - } catch (error) { - expect((error as ActorError).code).toBe("missing_token"); - } - } - }); - }); - - describe("Authentication Across Transports", () => { - if (driverTestConfig.transport === "websocket") { - test("should authenticate WebSocket connections", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - // Test WebSocket connection auth - const instance = client.authActor.getOrCreate(undefined, { - params: { apiKey: "valid-api-key" }, - }); - - // Should be able to establish connection and call actions - const authData = await instance.getUserAuth(); - expect(authData).toBeDefined(); - expect((authData as any).userId).toBe("user123"); - }); - } - - test("should authenticate HTTP actions", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - // Test HTTP action auth - const instance = client.authActor.getOrCreate(undefined, { - params: { apiKey: "valid-api-key" }, - }); - - // Actions should require authentication - const requests = await instance.getRequests(); - expect(typeof requests).toBe("number"); - }); - }); - - describe("Error Handling", () => { - test("should handle auth errors gracefully", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.authActor.getOrCreate(); - - if (driverTestConfig.clientType === "inline") { - // Inline clients bypass authentication - should succeed - const requests = await instance.getRequests(); - expect(typeof requests).toBe("number"); - } else { - // HTTP clients should enforce authentication - try { - await instance.getRequests(); - expect.fail("Expected authentication error"); - } catch (error) { - // Error should be properly structured - const actorError = error as ActorError; - expect(actorError.code).toBeDefined(); - expect(actorError.message).toBeDefined(); - } - } - }); - - test("should preserve error details for debugging", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.asyncAuthActor.getOrCreate(); - - if (driverTestConfig.clientType === "inline") { - // Inline clients bypass authentication - should succeed - const result = await instance.increment(); - expect(result).toBe(1); - } else { - // HTTP clients should enforce authentication - try { - await instance.increment(); - expect.fail("Expected token error"); - } catch (error) { - const actorError = error as ActorError; - expect(actorError.code).toBe("missing_token"); - expect(actorError.message).toBe("Token required"); - } - } - }); - }); - - describe("Raw HTTP Authentication", () => { - test("should allow raw HTTP access with valid auth", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - // Create actor with valid auth - const instance = client.rawHttpAuthActor.getOrCreate(undefined, { - params: { apiKey: "valid-api-key" }, - }); - - // Raw HTTP request should succeed - const response = await instance.fetch("api/auth-info"); - expect(response.ok).toBe(true); - - const data = (await response.json()) as any; - expect(data.message).toBe("Authenticated request"); - expect(data.requestCount).toBe(1); - - // Regular actions should also work - const count = await instance.getRequestCount(); - expect(count).toBe(1); - }); - - test("should deny raw HTTP access without auth", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - // Create actor without auth - const instance = client.rawHttpAuthActor.getOrCreate(); - - // All clients should now enforce authentication for raw endpoints - const response = await instance.fetch("api/protected"); - if (driverTestConfig.clientType === "inline") { - expect(response.ok).toBe(true); - expect(response.status).toBe(200); - } else { - expect(response.ok).toBe(false); - expect(response.status).toBe(400); - } - - // Check error details - try { - const errorData = (await response.json()) as any; - expect(errorData.c || errorData.code).toBe("missing_auth"); - } catch { - // Response might be CBOR encoded, status code check is sufficient - } - }); - - test("should deny raw HTTP for actors without onAuth", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.rawHttpNoAuthActor.getOrCreate(); - - // All clients should now enforce authentication for raw endpoints - const response = await instance.fetch("api/test"); - if (driverTestConfig.clientType === "inline") { - expect(response.ok).toBe(true); - expect(response.status).toBe(200); - } else { - expect(response.ok).toBe(false); - expect(response.status).toBe(403); - } - - // Check error details - try { - const errorData = (await response.json()) as any; - expect(errorData.c || errorData.code).toBe("forbidden"); - } catch { - // Response might be CBOR encoded, status code check is sufficient - } - }); - - test("should allow public raw HTTP access", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.rawHttpPublicActor.getOrCreate(); - - // Should work without auth - const response = await instance.fetch("api/visit"); - expect(response.ok).toBe(true); - - const data = (await response.json()) as any; - expect(data.message).toBe("Welcome visitor!"); - expect(data.count).toBe(1); - - // Second request - const response2 = await instance.fetch("api/visit"); - const data2 = (await response2.json()) as any; - expect(data2.count).toBe(2); - }); - - test("should handle custom auth in onFetch", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.rawHttpCustomAuthActor.getOrCreate(); - - // Request without auth should fail - const response1 = await instance.fetch("api/data"); - expect(response1.ok).toBe(false); - expect(response1.status).toBe(401); - - const error1 = (await response1.json()) as any; - expect(error1.error).toBe("Unauthorized"); - - // Request with wrong token should fail - const response2 = await instance.fetch("api/data", { - headers: { - Authorization: "Bearer wrong-token", - }, - }); - expect(response2.ok).toBe(false); - expect(response2.status).toBe(403); - - // Request with correct token should succeed - const response3 = await instance.fetch("api/data", { - headers: { - Authorization: "Bearer custom-token", - }, - }); - expect(response3.ok).toBe(true); - - const data = (await response3.json()) as any; - expect(data.message).toBe("Authorized!"); - expect(data.authorized).toBe(1); - - // Check stats - const stats = await instance.getStats(); - expect(stats.authorized).toBe(1); - expect(stats.unauthorized).toBe(2); - }); - }); - - describe("Raw WebSocket Authentication", () => { - test("should allow raw WebSocket access with valid auth", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - // Create actor with valid auth - const instance = client.rawWebSocketAuthActor.getOrCreate(undefined, { - params: { apiKey: "valid-api-key" }, - }); - - const ws = await instance.websocket(); - - // Wait for welcome message - const welcomePromise = new Promise((resolve, reject) => { - ws.addEventListener("message", (event: any) => { - const data = JSON.parse(event.data); - if (data.type === "welcome") { - resolve(data); - } - }); - ws.addEventListener("close", () => reject("closed")); - }); - - const welcomeData = (await welcomePromise) as any; - expect(welcomeData.message).toBe("Authenticated WebSocket connection"); - expect(welcomeData.connectionCount).toBe(1); - - ws.close(); - }); - - test("should deny raw WebSocket access without auth", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.rawWebSocketAuthActor.getOrCreate(); - - // All clients should now enforce authentication for raw endpoints - try { - await instance.websocket(); - expect.fail("Expected authentication error"); - } catch (error) { - // WebSocket connection failures may not always have structured error codes - expect(error).toBeDefined(); - } - }); - - test("should deny raw WebSocket for actors without onAuth", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.rawWebSocketNoAuthActor.getOrCreate(); - - // All clients should now enforce authentication for raw endpoints - try { - await instance.websocket(); - expect.fail("Expected forbidden error"); - } catch (error) { - // WebSocket connection failures may not always have structured error codes - expect(error).toBeDefined(); - } - }); - - test("should allow public raw WebSocket access", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.rawWebSocketPublicActor.getOrCreate(); - - // Should work without auth - const ws = await instance.websocket(); - - const welcomePromise = new Promise((resolve, reject) => { - ws.addEventListener("message", (event: any) => { - const data = JSON.parse(event.data); - if (data.type === "welcome") { - resolve(data); - } - }); - ws.addEventListener("close", reject); - }); - - const welcomeData = (await welcomePromise) as any; - expect(welcomeData.message).toBe("Public WebSocket connection"); - expect(welcomeData.visitorNumber).toBe(1); - - ws.close(); - }); - - test("should handle custom auth in onWebSocket", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - const instance = client.rawWebSocketCustomAuthActor.getOrCreate(); - - // WebSocket without token should be rejected - try { - const ws1 = await instance.websocket(); - - // Listen for error message before close - const errorPromise = new Promise((resolve, reject) => { - ws1.addEventListener("message", (event: any) => { - const data = JSON.parse(event.data); - if (data.type === "error") { - resolve(data); - } - }); - ws1.addEventListener("close", reject); - }); - - const errorData = (await errorPromise) as any; - expect(errorData.type).toBe("error"); - expect(errorData.message).toBe("Unauthorized"); - } catch (error) { - // Some drivers might reject the connection immediately - expect(error).toBeDefined(); - } - - // WebSocket with correct token should succeed - const ws2 = await instance.websocket("?token=custom-ws-token"); - - const authPromise = new Promise((resolve, reject) => { - ws2.addEventListener("message", (event: any) => { - const data = JSON.parse(event.data); - if (data.type === "authorized") { - resolve(data); - } - }); - ws2.addEventListener("close", reject); - }); - - const authData = (await authPromise) as any; - expect(authData.message).toBe("Welcome authenticated user!"); - - ws2.close(); - - // Check stats - const stats = await instance.getStats(); - expect(stats.authorized).toBeGreaterThanOrEqual(1); - expect(stats.unauthorized).toBeGreaterThanOrEqual(1); - }); - }); - }); -} diff --git a/packages/rivetkit/src/driver-test-suite/tests/actor-handle.ts b/packages/rivetkit/src/driver-test-suite/tests/actor-handle.ts index d278d2aab..7f51e3ac5 100644 --- a/packages/rivetkit/src/driver-test-suite/tests/actor-handle.ts +++ b/packages/rivetkit/src/driver-test-suite/tests/actor-handle.ts @@ -1,4 +1,5 @@ import { describe, expect, test } from "vitest"; +import type { ActorError } from "@/client/mod"; import type { DriverTestConfig } from "../mod"; import { setupDriverTest } from "../utils"; @@ -74,6 +75,38 @@ export function runActorHandleTests(driverTestConfig: DriverTestConfig) { const retrievedCount = await handle.getCount(); expect(retrievedCount).toBe(9); }); + + test("errors when calling create twice with the same key", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + const key = ["duplicate-create-handle", crypto.randomUUID()]; + + // First create should succeed + await client.counter.create(key); + + // Second create with same key should throw ActorAlreadyExists + try { + await client.counter.create(key); + expect.fail("did not error on duplicate create"); + } catch (err) { + expect((err as ActorError).group).toBe("actor"); + expect((err as ActorError).code).toBe("already_exists"); + } + }); + + test(".get().resolve() errors for non-existent actor", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + const missingId = `nonexistent-${crypto.randomUUID()}`; + + try { + await client.counter.get([missingId]).resolve(); + expect.fail("did not error for get().resolve() on missing actor"); + } catch (err) { + expect((err as ActorError).group).toBe("actor"); + expect((err as ActorError).code).toBe("not_found"); + } + }); }); describe("Action Functionality", () => { diff --git a/packages/rivetkit/src/driver-test-suite/tests/manager-driver.ts b/packages/rivetkit/src/driver-test-suite/tests/manager-driver.ts index 415c5f1f5..408b553bc 100644 --- a/packages/rivetkit/src/driver-test-suite/tests/manager-driver.ts +++ b/packages/rivetkit/src/driver-test-suite/tests/manager-driver.ts @@ -39,7 +39,8 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { await client.counter.create(uniqueKey); expect.fail("did not error on duplicate create"); } catch (err) { - expect((err as ActorError).code).toBe("actor_already_exists"); + expect((err as ActorError).group).toBe("actor"); + expect((err as ActorError).code).toBe("already_exists"); } // Verify the original actor still works and has its state @@ -60,7 +61,8 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { await client.counter.get([nonexistentId]).resolve(); expect.fail("did not error for get"); } catch (err) { - expect((err as ActorError).code).toBe("actor_not_found"); + expect((err as ActorError).group).toBe("actor"); + expect((err as ActorError).code).toBe("not_found"); } // Create the actor diff --git a/packages/rivetkit/src/driver-test-suite/utils.ts b/packages/rivetkit/src/driver-test-suite/utils.ts index adb68c66b..82cc16bc7 100644 --- a/packages/rivetkit/src/driver-test-suite/utils.ts +++ b/packages/rivetkit/src/driver-test-suite/utils.ts @@ -1,11 +1,9 @@ import { resolve } from "node:path"; import { type TestContext, vi } from "vitest"; import { assertUnreachable } from "@/actor/utils"; -import { createClientWithDriver } from "@/client/client"; import { type Client, createClient } from "@/client/mod"; import type { registry } from "../../fixtures/driver-test-suite/registry"; import type { DriverTestConfig } from "./mod"; -import { createTestInlineClientDriver } from "./test-inline-client-driver"; export const FAKE_TIME = new Date("2024-01-01T00:00:00.000Z"); @@ -24,22 +22,27 @@ export async function setupDriverTest( // Build drivers const projectPath = resolve(__dirname, "../../fixtures/driver-test-suite"); - const { endpoint, cleanup } = await driverTestConfig.start(projectPath); + const { endpoint, namespace, runnerName, cleanup } = + await driverTestConfig.start(projectPath); c.onTestFinished(cleanup); let client: Client; if (driverTestConfig.clientType === "http") { // Create client - client = createClient(endpoint, { + client = createClient({ + endpoint, + namespace, + runnerName, transport: driverTestConfig.transport, }); } else if (driverTestConfig.clientType === "inline") { - // Use inline client from driver - const clientDriver = createTestInlineClientDriver( - endpoint, - driverTestConfig.transport ?? "websocket", - ); - client = createClientWithDriver(clientDriver); + throw "TODO"; + // // Use inline client from driver + // const clientDriver = createTestInlineClientDriver( + // endpoint, + // driverTestConfig.transport ?? "websocket", + // ); + // client = createClientWithDriver(clientDriver); } else { assertUnreachable(driverTestConfig.clientType); } diff --git a/packages/rivetkit/src/drivers/default.ts b/packages/rivetkit/src/drivers/default.ts index 3408bcb09..ea85629a3 100644 --- a/packages/rivetkit/src/drivers/default.ts +++ b/packages/rivetkit/src/drivers/default.ts @@ -9,7 +9,7 @@ import { getEnvUniversal } from "@/utils"; * Chooses the appropriate driver based on the run configuration. */ export function chooseDefaultDriver(runConfig: RunConfig): DriverConfig { - const engineEndpoint = runConfig.engine || getEnvUniversal("RIVET_ENGINE"); + const engineEndpoint = runConfig.endpoint ?? getEnvUniversal("RIVET_ENGINE"); if (engineEndpoint && runConfig.driver) { throw new UserError( diff --git a/packages/rivetkit/src/drivers/engine/actor-driver.ts b/packages/rivetkit/src/drivers/engine/actor-driver.ts index 9ed7ea8c3..8be47142a 100644 --- a/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -6,6 +6,7 @@ import { Runner } from "@rivetkit/engine-runner"; import * as cbor from "cbor-x"; import { WSContext } from "hono/ws"; import invariant from "invariant"; +import { deserializeActorKey } from "@/actor/keys"; import { EncodingSchema } from "@/actor/protocol/serde"; import type { Client } from "@/client/client"; import { getLogger } from "@/common/log"; @@ -36,7 +37,6 @@ import { PATH_RAW_WEBSOCKET_PREFIX, } from "@/mod"; import type { Config } from "./config"; -import { deserializeActorKey } from "./keys"; import { KEYS } from "./kv"; import { logger } from "./log"; diff --git a/packages/rivetkit/src/drivers/engine/api-endpoints.ts b/packages/rivetkit/src/drivers/engine/api-endpoints.ts deleted file mode 100644 index 4a6e42a44..000000000 --- a/packages/rivetkit/src/drivers/engine/api-endpoints.ts +++ /dev/null @@ -1,128 +0,0 @@ -import { apiCall } from "./api-utils"; -import type { Config } from "./config"; -import { serializeActorKey } from "./keys"; - -// MARK: Common types -export type RivetId = string; - -export interface Actor { - actor_id: RivetId; - name: string; - key: string; - namespace_id: RivetId; - runner_name_selector: string; - create_ts: number; - connectable_ts?: number | null; - destroy_ts?: number | null; - sleep_ts?: number | null; - start_ts?: number | null; -} - -export interface ActorsGetResponse { - actor: Actor; -} - -export interface ActorsGetByIdResponse { - actor_id?: RivetId | null; -} - -export interface ActorsGetOrCreateResponse { - actor: Actor; - created: boolean; -} - -export interface ActorsGetOrCreateByIdResponse { - actor_id: RivetId; - created: boolean; -} - -export interface ActorsCreateRequest { - name: string; - runner_name_selector: string; - crash_policy: string; - key?: string | null; - input?: string | null; -} - -export interface ActorsCreateResponse { - actor: Actor; -} - -// MARK: Get actor -export async function getActor( - config: Config, - actorId: RivetId, -): Promise { - return apiCall( - config.endpoint, - config.namespace, - "GET", - `/actors/${encodeURIComponent(actorId)}`, - ); -} - -// MARK: Get actor by id -export async function getActorById( - config: Config, - name: string, - key: string[], -): Promise { - const serializedKey = serializeActorKey(key); - return apiCall( - config.endpoint, - config.namespace, - "GET", - `/actors/by-id?name=${encodeURIComponent(name)}&key=${encodeURIComponent(serializedKey)}`, - ); -} - -// MARK: Get or create actor by id -export interface ActorsGetOrCreateByIdRequest { - name: string; - key: string; - runner_name_selector: string; - crash_policy: string; - input?: string | null; -} - -export async function getOrCreateActorById( - config: Config, - request: ActorsGetOrCreateByIdRequest, -): Promise { - return apiCall( - config.endpoint, - config.namespace, - "PUT", - `/actors/by-id`, - request, - ); -} - -// MARK: Create actor -export async function createActor( - config: Config, - request: ActorsCreateRequest, -): Promise { - return apiCall( - config.endpoint, - config.namespace, - "POST", - `/actors`, - request, - ); -} - -// MARK: Destroy actor -export type ActorsDeleteResponse = {}; - -export async function destroyActor( - config: Config, - actorId: RivetId, -): Promise { - return apiCall( - config.endpoint, - config.namespace, - "DELETE", - `/actors/${encodeURIComponent(actorId)}`, - ); -} diff --git a/packages/rivetkit/src/drivers/engine/api-utils.ts b/packages/rivetkit/src/drivers/engine/api-utils.ts deleted file mode 100644 index ba5d0d04a..000000000 --- a/packages/rivetkit/src/drivers/engine/api-utils.ts +++ /dev/null @@ -1,71 +0,0 @@ -import { logger } from "./log"; - -// Error class for Engine API errors -export class EngineApiError extends Error { - constructor( - public readonly group: string, - public readonly code: string, - message?: string, - ) { - super(message || `Engine API error: ${group}/${code}`); - this.name = "EngineApiError"; - } -} - -// Helper function for making API calls -export async function apiCall( - endpoint: string, - namespace: string, - method: "GET" | "POST" | "PUT" | "DELETE", - path: string, - body?: TInput, -): Promise { - const url = `${endpoint}${path}${path.includes("?") ? "&" : "?"}namespace=${encodeURIComponent(namespace)}`; - - const options: RequestInit = { - method, - headers: { - "Content-Type": "application/json", - }, - }; - - if (body !== undefined && method !== "GET") { - options.body = JSON.stringify(body); - } - - logger().debug({ msg: "making api call", method, url }); - - const response = await fetch(url, options); - - if (!response.ok) { - const errorText = await response.text(); - logger().error({ - msg: "api call failed", - status: response.status, - statusText: response.statusText, - error: errorText, - method, - path, - }); - - // Try to parse error response - try { - const errorData = JSON.parse(errorText); - if (errorData.kind === "error" && errorData.group && errorData.code) { - throw new EngineApiError( - errorData.group, - errorData.code, - errorData.message, - ); - } - } catch (parseError) { - // If parsing fails or it's not our expected error format, continue - } - - throw new Error( - `API call failed: ${response.status} ${response.statusText}`, - ); - } - - return response.json() as Promise; -} diff --git a/packages/rivetkit/src/drivers/engine/manager-driver.ts b/packages/rivetkit/src/drivers/engine/manager-driver.ts deleted file mode 100644 index 3f436716c..000000000 --- a/packages/rivetkit/src/drivers/engine/manager-driver.ts +++ /dev/null @@ -1,405 +0,0 @@ -import * as cbor from "cbor-x"; -import type { Context as HonoContext } from "hono"; -import invariant from "invariant"; -import { ActorAlreadyExists } from "@/actor/errors"; -import { - HEADER_AUTH_DATA, - HEADER_CONN_PARAMS, - HEADER_ENCODING, - HEADER_EXPOSE_INTERNAL_ERROR, -} from "@/actor/router-endpoints"; -import { generateRandomString } from "@/actor/utils"; -import { importWebSocket } from "@/common/websocket"; -import type { - ActorOutput, - CreateInput, - GetForIdInput, - GetOrCreateWithKeyInput, - GetWithKeyInput, - ManagerDriver, -} from "@/driver-helpers/mod"; -import type { ManagerDisplayInformation } from "@/manager/driver"; -import { type Encoding, noopNext, type RunConfig } from "@/mod"; -import { - createActor, - destroyActor, - getActor, - getActorById, - getOrCreateActorById, -} from "./api-endpoints"; -import { EngineApiError } from "./api-utils"; -import type { Config } from "./config"; -import { deserializeActorKey, serializeActorKey } from "./keys"; -import { logger } from "./log"; -import { createWebSocketProxy } from "./ws-proxy"; - -export class EngineManagerDriver implements ManagerDriver { - #config: Config; - #runConfig: RunConfig; - #importWebSocketPromise: Promise; - - constructor(config: Config, runConfig: RunConfig) { - this.#config = config; - this.#runConfig = runConfig; - if (!this.#runConfig.inspector.token()) { - const token = generateRandomString(); - this.#runConfig.inspector.token = () => token; - } - this.#importWebSocketPromise = importWebSocket(); - } - - async sendRequest(actorId: string, actorRequest: Request): Promise { - logger().debug({ - msg: "sending request to actor via guard", - actorId, - method: actorRequest.method, - url: actorRequest.url, - }); - - return this.#forwardHttpRequest(actorRequest, actorId); - } - - async openWebSocket( - path: string, - actorId: string, - encoding: Encoding, - params: unknown, - ): Promise { - const WebSocket = await this.#importWebSocketPromise; - - // WebSocket connections go through guard - const guardUrl = `${this.#config.endpoint}${path}`; - - logger().debug({ - msg: "opening websocket to actor via guard", - actorId, - path, - guardUrl, - }); - - // Create WebSocket connection - const ws = new WebSocket(guardUrl, { - headers: buildGuardHeadersForWebSocket(actorId, encoding, params), - }); - - logger().debug({ msg: "websocket connection opened", actorId }); - - return ws; - } - - async proxyRequest( - _c: HonoContext, - actorRequest: Request, - actorId: string, - ): Promise { - logger().debug({ - msg: "forwarding request to actor via guard", - actorId, - method: actorRequest.method, - url: actorRequest.url, - hasBody: !!actorRequest.body, - }); - - return this.#forwardHttpRequest(actorRequest, actorId); - } - - async proxyWebSocket( - c: HonoContext, - path: string, - actorId: string, - encoding: Encoding, - params: unknown, - authData: unknown, - ): Promise { - const upgradeWebSocket = this.#runConfig.getUpgradeWebSocket?.(); - invariant(upgradeWebSocket, "missing getUpgradeWebSocket"); - - const guardUrl = `${this.#config.endpoint}${path}`; - const wsGuardUrl = guardUrl.replace("http://", "ws://"); - - logger().debug({ - msg: "forwarding websocket to actor via guard", - actorId, - path, - guardUrl, - }); - - // Build headers - const headers = buildGuardHeadersForWebSocket( - actorId, - encoding, - params, - authData, - ); - const args = await createWebSocketProxy(c, wsGuardUrl, headers); - - return await upgradeWebSocket(() => args)(c, noopNext()); - } - - displayInformation(): ManagerDisplayInformation { - return { - name: "Rivet Engine", - properties: { - Endpoint: this.#config.endpoint, - Namespace: this.#config.namespace, - Runner: this.#config.runnerName, - }, - }; - } - - extraStartupLog() { - return { - engine: this.#config.endpoint, - namespace: this.#config.namespace, - runner: this.#config.runnerName, - }; - } - - async getForId({ - c, - name, - actorId, - }: GetForIdInput): Promise { - // Fetch from API if not in cache - try { - const response = await getActor(this.#config, actorId); - - // Validate name matches - if (response.actor.name !== name) { - logger().debug({ - msg: "actor name mismatch from api", - actorId, - apiName: response.actor.name, - requestedName: name, - }); - return undefined; - } - - const keyRaw = response.actor.key; - invariant(keyRaw, `actor ${actorId} should have key`); - const key = deserializeActorKey(keyRaw); - - return { - actorId, - name, - key, - }; - } catch (error) { - if ( - error instanceof EngineApiError && - (error as EngineApiError).group === "actor" && - (error as EngineApiError).code === "not_found" - ) { - return undefined; - } - throw error; - } - } - - async getWithKey({ - c, - name, - key, - }: GetWithKeyInput): Promise { - logger().debug({ msg: "getWithKey: searching for actor", name, key }); - - // If not in local cache, fetch by key from API - try { - const response = await getActorById(this.#config, name, key); - - if (!response.actor_id) { - return undefined; - } - - const actorId = response.actor_id; - - logger().debug({ - msg: "getWithKey: found actor via api", - actorId, - name, - key, - }); - - return { - actorId, - name, - key, - }; - } catch (error) { - if ( - error instanceof EngineApiError && - (error as EngineApiError).group === "actor" && - (error as EngineApiError).code === "not_found" - ) { - return undefined; - } - throw error; - } - } - - async getOrCreateWithKey( - input: GetOrCreateWithKeyInput, - ): Promise { - const { c, name, key, input: actorInput, region } = input; - - logger().info({ - msg: "getOrCreateWithKey: getting or creating actor via engine api", - name, - key, - }); - - const response = await getOrCreateActorById(this.#config, { - name, - key: serializeActorKey(key), - runner_name_selector: this.#config.runnerName, - input: input ? cbor.encode(actorInput).toString("base64") : undefined, - crash_policy: "sleep", - }); - - const actorId = response.actor_id; - - logger().info({ - msg: "getOrCreateWithKey: actor ready", - actorId, - name, - key, - created: response.created, - }); - - return { - actorId, - name, - key, - }; - } - - async createActor({ - c, - name, - key, - input, - }: CreateInput): Promise { - // Check if actor with the same name and key already exists - const existingActor = await this.getWithKey({ c, name, key }); - if (existingActor) { - throw new ActorAlreadyExists(name, key); - } - - logger().info({ msg: "creating actor via engine api", name, key }); - - // Create actor via engine API - const result = await createActor(this.#config, { - name, - runner_name_selector: this.#config.runnerName, - key: serializeActorKey(key), - input: input ? cbor.encode(input).toString("base64") : null, - crash_policy: "sleep", - }); - const actorId = result.actor.actor_id; - - logger().info({ msg: "actor created", actorId, name, key }); - - return { - actorId, - name, - key, - }; - } - - async destroyActor(actorId: string): Promise { - logger().info({ msg: "destroying actor via engine api", actorId }); - - await destroyActor(this.#config, actorId); - - logger().info({ msg: "actor destroyed", actorId }); - } - - async #forwardHttpRequest( - actorRequest: Request, - actorId: string, - ): Promise { - // Route through guard port - const url = new URL(actorRequest.url); - const guardUrl = `${this.#config.endpoint}${url.pathname}${url.search}`; - - // Handle body properly based on method and presence - let bodyToSend: ArrayBuffer | null = null; - const guardHeaders = buildGuardHeadersForHttp(actorRequest, actorId); - - if ( - actorRequest.body && - actorRequest.method !== "GET" && - actorRequest.method !== "HEAD" - ) { - if (actorRequest.bodyUsed) { - throw new Error("Request body has already been consumed"); - } - - // TODO: This buffers the entire request in memory every time. We - // need to properly implement streaming bodies. - // Clone and read the body to ensure it can be sent - const clonedRequest = actorRequest.clone(); - bodyToSend = await clonedRequest.arrayBuffer(); - - // If this is a streaming request, we need to convert the headers - // for the basic array buffer - guardHeaders.delete("transfer-encoding"); - guardHeaders.set( - "content-length", - String((bodyToSend as ArrayBuffer).byteLength), - ); - } - - const guardRequest = new Request(guardUrl, { - method: actorRequest.method, - headers: guardHeaders, - body: bodyToSend, - }); - - return mutableResponse(await fetch(guardRequest)); - } -} - -function mutableResponse(fetchRes: Response): Response { - // We cannot return the raw response from `fetch` since the response type is not mutable. - // - // In order for middleware to be able to mutate the response, we need to build a new Response object that is mutable. - return new Response(fetchRes.body, fetchRes); -} - -function buildGuardHeadersForHttp( - actorRequest: Request, - actorId: string, -): Headers { - const headers = new Headers(); - // Copy all headers from the original request - for (const [key, value] of actorRequest.headers.entries()) { - headers.set(key, value); - } - // Add guard-specific headers - headers.set("x-rivet-target", "actor"); - headers.set("x-rivet-actor", actorId); - headers.set("x-rivet-port", "main"); - return headers; -} - -function buildGuardHeadersForWebSocket( - actorId: string, - encoding: Encoding, - params?: unknown, - authData?: unknown, -): Record { - const headers: Record = {}; - headers["x-rivet-target"] = "actor"; - headers["x-rivet-actor"] = actorId; - headers["x-rivet-port"] = "main"; - headers[HEADER_EXPOSE_INTERNAL_ERROR] = "true"; - headers[HEADER_ENCODING] = encoding; - if (params) { - headers[HEADER_CONN_PARAMS] = JSON.stringify(params); - } - if (authData) { - headers[HEADER_AUTH_DATA] = JSON.stringify(authData); - } - return headers; -} diff --git a/packages/rivetkit/src/drivers/engine/mod.ts b/packages/rivetkit/src/drivers/engine/mod.ts index a2f52970d..a282f4de7 100644 --- a/packages/rivetkit/src/drivers/engine/mod.ts +++ b/packages/rivetkit/src/drivers/engine/mod.ts @@ -2,13 +2,12 @@ import type { Client } from "@/client/client"; import type { ManagerDriver } from "@/manager/driver"; import type { RegistryConfig } from "@/registry/config"; import type { DriverConfig, RunConfig } from "@/registry/run-config"; +import { RemoteManagerDriver } from "@/remote-manager-driver/mod"; import { EngineActorDriver } from "./actor-driver"; import { ConfigSchema, type InputConfig } from "./config"; -import { EngineManagerDriver } from "./manager-driver"; export { EngineActorDriver } from "./actor-driver"; export { type Config, ConfigSchema, type InputConfig } from "./config"; -export { EngineManagerDriver } from "./manager-driver"; export function createEngineDriver(inputConfig?: InputConfig): DriverConfig { const config = ConfigSchema.parse(inputConfig); @@ -16,7 +15,7 @@ export function createEngineDriver(inputConfig?: InputConfig): DriverConfig { return { name: "engine", manager: (_registryConfig, runConfig) => { - return new EngineManagerDriver(config, runConfig); + return new RemoteManagerDriver(runConfig); }, actor: ( registryConfig: RegistryConfig, diff --git a/packages/rivetkit/src/drivers/file-system/manager.ts b/packages/rivetkit/src/drivers/file-system/manager.ts index fa1aabd8b..08b696840 100644 --- a/packages/rivetkit/src/drivers/file-system/manager.ts +++ b/packages/rivetkit/src/drivers/file-system/manager.ts @@ -17,7 +17,6 @@ import type { GetWithKeyInput, ManagerDriver, } from "@/driver-helpers/mod"; -import { createInlineClientDriver } from "@/inline-client-driver/mod"; import { ManagerInspector } from "@/inspector/manager"; import { type Actor, ActorFeature, type ActorId } from "@/inspector/mod"; import type { ManagerDisplayInformation } from "@/manager/driver"; @@ -28,6 +27,7 @@ import { PATH_RAW_WEBSOCKET_PREFIX, type RegistryConfig, type RunConfig, + type UniversalWebSocket, } from "@/mod"; import type * as schema from "@/schemas/file-system-driver/mod"; import type { FileSystemGlobalState } from "./global-state"; @@ -120,7 +120,7 @@ export class FileSystemManagerDriver implements ManagerDriver { } // Actors run on the same node as the manager, so we create a dummy actor router that we route requests to - const inlineClient = createClientWithDriver(createInlineClientDriver(this)); + const inlineClient = createClientWithDriver(this); this.#actorDriver = this.#driverConfig.actor( registryConfig, runConfig, @@ -141,7 +141,7 @@ export class FileSystemManagerDriver implements ManagerDriver { actorId: string, encoding: Encoding, params: unknown, - ): Promise { + ): Promise { // TODO: // Handle raw WebSocket paths diff --git a/packages/rivetkit/src/inline-client-driver/mod.ts b/packages/rivetkit/src/inline-client-driver/mod.ts deleted file mode 100644 index 178db3ecb..000000000 --- a/packages/rivetkit/src/inline-client-driver/mod.ts +++ /dev/null @@ -1,389 +0,0 @@ -import * as cbor from "cbor-x"; -import type { Context as HonoContext } from "hono"; -import invariant from "invariant"; -import onChange from "on-change"; -import type { WebSocket } from "ws"; -import * as errors from "@/actor/errors"; -import type { Encoding } from "@/actor/protocol/serde"; -import { - PATH_CONNECT_WEBSOCKET, - PATH_RAW_WEBSOCKET_PREFIX, -} from "@/actor/router"; -import { - HEADER_CONN_ID, - HEADER_CONN_PARAMS, - HEADER_CONN_TOKEN, - HEADER_ENCODING, - HEADER_EXPOSE_INTERNAL_ERROR, -} from "@/actor/router-endpoints"; -import { assertUnreachable } from "@/actor/utils"; -import type { ClientDriver } from "@/client/client"; -import { ActorError as ClientActorError } from "@/client/errors"; -import { sendHttpRequest } from "@/client/utils"; -import { importEventSource } from "@/common/eventsource"; -import type { UniversalEventSource } from "@/common/eventsource-interface"; -import { deconstructError } from "@/common/utils"; -import type { ManagerDriver } from "@/manager/driver"; -import type { ActorQuery } from "@/manager/protocol/query"; -import type { RunConfig } from "@/mod"; -import type * as protocol from "@/schemas/client-protocol/mod"; -import { - HTTP_ACTION_REQUEST_VERSIONED, - HTTP_ACTION_RESPONSE_VERSIONED, - TO_CLIENT_VERSIONED, - TO_SERVER_VERSIONED, -} from "@/schemas/client-protocol/versioned"; -import { bufferToArrayBuffer, httpUserAgent } from "@/utils"; -import { logger } from "./log"; - -/** - * Client driver that calls the manager driver inline. - * - * This is only applicable to standalone & coordinated topologies. - * - * This driver can access private resources. - * - * This driver serves a double purpose as: - * - Providing the client for the internal requests - * - Provide the driver for the manager HTTP router (see manager/router.ts) - */ -export function createInlineClientDriver( - managerDriver: ManagerDriver, -): ClientDriver { - const driver: ClientDriver = { - action: async = unknown[], Response = unknown>( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - actionName: string, - args: Args, - opts: { signal?: AbortSignal }, - ): Promise => { - try { - // Get the actor ID - const { actorId } = await queryActor(c, actorQuery, managerDriver); - logger().debug({ msg: "found actor for action", actorId }); - invariant(actorId, "Missing actor ID"); - - // Invoke the action - logger().debug({ msg: "handling action", actionName, encoding }); - const responseData = await sendHttpRequest< - protocol.HttpActionRequest, - protocol.HttpActionResponse - >({ - url: `http://actor/action/${encodeURIComponent(actionName)}`, - method: "POST", - headers: { - [HEADER_ENCODING]: encoding, - ...(params !== undefined - ? { [HEADER_CONN_PARAMS]: JSON.stringify(params) } - : {}), - [HEADER_EXPOSE_INTERNAL_ERROR]: "true", - }, - body: { - args: bufferToArrayBuffer(cbor.encode(args)), - } satisfies protocol.HttpActionRequest, - encoding: encoding, - customFetch: managerDriver.sendRequest.bind(managerDriver, actorId), - signal: opts?.signal, - requestVersionedDataHandler: HTTP_ACTION_REQUEST_VERSIONED, - responseVersionedDataHandler: HTTP_ACTION_RESPONSE_VERSIONED, - }); - - return cbor.decode(new Uint8Array(responseData.output)); - } catch (err) { - // Standardize to ClientActorError instead of the native backend error - const { code, message, metadata } = deconstructError( - err, - logger(), - {}, - true, - ); - const x = new ClientActorError(code, message, metadata); - throw new ClientActorError(code, message, metadata); - } - }, - - resolveActorId: async ( - c: HonoContext | undefined, - actorQuery: ActorQuery, - _encodingKind: Encoding, - ): Promise => { - // Get the actor ID - const { actorId } = await queryActor(c, actorQuery, managerDriver); - logger().debug({ msg: "resolved actor", actorId }); - invariant(actorId, "missing actor ID"); - - return actorId; - }, - - connectWebSocket: async ( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params?: unknown, - ): Promise => { - // Get the actor ID - const { actorId } = await queryActor(c, actorQuery, managerDriver); - logger().debug({ msg: "found actor for action", actorId }); - invariant(actorId, "Missing actor ID"); - - // Invoke the action - logger().debug({ - msg: "opening websocket", - actorId, - encoding: encodingKind, - }); - - // Open WebSocket - const ws = await managerDriver.openWebSocket( - PATH_CONNECT_WEBSOCKET, - actorId, - encodingKind, - params, - ); - - // Node & browser WebSocket types are incompatible - return ws as any; - }, - - connectSse: async ( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encodingKind: Encoding, - params: unknown, - ): Promise => { - // Get the actor ID - const { actorId } = await queryActor(c, actorQuery, managerDriver); - logger().debug({ msg: "found actor for sse connection", actorId }); - invariant(actorId, "Missing actor ID"); - - logger().debug({ - msg: "opening sse connection", - actorId, - encoding: encodingKind, - }); - - const EventSourceClass = await importEventSource(); - - const eventSource = new EventSourceClass("http://actor/connect/sse", { - fetch: (input, init) => { - return fetch(input, { - ...init, - headers: { - ...init?.headers, - "User-Agent": httpUserAgent(), - [HEADER_ENCODING]: encodingKind, - ...(params !== undefined - ? { [HEADER_CONN_PARAMS]: JSON.stringify(params) } - : {}), - [HEADER_EXPOSE_INTERNAL_ERROR]: "true", - }, - }); - }, - }) as UniversalEventSource; - - return eventSource; - }, - - sendHttpMessage: async ( - c: HonoContext | undefined, - actorId: string, - encoding: Encoding, - connectionId: string, - connectionToken: string, - message: protocol.ToServer, - ): Promise => { - logger().debug({ msg: "sending http message", actorId, connectionId }); - - // Send an HTTP request to the connections endpoint - await sendHttpRequest({ - url: "http://actor/connections/message", - method: "POST", - headers: { - [HEADER_ENCODING]: encoding, - [HEADER_CONN_ID]: connectionId, - [HEADER_CONN_TOKEN]: connectionToken, - [HEADER_EXPOSE_INTERNAL_ERROR]: "true", - }, - body: message, - encoding, - skipParseResponse: true, - customFetch: managerDriver.sendRequest.bind(managerDriver, actorId), - requestVersionedDataHandler: TO_SERVER_VERSIONED, - responseVersionedDataHandler: TO_CLIENT_VERSIONED, - }); - }, - - rawHttpRequest: async ( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - path: string, - init: RequestInit, - ): Promise => { - try { - // Get the actor ID - const { actorId } = await queryActor(c, actorQuery, managerDriver); - logger().debug({ msg: "found actor for raw http", actorId }); - invariant(actorId, "Missing actor ID"); - - // Build the URL with normalized path - const normalizedPath = path.startsWith("/") ? path.slice(1) : path; - const url = new URL(`http://actor/raw/http/${normalizedPath}`); - - // Forward conn params if provided - const proxyRequestHeaders = new Headers(init.headers); - if (params) { - proxyRequestHeaders.set(HEADER_CONN_PARAMS, JSON.stringify(params)); - } - - // Forward the request to the actor - const proxyRequest = new Request(url, { - ...init, - headers: proxyRequestHeaders, - }); - - return await managerDriver.sendRequest(actorId, proxyRequest); - } catch (err) { - // Standardize to ClientActorError instead of the native backend error - const { code, message, metadata } = deconstructError( - err, - logger(), - {}, - true, - ); - throw new ClientActorError(code, message, metadata); - } - }, - - rawWebSocket: async ( - c: HonoContext | undefined, - actorQuery: ActorQuery, - encoding: Encoding, - params: unknown, - path: string, - protocols: string | string[] | undefined, - ): Promise => { - // Get the actor ID - const { actorId } = await queryActor(c, actorQuery, managerDriver); - logger().debug({ msg: "found actor for action", actorId }); - invariant(actorId, "Missing actor ID"); - - // Normalize path to match raw HTTP behavior - const normalizedPath = path.startsWith("/") ? path.slice(1) : path; - logger().debug({ - msg: "opening websocket", - actorId, - encoding, - path: normalizedPath, - }); - - // Open WebSocket - const ws = await managerDriver.openWebSocket( - `${PATH_RAW_WEBSOCKET_PREFIX}${normalizedPath}`, - actorId, - encoding, - params, - ); - - // Node & browser WebSocket types are incompatible - return ws as any; - }, - }; - - return driver; -} - -/** - * Query the manager driver to get or create a actor based on the provided query - */ -export async function queryActor( - c: HonoContext | undefined, - query: ActorQuery, - driver: ManagerDriver, -): Promise<{ actorId: string }> { - logger().debug({ msg: "querying actor", query }); - let actorOutput: { actorId: string }; - if ("getForId" in query) { - const output = await driver.getForId({ - c, - name: query.getForId.name, - actorId: query.getForId.actorId, - }); - if (!output) throw new errors.ActorNotFound(query.getForId.actorId); - actorOutput = output; - } else if ("getForKey" in query) { - const existingActor = await driver.getWithKey({ - c, - name: query.getForKey.name, - key: query.getForKey.key, - }); - if (!existingActor) { - throw new errors.ActorNotFound( - `${query.getForKey.name}:${JSON.stringify(query.getForKey.key)}`, - ); - } - actorOutput = existingActor; - } else if ("getOrCreateForKey" in query) { - const getOrCreateOutput = await driver.getOrCreateWithKey({ - c, - name: query.getOrCreateForKey.name, - key: query.getOrCreateForKey.key, - input: query.getOrCreateForKey.input, - region: query.getOrCreateForKey.region, - }); - actorOutput = { - actorId: getOrCreateOutput.actorId, - }; - } else if ("create" in query) { - const createOutput = await driver.createActor({ - c, - name: query.create.name, - key: query.create.key, - input: query.create.input, - region: query.create.region, - }); - actorOutput = { - actorId: createOutput.actorId, - }; - } else { - throw new errors.InvalidRequest("Invalid query format"); - } - - logger().debug({ msg: "actor query result", actorId: actorOutput.actorId }); - return { actorId: actorOutput.actorId }; -} - -/** - * Removes the on-change library's proxy recursively from a value so we can clone it with `structuredClone`. - */ -function unproxyRecursive(objProxied: T): T { - const obj = onChange.target(objProxied); - - // Short circuit if this object was proxied - // - // If the reference is different, then this value was proxied and no - // nested values are proxied - if (obj !== objProxied) return obj; - - // Handle null/undefined - if (!obj || typeof obj !== "object") { - return obj; - } - - // Handle arrays - if (Array.isArray(obj)) { - return obj.map((x) => unproxyRecursive(x)) as T; - } - - // Handle objects - const result: any = {}; - for (const key in obj) { - result[key] = unproxyRecursive(obj[key]); - } - - return result; -} diff --git a/packages/rivetkit/src/manager-api/routes/actors-create.ts b/packages/rivetkit/src/manager-api/routes/actors-create.ts new file mode 100644 index 000000000..e9da2fc38 --- /dev/null +++ b/packages/rivetkit/src/manager-api/routes/actors-create.ts @@ -0,0 +1,16 @@ +import { z } from "zod"; +import { ActorSchema } from "./common"; + +export const ActorsCreateRequestSchema = z.object({ + name: z.string(), + runner_name_selector: z.string(), + crash_policy: z.string(), + key: z.string().nullable().optional(), + input: z.string().nullable().optional(), +}); +export type ActorsCreateRequest = z.infer; + +export const ActorsCreateResponseSchema = z.object({ + actor: ActorSchema, +}); +export type ActorsCreateResponse = z.infer; diff --git a/packages/rivetkit/src/manager-api/routes/actors-delete.ts b/packages/rivetkit/src/manager-api/routes/actors-delete.ts new file mode 100644 index 000000000..483058e65 --- /dev/null +++ b/packages/rivetkit/src/manager-api/routes/actors-delete.ts @@ -0,0 +1,4 @@ +import { z } from "zod"; + +export const ActorsDeleteResponseSchema = z.object({}); +export type ActorsDeleteResponse = z.infer; diff --git a/packages/rivetkit/src/manager-api/routes/actors-get-by-id.ts b/packages/rivetkit/src/manager-api/routes/actors-get-by-id.ts new file mode 100644 index 000000000..ebbbfb39d --- /dev/null +++ b/packages/rivetkit/src/manager-api/routes/actors-get-by-id.ts @@ -0,0 +1,7 @@ +import { z } from "zod"; +import { RivetIdSchema } from "./common"; + +export const ActorsGetByIdResponseSchema = z.object({ + actor_id: RivetIdSchema.nullable().optional(), +}); +export type ActorsGetByIdResponse = z.infer; diff --git a/packages/rivetkit/src/manager-api/routes/actors-get-or-create-by-id.ts b/packages/rivetkit/src/manager-api/routes/actors-get-or-create-by-id.ts new file mode 100644 index 000000000..ed8669784 --- /dev/null +++ b/packages/rivetkit/src/manager-api/routes/actors-get-or-create-by-id.ts @@ -0,0 +1,29 @@ +import { z } from "zod"; +import { ActorSchema, RivetIdSchema } from "./common"; + +export const ActorsGetOrCreateResponseSchema = z.object({ + actor: ActorSchema, + created: z.boolean(), +}); +export type ActorsGetOrCreateResponse = z.infer< + typeof ActorsGetOrCreateResponseSchema +>; + +export const ActorsGetOrCreateByIdResponseSchema = z.object({ + actor_id: RivetIdSchema, + created: z.boolean(), +}); +export type ActorsGetOrCreateByIdResponse = z.infer< + typeof ActorsGetOrCreateByIdResponseSchema +>; + +export const ActorsGetOrCreateByIdRequestSchema = z.object({ + name: z.string(), + key: z.string(), + runner_name_selector: z.string(), + crash_policy: z.string(), + input: z.string().nullable().optional(), +}); +export type ActorsGetOrCreateByIdRequest = z.infer< + typeof ActorsGetOrCreateByIdRequestSchema +>; diff --git a/packages/rivetkit/src/manager-api/routes/actors-get.ts b/packages/rivetkit/src/manager-api/routes/actors-get.ts new file mode 100644 index 000000000..6915389ce --- /dev/null +++ b/packages/rivetkit/src/manager-api/routes/actors-get.ts @@ -0,0 +1,7 @@ +import { z } from "zod"; +import { ActorSchema } from "./common"; + +export const ActorsGetResponseSchema = z.object({ + actor: ActorSchema, +}); +export type ActorsGetResponse = z.infer; diff --git a/packages/rivetkit/src/manager-api/routes/common.ts b/packages/rivetkit/src/manager-api/routes/common.ts new file mode 100644 index 000000000..c19109cb3 --- /dev/null +++ b/packages/rivetkit/src/manager-api/routes/common.ts @@ -0,0 +1,18 @@ +import { z } from "zod"; + +export const RivetIdSchema = z.string(); +export type RivetId = z.infer; + +export const ActorSchema = z.object({ + actor_id: RivetIdSchema, + name: z.string(), + key: z.string(), + namespace_id: RivetIdSchema, + runner_name_selector: z.string(), + create_ts: z.number(), + connectable_ts: z.number().nullable().optional(), + destroy_ts: z.number().nullable().optional(), + sleep_ts: z.number().nullable().optional(), + start_ts: z.number().nullable().optional(), +}); +export type Actor = z.infer; diff --git a/packages/rivetkit/src/manager/auth.ts b/packages/rivetkit/src/manager/auth.ts deleted file mode 100644 index e88e4d42d..000000000 --- a/packages/rivetkit/src/manager/auth.ts +++ /dev/null @@ -1,124 +0,0 @@ -import type { Context as HonoContext } from "hono"; -import type { AuthIntent } from "@/actor/config"; -import type { AnyActorDefinition } from "@/actor/definition"; -import * as errors from "@/actor/errors"; -import type { RegistryConfig } from "@/registry/config"; -import { stringifyError } from "@/utils"; -import type { ManagerDriver } from "./driver"; -import { logger } from "./log"; -import type { ActorQuery } from "./protocol/query"; - -/** - * Get authentication intents from a actor query - */ -export function getIntentsFromQuery(query: ActorQuery): Set { - const intents = new Set(); - - if ("getForId" in query) { - intents.add("get"); - } else if ("getForKey" in query) { - intents.add("get"); - } else if ("getOrCreateForKey" in query) { - intents.add("get"); - intents.add("create"); - } else if ("create" in query) { - intents.add("create"); - } - - return intents; -} - -/** - * Get actor name from a actor query - */ -export async function getActorNameFromQuery( - c: HonoContext, - driver: ManagerDriver, - query: ActorQuery, -): Promise { - if ("getForId" in query) { - // TODO: This will have a duplicate call to getForId between this and queryActor - const output = await driver.getForId({ - c, - name: query.getForId.name, - actorId: query.getForId.actorId, - }); - if (!output) throw new errors.ActorNotFound(query.getForId.actorId); - return output.name; - } else if ("getForKey" in query) { - return query.getForKey.name; - } else if ("getOrCreateForKey" in query) { - return query.getOrCreateForKey.name; - } else if ("create" in query) { - return query.create.name; - } else { - throw new errors.InvalidRequest("Invalid query format"); - } -} - -/** - * Authenticate a request using the actor's onAuth function - */ -export async function authenticateRequest( - c: HonoContext, - actorDefinition: AnyActorDefinition, - intents: Set, - params: unknown, -): Promise { - if (!("onAuth" in actorDefinition.config)) { - throw new errors.Forbidden( - "Actor requires authentication but no onAuth handler is defined (https://rivet.gg/docs/actors/authentication/). Provide an empty handler to disable auth: `onAuth: () => {}`", - ); - } - - try { - const dataOrPromise = actorDefinition.config.onAuth( - { - request: c.req.raw, - intents, - }, - params, - ); - if (dataOrPromise instanceof Promise) { - return await dataOrPromise; - } else { - return dataOrPromise; - } - } catch (error) { - logger().info({ - msg: "authentication error", - error: stringifyError(error), - }); - throw error; - } -} - -/** - * Simplified authentication for endpoints that combines all auth steps - */ -export async function authenticateEndpoint( - c: HonoContext, - driver: ManagerDriver, - registryConfig: RegistryConfig, - query: ActorQuery, - additionalIntents: AuthIntent[], - params: unknown, -): Promise { - // Get base intents from query - const intents = getIntentsFromQuery(query); - - // Add endpoint-specific intents - for (const intent of additionalIntents) { - intents.add(intent); - } - - // Get actor definition - const actorName = await getActorNameFromQuery(c, driver, query); - const actorDefinition = registryConfig.use[actorName]; - if (!actorDefinition) { - throw new errors.ActorNotFound(actorName); - } - - // Authenticate - return await authenticateRequest(c, actorDefinition, intents, params); -} diff --git a/packages/rivetkit/src/manager/driver.ts b/packages/rivetkit/src/manager/driver.ts index 5395bfcb6..125ade201 100644 --- a/packages/rivetkit/src/manager/driver.ts +++ b/packages/rivetkit/src/manager/driver.ts @@ -1,5 +1,5 @@ import type { Env, Hono, Context as HonoContext } from "hono"; -import type { ActorKey, Encoding } from "@/actor/mod"; +import type { ActorKey, Encoding, UniversalWebSocket } from "@/actor/mod"; import type { ManagerInspector } from "@/inspector/manager"; import type { RunConfig } from "@/mod"; import type { RegistryConfig } from "@/registry/config"; @@ -21,7 +21,7 @@ export interface ManagerDriver { actorId: string, encoding: Encoding, params: unknown, - ): Promise; + ): Promise; proxyRequest( c: HonoContext, actorRequest: Request, diff --git a/packages/rivetkit/src/manager/router.ts b/packages/rivetkit/src/manager/router.ts index c30ac77f4..4abc76876 100644 --- a/packages/rivetkit/src/manager/router.ts +++ b/packages/rivetkit/src/manager/router.ts @@ -1,127 +1,39 @@ import { createRoute, OpenAPIHono } from "@hono/zod-openapi"; import * as cbor from "cbor-x"; -import { - Hono, - type Context as HonoContext, - type MiddlewareHandler, -} from "hono"; +import type { Hono } from "hono"; import { cors } from "hono/cors"; -import { streamSSE } from "hono/streaming"; -import type { WSContext } from "hono/ws"; -import invariant from "invariant"; -import type { CloseEvent, MessageEvent, WebSocket } from "ws"; import { z } from "zod"; -import * as errors from "@/actor/errors"; -import type { Transport } from "@/actor/protocol/old"; -import type { Encoding } from "@/actor/protocol/serde"; -import { - PATH_CONNECT_WEBSOCKET, - PATH_RAW_WEBSOCKET_PREFIX, -} from "@/actor/router"; import { - ALLOWED_PUBLIC_HEADERS, - getRequestEncoding, - getRequestQuery, - HEADER_ACTOR_ID, - HEADER_ACTOR_QUERY, - HEADER_AUTH_DATA, - HEADER_CONN_ID, - HEADER_CONN_PARAMS, - HEADER_CONN_TOKEN, - HEADER_ENCODING, -} from "@/actor/router-endpoints"; -import type { ClientDriver } from "@/client/client"; + ActorError, + ActorNotFound, + FeatureNotImplemented, + MissingActorHeader, + RouteNotFound, + WebSocketsNotEnabled, +} from "@/actor/errors"; import { handleRouteError, handleRouteNotFound, loggerMiddleware, } from "@/common/router"; import { - type DeconstructedError, - deconstructError, - noopNext, - stringifyError, -} from "@/common/utils"; -import { createManagerInspectorRouter } from "@/inspector/manager"; -import { secureInspector } from "@/inspector/utils"; -import type { UpgradeWebSocketArgs } from "@/mod"; + type ActorsCreateRequest, + ActorsCreateRequestSchema, + ActorsCreateResponseSchema, +} from "@/manager-api/routes/actors-create"; +import { ActorsDeleteResponseSchema } from "@/manager-api/routes/actors-delete"; +import { ActorsGetResponseSchema } from "@/manager-api/routes/actors-get"; +import { ActorsGetByIdResponseSchema } from "@/manager-api/routes/actors-get-by-id"; +import { + type ActorsGetOrCreateByIdRequest, + ActorsGetOrCreateByIdRequestSchema, + ActorsGetOrCreateByIdResponseSchema, +} from "@/manager-api/routes/actors-get-or-create-by-id"; +import { RivetIdSchema } from "@/manager-api/routes/common"; import type { RegistryConfig } from "@/registry/config"; import type { RunConfig } from "@/registry/run-config"; -import type * as protocol from "@/schemas/client-protocol/mod"; -import { - HTTP_RESOLVE_RESPONSE_VERSIONED, - TO_CLIENT_VERSIONED, -} from "@/schemas/client-protocol/versioned"; -import { serializeWithEncoding } from "@/serde"; -import { bufferToArrayBuffer } from "@/utils"; -import { authenticateEndpoint } from "./auth"; import type { ManagerDriver } from "./driver"; import { logger } from "./log"; -import type { ActorQuery } from "./protocol/query"; -import { - ActorQuerySchema, - ConnectRequestSchema, - ConnectWebSocketRequestSchema, - ConnMessageRequestSchema, - ResolveRequestSchema, -} from "./protocol/query"; - -/** - * Parse WebSocket protocol headers for query and connection parameters - */ -function parseWebSocketProtocols(protocols: string | undefined): { - queryRaw: string | undefined; - encodingRaw: string | undefined; - connParamsRaw: string | undefined; -} { - let queryRaw: string | undefined; - let encodingRaw: string | undefined; - let connParamsRaw: string | undefined; - - if (protocols) { - const protocolList = protocols.split(",").map((p) => p.trim()); - for (const protocol of protocolList) { - if (protocol.startsWith("query.")) { - queryRaw = decodeURIComponent(protocol.substring("query.".length)); - } else if (protocol.startsWith("encoding.")) { - encodingRaw = protocol.substring("encoding.".length); - } else if (protocol.startsWith("conn_params.")) { - connParamsRaw = decodeURIComponent( - protocol.substring("conn_params.".length), - ); - } - } - } - - return { queryRaw, encodingRaw, connParamsRaw }; -} - -const OPENAPI_ENCODING = z.string().openapi({ - description: "The encoding format to use for the response (json, cbor)", - example: "json", -}); - -const OPENAPI_ACTOR_QUERY = z.string().openapi({ - description: "Actor query information", -}); - -const OPENAPI_CONN_PARAMS = z.string().openapi({ - description: "Connection parameters", -}); - -const OPENAPI_ACTOR_ID = z.string().openapi({ - description: "Actor ID (used in some endpoints)", - example: "actor-123456", -}); - -const OPENAPI_CONN_ID = z.string().openapi({ - description: "Connection ID", - example: "conn-123456", -}); - -const OPENAPI_CONN_TOKEN = z.string().openapi({ - description: "Connection token", -}); function buildOpenApiResponses(schema: T, validateBody: boolean) { return { @@ -144,17 +56,9 @@ function buildOpenApiResponses(schema: T, validateBody: boolean) { }; } -/** - * Only use `validateBody` to `true` if you need to export OpenAPI JSON. - * - * If left enabled for production, this will cause errors. We disable JSON validation since: - * - It prevents us from proxying requests, since validating the body requires consuming the body so we can't forward the body - * - We validate all types at the actor router layer since most requests are proxied - */ export function createManagerRouter( registryConfig: RegistryConfig, runConfig: RunConfig, - inlineClientDriver: ClientDriver, managerDriver: ManagerDriver, validateBody: boolean, ): { router: Hono; openapi: OpenAPIHono } { @@ -164,1653 +68,311 @@ export function createManagerRouter( router.use("*", loggerMiddleware(logger())); - if (runConfig.cors || runConfig.inspector?.cors) { - router.use("*", async (c, next) => { - // Don't apply to WebSocket routes - // HACK: This could be insecure if we had a varargs path. We have to check the path suffix for WS since we don't know the path that this router was mounted. - // HACK: Checking "/websocket/" is not safe, but there is no other way to handle this if we don't know the base path this is - // mounted on - const path = c.req.path; - if ( - path.endsWith("/actors/connect/websocket") || - path.includes("/actors/raw/websocket/") || - // inspectors implement their own CORS handling - path.endsWith("/inspect") || - path.endsWith("/actors/inspect") - ) { - return next(); + if (runConfig.cors) { + router.use("*", cors(runConfig.cors)); + } + + // Actor proxy middleware - intercept requests with x-rivet-target=actor + router.use("*", async (c, next) => { + const target = c.req.header("x-rivet-target"); + const actorId = c.req.header("x-rivet-actor"); + + if (target === "actor") { + if (!actorId) { + throw new MissingActorHeader(); } - return cors({ - ...(runConfig.cors ?? {}), - ...(runConfig.inspector?.cors ?? {}), - origin: (origin, c) => { - const inspectorOrigin = runConfig.inspector?.cors?.origin; + logger().debug({ + msg: "proxying request to actor", + actorId, + path: c.req.path, + method: c.req.method, + }); - if (inspectorOrigin !== undefined) { - if (typeof inspectorOrigin === "function") { - const allowed = inspectorOrigin(origin, c); - if (allowed) return allowed; - // Proceed to next CORS config if none provided - } else if (Array.isArray(inspectorOrigin)) { - return inspectorOrigin.includes(origin) ? origin : undefined; - } else { - return inspectorOrigin; - } - } + // Handle WebSocket upgrade + if (c.req.header("upgrade") === "websocket") { + const upgradeWebSocket = runConfig.getUpgradeWebSocket?.(); + if (!upgradeWebSocket) { + throw new WebSocketsNotEnabled(); + } - if (runConfig.cors?.origin !== undefined) { - if (typeof runConfig.cors.origin === "function") { - const allowed = runConfig.cors.origin(origin, c); - if (allowed) return allowed; - } else { - return runConfig.cors.origin as string; - } - } + // For WebSocket, use the driver's proxyWebSocket method + // Extract any additional headers that might be needed + const encoding = c.req.header("x-rivet-encoding") || "json"; + const connParams = c.req.header("x-rivet-conn-params"); + const authData = c.req.header("x-rivet-auth-data"); - return null; - }, - allowMethods: (origin, c) => { - const inspectorMethods = runConfig.inspector?.cors?.allowMethods; - if (inspectorMethods) { - if (typeof inspectorMethods === "function") { - return inspectorMethods(origin, c); - } - return inspectorMethods; - } + return await managerDriver.proxyWebSocket( + c, + c.req.path, + actorId, + encoding as any, // Will be validated by driver + connParams ? JSON.parse(connParams) : undefined, + authData ? JSON.parse(authData) : undefined, + ); + } - if (runConfig.cors?.allowMethods) { - if (typeof runConfig.cors.allowMethods === "function") { - return runConfig.cors.allowMethods(origin, c); - } - return runConfig.cors.allowMethods; - } + // Handle regular HTTP requests + // Preserve all headers except the routing headers + const proxyHeaders = new Headers(c.req.raw.headers); + proxyHeaders.delete("x-rivet-target"); + proxyHeaders.delete("x-rivet-actor"); - return []; - }, - allowHeaders: [ - ...(runConfig.cors?.allowHeaders ?? []), - ...(runConfig.inspector?.cors?.allowHeaders ?? []), - ...ALLOWED_PUBLIC_HEADERS, - "Content-Type", - "User-Agent", - ], - credentials: - runConfig.cors?.credentials ?? - runConfig.inspector?.cors?.credentials ?? - true, - })(c, next); - }); - } + // Build the proxy request with the actor URL format + const url = new URL(c.req.url); + const proxyUrl = new URL(`http://actor${url.pathname}${url.search}`); + + const proxyRequest = new Request(proxyUrl, { + method: c.req.method, + headers: proxyHeaders, + body: c.req.raw.body, + }); + + return await managerDriver.proxyRequest(c, proxyRequest, actorId); + } + + return next(); + }); // GET / - router.get("/", (c: HonoContext) => { + router.get("/", (c) => { return c.text( - "This is an RivetKit registry.\n\nLearn more at https://rivetkit.org", + "This is a RivetKit server.\n\nLearn more at https://rivetkit.org", ); }); - // POST /actors/resolve + // GET /actors/by-id { - const ResolveQuerySchema = z - .object({ - query: z.any().openapi({ - example: { getForId: { actorId: "actor-123" } }, + const route = createRoute({ + method: "get", + path: "/actors/by-id", + request: { + query: z.object({ + name: z.string(), + key: z.string(), }), - }) - .openapi("ResolveQuery"); + }, + responses: buildOpenApiResponses( + ActorsGetByIdResponseSchema, + validateBody, + ), + }); - const ResolveResponseSchema = z - .object({ - i: z.string().openapi({ - example: "actor-123", - }), - }) - .openapi("ResolveResponse"); + router.openapi(route, async (c) => { + const { name, key } = c.req.valid("query"); - const resolveRoute = createRoute({ - method: "post", - path: "/actors/resolve", + // Get actor by key from the driver + const actorOutput = await managerDriver.getWithKey({ + c, + name, + key: [key], // Convert string to ActorKey array + }); + + return c.json({ + actor_id: actorOutput?.actorId || null, + }); + }); + } + + // PUT /actors/by-id + { + const route = createRoute({ + method: "put", + path: "/actors/by-id", request: { body: { content: validateBody ? { "application/json": { - schema: ResolveQuerySchema, + schema: ActorsGetOrCreateByIdRequestSchema, }, } : {}, }, - headers: z.object({ - [HEADER_ACTOR_QUERY]: OPENAPI_ACTOR_QUERY, - }), }, - responses: buildOpenApiResponses(ResolveResponseSchema, validateBody), + responses: buildOpenApiResponses( + ActorsGetOrCreateByIdResponseSchema, + validateBody, + ), }); - router.openapi(resolveRoute, (c) => - handleResolveRequest(c, registryConfig, managerDriver), - ); - } + router.openapi(route, async (c) => { + const body = validateBody + ? await c.req.json() + : await c.req.json(); - // GET /actors/connect/websocket - { - // HACK: WebSockets don't work with mounts, so we need to dynamically match the trailing path - router.use("*", (c, next) => { - if (c.req.path.endsWith("/actors/connect/websocket")) { - return handleWebSocketConnectRequest( - c, - registryConfig, - runConfig, - managerDriver, - ); + // Parse and validate the request body if validation is enabled + if (validateBody) { + ActorsGetOrCreateByIdRequestSchema.parse(body); } - return next(); - }); + // Check if actor already exists + const existingActor = await managerDriver.getWithKey({ + c, + name: body.name, + key: [body.key], // Convert string to ActorKey array + }); - // This route is a noop, just used to generate docs - const wsRoute = createRoute({ - method: "get", - path: "/actors/connect/websocket", - responses: { - 101: { - description: "WebSocket upgrade", - }, - }, - }); + if (existingActor) { + return c.json({ + actor_id: existingActor.actorId, + created: false, + }); + } + + // Create new actor + const newActor = await managerDriver.getOrCreateWithKey({ + c, + name: body.name, + key: [body.key], // Convert string to ActorKey array + input: body.input + ? cbor.decode(Buffer.from(body.input, "base64")) + : undefined, + region: undefined, // Not provided in the request schema + }); - router.openapi(wsRoute, () => { - throw new Error("Should be unreachable"); + return c.json({ + actor_id: newActor.actorId, + created: true, + }); }); } - // GET /actors/connect/sse + // GET /actors/{actor_id} { - const sseRoute = createRoute({ + const route = createRoute({ method: "get", - path: "/actors/connect/sse", + path: "/actors/{actor_id}", request: { - headers: z.object({ - [HEADER_ENCODING]: OPENAPI_ENCODING, - [HEADER_ACTOR_QUERY]: OPENAPI_ACTOR_QUERY, - [HEADER_CONN_PARAMS]: OPENAPI_CONN_PARAMS.optional(), + params: z.object({ + actor_id: RivetIdSchema, }), }, - responses: { - 200: { - description: "SSE stream", - content: { - "text/event-stream": { - schema: z.unknown(), - }, - }, - }, - }, + responses: buildOpenApiResponses(ActorsGetResponseSchema, validateBody), }); - router.openapi(sseRoute, (c) => - handleSseConnectRequest(c, registryConfig, runConfig, managerDriver), - ); - } + router.openapi(route, async (c) => { + const { actor_id } = c.req.valid("param"); - // POST /actors/action/:action - { - const ActionParamsSchema = z - .object({ - action: z.string().openapi({ - param: { - name: "action", - in: "path", - }, - example: "myAction", - }), - }) - .openapi("ActionParams"); + // Get actor by ID from the driver + const actorOutput = await managerDriver.getForId({ + c, + name: "", // TODO: The API doesn't provide the name, this may need to be resolved + actorId: actor_id, + }); - const ActionRequestSchema = z - .object({ - query: z.any().openapi({ - example: { getForId: { actorId: "actor-123" } }, - }), - body: z - .any() - .optional() - .openapi({ - example: { param1: "value1", param2: 123 }, - }), - }) - .openapi("ActionRequest"); + if (!actorOutput) { + throw new ActorNotFound(actor_id); + } - const ActionResponseSchema = z.any().openapi("ActionResponse"); + // Transform ActorOutput to match ActorSchema + // Note: Some fields are not available from the driver and need defaults + const actor = { + actor_id: actorOutput.actorId, + name: actorOutput.name, + key: actorOutput.key, + namespace_id: "", // Not available from driver + runner_name_selector: "", // Not available from driver + create_ts: Date.now(), // Not available from driver + connectable_ts: null, + destroy_ts: null, + sleep_ts: null, + start_ts: null, + }; - const actionRoute = createRoute({ - method: "post", - path: "/actors/actions/{action}", - request: { - params: ActionParamsSchema, - body: { - content: validateBody - ? { - "application/json": { - schema: ActionRequestSchema, - }, - } - : {}, - }, - headers: z.object({ - [HEADER_ENCODING]: OPENAPI_ENCODING, - [HEADER_CONN_PARAMS]: OPENAPI_CONN_PARAMS.optional(), - }), - }, - responses: buildOpenApiResponses(ActionResponseSchema, validateBody), + return c.json({ actor }); }); - - router.openapi(actionRoute, (c) => - handleActionRequest(c, registryConfig, runConfig, managerDriver), - ); } - // POST /actors/message + // POST /actors { - const ConnectionMessageRequestSchema = z - .object({ - message: z.any().openapi({ - example: { type: "message", content: "Hello, actor!" }, - }), - }) - .openapi("ConnectionMessageRequest"); - - const ConnectionMessageResponseSchema = z - .any() - .openapi("ConnectionMessageResponse"); - - const messageRoute = createRoute({ + const route = createRoute({ method: "post", - path: "/actors/message", + path: "/actors", request: { body: { content: validateBody ? { "application/json": { - schema: ConnectionMessageRequestSchema, + schema: ActorsCreateRequestSchema, }, } : {}, }, - headers: z.object({ - [HEADER_ACTOR_ID]: OPENAPI_ACTOR_ID, - [HEADER_CONN_ID]: OPENAPI_CONN_ID, - [HEADER_ENCODING]: OPENAPI_ENCODING, - [HEADER_CONN_TOKEN]: OPENAPI_CONN_TOKEN, - }), }, responses: buildOpenApiResponses( - ConnectionMessageResponseSchema, + ActorsCreateResponseSchema, validateBody, ), }); - router.openapi(messageRoute, (c) => - handleMessageRequest(c, registryConfig, runConfig, managerDriver), - ); - } - - // Raw HTTP endpoints - /actors/raw/http/* - { - const RawHttpRequestBodySchema = z.any().optional().openapi({ - description: "Raw request body (can be any content type)", - }); - - const RawHttpResponseSchema = z.any().openapi({ - description: "Raw response from actor's onFetch handler", - }); - - // Define common route config - const rawHttpRouteConfig = { - path: "/actors/raw/http/*", - request: { - headers: z.object({ - [HEADER_ACTOR_QUERY]: OPENAPI_ACTOR_QUERY.optional(), - [HEADER_CONN_PARAMS]: OPENAPI_CONN_PARAMS.optional(), - }), - body: { - content: { - "*/*": { - schema: RawHttpRequestBodySchema, - }, - }, - }, - }, - responses: { - 200: { - description: "Success - response from actor's onFetch handler", - content: { - "*/*": { - schema: RawHttpResponseSchema, - }, - }, - }, - 404: { - description: "Actor does not have an onFetch handler", - }, - 500: { - description: "Internal server error or invalid response from actor", - }, - }, - }; - - // Create routes for each HTTP method - const httpMethods = [ - "get", - "post", - "put", - "delete", - "patch", - "head", - "options", - ] as const; - for (const method of httpMethods) { - const route = createRoute({ - method, - ...rawHttpRouteConfig, - }); + router.openapi(route, async (c) => { + const body = validateBody + ? await c.req.json() + : await c.req.json(); - router.openapi(route, async (c) => { - return handleRawHttpRequest( - c, - registryConfig, - runConfig, - managerDriver, - ); - }); - } - } - - // Raw WebSocket endpoint - /actors/raw/websocket/* - { - // HACK: WebSockets don't work with mounts, so we need to dynamically match the trailing path - router.use("*", async (c, next) => { - if (c.req.path.includes("/raw/websocket/")) { - return handleRawWebSocketRequest( - c, - registryConfig, - runConfig, - managerDriver, - ); + // Parse and validate the request body if validation is enabled + if (validateBody) { + ActorsCreateRequestSchema.parse(body); } - return next(); - }); - - // This route is a noop, just used to generate docs - const rawWebSocketRoute = createRoute({ - method: "get", - path: "/actors/raw/websocket/*", - request: {}, - responses: { - 101: { - description: "WebSocket upgrade successful", - }, - 400: { - description: "WebSockets not enabled or invalid request", - }, - 404: { - description: "Actor does not have an onWebSocket handler", - }, - }, - }); - - router.openapi(rawWebSocketRoute, () => { - throw new Error("Should be unreachable"); - }); - } - - if (runConfig.inspector?.enabled) { - router.route( - "/actors/inspect", - new Hono() - .use( - cors(runConfig.inspector.cors), - secureInspector(runConfig), - universalActorProxy({ - registryConfig, - runConfig, - driver: managerDriver, - }), - ) - .all("/", (c) => - // this should be handled by the actor proxy, but just in case - c.text("Unreachable.", 404), - ), - ); - router.route( - "/inspect", - new Hono() - .use( - cors(runConfig.inspector.cors), - secureInspector(runConfig), - async (c, next) => { - const inspector = managerDriver.inspector; - invariant(inspector, "inspector not supported on this platform"); - - c.set("inspector", inspector); - await next(); - }, - ) - .route("/", createManagerInspectorRouter()), - ); - } - - if (registryConfig.test.enabled) { - // Add HTTP endpoint to test the inline client - // - // We have to do this in a router since this needs to run in the same server as the RivetKit registry. Some test contexts to not run in the same server. - router.post(".test/inline-driver/call", async (c) => { - // TODO: use openapi instead - const buffer = await c.req.arrayBuffer(); - const { encoding, transport, method, args }: TestInlineDriverCallRequest = - cbor.decode(new Uint8Array(buffer)); - - logger().debug({ - msg: "received inline request", - encoding, - transport, - method, - args, + // Create actor using the driver + const actorOutput = await managerDriver.createActor({ + c, + name: body.name, + key: [body.key || crypto.randomUUID()], // Generate key if not provided, convert to ActorKey array + input: body.input ? JSON.parse(body.input) : undefined, + region: undefined, // Not provided in the request schema }); - // Forward inline driver request - let response: TestInlineDriverCallResponse; - try { - const output = await ((inlineClientDriver as any)[method] as any)( - ...args, - ); - response = { ok: output }; - } catch (rawErr) { - const err = deconstructError(rawErr, logger(), {}, true); - response = { err }; - } - - return c.body(cbor.encode(response)); - }); - - router.get(".test/inline-driver/connect-websocket", async (c) => { - const upgradeWebSocket = runConfig.getUpgradeWebSocket?.(); - invariant(upgradeWebSocket, "websockets not supported on this platform"); - - return upgradeWebSocket(async (c: any) => { - const { - actorQuery: actorQueryRaw, - params: paramsRaw, - encodingKind, - } = c.req.query() as { - actorQuery: string; - params?: string; - encodingKind: Encoding; - }; - const actorQuery = JSON.parse(actorQueryRaw); - const params = - paramsRaw !== undefined ? JSON.parse(paramsRaw) : undefined; - - logger().debug({ - msg: "received test inline driver websocket", - actorQuery, - params, - encodingKind, - }); - - // Connect to the actor using the inline client driver - this returns a Promise - const clientWsPromise = inlineClientDriver.connectWebSocket( - undefined, - actorQuery, - encodingKind, - params, - undefined, - ); - - return await createTestWebSocketProxy(clientWsPromise, "standard"); - })(c, noopNext()); - }); - - router.get(".test/inline-driver/raw-websocket", async (c) => { - const upgradeWebSocket = runConfig.getUpgradeWebSocket?.(); - invariant(upgradeWebSocket, "websockets not supported on this platform"); - - return upgradeWebSocket(async (c: any) => { - const { - actorQuery: actorQueryRaw, - params: paramsRaw, - encodingKind, - path, - protocols: protocolsRaw, - } = c.req.query() as { - actorQuery: string; - params?: string; - encodingKind: Encoding; - path: string; - protocols?: string; - }; - const actorQuery = JSON.parse(actorQueryRaw); - const params = - paramsRaw !== undefined ? JSON.parse(paramsRaw) : undefined; - const protocols = - protocolsRaw !== undefined ? JSON.parse(protocolsRaw) : undefined; - - logger().debug({ - msg: "received test inline driver raw websocket", - actorQuery, - params, - encodingKind, - path, - protocols, - }); - - // Connect to the actor using the inline client driver - this returns a Promise - logger().debug({ msg: "calling inlineClientDriver.rawWebSocket" }); - const clientWsPromise = inlineClientDriver.rawWebSocket( - undefined, - actorQuery, - encodingKind, - params, - path, - protocols, - undefined, - ); - - logger().debug({ msg: "calling createTestWebSocketProxy" }); - return await createTestWebSocketProxy(clientWsPromise, "raw"); - })(c, noopNext()); - }); - - // Raw HTTP endpoint for test inline driver - router.all(".test/inline-driver/raw-http/*", async (c) => { - // Extract parameters from headers - const actorQueryHeader = c.req.header(HEADER_ACTOR_QUERY); - const paramsHeader = c.req.header(HEADER_CONN_PARAMS); - const encodingHeader = c.req.header(HEADER_ENCODING); - - if (!actorQueryHeader || !encodingHeader) { - return c.text("Missing required headers", 400); - } - - const actorQuery = JSON.parse(actorQueryHeader); - const params = paramsHeader ? JSON.parse(paramsHeader) : undefined; - const encoding = encodingHeader as Encoding; - - // Extract the path after /raw-http/ - const fullPath = c.req.path; - const pathOnly = - fullPath.split("/.test/inline-driver/raw-http/")[1] || ""; - - // Include query string - const url = new URL(c.req.url); - const pathWithQuery = pathOnly + url.search; - - logger().debug({ - msg: "received test inline driver raw http", - actorQuery, - params, - encoding, - path: pathWithQuery, - method: c.req.method, - }); - - try { - // Forward the request using the inline client driver - const response = await inlineClientDriver.rawHttpRequest( - undefined, - actorQuery, - encoding, - params, - pathWithQuery, - { - method: c.req.method, - headers: c.req.raw.headers, - body: c.req.raw.body, - }, - undefined, - ); - - // Return the response directly - return response; - } catch (error) { - logger().error({ - msg: "error in test inline raw http", - error: stringifyError(error), - }); - - // Return error response - const err = deconstructError(error, logger(), {}, true); - return c.json( - { - error: { - code: err.code, - message: err.message, - metadata: err.metadata, - }, - }, - err.statusCode, - ); - } - }); - } - - managerDriver.modifyManagerRouter?.( - registryConfig, - router as unknown as Hono, - ); - - // Mount on both / and /registry - // - // We do this because the default requests are to `/registry/*`. - // - // If using `app.fetch` directly in a non-hono router, paths - // might not be truncated so they'll come to this router as - // `/registry/*`. If mounted correctly in Hono, requests will - // come in at the root as `/*`. - const mountedRouter = new Hono(); - mountedRouter.route("/", router); - mountedRouter.route("/registry", router); - - // IMPORTANT: These must be on `mountedRouter` instead of `router` or else they will not be called. - mountedRouter.notFound(handleRouteNotFound); - mountedRouter.onError(handleRouteError.bind(undefined, {})); - - return { router: mountedRouter, openapi: router }; -} - -export interface TestInlineDriverCallRequest { - encoding: Encoding; - transport: Transport; - method: string; - args: unknown[]; -} - -export type TestInlineDriverCallResponse = - | { - ok: T; - } - | { - err: DeconstructedError; - }; - -/** - * Query the manager driver to get or create a actor based on the provided query - */ -export async function queryActor( - c: HonoContext, - query: ActorQuery, - driver: ManagerDriver, -): Promise<{ actorId: string }> { - logger().debug({ msg: "querying actor", query }); - let actorOutput: { actorId: string }; - if ("getForId" in query) { - const output = await driver.getForId({ - c, - name: query.getForId.name, - actorId: query.getForId.actorId, - }); - if (!output) throw new errors.ActorNotFound(query.getForId.actorId); - actorOutput = output; - } else if ("getForKey" in query) { - const existingActor = await driver.getWithKey({ - c, - name: query.getForKey.name, - key: query.getForKey.key, - }); - if (!existingActor) { - throw new errors.ActorNotFound( - `${query.getForKey.name}:${JSON.stringify(query.getForKey.key)}`, - ); - } - actorOutput = existingActor; - } else if ("getOrCreateForKey" in query) { - const getOrCreateOutput = await driver.getOrCreateWithKey({ - c, - name: query.getOrCreateForKey.name, - key: query.getOrCreateForKey.key, - input: query.getOrCreateForKey.input, - region: query.getOrCreateForKey.region, - }); - actorOutput = { - actorId: getOrCreateOutput.actorId, - }; - } else if ("create" in query) { - const createOutput = await driver.createActor({ - c, - name: query.create.name, - key: query.create.key, - input: query.create.input, - region: query.create.region, - }); - actorOutput = { - actorId: createOutput.actorId, - }; - } else { - throw new errors.InvalidRequest("Invalid query format"); - } - - logger().debug({ msg: "actor query result", actorId: actorOutput.actorId }); - return { actorId: actorOutput.actorId }; -} - -/** - * Creates a WebSocket proxy for test endpoints that forwards messages between server and client WebSockets - */ -async function createTestWebSocketProxy( - clientWsPromise: Promise, - connectionType: string, -): Promise { - // Store a reference to the resolved WebSocket - let clientWs: WebSocket | null = null; - try { - // Resolve the client WebSocket promise - logger().debug({ msg: "awaiting client websocket promise" }); - const ws = await clientWsPromise; - clientWs = ws; - logger().debug({ - msg: "client websocket promise resolved", - constructor: ws?.constructor.name, - }); - - // Wait for ws to open - await new Promise((resolve, reject) => { - const onOpen = () => { - logger().debug({ msg: "test websocket connection opened" }); - resolve(); - }; - const onError = (error: any) => { - logger().error({ msg: "test websocket connection failed", error }); - reject( - new Error(`Failed to open WebSocket: ${error.message || error}`), - ); + // Transform ActorOutput to match ActorSchema + const actor = { + actor_id: actorOutput.actorId, + name: actorOutput.name, + key: actorOutput.key, + namespace_id: "", // Not available from driver + runner_name_selector: body.runner_name_selector, + create_ts: Date.now(), + connectable_ts: null, + destroy_ts: null, + sleep_ts: null, + start_ts: null, }; - ws.addEventListener("open", onOpen); - ws.addEventListener("error", onError); - }); - } catch (error) { - logger().error({ - msg: `failed to establish client ${connectionType} websocket connection`, - error, - }); - return { - onOpen: (_evt, serverWs) => { - serverWs.close(1011, "Failed to establish connection"); - }, - onMessage: () => {}, - onError: () => {}, - onClose: () => {}, - }; - } - - // Create WebSocket proxy handlers to relay messages between client and server - return { - onOpen: (_evt: any, serverWs: WSContext) => { - logger().debug({ - msg: `test ${connectionType} websocket connection opened`, - }); - - // Check WebSocket type - logger().debug({ - msg: "clientWs info", - constructor: clientWs.constructor.name, - hasAddEventListener: typeof clientWs.addEventListener === "function", - readyState: clientWs.readyState, - }); - - // Add message handler to forward messages from client to server - clientWs.addEventListener("message", (clientEvt: MessageEvent) => { - logger().debug({ - msg: `test ${connectionType} websocket connection message from client`, - dataType: typeof clientEvt.data, - isBlob: clientEvt.data instanceof Blob, - isArrayBuffer: clientEvt.data instanceof ArrayBuffer, - dataConstructor: clientEvt.data?.constructor?.name, - dataStr: - typeof clientEvt.data === "string" - ? clientEvt.data.substring(0, 100) - : undefined, - }); - - if (serverWs.readyState === 1) { - // OPEN - // Handle Blob data - if (clientEvt.data instanceof Blob) { - clientEvt.data - .arrayBuffer() - .then((buffer) => { - logger().debug({ - msg: "converted client blob to arraybuffer, sending to server", - bufferSize: buffer.byteLength, - }); - serverWs.send(buffer as any); - }) - .catch((error) => { - logger().error({ - msg: "failed to convert blob to arraybuffer", - error, - }); - }); - } else { - logger().debug({ - msg: "sending client data directly to server", - dataType: typeof clientEvt.data, - dataLength: - typeof clientEvt.data === "string" - ? clientEvt.data.length - : undefined, - }); - serverWs.send(clientEvt.data as any); - } - } - }); - - // Add close handler to close server when client closes - clientWs.addEventListener("close", (clientEvt: CloseEvent) => { - logger().debug({ - msg: `test ${connectionType} websocket connection closed`, - }); - - if (serverWs.readyState !== 3) { - // Not CLOSED - serverWs.close(clientEvt.code, clientEvt.reason); - } - }); - - // Add error handler - clientWs.addEventListener("error", () => { - logger().debug({ - msg: `test ${connectionType} websocket connection error`, - }); - - if (serverWs.readyState !== 3) { - // Not CLOSED - serverWs.close(1011, "Error in client websocket"); - } - }); - }, - onMessage: (evt: { data: any }) => { - logger().debug({ - msg: "received message from server", - dataType: typeof evt.data, - isBlob: evt.data instanceof Blob, - isArrayBuffer: evt.data instanceof ArrayBuffer, - dataConstructor: evt.data?.constructor?.name, - dataStr: - typeof evt.data === "string" ? evt.data.substring(0, 100) : undefined, - }); - - // Forward messages from server websocket to client websocket - if (clientWs.readyState === 1) { - // OPEN - // Handle Blob data - if (evt.data instanceof Blob) { - evt.data - .arrayBuffer() - .then((buffer) => { - logger().debug({ - msg: "converted blob to arraybuffer, sending", - bufferSize: buffer.byteLength, - }); - clientWs.send(buffer); - }) - .catch((error) => { - logger().error({ - msg: "failed to convert blob to arraybuffer", - error, - }); - }); - } else { - logger().debug({ - msg: "sending data directly", - dataType: typeof evt.data, - dataLength: - typeof evt.data === "string" ? evt.data.length : undefined, - }); - clientWs.send(evt.data); - } - } - }, - onClose: ( - event: { - wasClean: boolean; - code: number; - reason: string; - }, - serverWs: WSContext, - ) => { - logger().debug({ - msg: `server ${connectionType} websocket closed`, - wasClean: event.wasClean, - code: event.code, - reason: event.reason, - }); - - // HACK: Close socket in order to fix bug with Cloudflare leaving WS in closing state - // https://github.com/cloudflare/workerd/issues/2569 - serverWs.close(1000, "hack_force_close"); - - // Close the client websocket when the server websocket closes - if ( - clientWs && - clientWs.readyState !== clientWs.CLOSED && - clientWs.readyState !== clientWs.CLOSING - ) { - // Don't pass code/message since this may affect how close events are triggered - clientWs.close(1000, event.reason); - } - }, - onError: (error: unknown) => { - logger().error({ - msg: `error in server ${connectionType} websocket`, - error, - }); - - // Close the client websocket on error - if ( - clientWs && - clientWs.readyState !== clientWs.CLOSED && - clientWs.readyState !== clientWs.CLOSING - ) { - clientWs.close(1011, "Error in server websocket"); - } - }, - }; -} - -/** - * Handle SSE connection request - */ -async function handleSseConnectRequest( - c: HonoContext, - registryConfig: RegistryConfig, - _runConfig: RunConfig, - driver: ManagerDriver, -): Promise { - let encoding: Encoding | undefined; - try { - encoding = getRequestEncoding(c.req); - logger().debug({ msg: "sse connection request received", encoding }); - const params = ConnectRequestSchema.safeParse({ - query: getRequestQuery(c), - encoding: c.req.header(HEADER_ENCODING), - connParams: c.req.header(HEADER_CONN_PARAMS), + return c.json({ actor }); }); - - if (!params.success) { - logger().error({ - msg: "invalid connection parameters", - error: params.error, - }); - throw new errors.InvalidRequest(params.error); - } - - const query = params.data.query; - - // Parse connection parameters for authentication - const connParams = params.data.connParams - ? JSON.parse(params.data.connParams) - : undefined; - - // Authenticate the request - const authData = await authenticateEndpoint( - c, - driver, - registryConfig, - query, - ["connect"], - connParams, - ); - - // Get the actor ID - const { actorId } = await queryActor(c, query, driver); - invariant(actorId, "Missing actor ID"); - logger().debug({ msg: "sse connection to actor", actorId }); - - // Handle based on mode - logger().debug({ msg: "using custom proxy mode for sse connection" }); - const url = new URL("http://actor/connect/sse"); - - // Always build fresh request to prevent forwarding unwanted headers - const proxyRequestHeaderes = new Headers(); - proxyRequestHeaderes.set(HEADER_ENCODING, params.data.encoding); - if (params.data.connParams) { - proxyRequestHeaderes.set(HEADER_CONN_PARAMS, params.data.connParams); - } - if (authData) { - proxyRequestHeaderes.set(HEADER_AUTH_DATA, JSON.stringify(authData)); - } - - const proxyRequest = new Request(url, { headers: proxyRequestHeaderes }); - - return await driver.proxyRequest(c, proxyRequest, actorId); - } catch (error) { - // If we receive an error during setup, we send the error and close the socket immediately - // - // We have to return the error over SSE since SSE clients cannot read vanilla HTTP responses - - const { code, message, metadata } = deconstructError(error, logger(), { - sseEvent: "setup", - }); - - return streamSSE(c, async (stream) => { - try { - if (encoding) { - // Serialize and send the connection error - const errorMsg: protocol.ToClient = { - body: { - tag: "Error", - val: { - code, - message, - metadata: bufferToArrayBuffer(cbor.encode(metadata)), - actionId: null, - }, - }, - }; - - // Send the error message to the client - const serialized = serializeWithEncoding( - encoding, - errorMsg, - TO_CLIENT_VERSIONED, - ); - await stream.writeSSE({ - data: - typeof serialized === "string" - ? serialized - : Buffer.from(serialized).toString("base64"), - }); - } else { - // We don't know the encoding, send an error and close - await stream.writeSSE({ - data: code, - event: "error", - }); - } - } catch (serializeError) { - logger().error({ - msg: "failed to send error to sse client", - error: serializeError, - }); - await stream.writeSSE({ - data: "internal error during error handling", - event: "error", - }); - } - - // Stream will exit completely once function exits - }); - } -} - -/** - * Handle WebSocket connection request - */ -async function handleWebSocketConnectRequest( - c: HonoContext, - registryConfig: RegistryConfig, - runConfig: RunConfig, - driver: ManagerDriver, -): Promise { - const upgradeWebSocket = runConfig.getUpgradeWebSocket?.(); - if (!upgradeWebSocket) { - return c.text( - "WebSockets are not enabled for this driver. Use SSE instead.", - 400, - ); } - let encoding: Encoding | undefined; - try { - logger().debug({ msg: "websocket connection request received" }); - - // Parse configuration from Sec-WebSocket-Protocol header - // - // We use this instead of query parameters since this is more secure than - // query parameters. Query parameters often get logged. - // - // Browsers don't support using headers, so this is the only way to - // pass data securely. - const protocols = c.req.header("sec-websocket-protocol"); - const { queryRaw, encodingRaw, connParamsRaw } = - parseWebSocketProtocols(protocols); - - // Parse query - let queryUnvalidated: unknown; - try { - queryUnvalidated = JSON.parse(queryRaw!); - } catch (error) { - logger().error({ msg: "invalid query json", error }); - throw new errors.InvalidQueryJSON(error); - } - - // Parse conn params - let connParamsUnvalidated: unknown = null; - try { - if (connParamsRaw) { - connParamsUnvalidated = JSON.parse(connParamsRaw!); - } - } catch (error) { - logger().error({ msg: "invalid conn params", error }); - throw new errors.InvalidParams( - `Invalid params JSON: ${stringifyError(error)}`, - ); - } - - // We can't use the standard headers with WebSockets - // - // All other information will be sent over the socket itself, since that data needs to be E2EE - const params = ConnectWebSocketRequestSchema.safeParse({ - query: queryUnvalidated, - encoding: encodingRaw, - connParams: connParamsUnvalidated, - }); - if (!params.success) { - logger().error({ - msg: "invalid connection parameters", - error: params.error, - }); - throw new errors.InvalidRequest(params.error); - } - encoding = params.data.encoding; - - // Authenticate endpoint - const authData = await authenticateEndpoint( - c, - driver, - registryConfig, - params.data.query, - ["connect"], - connParamsRaw, - ); - - // Get the actor ID - const { actorId } = await queryActor(c, params.data.query, driver); - logger().debug({ msg: "found actor for websocket connection", actorId }); - invariant(actorId, "missing actor id"); - - // Proxy the WebSocket connection to the actor - // - // The proxyWebSocket handler will: - // 1. Validate the WebSocket upgrade request - // 2. Forward the request to the actor with the appropriate path - // 3. Handle the WebSocket pair and proxy messages between client and actor - return await driver.proxyWebSocket( - c, - PATH_CONNECT_WEBSOCKET, - actorId, - params.data.encoding, - params.data.connParams, - authData, - ); - } catch (error) { - // If we receive an error during setup, we send the error and close the socket immediately - // - // We have to return the error over WS since WebSocket clients cannot read vanilla HTTP responses - - const { code, message, metadata } = deconstructError(error, logger(), { - wsEvent: "setup", - }); - - return await upgradeWebSocket(() => ({ - onOpen: (_evt: unknown, ws: WSContext) => { - if (encoding) { - try { - // Serialize and send the connection error - const errorMsg: protocol.ToClient = { - body: { - tag: "Error", - val: { - code, - message, - metadata: bufferToArrayBuffer(cbor.encode(metadata)), - actionId: null, - }, - }, - }; - - // Send the error message to the client - const serialized = serializeWithEncoding( - encoding, - errorMsg, - TO_CLIENT_VERSIONED, - ); - ws.send(serialized); - - // Close the connection with an error code - ws.close(1011, code); - } catch (serializeError) { - logger().error({ - msg: "failed to send error to websocket client", - error: serializeError, - }); - ws.close(1011, "internal error during error handling"); - } - } else { - // We don't know the encoding so we send what we can - ws.close(1011, code); - } + // DELETE /actors/{actor_id} + { + const route = createRoute({ + method: "delete", + path: "/actors/{actor_id}", + request: { + params: z.object({ + actor_id: RivetIdSchema, + }), }, - }))(c, noopNext()); - } -} - -/** - * Handle a connection message request to a actor - * - * There is no authentication handler on this request since the connection - * token is used to authenticate the message. - */ -async function handleMessageRequest( - c: HonoContext, - _registryConfig: RegistryConfig, - _runConfig: RunConfig, - driver: ManagerDriver, -): Promise { - logger().debug({ msg: "connection message request received" }); - try { - const params = ConnMessageRequestSchema.safeParse({ - actorId: c.req.header(HEADER_ACTOR_ID), - connId: c.req.header(HEADER_CONN_ID), - encoding: c.req.header(HEADER_ENCODING), - connToken: c.req.header(HEADER_CONN_TOKEN), - }); - if (!params.success) { - logger().error({ - msg: "invalid connection parameters", - error: params.error, - }); - throw new errors.InvalidRequest(params.error); - } - const { actorId, connId, encoding, connToken } = params.data; - - // TODO: This endpoint can be used to exhause resources (DoS attack) on an actor if you know the actor ID: - // 1. Get the actor ID (usually this is reasonably secure, but we don't assume actor ID is sensitive) - // 2. Spam messages to the actor (the conn token can be invalid) - // 3. The actor will be exhausted processing messages — even if the token is invalid - // - // The solution is we need to move the authorization of the connection token to this request handler - // AND include the actor ID in the connection token so we can verify that it has permission to send - // a message to that actor. This would require changing the token to a JWT so we can include a secure - // payload, but this requires managing a private key & managing key rotations. - // - // All other solutions (e.g. include the actor name as a header or include the actor name in the actor ID) - // have exploits that allow the caller to send messages to arbitrary actors. - // - // Currently, we assume this is not a critical problem because requests will likely get rate - // limited before enough messages are passed to the actor to exhaust resources. - - const url = new URL("http://actor/connections/message"); - - // Always build fresh request to prevent forwarding unwanted headers - const proxyRequestHeaders = new Headers(); - proxyRequestHeaders.set(HEADER_ENCODING, encoding); - proxyRequestHeaders.set(HEADER_CONN_ID, connId); - proxyRequestHeaders.set(HEADER_CONN_TOKEN, connToken); - - const proxyRequest = new Request(url, { - method: "POST", - body: c.req.raw.body, - duplex: "half", - headers: proxyRequestHeaders, - }); - - return await driver.proxyRequest(c, proxyRequest, actorId); - } catch (error) { - logger().error({ msg: "error proxying connection message", error }); - - // Use ProxyError if it's not already an ActorError - if (!errors.ActorError.isActorError(error)) { - throw new errors.ProxyError("connection message", error); - } else { - throw error; - } - } -} - -/** - * Handle an action request to a actor - */ -async function handleActionRequest( - c: HonoContext, - registryConfig: RegistryConfig, - _runConfig: RunConfig, - driver: ManagerDriver, -): Promise { - try { - const actionName = c.req.param("action"); - logger().debug({ msg: "action call received", actionName }); - - const params = ConnectRequestSchema.safeParse({ - query: getRequestQuery(c), - encoding: c.req.header(HEADER_ENCODING), - connParams: c.req.header(HEADER_CONN_PARAMS), - }); - - if (!params.success) { - logger().error({ - msg: "invalid connection parameters", - error: params.error, - }); - throw new errors.InvalidRequest(params.error); - } - - // Parse connection parameters for authentication - const connParams = params.data.connParams - ? JSON.parse(params.data.connParams) - : undefined; - - // Authenticate the request - const authData = await authenticateEndpoint( - c, - driver, - registryConfig, - params.data.query, - ["action"], - connParams, - ); - - // Get the actor ID - const { actorId } = await queryActor(c, params.data.query, driver); - logger().debug({ msg: "found actor for action", actorId }); - invariant(actorId, "Missing actor ID"); - - const url = new URL( - `http://actor/action/${encodeURIComponent(actionName)}`, - ); - - // Always build fresh request to prevent forwarding unwanted headers - const proxyRequestHeaders = new Headers(); - proxyRequestHeaders.set(HEADER_ENCODING, params.data.encoding); - if (params.data.connParams) { - proxyRequestHeaders.set(HEADER_CONN_PARAMS, params.data.connParams); - } - if (authData) { - proxyRequestHeaders.set(HEADER_AUTH_DATA, JSON.stringify(authData)); - } - - const proxyRequest = new Request(url, { - method: "POST", - body: c.req.raw.body, - headers: proxyRequestHeaders, - }); - - return await driver.proxyRequest(c, proxyRequest, actorId); - } catch (error) { - logger().error({ - msg: "error in action handler", - error: stringifyError(error), - }); - - // Use ProxyError if it's not already an ActorError - if (!errors.ActorError.isActorError(error)) { - throw new errors.ProxyError("Action call", error); - } else { - throw error; - } - } -} - -/** - * Handle the resolve request to get a actor ID from a query - */ -async function handleResolveRequest( - c: HonoContext, - registryConfig: RegistryConfig, - driver: ManagerDriver, -): Promise { - const encoding = getRequestEncoding(c.req); - logger().debug({ msg: "resolve request encoding", encoding }); - - const params = ResolveRequestSchema.safeParse({ - query: getRequestQuery(c), - connParams: c.req.header(HEADER_CONN_PARAMS), - }); - if (!params.success) { - logger().error({ - msg: "invalid connection parameters", - error: params.error, - }); - throw new errors.InvalidRequest(params.error); - } - - // Parse connection parameters for authentication - const connParams = params.data.connParams - ? JSON.parse(params.data.connParams) - : undefined; - - const query = params.data.query; - - // Authenticate the request - await authenticateEndpoint(c, driver, registryConfig, query, [], connParams); - - // Get the actor ID - const { actorId } = await queryActor(c, query, driver); - logger().debug({ msg: "resolved actor", actorId }); - invariant(actorId, "Missing actor ID"); - - // Format response according to protocol - const response: protocol.HttpResolveResponse = { - actorId, - }; - const serialized = serializeWithEncoding( - encoding, - response, - HTTP_RESOLVE_RESPONSE_VERSIONED, - ); - return c.body(serialized); -} - -/** - * Handle raw HTTP requests to an actor - */ -async function handleRawHttpRequest( - c: HonoContext, - registryConfig: RegistryConfig, - _runConfig: RunConfig, - driver: ManagerDriver, -): Promise { - try { - const subpath = c.req.path.split("/raw/http/")[1] || ""; - logger().debug({ msg: "raw http request received", subpath }); - - // Get actor query from header (consistent with other endpoints) - const queryHeader = c.req.header(HEADER_ACTOR_QUERY); - if (!queryHeader) { - throw new errors.InvalidRequest("Missing actor query header"); - } - const query: ActorQuery = JSON.parse(queryHeader); - - // Parse connection parameters for authentication - const connParamsHeader = c.req.header(HEADER_CONN_PARAMS); - const connParams = connParamsHeader - ? JSON.parse(connParamsHeader) - : undefined; - - // Authenticate the request - const authData = await authenticateEndpoint( - c, - driver, - registryConfig, - query, - ["action"], - connParams, - ); - - // Get the actor ID - const { actorId } = await queryActor(c, query, driver); - logger().debug({ msg: "found actor for raw http", actorId }); - invariant(actorId, "Missing actor ID"); - - // Preserve the original URL's query parameters - const originalUrl = new URL(c.req.url); - const url = new URL( - `http://actor/raw/http/${subpath}${originalUrl.search}`, - ); - - // Forward the request to the actor - - logger().debug({ msg: "rewriting http url", from: c.req.url, to: url }); - - const proxyRequestHeaders = new Headers(c.req.raw.headers); - if (connParams) { - proxyRequestHeaders.set(HEADER_CONN_PARAMS, JSON.stringify(connParams)); - } - if (authData) { - proxyRequestHeaders.set(HEADER_AUTH_DATA, JSON.stringify(authData)); - } - - const proxyRequest = new Request(url, { - method: c.req.method, - headers: proxyRequestHeaders, - body: c.req.raw.body, - }); - - return await driver.proxyRequest(c, proxyRequest, actorId); - } catch (error) { - logger().error({ - msg: "error in raw http handler", - error: stringifyError(error), - }); - - // Use ProxyError if it's not already an ActorError - if (!errors.ActorError.isActorError(error)) { - throw new errors.ProxyError("Raw HTTP request", error); - } else { - throw error; - } - } -} - -/** - * Handle raw WebSocket requests to an actor - */ -async function handleRawWebSocketRequest( - c: HonoContext, - registryConfig: RegistryConfig, - runConfig: RunConfig, - driver: ManagerDriver, -): Promise { - const upgradeWebSocket = runConfig.getUpgradeWebSocket?.(); - if (!upgradeWebSocket) { - return c.text("WebSockets are not enabled for this driver.", 400); - } - - try { - const subpath = c.req.path.split("/raw/websocket/")[1] || ""; - logger().debug({ msg: "raw websocket request received", subpath }); - - // Parse protocols from Sec-WebSocket-Protocol header - const protocols = c.req.header("sec-websocket-protocol"); - const { - queryRaw: queryFromProtocol, - connParamsRaw: connParamsFromProtocol, - } = parseWebSocketProtocols(protocols); - - if (!queryFromProtocol) { - throw new errors.InvalidRequest("Missing query in WebSocket protocol"); - } - const query = JSON.parse(queryFromProtocol); - - // Parse connection parameters from protocol - let connParams: unknown; - if (connParamsFromProtocol) { - connParams = JSON.parse(connParamsFromProtocol); - } - - // Authenticate the request - const authData = await authenticateEndpoint( - c, - driver, - registryConfig, - query, - ["action"], - connParams, - ); - - // Get the actor ID - const { actorId } = await queryActor(c, query, driver); - logger().debug({ msg: "found actor for raw websocket", actorId }); - invariant(actorId, "Missing actor ID"); - - logger().debug({ msg: "using custom proxy mode for raw websocket" }); - - // Preserve the original URL's query parameters - const originalUrl = new URL(c.req.url); - const proxyPath = `${PATH_RAW_WEBSOCKET_PREFIX}${subpath}${originalUrl.search}`; - - logger().debug({ - msg: "manager router proxyWebSocket", - originalUrl: c.req.url, - subpath, - search: originalUrl.search, - proxyPath, + responses: buildOpenApiResponses( + ActorsDeleteResponseSchema, + validateBody, + ), }); - // For raw WebSocket, we need to use proxyWebSocket instead of proxyRequest - return await driver.proxyWebSocket( - c, - proxyPath, - actorId, - "json", // Default encoding for raw WebSocket - connParams, - authData, - ); - } catch (error) { - // If we receive an error during setup, we send the error and close the socket immediately - // - // We have to return the error over WS since WebSocket clients cannot read vanilla HTTP responses + router.openapi(route, async (c) => { + const { actor_id } = c.req.valid("param"); - const { code } = deconstructError(error, logger(), { - wsEvent: "setup", + // NOTE: The ManagerDriver interface doesn't currently have a deleteActor method + // This endpoint cannot be implemented until the driver supports actor deletion + throw new FeatureNotImplemented( + "Actor deletion - ManagerDriver lacks deleteActor method", + ); }); - - return await upgradeWebSocket(() => ({ - onOpen: (_evt: unknown, ws: WSContext) => { - // Close with message so we can see the error on the client - ws.close(1011, code); - }, - }))(c, noopNext()); } -} -function universalActorProxy({ - registryConfig, - runConfig, - driver, -}: { - registryConfig: RegistryConfig; - runConfig: RunConfig; - driver: ManagerDriver; -}): MiddlewareHandler { - return async (c, _next) => { - if (c.req.header("upgrade") === "websocket") { - return handleRawWebSocketRequest(c, registryConfig, runConfig, driver); - } else { - const queryHeader = c.req.header(HEADER_ACTOR_QUERY); - if (!queryHeader) { - throw new errors.InvalidRequest("Missing actor query header"); - } - const query = ActorQuerySchema.parse(JSON.parse(queryHeader)); - - const { actorId } = await queryActor(c, query, driver); + // Error handling + router.notFound(handleRouteNotFound); + router.onError(handleRouteError); - const url = new URL(c.req.url); - url.hostname = "actor"; - url.pathname = url.pathname - .replace(new RegExp(`^${runConfig.basePath}`, ""), "") - .replace(/^\/?registry\/actors/, "") - .replace(/^\/?actors/, ""); // Remove /registry prefix if present - - const proxyRequest = new Request(url, { - method: c.req.method, - headers: c.req.raw.headers, - body: c.req.raw.body, - }); - return await driver.proxyRequest(c, proxyRequest, actorId); - } - }; + return { router: router as Hono, openapi: router }; } diff --git a/packages/rivetkit/src/mod.ts b/packages/rivetkit/src/mod.ts index 3cd62b030..4b7cc14b8 100644 --- a/packages/rivetkit/src/mod.ts +++ b/packages/rivetkit/src/mod.ts @@ -3,7 +3,6 @@ export * from "@/actor/mod"; export { type AnyClient, type Client, - type ClientDriver, createClientWithDriver, } from "@/client/client"; export { InlineWebSocketAdapter2 } from "@/common/inline-websocket-adapter2"; @@ -13,7 +12,6 @@ export { createFileSystemDriver, createMemoryDriver, } from "@/drivers/file-system/mod"; -export { createInlineClientDriver } from "@/inline-client-driver/mod"; // Re-export important protocol types and utilities needed by drivers export type { ActorQuery } from "@/manager/protocol/query"; export * from "@/registry/mod"; diff --git a/packages/rivetkit/src/registry/config.ts b/packages/rivetkit/src/registry/config.ts index bb9e970c2..9a6428211 100644 --- a/packages/rivetkit/src/registry/config.ts +++ b/packages/rivetkit/src/registry/config.ts @@ -5,7 +5,7 @@ import type { ActorDefinition, AnyActorDefinition } from "@/actor/definition"; export const ActorsSchema = z.record( z.string(), - z.custom>(), + z.custom>(), ); export type RegistryActors = z.infer; diff --git a/packages/rivetkit/src/registry/mod.ts b/packages/rivetkit/src/registry/mod.ts index 94eba9953..895c61157 100644 --- a/packages/rivetkit/src/registry/mod.ts +++ b/packages/rivetkit/src/registry/mod.ts @@ -6,7 +6,6 @@ import { getPinoLevel, } from "@/common/log"; import { chooseDefaultDriver } from "@/drivers/default"; -import { createInlineClientDriver } from "@/inline-client-driver/mod"; import { getInspectorUrl } from "@/inspector/utils"; import { createManagerRouter } from "@/manager/router"; import pkg from "../../package.json" with { type: "json" }; @@ -70,17 +69,15 @@ export class Registry { // Create router const managerDriver = driver.manager(this.#config, config); - const clientDriver = createInlineClientDriver(managerDriver); const { router: hono } = createManagerRouter( this.#config, config, - clientDriver, managerDriver, false, ); // Create client - const client = createClientWithDriver(clientDriver); + const client = createClientWithDriver(managerDriver, config); const driverLog = managerDriver.extraStartupLog?.() ?? {}; logger().info({ @@ -98,6 +95,7 @@ export class Registry { const displayInfo = managerDriver.displayInformation(); console.log(); console.log(` RivetKit ${pkg.version} (${displayInfo.name})`); + console.log(` - Endpoint: http://127.0.0.1:6420`); for (const [k, v] of Object.entries(displayInfo.properties)) { const padding = " ".repeat(Math.max(0, 13 - k.length)); console.log(` - ${k}:${padding}${v}`); @@ -109,18 +107,14 @@ export class Registry { } // Create runner - if (config.role === "all" || config.role === "runner") { - const inlineClient = createClientWithDriver( - createInlineClientDriver(managerDriver), - ); - const _actorDriver = driver.actor( - this.#config, - config, - managerDriver, - inlineClient, - ); - // TODO: What do we do with the actor driver here? - } + // + // Even though we do not use the return value, this is required to start the code that will handle incoming actors + const _actorDriver = driver.actor( + this.#config, + config, + managerDriver, + client, + ); return { client, diff --git a/packages/rivetkit/src/registry/run-config.ts b/packages/rivetkit/src/registry/run-config.ts index 62521702f..f457c4a28 100644 --- a/packages/rivetkit/src/registry/run-config.ts +++ b/packages/rivetkit/src/registry/run-config.ts @@ -2,15 +2,13 @@ import type { cors } from "hono/cors"; import type { Logger } from "pino"; import { z } from "zod"; import type { ActorDriverBuilder } from "@/actor/driver"; +import { ClientConfigSchema } from "@/client/config"; import { LogLevelSchema } from "@/common/log"; import { InspectorConfigSchema } from "@/inspector/config"; import type { ManagerDriverBuilder } from "@/manager/driver"; -import type { UpgradeWebSocket } from "@/utils"; type CorsOptions = NonNullable[0]>; -export type GetUpgradeWebSocket = () => UpgradeWebSocket; - export const DriverConfigSchema = z.object({ /** Machine-readable name to identify this driver by. */ name: z.string(), @@ -21,47 +19,34 @@ export const DriverConfigSchema = z.object({ export type DriverConfig = z.infer; /** Base config used for the actor config across all platforms. */ -export const RunConfigSchema = z - .object({ - driver: DriverConfigSchema.optional(), - - /** Endpoint to connect to the Rivet engine. Can be configured via RIVET_ENGINE env var. */ - engine: z.string().optional(), - - // This is a function to allow for lazy configuration of upgradeWebSocket on the - // fly. This is required since the dependencies that profie upgradeWebSocket - // (specifically Node.js) can sometimes only be specified after the router is - // created or must be imported async using `await import(...)` - getUpgradeWebSocket: z.custom().optional(), - - role: z.enum(["all", "server", "runner"]).optional().default("all"), - - /** CORS configuration for the router. Uses Hono's CORS middleware options. */ - cors: z.custom().optional(), - - maxIncomingMessageSize: z.number().optional().default(65_536), - - inspector: InspectorConfigSchema, - - /** - * Base path for the router. This is used to prefix all routes. - * For example, if the base path is `/api`, then the route `/actors` will be - * available at `/api/actors`. - */ - basePath: z.string().optional().default("/"), - - /** Disable welcome message. */ - noWelcome: z.boolean().optional().default(false), - - logging: z - .object({ - baseLogger: z.custom().optional(), - level: LogLevelSchema.optional(), - }) - .optional() - .default({}), - }) - .default({}); +export const RunConfigSchema = ClientConfigSchema.extend({ + driver: DriverConfigSchema.optional(), + + /** CORS configuration for the router. Uses Hono's CORS middleware options. */ + cors: z.custom().optional(), + + maxIncomingMessageSize: z.number().optional().default(65_536), + + inspector: InspectorConfigSchema, + + /** + * Base path for the router. This is used to prefix all routes. + * For example, if the base path is `/api`, then the route `/actors` will be + * available at `/api/actors`. + */ + basePath: z.string().optional().default("/"), + + /** Disable welcome message. */ + noWelcome: z.boolean().optional().default(false), + + logging: z + .object({ + baseLogger: z.custom().optional(), + level: LogLevelSchema.optional(), + }) + .optional() + .default({}), +}).default({}); export type RunConfig = z.infer; export type RunConfigInput = z.input; diff --git a/packages/rivetkit/src/registry/serve.ts b/packages/rivetkit/src/registry/serve.ts index 34f20396d..8dd3a55bf 100644 --- a/packages/rivetkit/src/registry/serve.ts +++ b/packages/rivetkit/src/registry/serve.ts @@ -21,7 +21,8 @@ export async function crossPlatformServe( } // Mount registry - app.route("/registry", rivetKitRouter); + // app.route("/registry", rivetKitRouter); + app.route("/", rivetKitRouter); // Import @hono/node-ws let createNodeWebSocket: any; @@ -41,9 +42,7 @@ export async function crossPlatformServe( }); // Start server - const port = Number.parseInt( - getEnvUniversal("PORT") ?? getEnvUniversal("PORT_HTTP") ?? "8080", - ); + const port = 6420; const server = serve({ fetch: app.fetch, port }, () => logger().info({ msg: "server listening", port }), ); diff --git a/packages/rivetkit/src/remote-manager-driver/actor-http-client.ts b/packages/rivetkit/src/remote-manager-driver/actor-http-client.ts new file mode 100644 index 000000000..87302a7dc --- /dev/null +++ b/packages/rivetkit/src/remote-manager-driver/actor-http-client.ts @@ -0,0 +1,72 @@ +import type { ClientConfig } from "@/client/config"; +import { getEndpoint } from "./api-utils"; + +export async function sendHttpRequestToActor( + runConfig: ClientConfig, + actorId: string, + actorRequest: Request, +): Promise { + // Route through guard port + const url = new URL(actorRequest.url); + const endpoint = getEndpoint(runConfig); + const guardUrl = `${endpoint}${url.pathname}${url.search}`; + + // Handle body properly based on method and presence + let bodyToSend: ArrayBuffer | null = null; + const guardHeaders = buildGuardHeadersForHttp(actorRequest, actorId); + + if ( + actorRequest.body && + actorRequest.method !== "GET" && + actorRequest.method !== "HEAD" + ) { + if (actorRequest.bodyUsed) { + throw new Error("Request body has already been consumed"); + } + + // TODO: This buffers the entire request in memory every time. We + // need to properly implement streaming bodies. + // Clone and read the body to ensure it can be sent + const clonedRequest = actorRequest.clone(); + bodyToSend = await clonedRequest.arrayBuffer(); + + // If this is a streaming request, we need to convert the headers + // for the basic array buffer + guardHeaders.delete("transfer-encoding"); + guardHeaders.set( + "content-length", + String((bodyToSend as ArrayBuffer).byteLength), + ); + } + + const guardRequest = new Request(guardUrl, { + method: actorRequest.method, + headers: guardHeaders, + body: bodyToSend, + }); + + return mutableResponse(await fetch(guardRequest)); +} + +function mutableResponse(fetchRes: Response): Response { + // We cannot return the raw response from `fetch` since the response type is not mutable. + // + // In order for middleware to be able to mutate the response, we need to build a new Response object that is mutable. + return new Response(fetchRes.body, fetchRes); +} + +function buildGuardHeadersForHttp( + actorRequest: Request, + actorId: string, +): Headers { + const headers = new Headers(); + // Copy all headers from the original request + for (const [key, value] of actorRequest.headers.entries()) { + headers.set(key, value); + } + // Add guard-specific headers + headers.set("x-rivet-target", "actor"); + headers.set("x-rivet-actor", actorId); + headers.set("x-rivet-port", "main"); + return headers; +} diff --git a/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts b/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts new file mode 100644 index 000000000..abb7e9db0 --- /dev/null +++ b/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts @@ -0,0 +1,60 @@ +import { + HEADER_AUTH_DATA, + HEADER_CONN_PARAMS, + HEADER_ENCODING, +} from "@/actor/router-endpoints"; +import type { ClientConfig } from "@/client/config"; +import { importWebSocket } from "@/common/websocket"; +import type { Encoding, UniversalWebSocket } from "@/mod"; +import { getEndpoint } from "./api-utils"; +import { logger } from "./log"; + +export async function openWebSocketToActor( + runConfig: ClientConfig, + path: string, + actorId: string, + encoding: Encoding, + params: unknown, +): Promise { + const WebSocket = await importWebSocket(); + + // WebSocket connections go through guard + const endpoint = getEndpoint(runConfig); + const guardUrl = `${endpoint}${path}`; + + logger().debug({ + msg: "opening websocket to actor via guard", + actorId, + path, + guardUrl, + }); + + // Create WebSocket connection + const ws = new WebSocket(guardUrl, { + headers: buildGuardHeadersForWebSocket(actorId, encoding, params), + }); + + logger().debug({ msg: "websocket connection opened", actorId }); + + return ws as UniversalWebSocket; +} + +export function buildGuardHeadersForWebSocket( + actorId: string, + encoding: Encoding, + params?: unknown, + authData?: unknown, +): Record { + const headers: Record = {}; + headers["x-rivet-target"] = "actor"; + headers["x-rivet-actor"] = actorId; + headers["x-rivet-port"] = "main"; + headers[HEADER_ENCODING] = encoding; + if (params) { + headers[HEADER_CONN_PARAMS] = JSON.stringify(params); + } + if (authData) { + headers[HEADER_AUTH_DATA] = JSON.stringify(authData); + } + return headers; +} diff --git a/packages/rivetkit/src/remote-manager-driver/api-endpoints.ts b/packages/rivetkit/src/remote-manager-driver/api-endpoints.ts new file mode 100644 index 000000000..ae00a558a --- /dev/null +++ b/packages/rivetkit/src/remote-manager-driver/api-endpoints.ts @@ -0,0 +1,79 @@ +import { serializeActorKey } from "@/actor/keys"; +import type { ClientConfig } from "@/client/client"; +import type { + ActorsCreateRequest, + ActorsCreateResponse, +} from "@/manager-api/routes/actors-create"; +import type { ActorsDeleteResponse } from "@/manager-api/routes/actors-delete"; +import type { ActorsGetResponse } from "@/manager-api/routes/actors-get"; +import type { ActorsGetByIdResponse } from "@/manager-api/routes/actors-get-by-id"; +import type { + ActorsGetOrCreateByIdRequest, + ActorsGetOrCreateByIdResponse, +} from "@/manager-api/routes/actors-get-or-create-by-id"; +import type { RivetId } from "@/manager-api/routes/common"; +import { apiCall } from "./api-utils"; + +// MARK: Get actor +export async function getActor( + config: ClientConfig, + actorId: RivetId, +): Promise { + return apiCall( + config, + "GET", + `/actors/${encodeURIComponent(actorId)}`, + ); +} + +// MARK: Get actor by id +export async function getActorById( + config: ClientConfig, + name: string, + key: string[], +): Promise { + const serializedKey = serializeActorKey(key); + return apiCall( + config, + "GET", + `/actors/by-id?name=${encodeURIComponent(name)}&key=${encodeURIComponent(serializedKey)}`, + ); +} + +// MARK: Get or create actor by id +export async function getOrCreateActorById( + config: ClientConfig, + request: ActorsGetOrCreateByIdRequest, +): Promise { + return apiCall( + config, + "PUT", + `/actors/by-id`, + request, + ); +} + +// MARK: Create actor +export async function createActor( + config: ClientConfig, + request: ActorsCreateRequest, +): Promise { + return apiCall( + config, + "POST", + `/actors`, + request, + ); +} + +// MARK: Destroy actor +export async function destroyActor( + config: ClientConfig, + actorId: RivetId, +): Promise { + return apiCall( + config, + "DELETE", + `/actors/${encodeURIComponent(actorId)}`, + ); +} diff --git a/packages/rivetkit/src/remote-manager-driver/api-utils.ts b/packages/rivetkit/src/remote-manager-driver/api-utils.ts new file mode 100644 index 000000000..0b4e48bcc --- /dev/null +++ b/packages/rivetkit/src/remote-manager-driver/api-utils.ts @@ -0,0 +1,43 @@ +import type { ClientConfig } from "@/client/config"; +import { sendHttpRequest } from "@/client/utils"; +import { logger } from "./log"; + +// Error class for Engine API errors +export class EngineApiError extends Error { + constructor( + public readonly group: string, + public readonly code: string, + message?: string, + ) { + super(message || `Engine API error: ${group}/${code}`); + this.name = "EngineApiError"; + } +} + +export function getEndpoint(config: ClientConfig) { + return config.endpoint ?? "http://127.0.0.1:6420"; +} + +// Helper function for making API calls +export async function apiCall( + config: ClientConfig, + method: "GET" | "POST" | "PUT" | "DELETE", + path: string, + body?: TInput, +): Promise { + const endpoint = getEndpoint(config); + const url = `${endpoint}${path}${path.includes("?") ? "&" : "?"}namespace=${encodeURIComponent(config.namespace)}`; + + logger().debug({ msg: "making api call", method, url }); + + return await sendHttpRequest({ + method, + url, + headers: {}, + body, + encoding: "json", + skipParseResponse: false, + requestVersionedDataHandler: undefined, + responseVersionedDataHandler: undefined, + }); +} diff --git a/packages/rivetkit/src/inline-client-driver/log.ts b/packages/rivetkit/src/remote-manager-driver/log.ts similarity index 62% rename from packages/rivetkit/src/inline-client-driver/log.ts rename to packages/rivetkit/src/remote-manager-driver/log.ts index 17a0d5f31..46e83a5dc 100644 --- a/packages/rivetkit/src/inline-client-driver/log.ts +++ b/packages/rivetkit/src/remote-manager-driver/log.ts @@ -1,5 +1,5 @@ import { getLogger } from "@/common//log"; export function logger() { - return getLogger("inline-client-driver"); + return getLogger("remote-manager-driver"); } diff --git a/packages/rivetkit/src/remote-manager-driver/mod.ts b/packages/rivetkit/src/remote-manager-driver/mod.ts new file mode 100644 index 000000000..29f9fe158 --- /dev/null +++ b/packages/rivetkit/src/remote-manager-driver/mod.ts @@ -0,0 +1,274 @@ +import * as cbor from "cbor-x"; +import type { Hono, Context as HonoContext } from "hono"; +import invariant from "invariant"; +import { ActorAlreadyExists } from "@/actor/errors"; +import { deserializeActorKey, serializeActorKey } from "@/actor/keys"; +import type { ClientConfig } from "@/client/client"; +import { noopNext } from "@/common/utils"; +import type { + ActorOutput, + CreateInput, + GetForIdInput, + GetOrCreateWithKeyInput, + GetWithKeyInput, + ManagerDisplayInformation, + ManagerDriver, +} from "@/driver-helpers/mod"; +import type { ManagerInspector } from "@/inspector/manager"; +import type { Encoding, RegistryConfig, UniversalWebSocket } from "@/mod"; +import type { RunConfig } from "@/registry/run-config"; +import { sendHttpRequestToActor } from "./actor-http-client"; +import { + buildGuardHeadersForWebSocket, + openWebSocketToActor, +} from "./actor-websocket-client"; +import { + createActor, + destroyActor, + getActor, + getActorById, + getOrCreateActorById, +} from "./api-endpoints"; +import { EngineApiError, getEndpoint } from "./api-utils"; +import { logger } from "./log"; +import { createWebSocketProxy } from "./ws-proxy"; + +// TODO: +// // Lazily import the dynamic imports so we don't have to turn `createClient` in to an async fn +// const dynamicImports = (async () => { +// // Import dynamic dependencies +// const [WebSocket, EventSource] = await Promise.all([ +// importWebSocket(), +// importEventSource(), +// ]); +// return { +// WebSocket, +// EventSource, +// }; +// })(); + +export class RemoteManagerDriver implements ManagerDriver { + #config: ClientConfig; + + constructor(runConfig: ClientConfig) { + this.#config = runConfig; + } + + async getForId({ + c, + name, + actorId, + }: GetForIdInput): Promise { + // Fetch from API if not in cache + try { + const response = await getActor(this.#config, actorId); + + // Validate name matches + if (response.actor.name !== name) { + logger().debug({ + msg: "actor name mismatch from api", + actorId, + apiName: response.actor.name, + requestedName: name, + }); + return undefined; + } + + const keyRaw = response.actor.key; + invariant(keyRaw, `actor ${actorId} should have key`); + const key = deserializeActorKey(keyRaw); + + return { + actorId, + name, + key, + }; + } catch (error) { + if ( + error instanceof EngineApiError && + (error as EngineApiError).group === "actor" && + (error as EngineApiError).code === "not_found" + ) { + return undefined; + } + throw error; + } + } + + async getWithKey({ + c, + name, + key, + }: GetWithKeyInput): Promise { + logger().debug({ msg: "getWithKey: searching for actor", name, key }); + + // If not in local cache, fetch by key from API + try { + const response = await getActorById(this.#config, name, key); + + if (!response.actor_id) { + return undefined; + } + + const actorId = response.actor_id; + + logger().debug({ + msg: "getWithKey: found actor via api", + actorId, + name, + key, + }); + + return { + actorId, + name, + key, + }; + } catch (error) { + if ( + error instanceof EngineApiError && + (error as EngineApiError).group === "actor" && + (error as EngineApiError).code === "not_found" + ) { + return undefined; + } + throw error; + } + } + + async getOrCreateWithKey( + input: GetOrCreateWithKeyInput, + ): Promise { + const { c, name, key, input: actorInput, region } = input; + + logger().info({ + msg: "getOrCreateWithKey: getting or creating actor via engine api", + name, + key, + }); + + const response = await getOrCreateActorById(this.#config, { + name, + key: serializeActorKey(key), + runner_name_selector: this.#config.runnerName, + input: input ? cbor.encode(actorInput).toString("base64") : undefined, + crash_policy: "sleep", + }); + + const actorId = response.actor_id; + + logger().info({ + msg: "getOrCreateWithKey: actor ready", + actorId, + name, + key, + created: response.created, + }); + + return { + actorId, + name, + key, + }; + } + + async createActor({ + c, + name, + key, + input, + }: CreateInput): Promise { + logger().info({ msg: "creating actor via engine api", name, key }); + + // Create actor via engine API + const result = await createActor(this.#config, { + name, + runner_name_selector: this.#config.runnerName, + key: serializeActorKey(key), + input: input ? cbor.encode(input).toString("base64") : null, + crash_policy: "sleep", + }); + const actorId = result.actor.actor_id; + + logger().info({ msg: "actor created", actorId, name, key }); + + return { + actorId, + name, + key, + }; + } + + async destroyActor(actorId: string): Promise { + logger().info({ msg: "destroying actor via engine api", actorId }); + + await destroyActor(this.#config, actorId); + + logger().info({ msg: "actor destroyed", actorId }); + } + + async sendRequest(actorId: string, actorRequest: Request): Promise { + return await sendHttpRequestToActor(this.#config, actorId, actorRequest); + } + + async openWebSocket( + path: string, + actorId: string, + encoding: Encoding, + params: unknown, + ): Promise { + return await openWebSocketToActor( + this.#config, + path, + actorId, + encoding, + params, + ); + } + + async proxyRequest( + _c: HonoContext, + actorRequest: Request, + actorId: string, + ): Promise { + return await sendHttpRequestToActor(this.#config, actorId, actorRequest); + } + + async proxyWebSocket( + c: HonoContext, + path: string, + actorId: string, + encoding: Encoding, + params: unknown, + authData: unknown, + ): Promise { + const upgradeWebSocket = this.#config.getUpgradeWebSocket?.(); + invariant(upgradeWebSocket, "missing getUpgradeWebSocket"); + + const endpoint = getEndpoint(this.#config); + const guardUrl = `${endpoint}${path}`; + const wsGuardUrl = guardUrl.replace("http://", "ws://"); + + logger().debug({ + msg: "forwarding websocket to actor via guard", + actorId, + path, + guardUrl, + }); + + // Build headers + const headers = buildGuardHeadersForWebSocket( + actorId, + encoding, + params, + authData, + ); + const args = await createWebSocketProxy(c, wsGuardUrl, headers); + + return await upgradeWebSocket(() => args)(c, noopNext()); + } + + displayInformation(): ManagerDisplayInformation { + return { name: "Remote", properties: {} }; + } +} diff --git a/packages/rivetkit/src/drivers/engine/ws-proxy.ts b/packages/rivetkit/src/remote-manager-driver/ws-proxy.ts similarity index 97% rename from packages/rivetkit/src/drivers/engine/ws-proxy.ts rename to packages/rivetkit/src/remote-manager-driver/ws-proxy.ts index e0a551894..c631e1d5a 100644 --- a/packages/rivetkit/src/drivers/engine/ws-proxy.ts +++ b/packages/rivetkit/src/remote-manager-driver/ws-proxy.ts @@ -1,8 +1,6 @@ import type { Context as HonoContext } from "hono"; import type { WSContext } from "hono/ws"; -import invariant from "invariant"; -import type { CloseEvent } from "ws"; -import { deconstructError, stringifyError } from "@/common/utils"; +import { stringifyError } from "@/common/utils"; import { importWebSocket } from "@/common/websocket"; import type { UpgradeWebSocketArgs } from "@/mod"; import { logger } from "./log"; diff --git a/packages/rivetkit/src/serde.ts b/packages/rivetkit/src/serde.ts index 73d4b3393..79d50fe0e 100644 --- a/packages/rivetkit/src/serde.ts +++ b/packages/rivetkit/src/serde.ts @@ -40,13 +40,16 @@ export function wsBinaryTypeForEncoding( export function serializeWithEncoding( encoding: Encoding, value: T, - versionedDataHandler: VersionedDataHandler, + versionedDataHandler: VersionedDataHandler | undefined, ): Uint8Array | string { if (encoding === "json") { return jsonStringifyCompat(value); } else if (encoding === "cbor") { return cbor.encode(value); } else if (encoding === "bare") { + if (!versionedDataHandler) { + throw new Error("VersionedDataHandler is required for 'bare' encoding"); + } return versionedDataHandler.serializeWithEmbeddedVersion(value); } else { assertUnreachable(encoding); @@ -56,7 +59,7 @@ export function serializeWithEncoding( export function deserializeWithEncoding( encoding: Encoding, buffer: Uint8Array | string, - versionedDataHandler: VersionedDataHandler, + versionedDataHandler: VersionedDataHandler | undefined, ): T { if (encoding === "json") { if (typeof buffer === "string") { @@ -77,6 +80,9 @@ export function deserializeWithEncoding( typeof buffer !== "string", "buffer cannot be string for bare encoding", ); + if (!versionedDataHandler) { + throw new Error("VersionedDataHandler is required for 'bare' encoding"); + } return versionedDataHandler.deserializeWithEmbeddedVersion(buffer); } else { assertUnreachable(encoding); diff --git a/packages/rivetkit/src/test/mod.ts b/packages/rivetkit/src/test/mod.ts index b99c45fe8..a6f9ec22d 100644 --- a/packages/rivetkit/src/test/mod.ts +++ b/packages/rivetkit/src/test/mod.ts @@ -5,7 +5,6 @@ import { type TestContext, vi } from "vitest"; import { type Client, createClient } from "@/client/mod"; import { chooseDefaultDriver } from "@/drivers/default"; import { createFileSystemOrMemoryDriver } from "@/drivers/file-system/mod"; -import { createInlineClientDriver } from "@/inline-client-driver/mod"; import { getInspectorUrl } from "@/inspector/utils"; import { createManagerRouter } from "@/manager/router"; import type { Registry } from "@/registry/mod"; @@ -28,11 +27,9 @@ function serve(registry: Registry, inputConfig?: InputConfig): ServerType { const runConfig = RunConfigSchema.parse(inputConfig); const driver = inputConfig.driver ?? createFileSystemOrMemoryDriver(false); const managerDriver = driver.manager(registry.config, config); - const inlineClientDriver = createInlineClientDriver(managerDriver); const { router } = createManagerRouter( registry.config, runConfig, - inlineClientDriver, managerDriver, false, ); @@ -89,18 +86,21 @@ export async function setupTest>( async () => await new Promise((resolve) => server.close(() => resolve())), ); - // Create client - const client = createClient(`http://127.0.0.1:${port}`); - c.onTestFinished(async () => await client.dispose()); - - return { - client, - mockDriver: { - actorDriver: { - setCreateVarsContext: setDriverContextFn, - }, - }, - }; + throw "TODO: Fix engine port"; + + // // TODO: Figure out how to make this the correct endpoint + // // Create client + // const client = createClient(`http://127.0.0.1:${port}`); + // c.onTestFinished(async () => await client.dispose()); + // + // return { + // client, + // mockDriver: { + // actorDriver: { + // setCreateVarsContext: setDriverContextFn, + // }, + // }, + // }; } export async function getPort(): Promise { diff --git a/packages/rivetkit/tests/actor-types.test.ts b/packages/rivetkit/tests/actor-types.test.ts index 1639ca277..cf9f04312 100644 --- a/packages/rivetkit/tests/actor-types.test.ts +++ b/packages/rivetkit/tests/actor-types.test.ts @@ -26,10 +26,6 @@ describe("ActorDefinition", () => { bar: string; } - interface TestAuthData { - baz: string; - } - interface TestDatabase { createClient: (ctx: { getDatabase: () => Promise; @@ -46,7 +42,6 @@ describe("ActorDefinition", () => { TestConnState, TestVars, TestInput, - TestAuthData, TestDatabase, TestActions >; @@ -59,7 +54,6 @@ describe("ActorDefinition", () => { TestConnState, TestVars, TestInput, - TestAuthData, TestDatabase > >(); @@ -76,7 +70,6 @@ describe("ActorDefinition", () => { TestConnState, TestVars, TestInput, - TestAuthData, TestDatabase > >(); diff --git a/packages/rivetkit/tests/driver-engine.test.ts b/packages/rivetkit/tests/driver-engine.test.ts index 679c063cc..a1d173cc9 100644 --- a/packages/rivetkit/tests/driver-engine.test.ts +++ b/packages/rivetkit/tests/driver-engine.test.ts @@ -2,7 +2,6 @@ import { join } from "node:path"; import { createClientWithDriver } from "@/client/client"; import { createTestRuntime, runDriverTests } from "@/driver-test-suite/mod"; import { createEngineDriver } from "@/drivers/engine/mod"; -import { createInlineClientDriver } from "@/inline-client-driver/mod"; import { RunConfigSchema } from "@/registry/run-config"; import { getPort } from "@/test/mod"; @@ -20,6 +19,7 @@ runDriverTests({ // Get configuration from environment or use defaults const endpoint = process.env.RIVET_ENDPOINT || "http://localhost:6420"; const namespace = `test-${crypto.randomUUID().slice(0, 8)}`; + const runnerName = "test-runner"; // Create namespace const response = await fetch(`${endpoint}/namespaces`, { @@ -40,7 +40,7 @@ runDriverTests({ const driverConfig = createEngineDriver({ endpoint, namespace, - runnerName: "test-runner", + runnerName, totalSlots: 1000, }); @@ -50,8 +50,7 @@ runDriverTests({ getUpgradeWebSocket: () => undefined, }); const managerDriver = driverConfig.manager(registry.config, runConfig); - const inlineClientDriver = createInlineClientDriver(managerDriver); - const inlineClient = createClientWithDriver(inlineClientDriver); + const inlineClient = createClientWithDriver(managerDriver, runConfig); const actorDriver = driverConfig.actor( registry.config, runConfig, @@ -60,6 +59,11 @@ runDriverTests({ ); return { + rivetEngine: { + endpoint: "http://127.0.0.1:6420", + namespace: namespace, + runnerName: runnerName, + }, driver: driverConfig, cleanup: async () => { await actorDriver.shutdown?.(true); diff --git a/vitest.base.ts b/vitest.base.ts index 6dee536c5..2cd4510fb 100644 --- a/vitest.base.ts +++ b/vitest.base.ts @@ -7,7 +7,7 @@ export default { // Enable parallelism sequence: { // TODO: This breaks fake timers, unsure how to make tests run in parallel within the same file - concurrent: true, + // concurrent: true, }, env: { // Enable logging From a634459a18391a42620eeef7666093f7fbda2e49 Mon Sep 17 00:00:00 2001 From: Kacper Wojciechowski <39823706+jog1t@users.noreply.github.com> Date: Tue, 16 Sep 2025 05:07:52 +0200 Subject: [PATCH 2/3] fix(inspector): make inspector work again (#1228) --- packages/rivetkit/scripts/dump-openapi.ts | 3 ++ packages/rivetkit/src/actor/router.ts | 21 +++++++----- packages/rivetkit/src/inspector/config.ts | 10 ++++-- packages/rivetkit/src/manager/router.ts | 42 ++++++++++++++++++----- packages/rivetkit/src/registry/mod.ts | 10 ++---- 5 files changed, 59 insertions(+), 27 deletions(-) diff --git a/packages/rivetkit/scripts/dump-openapi.ts b/packages/rivetkit/scripts/dump-openapi.ts index b582e50d6..4022e0729 100644 --- a/packages/rivetkit/scripts/dump-openapi.ts +++ b/packages/rivetkit/scripts/dump-openapi.ts @@ -16,6 +16,9 @@ function main() { const driverConfig: RunConfig = RunConfigSchema.parse({ driver: createFileSystemOrMemoryDriver(false), getUpgradeWebSocket: () => () => unimplemented(), + inspector: { + enabled: false, + }, }); const managerDriver: ManagerDriver = { diff --git a/packages/rivetkit/src/actor/router.ts b/packages/rivetkit/src/actor/router.ts index 3802215f8..53025ba59 100644 --- a/packages/rivetkit/src/actor/router.ts +++ b/packages/rivetkit/src/actor/router.ts @@ -1,4 +1,5 @@ import { Hono, type Context as HonoContext } from "hono"; +import { cors } from "hono/cors"; import invariant from "invariant"; import { EncodingSchema } from "@/actor/protocol/serde"; import { @@ -241,14 +242,18 @@ export function createActorRouter( router.route( "/inspect", new Hono() - .use(secureInspector(runConfig), async (c, next) => { - const inspector = (await actorDriver.loadActor(c.env.actorId)) - .inspector; - invariant(inspector, "inspector not supported on this platform"); - - c.set("inspector", inspector); - await next(); - }) + .use( + cors(runConfig.inspector.cors), + secureInspector(runConfig), + async (c, next) => { + const inspector = (await actorDriver.loadActor(c.env.actorId)) + .inspector; + invariant(inspector, "inspector not supported on this platform"); + + c.set("inspector", inspector); + return next(); + }, + ) .route("/", createActorInspectorRouter()), ); } diff --git a/packages/rivetkit/src/inspector/config.ts b/packages/rivetkit/src/inspector/config.ts index 88cd99cf5..c612f498d 100644 --- a/packages/rivetkit/src/inspector/config.ts +++ b/packages/rivetkit/src/inspector/config.ts @@ -24,6 +24,7 @@ const defaultEnabled = () => { const defaultInspectorOrigins = [ "http://localhost:43708", + "http://localhost:43709", "https://studio.rivet.gg", ]; @@ -40,10 +41,13 @@ const defaultCors: CorsOptions = { }, allowMethods: ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], allowHeaders: [ - "Content-Type", "Authorization", - HEADER_ACTOR_QUERY, - "last-event-id", + "Content-Type", + "User-Agent", + "baggage", + "sentry-trace", + "x-rivet-actor", + "x-rivet-target", ], maxAge: 3600, credentials: true, diff --git a/packages/rivetkit/src/manager/router.ts b/packages/rivetkit/src/manager/router.ts index 4abc76876..ee00fd1a0 100644 --- a/packages/rivetkit/src/manager/router.ts +++ b/packages/rivetkit/src/manager/router.ts @@ -1,14 +1,14 @@ import { createRoute, OpenAPIHono } from "@hono/zod-openapi"; import * as cbor from "cbor-x"; -import type { Hono } from "hono"; -import { cors } from "hono/cors"; +import { Hono } from "hono"; +import { cors as corsMiddleware } from "hono/cors"; +import { createMiddleware } from "hono/factory"; import { z } from "zod"; import { - ActorError, ActorNotFound, FeatureNotImplemented, MissingActorHeader, - RouteNotFound, + Unsupported, WebSocketsNotEnabled, } from "@/actor/errors"; import { @@ -16,6 +16,8 @@ import { handleRouteNotFound, loggerMiddleware, } from "@/common/router"; +import { createManagerInspectorRouter } from "@/inspector/manager"; +import { secureInspector } from "@/inspector/utils"; import { type ActorsCreateRequest, ActorsCreateRequestSchema, @@ -68,12 +70,12 @@ export function createManagerRouter( router.use("*", loggerMiddleware(logger())); - if (runConfig.cors) { - router.use("*", cors(runConfig.cors)); - } + const cors = runConfig.cors + ? corsMiddleware(runConfig.cors) + : createMiddleware((_c, next) => next()); // Actor proxy middleware - intercept requests with x-rivet-target=actor - router.use("*", async (c, next) => { + router.use("*", cors, async (c, next) => { const target = c.req.header("x-rivet-target"); const actorId = c.req.header("x-rivet-actor"); @@ -135,7 +137,7 @@ export function createManagerRouter( }); // GET / - router.get("/", (c) => { + router.get("/", cors, (c) => { return c.text( "This is a RivetKit server.\n\nLearn more at https://rivetkit.org", ); @@ -144,6 +146,7 @@ export function createManagerRouter( // GET /actors/by-id { const route = createRoute({ + middleware: [cors], method: "get", path: "/actors/by-id", request: { @@ -177,6 +180,7 @@ export function createManagerRouter( // PUT /actors/by-id { const route = createRoute({ + cors: [cors], method: "put", path: "/actors/by-id", request: { @@ -241,6 +245,7 @@ export function createManagerRouter( // GET /actors/{actor_id} { const route = createRoute({ + middleware: [cors], method: "get", path: "/actors/{actor_id}", request: { @@ -287,6 +292,7 @@ export function createManagerRouter( // POST /actors { const route = createRoute({ + middleware: [cors], method: "post", path: "/actors", request: { @@ -346,6 +352,7 @@ export function createManagerRouter( // DELETE /actors/{actor_id} { const route = createRoute({ + middleware: [cors], method: "delete", path: "/actors/{actor_id}", request: { @@ -370,6 +377,23 @@ export function createManagerRouter( }); } + if (runConfig.inspector?.enabled) { + if (!managerDriver.inspector) { + throw new Unsupported("inspector"); + } + router.route( + "/inspect", + new Hono<{ Variables: { inspector: any } }>() + .use(corsMiddleware(runConfig.inspector.cors)) + .use(secureInspector(runConfig)) + .use((c, next) => { + c.set("inspector", managerDriver.inspector!); + return next(); + }) + .route("/", createManagerInspectorRouter()), + ); + } + // Error handling router.notFound(handleRouteNotFound); router.onError(handleRouteError); diff --git a/packages/rivetkit/src/registry/mod.ts b/packages/rivetkit/src/registry/mod.ts index 895c61157..6ee7f1375 100644 --- a/packages/rivetkit/src/registry/mod.ts +++ b/packages/rivetkit/src/registry/mod.ts @@ -1,10 +1,6 @@ import type { Hono } from "hono"; import { type Client, createClientWithDriver } from "@/client/client"; -import { - configureBaseLogger, - configureDefaultLogger, - getPinoLevel, -} from "@/common/log"; +import { configureBaseLogger, configureDefaultLogger } from "@/common/log"; import { chooseDefaultDriver } from "@/drivers/default"; import { getInspectorUrl } from "@/inspector/utils"; import { createManagerRouter } from "@/manager/router"; @@ -86,7 +82,7 @@ export class Registry { definitions: Object.keys(this.#config.use).length, ...driverLog, }); - if (config.inspector?.enabled) { + if (config.inspector?.enabled && managerDriver.inspector) { logger().info({ msg: "inspector ready", url: getInspectorUrl(config) }); } @@ -100,7 +96,7 @@ export class Registry { const padding = " ".repeat(Math.max(0, 13 - k.length)); console.log(` - ${k}:${padding}${v}`); } - if (config.inspector?.enabled) { + if (config.inspector?.enabled && managerDriver.inspector) { console.log(` - Inspector: ${getInspectorUrl(config)}`); } console.log(); From 30c5c6d5997f2d167d6ab0386f7fbe9cb8ee2e55 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Mon, 15 Sep 2025 20:15:14 -0700 Subject: [PATCH 3/3] fix(core): fix websockets (#1229) * fix(inspector): make inspector work again * fix(core): fix websockets --------- Co-authored-by: Kacper Wojciechowski <39823706+jog1t@users.noreply.github.com> --- packages/rivetkit/src/manager/router.ts | 13 ++++++++++--- .../remote-manager-driver/actor-websocket-client.ts | 3 +++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/packages/rivetkit/src/manager/router.ts b/packages/rivetkit/src/manager/router.ts index ee00fd1a0..87697f74d 100644 --- a/packages/rivetkit/src/manager/router.ts +++ b/packages/rivetkit/src/manager/router.ts @@ -100,9 +100,16 @@ export function createManagerRouter( // For WebSocket, use the driver's proxyWebSocket method // Extract any additional headers that might be needed - const encoding = c.req.header("x-rivet-encoding") || "json"; - const connParams = c.req.header("x-rivet-conn-params"); - const authData = c.req.header("x-rivet-auth-data"); + const encoding = + c.req.header("X-RivetKit-Encoding") || + c.req.header("x-rivet-encoding") || + "json"; + const connParams = + c.req.header("X-RivetKit-Conn-Params") || + c.req.header("x-rivet-conn-params"); + const authData = + c.req.header("X-RivetKit-Auth-Data") || + c.req.header("x-rivet-auth-data"); return await managerDriver.proxyWebSocket( c, diff --git a/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts b/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts index abb7e9db0..3a08593cb 100644 --- a/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts +++ b/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts @@ -34,6 +34,9 @@ export async function openWebSocketToActor( headers: buildGuardHeadersForWebSocket(actorId, encoding, params), }); + // Set binary type to arraybuffer for proper encoding support + ws.binaryType = "arraybuffer"; + logger().debug({ msg: "websocket connection opened", actorId }); return ws as UniversalWebSocket;