Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add abort signalling and controls to transport #1517

Merged
merged 5 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/rare-hotels-nail.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@penumbra-zone/transport-chrome': minor
'@penumbra-zone/transport-dom': minor
---

respect transport abort controls
2 changes: 2 additions & 0 deletions eslint.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ export default tseslint.config(
'**/*.story.@(ts|tsx|js|jsx|mjs|cjs)',
],
rules: {
'@typescript-eslint/no-empty-function': 'off',
'@typescript-eslint/no-non-null-assertion': 'off',
'@typescript-eslint/prefer-promise-reject-errors': 'off',
'react/display-name': 'off',
},
},
Expand Down
7 changes: 4 additions & 3 deletions packages/transport-chrome/src/session-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import {
isTransportAbort,
isTransportError,
isTransportMessage,
isTransportStream,
Expand Down Expand Up @@ -102,10 +103,10 @@ export class CRSessionClient {
try {
if (ev.data === false) {
this.disconnectService();
} else if (isTransportMessage(ev.data)) {
} else if (isTransportAbort(ev.data) || isTransportMessage(ev.data)) {
this.servicePort.postMessage(ev.data);
} else if (isTransportStream(ev.data)) {
this.servicePort.postMessage(this.requestChannelStream(ev.data));
this.servicePort.postMessage(this.makeChannelStreamRequest(ev.data));
} else {
console.warn('Unknown item from client', ev.data);
}
Expand Down Expand Up @@ -135,7 +136,7 @@ export class CRSessionClient {
return [{ requestId, stream }, [stream]] satisfies [TransportStream, [Transferable]];
};

private requestChannelStream = ({ requestId, stream }: TransportStream) => {
private makeChannelStreamRequest = ({ requestId, stream }: TransportStream) => {
const channel = nameConnection(this.prefix, ChannelLabel.STREAM);
const sinkListener = (p: chrome.runtime.Port) => {
if (p.name !== channel) {
Expand Down
51 changes: 37 additions & 14 deletions packages/transport-chrome/src/session-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ import { isTransportInitChannel, TransportInitChannel } from './message.js';
import { PortStreamSink, PortStreamSource } from './stream.js';
import { ChannelHandlerFn } from '@penumbra-zone/transport-dom/adapter';
import {
isTransportAbort,
isTransportMessage,
TransportEvent,
TransportMessage,
TransportStream,
} from '@penumbra-zone/transport-dom/messages';

interface CRSession {
interface CRSession extends AbortController {
clientId: string;
acont: AbortController;
port: chrome.runtime.Port;
origin: string;
}
Expand Down Expand Up @@ -43,6 +43,7 @@ interface CRSession {
export class CRSessionManager {
private static singleton?: CRSessionManager;
private sessions = new Map<string, CRSession>();
private requests = new Map<string, AbortController>();

private constructor(
private prefix: string,
Expand All @@ -61,6 +62,19 @@ export class CRSessionManager {
*/
public static init = (prefix: string, handler: ChannelHandlerFn) => {
CRSessionManager.singleton ??= new CRSessionManager(prefix, handler);
return CRSessionManager.singleton.sessions;
};

public static killOrigin = (targetOrigin: string) => {
if (CRSessionManager.singleton) {
CRSessionManager.singleton.sessions.forEach(session => {
if (session.origin === targetOrigin) {
session.abort(targetOrigin);
}
});
} else {
throw new Error('No session manager');
}
};

/**
Expand Down Expand Up @@ -100,29 +114,31 @@ export class CRSessionManager {
if (this.sessions.has(clientId)) {
throw new Error(`Session collision: ${clientId}`);
}
const session = {

const session: CRSession = Object.assign(new AbortController(), {
clientId,
acont: new AbortController(),
origin: sender.origin,
port: port,
};
});
this.sessions.set(clientId, session);

session.acont.signal.addEventListener('abort', () => port.disconnect());
port.onDisconnect.addListener(() => session.acont.abort('Disconnect'));
session.signal.addEventListener('abort', () => port.disconnect());
port.onDisconnect.addListener(() => session.abort('Disconnect'));

port.onMessage.addListener((i, p) => {
void (async () => {
try {
if (isTransportMessage(i)) {
p.postMessage(await this.clientMessageHandler(session.acont.signal, i));
if (isTransportAbort(i)) {
this.requests.get(i.requestId)?.abort();
} else if (isTransportMessage(i)) {
p.postMessage(await this.clientMessageHandler(session, i));
} else if (isTransportInitChannel(i)) {
console.warn('Client streaming unimplemented', this.acceptChannelStreamRequest(i));
} else {
console.warn('Unknown item in transport', i);
}
} catch (e) {
session.acont.abort(e);
session.abort(e);
}
})();
});
Expand All @@ -137,13 +153,19 @@ export class CRSessionManager {
* representing an error.
*/
private clientMessageHandler(
signal: AbortSignal,
session: CRSession,
{ requestId, message }: TransportMessage,
): Promise<TransportEvent> {
return this.handler(message)
if (this.requests.has(requestId)) {
throw new Error(`Request collision: ${requestId}`);
}
const requestController = new AbortController();
session.signal.addEventListener('abort', () => requestController.abort());
this.requests.set(requestId, requestController);
return this.handler(message, AbortSignal.any([session.signal, requestController.signal]))
.then(response =>
response instanceof ReadableStream
? this.responseChannelStream(signal, {
? this.responseChannelStream(requestController.signal, {
requestId,
stream: response as unknown,
} as TransportStream)
Expand All @@ -152,7 +174,8 @@ export class CRSessionManager {
.catch((error: unknown) => ({
requestId,
error: errorToJson(ConnectError.from(error), undefined),
}));
}))
.finally(() => this.requests.delete(requestId));
}

/**
Expand Down
6 changes: 4 additions & 2 deletions packages/transport-dom/src/ReadableStream.from.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ const ReadableStreamWithFrom: typeof ReadableStream & { from: ReadableStreamFrom
'from' in ReadableStream
? (ReadableStream as typeof ReadableStream & { from: ReadableStreamFrom })
: Object.assign(ReadableStream, {
from<T>(iterable: Iterable<T> | AsyncIterable<T>): ReadableStream<T> {
if (Symbol.iterator in iterable) {
from<T>(iterable: ReadableStream<T> | Iterable<T> | AsyncIterable<T>): ReadableStream<T> {
if (iterable instanceof ReadableStream) {
return iterable;
} else if (Symbol.iterator in iterable) {
const it = iterable[Symbol.iterator]();
return new ReadableStream({
pull(cont) {
Expand Down
63 changes: 50 additions & 13 deletions packages/transport-dom/src/adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import {
} from '@connectrpc/connect';
import {
Any,
AnyMessage,
JsonReadOptions,
JsonValue,
JsonWriteOptions,
MessageType,
MethodInfo,
MethodKind,
ServiceType,
Expand All @@ -33,10 +35,14 @@ import ReadableStream from './ReadableStream.from.js';
// hopefully also simplifies transport call soon
type MethodType = MethodInfo & { service: { typeName: string } };

type ChannelRequest = JsonValue;
type ChannelRequest = JsonValue | ReadableStream<JsonValue>;
type ChannelResponse = JsonValue | ReadableStream<JsonValue>;

export type ChannelHandlerFn = (r: ChannelRequest) => Promise<ChannelResponse>;
export type ChannelHandlerFn = (
request: ChannelRequest,
signal?: AbortSignal,
timeoutMs?: number,
) => Promise<ChannelResponse>;
export type ChannelContextFn = (
h: UniversalServerRequest,
) => Promise<UniversalServerRequest & { contextValues: ContextValues }>;
Expand Down Expand Up @@ -145,7 +151,7 @@ export const connectChannelAdapter = (opt: ChannelAdapterOptions): ChannelHandle
);

// TODO: alternatively, we could have the channelClient provide a requestPath
const I_MethodType = new Map<string, MethodType>(
const methodTypesByName = new Map<string, MethodType>(
router.handlers.map(({ method, service }) => [
method.I.typeName,
{ ...method, service: { typeName: service.typeName } },
Expand All @@ -164,11 +170,28 @@ export const connectChannelAdapter = (opt: ChannelAdapterOptions): ChannelHandle
httpClient: injectRequestContext,
});

return async function channelHandler(message: ChannelRequest) {
const request = Any.fromJson(message, jsonOptions).unpack(jsonOptions.typeRegistry)!;
const requestType = request.getType();
const deserializeRequest = (
message: ChannelRequest,
): { requestType: MessageType; request: AnyMessage | ReadableStream<AnyMessage> } => {
if (message instanceof ReadableStream) {
throw new ConnectError('Streaming request unimplemented', ConnectErrorCode.Unimplemented);
} else {
const request = Any.fromJson(message, jsonOptions).unpack(jsonOptions.typeRegistry);
if (!request) {
throw new ConnectError('Invalid request', ConnectErrorCode.InvalidArgument);
}
return { requestType: request.getType(), request };
}
};

const methodType = I_MethodType.get(requestType.typeName);
return async function channelHandler(
message: ChannelRequest,
signal?: AbortSignal,
timeoutMs?: number,
) {
const { request, requestType } = deserializeRequest(message);

const methodType = methodTypesByName.get(requestType.typeName);
if (!methodType) {
throw new ConnectError(`Method ${requestType.typeName} not found`, ConnectErrorCode.NotFound);
}
Expand All @@ -180,8 +203,8 @@ export const connectChannelAdapter = (opt: ChannelAdapterOptions): ChannelHandle
// only uses service.typeName, so this cast is ok
methodType.service as ServiceType,
methodType satisfies MethodInfo,
undefined, // TODO abort
undefined, // TODO timeout
signal,
timeoutMs,
undefined, // TODO headers
request,
);
Expand All @@ -191,21 +214,35 @@ export const connectChannelAdapter = (opt: ChannelAdapterOptions): ChannelHandle
// only uses service.typeName, so this cast is ok
methodType.service as ServiceType,
methodType satisfies MethodInfo,
undefined, // TODO abort
undefined, // TODO timeout
signal,
timeoutMs,
undefined, // TODO headers
createAsyncIterable([request]),
);
break;
case MethodKind.BiDiStreaming:
case MethodKind.ClientStreaming:
response = await transport.stream(
// only uses service.typeName, so this cast is ok
methodType.service as ServiceType,
methodType satisfies MethodInfo,
signal,
timeoutMs,
undefined, // TODO headers
request as never,
);
break;
default:
throw new ConnectError(
`Unimplemented method kind ${methodType.kind}`,
`Unexpected method kind for ${requestType.typeName}`,
ConnectErrorCode.Unimplemented,
);
}

if (response.stream) {
return ReadableStream.from(response.message).pipeThrough(new MessageToJson(jsonOptions));
return ReadableStream.from(response.message).pipeThrough(new MessageToJson(jsonOptions), {
signal,
});
} else {
return Any.pack(response.message).toJson(jsonOptions);
}
Expand Down
Loading
Loading