Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 34 additions & 36 deletions src/websocket/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@ import type http from "node:http";
import type { Duplex } from "node:stream";

import { ClientInfoSchema } from "@shellular/protocol";
import type { WebSocket, WebSocketServer } from "ws";
import { z } from "zod";

import { getClient, verifyClient } from "@/db/client";
import { getHost } from "@/db/host";
import { env } from "@/env";
import { logger } from "@/logger";
import { getActiveSessionForHost, joinSession } from "./sessions";
import { CloseCodeAndReason, closeWsWithError } from "./shared";
import { initAppWebSocket, requestClientApprovalFromHost } from "./ws-app";
import { initCliWebSocket } from "./ws-cli";

const PING_INTERVAL_MS = 30_000;

const HostQuerySchema = z.object({
hostId: z.string(),
});
Expand All @@ -27,9 +25,6 @@ export function initWebSocketRelay(server: http.Server) {
handleUpgradeRequest(request, socket, head);
});

setupKeepAlive(cliWsServer);
setupKeepAlive(appWsServer);

return { cliWsServer, appWsServer };
}

Expand Down Expand Up @@ -70,6 +65,14 @@ async function handleUpgradeRequest(
}

if (pathname === "/app") {
const origin = request.headers.origin ?? "";
if (env.NODE_ENV !== "dev" && !isAppOriginAllowed(origin)) {
logger.warn(`Rejecting app websocket: disallowed origin: '${origin}'`);
socket.write("HTTP/1.1 403 Forbidden\r\nConnection: close\r\n\r\n");
socket.destroy();
return;
}

// async approval before upgrade is fine
const parsed = ClientInfoSchema.safeParse(query);

Expand Down Expand Up @@ -110,16 +113,12 @@ async function handleUpgradeRequest(
return;
}

const approval = await requestClientApprovalFromHost(
session,
parsed.data,
);
if (!approval.approved) {
const { code, reason } = CloseCodeAndReason.APPROVAL_DENIED;
const failure = await requestClientApprovalFromHost(session, parsed.data);
if (failure) {
logger.info(
`Rejecting app websocket: approval denied for hostId=${hostId} clientId=${parsed.data.clientId} reason=${approval.reason}`,
`Rejecting app websocket: approval denied for hostId=${hostId} clientId=${parsed.data.clientId} reason=${failure.reason}`,
);
closeWsWithError(ws, code, reason);
closeWsWithError(ws, failure.code, failure.reason);
return;
}

Expand All @@ -144,31 +143,30 @@ async function handleUpgradeRequest(
socket.destroy();
}

function setupKeepAlive(wsServer: WebSocketServer) {
const aliveSet = new WeakSet<WebSocket>();

wsServer.on("connection", (ws) => {
aliveSet.add(ws);
const APP_PROTOCOL = "shellular:";
const WEB_PROTOCOLS = new Set(["https:", "wss:"]);

ws.on("pong", () => {
aliveSet.add(ws);
});
function isAppOriginAllowed(origin: string): boolean {
try {
const url = new URL(origin);

ws.on("close", () => {
aliveSet.delete(ws);
});
});
if (url.protocol === APP_PROTOCOL) {
return true;
}
Comment thread
biraj21 marked this conversation as resolved.

// Periodic ping to detect dead connections and keep connections alive
// through reverse proxies and load balancers
setInterval(() => {
for (const ws of wsServer.clients) {
if (!aliveSet.has(ws)) {
continue;
}
if (!WEB_PROTOCOLS.has(url.protocol)) {
return false;
}

aliveSet.delete(ws);
ws.ping();
if (
url.hostname === "shellular.dev" ||
url.hostname.endsWith(".shellular.dev")
) {
return true;
}
}, PING_INTERVAL_MS);

return false;
} catch {
return false;
}
}
26 changes: 11 additions & 15 deletions src/websocket/sessions.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { type ClientInfo, type HostInfo, MsgType } from "@shellular/protocol";
import type { ClientInfo, HostInfo } from "@shellular/protocol";
import { nanoid } from "nanoid";
import type { WebSocket } from "ws";

import { sleep } from "@/utils";
import { CloseCodeAndReason } from "./shared";

export interface ClientInfoWithWebSocket {
Expand Down Expand Up @@ -70,9 +69,14 @@ export function joinSession(

export function removeClient(sessionId: string, clientId: string): void {
const session = sessions.get(sessionId);
if (!session) return;
if (!session) {
return;
}

const clientInfo = session.clients.get(clientId);
if (!clientInfo) return;
if (!clientInfo) {
return;
}

session.clients.delete(clientId);
socketToSession.delete(clientInfo.ws);
Expand All @@ -91,7 +95,7 @@ export function getSessionForSocket(ws: WebSocket) {
return socketToSession.get(ws) ?? null;
}

export async function removeSocket(ws: WebSocket) {
export function removeSocket(ws: WebSocket) {
const entry = socketToSession.get(ws);
if (!entry) {
return;
Expand All @@ -100,19 +104,11 @@ export async function removeSocket(ws: WebSocket) {
socketToSession.delete(ws);

if (entry.role === "host") {
// Notify all WS clients that host disconnected
// Close all WS clients that with host disconnected code
for (const [, clientInfo] of entry.session.clients) {
clientInfo.ws.send(
JSON.stringify({
type: MsgType.SESSION_ERROR,
id: `server_${nanoid(8)}`,
error: "Host disconnected",
}),
);
await sleep(250); // Give client a moment to receive message before closing
const { code, reason } = CloseCodeAndReason.HOST_DISCONNECTED;
clientInfo.ws.close(code, reason);
socketToSession.delete(ws);
socketToSession.delete(clientInfo.ws);
}
entry.session.clients.clear();
sessions.delete(entry.session.id);
Expand Down
48 changes: 47 additions & 1 deletion src/websocket/shared.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import { MsgType, type SessionErrorMsg } from "@shellular/protocol";
import { nanoid } from "nanoid";
import type { WebSocket } from "ws";
import type { WebSocket, WebSocketServer } from "ws";

import { env } from "@/env";
import { logger } from "@/logger";
import { removeSocket } from "./sessions";

const PING_INTERVAL_MS = 30_000;

function sendSessionErrorMsg(
ws: WebSocket,
Expand Down Expand Up @@ -67,8 +70,51 @@ export const CloseCodeAndReason = {
HOST_AUTH_FAILED: { code: 4007, reason: "host_auth_failed" },
} as const;

export type CloseCodeAndReasonValue =
(typeof CloseCodeAndReason)[keyof typeof CloseCodeAndReason];

export function closeWsWithError(ws: WebSocket, code: number, reason: string) {
// reason should stay short (<123 bytes)
logger.info(`Closing websocket with code=${code} reason=${reason}`);
ws.close(code, reason.slice(0, 123));
}

/**
* Periodic ping to detect dead connections and keep connections alive through
* reverse proxies and load balancers. A socket that misses a pong between two
* cycles is terminated so its session is freed for reconnect.
*
*/
export function setupKeepAlive(wsServer: WebSocketServer): void {
const aliveSet = new WeakSet<WebSocket>();

wsServer.on("connection", (ws) => {
aliveSet.add(ws);

ws.on("pong", () => {
aliveSet.add(ws);
});

ws.on("close", () => {
aliveSet.delete(ws);
});
});

const interval = setInterval(() => {
for (const ws of wsServer.clients) {
if (!aliveSet.has(ws)) {
logger.info("Terminating unresponsive websocket (missed pong)");
removeSocket(ws);
ws.terminate();
continue;
}

aliveSet.delete(ws);
ws.ping();
}
}, PING_INTERVAL_MS);

wsServer.on("close", () => {
clearInterval(interval);
});
}
Loading