Skip to content

Commit

Permalink
feat: support upgrade hook to set headers
Browse files Browse the repository at this point in the history
  • Loading branch information
pi0 committed Feb 24, 2024
1 parent cb6721c commit 91edb54
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 162 deletions.
17 changes: 13 additions & 4 deletions playground/_shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,25 @@ export function createDemo<T extends WebSocketAdapter>(
peer.send("pong");
}
},
upgrade(req) {
return {
headers: {
"x-powered-by": "cross-ws",
"set-cookie": "cross-ws=1; SameSite=None; Secure",
},
};
},
});

const resolve: CrossWSOptions["resolve"] = (peer) => {
const resolve: CrossWSOptions["resolve"] = (info) => {
return {
open: () => {
open: (peer) => {
peer.send(
JSON.stringify(
{
url: peer.url,
headers: peer.headers && Object.fromEntries(peer.headers),
url: info.url,
headers:
info.headers && Object.fromEntries(new Headers(info.headers)),
},
undefined,
2,
Expand Down
6 changes: 3 additions & 3 deletions playground/bun.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import bunAdapter from "../src/adapters/bun";
import { createDemo, getIndexHTML } from "./_shared";

const adapter = createDemo(bunAdapter);
const { websocket, handleUpgrade } = createDemo(bunAdapter);

Bun.serve({
port: 3001,
websocket: adapter.websocket,
websocket,
async fetch(req, server) {
if (server.upgrade(req, { data: { req, server } })) {
if (await handleUpgrade(req, server)) {
return;
}
return new Response(await getIndexHTML({ name: "bun" }), {
Expand Down
36 changes: 18 additions & 18 deletions src/adapters/bun.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,82 +4,82 @@ import type { WebSocketHandler, ServerWebSocket, Server } from "bun";

import { WebSocketMessage } from "../message";
import { WebSocketError } from "../error";
import { WebSocketPeer } from "../peer";
import { WSPeer } from "../peer";
import { defineWebSocketAdapter } from "../adapter";
import { CrossWSOptions, createCrossWS } from "../crossws";

export interface AdapterOptions extends CrossWSOptions {}

type ContextData = {
_peer?: WebSocketPeer;
_peer?: WSPeer;
req?: Request;
server?: Server;
};

export interface Adapter {
websocket: WebSocketHandler<ContextData>;
handleUpgrade(req: Request, server: Server): Promise<boolean>;
}

export default defineWebSocketAdapter<Adapter, AdapterOptions>(
(hooks, options = {}) => {
const crossws = createCrossWS(hooks, options);

const getPeer = (ws: ServerWebSocket<ContextData>) => {
const getWSPeer = (ws: ServerWebSocket<ContextData>) => {
if (ws.data?._peer) {
return ws.data._peer;
}
const peer = new BunPeer({ bun: { ws } });
const peer = new BunWSPeer({ bun: { ws } });
ws.data = ws.data || {};
ws.data._peer = peer;
return peer;
};

return {
handleUpgrade(req: Request, server: Server) {
async handleUpgrade(req: Request, server: Server) {
const { headers } = await crossws.upgrade({
url: req.url,
headers: req.headers,
});
return server.upgrade(req, {
data: { req, server },
headers,
});
},
websocket: {
message: (ws, message) => {
const peer = getPeer(ws);
const peer = getWSPeer(ws);
crossws.$("bun:message", peer, ws, message);
crossws.message(peer, new WebSocketMessage(message));
},
open: (ws) => {
const peer = getPeer(ws);
const peer = getWSPeer(ws);
crossws.$("bun:open", peer, ws);
crossws.open(peer);
},
close: (ws) => {
const peer = getPeer(ws);
const peer = getWSPeer(ws);
crossws.$("bun:close", peer, ws);
crossws.close(peer, {});
},
drain: (ws) => {
const peer = getPeer(ws);
const peer = getWSPeer(ws);
crossws.$("bun:drain", peer);
},
// @ts-expect-error types unavailable but mentioned in docs
error: (ws, error) => {
const peer = getPeer(ws);
crossws.$("bun:error", peer, ws, error);
crossws.error(peer, new WebSocketError(error));
},
ping(ws, data) {
const peer = getPeer(ws);
const peer = getWSPeer(ws);
crossws.$("bun:ping", peer, ws, data);
},
pong(ws, data) {
const peer = getPeer(ws);
const peer = getWSPeer(ws);
crossws.$("bun:pong", peer, ws, data);
},
},
};
},
);

class BunPeer extends WebSocketPeer<{
class BunWSPeer extends WSPeer<{
bun: { ws: ServerWebSocket<ContextData> };
}> {
get id() {
Expand Down
18 changes: 10 additions & 8 deletions src/adapters/cloudflare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import type * as _cf from "@cloudflare/workers-types";

import { WebSocketPeer } from "../peer";
import { WSPeer } from "../peer";
import { defineWebSocketAdapter } from "../adapter.js";
import { WebSocketMessage } from "../message";
import { WebSocketError } from "../error";
Expand All @@ -20,14 +20,14 @@ export interface Adapter {
req: _cf.Request,
env: Env,
context: _cf.ExecutionContext,
): _cf.Response;
): Promise<_cf.Response>;
}

export default defineWebSocketAdapter<Adapter, AdapterOptions>(
(hooks, options = {}) => {
const crossws = createCrossWS(hooks, options);

const handleUpgrade = (
const handleUpgrade = async (
req: _cf.Request,
env: Env,
context: _cf.ExecutionContext,
Expand All @@ -40,30 +40,32 @@ export default defineWebSocketAdapter<Adapter, AdapterOptions>(
cloudflare: { client, server, req, env, context },
});

server.accept();
const { headers } = await crossws.upgrade(peer);

server.accept();
crossws.$("cloudflare:accept", peer);
crossws.open(peer);

server.addEventListener("message", (event) => {
crossws.$("cloudflare:message", peer, event);
hooks.message?.(peer, new WebSocketMessage(event.data));
crossws.message(peer, new WebSocketMessage(event.data));
});

server.addEventListener("error", (event) => {
crossws.$("cloudflare:error", peer, event);
hooks.error?.(peer, new WebSocketError(event.error));
crossws.error(peer, new WebSocketError(event.error));
});

server.addEventListener("close", (event) => {
crossws.$("cloudflare:close", peer, event);
hooks.close?.(peer, { code: event.code, reason: event.reason });
crossws.close(peer, { code: event.code, reason: event.reason });
});

// eslint-disable-next-line unicorn/no-null
return new Response(null, {
status: 101,
webSocket: client,
headers,
});
};

Expand All @@ -73,7 +75,7 @@ export default defineWebSocketAdapter<Adapter, AdapterOptions>(
},
);

class CloudflarePeer extends WebSocketPeer<{
class CloudflarePeer extends WSPeer<{
cloudflare: {
client: _cf.WebSocket;
server: _cf.WebSocket;
Expand Down
28 changes: 19 additions & 9 deletions src/adapters/deno.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import { WebSocketMessage } from "../message";
import { WebSocketError } from "../error";
import { WebSocketPeer } from "../peer";
import { WSPeer } from "../peer";
import { defineWebSocketAdapter } from "../adapter.js";
import { CrossWSOptions, createCrossWS } from "../crossws";

export interface AdapterOptions extends CrossWSOptions {}

export interface Adapter {
handleUpgrade(req: Request): Response;
handleUpgrade(req: Request): Promise<Response>;
}

declare global {
Expand All @@ -22,26 +22,36 @@ export default defineWebSocketAdapter<Adapter, AdapterOptions>(
(hooks, options = {}) => {
const crossws = createCrossWS(hooks, options);

const handleUpgrade = (req: Request) => {
const upgrade = Deno.upgradeWebSocket(req);
const peer = new DenoPeer({
const handleUpgrade = async (req: Request) => {
const { headers } = await crossws.upgrade({
url: req.url,
headers: req.headers,
});

const upgrade = Deno.upgradeWebSocket(req, {
// @ts-expect-error https://github.com/denoland/deno/pull/22242
headers,
});

const peer = new DenoWSPeer({
deno: { ws: upgrade.socket, req },
});

upgrade.socket.addEventListener("open", () => {
crossws.$("deno:open", peer);
crossws.open(peer);
});
upgrade.socket.addEventListener("message", (event) => {
crossws.$("deno:message", peer, event);
hooks.message?.(peer, new WebSocketMessage(event.data));
crossws.message(peer, new WebSocketMessage(event.data));
});
upgrade.socket.addEventListener("close", () => {
crossws.$("deno:close", peer);
hooks.close?.(peer, {});
crossws.close(peer, {});
});
upgrade.socket.addEventListener("error", (error) => {
crossws.$("deno:error", peer, error);
hooks.error?.(peer, new WebSocketError(error));
crossws.error(peer, new WebSocketError(error));
});
return upgrade.response;
};
Expand All @@ -52,7 +62,7 @@ export default defineWebSocketAdapter<Adapter, AdapterOptions>(
},
);

class DenoPeer extends WebSocketPeer<{
class DenoWSPeer extends WSPeer<{
deno: { ws: any; req: Request };
}> {
get id() {
Expand Down
Loading

0 comments on commit 91edb54

Please sign in to comment.