Skip to content

Commit

Permalink
feat: add option to use async iterables (#605)
Browse files Browse the repository at this point in the history
Adds option useAsyncIterable which uses AsyncIterable instead of Observable.

For example:

  bidirectionalStreamingRequest(
    service: string,
    method: string,
    data: AsyncIterable<Uint8Array>
  ): AsyncIterable<Uint8Array>

Generates Transform async iterables for encoding and decoding:

  // encodeTransform encodes a source of message objects.
  // Transform<TestMessage, Uint8Array>
  async *encodeTransform(
    source: AsyncIterable<TestMessage | TestMessage[]> | Iterable<TestMessage | TestMessage[]>
  ): AsyncIterable<Uint8Array> {
    for await (const pkt of source) {
      if (Array.isArray(pkt)) {
        for (const p of pkt) {
          yield* [TestMessage.encode(p).finish()];
        }
      } else {
        yield* [TestMessage.encode(pkt).finish()];
      }
    }
  },

  // decodeTransform decodes a source of encoded messages.
  // Transform<Uint8Array, TestMessage>
  async *decodeTransform(
    source: AsyncIterable<Uint8Array | Uint8Array[]> | Iterable<Uint8Array | Uint8Array[]>
  ): AsyncIterable<TestMessage> {
    for await (const pkt of source) {
      if (Array.isArray(pkt)) {
        for (const p of pkt) {
          yield* [TestMessage.decode(p)];
        }
      } else {
        yield* [TestMessage.decode(pkt)];
      }
    }
  },

Generates RPC service implementations which use the Transform iterators:

  BidiStreaming(request: AsyncIterable<TestMessage>): AsyncIterable<TestMessage> {
    const data = TestMessage.encodeTransform(request);
    const result = this.rpc.bidirectionalStreamingRequest('simple.Test', 'BidiStreaming', data);
    return TestMessage.decodeTransform(result);
  }

AsyncIterables indicate a stream has ended by closing with an optional error.

Fixes #600

Signed-off-by: Christian Stewart <christian@paral.in>
  • Loading branch information
paralin committed Jul 1, 2022
1 parent b7954f2 commit ca8ea8d
Show file tree
Hide file tree
Showing 20 changed files with 248 additions and 31 deletions.
2 changes: 2 additions & 0 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ Generated code will be placed in the Gradle build directory.

- With `--ts_proto_opt=outputServices=false`, or `=none`, ts-proto will output NO service definitions.

- With `--ts_proto_opt=useAsyncIterable=true`, the generated services will use `AsyncIterable` instead of `Observable`.

- With `--ts_proto_opt=emitImportedFiles=false`, ts-proto will not emit `google/protobuf/*` files unless you explicit add files to `protoc` like this
`protoc --plugin=./node_modules/.bin/protoc-gen-ts_proto my_message.proto google/protobuf/duration.proto`

Expand Down
1 change: 1 addition & 0 deletions integration/async-iterable-services/parameters.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
useAsyncIterable=true
Binary file added integration/async-iterable-services/simple.bin
Binary file not shown.
11 changes: 11 additions & 0 deletions integration/async-iterable-services/simple.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto3";

package simple;

service Test {
rpc BidiStreaming(stream TestMessage) returns (stream TestMessage) {}
}

message TestMessage {
string value = 1;
}
138 changes: 138 additions & 0 deletions integration/async-iterable-services/simple.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/* eslint-disable */
import * as _m0 from 'protobufjs/minimal';

export const protobufPackage = 'simple';

export interface TestMessage {
value: string;
}

function createBaseTestMessage(): TestMessage {
return { value: '' };
}

export const TestMessage = {
encode(message: TestMessage, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
if (message.value !== '') {
writer.uint32(10).string(message.value);
}
return writer;
},

decode(input: _m0.Reader | Uint8Array, length?: number): TestMessage {
const reader = input instanceof _m0.Reader ? input : new _m0.Reader(input);
let end = length === undefined ? reader.len : reader.pos + length;
const message = createBaseTestMessage();
while (reader.pos < end) {
const tag = reader.uint32();
switch (tag >>> 3) {
case 1:
message.value = reader.string();
break;
default:
reader.skipType(tag & 7);
break;
}
}
return message;
},

// encodeTransform encodes a source of message objects.
// Transform<TestMessage, Uint8Array>
async *encodeTransform(
source: AsyncIterable<TestMessage | TestMessage[]> | Iterable<TestMessage | TestMessage[]>
): AsyncIterable<Uint8Array> {
for await (const pkt of source) {
if (Array.isArray(pkt)) {
for (const p of pkt) {
yield* [TestMessage.encode(p).finish()];
}
} else {
yield* [TestMessage.encode(pkt).finish()];
}
}
},

// decodeTransform decodes a source of encoded messages.
// Transform<Uint8Array, TestMessage>
async *decodeTransform(
source: AsyncIterable<Uint8Array | Uint8Array[]> | Iterable<Uint8Array | Uint8Array[]>
): AsyncIterable<TestMessage> {
for await (const pkt of source) {
if (Array.isArray(pkt)) {
for (const p of pkt) {
yield* [TestMessage.decode(p)];
}
} else {
yield* [TestMessage.decode(pkt)];
}
}
},

fromJSON(object: any): TestMessage {
return {
value: isSet(object.value) ? String(object.value) : '',
};
},

toJSON(message: TestMessage): unknown {
const obj: any = {};
message.value !== undefined && (obj.value = message.value);
return obj;
},

fromPartial<I extends Exact<DeepPartial<TestMessage>, I>>(object: I): TestMessage {
const message = createBaseTestMessage();
message.value = object.value ?? '';
return message;
},
};

export interface Test {
BidiStreaming(request: AsyncIterable<TestMessage>): AsyncIterable<TestMessage>;
}

export class TestClientImpl implements Test {
private readonly rpc: Rpc;
constructor(rpc: Rpc) {
this.rpc = rpc;
this.BidiStreaming = this.BidiStreaming.bind(this);
}
BidiStreaming(request: AsyncIterable<TestMessage>): AsyncIterable<TestMessage> {
const data = TestMessage.encodeTransform(request);
const result = this.rpc.bidirectionalStreamingRequest('simple.Test', 'BidiStreaming', data);
return TestMessage.decodeTransform(result);
}
}

interface Rpc {
request(service: string, method: string, data: Uint8Array): Promise<Uint8Array>;
clientStreamingRequest(service: string, method: string, data: AsyncIterable<Uint8Array>): Promise<Uint8Array>;
serverStreamingRequest(service: string, method: string, data: Uint8Array): AsyncIterable<Uint8Array>;
bidirectionalStreamingRequest(
service: string,
method: string,
data: AsyncIterable<Uint8Array>
): AsyncIterable<Uint8Array>;
}

type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined;

export type DeepPartial<T> = T extends Builtin
? T
: T extends Array<infer U>
? Array<DeepPartial<U>>
: T extends ReadonlyArray<infer U>
? ReadonlyArray<DeepPartial<U>>
: T extends {}
? { [K in keyof T]?: DeepPartial<T[K]> }
: Partial<T>;

type KeysOfUnion<T> = T extends T ? keyof T : never;
export type Exact<P, I extends P> = P extends Builtin
? P
: P & { [K in keyof P]: Exact<P[K], I[K]> } & Record<Exclude<keyof I, KeysOfUnion<P>>, never>;

function isSet(value: any): boolean {
return value !== null && value !== undefined;
}
2 changes: 1 addition & 1 deletion integration/generic-metadata/hero.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable */
import { Observable } from 'rxjs';
import { Foo } from './some-file';
import { Observable } from 'rxjs';
import { map } from 'rxjs/operators';
import * as _m0 from 'protobufjs/minimal';

Expand Down
2 changes: 1 addition & 1 deletion integration/grpc-web-no-streaming-observable/example.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/* eslint-disable */
import { grpc } from '@improbable-eng/grpc-web';
import { Observable } from 'rxjs';
import { BrowserHeaders } from 'browser-headers';
import { take } from 'rxjs/operators';
import { Observable } from 'rxjs';
import * as _m0 from 'protobufjs/minimal';

export const protobufPackage = 'rpx';
Expand Down
2 changes: 1 addition & 1 deletion integration/grpc-web/example.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/* eslint-disable */
import { grpc } from '@improbable-eng/grpc-web';
import { Observable } from 'rxjs';
import { BrowserHeaders } from 'browser-headers';
import { share } from 'rxjs/operators';
import { Observable } from 'rxjs';
import * as _m0 from 'protobufjs/minimal';

export const protobufPackage = 'rpx';
Expand Down
2 changes: 1 addition & 1 deletion integration/nestjs-metadata-grpc-js/hero.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable */
import { GrpcMethod, GrpcStreamMethod } from '@nestjs/microservices';
import { Observable } from 'rxjs';
import { Metadata } from '@grpc/grpc-js';
import { Observable } from 'rxjs';

export const protobufPackage = 'hero';

Expand Down
2 changes: 1 addition & 1 deletion integration/nestjs-metadata-observables/hero.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable */
import { GrpcMethod, GrpcStreamMethod } from '@nestjs/microservices';
import { Observable } from 'rxjs';
import { Metadata } from '@grpc/grpc-js';
import { Observable } from 'rxjs';

export const protobufPackage = 'hero';

Expand Down
2 changes: 1 addition & 1 deletion integration/nestjs-metadata-restparameters/hero.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable */
import { GrpcMethod, GrpcStreamMethod } from '@nestjs/microservices';
import { Observable } from 'rxjs';
import { Metadata } from '@grpc/grpc-js';
import { Observable } from 'rxjs';

export const protobufPackage = 'hero';

Expand Down
2 changes: 1 addition & 1 deletion integration/nestjs-metadata/hero.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable */
import { GrpcMethod, GrpcStreamMethod } from '@nestjs/microservices';
import { Observable } from 'rxjs';
import { Metadata } from '@grpc/grpc-js';
import { Observable } from 'rxjs';

export const protobufPackage = 'hero';

Expand Down
2 changes: 1 addition & 1 deletion integration/nestjs-simple/hero.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable */
import { GrpcMethod, GrpcStreamMethod } from '@nestjs/microservices';
import { Observable } from 'rxjs';
import { Timestamp } from './google/protobuf/timestamp';
import { Observable } from 'rxjs';
import { Empty } from './google/protobuf/empty';

export const protobufPackage = 'hero';
Expand Down
43 changes: 43 additions & 0 deletions src/generate-async-iterable.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { code, Code } from 'ts-poet';

/** Creates a function to transform a message Source to a Uint8Array Source. */
export function generateEncodeTransform(fullName: string): Code {
return code`
// encodeTransform encodes a source of message objects.
// Transform<${fullName}, Uint8Array>
async *encodeTransform(
source: AsyncIterable<${fullName} | ${fullName}[]> | Iterable<${fullName} | ${fullName}[]>
): AsyncIterable<Uint8Array> {
for await (const pkt of source) {
if (Array.isArray(pkt)) {
for (const p of pkt) {
yield* [${fullName}.encode(p).finish()]
}
} else {
yield* [${fullName}.encode(pkt).finish()]
}
}
}
`;
}

/** Creates a function to transform a Uint8Array Source to a message Source. */
export function generateDecodeTransform(fullName: string): Code {
return code`
// decodeTransform decodes a source of encoded messages.
// Transform<Uint8Array, ${fullName}>
async *decodeTransform(
source: AsyncIterable<Uint8Array | Uint8Array[]> | Iterable<Uint8Array | Uint8Array[]>
): AsyncIterable<${fullName}> {
for await (const pkt of source) {
if (Array.isArray(pkt)) {
for (const p of pkt) {
yield* [${fullName}.decode(p)]
}
} else {
yield* [${fullName}.decode(pkt)]
}
}
}
`;
}
29 changes: 14 additions & 15 deletions src/generate-grpc-web.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { MethodDescriptorProto, FileDescriptorProto, ServiceDescriptorProto } from 'ts-proto-descriptors';
import { rawRequestType, requestType, responsePromiseOrObservable, responseType } from './types';
import { rawRequestType, requestType, responsePromiseOrObservable, responseType, observableType } from './types';
import { Code, code, imp, joinCode } from 'ts-poet';
import { Context } from './context';
import { assertInstanceOf, FormattedMethodDescriptor, maybePrefixPackage } from './utils';
Expand All @@ -8,12 +8,11 @@ const grpc = imp('grpc@@improbable-eng/grpc-web');
const share = imp('share@rxjs/operators');
const take = imp('take@rxjs/operators');
const BrowserHeaders = imp('BrowserHeaders@browser-headers');
const Observable = imp('Observable@rxjs');

/** Generates a client that uses the `@improbable-web/grpc-web` library. */
export function generateGrpcClientImpl(
ctx: Context,
fileDesc: FileDescriptorProto,
_fileDesc: FileDescriptorProto,
serviceDesc: ServiceDescriptorProto
): Code {
const chunks: Code[] = [];
Expand Down Expand Up @@ -154,18 +153,18 @@ export function addGrpcWebMisc(ctx: Context, hasStreamingMethods: boolean): Code
interface UnaryMethodDefinitionishR extends ${grpc}.UnaryMethodDefinition<any, any> { requestStream: any; responseStream: any; }
`);
chunks.push(code`type UnaryMethodDefinitionish = UnaryMethodDefinitionishR;`);
chunks.push(generateGrpcWebRpcType(options.returnObservable, hasStreamingMethods));
chunks.push(generateGrpcWebImpl(options.returnObservable, hasStreamingMethods));
chunks.push(generateGrpcWebRpcType(ctx, options.returnObservable, hasStreamingMethods));
chunks.push(generateGrpcWebImpl(ctx, options.returnObservable, hasStreamingMethods));
return joinCode(chunks, { on: '\n\n' });
}

/** Makes an `Rpc` interface to decouple from the low-level grpc-web `grpc.invoke and grpc.unary`/etc. methods. */
function generateGrpcWebRpcType(returnObservable: boolean, hasStreamingMethods: boolean): Code {
function generateGrpcWebRpcType(ctx: Context, returnObservable: boolean, hasStreamingMethods: boolean): Code {
const chunks: Code[] = [];

chunks.push(code`interface Rpc {`);

const wrapper = returnObservable ? Observable : 'Promise';
const wrapper = returnObservable ? observableType(ctx) : 'Promise';
chunks.push(code`
unary<T extends UnaryMethodDefinitionish>(
methodDesc: T,
Expand All @@ -180,7 +179,7 @@ function generateGrpcWebRpcType(returnObservable: boolean, hasStreamingMethods:
methodDesc: T,
request: any,
metadata: grpc.Metadata | undefined,
): ${Observable}<any>;
): ${observableType(ctx)}<any>;
`);
}

Expand All @@ -189,7 +188,7 @@ function generateGrpcWebRpcType(returnObservable: boolean, hasStreamingMethods:
}

/** Implements the `Rpc` interface by making calls using the `grpc.unary` method. */
function generateGrpcWebImpl(returnObservable: boolean, hasStreamingMethods: boolean): Code {
function generateGrpcWebImpl(ctx: Context, returnObservable: boolean, hasStreamingMethods: boolean): Code {
const options = code`
{
transport?: grpc.TransportFactory,
Expand All @@ -212,13 +211,13 @@ function generateGrpcWebImpl(returnObservable: boolean, hasStreamingMethods: boo
`);

if (returnObservable) {
chunks.push(createObservableUnaryMethod());
chunks.push(createObservableUnaryMethod(ctx));
} else {
chunks.push(createPromiseUnaryMethod());
}

if (hasStreamingMethods) {
chunks.push(createInvokeMethod());
chunks.push(createInvokeMethod(ctx));
}

chunks.push(code`}`);
Expand Down Expand Up @@ -260,13 +259,13 @@ function createPromiseUnaryMethod(): Code {
`;
}

function createObservableUnaryMethod(): Code {
function createObservableUnaryMethod(ctx: Context): Code {
return code`
unary<T extends UnaryMethodDefinitionish>(
methodDesc: T,
_request: any,
metadata: grpc.Metadata | undefined
): ${Observable}<any> {
): ${observableType(ctx)}<any> {
const request = { ..._request, ...methodDesc.requestType };
const maybeCombinedMetadata =
metadata && this.options.metadata
Expand All @@ -293,13 +292,13 @@ function createObservableUnaryMethod(): Code {
`;
}

function createInvokeMethod() {
function createInvokeMethod(ctx: Context) {
return code`
invoke<T extends UnaryMethodDefinitionish>(
methodDesc: T,
_request: any,
metadata: grpc.Metadata | undefined
): ${Observable}<any> {
): ${observableType(ctx)}<any> {
// Status Response Codes (https://developers.google.com/maps-booking/reference/grpc-api/status_codes)
const upStreamCodes = [2, 4, 8, 9, 10, 13, 14, 15];
const DEFAULT_TIMEOUT_TIME: number = 3_000;
Expand Down
Loading

0 comments on commit ca8ea8d

Please sign in to comment.