Skip to content

Commit

Permalink
refactor: add AwaiterOnce and awaiter tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lubieowoce committed Jun 11, 2024
1 parent 3b58ffd commit ac7e6d9
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 15 deletions.
81 changes: 81 additions & 0 deletions packages/next/src/server/lib/awaiter.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import { InvariantError } from '../../shared/lib/invariant-error'
import { AwaiterMulti, AwaiterOnce } from './awaiter'

describe('AwaiterOnce/AwaiterMulti', () => {
describe.each([
{ name: 'AwaiterMulti', impl: AwaiterMulti },
{ name: 'AwaiterOnce', impl: AwaiterOnce },
])('$name', ({ impl: AwaiterImpl }) => {
it('awaits promises added by other promises', async () => {
const awaiter = new AwaiterImpl()

const MAX_DEPTH = 5
const promises: TrackedPromise<unknown>[] = []

const waitUntil = (promise: Promise<unknown>) => {
promises.push(trackPromiseSettled(promise))
awaiter.waitUntil(promise)
}

const makeNestedPromise = async () => {
if (promises.length >= MAX_DEPTH) {
return
}
await sleep(100)
waitUntil(makeNestedPromise())
}

waitUntil(makeNestedPromise())

await awaiter.awaiting()

for (const promise of promises) {
expect(promise.isSettled).toBe(true)
}
})

it('calls onError for rejected promises', async () => {
const onError = jest.fn<void, [error: unknown]>()
const awaiter = new AwaiterImpl({ onError })

awaiter.waitUntil(Promise.reject('error 1'))
awaiter.waitUntil(
sleep(100).then(() => awaiter.waitUntil(Promise.reject('error 2')))
)

await awaiter.awaiting()

expect(onError).toHaveBeenCalledWith('error 1')
expect(onError).toHaveBeenCalledWith('error 2')
})
})
})

describe('AwaiterOnce', () => {
it("does not allow calling waitUntil after it's been awaited", async () => {
const awaiter = new AwaiterOnce()
awaiter.waitUntil(Promise.resolve(1))
await awaiter.awaiting()
expect(() => awaiter.waitUntil(Promise.resolve(2))).toThrow(InvariantError)
})
})

type TrackedPromise<T> = Promise<T> & { isSettled: boolean }

function trackPromiseSettled<T>(promise: Promise<T>): TrackedPromise<T> {
const tracked = promise as TrackedPromise<T>
tracked.isSettled = false
tracked.then(
() => {
tracked.isSettled = true
},
() => {
tracked.isSettled = true
}
)
return tracked
}

function sleep(duration: number) {
return new Promise<void>((resolve) => setTimeout(resolve, duration))
}
57 changes: 52 additions & 5 deletions packages/next/src/server/lib/awaiter.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
import { InvariantError } from '../../shared/lib/invariant-error'

/**
* The Awaiter class is used to manage and await multiple promises.
* Provides a `waitUntil` implementation which gathers promises to be awaited later (via {@link AwaiterMulti.awaiting}).
* Unlike a simple `Promise.all`, {@link AwaiterMulti} works recursively --
* if a promise passed to {@link AwaiterMulti.waitUntil} calls `waitUntil` again,
* that second promise will also be awaited.
*/
export class Awaiter {
export class AwaiterMulti {
private promises: Set<Promise<unknown>> = new Set()
private onError: ((error: unknown) => void) | undefined
private onError: (error: unknown) => void

constructor({ onError }: { onError?: (error: unknown) => void } = {}) {
this.onError = onError ?? console.error
}

public waitUntil = (promise: Promise<unknown>) => {
this.promises.add(promise.catch(this.onError))
public waitUntil = (promise: Promise<unknown>): void => {
// if a promise settles before we await it, we can drop it.
const cleanup = () => {
this.promises.delete(promise)
}

this.promises.add(
promise.then(cleanup, (err) => {
cleanup()
this.onError(err)
})
)
}

public async awaiting(): Promise<void> {
Expand All @@ -21,3 +36,35 @@ export class Awaiter {
}
}
}

/**
* Like {@link AwaiterMulti}, but can only be awaited once.
* If {@link AwaiterOnce.waitUntil} is called after that, it will throw.
*/
export class AwaiterOnce {
private awaiter: AwaiterMulti
private done: boolean = false
private pending: Promise<void> | undefined

constructor(options: { onError?: (error: unknown) => void } = {}) {
this.awaiter = new AwaiterMulti(options)
}

public waitUntil = (promise: Promise<unknown>): void => {
if (this.done) {
throw new InvariantError(
'Cannot call waitUntil() on an AwaiterOnce that was already awaited'
)
}
return this.awaiter.waitUntil(promise)
}

public async awaiting(): Promise<void> {
if (!this.pending) {
this.pending = this.awaiter.awaiting().finally(() => {
this.done = true
})
}
return this.pending
}
}
11 changes: 3 additions & 8 deletions packages/next/src/server/web/spec-extension/fetch-event.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
import { Awaiter } from '../../lib/awaiter'
import { AwaiterOnce } from '../../lib/awaiter'
import { PageSignatureError } from '../error'
import type { NextRequest } from './request'

const responseSymbol = Symbol('response')
const passThroughSymbol = Symbol('passThrough')
const awaiterSymbol = Symbol('awaiter')
const waitUntilCacheSymbol = Symbol('waitUntil.cache')

export const waitUntilSymbol = Symbol('waitUntil')

class FetchEvent {
[responseSymbol]?: Promise<Response>;
[passThroughSymbol] = false;

[awaiterSymbol] = new Awaiter();
[waitUntilCacheSymbol]: Promise<void> | undefined = undefined;
[awaiterSymbol] = new AwaiterOnce();

[waitUntilSymbol] = () => {
if (!this[waitUntilCacheSymbol]) {
this[waitUntilCacheSymbol] = this[awaiterSymbol].awaiting()
}
return this[waitUntilCacheSymbol]
return this[awaiterSymbol].awaiting()
}

// eslint-disable-next-line @typescript-eslint/no-useless-constructor
Expand Down
4 changes: 2 additions & 2 deletions test/e2e/app-dir/next-after-app/utils/simulated-invocation.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { requestAsyncStorage } from 'next/dist/client/components/request-async-storage.external'
import { Awaiter } from 'next/dist/server/lib/awaiter'
import { AwaiterOnce } from 'next/dist/server/lib/awaiter'
import { cliLog } from './log'

// replaced in tests
Expand Down Expand Up @@ -74,7 +74,7 @@ So for edge, the flow goes like this:
*/

function createInvocationContext() {
const awaiter = new Awaiter()
const awaiter = new AwaiterOnce()

const waitUntil = (promise) => {
awaiter.waitUntil(promise)
Expand Down

0 comments on commit ac7e6d9

Please sign in to comment.