Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 185 additions & 22 deletions src/internal/concurrency/async-abort-controller.ts
Original file line number Diff line number Diff line change
@@ -1,41 +1,100 @@
/**
* This special AbortController is used to wait for all the abort handlers to finish before resolving the promise.
*/
type AbortListener = EventListenerOrEventListenerObject

type ListenerRecord = {
wrapped: EventListener
cleanup: () => void
}

export class AsyncAbortController extends AbortController {
protected promises: Promise<any>[] = []
protected runningPromises = new Set<Promise<void>>()
protected abortListeners = new WeakMap<AbortListener, Map<boolean, ListenerRecord>>()
protected _nextGroup?: AsyncAbortController

constructor() {
super()

const originalEventListener = this.signal.addEventListener
const originalAddEventListener = this.signal.addEventListener.bind(this.signal)
const originalRemoveEventListener = this.signal.removeEventListener.bind(this.signal)

// Patch event addEventListener to keep track of listeners and their promises
this.signal.addEventListener = (type: string, listener: any, options: any) => {
this.signal.addEventListener = (
type: string,
listener: EventListenerOrEventListenerObject | null,
options?: boolean | AddEventListenerOptions
) => {
if (!listener) {
return
}

if (type !== 'abort') {
return originalEventListener.call(this.signal, type, listener, options)
return originalAddEventListener(type, listener, options)
}

if (this.signal.aborted) {
return originalAddEventListener(type, listener, options)
}

const capture = getCaptureOption(options)
const existingRecord = this.getAbortListenerRecord(listener, capture)
if (existingRecord) {
return originalAddEventListener(type, existingRecord.wrapped, options)
}

const registrationSignal = getRegistrationSignal(options)
if (registrationSignal?.aborted) {
return
}

let wrapped!: EventListener
const cleanupRegistrationSignal = this.watchListenerRemovalSignal(
registrationSignal,
listener,
capture
)

wrapped = (event: Event) => {
this.deleteAbortListenerRecord(listener, capture)
originalRemoveEventListener(type, wrapped, capture)

const runningPromise = this.invokeAbortListener(listener, event)
this.runningPromises.add(runningPromise)
void runningPromise.finally(() => {
this.runningPromises.delete(runningPromise)
})
}

let resolving: undefined | (() => Promise<void>) = undefined
const promise = new Promise<void>((resolve, reject) => {
resolving = async (): Promise<void> => {
return Promise.resolve()
.then(() => listener())
.then(() => {
resolve()
})
.catch((error) => {
reject(error)
})
}
this.setAbortListenerRecord(listener, capture, {
wrapped,
cleanup: cleanupRegistrationSignal,
})
this.promises.push(promise)

if (!resolving) {
throw new Error('resolve is undefined')
return originalAddEventListener(type, wrapped, options)
}

this.signal.removeEventListener = (
type: string,
listener: EventListenerOrEventListenerObject | null,
options?: boolean | EventListenerOptions
) => {
if (!listener) {
return
}

if (type !== 'abort') {
return originalRemoveEventListener(type, listener, options)
}

const capture = getCaptureOption(options)
const record = this.getAbortListenerRecord(listener, capture)
if (!record) {
return originalRemoveEventListener(type, listener, options)
}

return originalEventListener.call(this.signal, type, resolving, options)
this.deleteAbortListenerRecord(listener, capture)
return originalRemoveEventListener(type, record.wrapped, options)
}
}

Expand All @@ -50,8 +109,8 @@ export class AsyncAbortController extends AbortController {

async abortAsync() {
this.abort()
while (this.promises.length > 0) {
const promises = this.promises.splice(0, 100)
while (this.runningPromises.size > 0) {
const promises = Array.from(this.runningPromises)
await Promise.allSettled(promises)
}
await this.abortNextGroup()
Expand All @@ -62,4 +121,108 @@ export class AsyncAbortController extends AbortController {
await this._nextGroup.abortAsync()
}
}

protected invokeAbortListener(listener: AbortListener, event: Event): Promise<void> {
try {
const result =
typeof listener === 'function'
? listener.call(this.signal, event)
: listener.handleEvent(event)

return Promise.resolve(result).then(() => undefined)
} catch (error) {
return Promise.reject(error)
}
}

protected getAbortListenerRecord(
listener: AbortListener,
capture: boolean
): ListenerRecord | undefined {
return this.abortListeners.get(listener)?.get(capture)
}

protected setAbortListenerRecord(
listener: AbortListener,
capture: boolean,
record: ListenerRecord
) {
const records = this.abortListeners.get(listener) ?? new Map<boolean, ListenerRecord>()
records.set(capture, record)
this.abortListeners.set(listener, records)
}

protected deleteAbortListenerRecord(listener: AbortListener, capture: boolean) {
const records = this.abortListeners.get(listener)
const record = records?.get(capture)
if (!records || !record) {
return
}

record.cleanup()
records.delete(capture)

if (records.size === 0) {
this.abortListeners.delete(listener)
}
}

protected watchListenerRemovalSignal(
signal: AbortSignal | undefined,
listener: AbortListener,
capture: boolean
): () => void {
if (!signal) {
return () => {}
}

const onAbort = () => {
this.deleteAbortListenerRecord(listener, capture)
}

addNativeEventListener(signal, 'abort', onAbort, { once: true })

return () => {
removeNativeEventListener(signal, 'abort', onAbort, { capture: false })
}
}
}

const nativeAddEventListener = EventTarget.prototype.addEventListener
const nativeRemoveEventListener = EventTarget.prototype.removeEventListener

function addNativeEventListener(
target: EventTarget,
type: string,
listener: EventListenerOrEventListenerObject,
options?: boolean | AddEventListenerOptions
) {
nativeAddEventListener.call(target, type, listener, options)
}

function removeNativeEventListener(
target: EventTarget,
type: string,
listener: EventListenerOrEventListenerObject,
options?: boolean | EventListenerOptions
) {
nativeRemoveEventListener.call(target, type, listener, options)
}

function getCaptureOption(options?: boolean | EventListenerOptions): boolean {
if (typeof options === 'boolean') {
return options
}

return options?.capture ?? false
}

function getRegistrationSignal(
options?: boolean | AddEventListenerOptions
): AbortSignal | undefined {
if (typeof options === 'boolean') {
return undefined
}

return options?.signal
}
108 changes: 108 additions & 0 deletions src/test/async-abort-controller.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,112 @@ describe('AsyncAbortController', () => {

expect(order).toEqual(['root:start', 'root:end', 'child', 'grandchild'])
})

it('forwards the real abort event to function listeners with the signal as context', async () => {
const controller = new AsyncAbortController()
const seen: {
target: EventTarget | null
currentTarget: EventTarget | null
context: unknown
} = {
target: null,
currentTarget: null,
context: undefined,
}

controller.signal.addEventListener('abort', function (event) {
seen.target = event.target
seen.currentTarget = event.currentTarget
seen.context = this
})

await controller.abortAsync()

expect(seen.target).toBe(controller.signal)
expect(seen.currentTarget).toBe(controller.signal)
expect(seen.context).toBe(controller.signal)
})

it('waits for handleEvent listeners before aborting nested groups', async () => {
const controller = new AsyncAbortController()
const childGroup = controller.nextGroup
const order: string[] = []
let releaseRootAbort!: () => void
const rootAbortDone = new Promise<void>((resolve) => {
releaseRootAbort = resolve
})
const listener = {
target: null as EventTarget | null,
async handleEvent(event: Event) {
this.target = event.target
order.push('root:start')
await rootAbortDone
order.push('root:end')
},
}

controller.signal.addEventListener('abort', listener)
childGroup.signal.addEventListener('abort', () => {
order.push('child')
})

const abortPromise = controller.abortAsync()

await Promise.resolve()
expect(order).toEqual(['root:start'])

releaseRootAbort()
await abortPromise

expect(listener.target).toBe(controller.signal)
expect(order).toEqual(['root:start', 'root:end', 'child'])
})

it('ignores null abort listeners', async () => {
const controller = new AsyncAbortController()
const nullListener = null as unknown as EventListenerOrEventListenerObject

expect(() => controller.signal.addEventListener('abort', nullListener)).not.toThrow()
await expect(controller.abortAsync()).resolves.toBeUndefined()
})

it('does not invoke or wait on explicitly removed abort listeners', async () => {
const controller = new AsyncAbortController()
const listener = jest.fn()

controller.signal.addEventListener('abort', listener)
controller.signal.removeEventListener('abort', listener)

await expect(controller.abortAsync()).resolves.toBeUndefined()
expect(listener).not.toHaveBeenCalled()
})

it('does not invoke or wait on abort listeners removed by a registration signal', async () => {
const controller = new AsyncAbortController()
const registration = new AbortController()
const listener = jest.fn()

controller.signal.addEventListener('abort', listener, {
signal: registration.signal,
})

registration.abort()

await expect(controller.abortAsync()).resolves.toBeUndefined()
expect(listener).not.toHaveBeenCalled()
})

it('ignores abort listeners registered with an already aborted signal', async () => {
const controller = new AsyncAbortController()
const registration = new AbortController()
const listener = jest.fn()

registration.abort()
controller.signal.addEventListener('abort', listener, {
signal: registration.signal,
})

await expect(controller.abortAsync()).resolves.toBeUndefined()
expect(listener).not.toHaveBeenCalled()
})
})