Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[unstable_after] fixes for waitUntil in edge runtime #66135

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,27 @@ export function getRender({
request: NextRequestHint,
event?: NextFetchEvent
) {
const isAfterEnabled = !!process.env.__NEXT_AFTER

const extendedReq = new WebNextRequest(request)
const extendedRes = new WebNextResponse(
undefined,
// tracking onClose adds overhead, so only do it if `experimental.after` is on.
!!process.env.__NEXT_AFTER
isAfterEnabled
)

handler(extendedReq, extendedRes)
const result = await extendedRes.toResponse()

if (event?.waitUntil) {
if (isAfterEnabled) {
// make sure that NextRequestHint's awaiter stays open long enough
// for late `waitUntil`s called during streaming to get picked up.
event.waitUntil(
new Promise<void>((resolve) => extendedRes.onClose(resolve))
)
}

// TODO(after):
// remove `internal_runWithWaitUntil` and the `internal-edge-wait-until` module
// when consumers switch to `unstable_after`.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import { createAsyncLocalStorage } from './async-local-storage'
import type { LifecycleAsyncStorage } from './lifecycle-async-storage.external'

export const _lifecycleAsyncStorage: LifecycleAsyncStorage =
createAsyncLocalStorage()
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import type { AsyncLocalStorage } from 'async_hooks'
// Share the instance module in the next-shared layer
import { _lifecycleAsyncStorage as lifecycleAsyncStorage } from './lifecycle-async-storage-instance' with { 'turbopack-transition': 'next-shared' }

export interface LifecycleStore {
readonly waitUntil: ((promise: Promise<any>) => void) | undefined
}

export type LifecycleAsyncStorage = AsyncLocalStorage<LifecycleStore>

export { lifecycleAsyncStorage }
8 changes: 7 additions & 1 deletion packages/next/src/server/base-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ import {
getBuiltinRequestContext,
type WaitUntil,
} from './after/builtin-request-context'
import { lifecycleAsyncStorage } from '../client/components/lifecycle-async-storage.external'

export type FindComponentsResult = {
components: LoadComponentsReturnType
Expand Down Expand Up @@ -1669,7 +1670,12 @@ export default abstract class Server<
)
}

private getWaitUntil(): WaitUntil | undefined {
protected getWaitUntil(): WaitUntil | undefined {
const lifecycleStore = lifecycleAsyncStorage.getStore()
if (lifecycleStore) {
return lifecycleStore.waitUntil
}

const builtinRequestContext = getBuiltinRequestContext()
if (builtinRequestContext) {
// the platform provided a request context.
Expand Down
81 changes: 81 additions & 0 deletions packages/next/src/server/lib/awaiter.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import { InvariantError } from '../../shared/lib/invariant-error'
import { AwaiterMulti, AwaiterOnce } from './awaiter'

describe('AwaiterOnce/AwaiterMulti', () => {
describe.each([
{ name: 'AwaiterMulti', impl: AwaiterMulti },
{ name: 'AwaiterOnce', impl: AwaiterOnce },
])('$name', ({ impl: AwaiterImpl }) => {
it('awaits promises added by other promises', async () => {
const awaiter = new AwaiterImpl()

const MAX_DEPTH = 5
const promises: TrackedPromise<unknown>[] = []

const waitUntil = (promise: Promise<unknown>) => {
promises.push(trackPromiseSettled(promise))
awaiter.waitUntil(promise)
}

const makeNestedPromise = async () => {
if (promises.length >= MAX_DEPTH) {
return
}
await sleep(100)
waitUntil(makeNestedPromise())
}

waitUntil(makeNestedPromise())

await awaiter.awaiting()

for (const promise of promises) {
expect(promise.isSettled).toBe(true)
}
})

it('calls onError for rejected promises', async () => {
const onError = jest.fn<void, [error: unknown]>()
const awaiter = new AwaiterImpl({ onError })

awaiter.waitUntil(Promise.reject('error 1'))
awaiter.waitUntil(
sleep(100).then(() => awaiter.waitUntil(Promise.reject('error 2')))
)

await awaiter.awaiting()

expect(onError).toHaveBeenCalledWith('error 1')
expect(onError).toHaveBeenCalledWith('error 2')
})
})
})

describe('AwaiterOnce', () => {
it("does not allow calling waitUntil after it's been awaited", async () => {
const awaiter = new AwaiterOnce()
awaiter.waitUntil(Promise.resolve(1))
await awaiter.awaiting()
expect(() => awaiter.waitUntil(Promise.resolve(2))).toThrow(InvariantError)
})
})

type TrackedPromise<T> = Promise<T> & { isSettled: boolean }

function trackPromiseSettled<T>(promise: Promise<T>): TrackedPromise<T> {
const tracked = promise as TrackedPromise<T>
tracked.isSettled = false
tracked.then(
() => {
tracked.isSettled = true
},
() => {
tracked.isSettled = true
}
)
return tracked
}

function sleep(duration: number) {
return new Promise<void>((resolve) => setTimeout(resolve, duration))
}
70 changes: 70 additions & 0 deletions packages/next/src/server/lib/awaiter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { InvariantError } from '../../shared/lib/invariant-error'

/**
* Provides a `waitUntil` implementation which gathers promises to be awaited later (via {@link AwaiterMulti.awaiting}).
* Unlike a simple `Promise.all`, {@link AwaiterMulti} works recursively --
* if a promise passed to {@link AwaiterMulti.waitUntil} calls `waitUntil` again,
* that second promise will also be awaited.
*/
export class AwaiterMulti {
private promises: Set<Promise<unknown>> = new Set()
private onError: (error: unknown) => void

constructor({ onError }: { onError?: (error: unknown) => void } = {}) {
this.onError = onError ?? console.error
}

public waitUntil = (promise: Promise<unknown>): void => {
// if a promise settles before we await it, we can drop it.
const cleanup = () => {
this.promises.delete(promise)
}

this.promises.add(
promise.then(cleanup, (err) => {
cleanup()
this.onError(err)
})
)
}

public async awaiting(): Promise<void> {
while (this.promises.size > 0) {
const promises = Array.from(this.promises)
this.promises.clear()
await Promise.all(promises)
}
}
}

/**
* Like {@link AwaiterMulti}, but can only be awaited once.
* If {@link AwaiterOnce.waitUntil} is called after that, it will throw.
*/
export class AwaiterOnce {
private awaiter: AwaiterMulti
private done: boolean = false
private pending: Promise<void> | undefined

constructor(options: { onError?: (error: unknown) => void } = {}) {
this.awaiter = new AwaiterMulti(options)
}

public waitUntil = (promise: Promise<unknown>): void => {
if (this.done) {
throw new InvariantError(
'Cannot call waitUntil() on an AwaiterOnce that was already awaited'
)
}
return this.awaiter.waitUntil(promise)
}

public async awaiting(): Promise<void> {
if (!this.pending) {
this.pending = this.awaiter.awaiting().finally(() => {
this.done = true
})
}
return this.pending
}
}
12 changes: 11 additions & 1 deletion packages/next/src/server/next-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,12 @@ export default class NextNodeServer extends BaseServer<
})

// If we handled the request, we can return early.
if (handled) return true
if (handled) {
const waitUntil = this.getWaitUntil()
waitUntil?.(handled.waitUntil)

return true
}
}

// If the route was detected as being a Pages API route, then handle
Expand Down Expand Up @@ -1934,6 +1939,11 @@ export default class NextNodeServer extends BaseServer<
}
})

const waitUntil = this.getWaitUntil()
if (waitUntil) {
waitUntil(result.waitUntil)
}

const { originalResponse } = params.res
if (result.response.body) {
await pipeToNodeResponse(result.response.body, originalResponse)
Expand Down
29 changes: 25 additions & 4 deletions packages/next/src/server/web/adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { getTracer } from '../lib/trace/tracer'
import type { TextMapGetter } from 'next/dist/compiled/@opentelemetry/api'
import { MiddlewareSpan } from '../lib/trace/constants'
import { CloseController } from './web-on-close'
import { lifecycleAsyncStorage } from '../../client/components/lifecycle-async-storage.external'

export class NextRequestHint extends NextRequest {
sourcePage: string
Expand Down Expand Up @@ -213,13 +214,14 @@ export async function adapter(
const isMiddleware =
params.page === '/middleware' || params.page === '/src/middleware'

const isAfterEnabled =
params.request.nextConfig?.experimental?.after ??
!!process.env.__NEXT_AFTER

if (isMiddleware) {
// if we're in an edge function, we only get a subset of `nextConfig` (no `experimental`),
// so we have to inject it via DefinePlugin.
// in `next start` this will be passed normally (see `NextNodeServer.runMiddleware`).
const isAfterEnabled =
params.request.nextConfig?.experimental?.after ??
!!process.env.__NEXT_AFTER

let waitUntil: WrapperRenderOpts['waitUntil'] = undefined
let closeController: CloseController | undefined = undefined
Expand Down Expand Up @@ -279,6 +281,25 @@ export async function adapter(
}
)
}

if (isAfterEnabled) {
// NOTE:
// Currently, `adapter` is expected to return promises passed to `waitUntil`
// as part of its result (i.e. a FetchEventResult).
// Because of this, we override any outer contexts that might provide a real `waitUntil`,
// and provide the `waitUntil` from the NextFetchEvent instead so that we can collect those promises.
// This is not ideal, but until we change this calling convention, it's the least surprising thing to do.
//
// Notably, the only case that currently cares about this ALS is Edge SSR
// (i.e. a handler created via `build/webpack/loaders/next-edge-ssr-loader/render.ts`)
// Other types of handlers will grab the waitUntil from the passed FetchEvent,
// but NextWebServer currently has no interface that'd allow for that.
return lifecycleAsyncStorage.run(
{ waitUntil: event.waitUntil.bind(event) },
() => params.handler(request, event)
)
}

return params.handler(request, event)
})

Expand Down Expand Up @@ -399,7 +420,7 @@ export async function adapter(

return {
response: finalResponse,
waitUntil: Promise.all(event[waitUntilSymbol]),
waitUntil: event[waitUntilSymbol](),
fetchMetrics: request.fetchMetrics,
}
}
7 changes: 7 additions & 0 deletions packages/next/src/server/web/edge-route-module-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ export class EdgeRouteModuleWrapper {
const trackedBody = trackStreamConsumed(res.body, () =>
_closeController.dispatchClose()
)

// make sure that NextRequestHint's awaiter stays open long enough
// for `waitUntil`s called late during streaming to get picked up.
evt.waitUntil(
new Promise<void>((resolve) => _closeController.onClose(resolve))
)

res = new Response(trackedBody, {
status: res.status,
statusText: res.statusText,
Expand Down
4 changes: 4 additions & 0 deletions packages/next/src/server/web/sandbox/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ Learn More: https://nextjs.org/docs/messages/edge-dynamic-code-evaluation`),
context.clearTimeout = (timeout: number) =>
timeoutsManager.remove(timeout)

if (process.env.__NEXT_TEST_MODE) {
context.__next_outer_globalThis__ = globalThis
}

return context
},
})
Expand Down
14 changes: 11 additions & 3 deletions packages/next/src/server/web/spec-extension/fetch-event.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import { AwaiterOnce } from '../../lib/awaiter'
import { PageSignatureError } from '../error'
import type { NextRequest } from './request'

const responseSymbol = Symbol('response')
const passThroughSymbol = Symbol('passThrough')
const awaiterSymbol = Symbol('awaiter')

export const waitUntilSymbol = Symbol('waitUntil')

class FetchEvent {
readonly [waitUntilSymbol]: Promise<any>[] = [];
[responseSymbol]?: Promise<Response>;
[passThroughSymbol] = false
[passThroughSymbol] = false;

[awaiterSymbol] = new AwaiterOnce();

[waitUntilSymbol] = () => {
return this[awaiterSymbol].awaiting()
}

// eslint-disable-next-line @typescript-eslint/no-useless-constructor
constructor(_request: Request) {}
Expand All @@ -24,7 +32,7 @@ class FetchEvent {
}

waitUntil(promise: Promise<any>): void {
this[waitUntilSymbol].push(promise)
this[awaiterSymbol].waitUntil(promise)
}
}

Expand Down
Loading
Loading