From 4011e0e98292828217b47ef11426b7819c1f9db4 Mon Sep 17 00:00:00 2001 From: Ben Reinhart Date: Wed, 25 Sep 2024 15:52:38 -0700 Subject: [PATCH 1/3] Update WS server client to pass in topic context and connection --- packages/api/server/ws-client.mts | 181 +++++++++++------ packages/api/server/ws.mts | 182 ++++++++---------- packages/shared/src/schemas/websockets.mts | 26 +-- packages/web/src/components/cells/code.tsx | 14 +- .../src/components/cells/get-completions.ts | 2 - packages/web/src/components/cells/hover.ts | 8 +- .../web/src/components/use-package-json.tsx | 5 +- .../web/src/components/use-tsconfig-json.tsx | 1 - packages/web/src/routes/session.tsx | 12 +- 9 files changed, 221 insertions(+), 210 deletions(-) diff --git a/packages/api/server/ws-client.mts b/packages/api/server/ws-client.mts index 396f65b3..d4d44d45 100644 --- a/packages/api/server/ws-client.mts +++ b/packages/api/server/ws-client.mts @@ -3,72 +3,135 @@ import z from 'zod'; import { type RawData, WebSocket } from 'ws'; import { WebSocketMessageSchema } from '@srcbook/shared'; -const VALID_TOPIC_RE = /^[a-zA-Z0-9_:]+$/; +type TopicPart = { dynamic: false; segment: string } | { dynamic: true; parameter: string }; + +export type MessageContextType = { + topic: string; + event: string; + params: Record; +}; + +type TopicMatch = Pick; + +export interface ConnectionContextType { + reply: (topic: string, event: string, payload: Record) => void; +} /** - * Channel is responsible for dispatching incoming and outgoing messages for a given topic. + * Channel is responsible for dispatching incoming messages for a given topic. + * + * Topics are strings that represent a channel for messages. Topics + * can be broken into multiple parts separated by a colon. The following + * are all examples of valid topics: + * + * - session + * - session:123 + * - room:123:users:456:messages + * + * When we define a topic, we can use the `` syntax to indicate a + * wildcard match. For example, the topic `room::messages` would match + * `room:123:messages`, `room:456:messages`, etc. + * + * The wildcard syntax must be between two colons (or at the start/end of the string). + * The text inside must be a valid JavaScript identifier. * * Examples: * - * const channel = new Channel("session") // matches "session" only - * const channel = new Channel("session:*") // matches "session:123", "session:456", etc. + * const channel = new Channel("session") // matches "session" only + * const channel = new Channel("session:") // matches "session:123", "session:456", etc. * */ export class Channel { + // The topic pattern, e.g. "sessions:" readonly topic: string; - readonly events: { - incoming: Record< - string, - { schema: z.ZodTypeAny; handler: (payload: Record) => void } - >; - outgoing: Record; - } = { incoming: {}, outgoing: {} }; - - private wildcardMatch = false; + // The parts of the topic string, e.g. "sessions" and "" for "sessions:" + private readonly parts: TopicPart[]; + + readonly events: Record< + string, + { + schema: z.ZodTypeAny; + handler: ( + payload: Record, + context: MessageContextType, + conn: ConnectionContextType, + ) => void; + } + > = {}; constructor(topic: string) { - if (topic.endsWith(':*')) { - // Remove asterisk from topic - topic = topic.slice(0, -1); - this.wildcardMatch = true; - } + this.topic = topic; + this.parts = this.splitIntoParts(topic); + } + + private splitIntoParts(topic: string) { + const parts: TopicPart[] = []; + + for (const part of topic.split(':')) { + const parameter = part.match(/^<([a-zA-Z_]+[a-zA-Z0-9_]*)>$/); - if (!VALID_TOPIC_RE.test(topic)) { - throw new Error(`Invalid channel topic '${topic}'`); + if (parameter !== null) { + parts.push({ dynamic: true, parameter: parameter[1] as string }); + continue; + } + + if (/^[a-zA-Z0-9_]+$/.test(part)) { + parts.push({ dynamic: false, segment: part }); + continue; + } + + throw new Error(`Invalid channel topic: ${topic}`); } - this.topic = topic; + return parts; } - matches(topic: string) { - if (topic === this.topic) { - return true; + match(topic: string): TopicMatch | null { + const parts = topic.split(':'); + + if (parts.length !== this.parts.length) { + return null; } - if (this.wildcardMatch) { - return topic.startsWith(this.topic) && topic.length > this.topic.length; + const match: TopicMatch = { + topic: topic, + params: {}, + }; + + for (let i = 0, len = this.parts.length; i < len; i++) { + const thisPart = this.parts[i] as TopicPart; + + if (thisPart.dynamic) { + match.params[thisPart.parameter] = parts[i] as string; + continue; + } else if (thisPart.segment === parts[i]) { + continue; + } + + return null; } - return false; + return match; } - incoming( + on( event: string, schema: T, - handler: (payload: z.infer) => void, + handler: ( + payload: z.infer, + context: MessageContextType, + conn: ConnectionContextType, + ) => void, ) { - this.events.incoming[event] = { schema, handler }; - return this; - } - - outgoing(event: string, schema: T) { - this.events.outgoing[event] = schema; + this.events[event] = { schema, handler }; return this; } } type ConnectionType = { + // Reply only to this connection, not to all connections. + reply: (topic: string, event: string, payload: Record) => void; socket: WebSocket; subscriptions: string[]; }; @@ -90,7 +153,13 @@ export default class WebSocketServer { return; } - const connection = { socket, subscriptions: [] }; + const connection = { + socket, + subscriptions: [], + reply: (topic: string, event: string, payload: Record) => { + this.send(connection, topic, event, payload); + }, + }; this.connections.push(connection); @@ -115,23 +184,9 @@ export default class WebSocketServer { } broadcast(topic: string, event: string, payload: Record) { - const channel = this.findChannel(topic); - - if (channel === undefined) { - throw new Error(`Cannot broadcast to unknown topic '${topic}'`); - } - - const schema = channel.events.outgoing[event]; - - if (schema === undefined) { - throw new Error(`Cannot broadcast to unknown event '${event}'`); - } - - const validatedPayload = schema.parse(payload); - for (const conn of this.connections) { if (conn.subscriptions.includes(topic)) { - conn.socket.send(JSON.stringify([topic, event, validatedPayload])); + this.send(conn, topic, event, payload); } } } @@ -140,9 +195,9 @@ export default class WebSocketServer { const parsed = JSON.parse(message.toString('utf8')); const [topic, event, payload] = WebSocketMessageSchema.parse(parsed); - const channel = this.findChannel(topic); + const channelMatch = this.findChannelMatch(topic); - if (channel === undefined) { + if (channelMatch === null) { console.warn(`Server received unknown topic '${topic}'`); return; } @@ -157,7 +212,9 @@ export default class WebSocketServer { return; } - const registeredEvent = channel.events.incoming[event]; + const { channel, match } = channelMatch; + + const registeredEvent = channel.events[event]; if (registeredEvent === undefined) { console.warn(`Server received unknown event '${event}' for topic '${topic}'`); @@ -175,15 +232,19 @@ export default class WebSocketServer { return; } - handler(result.data); + handler(result.data, { topic: match.topic, event: event, params: match.params }, conn); } - private findChannel(topic: string) { + private findChannelMatch(topic: string): { channel: Channel; match: TopicMatch } | null { for (const channel of this.channels) { - if (channel.matches(topic)) { - return channel; + const match = channel.match(topic); + + if (match !== null) { + return { channel, match }; } } + + return null; } private removeConnection(socket: WebSocket) { @@ -191,4 +252,8 @@ export default class WebSocketServer { return conn.socket !== socket; }); } + + private send(conn: ConnectionType, topic: string, event: string, payload: Record) { + conn.socket.send(JSON.stringify([topic, event, payload])); + } } diff --git a/packages/api/server/ws.mts b/packages/api/server/ws.mts index b84c37aa..0c303289 100644 --- a/packages/api/server/ws.mts +++ b/packages/api/server/ws.mts @@ -40,42 +40,32 @@ import type { TsServerDefinitionLocationRequestPayloadType, } from '@srcbook/shared'; import { - CellErrorPayloadSchema, CellUpdatePayloadSchema, - CellUpdatedPayloadSchema, CellRenamePayloadSchema, CellDeletePayloadSchema, CellFormatPayloadSchema, CellExecPayloadSchema, CellStopPayloadSchema, AiGenerateCellPayloadSchema, - AiGeneratedCellPayloadSchema, AiFixDiagnosticsPayloadSchema, DepsInstallPayloadSchema, DepsValidatePayloadSchema, - CellOutputPayloadSchema, - DepsValidateResponsePayloadSchema, TsServerStartPayloadSchema, TsServerStopPayloadSchema, - TsServerCellDiagnosticsPayloadSchema, CellCreatePayloadSchema, TsConfigUpdatePayloadSchema, - TsConfigUpdatedPayloadSchema, - TsServerCellSuggestionsPayloadSchema, TsServerQuickInfoRequestPayloadSchema, - TsServerQuickInfoResponsePayloadSchema, - CellFormattedPayloadSchema, TsServerDefinitionLocationRequestPayloadSchema, - TsServerDefinitionLocationResponsePayloadSchema, - TsServerCompletionEntriesPayloadSchema, } from '@srcbook/shared'; import tsservers from '../tsservers.mjs'; import { TsServer } from '../tsserver/tsserver.mjs'; -import WebSocketServer from './ws-client.mjs'; +import WebSocketServer, { MessageContextType } from './ws-client.mjs'; import { filenameFromPath, pathToCodeFile } from '../srcbook/path.mjs'; import { normalizeDiagnostic } from '../tsserver/utils.mjs'; import { removeCodeCellFromDisk } from '../srcbook/index.mjs'; +type SessionsContextType = MessageContextType<'sessionId'>; + const wss = new WebSocketServer(); function addRunningProcess( @@ -120,8 +110,8 @@ async function nudgeMissingDeps(wss: WebSocketServer, session: SessionType) { } } -async function cellExec(payload: CellExecPayloadType) { - const session = await findSession(payload.sessionId); +async function cellExec(payload: CellExecPayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); const cell = findCell(session, payload.cellId); if (!cell || cell.type !== 'code') { console.error(`Cannot execute cell with id ${payload.cellId}; cell not found.`); @@ -229,8 +219,8 @@ async function tsxExec({ session, cell, secrets }: ExecRequestType) { ); } -async function depsInstall(payload: DepsInstallPayloadType) { - const session = await findSession(payload.sessionId); +async function depsInstall(payload: DepsInstallPayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); const cell = session.cells.find( (cell) => cell.type === 'package.json', ) as PackageJsonCellType | void; @@ -306,13 +296,13 @@ async function depsInstall(payload: DepsInstallPayloadType) { ); } -async function depsValidate(payload: DepsValidatePayloadType) { - const session = await findSession(payload.sessionId); +async function depsValidate(_payload: DepsValidatePayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); nudgeMissingDeps(wss, session); } -async function cellStop(payload: CellStopPayloadType) { - const session = await findSession(payload.sessionId); +async function cellStop(payload: CellStopPayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); const cell = findCell(session, payload.cellId); if (!cell || cell.type !== 'code') { @@ -337,11 +327,11 @@ async function cellStop(payload: CellStopPayloadType) { } } -async function cellCreate(payload: CellCreatePayloadType) { - const session = await findSession(payload.sessionId); +async function cellCreate(payload: CellCreatePayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); if (!session) { - throw new Error(`No session exists for session '${payload.sessionId}'`); + throw new Error(`No session exists for session '${context.params.sessionId}'`); } const { index, cell } = payload; @@ -392,8 +382,8 @@ function reopenFileInTsServer( tsserver.open({ file: openFilePath, fileContent: file.source }); } -async function cellGenerate(payload: AiGenerateCellPayloadType) { - const session = await findSession(payload.sessionId); +async function cellGenerate(payload: AiGenerateCellPayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); const cell = session.cells.find((cell) => cell.id === payload.cellId) as CodeCellType; posthog.capture({ @@ -412,8 +402,11 @@ async function cellGenerate(payload: AiGenerateCellPayloadType) { }); } -async function cellFixDiagnostics(payload: AiFixDiagnosticsPayloadType) { - const session = await findSession(payload.sessionId); +async function cellFixDiagnostics( + payload: AiFixDiagnosticsPayloadType, + context: SessionsContextType, +) { + const session = await findSession(context.params.sessionId); const cell = findCell(session, payload.cellId) as CodeCellType; const result = await fixDiagnostics(session, cell, payload.diagnostics); @@ -424,16 +417,16 @@ async function cellFixDiagnostics(payload: AiFixDiagnosticsPayloadType) { }); } -async function cellFormat(payload: CellFormatPayloadType) { - const session = await findSession(payload.sessionId); +async function cellFormat(payload: CellFormatPayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); if (!session) { - throw new Error(`No session exists for session '${payload.sessionId}'`); + throw new Error(`No session exists for session '${context.params.sessionId}'`); } const cellBeforeUpdate = findCell(session, payload.cellId); if (!cellBeforeUpdate || cellBeforeUpdate.type !== 'code') { throw new Error( - `No cell exists or not a code cell for session '${payload.sessionId}' and cell '${payload.cellId}'`, + `No cell exists or not a code cell for session '${context.params.sessionId}' and cell '${payload.cellId}'`, ); } const result = await formatAndUpdateCodeCell(session, cellBeforeUpdate); @@ -461,18 +454,18 @@ async function cellFormat(payload: CellFormatPayloadType) { } } -async function cellUpdate(payload: CellUpdatePayloadType) { - const session = await findSession(payload.sessionId); +async function cellUpdate(payload: CellUpdatePayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); if (!session) { - throw new Error(`No session exists for session '${payload.sessionId}'`); + throw new Error(`No session exists for session '${context.params.sessionId}'`); } const cellBeforeUpdate = findCell(session, payload.cellId); if (!cellBeforeUpdate) { throw new Error( - `No cell exists for session '${payload.sessionId}' and cell '${payload.cellId}'`, + `No cell exists for session '${context.params.sessionId}' and cell '${payload.cellId}'`, ); } const result = await updateCell(session, cellBeforeUpdate, payload.updates); @@ -486,18 +479,18 @@ async function cellUpdate(payload: CellUpdatePayloadType) { refreshCodeCellDiagnostics(session, cell); } -async function cellRename(payload: CellRenamePayloadType) { - const session = await findSession(payload.sessionId); +async function cellRename(payload: CellRenamePayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); if (!session) { - throw new Error(`No session exists for session '${payload.sessionId}'`); + throw new Error(`No session exists for session '${context.params.sessionId}'`); } const cellBeforeUpdate = findCell(session, payload.cellId); if (!cellBeforeUpdate) { throw new Error( - `No cell exists for session '${payload.sessionId}' and cell '${payload.cellId}'`, + `No cell exists for session '${context.params.sessionId}' and cell '${payload.cellId}'`, ); } @@ -559,18 +552,18 @@ async function cellRename(payload: CellRenamePayloadType) { } } -async function cellDelete(payload: CellDeletePayloadType) { - const session = await findSession(payload.sessionId); +async function cellDelete(payload: CellDeletePayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); if (!session) { - throw new Error(`No session exists for session '${payload.sessionId}'`); + throw new Error(`No session exists for session '${context.params.sessionId}'`); } const cell = findCell(session, payload.cellId); if (!cell) { throw new Error( - `No cell exists for session '${payload.sessionId}' and cell '${payload.cellId}'`, + `No cell exists for session '${context.params.sessionId}' and cell '${payload.cellId}'`, ); } @@ -682,11 +675,11 @@ function createTsServer(session: SessionType) { return tsserver; } -async function tsserverStart(payload: TsServerStartPayloadType) { - const session = await findSession(payload.sessionId); +async function tsserverStart(_payload: TsServerStartPayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); if (!session) { - throw new Error(`No session exists for session '${payload.sessionId}'`); + throw new Error(`No session exists for session '${context.params.sessionId}'`); } if (session.language !== 'typescript') { @@ -699,15 +692,15 @@ async function tsserverStart(payload: TsServerStartPayloadType) { ); } -async function tsserverStop(payload: TsServerStopPayloadType) { - tsservers.shutdown(payload.sessionId); +async function tsserverStop(_payload: TsServerStopPayloadType, context: SessionsContextType) { + tsservers.shutdown(context.params.sessionId); } -async function tsconfigUpdate(payload: TsConfigUpdatePayloadType) { - const session = await findSession(payload.sessionId); +async function tsconfigUpdate(payload: TsConfigUpdatePayloadType, context: SessionsContextType) { + const session = await findSession(context.params.sessionId); if (!session) { - throw new Error(`No session exists for session '${payload.sessionId}'`); + throw new Error(`No session exists for session '${context.params.sessionId}'`); } posthog.capture({ event: 'user updated tsconfig' }); @@ -725,11 +718,14 @@ async function tsconfigUpdate(payload: TsConfigUpdatePayloadType) { }); } -async function tsserverQuickInfo(payload: TsServerQuickInfoRequestPayloadType) { - const session = await findSession(payload.sessionId); +async function tsserverQuickInfo( + payload: TsServerQuickInfoRequestPayloadType, + context: SessionsContextType, +) { + const session = await findSession(context.params.sessionId); if (!session) { - throw new Error(`No session exists for session '${payload.sessionId}'`); + throw new Error(`No session exists for session '${context.params.sessionId}'`); } if (session.language !== 'typescript') { @@ -764,11 +760,14 @@ async function tsserverQuickInfo(payload: TsServerQuickInfoRequestPayloadType) { }); } -async function getCompletions(payload: TsServerDefinitionLocationRequestPayloadType) { - const session = await findSession(payload.sessionId); +async function getCompletions( + payload: TsServerDefinitionLocationRequestPayloadType, + context: SessionsContextType, +) { + const session = await findSession(context.params.sessionId); if (!session) { - throw new Error(`No session exists for session '${payload.sessionId}'`); + throw new Error(`No session exists for session '${context.params.sessionId}'`); } if (session.language !== 'typescript') { @@ -798,11 +797,14 @@ async function getCompletions(payload: TsServerDefinitionLocationRequestPayloadT }); } -async function getDefinitionLocation(payload: TsServerDefinitionLocationRequestPayloadType) { - const session = await findSession(payload.sessionId); +async function getDefinitionLocation( + payload: TsServerDefinitionLocationRequestPayloadType, + context: SessionsContextType, +) { + const session = await findSession(context.params.sessionId); if (!session) { - throw new Error(`No session exists for session '${payload.sessionId}'`); + throw new Error(`No session exists for session '${context.params.sessionId}'`); } if (session.language !== 'typescript') { @@ -852,51 +854,33 @@ function refreshCodeCellDiagnostics(session: SessionType, cell: CodeCellType) { requestAllDiagnostics(tsserver, session); } } + wss - .channel('session:*') - .incoming('cell:exec', CellExecPayloadSchema, cellExec) - .incoming('cell:stop', CellStopPayloadSchema, cellStop) - .incoming('cell:create', CellCreatePayloadSchema, cellCreate) - .incoming('cell:update', CellUpdatePayloadSchema, cellUpdate) - .incoming('cell:rename', CellRenamePayloadSchema, cellRename) - .incoming('cell:delete', CellDeletePayloadSchema, cellDelete) - .incoming('cell:format', CellFormatPayloadSchema, cellFormat) - .incoming('ai:generate', AiGenerateCellPayloadSchema, cellGenerate) - .incoming('ai:fix_diagnostics', AiFixDiagnosticsPayloadSchema, cellFixDiagnostics) - .incoming('deps:install', DepsInstallPayloadSchema, depsInstall) - .incoming('deps:validate', DepsValidatePayloadSchema, depsValidate) - .incoming('tsserver:start', TsServerStartPayloadSchema, tsserverStart) - .incoming('tsserver:stop', TsServerStopPayloadSchema, tsserverStop) - .incoming('tsconfig.json:update', TsConfigUpdatePayloadSchema, tsconfigUpdate) - .incoming( - 'tsserver:cell:quickinfo:request', - TsServerQuickInfoRequestPayloadSchema, - tsserverQuickInfo, - ) - .incoming( + .channel('session:') + .on('cell:exec', CellExecPayloadSchema, cellExec) + .on('cell:stop', CellStopPayloadSchema, cellStop) + .on('cell:create', CellCreatePayloadSchema, cellCreate) + .on('cell:update', CellUpdatePayloadSchema, cellUpdate) + .on('cell:rename', CellRenamePayloadSchema, cellRename) + .on('cell:delete', CellDeletePayloadSchema, cellDelete) + .on('cell:format', CellFormatPayloadSchema, cellFormat) + .on('ai:generate', AiGenerateCellPayloadSchema, cellGenerate) + .on('ai:fix_diagnostics', AiFixDiagnosticsPayloadSchema, cellFixDiagnostics) + .on('deps:install', DepsInstallPayloadSchema, depsInstall) + .on('deps:validate', DepsValidatePayloadSchema, depsValidate) + .on('tsserver:start', TsServerStartPayloadSchema, tsserverStart) + .on('tsserver:stop', TsServerStopPayloadSchema, tsserverStop) + .on('tsconfig.json:update', TsConfigUpdatePayloadSchema, tsconfigUpdate) + .on('tsserver:cell:quickinfo:request', TsServerQuickInfoRequestPayloadSchema, tsserverQuickInfo) + .on( 'tsserver:cell:definition_location:request', TsServerDefinitionLocationRequestPayloadSchema, getDefinitionLocation, ) - .incoming( + .on( 'tsserver:cell:completions:request', TsServerDefinitionLocationRequestPayloadSchema, getCompletions, - ) - .outgoing('tsserver:cell:completions:response', TsServerCompletionEntriesPayloadSchema) - .outgoing('tsserver:cell:quickinfo:response', TsServerQuickInfoResponsePayloadSchema) - .outgoing( - 'tsserver:cell:definition_location:response', - TsServerDefinitionLocationResponsePayloadSchema, - ) - .outgoing('cell:updated', CellUpdatedPayloadSchema) - .outgoing('cell:formatted', CellFormattedPayloadSchema) - .outgoing('cell:error', CellErrorPayloadSchema) - .outgoing('cell:output', CellOutputPayloadSchema) - .outgoing('ai:generated', AiGeneratedCellPayloadSchema) - .outgoing('deps:validate:response', DepsValidateResponsePayloadSchema) - .outgoing('tsserver:cell:diagnostics', TsServerCellDiagnosticsPayloadSchema) - .outgoing('tsserver:cell:suggestions', TsServerCellSuggestionsPayloadSchema) - .outgoing('tsconfig.json:updated', TsConfigUpdatedPayloadSchema); + ); export default wss; diff --git a/packages/shared/src/schemas/websockets.mts b/packages/shared/src/schemas/websockets.mts index 63d55a6a..872cceef 100644 --- a/packages/shared/src/schemas/websockets.mts +++ b/packages/shared/src/schemas/websockets.mts @@ -16,57 +16,47 @@ export const WebSocketMessageSchema = z.tuple([ ]); export const CellExecPayloadSchema = z.object({ - sessionId: z.string(), cellId: z.string(), }); export const CellStopPayloadSchema = z.object({ - sessionId: z.string(), cellId: z.string(), }); export const CellCreatePayloadSchema = z.object({ - sessionId: z.string(), index: z.number(), cell: z.union([MarkdownCellSchema, CodeCellSchema]), }); export const CellUpdatePayloadSchema = z.object({ - sessionId: z.string(), cellId: z.string(), updates: CellUpdateAttrsSchema, }); export const CellFormatPayloadSchema = z.object({ - sessionId: z.string(), cellId: z.string(), }); export const AiGenerateCellPayloadSchema = z.object({ - sessionId: z.string(), cellId: z.string(), prompt: z.string(), }); export const AiFixDiagnosticsPayloadSchema = z.object({ - sessionId: z.string(), cellId: z.string(), diagnostics: z.string(), }); export const CellRenamePayloadSchema = z.object({ - sessionId: z.string(), cellId: z.string(), filename: z.string(), }); export const CellDeletePayloadSchema = z.object({ - sessionId: z.string(), cellId: z.string(), }); export const CellErrorPayloadSchema = z.object({ - sessionId: z.string(), cellId: z.string(), errors: z.array( z.object({ @@ -98,25 +88,18 @@ export const CellOutputPayloadSchema = z.object({ }); export const DepsInstallPayloadSchema = z.object({ - sessionId: z.string(), packages: z.array(z.string()).optional(), }); -export const DepsValidatePayloadSchema = z.object({ - sessionId: z.string(), -}); +export const DepsValidatePayloadSchema = z.object({}); export const DepsValidateResponsePayloadSchema = z.object({ packages: z.array(z.string()).optional(), }); -export const TsServerStartPayloadSchema = z.object({ - sessionId: z.string(), -}); +export const TsServerStartPayloadSchema = z.object({}); -export const TsServerStopPayloadSchema = z.object({ - sessionId: z.string(), -}); +export const TsServerStopPayloadSchema = z.object({}); export const TsServerCellDiagnosticsPayloadSchema = z.object({ cellId: z.string(), @@ -130,7 +113,6 @@ export const TsServerCellSuggestionsPayloadSchema = z.object({ export const TsServerQuickInfoRequestPayloadSchema = z.object({ cellId: z.string(), - sessionId: z.string(), request: TsServerQuickInfoRequestSchema, }); @@ -140,7 +122,6 @@ export const TsServerQuickInfoResponsePayloadSchema = z.object({ export const TsServerDefinitionLocationRequestPayloadSchema = z.object({ cellId: z.string(), - sessionId: z.string(), request: TsServerQuickInfoRequestSchema, }); @@ -153,7 +134,6 @@ export const TsServerCompletionEntriesPayloadSchema = z.object({ }); export const TsConfigUpdatePayloadSchema = z.object({ - sessionId: z.string(), source: z.string(), }); diff --git a/packages/web/src/components/cells/code.tsx b/packages/web/src/components/cells/code.tsx index 2dfdc07d..8f5206f8 100644 --- a/packages/web/src/components/cells/code.tsx +++ b/packages/web/src/components/cells/code.tsx @@ -124,7 +124,7 @@ type ReadOnlyProps = BaseProps & { readOnly: true }; type Props = RegularProps | ReadOnlyProps; export default function ControlledCodeCell(props: Props) { - const { readOnly, session, cell } = props; + const { readOnly, cell } = props; const channel = !readOnly ? props.channel : null; const { theme, codeTheme } = useTheme(); @@ -227,7 +227,6 @@ export default function ControlledCodeCell(props: Props) { updateCellOnClient({ ...cell, filename }); channel.push('cell:rename', { - sessionId: session.id, cellId: cell.id, filename, }); @@ -255,7 +254,6 @@ export default function ControlledCodeCell(props: Props) { setGenerationType('edit'); channel.push('ai:generate', { - sessionId: session.id, cellId: cell.id, prompt, }); @@ -269,7 +267,6 @@ export default function ControlledCodeCell(props: Props) { setCellMode('fixing'); setGenerationType('fix'); channel.push('ai:fix_diagnostics', { - sessionId: session.id, cellId: cell.id, diagnostics, }); @@ -293,7 +290,6 @@ export default function ControlledCodeCell(props: Props) { // TODO: Handle this in a more robust way setTimeout(() => { channel.push('cell:exec', { - sessionId: session.id, cellId: cell.id, }); }, DEBOUNCE_DELAY + 10); @@ -303,7 +299,7 @@ export default function ControlledCodeCell(props: Props) { if (!channel) { return; } - channel.push('cell:stop', { sessionId: session.id, cellId: cell.id }); + channel.push('cell:stop', { cellId: cell.id }); } function onRevertDiff() { @@ -327,7 +323,6 @@ export default function ControlledCodeCell(props: Props) { } setCellMode('formatting'); channel.push('cell:format', { - sessionId: session.id, cellId: cell.id, }); } @@ -364,7 +359,6 @@ export default function ControlledCodeCell(props: Props) { channel.on('tsserver:cell:definition_location:response', gotoDefCallback); channel.push('tsserver:cell:definition_location:request', { - sessionId: session.id, cellId: cell.id, request: { location: mapCMLocationToTsServer(cell.source, pos) }, }); @@ -375,13 +369,13 @@ export default function ControlledCodeCell(props: Props) { // We want the errors to be first, so we call tsLinter before tsHover. const extensions: Array = [javascript({ typescript: true })]; if (channel) { - extensions.push(tsHover(session.id, cell, channel, theme)); + extensions.push(tsHover(cell, channel, theme)); } extensions.push(tsLinter(cell, getTsServerDiagnostics, getTsServerSuggestions)); if (channel) { extensions.push( autocompletion({ - override: [(context) => getCompletions(context, session.id, cell, channel)], + override: [(context) => getCompletions(context, cell, channel)], }), ); } diff --git a/packages/web/src/components/cells/get-completions.ts b/packages/web/src/components/cells/get-completions.ts index ed61fc52..ea1f7a26 100644 --- a/packages/web/src/components/cells/get-completions.ts +++ b/packages/web/src/components/cells/get-completions.ts @@ -5,7 +5,6 @@ import { mapCMLocationToTsServer } from './util'; export function getCompletions( context: CompletionContext, - sessionId: string, cell: CodeCellType, channel: SessionChannel, ): Promise { @@ -43,7 +42,6 @@ export function getCompletions( channel.on('tsserver:cell:completions:response', callback); channel.push('tsserver:cell:completions:request', { - sessionId: sessionId, cellId: cell.id, request: { location: mapCMLocationToTsServer(cell.source, pos), diff --git a/packages/web/src/components/cells/hover.ts b/packages/web/src/components/cells/hover.ts index 2996e598..0d70b0dd 100644 --- a/packages/web/src/components/cells/hover.ts +++ b/packages/web/src/components/cells/hover.ts @@ -12,12 +12,7 @@ import { formatCode } from '@srcbook/components/src/lib/code-theme'; import { type ThemeType } from '@srcbook/components/src/components/use-theme'; /** Hover extension for TS server information */ -export function tsHover( - sessionId: string, - cell: CodeCellType, - channel: SessionChannel, - theme: ThemeType, -): Extension { +export function tsHover(cell: CodeCellType, channel: SessionChannel, theme: ThemeType): Extension { return hoverTooltip(async (view, pos) => { if (cell.language !== 'typescript') { return null; // bail early if not typescript @@ -58,7 +53,6 @@ export function tsHover( mount() { channel.on('tsserver:cell:quickinfo:response', callback); channel.push('tsserver:cell:quickinfo:request', { - sessionId: sessionId, cellId: cell.id, request: { location: mapCMLocationToTsServer(cell.source, pos) }, }); diff --git a/packages/web/src/components/use-package-json.tsx b/packages/web/src/components/use-package-json.tsx index c2a0d161..21c20095 100644 --- a/packages/web/src/components/use-package-json.tsx +++ b/packages/web/src/components/use-package-json.tsx @@ -68,7 +68,7 @@ export function PackageJsonProvider({ channel, session, children }: ProviderProp const [validationError, setValidationError] = useState(null); useEffectOnce(() => { - channel.push('deps:validate', { sessionId: session.id }); + channel.push('deps:validate', {}); }); const npmInstall = useCallback( @@ -79,7 +79,7 @@ export function PackageJsonProvider({ channel, session, children }: ProviderProp updateCellOnClient({ ...cell, status: 'running' }); clearOutput(cell.id); setOutdated(false); - channel.push('deps:install', { sessionId: session.id, packages }); + channel.push('deps:install', { packages }); } }, [cell, channel, session.id, updateCellOnClient, clearOutput], @@ -97,7 +97,6 @@ export function PackageJsonProvider({ channel, session, children }: ProviderProp function updateCellOnServer(updates: PackageJsonCellUpdateAttrsType) { channel.push('cell:update', { - sessionId: session.id, cellId: cell.id, updates, }); diff --git a/packages/web/src/components/use-tsconfig-json.tsx b/packages/web/src/components/use-tsconfig-json.tsx index bf92f584..454fe250 100644 --- a/packages/web/src/components/use-tsconfig-json.tsx +++ b/packages/web/src/components/use-tsconfig-json.tsx @@ -43,7 +43,6 @@ export function TsConfigProvider({ channel, session, children }: ProviderPropsTy if (error === null) { channel.push('tsconfig.json:update', { source, - sessionId: session.id, }); } }, diff --git a/packages/web/src/routes/session.tsx b/packages/web/src/routes/session.tsx index bb630d18..6d7c81c0 100644 --- a/packages/web/src/routes/session.tsx +++ b/packages/web/src/routes/session.tsx @@ -74,7 +74,7 @@ function SessionPage() { oldChannel.unsubscribe(); if (connectedSessionLanguageRef.current === 'typescript') { - oldChannel.push('tsserver:stop', { sessionId: session.id }); + oldChannel.push('tsserver:stop', {}); } } @@ -88,7 +88,7 @@ function SessionPage() { channel.subscribe(); if (session.language === 'typescript') { - channel.push('tsserver:start', { sessionId: session.id }); + channel.push('tsserver:start', {}); } forceComponentRerender(); @@ -170,7 +170,6 @@ function Session(props: SessionProps) { removeCell(cell); channel.push('cell:delete', { - sessionId: session.id, cellId: cell.id, }); } @@ -232,7 +231,6 @@ function Session(props: SessionProps) { return; } channel.push('cell:update', { - sessionId: session.id, cellId: cell.id, updates, }); @@ -250,11 +248,11 @@ function Session(props: SessionProps) { switch (type) { case 'code': cell = createCodeCell(index, session.language); - channel.push('cell:create', { sessionId: session.id, index, cell }); + channel.push('cell:create', { index, cell }); break; case 'markdown': cell = createMarkdownCell(index); - channel.push('cell:create', { sessionId: session.id, index, cell }); + channel.push('cell:create', { index, cell }); break; case 'generate-ai': cell = createGenerateAiCell(index); @@ -280,7 +278,7 @@ function Session(props: SessionProps) { newCell = createMarkdownCell(insertIdx, cell); break; } - channel.push('cell:create', { sessionId: session.id, index: insertIdx, cell: newCell }); + channel.push('cell:create', { index: insertIdx, cell: newCell }); } } From 6649ff19862bb0467075775b81014614182774e8 Mon Sep 17 00:00:00 2001 From: Ben Reinhart Date: Fri, 4 Oct 2024 15:43:07 -0700 Subject: [PATCH 2/3] Fix lint warning --- packages/web/src/components/use-package-json.tsx | 6 ++---- packages/web/src/components/use-tsconfig-json.tsx | 2 +- packages/web/src/routes/session.tsx | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/packages/web/src/components/use-package-json.tsx b/packages/web/src/components/use-package-json.tsx index 21c20095..5092ba6a 100644 --- a/packages/web/src/components/use-package-json.tsx +++ b/packages/web/src/components/use-package-json.tsx @@ -4,7 +4,6 @@ import { PackageJsonCellType, PackageJsonCellUpdateAttrsType, } from '@srcbook/shared'; -import { SessionType } from '@/types'; import { OutputType } from '@srcbook/components/src/types'; import { SessionChannel } from '@/clients/websocket'; import { useCells } from '@srcbook/components/src/components/use-cell'; @@ -34,7 +33,6 @@ export interface PackageJsonContextValue { const PackageJsonContext = createContext(undefined); type ProviderPropsType = { - session: SessionType; channel: SessionChannel; children: React.ReactNode; }; @@ -48,7 +46,7 @@ type ProviderPropsType = { * 2. Decouple the rest of the code from treating package.json * as a cell since we want to move away from that. */ -export function PackageJsonProvider({ channel, session, children }: ProviderPropsType) { +export function PackageJsonProvider({ channel, children }: ProviderPropsType) { const { cells, updateCell: updateCellOnClient, getOutput, clearOutput } = useCells(); const cell = cells.find((cell) => cell.type === 'package.json') as PackageJsonCellType; @@ -82,7 +80,7 @@ export function PackageJsonProvider({ channel, session, children }: ProviderProp channel.push('deps:install', { packages }); } }, - [cell, channel, session.id, updateCellOnClient, clearOutput], + [cell, channel, updateCellOnClient, clearOutput], ); useEffect(() => { diff --git a/packages/web/src/components/use-tsconfig-json.tsx b/packages/web/src/components/use-tsconfig-json.tsx index 454fe250..1fbd15cb 100644 --- a/packages/web/src/components/use-tsconfig-json.tsx +++ b/packages/web/src/components/use-tsconfig-json.tsx @@ -46,7 +46,7 @@ export function TsConfigProvider({ channel, session, children }: ProviderPropsTy }); } }, - [session.id, setSource, channel, setValidationError], + [setSource, channel, setValidationError], ); const context: TsConfigContextValue = { diff --git a/packages/web/src/routes/session.tsx b/packages/web/src/routes/session.tsx index 6d7c81c0..c40a95a9 100644 --- a/packages/web/src/routes/session.tsx +++ b/packages/web/src/routes/session.tsx @@ -96,7 +96,7 @@ function SessionPage() { return ( - + {VITE_SRCBOOK_DEBUG_RENDER_SESSION_AS_READ_ONLY ? ( From 64365e7605217b1bda96dbba5545c7502cd4eca9 Mon Sep 17 00:00:00 2001 From: Ben Reinhart Date: Fri, 4 Oct 2024 15:44:06 -0700 Subject: [PATCH 3/3] Add changeset --- .changeset/afraid-adults-chew.md | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .changeset/afraid-adults-chew.md diff --git a/.changeset/afraid-adults-chew.md b/.changeset/afraid-adults-chew.md new file mode 100644 index 00000000..4c9c7b9b --- /dev/null +++ b/.changeset/afraid-adults-chew.md @@ -0,0 +1,8 @@ +--- +'@srcbook/components': patch +'@srcbook/shared': patch +'@srcbook/api': patch +'@srcbook/web': patch +--- + +Update websocket client to pass context and connection