Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/afraid-adults-chew.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
'@srcbook/components': patch
'@srcbook/shared': patch
'@srcbook/api': patch
'@srcbook/web': patch
---

Update websocket client to pass context and connection
181 changes: 123 additions & 58 deletions packages/api/server/ws-client.mts
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,135 @@ import z from 'zod';
import { type RawData, WebSocket } from 'ws';
import { WebSocketMessageSchema } from '@srcbook/shared';

const VALID_TOPIC_RE = /^[a-zA-Z0-9_:]+$/;
type TopicPart = { dynamic: false; segment: string } | { dynamic: true; parameter: string };

export type MessageContextType<Key extends string = string> = {
topic: string;
event: string;
params: Record<Key, string>;
};

type TopicMatch = Pick<MessageContextType, 'topic' | 'params'>;

export interface ConnectionContextType {
reply: (topic: string, event: string, payload: Record<string, any>) => void;
}

/**
* Channel is responsible for dispatching incoming and outgoing messages for a given topic.
* Channel is responsible for dispatching incoming messages for a given topic.
*
* Topics are strings that represent a channel for messages. Topics
* can be broken into multiple parts separated by a colon. The following
* are all examples of valid topics:
*
* - session
* - session:123
* - room:123:users:456:messages
*
* When we define a topic, we can use the `<variable>` syntax to indicate a
* wildcard match. For example, the topic `room:<roomId>:messages` would match
* `room:123:messages`, `room:456:messages`, etc.
*
* The wildcard syntax must be between two colons (or at the start/end of the string).
* The text inside must be a valid JavaScript identifier.
*
* Examples:
*
* const channel = new Channel("session") // matches "session" only
* const channel = new Channel("session:*") // matches "session:123", "session:456", etc.
* const channel = new Channel("session") // matches "session" only
* const channel = new Channel("session:<sessionId>") // matches "session:123", "session:456", etc.
*
*/
export class Channel {
// The topic pattern, e.g. "sessions:<sessionId>"
readonly topic: string;

readonly events: {
incoming: Record<
string,
{ schema: z.ZodTypeAny; handler: (payload: Record<string, any>) => void }
>;
outgoing: Record<string, z.ZodTypeAny>;
} = { incoming: {}, outgoing: {} };

private wildcardMatch = false;
// The parts of the topic string, e.g. "sessions" and "<sessionId>" for "sessions:<sessionId>"
private readonly parts: TopicPart[];

readonly events: Record<
string,
{
schema: z.ZodTypeAny;
handler: (
payload: Record<string, any>,
context: MessageContextType,
conn: ConnectionContextType,
) => void;
}
> = {};

constructor(topic: string) {
if (topic.endsWith(':*')) {
// Remove asterisk from topic
topic = topic.slice(0, -1);
this.wildcardMatch = true;
}
this.topic = topic;
this.parts = this.splitIntoParts(topic);
}

private splitIntoParts(topic: string) {
const parts: TopicPart[] = [];

for (const part of topic.split(':')) {
const parameter = part.match(/^<([a-zA-Z_]+[a-zA-Z0-9_]*)>$/);

if (!VALID_TOPIC_RE.test(topic)) {
throw new Error(`Invalid channel topic '${topic}'`);
if (parameter !== null) {
parts.push({ dynamic: true, parameter: parameter[1] as string });
continue;
}

if (/^[a-zA-Z0-9_]+$/.test(part)) {
parts.push({ dynamic: false, segment: part });
continue;
}

throw new Error(`Invalid channel topic: ${topic}`);
}

this.topic = topic;
return parts;
}

matches(topic: string) {
if (topic === this.topic) {
return true;
match(topic: string): TopicMatch | null {
const parts = topic.split(':');

if (parts.length !== this.parts.length) {
return null;
}

if (this.wildcardMatch) {
return topic.startsWith(this.topic) && topic.length > this.topic.length;
const match: TopicMatch = {
topic: topic,
params: {},
};

for (let i = 0, len = this.parts.length; i < len; i++) {
const thisPart = this.parts[i] as TopicPart;

if (thisPart.dynamic) {
match.params[thisPart.parameter] = parts[i] as string;
continue;
} else if (thisPart.segment === parts[i]) {
continue;
}

return null;
}

return false;
return match;
}

incoming<T extends z.ZodTypeAny>(
on<T extends z.ZodTypeAny>(
event: string,
schema: T,
handler: (payload: z.infer<T>) => void,
handler: (
payload: z.infer<T>,
context: MessageContextType,
conn: ConnectionContextType,
) => void,
) {
this.events.incoming[event] = { schema, handler };
return this;
}

outgoing<T extends z.ZodTypeAny>(event: string, schema: T) {
this.events.outgoing[event] = schema;
this.events[event] = { schema, handler };
return this;
}
}

type ConnectionType = {
// Reply only to this connection, not to all connections.
reply: (topic: string, event: string, payload: Record<string, any>) => void;
socket: WebSocket;
subscriptions: string[];
};
Expand All @@ -90,7 +153,13 @@ export default class WebSocketServer {
return;
}

const connection = { socket, subscriptions: [] };
const connection = {
socket,
subscriptions: [],
reply: (topic: string, event: string, payload: Record<string, any>) => {
this.send(connection, topic, event, payload);
},
};

this.connections.push(connection);

Expand All @@ -115,23 +184,9 @@ export default class WebSocketServer {
}

broadcast(topic: string, event: string, payload: Record<string, any>) {
const channel = this.findChannel(topic);

if (channel === undefined) {
throw new Error(`Cannot broadcast to unknown topic '${topic}'`);
}

const schema = channel.events.outgoing[event];

if (schema === undefined) {
throw new Error(`Cannot broadcast to unknown event '${event}'`);
}

const validatedPayload = schema.parse(payload);

for (const conn of this.connections) {
if (conn.subscriptions.includes(topic)) {
conn.socket.send(JSON.stringify([topic, event, validatedPayload]));
this.send(conn, topic, event, payload);
}
}
}
Expand All @@ -140,9 +195,9 @@ export default class WebSocketServer {
const parsed = JSON.parse(message.toString('utf8'));
const [topic, event, payload] = WebSocketMessageSchema.parse(parsed);

const channel = this.findChannel(topic);
const channelMatch = this.findChannelMatch(topic);

if (channel === undefined) {
if (channelMatch === null) {
console.warn(`Server received unknown topic '${topic}'`);
return;
}
Expand All @@ -157,7 +212,9 @@ export default class WebSocketServer {
return;
}

const registeredEvent = channel.events.incoming[event];
const { channel, match } = channelMatch;

const registeredEvent = channel.events[event];

if (registeredEvent === undefined) {
console.warn(`Server received unknown event '${event}' for topic '${topic}'`);
Expand All @@ -175,20 +232,28 @@ export default class WebSocketServer {
return;
}

handler(result.data);
handler(result.data, { topic: match.topic, event: event, params: match.params }, conn);
}

private findChannel(topic: string) {
private findChannelMatch(topic: string): { channel: Channel; match: TopicMatch } | null {
for (const channel of this.channels) {
if (channel.matches(topic)) {
return channel;
const match = channel.match(topic);

if (match !== null) {
return { channel, match };
}
}

return null;
}

private removeConnection(socket: WebSocket) {
this.connections = this.connections.filter((conn) => {
return conn.socket !== socket;
});
}

private send(conn: ConnectionType, topic: string, event: string, payload: Record<string, any>) {
conn.socket.send(JSON.stringify([topic, event, payload]));
}
}
Loading