diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 6edadccfb..0c3bbf979 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -184,7 +184,11 @@ jobs: npm install npm link streamr-client ./gradlew fatjar - - name: run-client-testing - timeout-minutes: 10 + - uses: nick-invision/retry@v2 + name: run-client-testing working-directory: streamr-client-testing - run: java -jar build/libs/client_testing-1.0-SNAPSHOT.jar -s $TEST_NAME -c config/$CONFIG_NAME.conf -n $NUM_MESSAGES + with: + max_attempts: 2 + timeout_minutes: 3 + retry_on: error + command: java -jar build/libs/client_testing-1.0-SNAPSHOT.jar -s $TEST_NAME -c config/$CONFIG_NAME.conf -n $NUM_MESSAGES diff --git a/src/Config.ts b/src/Config.ts index 7f156ac27..8a0cf7be3 100644 --- a/src/Config.ts +++ b/src/Config.ts @@ -20,6 +20,8 @@ export type EthereumConfig = ExternalProvider|JsonRpcFetchFunc * @category Important */ export type StrictStreamrClientOptions = { + /** Custom human-readable debug id for client. Used in logging. Unique id will be generated regardless. */ + id?: string, /** * Authentication: identity used by this StreamrClient instance. * Can contain member privateKey or (window.)ethereum @@ -169,8 +171,8 @@ export default function ClientConfig(opts: StreamrClientOptions = {}) { ...opts.dataUnion }, cache: { - ...opts.cache, ...STREAM_CLIENT_DEFAULTS.cache, + ...opts.cache, } // NOTE: sidechain is not merged with the defaults } diff --git a/src/Ethereum.js b/src/Ethereum.js index f159ec236..6d36b4d6e 100644 --- a/src/Ethereum.js +++ b/src/Ethereum.js @@ -20,8 +20,8 @@ export default class StreamrEthereum { const key = auth.privateKey const address = getAddress(computeAddress(key)) this._getAddress = async () => address - this.getSigner = () => new Wallet(key, this.getMainnetProvider()) - this.getSidechainSigner = async () => new Wallet(key, this.getSidechainProvider()) + this._getSigner = () => new Wallet(key, this.getMainnetProvider()) + this._getSidechainSigner = async () => new Wallet(key, this.getSidechainProvider()) } else if (auth.ethereum) { this._getAddress = async () => { try { @@ -61,6 +61,10 @@ export default class StreamrEthereum { } } + canEncrypt() { + return !!(this._getAddress && this._getSigner) + } + async getAddress() { if (!this._getAddress) { // _getAddress is assigned in constructor diff --git a/src/StreamrClient.ts b/src/StreamrClient.ts index 7ae237a1b..cf77534a1 100644 --- a/src/StreamrClient.ts +++ b/src/StreamrClient.ts @@ -188,7 +188,7 @@ export class StreamrClient extends EventEmitter { // eslint-disable-line no-rede // TODO annotate connection parameter as internal parameter if possible? constructor(options: StreamrClientOptions = {}, connection?: StreamrConnection) { super() - this.id = counterId(`${this.constructor.name}:${uid}`) + this.id = counterId(`${this.constructor.name}:${uid}${options.id || ''}`) this.debug = Debug(this.id) this.options = Config(options) @@ -217,9 +217,9 @@ export class StreamrClient extends EventEmitter { // eslint-disable-line no-rede .on('disconnected', this.onConnectionDisconnected) .on('error', this.onConnectionError) + this.ethereum = new StreamrEthereum(this) this.publisher = Publisher(this) this.subscriber = new Subscriber(this) - this.ethereum = new StreamrEthereum(this) Plugin(this, new StreamEndpoints(this)) Plugin(this, new LoginEndpoints(this)) @@ -357,10 +357,14 @@ export class StreamrClient extends EventEmitter { // eslint-disable-line no-rede let subTask: Todo let sub: Todo const hasResend = !!(opts.resend || opts.from || opts.to || opts.last) - const onEnd = () => { + const onEnd = (err?: Error) => { if (sub && typeof onMessage === 'function') { sub.off('message', onMessage) } + + if (err) { + throw err + } } if (hasResend) { @@ -429,6 +433,13 @@ export class StreamrClient extends EventEmitter { // eslint-disable-line no-rede return this.getAddress() } + /** + * True if authenticated with private key/ethereum provider + */ + canEncrypt() { + return this.ethereum.canEncrypt() + } + /** * Get token balance in "wei" (10^-18 parts) for given address */ diff --git a/src/publish/Encrypt.ts b/src/publish/Encrypt.ts index b35e265b4..f3a752708 100644 --- a/src/publish/Encrypt.ts +++ b/src/publish/Encrypt.ts @@ -10,15 +10,27 @@ const { StreamMessage } = MessageLayer type PublisherKeyExhangeAPI = ReturnType export default function Encrypt(client: StreamrClient) { - const publisherKeyExchange = PublisherKeyExhange(client, { - groupKeys: { - ...client.options.groupKeys, + let publisherKeyExchange: ReturnType + + function getPublisherKeyExchange() { + if (!publisherKeyExchange) { + publisherKeyExchange = PublisherKeyExhange(client, { + groupKeys: { + ...client.options.groupKeys, + } + }) } - }) + return publisherKeyExchange + } + async function encrypt(streamMessage: MessageLayer.StreamMessage, stream: Stream) { + if (!client.canEncrypt()) { + return + } + if ( - !publisherKeyExchange.hasAnyGroupKey(stream.id) - && !stream.requireEncryptedData + !stream.requireEncryptedData + && !getPublisherKeyExchange().hasAnyGroupKey(stream.id) ) { // not needed return @@ -27,19 +39,23 @@ export default function Encrypt(client: StreamrClient) { if (streamMessage.messageType !== StreamMessage.MESSAGE_TYPES.MESSAGE) { return } - const groupKey = await publisherKeyExchange.useGroupKey(stream.id) + const groupKey = await getPublisherKeyExchange().useGroupKey(stream.id) await EncryptionUtil.encryptStreamMessage(streamMessage, groupKey) } return Object.assign(encrypt, { setNextGroupKey(...args: Parameters) { - return publisherKeyExchange.setNextGroupKey(...args) + return getPublisherKeyExchange().setNextGroupKey(...args) }, rotateGroupKey(...args: Parameters) { - return publisherKeyExchange.rotateGroupKey(...args) + return getPublisherKeyExchange().rotateGroupKey(...args) + }, + start() { + return getPublisherKeyExchange().start() }, stop() { - return publisherKeyExchange.stop() + if (!publisherKeyExchange) { return Promise.resolve() } + return getPublisherKeyExchange().stop() } }) } diff --git a/src/publish/index.js b/src/publish/index.js index 292065a08..3ce58d644 100644 --- a/src/publish/index.js +++ b/src/publish/index.js @@ -122,6 +122,7 @@ function getCreateStreamMessage(client) { [streamId, streamPartition, publisherId, msgChainId].join('|') ), ...cacheOptions, + maxAge: undefined }), { clear() { mem.clear(getMsgChainer) @@ -188,6 +189,9 @@ function getCreateStreamMessage(client) { rotateGroupKey(maybeStreamId) { return encrypt.rotateGroupKey(maybeStreamId) }, + startKeyExchange() { + return encrypt.start() + }, clear() { computeStreamPartition.clear() getMsgChainer.clear() @@ -300,6 +304,9 @@ export default function Publisher(client) { throw error } }, + async startKeyExchange() { + return createStreamMessage.startKeyExchange() + }, async stop() { sendQueue.clear() createStreamMessage.clear() diff --git a/src/stream/Encryption.js b/src/stream/Encryption.ts similarity index 60% rename from src/stream/Encryption.js rename to src/stream/Encryption.ts index d3b2771f1..161c484ec 100644 --- a/src/stream/Encryption.js +++ b/src/stream/Encryption.ts @@ -1,5 +1,6 @@ import crypto from 'crypto' import util from 'util' +import { O } from 'ts-toolbelt' // this is shimmed out for actual browser build allows us to run tests in node against browser API import { Crypto } from 'node-webcrypto-ossl' @@ -8,8 +9,11 @@ import { MessageLayer } from 'streamr-client-protocol' import { uuid } from '../utils' +const { StreamMessage, EncryptedGroupKey } = MessageLayer + export class UnableToDecryptError extends Error { - constructor(message = '', streamMessage) { + streamMessage: MessageLayer.StreamMessage + constructor(message = '', streamMessage: MessageLayer.StreamMessage) { super(`Unable to decrypt. ${message} ${util.inspect(streamMessage)}`) this.streamMessage = streamMessage if (Error.captureStackTrace) { @@ -19,7 +23,8 @@ export class UnableToDecryptError extends Error { } class InvalidGroupKeyError extends Error { - constructor(message, groupKey) { + groupKey: GroupKey | any + constructor(message: string, groupKey?: GroupKey) { super(message) this.groupKey = groupKey if (Error.captureStackTrace) { @@ -28,36 +33,75 @@ class InvalidGroupKeyError extends Error { } } -export class GroupKey { +type GroupKeyObject = { + id: string, + hex: string, + data: Uint8Array, +} + +type GroupKeyProps = { + groupKeyId: string, + groupKeyHex: string, + groupKeyData: Uint8Array, +} + +function GroupKeyObjectFromProps(data: GroupKeyProps | GroupKeyObject) { + if ('groupKeyId' in data) { + return { + id: data.groupKeyId, + hex: data.groupKeyHex, + data: data.groupKeyData, + } + } + + return data +} + +interface GroupKey extends GroupKeyObject {} + +// eslint-disable-next-line no-redeclare +class GroupKey { static InvalidGroupKeyError = InvalidGroupKeyError - static validate(maybeGroupKey) { + static validate(maybeGroupKey: GroupKey) { if (!maybeGroupKey) { - throw new InvalidGroupKeyError(`value must be a ${this.name}: ${util.inspect(maybeGroupKey)}`) + throw new InvalidGroupKeyError(`value must be a ${this.name}: ${util.inspect(maybeGroupKey)}`, maybeGroupKey) } if (!(maybeGroupKey instanceof this)) { - throw new InvalidGroupKeyError(`value must be a ${this.name}: ${util.inspect(maybeGroupKey)}`) + throw new InvalidGroupKeyError(`value must be a ${this.name}: ${util.inspect(maybeGroupKey)}`, maybeGroupKey) } if (!maybeGroupKey.id || typeof maybeGroupKey.id !== 'string') { - throw new InvalidGroupKeyError(`${this.name} id must be a string: ${util.inspect(maybeGroupKey)}`) + throw new InvalidGroupKeyError(`${this.name} id must be a string: ${util.inspect(maybeGroupKey)}`, maybeGroupKey) + } + + if (maybeGroupKey.id.includes('---BEGIN')) { + throw new InvalidGroupKeyError( + `${this.name} public/private key is not a valid group key id: ${util.inspect(maybeGroupKey)}`, + maybeGroupKey + ) } if (!maybeGroupKey.data || !Buffer.isBuffer(maybeGroupKey.data)) { - throw new InvalidGroupKeyError(`${this.name} data must be a buffer: ${util.inspect(maybeGroupKey)}`) + throw new InvalidGroupKeyError(`${this.name} data must be a Buffer: ${util.inspect(maybeGroupKey)}`, maybeGroupKey) } if (!maybeGroupKey.hex || typeof maybeGroupKey.hex !== 'string') { - throw new InvalidGroupKeyError(`${this.name} hex must be a string: ${util.inspect(maybeGroupKey)}`) + throw new InvalidGroupKeyError(`${this.name} hex must be a string: ${util.inspect(maybeGroupKey)}`, maybeGroupKey) } if (maybeGroupKey.data.length !== 32) { - throw new InvalidGroupKeyError(`Group key must have a size of 256 bits, not ${maybeGroupKey.data.length * 8}`) + throw new InvalidGroupKeyError(`Group key must have a size of 256 bits, not ${maybeGroupKey.data.length * 8}`, maybeGroupKey) } + } - constructor(groupKeyId, groupKeyBufferOrHexString) { + id: string + hex: string + data: Uint8Array + + constructor(groupKeyId: string, groupKeyBufferOrHexString: Uint8Array | string) { this.id = groupKeyId if (!groupKeyId) { throw new InvalidGroupKeyError(`groupKeyId must not be falsey ${util.inspect(groupKeyId)}`) @@ -75,10 +119,11 @@ export class GroupKey { this.hex = Buffer.from(this.data).toString('hex') } - this.constructor.validate(this) + // eslint-disable-next-line no-extra-semi + ;(this.constructor as typeof GroupKey).validate(this) } - equals(other) { + equals(other: GroupKey) { if (!(other instanceof GroupKey)) { return false } @@ -90,12 +135,20 @@ export class GroupKey { return this.id } + toArray() { + return [this.id, this.hex] + } + + serialize() { + return JSON.stringify(this.toArray()) + } + static generate(id = uuid('GroupKey')) { const keyBytes = crypto.randomBytes(32) return new GroupKey(id, keyBytes) } - static from(maybeGroupKey) { + static from(maybeGroupKey: GroupKey | GroupKeyObject | ConstructorParameters) { if (!maybeGroupKey || typeof maybeGroupKey !== 'object') { throw new InvalidGroupKeyError(`Group key must be object ${util.inspect(maybeGroupKey)}`) } @@ -105,29 +158,35 @@ export class GroupKey { } try { - return new GroupKey(maybeGroupKey.id || maybeGroupKey.groupKeyId, maybeGroupKey.hex || maybeGroupKey.data || maybeGroupKey.groupKeyHex) + if (Array.isArray(maybeGroupKey)) { + return new GroupKey(maybeGroupKey[0], maybeGroupKey[1]) + } + + const groupKeyObj = GroupKeyObjectFromProps(maybeGroupKey) + return new GroupKey(groupKeyObj.id, groupKeyObj.hex || groupKeyObj.data) } catch (err) { if (err instanceof InvalidGroupKeyError) { // wrap err with logging of original object - throw new InvalidGroupKeyError(`${err.message}. From: ${util.inspect(maybeGroupKey)}`) + throw new InvalidGroupKeyError(`${err.stack}. From: ${util.inspect(maybeGroupKey)}`) } throw err } } } -const { StreamMessage } = MessageLayer +export { GroupKey } -function ab2str(buf) { - return String.fromCharCode.apply(null, new Uint8Array(buf)) +function ab2str(...args: any[]) { + // @ts-ignore + return String.fromCharCode.apply(null, new Uint8Array(...args)) } // shim browser btoa for node -function btoa(str) { - if (global.btoa) { return global.btoa(str) } +function btoa(str: string | Uint8Array) { + if (global.btoa) { return global.btoa(str as string) } let buffer - if (str instanceof Buffer) { + if (Buffer.isBuffer(str)) { buffer = str } else { buffer = Buffer.from(str.toString(), 'binary') @@ -136,7 +195,7 @@ function btoa(str) { return buffer.toString('base64') } -async function exportCryptoKey(key, { isPrivate = false } = {}) { +async function exportCryptoKey(key: CryptoKey, { isPrivate = false } = {}) { const WebCrypto = new Crypto() const keyType = isPrivate ? 'pkcs8' : 'spki' const exported = await WebCrypto.subtle.exportKey(keyType, key) @@ -149,28 +208,26 @@ async function exportCryptoKey(key, { isPrivate = false } = {}) { // put all static functions into EncryptionUtilBase, with exception of create, // so it's clearer what the static & instance APIs look like class EncryptionUtilBase { - static validatePublicKey(publicKey) { - if (typeof publicKey !== 'string' || !publicKey.startsWith('-----BEGIN PUBLIC KEY-----') - || !publicKey.endsWith('-----END PUBLIC KEY-----\n')) { - throw new Error('"publicKey" must be a PKCS#8 RSA public key as a string in the PEM format') + static validatePublicKey(publicKey: crypto.KeyLike) { + const keyString = typeof publicKey === 'string' ? publicKey : publicKey.toString('utf8') + if (typeof keyString !== 'string' || !keyString.startsWith('-----BEGIN PUBLIC KEY-----') + || !keyString.endsWith('-----END PUBLIC KEY-----\n')) { + throw new Error('"publicKey" must be a PKCS#8 RSA public key in the PEM format') } } - static validatePrivateKey(privateKey) { - if (typeof privateKey !== 'string' || !privateKey.startsWith('-----BEGIN PRIVATE KEY-----') - || !privateKey.endsWith('-----END PRIVATE KEY-----\n')) { - throw new Error('"privateKey" must be a PKCS#8 RSA public key as a string in the PEM format') + static validatePrivateKey(privateKey: crypto.KeyLike) { + const keyString = typeof privateKey === 'string' ? privateKey : privateKey.toString('utf8') + if (typeof keyString !== 'string' || !keyString.startsWith('-----BEGIN PRIVATE KEY-----') + || !keyString.endsWith('-----END PRIVATE KEY-----\n')) { + throw new Error('"privateKey" must be a PKCS#8 RSA public key in the PEM format') } } - static validateGroupKey(groupKey) { - return GroupKey.validate(groupKey) - } - /* * Returns a Buffer or a hex String */ - static encryptWithPublicKey(plaintextBuffer, publicKey, outputInHex = false) { + static encryptWithPublicKey(plaintextBuffer: Uint8Array, publicKey: crypto.KeyLike, outputInHex = false) { this.validatePublicKey(publicKey) const ciphertextBuffer = crypto.publicEncrypt(publicKey, plaintextBuffer) if (outputInHex) { @@ -182,33 +239,45 @@ class EncryptionUtilBase { /* * Both 'data' and 'groupKey' must be Buffers. Returns a hex string without the '0x' prefix. */ - static encrypt(data, groupKey) { + static encrypt(data: Uint8Array, groupKey: GroupKey) { GroupKey.validate(groupKey) const iv = crypto.randomBytes(16) // always need a fresh IV when using CTR mode const cipher = crypto.createCipheriv('aes-256-ctr', groupKey.data, iv) - return hexlify(iv).slice(2) + cipher.update(data, null, 'hex') + cipher.final('hex') + + return hexlify(iv).slice(2) + cipher.update(data, undefined, 'hex') + cipher.final('hex') } /* - * 'ciphertext' must be a hex string (without '0x' prefix), 'groupKey' must be a Buffer. Returns a Buffer. + * 'ciphertext' must be a hex string (without '0x' prefix), 'groupKey' must be a GroupKey. Returns a Buffer. */ - static decrypt(ciphertext, groupKey) { + static decrypt(ciphertext: string, groupKey: GroupKey) { GroupKey.validate(groupKey) const iv = arrayify(`0x${ciphertext.slice(0, 32)}`) const decipher = crypto.createDecipheriv('aes-256-ctr', groupKey.data, iv) - return Buffer.concat([decipher.update(ciphertext.slice(32), 'hex', null), decipher.final(null)]) + return Buffer.concat([decipher.update(ciphertext.slice(32), 'hex'), decipher.final()]) } /* * Sets the content of 'streamMessage' with the encryption result of the old content with 'groupKey'. */ - static encryptStreamMessage(streamMessage, groupKey) { + static encryptStreamMessage(streamMessage: MessageLayer.StreamMessage, groupKey: GroupKey, nextGroupKey?: GroupKey) { GroupKey.validate(groupKey) /* eslint-disable no-param-reassign */ streamMessage.encryptionType = StreamMessage.ENCRYPTION_TYPES.AES streamMessage.groupKeyId = groupKey.id + + if (nextGroupKey) { + GroupKey.validate(nextGroupKey) + // @ts-expect-error + streamMessage.newGroupKey = nextGroupKey + } + streamMessage.serializedContent = this.encrypt(Buffer.from(streamMessage.getSerializedContent(), 'utf8'), groupKey) + if (nextGroupKey) { + GroupKey.validate(nextGroupKey) + streamMessage.newGroupKey = new EncryptedGroupKey(nextGroupKey.id, this.encrypt(nextGroupKey.data, groupKey)) + } streamMessage.parsedContent = undefined /* eslint-enable no-param-reassign */ } @@ -220,7 +289,7 @@ class EncryptionUtilBase { * message content and returns null. */ - static decryptStreamMessage(streamMessage, groupKey) { + static decryptStreamMessage(streamMessage: MessageLayer.StreamMessage, groupKey: GroupKey) { if ((streamMessage.encryptionType !== StreamMessage.ENCRYPTION_TYPES.AES)) { return null } @@ -239,13 +308,34 @@ class EncryptionUtilBase { streamMessage.serializedContent = serializedContent } catch (err) { streamMessage.encryptionType = StreamMessage.ENCRYPTION_TYPES.AES - throw new UnableToDecryptError(err.message, streamMessage) + throw new UnableToDecryptError(err.stack, streamMessage) + } + + try { + const { newGroupKey } = streamMessage + if (newGroupKey) { + // newGroupKey should be EncryptedGroupKey | GroupKey, but GroupKey is not defined in protocol + // @ts-expect-error + streamMessage.newGroupKey = GroupKey.from([ + newGroupKey.groupKeyId, + this.decrypt(newGroupKey.encryptedGroupKeyHex, groupKey) + ]) + } + } catch (err) { + streamMessage.encryptionType = StreamMessage.ENCRYPTION_TYPES.AES + throw new UnableToDecryptError('Could not decrypt new group key: ' + err.stack, streamMessage) } return null /* eslint-enable no-param-reassign */ } } +// after EncryptionUtil is ready +type InitializedEncryptionUtil = O.Overwrite + /** @internal */ export default class EncryptionUtil extends EncryptionUtilBase { /** @@ -253,15 +343,22 @@ export default class EncryptionUtil extends EncryptionUtilBase { * Convenience. */ - static async create(...args) { + static async create(...args: ConstructorParameters) { const encryptionUtil = new EncryptionUtil(...args) await encryptionUtil.onReady() return encryptionUtil } - constructor(options = {}) { - super(options) - if (options.privateKey && options.publicKey) { + privateKey + publicKey + private _generateKeyPairPromise: Promise | undefined + + constructor(options: { + privateKey: string, + publicKey: string, + } | {} = {}) { + super() + if ('privateKey' in options && 'publicKey' in options) { EncryptionUtil.validatePrivateKey(options.privateKey) EncryptionUtil.validatePublicKey(options.publicKey) this.privateKey = options.privateKey @@ -274,17 +371,14 @@ export default class EncryptionUtil extends EncryptionUtilBase { return this._generateKeyPair() } - isReady() { - return !!this.privateKey + isReady(this: EncryptionUtil): this is InitializedEncryptionUtil { + return !!(this.privateKey && this.publicKey) } // Returns a Buffer - decryptWithPrivateKey(ciphertext, isHexString = false) { + decryptWithPrivateKey(ciphertext: string | Uint8Array, isHexString = false) { if (!this.isReady()) { throw new Error('EncryptionUtil not ready.') } - let ciphertextBuffer = ciphertext - if (isHexString) { - ciphertextBuffer = arrayify(`0x${ciphertext}`) - } + const ciphertextBuffer = isHexString ? arrayify(`0x${ciphertext}`) : ciphertext as Uint8Array return crypto.privateDecrypt(this.privateKey, ciphertextBuffer) } @@ -302,7 +396,7 @@ export default class EncryptionUtil extends EncryptionUtilBase { } async __generateKeyPair() { - if (process.browser) { return this._keyPairBrowser() } + if (typeof window !== 'undefined') { return this._keyPairBrowser() } return this._keyPairServer() } diff --git a/src/stream/KeyExchange.js b/src/stream/KeyExchange.js index 159c5ea6e..8aee4be89 100644 --- a/src/stream/KeyExchange.js +++ b/src/stream/KeyExchange.js @@ -51,21 +51,24 @@ class InvalidContentTypeError extends Error { } */ -function getKeyExchangeStreamId(address) { +export function getKeyExchangeStreamId(address) { if (isKeyExchangeStream(address)) { return address // prevent ever double-handling } return `${KEY_EXCHANGE_STREAM_PREFIX}/${address.toLowerCase()}` } -function GroupKeyStore({ groupKeys }) { - const store = new Map(groupKeys) +function GroupKeyStore({ groupKeys = new Map() }) { + const store = new Map() + groupKeys.forEach((value, key) => { + store.set(key, value) + }) let currentGroupKeyId // current key id if any let nextGroupKey // key to use next, disappears if not actually used. store.forEach((groupKey) => { - GroupKey.validate(groupKey) + GroupKey.validate(GroupKey.from(groupKey)) // use last init key as current currentGroupKeyId = groupKey.id }) @@ -73,7 +76,7 @@ function GroupKeyStore({ groupKeys }) { function storeKey(groupKey) { GroupKey.validate(groupKey) if (store.has(groupKey.id)) { - const existingKey = store.get(groupKey.id) + const existingKey = GroupKey.from(store.get(groupKey.id)) if (!existingKey.equals(groupKey)) { throw new GroupKey.InvalidGroupKeyError( `Trying to add groupKey ${groupKey.id} but key exists & is not equivalent to new GroupKey: ${groupKey}.` @@ -90,12 +93,12 @@ function GroupKeyStore({ groupKeys }) { } return { - has(id) { - if (currentGroupKeyId === id) { return true } + has(groupKeyId) { + if (currentGroupKeyId === groupKeyId) { return true } - if (nextGroupKey && nextGroupKey.id === id) { return true } + if (nextGroupKey && nextGroupKey.id === groupKeyId) { return true } - return store.has(id) + return store.has(groupKeyId) }, isEmpty() { return !nextGroupKey && store.size === 0 @@ -115,10 +118,12 @@ function GroupKeyStore({ groupKeys }) { return this.useGroupKey() } - return store.get(currentGroupKeyId) + return this.get(currentGroupKeyId) }, - get(id) { - return store.get(id) + get(groupKeyId) { + const groupKey = store.get(groupKeyId) + if (!groupKey) { return undefined } + return GroupKey.from(groupKey) }, clear() { currentGroupKeyId = undefined @@ -158,11 +163,10 @@ function waitForSubMessage(sub, matchFn) { } sub.on('message', onMessage) sub.once('error', task.reject) - // eslint-disable-next-line promise/catch-or-return task.finally(() => { sub.off('message', onMessage) sub.off('error', task.reject) - }) + }).catch(() => {}) // prevent unhandled rejection return task } @@ -176,7 +180,9 @@ async function subscribeToKeyExchangeStream(client, onKeyExchangeMessage) { // subscribing to own keyexchange stream const publisherId = await client.getUserId() const streamId = getKeyExchangeStreamId(publisherId) - return client.subscribe(streamId, onKeyExchangeMessage) + const sub = await client.subscribe(streamId, onKeyExchangeMessage) + sub.on('error', () => {}) // errors should not shut down subscription + return sub } async function catchKeyExchangeError(client, streamMessage, fn) { @@ -197,7 +203,7 @@ async function catchKeyExchangeError(client, streamMessage, fn) { } async function PublisherKeyExhangeSubscription(client, getGroupKeyStore) { - async function onKeyExchangeMessage(parsedContent, streamMessage) { + async function onKeyExchangeMessage(_parsedContent, streamMessage) { return catchKeyExchangeError(client, streamMessage, async () => { if (streamMessage.messageType !== StreamMessage.MESSAGE_TYPES.GROUP_KEY_REQUEST) { return Promise.resolve() @@ -279,7 +285,7 @@ export function PublisherKeyExhange(client, { groupKeys = {} } = {}) { if (!sub) { return } const cancelTask = sub.cancel() sub = undefined - await cancelTask() + await cancelTask } } ], () => enabled) @@ -354,6 +360,7 @@ async function SubscriberKeyExhangeSubscription(client, getGroupKeyStore, encryp } sub = await subscribeToKeyExchangeStream(client, onKeyExchangeMessage) + sub.on('error', () => {}) // errors should not shut down subscription return sub } @@ -401,7 +408,7 @@ export function SubscriberKeyExchange(client, { groupKeys = {} } = {}) { cancelTask.then(responseTask.resolve).catch(responseTask.reject) return () => { - cancelTask.resolve({}) + cancelTask.resolve() } }, async () => { const msg = new GroupKeyRequest({ @@ -417,16 +424,17 @@ export function SubscriberKeyExchange(client, { groupKeys = {} } = {}) { response = undefined } }, async () => { - receivedGroupKeys = await getGroupKeysFromStreamMessage(response, encryptionUtil) + receivedGroupKeys = response ? await getGroupKeysFromStreamMessage(response, encryptionUtil) : [] return () => { receivedGroupKeys = [] } }, ], () => enabled && !done, { + id: `requestKeys.${requestId}`, onChange(isGoingUp) { if (!isGoingUp && cancelTask) { - cancelTask.resolve({}) + cancelTask.resolve() } } }) @@ -506,19 +514,20 @@ export function SubscriberKeyExchange(client, { groupKeys = {} } = {}) { const next = Scaffold([ async () => { - [sub] = await Promise.all([ - SubscriberKeyExhangeSubscription(client, getGroupKeyStore, encryptionUtil), - encryptionUtil.onReady(), - ]) + return encryptionUtil.onReady() + }, + async () => { + sub = await SubscriberKeyExhangeSubscription(client, getGroupKeyStore, encryptionUtil) return async () => { mem.clear(getGroupKeyStore) if (!sub) { return } const cancelTask = sub.cancel() sub = undefined - await cancelTask() + await cancelTask } } ], () => enabled, { + id: `SubscriberKeyExhangeSubscription.${client.id}`, async onDone() { // clean up requestKey if (requestKeys.step) { diff --git a/src/subscribe/Decrypt.js b/src/subscribe/Decrypt.js index c2c794927..25ab37f31 100644 --- a/src/subscribe/Decrypt.js +++ b/src/subscribe/Decrypt.js @@ -1,6 +1,5 @@ import { MessageLayer } from 'streamr-client-protocol' -import PushQueue from '../utils/PushQueue' import EncryptionUtil, { UnableToDecryptError } from '../stream/Encryption' import { SubscriberKeyExchange } from '../stream/KeyExchange' @@ -26,29 +25,33 @@ export default function Decrypt(client, options = {}) { } }) - async function* decrypt(src, onError = async (err) => { throw err }) { - yield* PushQueue.transform(src, async (streamMessage) => { + async function* decrypt(src, onError = async () => {}) { + for await (const streamMessage of src) { if (!streamMessage.groupKeyId) { - return streamMessage + yield streamMessage + continue } if (streamMessage.encryptionType !== StreamMessage.ENCRYPTION_TYPES.AES) { - return streamMessage + yield streamMessage + continue } try { - const groupKey = await requestKey(streamMessage) + const groupKey = await requestKey(streamMessage).catch((err) => { + throw new UnableToDecryptError(`Could not get GroupKey: ${streamMessage.groupKeyId} – ${err.message}`, streamMessage) + }) + if (!groupKey) { throw new UnableToDecryptError(`Group key not found: ${streamMessage.groupKeyId}`, streamMessage) } await EncryptionUtil.decryptStreamMessage(streamMessage, groupKey) - return streamMessage } catch (err) { await onError(err, streamMessage) + } finally { + yield streamMessage } - - return streamMessage - }) + } } return Object.assign(decrypt, { diff --git a/src/subscribe/Validator.js b/src/subscribe/Validator.js index 0a30b3c27..af8ff7b9b 100644 --- a/src/subscribe/Validator.js +++ b/src/subscribe/Validator.js @@ -49,7 +49,8 @@ export default function Validator(client, opts) { const validate = pOrderedResolve(async (msg) => { if (msg.messageType === StreamMessage.MESSAGE_TYPES.GROUP_KEY_ERROR_RESPONSE) { const res = GroupKeyErrorResponse.fromArray(msg.getParsedContent()) - const err = new ValidationError(`GroupKeyErrorResponse: ${res.errorMessage}`, msg) + const err = new ValidationError(`${client.id} GroupKeyErrorResponse: ${res.errorMessage}`, msg) + err.streamMessage = msg err.code = res.errorCode throw err } @@ -64,7 +65,14 @@ export default function Validator(client, opts) { } // In all other cases validate using the validator - await validator.validate(msg) // will throw with appropriate validation failure + // will throw with appropriate validation failure + await validator.validate(msg).catch((err) => { + if (!err.streamMessage) { + err.streamMessage = msg // eslint-disable-line no-param-reassign + } + throw err + }) + return msg }) diff --git a/src/subscribe/index.ts b/src/subscribe/index.ts index e2ae9fc8a..10ab40672 100644 --- a/src/subscribe/index.ts +++ b/src/subscribe/index.ts @@ -10,9 +10,15 @@ import MessagePipeline from './pipeline' import Validator from './Validator' import messageStream from './messageStream' import resendStream from './resendStream' -import { MaybeAsync, Todo } from '../types' +import { Todo, MaybeAsync } from '../types' import StreamrClient, { StreamPartDefinition, SubscribeOptions } from '..' +async function defaultOnFinally(err?: Error) { + if (err) { + throw err + } +} + /** * @category Important */ @@ -39,12 +45,12 @@ export class Subscription extends Emitter { /** @internal */ iterated?: Todo - constructor(client: StreamrClient, opts: Todo, onFinally = () => {}) { + constructor(client: StreamrClient, opts: Todo, onFinally = defaultOnFinally) { super() this.client = client this.options = validateOptions(opts) this.key = this.options.key - this.id = counterId(`Subscription.${this.key}`) + this.id = counterId(`Subscription.${this.options.id || ''}${this.key}`) this.streamId = this.options.streamId this.streamPartition = this.options.streamPartition @@ -57,24 +63,44 @@ export class Subscription extends Emitter { this.pipeline = opts.pipeline || MessagePipeline(client, { ...this.options, validate, - onError: (err: Todo) => { + onError: (err: Error) => { this.emit('error', err) }, - // @ts-expect-error }, this.onPipelineEnd) this.msgStream = this.pipeline.msgStream } + emit(event: symbol | string, ...args: any[]) { + if (event !== 'error') { + return super.emit(event, ...args) + } + + try { + if (this.listenerCount('error')) { + // debugger + return super.emit('error', ...args) + } + throw args[0] + } catch (err) { + this.cancel(err) + return false + } + } + /** * Expose cleanup * @internal */ - async onPipelineEnd(err: Todo) { + + async onPipelineEnd(err?: Error) { + let error = err try { - await this._onFinally(err) + await this._onFinally(error) + } catch (onFinallyError) { + error = AggregatedError.from(error, onFinallyError) } finally { - this._onDone.handleErrBack(err) + this._onDone.handleErrBack(error) } } @@ -166,12 +192,16 @@ class SubscriptionSession extends Emitter { subscriptions: Set deletedSubscriptions: Set step?: Todo + _subscribe + _unsubscribe constructor(client: StreamrClient, options: Todo) { super() this.client = client this.options = validateOptions(options) this.validate = Validator(client, this.options) + this._subscribe = this.options.subscribe || subscribe + this._unsubscribe = this.options.unsubscribe || unsubscribe this.subscriptions = new Set() // active subs this.deletedSubscriptions = new Set() // hold so we can clean up @@ -259,13 +289,13 @@ class SubscriptionSession extends Emitter { }, // subscribe async () => { - await subscribe(this.client, this.options) + await this._subscribe(this.client, this.options) this.emit('subscribed') return async () => { if (needsReset) { return } this.emit('unsubscribing') - await unsubscribe(this.client, this.options) + await this._unsubscribe(this.client, this.options) } } // @ts-expect-error @@ -368,7 +398,6 @@ class SubscriptionSession extends Emitter { */ class Subscriptions { - client: StreamrClient subSessions: Map @@ -377,7 +406,7 @@ class Subscriptions { this.subSessions = new Map() } - async add(opts: StreamPartDefinition, onFinally: MaybeAsync<(err?: any) => void> = async () => {}) { + async add(opts: StreamPartDefinition, onFinally: MaybeAsync<(err?: any) => void> = defaultOnFinally) { const options = validateOptions(opts) const { key } = options @@ -389,7 +418,6 @@ class Subscriptions { const sub = new Subscription(this.client, { ...options, validate: subSession.validate, - // @ts-expect-error }, async (err: Todo) => { try { await this.remove(sub) @@ -646,7 +674,6 @@ export class Subscriber { await resendDone // ensure realtime doesn't start until resend ends yield* resendSubscribeSub.realtime }, - // @ts-expect-error ], end) const resendTask = resendMessageStream.subscribe() diff --git a/src/subscribe/pipeline.js b/src/subscribe/pipeline.js index 3e45d7c4b..0beb14294 100644 --- a/src/subscribe/pipeline.js +++ b/src/subscribe/pipeline.js @@ -22,9 +22,9 @@ async function collect(src) { * Subscription message processing pipeline */ -export default function MessagePipeline(client, opts = {}, onFinally = async () => {}) { +export default function MessagePipeline(client, opts = {}, onFinally = async (err) => { if (err) { throw err } }) { const options = validateOptions(opts) - const { key, afterSteps = [], beforeSteps = [], onError = (err) => { throw err } } = options + const { key, afterSteps = [], beforeSteps = [] } = options const id = counterId('MessagePipeline') + key /* eslint-disable object-curly-newline */ @@ -36,6 +36,16 @@ export default function MessagePipeline(client, opts = {}, onFinally = async () } = options /* eslint-enable object-curly-newline */ + const seenErrors = new WeakSet() + const onErrorFn = options.onError ? options.onError : (error) => { throw error } + const onError = async (err) => { + if (seenErrors.has(err)) { + return + } + seenErrors.add(err) + await onErrorFn(err) + } + // re-order messages (ignore gaps) const internalOrderingUtil = OrderMessages(client, { ...options, diff --git a/src/utils/AggregatedError.ts b/src/utils/AggregatedError.ts index a68c5b942..6a0b07437 100644 --- a/src/utils/AggregatedError.ts +++ b/src/utils/AggregatedError.ts @@ -11,20 +11,42 @@ function joinMessages(msgs: (string | undefined)[]): string { return msgs.filter(Boolean).join('\n') } +function getStacks(err: Error | AggregatedError) { + if (err instanceof AggregatedError) { + return [ + err.ownStack, + ...[...err.errors].map(({ stack }) => stack) + ] + } + + return [err.stack] +} + +function joinStackTraces(errs: Error[]): string { + return errs.flatMap((err) => getStacks(err)).filter(Boolean).join('\n') +} + export default class AggregatedError extends Error { errors: Set - ownMessage?: string + ownMessage: string + ownStack?: string constructor(errors: Error[] = [], errorMessage = '') { const message = joinMessages([ errorMessage, ...errors.map((err) => err.message) ]) super(message) + errors.forEach((err) => { + Object.assign(this, err) + }) + this.message = message this.ownMessage = errorMessage this.errors = new Set(errors) if (Error.captureStackTrace) { Error.captureStackTrace(this, this.constructor) } + this.ownStack = this.stack + this.stack = joinStackTraces([this, ...errors]) } /** diff --git a/src/utils/PushQueue.js b/src/utils/PushQueue.ts similarity index 66% rename from src/utils/PushQueue.js rename to src/utils/PushQueue.ts index adab5e0a8..96635b390 100644 --- a/src/utils/PushQueue.js +++ b/src/utils/PushQueue.ts @@ -1,16 +1,40 @@ -import { CancelableGenerator } from './iterators' // eslint-disable-line import/no-cycle +import { pOrderedResolve, Defer, pTimeout } from './index' -import { pOrderedResolve } from './index' +async function endGenerator(gtr: AsyncGenerator, error?: Error) { + return error + ? gtr.throw(error).catch(() => {}) // ignore err + : gtr.return(undefined) +} + +type EndGeneratorTimeoutOptions = { + timeout?: number + error?: Error +} + +async function endGeneratorTimeout( + gtr: AsyncGenerator, + { + timeout = 250, + error, + }: EndGeneratorTimeoutOptions = {} +) { + return pTimeout(endGenerator(gtr, error), { + timeout, + rejectOnTimeout: false, + }) +} export class AbortError extends Error { - constructor(msg = '', ...args) { - super(`The operation was aborted. ${msg}`, ...args) + constructor(msg = '') { + super(`The operation was aborted. ${msg}`) if (Error.captureStackTrace) { Error.captureStackTrace(this, this.constructor) } } } +type AnyIterable = Iterable | AsyncIterable + /** * Async Iterable PushQueue * On throw/abort any items in buffer will be flushed before iteration throws. @@ -41,26 +65,49 @@ export class AbortError extends Error { * */ -export default class PushQueue { - constructor(items = [], { signal, onEnd, timeout = 0, autoEnd = true } = {}) { +type PushQueueOptions = Partial<{ + signal: AbortSignal, + onEnd: (err?: Error, ...args: any[]) => void + timeout: number, + autoEnd: boolean, +}> + +export default class PushQueue { + autoEnd + timeout + signal + iterator + buffer: T[] | [...T[], null] + error?: Error// queued error + nextQueue: (ReturnType)[] = [] // queued promises for next() + finished = false + pending: number = 0 + ended = false + _onEnd: PushQueueOptions['onEnd'] + _onEndCalled = false + _isCancelled = false + + constructor(items: T[] = [], { + signal, + onEnd, + timeout = 0, + autoEnd = true + }: PushQueueOptions = {}) { this.autoEnd = autoEnd - this.buffer = [...items] - this.finished = false - this.error = null // queued error - this.nextQueue = [] // queued promises for next() - this.pending = 0 - this._onEnd = onEnd this.timeout = timeout + this._onEnd = onEnd + this.buffer = [...items] this[Symbol.asyncIterator] = this[Symbol.asyncIterator].bind(this) this.onAbort = this.onAbort.bind(this) this.onEnd = this.onEnd.bind(this) this.cancel = this.cancel.bind(this) + this.isCancelled = this.isCancelled.bind(this) this.end = this.end.bind(this) // abort signal handling - this.signal = signal if (signal) { + this.signal = signal if (signal.aborted) { this.onAbort() } @@ -73,14 +120,14 @@ export default class PushQueue { this.iterator = this.iterate() } - static from(iterable, opts = {}) { - const queue = new PushQueue([], opts) + static from(iterable: AnyIterable, opts = {}) { + const queue = new PushQueue([], opts) queue.from(iterable) return queue } - static transform(src, fn, opts = {}) { - const buffer = new PushQueue([], opts) + static transform(src: AnyIterable, fn: (value: TT) => U, opts = {}) { + const buffer = new PushQueue([], opts) const orderedFn = pOrderedResolve(fn) // push must be run in sequence ;(async () => { // eslint-disable-line semi-style const tasks = [] @@ -104,16 +151,16 @@ export default class PushQueue { return buffer } - async from(iterable, { end = this.autoEnd } = {}) { + async from(iterable: Iterable | AsyncIterable, { end = this.autoEnd } = {}) { try { // detect sync/async iterable and iterate appropriately - if (!iterable[Symbol.asyncIterator]) { - // sync iterables push into buffer immediately - for (const item of iterable) { + if ((Symbol.asyncIterator || Symbol.for('Symbol.asyncIterator')) in iterable) { + for await (const item of iterable as AsyncIterable) { this.push(item) } - } else { - for await (const item of iterable) { + } else if ((Symbol.iterator || Symbol.for('Symbol.iterator')) in iterable) { + // sync iterables push into buffer immediately + for (const item of iterable as Iterable) { this.push(item) } } @@ -122,26 +169,26 @@ export default class PushQueue { } if (end) { - this.end() + await this.end() } return Promise.resolve() } - onEnd(...args) { + onEnd(err?: Error, ...args: any[]) { if (this._onEndCalled || !this._onEnd) { return Promise.resolve() } this._onEndCalled = true - return this._onEnd(...args) + return this._onEnd(err, ...args) } /** * signals no more data should be buffered */ - end(v) { + end(v?: T | null) { if (this.ended) { return } @@ -158,7 +205,7 @@ export default class PushQueue { return this.throw(new AbortError()) } - async next(...args) { + async next(...args:[] | [unknown]) { return this.iterator.next(...args) } @@ -175,7 +222,11 @@ export default class PushQueue { await this._cleanup() } - async throw(err) { + async throw(err: Error) { + if (this.finished) { + return + } + this.finished = true const p = this.nextQueue.shift() if (p) { @@ -185,7 +236,7 @@ export default class PushQueue { this.error = err } - return this.return() + await this._cleanup() } get length() { @@ -193,17 +244,22 @@ export default class PushQueue { return this.ended && count ? count - 1 : count } - _cleanup() { + async _cleanup() { this.finished = true + const { error } = this for (const p of this.nextQueue) { - p.resolve() + if (error) { + p.reject(error) + } else { + p.resolve(undefined) + } } this.pending = 0 this.buffer.length = 0 return this.onEnd(this.error) } - push(...values) { + push(...values: (T | null)[]) { if (this.finished || this.ended) { // do nothing if done return @@ -211,26 +267,29 @@ export default class PushQueue { // if values contains null, treat null as end const nullIndex = values.findIndex((v) => v === null) + let validValues = values as T[] if (nullIndex !== -1) { this.ended = true // include null but trim rest - values = values.slice(0, nullIndex + 1) // eslint-disable-line no-param-reassign + validValues = values.slice(0, nullIndex + 1) as T[] } // resolve pending next calls - while (this.nextQueue.length && values.length) { + while (this.nextQueue.length && validValues.length) { const p = this.nextQueue.shift() - p.resolve(values.shift()) + if (p) { + p.resolve(validValues.shift()) + } } // push any remaining values into buffer - if (values.length) { - this.buffer.push(...values) + if (validValues.length) { + this.buffer.push(...validValues) } } - iterate() { // eslint-disable-line class-methods-use-this - const handleTerminalValues = (value) => { + iterate() { + const handleTerminalValues = (value: null | Error | any) => { // returns final task to perform before returning, or false if (value === null) { return this.return() @@ -243,7 +302,7 @@ export default class PushQueue { return false } - const [cancel, itr] = CancelableGenerator(async function* iterate() { + return async function* iterate(this: PushQueue) { while (true) { /* eslint-disable no-await-in-loop */ // feed from buffer first @@ -265,7 +324,7 @@ export default class PushQueue { // handle queued error if (this.error) { const err = this.error - this.error = null + this.error = undefined throw err } @@ -279,13 +338,8 @@ export default class PushQueue { continue // eslint-disable-line no-continue } - const deferred = new Promise((resolve, reject) => { - // wait for next push - this.nextQueue.push({ - resolve, - reject, - }) - }) + const deferred = Defer() + this.nextQueue.push(deferred) deferred.catch(() => {}) // prevent unhandledrejection const value = await deferred @@ -304,28 +358,29 @@ export default class PushQueue { yield value /* eslint-enable no-await-in-loop */ } - }.call(this), async (err) => { - return this.onEnd(err) - }, { - timeout: this.timeout, - }) - - return Object.assign(itr, { - cancel, - }) + }.call(this) } - pipe(next, opts) { + pipe(next: PushQueue, opts: Parameters['from']>[1]) { return next.from(this, opts) } - async cancel(...args) { + async cancel(error?: Error) { this.finished = true - return this.iterator.cancel(...args) + this._isCancelled = true + if (error) { + this.error = error + } + await endGeneratorTimeout(this.iterator, { + timeout: this.timeout, + error, + }) + + return this.return() } - isCancelled(...args) { - return this.iterator.isCancelled(...args) + isCancelled() { + return this._isCancelled } async* [Symbol.asyncIterator]() { @@ -336,10 +391,9 @@ export default class PushQueue { } finally { this.finished = true if (this.signal) { - this.signal.removeEventListener('abort', this.onAbort, { - once: true, - }) + this.signal.removeEventListener('abort', this.onAbort) } + await this.onEnd(this.error) } } } diff --git a/src/utils/Scaffold.ts b/src/utils/Scaffold.ts index 1adc7ae16..9adec24d0 100644 --- a/src/utils/Scaffold.ts +++ b/src/utils/Scaffold.ts @@ -21,6 +21,7 @@ type ScaffoldOptions = { onError?: (error: Error) => void onDone?: MaybeAsync<(shouldUp: boolean, error?: Error) => void> onChange?: MaybeAsync<(shouldUp: boolean) => void> + id?: string, } const noop = () => {} @@ -28,7 +29,7 @@ const noop = () => {} export default function Scaffold( sequence: Step[] = [], _checkFn: () => Promise, - { onError, onDone, onChange }: ScaffoldOptions = {} + { id = '', onError, onDone, onChange }: ScaffoldOptions = {} ) { let error: Error | undefined // ignore error if check fails @@ -59,7 +60,7 @@ export default function Scaffold( throw err // rethrow } } catch (newErr) { - error = AggregatedError.from(error, newErr) + error = AggregatedError.from(error, newErr, `ScaffoldError:${id}`) } } @@ -103,7 +104,8 @@ export default function Scaffold( collectErrors(err) } onDownSteps.push(onDownStep || (() => {})) - return next() + // eslint-disable-next-line no-return-await + return await next() // return await gives us a better stack trace } } else if (onDownSteps.length) { isDone = false @@ -115,7 +117,8 @@ export default function Scaffold( collectErrors(err) } nextSteps.push(prevSteps.pop() as StepUp) - return next() + // eslint-disable-next-line no-return-await + return await next() // return await gives us a better stack trace } else if (error) { const err = error // eslint-disable-next-line require-atomic-updates diff --git a/src/utils/index.ts b/src/utils/index.ts index 008a0fe0c..203843368 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -19,12 +19,13 @@ export { AggregatedError, Scaffold } const UUID = uuidv4() +export const SEPARATOR = ':' /* * Incrementing + human readable uuid */ export function uuid(label = '') { - return uniqueId(`${UUID}${label ? `.${label}` : ''}`) + return uniqueId(`${UUID}${label ? `${SEPARATOR}${label}` : ''}`) } export function randomString(length = 20) { @@ -62,7 +63,7 @@ export const counterId = (() => { } } - return `${prefix}.${counts[prefix]}` + return `${prefix}${SEPARATOR}${counts[prefix]}` } /** @@ -219,7 +220,7 @@ export function CacheFn(fn: Parameters[0], { type PromiseResolve = L.Compulsory['then']>>[0] type PromiseReject = L.Compulsory['then']>>[1] -export function Defer(executor: (...args: Parameters['then']>) => void = () => {}) { +export function Defer(executor: (...args: Parameters['then']>) => void = () => {}) { let resolve: PromiseResolve = () => {} let reject: PromiseReject = () => {} // eslint-disable-next-line promise/param-names @@ -228,6 +229,7 @@ export function Defer(executor: (...args: Parameters['then']>) => v reject = _reject executor(resolve, reject) }) + p.catch(() => {}) // prevent unhandledrejection function wrap(fn: F.Function) { return async (...args: unknown[]) => { @@ -403,9 +405,20 @@ export async function pTimeout(promise: Promise, ...args: pTimeoutArgs) } let timedOut = false - let t: ReturnType + const p = Defer() + const t = setTimeout(() => { + timedOut = true + if (rejectOnTimeout) { + p.reject(new TimeoutError(message, timeout)) + } else { + p.resolve(undefined) + } + }, timeout) + p.catch(() => {}) + return Promise.race([ Promise.resolve(promise).catch((err) => { + clearTimeout(t) if (timedOut) { // ignore errors after timeout return @@ -413,18 +426,10 @@ export async function pTimeout(promise: Promise, ...args: pTimeoutArgs) throw err }), - new Promise((resolve, reject) => { - t = setTimeout(() => { - timedOut = true - if (rejectOnTimeout) { - reject(new TimeoutError(message, timeout)) - } else { - resolve(undefined) - } - }, timeout) - }) + p ]).finally(() => { clearTimeout(t) + p.resolve(undefined) }) } @@ -453,9 +458,9 @@ export async function sleep(ms: number = 0) { /** * Wait until a condition is true - * @param {function(): Promise|function(): boolean} condition wait until this callback function returns true - * @param {number} [timeOutMs=10000] stop waiting after that many milliseconds, -1 for disable - * @param {number} [pollingIntervalMs=100] check condition between so many milliseconds + * @param condition - wait until this callback function returns true + * @param timeOutMs - stop waiting after that many milliseconds, -1 for disable + * @param pollingIntervalMs - check condition between so many milliseconds */ export async function until(condition: MaybeAsync<() => boolean>, timeOutMs = 10000, pollingIntervalMs = 100) { let timeout = false diff --git a/src/utils/iterators.js b/src/utils/iterators.js index bfb70ffba..73c06a189 100644 --- a/src/utils/iterators.js +++ b/src/utils/iterators.js @@ -17,12 +17,15 @@ export function iteratorFinally(iterable, onFinally) { let started = false let ended = false let error - + let onFinallyTask // ensure finally only runs once - const onFinallyOnce = pMemoize(onFinally, { - cachePromiseRejection: true, // don't run again if failed - cacheKey: () => true // always same key - }) + const onFinallyOnce = (err) => { + if (!onFinallyTask) { + // eslint-disable-next-line promise/no-promise-in-callback + onFinallyTask = Promise.resolve().then(async () => onFinally(err)) + } + return onFinallyTask + } // wraps return/throw to call onFinally even if generator was never started const handleFinally = (originalFn) => async (...args) => { @@ -84,6 +87,8 @@ export function iteratorFinally(iterable, onFinally) { // return a generator that simply runs finally script (once) return (async function* generatorRunFinally() { // eslint-disable-line require-yield try { + // NOTE: native generators do not throw if gen.throw(err) called before started + // so we should do the same here if (typeof iterable.return === 'function') { await iterable.return() // runs onFinally for nested iterable } @@ -128,6 +133,7 @@ const endGeneratorTimeout = pMemoize(async (gtr, error, timeout = 250) => { await cancelGenerator(gtr, error) } }, { + cache: new WeakMap(), cachePromiseRejection: true, }) @@ -276,10 +282,7 @@ export function CancelableGenerator(iterable, onFinally = () => {}, { timeout = isDone: () => finalCalled, }) - return [ - cancelFn, - cancelableGenerator - ] + return cancelableGenerator } /** @@ -290,7 +293,13 @@ const isPipeline = Symbol('isPipeline') const getIsStream = (item) => typeof item.from === 'function' -export function pipeline(iterables = [], onFinally = () => {}, { end, ...opts } = {}) { +async function defaultOnFinally(err) { + if (err) { + throw err + } +} + +export function pipeline(iterables = [], onFinally = defaultOnFinally, { end, ...opts } = {}) { const cancelFns = new Set() let cancelled = false let error @@ -309,7 +318,7 @@ export function pipeline(iterables = [], onFinally = () => {}, { end, ...opts } try { // eslint-disable-next-line promise/no-promise-in-callback await allSettledValues([...cancelFns].map(async ({ isCancelled, cancel }) => ( - isCancelled ? cancel(err) : undefined + !isCancelled() ? cancel(err) : undefined ))) } finally { cancelFns.clear() @@ -321,23 +330,19 @@ export function pipeline(iterables = [], onFinally = () => {}, { end, ...opts } return } - try { - if (cancelled) { - await onCancelDone - return - } - - if (error) { - // eslint-disable-next-line promise/no-promise-in-callback - pipelineValue.throw(error).catch(() => {}) // ignore err - } else { - pipelineValue.return() - } - await cancelAll(err) + if (cancelled) { await onCancelDone - } finally { - await true + return } + + if (error) { + // eslint-disable-next-line promise/no-promise-in-callback + pipelineValue.throw(error).catch(() => {}) // ignore err + } else { + pipelineValue.return() + } + await cancelAll(err) + await onCancelDone } let firstSrc @@ -347,13 +352,10 @@ export function pipeline(iterables = [], onFinally = () => {}, { end, ...opts } } const last = iterables.reduce((_prev, next, index) => { - let prev - let nextIterable - - const [, it] = CancelableGenerator((async function* Gen() { - prev = index === 0 ? firstSrc : _prev + const it = CancelableGenerator((async function* Gen() { + const prev = index === 0 ? firstSrc : _prev // take first "prev" from outer iterator, if one exists - nextIterable = typeof next === 'function' ? next(prev) : next + const nextIterable = typeof next === 'function' ? next(prev) : next if (prev && nextIterable[isPipeline]) { nextIterable.setFirstSource(prev) @@ -367,7 +369,6 @@ export function pipeline(iterables = [], onFinally = () => {}, { end, ...opts } prev.id = prev.id || 'inter-' + nextIterable.id nextIterable.from(prev, { end }) } - yield* nextIterable }()), async (err) => { if (!error && err && error !== err) { diff --git a/test/integration/Encryption.test.js b/test/integration/Encryption.test.js index 8b4475760..aa7cc2029 100644 --- a/test/integration/Encryption.test.js +++ b/test/integration/Encryption.test.js @@ -93,8 +93,8 @@ describe('decryption', () => { } }) - beforeEach(async () => { - client = createClient() + async function setupClient(opts) { + client = createClient(opts) await Promise.all([ client.session.getSessionToken(), client.connect(), @@ -109,6 +109,10 @@ describe('decryption', () => { publishTestMessages = getPublishTestMessages(client, { stream }) + } + + beforeEach(async () => { + await setupClient() }) it('client.subscribe can decrypt encrypted messages if it knows the group key', async () => { @@ -545,4 +549,3 @@ describe('decryption', () => { expect(onSubError).toHaveBeenCalledTimes(1) }) }) - diff --git a/test/integration/MessagePipeline.test.js b/test/integration/MessagePipeline.test.js new file mode 100644 index 000000000..54f807498 --- /dev/null +++ b/test/integration/MessagePipeline.test.js @@ -0,0 +1,275 @@ +import { wait } from 'streamr-test-utils' +import { MessageLayer, ControlLayer } from 'streamr-client-protocol' + +import { fakePrivateKey, addAfterFn } from '../utils' +import { pipeline } from '../../src/utils/iterators' +import PushQueue from '../../src/utils/PushQueue' +import { StreamrClient } from '../../src/StreamrClient' +import Connection from '../../src/Connection' + +import config from './config' +import MessagePipeline from '../../src/subscribe/pipeline' +import { Subscription } from '../../src/subscribe' +import Validator from '../../src/subscribe/Validator' + +const { StreamMessage, MessageID } = MessageLayer +const { BroadcastMessage } = ControlLayer + +const MOCK_INVALID_GROUP_KEY_MESSAGE = new StreamMessage({ + messageId: new MessageID( + 'SYSTEM/keyexchange/0x848f6fc62d8c6471ab0d0dd7ae7439e6bf927cf8', + 0, + 1614960211925, + 0, + '0x320e5461c6521fce2df230e0cdfe715baa01b094', + 'nl9gi01z8qnz4st4frdb' + ), + prevMsgRef: null, + messageType: 31, + contentType: 0, + encryptionType: 0, + groupKeyId: null, + newGroupKey: null, + signatureType: 2, + signature: '0xaf53be7ac333480b9dc2fcc5a171a661e1077738bf8446e36f4dd3214582153403e2f2b57ddd5b3dcb8e8ef77212160d8a0cfeb61b212221361a96391d7582fe1c', + // eslint-disable-next-line max-len + content: [ + // mock INVALID_GROUP_KEY_REQUEST + 'ff7a68fa-acbd-4540-8e7e-11b5e0413e49:GroupKeyRequest15', + 'VLKnLfRcTLGaG5FDEj2qZw', + 'INVALID_GROUP_KEY_REQUEST', + '0x848f6fc62d8c6471ab0d0dd7ae7439e6bf927cf8 is not a subscriber on stream VLKnLfRcTLGaG5FDEj2qZw. Group key request: ...', + ['ff7a68fa-acbd-4540-8e7e-11b5e0413e49:GroupKey11'] + ] +}) + +describe('MessagePipeline', () => { + let expectErrors = 0 // check no errors by default + let errors = [] + + const getOnError = (errs) => jest.fn((err) => { + errs.push(err) + }) + + let onError = jest.fn() + let client + + const createClient = (opts = {}) => { + const c = new StreamrClient({ + ...config.clientOptions, + auth: { + privateKey: fakePrivateKey(), + }, + autoConnect: false, + autoDisconnect: false, + disconnectDelay: 1, + publishAutoDisconnectDelay: 50, + maxRetries: 2, + cache: { + maxAge: 1, + }, + ...opts, + }) + c.onError = jest.fn() + c.on('error', onError) + + return c + } + + const addAfter = addAfterFn() + + beforeEach(() => { + errors = [] + expectErrors = 0 + onError = getOnError(errors) + }) + + afterEach(async () => { + await wait() + // ensure no unexpected errors + expect(errors).toHaveLength(expectErrors) + if (client) { + expect(client.onError).toHaveBeenCalledTimes(expectErrors) + } + }) + + afterEach(async () => { + await wait() + if (client) { + client.debug('disconnecting after test') + await client.disconnect() + } + + const openSockets = Connection.getOpen() + if (openSockets !== 0) { + throw new Error(`sockets not closed: ${openSockets}`) + } + }) + + async function setupClient(opts) { + client = createClient(opts) + await Promise.all([ + client.session.getSessionToken(), + client.connect(), + ]) + } + + beforeEach(async () => { + await setupClient() + }) + + it('handles errors', async () => { + const validate = Validator(client, MOCK_INVALID_GROUP_KEY_MESSAGE.messageId) + let p + const onPipelineError = jest.fn(async (err) => { + await wait(10) + return p.cancel(err) + }) + p = pipeline([ + async function* generate() { + await wait(10) + yield MOCK_INVALID_GROUP_KEY_MESSAGE + }, + async function* ValidateMessages(src) { + for await (const streamMessage of src) { + try { + await validate(streamMessage) + } catch (err) { + await onPipelineError(err) + } + yield streamMessage + } + }, + async function* Delay(src) { + for await (const streamMessage of src) { + await wait(10) + yield streamMessage + } + }, + pipeline([ + async function* ValidateMessages2(src) { + yield* (async function* validate2() { + for await (const streamMessage of src) { + try { + await wait(10) + await validate(streamMessage) + } catch (err) { + await onPipelineError(err) + } + yield streamMessage + } + }()) + }, + ]) + ]) + + const received = [] + await expect(async () => { + for await (const streamMessage of p) { + received.push(streamMessage) + } + }).rejects.toThrow() + expect(received).toHaveLength(0) + expect(onPipelineError).toHaveBeenCalledTimes(1) + }) + + it('handles errors in MessagePipeline', async () => { + const onPipelineError = jest.fn((err) => { + throw err + }) + const msgStream = new PushQueue([ + ]) + const p = MessagePipeline(client, { + ...MOCK_INVALID_GROUP_KEY_MESSAGE.messageId, + msgStream, + onError: onPipelineError, + }, (err) => { + if (err) { + throw err + } + }) + const t = setTimeout(() => { + msgStream.push(new BroadcastMessage({ + streamMessage: MOCK_INVALID_GROUP_KEY_MESSAGE, + requestId: '', + })) + }, 15) + addAfter(() => clearTimeout(t)) + + const received = [] + await expect(async () => { + for await (const streamMessage of p) { + received.push(streamMessage) + } + }).rejects.toThrow() + expect(received).toHaveLength(0) + expect(onPipelineError).toHaveBeenCalledTimes(1) + }) + + it('handles errors in Subscription', async () => { + const onPipelineError = jest.fn((err) => { + throw err + }) + const msgStream = new PushQueue() + const sub = new Subscription(client, { + ...MOCK_INVALID_GROUP_KEY_MESSAGE.messageId, + msgStream, + }, (err) => { + if (err) { + throw err + } + }) + sub.on('error', onPipelineError) + + const t = setTimeout(() => { + msgStream.push(new BroadcastMessage({ + streamMessage: MOCK_INVALID_GROUP_KEY_MESSAGE, + requestId: '', + })) + }, 15) + addAfter(() => clearTimeout(t)) + const received = [] + await expect(async () => { + for await (const streamMessage of sub) { + received.push(streamMessage) + } + }).rejects.toThrow() + expect(received).toHaveLength(0) + expect(onPipelineError).toHaveBeenCalledTimes(1) + }) + + it('handles errors in client.subscribe', async () => { + const onPipelineError = jest.fn((err) => { + throw err + }) + const msgStream = new PushQueue() + const sub = await client.subscribe({ + ...MOCK_INVALID_GROUP_KEY_MESSAGE.messageId, + msgStream, + subscribe: async () => { + await wait(10) + }, + unsubscribe: async () => { + await wait(10) + }, + }) + sub.on('error', onPipelineError) + + const t = setTimeout(() => { + msgStream.push(new BroadcastMessage({ + streamMessage: MOCK_INVALID_GROUP_KEY_MESSAGE, + requestId: '', + })) + }, 15) + addAfter(() => clearTimeout(t)) + + const received = [] + await expect(async () => { + for await (const streamMessage of sub) { + received.push(streamMessage) + } + }).rejects.toThrow() + expect(received).toHaveLength(0) + expect(onPipelineError).toHaveBeenCalledTimes(1) + }) +}) diff --git a/test/integration/Subscriber.test.js b/test/integration/Subscriber.test.js index 55800d3eb..72559389c 100644 --- a/test/integration/Subscriber.test.js +++ b/test/integration/Subscriber.test.js @@ -11,8 +11,9 @@ import config from './config' const { ControlMessage } = ControlLayer const MAX_ITEMS = 2 +const NUM_MESSAGES = 6 -describeRepeats('StreamrClient Stream', () => { +describeRepeats('Subscriber', () => { let expectErrors = 0 // check no errors by default let onError = jest.fn() let client @@ -190,6 +191,184 @@ describeRepeats('StreamrClient Stream', () => { expect(M.count(stream.id)).toBe(0) }) + + describe('subscription error handling', () => { + it('works when error thrown inline', async () => { + const err = new Error('expected') + + const sub = await M.subscribe({ + ...stream, + afterSteps: [ + async function* ThrowError(s) { + let count = 0 + for await (const msg of s) { + if (count === MAX_ITEMS) { + throw err + } + count += 1 + yield msg + } + } + ] + }) + + expect(M.count(stream.id)).toBe(1) + + const published = await publishTestMessages(NUM_MESSAGES, { + timestamp: 111111, + }) + + const received = [] + await expect(async () => { + for await (const m of sub) { + received.push(m.getParsedContent()) + } + }).rejects.toThrow(err) + expect(received).toEqual(published.slice(0, MAX_ITEMS)) + }) + + it('works when multiple steps error', async () => { + const err = new Error('expected') + + const sub = await M.subscribe({ + ...stream, + afterSteps: [ + async function* ThrowError1(s) { + let count = 0 + for await (const msg of s) { + if (count === MAX_ITEMS) { + throw err + } + count += 1 + yield msg + } + }, + async function* ThrowError2(s) { + let count = 0 + for await (const msg of s) { + if (count === MAX_ITEMS) { + throw err + } + count += 1 + yield msg + } + } + ] + }) + + expect(M.count(stream.id)).toBe(1) + + const published = await publishTestMessages(NUM_MESSAGES, { + timestamp: 111111, + }) + + const received = [] + await expect(async () => { + for await (const m of sub) { + received.push(m.getParsedContent()) + } + }).rejects.toThrow(err) + expect(received).toEqual(published.slice(0, MAX_ITEMS)) + }) + + describe('error is bad groupkey', () => { + let sub + const BAD_GROUP_KEY_ID = 'BAD_GROUP_KEY_ID' + + beforeEach(async () => { + await client.publisher.startKeyExchange() + sub = await M.subscribe({ + ...stream, + beforeSteps: [ + async function* ThrowError(s) { + let count = 0 + for await (const msg of s) { + if (count === MAX_ITEMS) { + msg.streamMessage.encryptionType = 2 + msg.streamMessage.groupKeyId = BAD_GROUP_KEY_ID + } + count += 1 + yield msg + } + } + ] + }) + + expect(M.count(stream.id)).toBe(1) + }) + + it('throws subscription loop when encountering bad message', async () => { + const published = await publishTestMessages(NUM_MESSAGES, { + timestamp: 111111, + }) + + const received = [] + await expect(async () => { + for await (const m of sub) { + received.push(m.getParsedContent()) + } + }).rejects.toThrow(BAD_GROUP_KEY_ID) + expect(received).toEqual(published.slice(0, MAX_ITEMS)) + }) + + it('will skip bad message if error handler attached', async () => { + const published = await publishTestMessages(NUM_MESSAGES, { + timestamp: 111111, + }) + + const onSubscriptionError = jest.fn() + sub.on('error', onSubscriptionError) + + const received = [] + let t + for await (const m of sub) { + received.push(m.getParsedContent()) + if (received.length === published.length - 1) { + // eslint-disable-next-line no-loop-func + t = setTimeout(() => { + // give it a moment to incorrectly get messages + sub.cancel() + }, 100) + } + + if (received.length === published.length) { + break + } + } + clearTimeout(t) + expect(received).toEqual([ + ...published.slice(0, MAX_ITEMS), + ...published.slice(MAX_ITEMS + 1) + ]) + expect(onSubscriptionError).toHaveBeenCalledTimes(1) + }) + + it('will not skip bad message if error handler attached & throws', async () => { + expect(M.count(stream.id)).toBe(1) + + const published = await publishTestMessages(NUM_MESSAGES, { + timestamp: 111111, + }) + + const received = [] + const onSubscriptionError = jest.fn((err) => { + throw err + }) + + sub.on('error', onSubscriptionError) + await expect(async () => { + for await (const m of sub) { + received.push(m.getParsedContent()) + if (received.length === published.length) { + break + } + } + }).rejects.toThrow(BAD_GROUP_KEY_ID) + expect(received).toEqual(published.slice(0, MAX_ITEMS)) + expect(onSubscriptionError).toHaveBeenCalledTimes(1) + }) + }) + }) }) describe('ending a subscription', () => { diff --git a/test/integration/SubscriberResends.test.ts b/test/integration/SubscriberResends.test.ts index 82805a78a..f3da36e01 100644 --- a/test/integration/SubscriberResends.test.ts +++ b/test/integration/SubscriberResends.test.ts @@ -162,12 +162,15 @@ describeRepeats('resends', () => { await client.publish(emptyStream.id, msg) const received = [] + let t!: ReturnType for await (const m of sub) { received.push(m.getParsedContent()) - setTimeout(() => { + clearTimeout(t) + t = setTimeout(() => { sub.cancel() }, 250) } + clearTimeout(t) expect(onResent).toHaveBeenCalledTimes(1) expect(received).toEqual([msg]) @@ -293,12 +296,12 @@ describeRepeats('resends', () => { const req = await client.publish(stream.id, newMessage) // should be realtime published.push(newMessage) publishedRequests.push(req) - let t: ReturnType + let t!: ReturnType for await (const msg of sub) { receivedMsgs.push(msg.getParsedContent()) if (receivedMsgs.length === published.length) { await sub.return() - clearTimeout(t!) + clearTimeout(t) t = setTimeout(() => { // await wait() // give resent event a chance to fire onResent.reject(new Error('resent never called')) @@ -307,7 +310,7 @@ describeRepeats('resends', () => { } await onResent - clearTimeout(t!) + clearTimeout(t) expect(receivedMsgs).toHaveLength(published.length) expect(receivedMsgs).toEqual(published) @@ -445,8 +448,8 @@ describeRepeats('resends', () => { published.push(message) publishedRequests.push(req) - let t - let receivedMsgs + let t!: ReturnType + let receivedMsgs: any[] try { receivedMsgs = await collect(sub, async ({ received }) => { if (received.length === published.length) { diff --git a/test/integration/Validation.test.js b/test/integration/Validation.test.js index bcefc8b0f..027d32cab 100644 --- a/test/integration/Validation.test.js +++ b/test/integration/Validation.test.js @@ -105,10 +105,12 @@ describeRepeats('Validation', () => { it('subscribe fails gracefully when signature bad', async () => { const sub = await client.subscribe(stream.id) + + const errs = [] const onSubError = jest.fn((err) => { - expect(err).toBeInstanceOf(Error) - expect(err.message).toMatch('signature') + errs.push(err) }) + sub.on('error', onSubError) const { parse } = client.connection let count = 0 @@ -131,13 +133,27 @@ describeRepeats('Validation', () => { timestamp: 111111, }) + let t const received = [] for await (const m of sub) { received.push(m.getParsedContent()) if (received.length === published.length - 1) { + clearTimeout(t) + // give it a chance to fail + t = setTimeout(() => { + sub.cancel() + }, 500) + } + + if (received.length === published.length) { + // failed + clearTimeout(t) break } } + + clearTimeout(t) + const expectedMessages = [ // remove bad message ...published.slice(0, BAD_INDEX), @@ -147,6 +163,11 @@ describeRepeats('Validation', () => { expect(received).toEqual(expectedMessages) expect(client.connection.getState()).toBe('connected') expect(onSubError).toHaveBeenCalledTimes(1) + expect(errs).toHaveLength(1) + errs.forEach((err) => { + expect(err).toBeInstanceOf(Error) + expect(err.message).toMatch('signature') + }) }, 10000) }) diff --git a/test/unit/Encryption.test.ts b/test/unit/Encryption.test.ts index ba54f3d96..0919aedef 100644 --- a/test/unit/Encryption.test.ts +++ b/test/unit/Encryption.test.ts @@ -163,20 +163,24 @@ function TestEncryptionUtil({ isBrowser = false } = {}) { }) }) - it('validateGroupKey() throws if key is the wrong size', () => { - expect(() => { - EncryptionUtil.validateGroupKey(crypto.randomBytes(16)) - }).toThrow() - }) + describe('GroupKey.validate', () => { + it('throws if key is the wrong size', () => { + expect(() => { + GroupKey.validate(GroupKey.from(['test', crypto.randomBytes(16)])) + }).toThrow('size') + }) - it('validateGroupKey() throws if key is not a buffer', () => { - expect(() => { - EncryptionUtil.validateGroupKey(ethers.utils.hexlify(GroupKey.generate() as any)) - }).toThrow() - }) + it('throws if key is not a buffer', () => { + expect(() => { + // expected error below is desirable, show typecheks working as intended + // @ts-expect-error + GroupKey.validate(GroupKey.from(['test', Array.from(crypto.randomBytes(32))])) + }).toThrow('Buffer') + }) - it('validateGroupKey() does not throw', () => { - EncryptionUtil.validateGroupKey(GroupKey.generate()) + it('does not throw with valid values', () => { + GroupKey.validate(GroupKey.generate()) + }) }) }) } diff --git a/test/unit/IteratorTest.js b/test/unit/IteratorTest.js new file mode 100644 index 000000000..e43a2e12a --- /dev/null +++ b/test/unit/IteratorTest.js @@ -0,0 +1,219 @@ +export const expected = [1, 2, 3, 4, 5, 6, 7, 8] + +export const MAX_ITEMS = 3 + +const wait = (ms) => new Promise((resolve) => setTimeout(resolve, ms)) +const WAIT = 20 + +export default function IteratorTest(name, fn) { + describe(`${name} IteratorTest`, () => { + it('runs to completion', async () => { + const received = [] + const itr = fn({ + items: expected, max: MAX_ITEMS + }) + for await (const msg of itr) { + received.push(msg) + } + expect(received).toEqual(expected) + }) + + it('can return in finally', async () => { + const received = [] + const itr = (async function* Outer() { + const innerItr = fn({ + items: expected, max: MAX_ITEMS + })[Symbol.asyncIterator]() + try { + yield* innerItr + } finally { + await innerItr.return() // note itr.return would block + } + }()) + + for await (const msg of itr) { + received.push(msg) + if (received.length === MAX_ITEMS) { + break + } + } + expect(received).toEqual(expected.slice(0, MAX_ITEMS)) + }) + + it('can return mid-iteration', async () => { + const received = [] + for await (const msg of fn({ + items: expected, max: MAX_ITEMS + })) { + received.push(msg) + if (received.length === MAX_ITEMS) { + break + } + } + expect(received).toEqual(expected.slice(0, MAX_ITEMS)) + }) + + it('can throw mid-iteration', async () => { + const received = [] + const err = new Error('expected err') + await expect(async () => { + for await (const msg of fn({ + items: expected, max: MAX_ITEMS + })) { + received.push(msg) + if (received.length === MAX_ITEMS) { + throw err + } + } + }).rejects.toThrow(err) + expect(received).toEqual(expected.slice(0, MAX_ITEMS)) + }) + + it('can throw() mid-iteration', async () => { + const received = [] + const err = new Error('expected err 2') + await expect(async () => { + const it = fn({ + items: expected, + max: MAX_ITEMS, + errors: [err], + }) + for await (const msg of it) { + received.push(msg) + if (received.length === MAX_ITEMS) { + await it.throw(err) + } + } + }).rejects.toThrow(err) + expect(received).toEqual(expected.slice(0, MAX_ITEMS)) + }) + + it('throws parent mid-iteration', async () => { + const received = [] + const err = new Error('expected err') + async function* parentGen() { + const s = fn({ + items: expected, + max: MAX_ITEMS, + errors: [err], + }) + for await (const msg of s) { + yield msg + if (received.length === MAX_ITEMS) { + await s.throw(err) + } + } + } + await expect(async () => { + for await (const msg of parentGen()) { + received.push(msg) + } + }).rejects.toThrow(err) + expect(received).toEqual(expected.slice(0, MAX_ITEMS)) + }) + + it('can throw before iterating', async () => { + const received = [] + const itr = fn({ + items: expected, max: MAX_ITEMS + })[Symbol.asyncIterator]() + const err = new Error('expected err') + + await expect(async () => { + await itr.throw(err) + }).rejects.toThrow(err) + + // does not throw + for await (const msg of itr) { + received.push(msg) + } + expect(received).toEqual([]) + }) + + it('can return before iterating', async () => { + const itr = fn({ + items: expected, max: MAX_ITEMS + })[Symbol.asyncIterator]() + await itr.return() + const received = [] + for await (const msg of itr) { + received.push(msg) + } + expect(received).toEqual([]) + }) + + it('can queue next calls', async () => { + const itr = fn({ + items: expected, + max: MAX_ITEMS + })[Symbol.asyncIterator]() + const tasks = expected.map(async () => itr.next()) + const received = await Promise.all(tasks) + expect(received.map(({ value }) => value)).toEqual(expected) + await itr.return() + }) + + it('can queue delayed next calls', async () => { + const itr = fn({ + items: expected, max: MAX_ITEMS + })[Symbol.asyncIterator]() + const tasks = expected.map(async () => { + await wait(WAIT) + return itr.next() + }) + const received = await Promise.all(tasks) + expect(received.map(({ value }) => value)).toEqual(expected) + await itr.return() + }) + + it('can queue delayed next calls resolving out of order', async () => { + const itr = fn({ + items: expected, max: MAX_ITEMS + })[Symbol.asyncIterator]() + const tasks = expected.map(async (v, index, arr) => { + // resolve backwards + const result = await itr.next() + await wait(WAIT + (WAIT * 10 * ((arr.length - index) / arr.length))) + return result + }) + const received = await Promise.all(tasks) + expect(received.map(({ value }) => value)).toEqual(expected) + await itr.return() + }) + + it('can handle error in queued next calls', async () => { + const itr = fn({ + items: expected, + })[Symbol.asyncIterator]() + const err = new Error('expected') + const tasks = expected.map(async (v, index, arr) => { + const result = await itr.next() + await wait(WAIT + WAIT * ((arr.length - index) / arr.length)) + if (index === MAX_ITEMS) { + throw err + } + return result + }) + + const received = await Promise.allSettled(tasks) + + expect(received).toEqual(expected.map((value, index) => { + if (index === MAX_ITEMS) { + return { + status: 'rejected', + reason: err + } + } + + return { + status: 'fulfilled', + value: { + done: false, + value, + } + } + })) + await itr.return() + }) + }) +} diff --git a/test/unit/PushQueue.test.js b/test/unit/PushQueue.test.js index 9245649de..a60d11ee8 100644 --- a/test/unit/PushQueue.test.js +++ b/test/unit/PushQueue.test.js @@ -2,6 +2,9 @@ import { wait } from 'streamr-test-utils' import AbortController from 'node-abort-controller' import PushQueue from '../../src/utils/PushQueue' +import { Defer } from '../../src/utils' + +import IteratorTest from './IteratorTest' const expected = [1, 2, 3, 4, 5, 6, 7, 8] const WAIT = 20 @@ -18,6 +21,14 @@ async function* generate(items = expected) { } describe('PushQueue', () => { + IteratorTest('PushQueue works like regular iterator', ({ items }) => ( + new PushQueue([...items, null]) + )) + + IteratorTest('PushQueue.from works like regular iterator', ({ items }) => ( + PushQueue.from(generate([...items, null])) + )) + it('supports pre-buffering, async push & return', async () => { const q = new PushQueue() expect(q.length).toBe(0) @@ -25,6 +36,7 @@ describe('PushQueue', () => { expect(q.length).toBe(1) q.push(expected[1]) expect(q.length).toBe(2) + const done = Defer() setTimeout(() => { // buffer should have drained by now @@ -34,6 +46,7 @@ describe('PushQueue', () => { setTimeout(() => { q.return(5) // both items above should get through q.push('nope') // this should not + done.resolve() }, 20) }, 10) @@ -46,6 +59,7 @@ describe('PushQueue', () => { expect(i).toBe(4) // buffer should have drained at end expect(q.length).toBe(0) + await done }) it('supports passing initial values to constructor', async () => { diff --git a/test/unit/iterators.test.js b/test/unit/iterators.test.js index 268541443..521ddbd02 100644 --- a/test/unit/iterators.test.js +++ b/test/unit/iterators.test.js @@ -4,7 +4,8 @@ import { iteratorFinally, CancelableGenerator, pipeline } from '../../src/utils/ import { Defer } from '../../src/utils' import PushQueue from '../../src/utils/PushQueue' -const expected = [1, 2, 3, 4, 5, 6, 7, 8] +import IteratorTest, { expected, MAX_ITEMS } from './IteratorTest' + const WAIT = 20 async function* generate(items = expected) { @@ -17,117 +18,19 @@ async function* generate(items = expected) { await wait(WAIT * 0.1) } -const MAX_ITEMS = 3 - -function IteratorTest(name, fn) { - describe(`${name} IteratorTest`, () => { - it('runs to completion', async () => { - const received = [] - const itr = fn() - for await (const msg of itr) { - received.push(msg) - } - expect(received).toEqual(expected) - }) - - it('can return in finally', async () => { - const received = [] - const itr = (async function* Outer() { - const innerItr = fn()[Symbol.asyncIterator]() - try { - yield* innerItr - } finally { - await innerItr.return() // note itr.return would block - } - }()) - - for await (const msg of itr) { - received.push(msg) - if (received.length === MAX_ITEMS) { - break - } - } - expect(received).toEqual(expected.slice(0, MAX_ITEMS)) - }) - - it('can return mid-iteration', async () => { - const received = [] - for await (const msg of fn()) { - received.push(msg) - if (received.length === MAX_ITEMS) { - break - } - } - expect(received).toEqual(expected.slice(0, MAX_ITEMS)) - }) - - it('can throw mid-iteration', async () => { - const received = [] - const err = new Error('expected err') - await expect(async () => { - for await (const msg of fn()) { - received.push(msg) - if (received.length === MAX_ITEMS) { - throw err - } - } - }).rejects.toThrow(err) - expect(received).toEqual(expected.slice(0, MAX_ITEMS)) - }) - - it('throws parent mid-iteration', async () => { - const received = [] - const err = new Error('expected err') - async function* parentGen() { - for await (const msg of fn()) { - yield msg - if (received.length === MAX_ITEMS) { - throw err - } - } - } - await expect(async () => { - for await (const msg of parentGen()) { - received.push(msg) - } - }).rejects.toThrow(err) - expect(received).toEqual(expected.slice(0, MAX_ITEMS)) - }) - - it('can throw before iterating', async () => { - const received = [] - const itr = fn()[Symbol.asyncIterator]() - const err = new Error('expected err') - - await expect(async () => { - await itr.throw(err) - }).rejects.toThrow(err) - - // does not throw - for await (const msg of itr) { - received.push(msg) - } - expect(received).toEqual([]) - }) - - it('can return before iterating', async () => { - const itr = fn()[Symbol.asyncIterator]() - await itr.return() - const received = [] - for await (const msg of itr) { - received.push(msg) - } - expect(received).toEqual([]) - }) - - it('can queue next calls', async () => { - const itr = fn()[Symbol.asyncIterator]() - const tasks = expected.map(async () => itr.next()) - const received = await Promise.all(tasks) - expect(received.map(({ value }) => value)).toEqual(expected) - await itr.return() - }) - }) +async function* generateThrow(items = expected, { max = MAX_ITEMS, err = new Error('expected') }) { + let index = 0 + await wait(WAIT * 0.1) + for await (const item of items) { + index += 1 + await wait(WAIT * 0.1) + if (index > max) { + throw err + } + yield item + await wait(WAIT * 0.1) + } + await wait(WAIT * 0.1) } describe('Iterator Utils', () => { @@ -157,24 +60,29 @@ describe('Iterator Utils', () => { }) it('runs fn when iterator.return() is called asynchronously', async () => { - const received = [] - const itr = iteratorFinally(generate(), onFinally) - const onTimeoutReached = jest.fn() - let receievedAtCallTime - for await (const msg of itr) { - received.push(msg) - if (received.length === MAX_ITEMS) { - // eslint-disable-next-line no-loop-func - setTimeout(() => { - onTimeoutReached() - receievedAtCallTime = received - itr.return() - }) + const done = Defer() + try { + const received = [] + const itr = iteratorFinally(generate(), onFinally) + const onTimeoutReached = jest.fn() + let receievedAtCallTime + for await (const msg of itr) { + received.push(msg) + if (received.length === MAX_ITEMS) { + // eslint-disable-next-line no-loop-func + setTimeout(done.wrap(() => { + onTimeoutReached() + receievedAtCallTime = received + itr.return() + })) + } } - } - expect(onTimeoutReached).toHaveBeenCalledTimes(1) - expect(received).toEqual(receievedAtCallTime) + expect(onTimeoutReached).toHaveBeenCalledTimes(1) + expect(received).toEqual(receievedAtCallTime) + } finally { + await done + } }) it('runs fn when iterator returns + breaks during iteration', async () => { @@ -281,11 +189,72 @@ describe('Iterator Utils', () => { await expect(async () => itr.throw(err)).rejects.toThrow(err) expect(onFinally).toHaveBeenCalledTimes(1) expect(onFinallyAfter).toHaveBeenCalledTimes(1) - // doesn't throw, matches native iterators + // NOTE: doesn't throw, matches native iterators + for await (const msg of itr) { + received.push(msg) + } + expect(received).toEqual([]) + }) + + it('runs fn when inner iterator throws during iteration', async () => { + const received = [] + const err = new Error('expected err') + const itr = iteratorFinally(generateThrow(expected, { + err, + }), onFinally) + await expect(async () => { + for await (const msg of itr) { + received.push(msg) + } + }).rejects.toThrow(err) + expect(onFinally).toHaveBeenCalledTimes(1) + expect(onFinallyAfter).toHaveBeenCalledTimes(1) + expect(received).toEqual(expected.slice(0, MAX_ITEMS)) + }) + + it('errored before start iterator works if onFinally is async', async () => { + const received = [] + const errs = [] + const onFinallyDelayed = jest.fn(async (err) => { + errs.push(err) + await wait(100) + return onFinally(err) + }) + const itr = iteratorFinally(generate(), onFinallyDelayed) + const err = new Error('expected err 1') + await expect(async () => { + await itr.throw(err) + }).rejects.toThrow(err) for await (const msg of itr) { received.push(msg) } expect(received).toEqual([]) + expect(onFinallyDelayed).toHaveBeenCalledTimes(1) + expect(errs).toEqual([err]) + }) + + it('errored iterator works if onFinally is async', async () => { + const received = [] + const errs = [] + const onFinallyDelayed = jest.fn(async (err) => { + errs.push(err) + await wait(100) + return onFinally(err) + }) + const itr = iteratorFinally(generate(), onFinallyDelayed) + const err = new Error('expected err 2') + await expect(async () => { + for await (const msg of itr) { + received.push(msg) + if (received.length === MAX_ITEMS) { + await itr.throw(err) + } + } + }).rejects.toThrow(err) + + expect(received).toEqual(expected.slice(0, MAX_ITEMS)) + expect(onFinallyDelayed).toHaveBeenCalledTimes(1) + expect(errs).toEqual([err]) }) describe('nesting', () => { @@ -417,17 +386,16 @@ describe('Iterator Utils', () => { }) IteratorTest('CancelableGenerator', () => { - const [, itr] = CancelableGenerator(generate(), onFinally) - return itr + return CancelableGenerator(generate(), onFinally) }) it('can cancel during iteration', async () => { - const [cancel, itr] = CancelableGenerator(generate(), onFinally) + const itr = CancelableGenerator(generate(), onFinally) const received = [] for await (const msg of itr) { received.push(msg) if (received.length === MAX_ITEMS) { - cancel() + itr.cancel() } } @@ -436,9 +404,9 @@ describe('Iterator Utils', () => { }) it('can cancel before iteration', async () => { - const [cancel, itr] = CancelableGenerator(generate(), onFinally) + const itr = CancelableGenerator(generate(), onFinally) const received = [] - cancel() + itr.cancel() expect(itr.isCancelled()).toEqual(true) for await (const msg of itr) { received.push(msg) @@ -449,12 +417,12 @@ describe('Iterator Utils', () => { }) it('can cancel with error before iteration', async () => { - const [cancel, itr] = CancelableGenerator(generate(), () => { + const itr = CancelableGenerator(generate(), () => { return onFinally() }) const received = [] const err = new Error('expected') - cancel(err) + itr.cancel(err) await expect(async () => { for await (const msg of itr) { received.push(msg) @@ -464,33 +432,71 @@ describe('Iterator Utils', () => { expect(received).toEqual([]) }) + it('cancels with error when iterator.cancel(err) is called asynchronously with error', async () => { + const done = Defer() + try { + const err = new Error('expected') + const received = [] + const itr = CancelableGenerator(generate(), onFinally, { + timeout: WAIT, + }) + let receievedAtCallTime + await expect(async () => { + for await (const msg of itr) { + received.push(msg) + if (received.length === MAX_ITEMS) { + // eslint-disable-next-line no-loop-func + setTimeout(done.wrap(async () => { + receievedAtCallTime = received + await itr.cancel(err) + expect(onFinally).toHaveBeenCalledTimes(1) + expect(onFinallyAfter).toHaveBeenCalledTimes(1) + })) + } + } + }).rejects.toThrow(err) + + expect(received).toEqual(receievedAtCallTime) + expect(itr.isCancelled()).toEqual(true) + } catch (err) { + done.reject(err) + } finally { + await done + } + }) + it('cancels when iterator.cancel() is called asynchronously', async () => { - const received = [] - const [cancel, itr] = CancelableGenerator(generate(), onFinally, { - timeout: WAIT, - }) - let receievedAtCallTime - for await (const msg of itr) { - received.push(msg) - if (received.length === MAX_ITEMS) { - // eslint-disable-next-line no-loop-func - setTimeout(async () => { - receievedAtCallTime = received - await cancel() - expect(onFinally).toHaveBeenCalledTimes(1) - expect(onFinallyAfter).toHaveBeenCalledTimes(1) - }) + const done = Defer() + try { + const received = [] + const itr = CancelableGenerator(generate(), onFinally, { + timeout: WAIT, + }) + let receievedAtCallTime + for await (const msg of itr) { + received.push(msg) + if (received.length === MAX_ITEMS) { + // eslint-disable-next-line no-loop-func + setTimeout(done.wrap(async () => { + receievedAtCallTime = received + await itr.cancel() + expect(onFinally).toHaveBeenCalledTimes(1) + expect(onFinallyAfter).toHaveBeenCalledTimes(1) + })) + } } - } - expect(received).toEqual(receievedAtCallTime) - expect(itr.isCancelled()).toEqual(true) + expect(received).toEqual(receievedAtCallTime) + expect(itr.isCancelled()).toEqual(true) + } finally { + await done + } }) it('prevents subsequent .next call', async () => { const received = [] const triggeredForever = jest.fn() - const [cancel, itr] = CancelableGenerator((async function* Gen() { + const itr = CancelableGenerator((async function* Gen() { yield* expected yield await new Promise(() => { triggeredForever() // should not get here @@ -500,7 +506,7 @@ describe('Iterator Utils', () => { for await (const msg of itr) { received.push(msg) if (received.length === expected.length) { - await cancel() + await itr.cancel() expect(onFinally).toHaveBeenCalledTimes(1) expect(onFinallyAfter).toHaveBeenCalledTimes(1) } @@ -514,11 +520,11 @@ describe('Iterator Utils', () => { it('interrupts outstanding .next call', async () => { const received = [] const triggeredForever = jest.fn() - const [cancel, itr] = CancelableGenerator((async function* Gen() { + const itr = CancelableGenerator((async function* Gen() { yield* expected yield await new Promise(() => { triggeredForever() - cancel() + itr.cancel() }) // would wait forever }()), onFinally) @@ -532,36 +538,41 @@ describe('Iterator Utils', () => { }) it('interrupts outstanding .next call when called asynchronously', async () => { - const received = [] - const triggeredForever = jest.fn() - const [cancel, itr] = CancelableGenerator((async function* Gen() { - yield* expected - yield await new Promise(() => { - triggeredForever() - }) // would wait forever - }()), onFinally, { - timeout: WAIT, - }) + const done = Defer() + try { + const received = [] + const triggeredForever = jest.fn() + const itr = CancelableGenerator((async function* Gen() { + yield* expected + yield await new Promise(() => { + triggeredForever() + }) // would wait forever + }()), onFinally, { + timeout: WAIT, + }) - for await (const msg of itr) { - received.push(msg) - if (received.length === expected.length) { - // eslint-disable-next-line no-loop-func - setTimeout(async () => { - await cancel() - expect(onFinally).toHaveBeenCalledTimes(1) - expect(onFinallyAfter).toHaveBeenCalledTimes(1) - }) + for await (const msg of itr) { + received.push(msg) + if (received.length === expected.length) { + // eslint-disable-next-line no-loop-func + setTimeout(done.wrap(async () => { + await itr.cancel() + expect(onFinally).toHaveBeenCalledTimes(1) + expect(onFinallyAfter).toHaveBeenCalledTimes(1) + })) + } } - } - expect(received).toEqual(expected) - expect(itr.isCancelled()).toEqual(true) + expect(received).toEqual(expected) + expect(itr.isCancelled()).toEqual(true) + } finally { + await done + } }) it('stops iterator', async () => { const shouldRunFinally = jest.fn() - const [cancel, itr] = CancelableGenerator((async function* Gen() { + const itr = CancelableGenerator((async function* Gen() { try { yield 1 await wait(WAIT) @@ -579,7 +590,7 @@ describe('Iterator Utils', () => { for await (const msg of itr) { received.push(msg) if (received.length === 2) { - cancel() + itr.cancel() expect(itr.isCancelled()).toEqual(true) } } @@ -592,125 +603,145 @@ describe('Iterator Utils', () => { }) it('interrupts outstanding .next call with error', async () => { - const received = [] - const [cancel, itr] = CancelableGenerator((async function* Gen() { - yield* expected - yield await new Promise(() => {}) // would wait forever - }()), onFinally, { - timeout: WAIT, - }) - - const err = new Error('expected') + const done = Defer() + try { + const received = [] + const itr = CancelableGenerator((async function* Gen() { + yield* expected + yield await new Promise(() => {}) // would wait forever + }()), onFinally, { + timeout: WAIT, + }) - let receievedAtCallTime - await expect(async () => { - for await (const msg of itr) { - received.push(msg) - if (received.length === MAX_ITEMS) { - // eslint-disable-next-line no-loop-func - setTimeout(async () => { - receievedAtCallTime = received - await cancel(err) + const err = new Error('expected') - expect(onFinally).toHaveBeenCalledTimes(1) - expect(onFinallyAfter).toHaveBeenCalledTimes(1) - }) + let receievedAtCallTime + await expect(async () => { + for await (const msg of itr) { + received.push(msg) + if (received.length === MAX_ITEMS) { + // eslint-disable-next-line no-loop-func + setTimeout(done.wrap(async () => { + receievedAtCallTime = received + await itr.cancel(err) + + expect(onFinally).toHaveBeenCalledTimes(1) + expect(onFinallyAfter).toHaveBeenCalledTimes(1) + })) + } } - } - }).rejects.toThrow(err) + }).rejects.toThrow(err) - expect(received).toEqual(receievedAtCallTime) - expect(itr.isCancelled()).toEqual(true) + expect(received).toEqual(receievedAtCallTime) + expect(itr.isCancelled()).toEqual(true) + } finally { + await done + } }) it('can handle queued next calls', async () => { - const triggeredForever = jest.fn() - const [cancel, itr] = CancelableGenerator((async function* Gen() { - yield* expected - setTimeout(async () => { - await cancel() + const done = Defer() + try { + const triggeredForever = jest.fn() + const itr = CancelableGenerator((async function* Gen() { + yield* expected + setTimeout(done.wrap(async () => { + await itr.cancel() - expect(onFinally).toHaveBeenCalledTimes(1) - expect(onFinallyAfter).toHaveBeenCalledTimes(1) - }, WAIT * 2) - yield await new Promise(() => { - triggeredForever() - }) // would wait forever - }()), onFinally, { - timeout: WAIT, - }) + expect(onFinally).toHaveBeenCalledTimes(1) + expect(onFinallyAfter).toHaveBeenCalledTimes(1) + }), WAIT * 2) + yield await new Promise(() => { + triggeredForever() + }) // would wait forever + }()), onFinally, { + timeout: WAIT, + }) - const tasks = expected.map(async () => itr.next()) - tasks.push(itr.next()) // one more over the edge (should trigger forever promise) - const received = await Promise.all(tasks) - expect(received.map(({ value }) => value)).toEqual([...expected, undefined]) - expect(triggeredForever).toHaveBeenCalledTimes(1) - expect(itr.isCancelled()).toEqual(true) + const tasks = expected.map(async () => itr.next()) + tasks.push(itr.next()) // one more over the edge (should trigger forever promise) + const received = await Promise.all(tasks) + expect(received.map(({ value }) => value)).toEqual([...expected, undefined]) + expect(triggeredForever).toHaveBeenCalledTimes(1) + expect(itr.isCancelled()).toEqual(true) + } finally { + await done + } }) it('can handle queued next calls resolving out of order', async () => { - const triggeredForever = jest.fn() - const [cancel, itr] = CancelableGenerator((async function* Gen() { - let i = 0 - for await (const v of expected) { - i += 1 - await wait((expected.length - i - 1) * 2 * WAIT) - yield v - } + const done = Defer() + try { + const triggeredForever = jest.fn() + const itr = CancelableGenerator((async function* Gen() { + let i = 0 + for await (const v of expected) { + i += 1 + await wait((expected.length - i - 1) * 2 * WAIT) + yield v + } - setTimeout(async () => { - await cancel() + setTimeout(done.wrap(async () => { + await itr.cancel() - expect(onFinally).toHaveBeenCalledTimes(1) - expect(onFinallyAfter).toHaveBeenCalledTimes(1) - }, WAIT * 2) + expect(onFinally).toHaveBeenCalledTimes(1) + expect(onFinallyAfter).toHaveBeenCalledTimes(1) + }), WAIT * 2) - yield await new Promise(() => { - triggeredForever() - }) // would wait forever - }()), onFinally, { - timeout: WAIT, - }) + yield await new Promise(() => { + triggeredForever() + }) // would wait forever + }()), onFinally, { + timeout: WAIT, + }) - const tasks = expected.map(async () => itr.next()) - tasks.push(itr.next()) // one more over the edge (should trigger forever promise) - const received = await Promise.all(tasks) - expect(received.map(({ value }) => value)).toEqual([...expected, undefined]) - expect(triggeredForever).toHaveBeenCalledTimes(1) + const tasks = expected.map(async () => itr.next()) + tasks.push(itr.next()) // one more over the edge (should trigger forever promise) + const received = await Promise.all(tasks) + expect(received.map(({ value }) => value)).toEqual([...expected, undefined]) + expect(triggeredForever).toHaveBeenCalledTimes(1) + } finally { + await done + } }) it('ignores err if cancelled', async () => { - const received = [] - const err = new Error('expected') - const d = Defer() - const [cancel, itr] = CancelableGenerator((async function* Gen() { - yield* expected - await wait(WAIT * 2) - d.resolve() - throw new Error('should not see this') - }()), onFinally) - - let receievedAtCallTime - await expect(async () => { - for await (const msg of itr) { - received.push(msg) - if (received.length === MAX_ITEMS) { - // eslint-disable-next-line no-loop-func - setTimeout(async () => { - receievedAtCallTime = received - await cancel(err) + const done = Defer() + try { + const received = [] + const err = new Error('expected') + const d = Defer() + const itr = CancelableGenerator((async function* Gen() { + yield* expected + await wait(WAIT * 2) + d.resolve() + throw new Error('should not see this') + }()), onFinally) - expect(onFinally).toHaveBeenCalledTimes(1) - expect(onFinallyAfter).toHaveBeenCalledTimes(1) - }) + let receievedAtCallTime + await expect(async () => { + for await (const msg of itr) { + received.push(msg) + if (received.length === MAX_ITEMS) { + // eslint-disable-next-line no-loop-func + setTimeout(done.wrap(async () => { + receievedAtCallTime = received + await itr.cancel(err) + + expect(onFinally).toHaveBeenCalledTimes(1) + expect(onFinallyAfter).toHaveBeenCalledTimes(1) + })) + } } - } - }).rejects.toThrow(err) + }).rejects.toThrow(err) - await d - await wait(WAIT * 2) + await d + await wait(WAIT * 2) - expect(received).toEqual(receievedAtCallTime) + expect(received).toEqual(receievedAtCallTime) + } finally { + await done + } }) describe('nesting', () => { @@ -732,14 +763,14 @@ describe('Iterator Utils', () => { }) IteratorTest('CancelableGenerator nested', () => { - const [, itrInner] = CancelableGenerator(generate(), onFinallyInner) - const [, itrOuter] = CancelableGenerator(itrInner, onFinally) + const itrInner = CancelableGenerator(generate(), onFinallyInner) + const itrOuter = CancelableGenerator(itrInner, onFinally) return itrOuter }) it('can cancel nested cancellable iterator in finally', async () => { const waitInner = jest.fn() - const [cancelInner, itrInner] = CancelableGenerator((async function* Gen() { + const itrInner = CancelableGenerator((async function* Gen() { yield* generate() yield await new Promise(() => { // should not get here @@ -750,14 +781,14 @@ describe('Iterator Utils', () => { }) const waitOuter = jest.fn() - const [cancelOuter, itrOuter] = CancelableGenerator((async function* Gen() { + const itrOuter = CancelableGenerator((async function* Gen() { yield* itrInner yield await new Promise(() => { // should not get here waitOuter() }) // would wait forever }()), async () => { - await cancelInner() + await itrInner.cancel() expect(onFinallyInner).toHaveBeenCalledTimes(1) expect(onFinallyInnerAfter).toHaveBeenCalledTimes(1) await onFinally() @@ -769,7 +800,7 @@ describe('Iterator Utils', () => { for await (const msg of itrOuter) { received.push(msg) if (received.length === expected.length) { - await cancelOuter() + await itrOuter.cancel() } } @@ -779,57 +810,62 @@ describe('Iterator Utils', () => { }) it('can cancel nested cancellable iterator in finally, asynchronously', async () => { - const waitInner = jest.fn() - const [cancelInner, itrInner] = CancelableGenerator((async function* Gen() { - yield* generate() - yield await new Promise(() => { - // should not get here - waitInner() - }) // would wait forever - }()), onFinallyInner, { - timeout: WAIT, - }) + const done = Defer() + try { + const waitInner = jest.fn() + const itrInner = CancelableGenerator((async function* Gen() { + yield* generate() + yield await new Promise(() => { + // should not get here + waitInner() + }) // would wait forever + }()), onFinallyInner, { + timeout: WAIT, + }) - const waitOuter = jest.fn() - const [cancelOuter, itrOuter] = CancelableGenerator((async function* Gen() { - yield* itrInner - yield await new Promise(() => { - // should not get here - waitOuter() - }) // would wait forever - }()), async () => { - await cancelInner() - await onFinally() - }, { - timeout: WAIT, - }) + const waitOuter = jest.fn() + const itrOuter = CancelableGenerator((async function* Gen() { + yield* itrInner + yield await new Promise(() => { + // should not get here + waitOuter() + }) // would wait forever + }()), async () => { + await itrInner.cancel() + await onFinally() + }, { + timeout: WAIT, + }) - const received = [] - for await (const msg of itrOuter) { - received.push(msg) - if (received.length === expected.length) { - setTimeout(() => { - cancelOuter() - }) + const received = [] + for await (const msg of itrOuter) { + received.push(msg) + if (received.length === expected.length) { + setTimeout(done.wrap(() => { + itrOuter.cancel() + })) + } } - } - expect(waitOuter).toHaveBeenCalledTimes(1) - expect(waitInner).toHaveBeenCalledTimes(1) - expect(received).toEqual(expected) + expect(waitOuter).toHaveBeenCalledTimes(1) + expect(waitInner).toHaveBeenCalledTimes(1) + expect(received).toEqual(expected) + } finally { + await done + } }) }) it('can cancel in parallel and wait correctly for both', async () => { - const [cancel, itr] = CancelableGenerator(generate(), onFinally) + const itr = CancelableGenerator(generate(), onFinally) const ranTests = jest.fn() const received = [] for await (const msg of itr) { received.push(msg) if (received.length === MAX_ITEMS) { - const t1 = cancel() - const t2 = cancel() + const t1 = itr.cancel() + const t2 = itr.cancel() await Promise.race([t1, t2]) expect(onFinally).toHaveBeenCalledTimes(1) expect(onFinallyAfter).toHaveBeenCalledTimes(1) @@ -848,12 +884,22 @@ describe('Iterator Utils', () => { describe('pipeline', () => { let onFinally let onFinallyAfter + let errors = [] beforeEach(() => { - onFinallyAfter = jest.fn() - onFinally = jest.fn(async () => { + errors = [] + const errorsLocal = errors + onFinallyAfter = jest.fn((err) => { + if (errorsLocal !== errors) { return } // cross-contaminated test + + if (err) { + errors.push(err) + } + }) + + onFinally = jest.fn(async (err) => { await wait(WAIT) - onFinallyAfter() + onFinallyAfter(err) }) }) @@ -1058,6 +1104,123 @@ describe('Iterator Utils', () => { expect(afterStep2).toHaveBeenCalledTimes(1) }) + it('feeds items from one to next, stops & errors all when middle .throws()', async () => { + const receivedStep1 = [] + const receivedStep2 = [] + const afterStep1 = jest.fn() + const afterStep2 = jest.fn() + const catchStep1 = jest.fn() + const catchStep2 = jest.fn() + const err = new Error('expected') + + const p = pipeline([ + generate(), + async function* Step1(s) { + try { + for await (const msg of s) { + receivedStep1.push(msg) + yield msg * 2 + if (receivedStep1.length === MAX_ITEMS) { + await s.throw(err) + } + } + } catch (error) { + catchStep1(error) + throw error + } finally { + afterStep1() + } + }, + async function* Step2(s) { + try { + for await (const msg of s) { + receivedStep2.push(msg) + yield msg * 10 + } + } catch (error) { + catchStep2(error) + throw error + } finally { + afterStep2() + } + } + ], onFinally) + + const received = [] + await expect(async () => { + for await (const msg of p) { + received.push(msg) + } + }).rejects.toThrow(err) + + expect(received).toEqual(expected.slice(0, MAX_ITEMS).map((v) => v * 20)) + expect(receivedStep2).toEqual(expected.slice(0, MAX_ITEMS).map((v) => v * 2)) + expect(receivedStep1).toEqual(expected.slice(0, MAX_ITEMS)) + expect(afterStep1).toHaveBeenCalledTimes(1) + expect(afterStep2).toHaveBeenCalledTimes(1) + expect(catchStep1).toHaveBeenCalledTimes(1) + expect(catchStep2).toHaveBeenCalledTimes(1) + }) + + it('feeds items from one to next, stops all when end .throws()', async () => { + const receivedStep1 = [] + const receivedStep2 = [] + const afterStep1 = jest.fn() + const afterStep2 = jest.fn(async () => { + throw new Error('oops') + }) + const catchStep1 = jest.fn() + const catchStep2 = jest.fn() + const err = new Error('expected') + + const p = pipeline([ + generate(), + async function* Step1(s) { + try { + for await (const msg of s) { + receivedStep1.push(msg) + yield msg * 2 + } + } catch (error) { + catchStep1(error) + throw error + } finally { + afterStep1() + } + }, + async function* Step2(s) { + try { + for await (const msg of s) { + receivedStep2.push(msg) + yield msg * 10 + if (receivedStep2.length === MAX_ITEMS) { + await s.throw(err) + } + } + } catch (error) { + catchStep2(error) + throw error + } finally { + await afterStep2() + } + } + ], onFinally) + + const received = [] + await expect(async () => { + for await (const msg of p) { + received.push(msg) + } + }).rejects.toThrow(err) + + expect(received).toEqual(expected.slice(0, MAX_ITEMS).map((v) => v * 20)) + expect(receivedStep2).toEqual(expected.slice(0, MAX_ITEMS).map((v) => v * 2)) + expect(receivedStep1).toEqual(expected.slice(0, MAX_ITEMS)) + expect(afterStep1).toHaveBeenCalledTimes(1) + expect(afterStep2).toHaveBeenCalledTimes(1) + expect(catchStep1).toHaveBeenCalledTimes(1) + expect(catchStep2).toHaveBeenCalledTimes(1) + }) it('handles errors before', async () => { const err = new Error('expected') @@ -2100,8 +2263,8 @@ describe('Iterator Utils', () => { expect(receivedStep1).toEqual(expected.slice(0, MAX_ITEMS)) expect(receivedStep2).toEqual(expected.slice(0, MAX_ITEMS)) // all streams were closed - expect(onFirstStreamClose).toHaveBeenCalledTimes(1) expect(onInputStreamClose).toHaveBeenCalledTimes(1) + expect(onFirstStreamClose).toHaveBeenCalledTimes(1) expect(onFinallyInner).toHaveBeenCalledTimes(1) expect(onFinallyInnerAfter).toHaveBeenCalledTimes(1)