diff --git a/src/websocket/index.ts b/src/websocket/index.ts index 5f10b9c..238f9bc 100644 --- a/src/websocket/index.ts +++ b/src/websocket/index.ts @@ -2,19 +2,17 @@ import type http from "node:http"; import type { Duplex } from "node:stream"; import { ClientInfoSchema } from "@shellular/protocol"; -import type { WebSocket, WebSocketServer } from "ws"; import { z } from "zod"; import { getClient, verifyClient } from "@/db/client"; import { getHost } from "@/db/host"; +import { env } from "@/env"; import { logger } from "@/logger"; import { getActiveSessionForHost, joinSession } from "./sessions"; import { CloseCodeAndReason, closeWsWithError } from "./shared"; import { initAppWebSocket, requestClientApprovalFromHost } from "./ws-app"; import { initCliWebSocket } from "./ws-cli"; -const PING_INTERVAL_MS = 30_000; - const HostQuerySchema = z.object({ hostId: z.string(), }); @@ -27,9 +25,6 @@ export function initWebSocketRelay(server: http.Server) { handleUpgradeRequest(request, socket, head); }); - setupKeepAlive(cliWsServer); - setupKeepAlive(appWsServer); - return { cliWsServer, appWsServer }; } @@ -70,6 +65,14 @@ async function handleUpgradeRequest( } if (pathname === "/app") { + const origin = request.headers.origin ?? ""; + if (env.NODE_ENV !== "dev" && !isAppOriginAllowed(origin)) { + logger.warn(`Rejecting app websocket: disallowed origin: '${origin}'`); + socket.write("HTTP/1.1 403 Forbidden\r\nConnection: close\r\n\r\n"); + socket.destroy(); + return; + } + // async approval before upgrade is fine const parsed = ClientInfoSchema.safeParse(query); @@ -110,16 +113,12 @@ async function handleUpgradeRequest( return; } - const approval = await requestClientApprovalFromHost( - session, - parsed.data, - ); - if (!approval.approved) { - const { code, reason } = CloseCodeAndReason.APPROVAL_DENIED; + const failure = await requestClientApprovalFromHost(session, parsed.data); + if (failure) { logger.info( - `Rejecting app websocket: approval denied for hostId=${hostId} clientId=${parsed.data.clientId} reason=${approval.reason}`, + `Rejecting app websocket: approval denied for hostId=${hostId} clientId=${parsed.data.clientId} reason=${failure.reason}`, ); - closeWsWithError(ws, code, reason); + closeWsWithError(ws, failure.code, failure.reason); return; } @@ -144,31 +143,30 @@ async function handleUpgradeRequest( socket.destroy(); } -function setupKeepAlive(wsServer: WebSocketServer) { - const aliveSet = new WeakSet(); - - wsServer.on("connection", (ws) => { - aliveSet.add(ws); +const APP_PROTOCOL = "shellular:"; +const WEB_PROTOCOLS = new Set(["https:", "wss:"]); - ws.on("pong", () => { - aliveSet.add(ws); - }); +function isAppOriginAllowed(origin: string): boolean { + try { + const url = new URL(origin); - ws.on("close", () => { - aliveSet.delete(ws); - }); - }); + if (url.protocol === APP_PROTOCOL) { + return true; + } - // Periodic ping to detect dead connections and keep connections alive - // through reverse proxies and load balancers - setInterval(() => { - for (const ws of wsServer.clients) { - if (!aliveSet.has(ws)) { - continue; - } + if (!WEB_PROTOCOLS.has(url.protocol)) { + return false; + } - aliveSet.delete(ws); - ws.ping(); + if ( + url.hostname === "shellular.dev" || + url.hostname.endsWith(".shellular.dev") + ) { + return true; } - }, PING_INTERVAL_MS); + + return false; + } catch { + return false; + } } diff --git a/src/websocket/sessions.ts b/src/websocket/sessions.ts index 91f14c4..2ca26a3 100644 --- a/src/websocket/sessions.ts +++ b/src/websocket/sessions.ts @@ -1,8 +1,7 @@ -import { type ClientInfo, type HostInfo, MsgType } from "@shellular/protocol"; +import type { ClientInfo, HostInfo } from "@shellular/protocol"; import { nanoid } from "nanoid"; import type { WebSocket } from "ws"; -import { sleep } from "@/utils"; import { CloseCodeAndReason } from "./shared"; export interface ClientInfoWithWebSocket { @@ -70,9 +69,14 @@ export function joinSession( export function removeClient(sessionId: string, clientId: string): void { const session = sessions.get(sessionId); - if (!session) return; + if (!session) { + return; + } + const clientInfo = session.clients.get(clientId); - if (!clientInfo) return; + if (!clientInfo) { + return; + } session.clients.delete(clientId); socketToSession.delete(clientInfo.ws); @@ -91,7 +95,7 @@ export function getSessionForSocket(ws: WebSocket) { return socketToSession.get(ws) ?? null; } -export async function removeSocket(ws: WebSocket) { +export function removeSocket(ws: WebSocket) { const entry = socketToSession.get(ws); if (!entry) { return; @@ -100,19 +104,11 @@ export async function removeSocket(ws: WebSocket) { socketToSession.delete(ws); if (entry.role === "host") { - // Notify all WS clients that host disconnected + // Close all WS clients that with host disconnected code for (const [, clientInfo] of entry.session.clients) { - clientInfo.ws.send( - JSON.stringify({ - type: MsgType.SESSION_ERROR, - id: `server_${nanoid(8)}`, - error: "Host disconnected", - }), - ); - await sleep(250); // Give client a moment to receive message before closing const { code, reason } = CloseCodeAndReason.HOST_DISCONNECTED; clientInfo.ws.close(code, reason); - socketToSession.delete(ws); + socketToSession.delete(clientInfo.ws); } entry.session.clients.clear(); sessions.delete(entry.session.id); diff --git a/src/websocket/shared.ts b/src/websocket/shared.ts index 0066b3f..4fe8332 100644 --- a/src/websocket/shared.ts +++ b/src/websocket/shared.ts @@ -1,9 +1,12 @@ import { MsgType, type SessionErrorMsg } from "@shellular/protocol"; import { nanoid } from "nanoid"; -import type { WebSocket } from "ws"; +import type { WebSocket, WebSocketServer } from "ws"; import { env } from "@/env"; import { logger } from "@/logger"; +import { removeSocket } from "./sessions"; + +const PING_INTERVAL_MS = 30_000; function sendSessionErrorMsg( ws: WebSocket, @@ -67,8 +70,51 @@ export const CloseCodeAndReason = { HOST_AUTH_FAILED: { code: 4007, reason: "host_auth_failed" }, } as const; +export type CloseCodeAndReasonValue = + (typeof CloseCodeAndReason)[keyof typeof CloseCodeAndReason]; + export function closeWsWithError(ws: WebSocket, code: number, reason: string) { // reason should stay short (<123 bytes) logger.info(`Closing websocket with code=${code} reason=${reason}`); ws.close(code, reason.slice(0, 123)); } + +/** + * Periodic ping to detect dead connections and keep connections alive through + * reverse proxies and load balancers. A socket that misses a pong between two + * cycles is terminated so its session is freed for reconnect. + * + */ +export function setupKeepAlive(wsServer: WebSocketServer): void { + const aliveSet = new WeakSet(); + + wsServer.on("connection", (ws) => { + aliveSet.add(ws); + + ws.on("pong", () => { + aliveSet.add(ws); + }); + + ws.on("close", () => { + aliveSet.delete(ws); + }); + }); + + const interval = setInterval(() => { + for (const ws of wsServer.clients) { + if (!aliveSet.has(ws)) { + logger.info("Terminating unresponsive websocket (missed pong)"); + removeSocket(ws); + ws.terminate(); + continue; + } + + aliveSet.delete(ws); + ws.ping(); + } + }, PING_INTERVAL_MS); + + wsServer.on("close", () => { + clearInterval(interval); + }); +} diff --git a/src/websocket/ws-app.ts b/src/websocket/ws-app.ts index b25ba29..1e07182 100644 --- a/src/websocket/ws-app.ts +++ b/src/websocket/ws-app.ts @@ -16,8 +16,19 @@ import { z } from "zod"; import { registerClient } from "@/db/client"; import { logger } from "@/logger"; import { type ClientToHostMsg, ClientToHostMsgSchema } from "./protocol"; -import { getSessionForSocket, removeSocket, type Session } from "./sessions"; -import { CloseCodeAndReason, sendSessionErrorToClient } from "./shared"; +import { + getSessionForSocket, + removeClient, + removeSocket, + type Session, +} from "./sessions"; +import { + CloseCodeAndReason, + type CloseCodeAndReasonValue, + closeWsWithError, + sendSessionErrorToClient, + setupKeepAlive, +} from "./shared"; const CLIENT_APPROVAL_TIMEOUT_MS = 60_000; @@ -26,97 +37,123 @@ type ApprovalDecision = { reason: string; }; -type PreUpgradeApprovalEntry = { +type PendingApproval = { hostId: string; timer: ReturnType; resolve: (decision: ApprovalDecision) => void; }; -const preUpgradeApprovals = new Map(); - -/** - * Returns true if the clientId is already connected or awaiting approval for - * the given session. - */ -export function isClientOccupied(session: Session, clientId: string): boolean { - const existingClient = session.clients.get(clientId); - if (existingClient) { - // User might be switching from one device to another, so we allow a new connection to take over an existing one. - const { code, reason } = CloseCodeAndReason.CLIENT_REPLACED; - existingClient.ws.close(code, reason); - session.clients.delete(clientId); - } - return preUpgradeApprovals.has(clientId); -} +const pendingApprovals = new Map(); export function requestClientApprovalFromHost( session: Session, clientInfo: ClientInfo, -): Promise { - if (isClientOccupied(session, clientInfo.clientId)) { +): Promise { + const { clientId } = clientInfo; + + // If an earlier approval is still in flight for this clientId, cancel it. + // Resolve the prior promise so its caller cleans up; the new connection + // will start its own approval below. + const pendingApprovalEntry = pendingApprovals.get(clientId); + if (pendingApprovalEntry) { logger.info( - `Denying pending app join for hostId=${session.hostId} clientId=${clientInfo.clientId}: occupied or pending`, + `Superseding pending app approval for clientId=${clientId} hostId=${session.hostId} due to new connection`, ); - return Promise.resolve({ + clearTimeout(pendingApprovalEntry.timer); + pendingApprovals.delete(clientId); + pendingApprovalEntry.resolve({ approved: false, - reason: "Client is already connected or pending approval", + reason: "Superseded by newer connection", }); } + // If there's already an active socket for this clientId, replace it. The + // new connection wins; the old one is closed with CLIENT_REPLACED so the + // previous tab/window knows it was preempted. + const existingClient = session.clients.get(clientId); + if (existingClient) { + const { code, reason } = CloseCodeAndReason.CLIENT_REPLACED; + logger.info( + `Replacing existing app connection for clientId=${clientId} hostId=${session.hostId} (pendingApprovalEntry=${pendingApprovalEntry}) because a new connection was established with the same clientId`, + ); + closeWsWithError(existingClient.ws, code, reason); + removeClient(session.id, clientId); + } + return new Promise((resolve) => { const timer = setTimeout(() => { - const entry = preUpgradeApprovals.get(clientInfo.clientId); + const entry = pendingApprovals.get(clientId); if (!entry) { return; } - preUpgradeApprovals.delete(clientInfo.clientId); + pendingApprovals.delete(clientId); logger.info( - `App join approval timed out for hostId=${session.hostId} clientId=${clientInfo.clientId}`, + `App join approval timed out for hostId=${session.hostId} clientId=${clientId}`, ); - resolve({ approved: false, reason: "Connection request timed out" }); + resolve(CloseCodeAndReason.SESSION_JOIN_FAILED); }, CLIENT_APPROVAL_TIMEOUT_MS); - preUpgradeApprovals.set(clientInfo.clientId, { + pendingApprovals.set(clientId, { hostId: session.hostId, timer, - resolve, + resolve: (approval) => { + if (approval.approved) { + resolve(undefined); + } else { + logger.info( + `App join approval rejected by host for hostId=${session.hostId} clientId=${clientId} reason=${approval.reason}`, + ); + resolve(CloseCodeAndReason.APPROVAL_DENIED); + } + }, }); try { logger.info( - `Requesting app join approval for hostId=${session.hostId} clientId=${clientInfo.clientId}`, + `Requesting app join approval for hostId=${session.hostId} clientId=${clientId}`, ); notifyHostPendingClient(session.host, clientInfo); } catch { clearTimeout(timer); - preUpgradeApprovals.delete(clientInfo.clientId); + pendingApprovals.delete(clientId); logger.info( - `Failed to notify host about app join for hostId=${session.hostId} clientId=${clientInfo.clientId}`, + `Failed to notify host about app join for hostId=${session.hostId} clientId=${clientId}`, ); - resolve({ approved: false, reason: "Failed to reach host" }); + resolve(CloseCodeAndReason.APPROVAL_DENIED); } }); } /** * Called by ws-cli.ts when the CLI host approves or rejects a pending client. + * The resolving host must own the pending approval — pendingApprovals is keyed + * globally by clientId, so we verify hostId to prevent one host resolving + * another host's pending approval. */ export function resolvePendingClient( + hostId: string, clientId: string, approved: boolean, ): void { - const preUpgradeEntry = preUpgradeApprovals.get(clientId); - if (!preUpgradeEntry) { + const pendingApproval = pendingApprovals.get(clientId); + if (!pendingApproval) { return; } - clearTimeout(preUpgradeEntry.timer); - preUpgradeApprovals.delete(clientId); + if (pendingApproval.hostId !== hostId) { + logger.warn( + `Ignoring approval result from hostId=${hostId} for clientId=${clientId}: pending approval is owned by hostId=${pendingApproval.hostId}`, + ); + return; + } + + clearTimeout(pendingApproval.timer); + pendingApprovals.delete(clientId); logger.info( - `Resolved app join approval for hostId=${preUpgradeEntry.hostId} clientId=${clientId} approved=${approved}`, + `Resolved app join approval for hostId=${pendingApproval.hostId} clientId=${clientId} approved=${approved}`, ); - preUpgradeEntry.resolve( + pendingApproval.resolve( approved ? { approved: true, reason: "" } : { approved: false, reason: "Connection rejected by host" }, @@ -136,6 +173,7 @@ function notifyHostPendingClient( export function initAppWebSocket() { const wsServer = new WebSocketServer({ noServer: true }); + setupKeepAlive(wsServer); wsServer.on("connection", (ws) => { logger.info("App websocket connection established"); @@ -212,6 +250,7 @@ export function initAppWebSocket() { ); const entry = getSessionForSocket(ws); if (entry && entry.role === "client" && entry.clientId) { + pendingApprovals.delete(entry.clientId); // Notify CLI that client disconnected const { session, clientId } = entry; const activeClient = session.clients.get(clientId); @@ -233,6 +272,10 @@ export function initAppWebSocket() { ws.on("error", (err) => { logger.error("App websocket error", err); + const entry = getSessionForSocket(ws); + if (entry?.role === "client" && entry.clientId) { + pendingApprovals.delete(entry.clientId); + } removeSocket(ws); }); }); diff --git a/src/websocket/ws-cli.ts b/src/websocket/ws-cli.ts index 229d59b..28abaf3 100644 --- a/src/websocket/ws-cli.ts +++ b/src/websocket/ws-cli.ts @@ -27,11 +27,13 @@ import { closeWsWithError, sendSessionErrorToClient, sendSessionErrorToHost, + setupKeepAlive, } from "./shared"; import { resolvePendingClient } from "./ws-app"; export function initCliWebSocket() { const wsServer = new WebSocketServer({ noServer: true }); + setupKeepAlive(wsServer); wsServer.on("connection", (ws) => { /** * Session associated with this WebSocket connection. @@ -98,6 +100,7 @@ export function initCliWebSocket() { }); } else { resolvePendingClient( + session.hostId, parsed.data.data.clientId, parsed.data.data.approved, );