diff --git a/lib/client.ts b/lib/client.ts index f69418ede7..3ce7d5a226 100644 --- a/lib/client.ts +++ b/lib/client.ts @@ -114,7 +114,7 @@ export class Client< * @param {Object} auth - the auth parameters * @private */ - private connect(name: string, auth: object = {}): void { + private connect(name: string, auth: Record = {}): void { if (this.server._nsps.has(name)) { debug("connecting to namespace %s", name); return this.doConnect(name, auth); @@ -152,10 +152,10 @@ export class Client< * * @private */ - private doConnect(name: string, auth: object): void { + private doConnect(name: string, auth: Record): void { const nsp = this.server.of(name); - const socket = nsp._add(this, auth, () => { + nsp._add(this, auth, (socket) => { this.sockets.set(socket.id, socket); this.nsps.set(nsp.name, socket); @@ -228,7 +228,7 @@ export class Client< } private writeToEngine( - encodedPackets: Array, + encodedPackets: Array, opts: WriteOptions ): void { if (opts.volatile && !this.conn.transport.writable) { @@ -267,7 +267,7 @@ export class Client< */ private ondecoded(packet: Packet): void { let namespace: string; - let authPayload; + let authPayload: Record; if (this.conn.protocol === 3) { const parsed = url.parse(packet.nsp, true); namespace = parsed.pathname!; diff --git a/lib/index.ts b/lib/index.ts index fa18b7670c..f6ace29308 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -17,7 +17,12 @@ import { Client } from "./client"; import { EventEmitter } from "events"; import { ExtendedError, Namespace, ServerReservedEventsMap } from "./namespace"; import { ParentNamespace } from "./parent-namespace"; -import { Adapter, Room, SocketId } from "socket.io-adapter"; +import { + Adapter, + SessionAwareAdapter, + Room, + SocketId, +} from "socket.io-adapter"; import * as parser from "socket.io-parser"; import type { Encoder } from "socket.io-parser"; import debugModule from "debug"; @@ -72,6 +77,25 @@ interface ServerOptions extends EngineOptions, AttachOptions { * @default 45000 */ connectTimeout: number; + /** + * Whether to enable the recovery of connection state when a client temporarily disconnects. + * + * The connection state includes the missed packets, the rooms the socket was in and the `data` attribute. + */ + connectionStateRecovery: { + /** + * The backup duration of the sessions and the packets. + * + * @default 120000 (2 minutes) + */ + maxDisconnectionDuration?: number; + /** + * Whether to skip middlewares upon successful connection state recovery. + * + * @default true + */ + skipMiddlewares?: boolean; + }; } /** @@ -148,7 +172,7 @@ export class Server< > = new Map(); private _adapter?: AdapterConstructor; private _serveClient: boolean; - private opts: Partial; + private readonly opts: Partial; private eio: Engine; private _path: string; private clientPathRegex: RegExp; @@ -204,9 +228,20 @@ export class Server< this.serveClient(false !== opts.serveClient); this._parser = opts.parser || parser; this.encoder = new this._parser.Encoder(); - this.adapter(opts.adapter || Adapter); - this.sockets = this.of("/"); this.opts = opts; + if (opts.connectionStateRecovery) { + opts.connectionStateRecovery = Object.assign( + { + maxDisconnectionDuration: 2 * 60 * 1000, + skipMiddlewares: true, + }, + opts.connectionStateRecovery + ); + this.adapter(opts.adapter || SessionAwareAdapter); + } else { + this.adapter(opts.adapter || Adapter); + } + this.sockets = this.of("/"); if (srv || typeof srv == "number") this.attach( srv as http.Server | HTTPSServer | Http2SecureServer | number diff --git a/lib/namespace.ts b/lib/namespace.ts index 620ead42e6..ff88e2ffe3 100644 --- a/lib/namespace.ts +++ b/lib/namespace.ts @@ -296,13 +296,25 @@ export class Namespace< * @return {Socket} * @private */ - _add( + async _add( client: Client, - query, - fn?: () => void - ): Socket { + auth: Record, + fn: ( + socket: Socket + ) => void + ) { debug("adding socket to nsp %s", this.name); - const socket = new Socket(this, client, query); + const socket = await this._createSocket(client, auth); + + if ( + // @ts-ignore + this.server.opts.connectionStateRecovery?.skipMiddlewares && + socket.recovered && + client.conn.readyState === "open" + ) { + return this._doConnect(socket, fn); + } + this.run(socket, (err) => { process.nextTick(() => { if ("open" !== client.conn.readyState) { @@ -324,22 +336,53 @@ export class Namespace< } } - // track socket - this.sockets.set(socket.id, socket); - - // it's paramount that the internal `onconnect` logic - // fires before user-set events to prevent state order - // violations (such as a disconnection before the connection - // logic is complete) - socket._onconnect(); - if (fn) fn(); - - // fire user-set events - this.emitReserved("connect", socket); - this.emitReserved("connection", socket); + this._doConnect(socket, fn); }); }); - return socket; + } + + private async _createSocket( + client: Client, + auth: Record + ) { + const sessionId = auth.pid; + const offset = auth.offset; + if ( + // @ts-ignore + this.server.opts.connectionStateRecovery && + typeof sessionId === "string" && + typeof offset === "string" + ) { + const session = await this.adapter.restoreSession(sessionId, offset); + if (session) { + debug("connection state recovered for sid %s", session.sid); + return new Socket(this, client, auth, session); + } else { + debug("unable to restore session state"); + } + } + return new Socket(this, client, auth); + } + + private _doConnect( + socket: Socket, + fn: ( + socket: Socket + ) => void + ) { + // track socket + this.sockets.set(socket.id, socket); + + // it's paramount that the internal `onconnect` logic + // fires before user-set events to prevent state order + // violations (such as a disconnection before the connection + // logic is complete) + socket._onconnect(); + if (fn) fn(socket); + + // fire user-set events + this.emitReserved("connect", socket); + this.emitReserved("connection", socket); } /** diff --git a/lib/socket.ts b/lib/socket.ts index 6c4cb45ea7..3314398351 100644 --- a/lib/socket.ts +++ b/lib/socket.ts @@ -2,19 +2,21 @@ import { Packet, PacketType } from "socket.io-parser"; import debugModule from "debug"; import type { Server } from "./index"; import { - EventParams, + DefaultEventsMap, EventNames, + EventParams, EventsMap, StrictEventEmitter, - DefaultEventsMap, } from "./typed-events"; import type { Client } from "./client"; import type { Namespace, NamespaceReservedEventsMap } from "./namespace"; -import type { IncomingMessage, IncomingHttpHeaders } from "http"; +import type { IncomingHttpHeaders, IncomingMessage } from "http"; import type { Adapter, BroadcastFlags, + PrivateSessionId, Room, + Session, SocketId, } from "socket.io-adapter"; import base64id from "base64id"; @@ -39,6 +41,15 @@ export type DisconnectReason = | "client namespace disconnect" | "server namespace disconnect"; +const RECOVERABLE_DISCONNECT_REASONS: ReadonlySet = new Set([ + "transport error", + "transport close", + "forced close", + "ping timeout", + "server shutting down", + "forced server close", +]); + export interface SocketReservedEventsMap { disconnect: (reason: DisconnectReason) => void; disconnecting: (reason: DisconnectReason) => void; @@ -173,6 +184,11 @@ export class Socket< * An unique identifier for the session. */ public readonly id: SocketId; + /** + * Whether the connection state was recovered after a temporary disconnection. In that case, any missed packets will + * be transmitted to the client, the data attribute and the rooms will be restored. + */ + public readonly recovered: boolean = false; /** * The handshake details. */ @@ -197,6 +213,14 @@ export class Socket< */ public connected: boolean = false; + /** + * The session ID, which must not be shared (unlike {@link id}). + * + * @private + */ + private readonly pid: PrivateSessionId; + + // TODO: remove this unused reference private readonly server: Server< ListenEvents, EmitEvents, @@ -221,16 +245,32 @@ export class Socket< constructor( readonly nsp: Namespace, readonly client: Client, - auth: object + auth: Record, + previousSession?: Session ) { super(); this.server = nsp.server; this.adapter = this.nsp.adapter; - if (client.conn.protocol === 3) { - // @ts-ignore - this.id = nsp.name !== "/" ? nsp.name + "#" + client.id : client.id; + if (previousSession) { + this.id = previousSession.sid; + this.pid = previousSession.pid; + previousSession.rooms.forEach((room) => this.join(room)); + this.data = previousSession.data as Partial; + previousSession.missedPackets.forEach((packet) => { + this.packet({ + type: PacketType.EVENT, + data: packet, + }); + }); + this.recovered = true; } else { - this.id = base64id.generateId(); // don't reuse the Engine.IO id because it's sensitive information + if (client.conn.protocol === 3) { + // @ts-ignore + this.id = nsp.name !== "/" ? nsp.name + "#" + client.id : client.id; + } else { + this.id = base64id.generateId(); // don't reuse the Engine.IO id because it's sensitive information + } + this.pid = base64id.generateId(); } this.handshake = this.buildHandshake(auth); } @@ -299,8 +339,18 @@ export class Socket< const flags = Object.assign({}, this.flags); this.flags = {}; - this.notifyOutgoingListeners(packet); - this.packet(packet, flags); + // @ts-ignore + if (this.nsp.server.opts.connectionStateRecovery) { + // this ensures the packet is stored and can be transmitted upon reconnection + this.adapter.broadcast(packet, { + rooms: new Set([this.id]), + except: new Set(), + flags, + }); + } else { + this.notifyOutgoingListeners(packet); + this.packet(packet, flags); + } return true; } @@ -508,7 +558,10 @@ export class Socket< if (this.conn.protocol === 3) { this.packet({ type: PacketType.CONNECT }); } else { - this.packet({ type: PacketType.CONNECT, data: { sid: this.id } }); + this.packet({ + type: PacketType.CONNECT, + data: { sid: this.id, pid: this.pid }, + }); } } @@ -644,6 +697,17 @@ export class Socket< if (!this.connected) return this; debug("closing socket - reason %s", reason); this.emitReserved("disconnecting", reason); + + if (RECOVERABLE_DISCONNECT_REASONS.has(reason)) { + debug("connection state recovery is enabled for sid %s", this.id); + this.adapter.persistSession({ + sid: this.id, + pid: this.pid, + rooms: [...this.rooms], + data: this.data, + }); + } + this._cleanup(); this.nsp._remove(this); this.client._remove(this); diff --git a/test/close.ts b/test/close.ts index a48368f9d1..d0ecbe91ff 100644 --- a/test/close.ts +++ b/test/close.ts @@ -4,46 +4,13 @@ import { join } from "path"; import { exec } from "child_process"; import { Server } from ".."; import expect from "expect.js"; -import { createClient, getPort } from "./support/util"; -import request from "supertest"; - -// TODO: update superagent as latest release now supports promises -const eioHandshake = (httpServer): Promise => { - return new Promise((resolve) => { - request(httpServer) - .get("/socket.io/") - .query({ transport: "polling", EIO: 4 }) - .end((err, res) => { - const sid = JSON.parse(res.text.substring(1)).sid; - resolve(sid); - }); - }); -}; - -const eioPush = (httpServer, sid: string, body: string): Promise => { - return new Promise((resolve) => { - request(httpServer) - .post("/socket.io/") - .send(body) - .query({ transport: "polling", EIO: 4, sid }) - .expect(200) - .end(() => { - resolve(); - }); - }); -}; - -const eioPoll = (httpServer, sid): Promise => { - return new Promise((resolve) => { - request(httpServer) - .get("/socket.io/") - .query({ transport: "polling", EIO: 4, sid }) - .expect(200) - .end((err, res) => { - resolve(res.text); - }); - }); -}; +import { + createClient, + eioHandshake, + eioPoll, + eioPush, + getPort, +} from "./support/util"; describe("close", () => { it("should be able to close sio sending a srv", (done) => { diff --git a/test/connection-state-recovery.ts b/test/connection-state-recovery.ts new file mode 100644 index 0000000000..d43b5cae07 --- /dev/null +++ b/test/connection-state-recovery.ts @@ -0,0 +1,196 @@ +import { Server, Socket } from ".."; +import expect from "expect.js"; +import { waitFor, eioHandshake, eioPush, eioPoll } from "./support/util"; +import { createServer, Server as HttpServer } from "http"; + +async function init(httpServer: HttpServer, io: Server) { + // Engine.IO handshake + const sid = await eioHandshake(httpServer); + + // Socket.IO handshake + await eioPush(httpServer, sid, "40"); + const handshakeBody = await eioPoll(httpServer, sid); + + expect(handshakeBody.startsWith("40")).to.be(true); + + const handshake = JSON.parse(handshakeBody.substring(2)); + + expect(handshake.sid).to.not.be(undefined); + // in that case, the handshake also contains a private session ID + expect(handshake.pid).to.not.be(undefined); + + io.emit("hello"); + + const message = await eioPoll(httpServer, sid); + + expect(message.startsWith('42["hello"')).to.be(true); + + const offset = JSON.parse(message.substring(2))[1]; + // in that case, each packet also includes an offset in the data array + expect(offset).to.not.be(undefined); + + await eioPush(httpServer, sid, "1"); + + return [handshake.sid, handshake.pid, offset]; +} + +describe("connection state recovery", () => { + it("should restore session and missed packets", async () => { + const httpServer = createServer().listen(0); + const io = new Server(httpServer, { + connectionStateRecovery: {}, + }); + + let serverSocket; + + io.once("connection", (socket) => { + socket.join("room1"); + serverSocket = socket; + }); + + const [sid, pid, offset] = await init(httpServer, io); + + io.emit("hello1"); // broadcast + io.to("room1").emit("hello2"); // broadcast to room + serverSocket.emit("hello3"); // direct message + + const newSid = await eioHandshake(httpServer); + await eioPush( + httpServer, + newSid, + `40{"pid":"${pid}","offset":"${offset}"}` + ); + + const payload = await eioPoll(httpServer, newSid); + const packets = payload.split("\x1e"); + + expect(packets.length).to.eql(4); + + // note: EVENT packets are received before the CONNECT packet, which is a bit weird + // see also: https://github.com/socketio/socket.io-deno/commit/518f534e1c205b746b1cb21fe76b187dabc96f34 + expect(packets[0].startsWith('42["hello1"')).to.be(true); + expect(packets[1].startsWith('42["hello2"')).to.be(true); + expect(packets[2].startsWith('42["hello3"')).to.be(true); + expect(packets[3]).to.eql(`40{"sid":"${sid}","pid":"${pid}"}`); + + io.close(); + }); + + it("should restore rooms and data attributes", async () => { + const httpServer = createServer().listen(0); + const io = new Server(httpServer, { + connectionStateRecovery: {}, + }); + + io.once("connection", (socket) => { + expect(socket.recovered).to.eql(false); + + socket.join("room1"); + socket.join("room2"); + socket.data.foo = "bar"; + }); + + const [sid, pid, offset] = await init(httpServer, io); + + const newSid = await eioHandshake(httpServer); + + const [socket] = await Promise.all([ + waitFor(io, "connection"), + eioPush(httpServer, newSid, `40{"pid":"${pid}","offset":"${offset}"}`), + ]); + + expect(socket.id).to.eql(sid); + expect(socket.recovered).to.eql(true); + + expect(socket.rooms.has(socket.id)).to.eql(true); + expect(socket.rooms.has("room1")).to.eql(true); + expect(socket.rooms.has("room2")).to.eql(true); + + expect(socket.data.foo).to.eql("bar"); + + await eioPoll(httpServer, newSid); // drain buffer + io.close(); + }); + + it("should not run middlewares upon recovery by default", async () => { + const httpServer = createServer().listen(0); + const io = new Server(httpServer, { + connectionStateRecovery: {}, + }); + + const [_, pid, offset] = await init(httpServer, io); + + io.use((socket, next) => { + socket.data.middlewareWasCalled = true; + + next(); + }); + + const newSid = await eioHandshake(httpServer); + + const [socket] = await Promise.all([ + waitFor(io, "connection"), + eioPush(httpServer, newSid, `40{"pid":"${pid}","offset":"${offset}"}`), + ]); + + expect(socket.recovered).to.be(true); + expect(socket.data.middlewareWasCalled).to.be(undefined); + + await eioPoll(httpServer, newSid); // drain buffer + io.close(); + }); + + it("should run middlewares even upon recovery", async () => { + const httpServer = createServer().listen(0); + const io = new Server(httpServer, { + connectionStateRecovery: { + skipMiddlewares: false, + }, + }); + + const [_, pid, offset] = await init(httpServer, io); + + io.use((socket, next) => { + socket.data.middlewareWasCalled = true; + + next(); + }); + + const newSid = await eioHandshake(httpServer); + + const [socket] = await Promise.all([ + waitFor(io, "connection"), + eioPush(httpServer, newSid, `40{"pid":"${pid}","offset":"${offset}"}`), + ]); + + expect(socket.recovered).to.be(true); + expect(socket.data.middlewareWasCalled).to.be(true); + + await eioPoll(httpServer, newSid); // drain buffer + io.close(); + }); + + it("should fail to restore an unknown session", async () => { + const httpServer = createServer().listen(0); + const io = new Server(httpServer, { + connectionStateRecovery: {}, + }); + + // Engine.IO handshake + const sid = await eioHandshake(httpServer); + + // Socket.IO handshake + await eioPush(httpServer, sid, '40{"pid":"foo","offset":"bar"}'); + + const handshakeBody = await eioPoll(httpServer, sid); + + expect(handshakeBody.startsWith("40")).to.be(true); + + const handshake = JSON.parse(handshakeBody.substring(2)); + + expect(handshake.sid).to.not.eql("foo"); + expect(handshake.pid).to.not.eql("bar"); + + io.close(); + }); +}); diff --git a/test/index.ts b/test/index.ts index 880b81e52d..7e21589f62 100644 --- a/test/index.ts +++ b/test/index.ts @@ -20,4 +20,5 @@ describe("socket.io", () => { require("./socket-timeout"); require("./uws"); require("./utility-methods"); + require("./connection-state-recovery"); }); diff --git a/test/support/util.ts b/test/support/util.ts index f052f51658..4467159e3a 100644 --- a/test/support/util.ts +++ b/test/support/util.ts @@ -5,6 +5,7 @@ import { Socket as ClientSocket, SocketOptions, } from "socket.io-client"; +import request from "supertest"; const expect = require("expect.js"); const i = expect.stringify; @@ -73,8 +74,46 @@ export function createPartialDone(count: number, done: (err?: Error) => void) { }; } -export function waitFor(emitter, event) { - return new Promise((resolve) => { +export function waitFor(emitter, event) { + return new Promise((resolve) => { emitter.once(event, resolve); }); } + +// TODO: update superagent as latest release now supports promises +export function eioHandshake(httpServer): Promise { + return new Promise((resolve) => { + request(httpServer) + .get("/socket.io/") + .query({ transport: "polling", EIO: 4 }) + .end((err, res) => { + const sid = JSON.parse(res.text.substring(1)).sid; + resolve(sid); + }); + }); +} + +export function eioPush(httpServer, sid: string, body: string): Promise { + return new Promise((resolve) => { + request(httpServer) + .post("/socket.io/") + .send(body) + .query({ transport: "polling", EIO: 4, sid }) + .expect(200) + .end(() => { + resolve(); + }); + }); +} + +export function eioPoll(httpServer, sid): Promise { + return new Promise((resolve) => { + request(httpServer) + .get("/socket.io/") + .query({ transport: "polling", EIO: 4, sid }) + .expect(200) + .end((err, res) => { + resolve(res.text); + }); + }); +}