diff --git a/packages/client/lib/RESP/encoder.ts b/packages/client/lib/RESP/encoder.ts index af857711dc..854bedb60a 100644 --- a/packages/client/lib/RESP/encoder.ts +++ b/packages/client/lib/RESP/encoder.ts @@ -2,7 +2,7 @@ import { RedisArgument } from './types'; const CRLF = '\r\n'; -export default function encodeCommand(args: Array): Array { +export default function encodeCommand(args: ReadonlyArray): Array { const toWrite: Array = []; let strings = '*' + args.length + CRLF; diff --git a/packages/client/lib/client/commands-queue.ts b/packages/client/lib/client/commands-queue.ts index a4029779fc..43faae8330 100644 --- a/packages/client/lib/client/commands-queue.ts +++ b/packages/client/lib/client/commands-queue.ts @@ -1,8 +1,8 @@ import { SinglyLinkedList, DoublyLinkedNode, DoublyLinkedList } from './linked-list'; import encodeCommand from '../RESP/encoder'; import { Decoder, PUSH_TYPE_MAPPING, RESP_TYPES } from '../RESP/decoder'; -import { CommandArguments, TypeMapping, ReplyUnion, RespVersions } from '../RESP/types'; -import { ChannelListeners, PubSub, PubSubCommand, PubSubListener, PubSubType, PubSubTypeListeners } from './pub-sub'; +import { TypeMapping, ReplyUnion, RespVersions, RedisArgument } from '../RESP/types'; +import { COMMANDS, ChannelListeners, PUBSUB_TYPE, PubSub, PubSubCommand, PubSubListener, PubSubType, PubSubTypeListeners } from './pub-sub'; import { AbortError, ErrorReply } from '../errors'; import { MonitorCallback } from '.'; @@ -17,7 +17,7 @@ export interface CommandOptions { } export interface CommandToWrite extends CommandWaitingForReply { - args: CommandArguments; + args: ReadonlyArray; chainId: symbol | undefined; abort: { signal: AbortSignal; @@ -51,6 +51,7 @@ export default class RedisCommandsQueue { #chainInExecution: symbol | undefined; readonly decoder; readonly #pubSub = new PubSub(); + readonly #pushHandlers: Map) => unknown>> = new Map(); get isPubSubActive() { return this.#pubSub.isActive; @@ -64,6 +65,17 @@ export default class RedisCommandsQueue { this.#respVersion = respVersion; this.#maxLength = maxLength; this.#onShardedChannelMoved = onShardedChannelMoved; + + this.#addPushHandler(COMMANDS[PUBSUB_TYPE.CHANNELS].message.toString(), this.#pubSub.handleMessageReplyChannel.bind(this.#pubSub)); + this.#addPushHandler(COMMANDS[PUBSUB_TYPE.CHANNELS].subscribe.toString(), this.#handleStatusReply.bind(this)); + this.#addPushHandler(COMMANDS[PUBSUB_TYPE.CHANNELS].unsubscribe.toString(), this.#handleStatusReply.bind(this)); + this.#addPushHandler(COMMANDS[PUBSUB_TYPE.PATTERNS].message.toString(), this.#pubSub.handleMessageReplyPattern.bind(this.#pubSub)); + this.#addPushHandler(COMMANDS[PUBSUB_TYPE.PATTERNS].subscribe.toString(), this.#handleStatusReply.bind(this)); + this.#addPushHandler(COMMANDS[PUBSUB_TYPE.PATTERNS].unsubscribe.toString(), this.#handleStatusReply.bind(this)); + this.#addPushHandler(COMMANDS[PUBSUB_TYPE.SHARDED].message.toString(), this.#pubSub.handleMessageReplySharded.bind(this.#pubSub)); + this.#addPushHandler(COMMANDS[PUBSUB_TYPE.SHARDED].subscribe.toString(), this.#handleStatusReply.bind(this)); + this.#addPushHandler(COMMANDS[PUBSUB_TYPE.SHARDED].unsubscribe.toString(), this.#handleShardedUnsubscribe.bind(this)); + this.decoder = this.#initiateDecoder(); } @@ -75,28 +87,68 @@ export default class RedisCommandsQueue { this.#waitingForReply.shift()!.reject(err); } - #onPush(push: Array) { - // TODO: type - if (this.#pubSub.handleMessageReply(push)) return true; - - const isShardedUnsubscribe = PubSub.isShardedUnsubscribe(push); - if (isShardedUnsubscribe && !this.#waitingForReply.length) { + #handleStatusReply(push: ReadonlyArray) { + const head = this.#waitingForReply.head!.value; + if ( + (Number.isNaN(head.channelsCounter!) && push[2] === 0) || + --head.channelsCounter! === 0 + ) { + this.#waitingForReply.shift()!.resolve(); + } + } + + #handleShardedUnsubscribe(push: ReadonlyArray) { + if (!this.#waitingForReply.length) { const channel = push[1].toString(); this.#onShardedChannelMoved( channel, this.#pubSub.removeShardedListeners(channel) ); - return true; - } else if (isShardedUnsubscribe || PubSub.isStatusReply(push)) { - const head = this.#waitingForReply.head!.value; - if ( - (Number.isNaN(head.channelsCounter!) && push[2] === 0) || - --head.channelsCounter! === 0 - ) { - this.#waitingForReply.shift()!.resolve(); + } else { + this.#handleStatusReply(push); + } + } + + #addPushHandler(messageType: string, handler: (pushMsg: ReadonlyArray) => unknown) { + let handlerMap = this.#pushHandlers.get(messageType); + if (handlerMap === undefined) { + handlerMap = new Map(); + this.#pushHandlers.set(messageType, handlerMap); + } + + const symbol = Symbol(messageType); + handlerMap.set(symbol, handler); + + return symbol; + } + + addPushHandler(messageType: string, handler: (pushMsg: ReadonlyArray) => unknown) { + if (this.#respVersion !== 3) throw new Error("cannot add push handlers to resp2 clients") + + return this.#addPushHandler(messageType, handler); + } + + removePushHandler(symbol: Symbol) { + const handlers = this.#pushHandlers.get(symbol.description!); + if (handlers) { + handlers.delete(symbol); + if (handlers.size === 0) { + this.#pushHandlers.delete(symbol.description!); } + } + } + + #onPush(push: Array) { + const handlers = this.#pushHandlers.get(push[0].toString()); + if (handlers) { + for (const handler of handlers.values()) { + handler(push); + } + return true; } + + return false; } #getTypeMapping() { @@ -108,16 +160,14 @@ export default class RedisCommandsQueue { onReply: reply => this.#onReply(reply), onErrorReply: err => this.#onErrorReply(err), onPush: push => { - if (!this.#onPush(push)) { - - } + return this.#onPush(push); }, getTypeMapping: () => this.#getTypeMapping() }); } addCommand( - args: CommandArguments, + args: ReadonlyArray, options?: CommandOptions ): Promise { if (this.#maxLength && this.#toWrite.length + this.#waitingForReply.length >= this.#maxLength) { @@ -346,7 +396,7 @@ export default class RedisCommandsQueue { *commandsToWrite() { let toSend = this.#toWrite.shift(); while (toSend) { - let encoded: CommandArguments; + let encoded: ReadonlyArray try { encoded = encodeCommand(toSend.args); } catch (err) { diff --git a/packages/client/lib/client/index.spec.ts b/packages/client/lib/client/index.spec.ts index 2fd689b9d7..8f5852af98 100644 --- a/packages/client/lib/client/index.spec.ts +++ b/packages/client/lib/client/index.spec.ts @@ -9,6 +9,7 @@ import { MATH_FUNCTION, loadMathFunction } from '../commands/FUNCTION_LOAD.spec' import { RESP_TYPES } from '../RESP/decoder'; import { BlobStringReply, NumberReply } from '../RESP/types'; import { SortedSetMember } from '../commands/generic-transformers'; +import { COMMANDS, PUBSUB_TYPE } from './pub-sub'; export const SQUARE_SCRIPT = defineScript({ SCRIPT: @@ -769,4 +770,39 @@ describe('Client', () => { } }, GLOBAL.SERVERS.OPEN); }); + + describe('Push Handlers', () => { + testUtils.testWithClient('RESP3: add/remove invalidate handler, and validate its called', async client => { + const key = 'x' + + let nodeResolve; + + const promise = new Promise((res) => { + nodeResolve = res; + }); + + const symbol = client.addPushHandler("invalidate", (push: ReadonlyArray) => { + assert.equal(push[0].toString(), "invalidate"); + assert.equal(push[1].length, 1); + assert.equal(push[1].length, 1); + assert.equal(push[1][0].toString(), key); + // this test removing the handler, + // as flushAll in cleanup of test will issue a full invalidate, + // which would fail if this handler is called on it + client.removePushHandler(symbol); + nodeResolve(); + }) + + await client.sendCommand(['CLIENT', 'TRACKING', 'ON']); + await client.get(key); + await client.set(key, '1'); + + await promise; + }, { + ...GLOBAL.SERVERS.OPEN, + clientOptions: { + RESP: 3 + } + }); + }); }); diff --git a/packages/client/lib/client/index.ts b/packages/client/lib/client/index.ts index 3efa793eeb..43811fde9b 100644 --- a/packages/client/lib/client/index.ts +++ b/packages/client/lib/client/index.ts @@ -573,6 +573,14 @@ export default class RedisClient< return this as unknown as RedisClientType; } + addPushHandler(messageType: string, handler: (pushMsg: ReadonlyArray) => unknown) { + return this._self.#queue.addPushHandler(messageType, handler); + } + + removePushHandler(symbol: Symbol) { + this._self.#queue.removePushHandler(symbol); + } + sendCommand( args: Array, options?: CommandOptions diff --git a/packages/client/lib/client/pub-sub.ts b/packages/client/lib/client/pub-sub.ts index 1387aea841..c5b0ba2409 100644 --- a/packages/client/lib/client/pub-sub.ts +++ b/packages/client/lib/client/pub-sub.ts @@ -11,7 +11,7 @@ export type PUBSUB_TYPE = typeof PUBSUB_TYPE; export type PubSubType = PUBSUB_TYPE[keyof PUBSUB_TYPE]; -const COMMANDS = { +export const COMMANDS = { [PUBSUB_TYPE.CHANNELS]: { subscribe: Buffer.from('subscribe'), unsubscribe: Buffer.from('unsubscribe'), @@ -344,28 +344,37 @@ export class PubSub { return commands; } + handleMessageReplyChannel(push: ReadonlyArray) { + this.#emitPubSubMessage( + PUBSUB_TYPE.CHANNELS, + push[2], + push[1] + ); + } + + handleMessageReplyPattern(push: ReadonlyArray) { + this.#emitPubSubMessage( + PUBSUB_TYPE.PATTERNS, + push[3], + push[2], + push[1] + ); + } + + handleMessageReplySharded(push: ReadonlyArray) { + this.#emitPubSubMessage( + PUBSUB_TYPE.SHARDED, + push[2], + push[1] + ); + } + handleMessageReply(reply: Array): boolean { if (COMMANDS[PUBSUB_TYPE.CHANNELS].message.equals(reply[0])) { - this.#emitPubSubMessage( - PUBSUB_TYPE.CHANNELS, - reply[2], - reply[1] - ); return true; } else if (COMMANDS[PUBSUB_TYPE.PATTERNS].message.equals(reply[0])) { - this.#emitPubSubMessage( - PUBSUB_TYPE.PATTERNS, - reply[3], - reply[2], - reply[1] - ); return true; } else if (COMMANDS[PUBSUB_TYPE.SHARDED].message.equals(reply[0])) { - this.#emitPubSubMessage( - PUBSUB_TYPE.SHARDED, - reply[2], - reply[1] - ); return true; } diff --git a/packages/client/lib/client/socket.ts b/packages/client/lib/client/socket.ts index dcadad4c3d..384dd7364e 100644 --- a/packages/client/lib/client/socket.ts +++ b/packages/client/lib/client/socket.ts @@ -271,7 +271,7 @@ export default class RedisSocket extends EventEmitter { }); } - write(iterable: Iterable>) { + write(iterable: Iterable>) { if (!this.#socket) return; this.#socket.cork();