From ac7e6d994bd26e12d3ad55c268689354ff5af7cc Mon Sep 17 00:00:00 2001 From: Janka Uryga Date: Mon, 10 Jun 2024 20:32:08 +0200 Subject: [PATCH] refactor: add AwaiterOnce and awaiter tests --- packages/next/src/server/lib/awaiter.test.ts | 81 +++++++++++++++++++ packages/next/src/server/lib/awaiter.ts | 57 +++++++++++-- .../server/web/spec-extension/fetch-event.ts | 11 +-- .../utils/simulated-invocation.js | 4 +- 4 files changed, 138 insertions(+), 15 deletions(-) create mode 100644 packages/next/src/server/lib/awaiter.test.ts diff --git a/packages/next/src/server/lib/awaiter.test.ts b/packages/next/src/server/lib/awaiter.test.ts new file mode 100644 index 0000000000000..89af27bb052b4 --- /dev/null +++ b/packages/next/src/server/lib/awaiter.test.ts @@ -0,0 +1,81 @@ +import { InvariantError } from '../../shared/lib/invariant-error' +import { AwaiterMulti, AwaiterOnce } from './awaiter' + +describe('AwaiterOnce/AwaiterMulti', () => { + describe.each([ + { name: 'AwaiterMulti', impl: AwaiterMulti }, + { name: 'AwaiterOnce', impl: AwaiterOnce }, + ])('$name', ({ impl: AwaiterImpl }) => { + it('awaits promises added by other promises', async () => { + const awaiter = new AwaiterImpl() + + const MAX_DEPTH = 5 + const promises: TrackedPromise[] = [] + + const waitUntil = (promise: Promise) => { + promises.push(trackPromiseSettled(promise)) + awaiter.waitUntil(promise) + } + + const makeNestedPromise = async () => { + if (promises.length >= MAX_DEPTH) { + return + } + await sleep(100) + waitUntil(makeNestedPromise()) + } + + waitUntil(makeNestedPromise()) + + await awaiter.awaiting() + + for (const promise of promises) { + expect(promise.isSettled).toBe(true) + } + }) + + it('calls onError for rejected promises', async () => { + const onError = jest.fn() + const awaiter = new AwaiterImpl({ onError }) + + awaiter.waitUntil(Promise.reject('error 1')) + awaiter.waitUntil( + sleep(100).then(() => awaiter.waitUntil(Promise.reject('error 2'))) + ) + + await awaiter.awaiting() + + expect(onError).toHaveBeenCalledWith('error 1') + expect(onError).toHaveBeenCalledWith('error 2') + }) + }) +}) + +describe('AwaiterOnce', () => { + it("does not allow calling waitUntil after it's been awaited", async () => { + const awaiter = new AwaiterOnce() + awaiter.waitUntil(Promise.resolve(1)) + await awaiter.awaiting() + expect(() => awaiter.waitUntil(Promise.resolve(2))).toThrow(InvariantError) + }) +}) + +type TrackedPromise = Promise & { isSettled: boolean } + +function trackPromiseSettled(promise: Promise): TrackedPromise { + const tracked = promise as TrackedPromise + tracked.isSettled = false + tracked.then( + () => { + tracked.isSettled = true + }, + () => { + tracked.isSettled = true + } + ) + return tracked +} + +function sleep(duration: number) { + return new Promise((resolve) => setTimeout(resolve, duration)) +} diff --git a/packages/next/src/server/lib/awaiter.ts b/packages/next/src/server/lib/awaiter.ts index e8b2ede3bd08f..7f5c32c441302 100644 --- a/packages/next/src/server/lib/awaiter.ts +++ b/packages/next/src/server/lib/awaiter.ts @@ -1,16 +1,31 @@ +import { InvariantError } from '../../shared/lib/invariant-error' + /** - * The Awaiter class is used to manage and await multiple promises. + * Provides a `waitUntil` implementation which gathers promises to be awaited later (via {@link AwaiterMulti.awaiting}). + * Unlike a simple `Promise.all`, {@link AwaiterMulti} works recursively -- + * if a promise passed to {@link AwaiterMulti.waitUntil} calls `waitUntil` again, + * that second promise will also be awaited. */ -export class Awaiter { +export class AwaiterMulti { private promises: Set> = new Set() - private onError: ((error: unknown) => void) | undefined + private onError: (error: unknown) => void constructor({ onError }: { onError?: (error: unknown) => void } = {}) { this.onError = onError ?? console.error } - public waitUntil = (promise: Promise) => { - this.promises.add(promise.catch(this.onError)) + public waitUntil = (promise: Promise): void => { + // if a promise settles before we await it, we can drop it. + const cleanup = () => { + this.promises.delete(promise) + } + + this.promises.add( + promise.then(cleanup, (err) => { + cleanup() + this.onError(err) + }) + ) } public async awaiting(): Promise { @@ -21,3 +36,35 @@ export class Awaiter { } } } + +/** + * Like {@link AwaiterMulti}, but can only be awaited once. + * If {@link AwaiterOnce.waitUntil} is called after that, it will throw. + */ +export class AwaiterOnce { + private awaiter: AwaiterMulti + private done: boolean = false + private pending: Promise | undefined + + constructor(options: { onError?: (error: unknown) => void } = {}) { + this.awaiter = new AwaiterMulti(options) + } + + public waitUntil = (promise: Promise): void => { + if (this.done) { + throw new InvariantError( + 'Cannot call waitUntil() on an AwaiterOnce that was already awaited' + ) + } + return this.awaiter.waitUntil(promise) + } + + public async awaiting(): Promise { + if (!this.pending) { + this.pending = this.awaiter.awaiting().finally(() => { + this.done = true + }) + } + return this.pending + } +} diff --git a/packages/next/src/server/web/spec-extension/fetch-event.ts b/packages/next/src/server/web/spec-extension/fetch-event.ts index 6fcf3ca53d300..43d509bcdd3fa 100644 --- a/packages/next/src/server/web/spec-extension/fetch-event.ts +++ b/packages/next/src/server/web/spec-extension/fetch-event.ts @@ -1,11 +1,10 @@ -import { Awaiter } from '../../lib/awaiter' +import { AwaiterOnce } from '../../lib/awaiter' import { PageSignatureError } from '../error' import type { NextRequest } from './request' const responseSymbol = Symbol('response') const passThroughSymbol = Symbol('passThrough') const awaiterSymbol = Symbol('awaiter') -const waitUntilCacheSymbol = Symbol('waitUntil.cache') export const waitUntilSymbol = Symbol('waitUntil') @@ -13,14 +12,10 @@ class FetchEvent { [responseSymbol]?: Promise; [passThroughSymbol] = false; - [awaiterSymbol] = new Awaiter(); - [waitUntilCacheSymbol]: Promise | undefined = undefined; + [awaiterSymbol] = new AwaiterOnce(); [waitUntilSymbol] = () => { - if (!this[waitUntilCacheSymbol]) { - this[waitUntilCacheSymbol] = this[awaiterSymbol].awaiting() - } - return this[waitUntilCacheSymbol] + return this[awaiterSymbol].awaiting() } // eslint-disable-next-line @typescript-eslint/no-useless-constructor diff --git a/test/e2e/app-dir/next-after-app/utils/simulated-invocation.js b/test/e2e/app-dir/next-after-app/utils/simulated-invocation.js index 642b584461805..f7c3bd026ccaa 100644 --- a/test/e2e/app-dir/next-after-app/utils/simulated-invocation.js +++ b/test/e2e/app-dir/next-after-app/utils/simulated-invocation.js @@ -1,5 +1,5 @@ import { requestAsyncStorage } from 'next/dist/client/components/request-async-storage.external' -import { Awaiter } from 'next/dist/server/lib/awaiter' +import { AwaiterOnce } from 'next/dist/server/lib/awaiter' import { cliLog } from './log' // replaced in tests @@ -74,7 +74,7 @@ So for edge, the flow goes like this: */ function createInvocationContext() { - const awaiter = new Awaiter() + const awaiter = new AwaiterOnce() const waitUntil = (promise) => { awaiter.waitUntil(promise)