diff --git a/packages/transport-dom/src/create.ts b/packages/transport-dom/src/create.ts index 4829e4bad3..5658b928a6 100644 --- a/packages/transport-dom/src/create.ts +++ b/packages/transport-dom/src/create.ts @@ -46,6 +46,34 @@ export interface ChannelTransportOptions getPort: () => PromiseLike; } +/** + * For use with `ConnectError.from`, in `rejectOnSignal`. Identifies an + * appropriate error code for an unknown throw. + * - ConnectError.from forwards exising ConnectError codes, ignoring this + * - ConnectError.from uses `Code.Canceled` for an 'AbortError', ignoring this + * - We want to apply `Code.DeadlineExceeded` for any 'TimeoutError' + * - All others should use `Code.Aborted` + */ +const codeForError = (r?: unknown) => { + if (r instanceof DOMException && r.name === 'TimeoutError') { + return Code.DeadlineExceeded; + } else { + return Code.Aborted; + } +}; + +const rejectOnSignal = (...signals: (AbortSignal | undefined)[]) => { + return new Promise((_, reject) => { + const signal = AbortSignal.any(signals.filter(s => s instanceof AbortSignal)); + signal.addEventListener('abort', () => + reject(ConnectError.from(signal.reason, codeForError(signal.reason))), + ); + if (signal.aborted) { + reject(ConnectError.from(signal.reason, codeForError(signal.reason))); + } + }); +}; + export const createChannelTransport = ({ getPort, jsonOptions, @@ -54,7 +82,7 @@ export const createChannelTransport = ({ const pending = new Map void>(); // this is used to recover errors that couldn't be thrown at a caller - const { reject: listenerReject, promise: transportFailure } = Promise.withResolvers(); + const transportFailure = new AbortController(); // port returned by the penumbra global let port: MessagePort | undefined; @@ -67,28 +95,26 @@ export const createChannelTransport = ({ * @returns A promise that resolves when the channel is acquired. */ const connect = async () => { - const initTimeout = new Promise( - (_, reject) => - defaultTimeoutMs && - setTimeout( - reject, - defaultTimeoutMs, - new ConnectError('Channel connection request timed out', Code.Unavailable), - ), - ); - - const gotPort = await Promise.race([getPort(), initTimeout]); + const connectionPort = await Promise.race([ + getPort(), + rejectOnSignal(AbortSignal.timeout(defaultTimeoutMs)).catch(() => + Promise.reject(new ConnectError('Channel connection request timed out', Code.Unavailable)), + ), + ]); - gotPort.addEventListener('message', transportListener); - gotPort.start(); + connectionPort.addEventListener('message', transportListener); + connectionPort.addEventListener('messageerror', (ev: MessageEvent) => + transportFailure.abort(ConnectError.from(ev.data)), + ); + connectionPort.start(); - return gotPort; + return connectionPort; }; const transportListener = ({ data }: MessageEvent) => { - if (!data) { - // likely 'false' indicating a disconnect - listenerReject(new ConnectError('Connection closed', Code.Unavailable)); + if (data === false) { + // 'false' indicating a disconnect + transportFailure.abort(new ConnectError('Connection closed', Code.Unavailable)); } else if (isTransportEvent(data)) { // this is a response to a specific request. the port may be shared, so // it's okay if it contains a requestId we don't know about. the response @@ -96,15 +122,15 @@ export const createChannelTransport = ({ pending.get(data.requestId)?.(data); } else if (isTransportError(data)) { // this is a channel-level error, corresponding to no specific request. - // this will fail this transport, and every client using this transport. - // every transport sharing this port will fail independently, but the - // rejection created here will be delivered to every subsequent request - // attempted on this transport. - listenerReject( + // it will fail this transport, and every client using this transport, and + // every transport using this channel. every transport sharing this port + // will fail independently, but the rejection created here will be + // delivered to every subsequent request attempted on this transport. + transportFailure.abort( errorFromJson(data.error, data.metadata, new ConnectError('Transport failed')), ); } else { - listenerReject( + transportFailure.abort( new ConnectError( 'Unknown item in transport', Code.Unimplemented, @@ -116,23 +142,6 @@ export const createChannelTransport = ({ } }; - const requestFailures = (signal?: AbortSignal, timeoutMs?: number) => ({ - cancel: new Promise((_, reject) => { - signal?.addEventListener('abort', () => - reject(ConnectError.from(signal.reason, Code.Aborted)), - ); - if (signal?.aborted) { - reject(ConnectError.from(signal.reason, Code.Aborted)); - } - }), - deadline: new Promise((_, reject) => { - setTimeout( - () => reject(new ConnectError('Request timed out', Code.DeadlineExceeded)), - timeoutMs, - ); - }), - }); - return { async unary = AnyMessage, O extends Message = AnyMessage>( service: ServiceType, @@ -142,15 +151,19 @@ export const createChannelTransport = ({ header: HeadersInit | undefined, input: PartialMessage, ): Promise> { + transportFailure.signal.throwIfAborted(); port ??= await connect(); const requestId = crypto.randomUUID(); + const requestFailure = new AbortController(); - const { cancel, deadline } = requestFailures(signal, timeoutMs ?? defaultTimeoutMs); const response = Promise.race([ - cancel, - deadline, - transportFailure, + rejectOnSignal( + transportFailure.signal, + requestFailure.signal, + AbortSignal.timeout(timeoutMs ?? defaultTimeoutMs), + signal, + ), new Promise((resolve, reject) => { pending.set(requestId, (tev: TransportEvent) => { if (isTransportMessage(tev, requestId)) { @@ -164,21 +177,26 @@ export const createChannelTransport = ({ }), ]).finally(() => pending.delete(requestId)); - switch (method.kind) { - case MethodKind.Unary: - { - const message = Any.pack(new method.I(input)).toJson(jsonOptions); - port.postMessage({ requestId, message, header } satisfies TransportMessage); + if (!signal?.aborted) { + try { + switch (method.kind) { + case MethodKind.Unary: + { + const message = Any.pack(new method.I(input)).toJson(jsonOptions); + signal?.addEventListener('abort', () => + port?.postMessage({ requestId, abort: true } satisfies TransportAbort), + ); + port.postMessage({ requestId, message, header } satisfies TransportMessage); + } + break; + default: + throw new ConnectError('MethodKind not supported', Code.Unimplemented); } - break; - default: - throw new ConnectError('MethodKind not supported', Code.Unimplemented); + } catch (e) { + requestFailure.abort(e); + } } - void cancel.catch(() => - port?.postMessage({ requestId, abort: true } satisfies TransportAbort), - ); - return { service, method, @@ -201,15 +219,20 @@ export const createChannelTransport = ({ header: HeadersInit | undefined, input: AsyncIterable>, ): Promise> { + transportFailure.signal.throwIfAborted(); port ??= await connect(); const requestId = crypto.randomUUID(); - const { cancel, deadline } = requestFailures(signal, timeoutMs ?? defaultTimeoutMs); + const requestFailure = new AbortController(); + const response = Promise.race([ - cancel, - deadline, - transportFailure, + rejectOnSignal( + transportFailure.signal, + requestFailure.signal, + AbortSignal.timeout(timeoutMs ?? defaultTimeoutMs), + signal, + ), new Promise((resolve, reject) => { pending.set(requestId, (tev: TransportEvent) => { if (isTransportStream(tev, requestId)) { @@ -223,44 +246,51 @@ export const createChannelTransport = ({ }), ]).finally(() => pending.delete(requestId)); - switch (method.kind) { - case MethodKind.ServerStreaming: - // send as a single message - { - const iter = input[Symbol.asyncIterator](); - const [{ value } = { value: null }, { done }] = [await iter.next(), await iter.next()]; - if (done && typeof value === 'object' && value != null) { - const message = Any.pack(new method.I(value as object)).toJson(jsonOptions); - port.postMessage({ requestId, message, header } satisfies TransportMessage); - } else { - throw new ConnectError( - 'MethodKind.ServerStreaming expects a single request message', - Code.OutOfRange, - ); - } + if (!signal?.aborted) { + try { + switch (method.kind) { + case MethodKind.ServerStreaming: + // send as a single message + { + // consume the input stream, which should have only one message + const iter = input[Symbol.asyncIterator](); + const [{ value } = { value: null }, { done }] = [ + await iter.next(), + await iter.next(), + ]; + // confirm the input stream ended after one message with content + if (done && typeof value === 'object' && value !== null) { + const message = Any.pack(new method.I(value as object)).toJson(jsonOptions); + port.postMessage({ requestId, message, header } satisfies TransportMessage); + } else { + throw new ConnectError( + 'MethodKind.ServerStreaming expects a single request message', + Code.OutOfRange, + ); + } + } + break; + case MethodKind.ClientStreaming: + case MethodKind.BiDiStreaming: + // send as an actual stream + { + const stream: ReadableStream = ReadableStream.from(input).pipeThrough( + new TransformStream({ + transform: (chunk: PartialMessage, cont) => + cont.enqueue(Any.pack(new method.I(chunk)).toJson(jsonOptions)), + }), + ); + port.postMessage({ requestId, stream, header } satisfies TransportStream, [stream]); + } + break; + default: + throw new ConnectError('MethodKind not supported', Code.Unimplemented); } - break; - case MethodKind.ClientStreaming: - case MethodKind.BiDiStreaming: - // send as an actual stream - { - const stream: ReadableStream = ReadableStream.from(input).pipeThrough( - new TransformStream({ - transform: (chunk: PartialMessage, cont) => - cont.enqueue(Any.pack(new method.I(chunk)).toJson(jsonOptions)), - }), - ); - port.postMessage({ requestId, stream, header } satisfies TransportStream, [stream]); - } - break; - default: - throw new ConnectError('MethodKind not supported', Code.Unimplemented); + } catch (e) { + requestFailure.abort(e); + } } - void cancel.catch(() => - port?.postMessage({ requestId, abort: true } satisfies TransportAbort), - ); - return { service, method,