Skip to content

Commit

Permalink
feat: Streaming support (#373)
Browse files Browse the repository at this point in the history
* Fix response type handling of server streaming methods

It still won't compile, but this at least makes the class match the
interface.

* Support for streaming RPCs in services

If you define a single streaming RPC (client, server or bidirectional
streaming), the Rpc interface will require you to also implement
clientStreaming, serverStreaming and bidirectionalStreaming. Those
methods will be passed/return Observables.

Previously, the generated file when using streaming RPCs did not
compile.

* Run prettier over new streaming code

* Streaming: Annotate code snippets with code``

This fixes broken generated code that somehow didn't happen for all tests but
only a few.

* Streaming: Run prettier again
  • Loading branch information
Jille committed Nov 2, 2021
1 parent 642d74c commit 459b94f
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 48 deletions.
2 changes: 1 addition & 1 deletion integration/batching-with-context/batching.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable */
import { util, configure, Reader, Writer } from 'protobufjs/minimal';
import { util, configure, Writer, Reader } from 'protobufjs/minimal';
import * as Long from 'long';
import * as DataLoader from 'dataloader';
import * as hash from 'object-hash';
Expand Down
2 changes: 1 addition & 1 deletion integration/batching/batching.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable */
import { util, configure, Reader, Writer } from 'protobufjs/minimal';
import { util, configure, Writer, Reader } from 'protobufjs/minimal';
import * as Long from 'long';

export const protobufPackage = 'batching';
Expand Down
12 changes: 8 additions & 4 deletions integration/grpc-web-go-server/example.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/* eslint-disable */
import { util, configure, Reader, Writer } from 'protobufjs/minimal';
import { util, configure, Writer, Reader } from 'protobufjs/minimal';
import * as Long from 'long';
import { Observable } from 'rxjs';
import { map } from 'rxjs/operators';

export const protobufPackage = 'rpx';

Expand Down Expand Up @@ -676,10 +677,10 @@ export class DashStateClientImpl implements DashState {
return promise.then((data) => DashUserSettingsState.decode(new Reader(data)));
}

ActiveUserSettingsStream(request: Empty): Promise<DashUserSettingsState> {
ActiveUserSettingsStream(request: Empty): Observable<DashUserSettingsState> {
const data = Empty.encode(request).finish();
const promise = this.rpc.request('rpx.DashState', 'ActiveUserSettingsStream', data);
return promise.then((data) => DashUserSettingsState.decode(new Reader(data)));
const result = this.rpc.serverStreamingRequest('rpx.DashState', 'ActiveUserSettingsStream', data);
return result.pipe(map((data) => DashUserSettingsState.decode(new Reader(data))));
}
}

Expand Down Expand Up @@ -723,6 +724,9 @@ export class DashAPICredsClientImpl implements DashAPICreds {

interface Rpc {
request(service: string, method: string, data: Uint8Array): Promise<Uint8Array>;
clientStreamingRequest(service: string, method: string, data: Observable<Uint8Array>): Promise<Uint8Array>;
serverStreamingRequest(service: string, method: string, data: Uint8Array): Observable<Uint8Array>;
bidirectionalStreamingRequest(service: string, method: string, data: Observable<Uint8Array>): Observable<Uint8Array>;
}

type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined;
Expand Down
2 changes: 1 addition & 1 deletion integration/lower-case-svc-methods/math.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable */
import { util, configure, Reader, Writer } from 'protobufjs/minimal';
import { util, configure, Writer, Reader } from 'protobufjs/minimal';
import * as Long from 'long';
import * as DataLoader from 'dataloader';
import * as hash from 'object-hash';
Expand Down
2 changes: 1 addition & 1 deletion integration/meta-typings/simple.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable */
import { FileDescriptorProto } from 'ts-proto-descriptors';
import { util, configure, Reader, Writer } from 'protobufjs/minimal';
import { util, configure, Writer, Reader } from 'protobufjs/minimal';
import * as Long from 'long';
import { protoMetadata as protoMetadata1, DateMessage } from './google/type/date';
import { protoMetadata as protoMetadata2, StringValue, Int32Value, BoolValue } from './google/protobuf/wrappers';
Expand Down
12 changes: 8 additions & 4 deletions integration/no-proto-package/no-proto-package.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/* eslint-disable */
import { util, configure, Reader, Writer } from 'protobufjs/minimal';
import { util, configure, Writer, Reader } from 'protobufjs/minimal';
import * as Long from 'long';
import { Observable } from 'rxjs';
import { map } from 'rxjs/operators';

export const protobufPackage = '';

Expand Down Expand Up @@ -110,15 +111,18 @@ export class UserStateClientImpl implements UserState {
this.rpc = rpc;
this.GetUsers = this.GetUsers.bind(this);
}
GetUsers(request: Empty): Promise<User> {
GetUsers(request: Empty): Observable<User> {
const data = Empty.encode(request).finish();
const promise = this.rpc.request('UserState', 'GetUsers', data);
return promise.then((data) => User.decode(new Reader(data)));
const result = this.rpc.serverStreamingRequest('UserState', 'GetUsers', data);
return result.pipe(map((data) => User.decode(new Reader(data))));
}
}

interface Rpc {
request(service: string, method: string, data: Uint8Array): Promise<Uint8Array>;
clientStreamingRequest(service: string, method: string, data: Observable<Uint8Array>): Promise<Uint8Array>;
serverStreamingRequest(service: string, method: string, data: Uint8Array): Observable<Uint8Array>;
bidirectionalStreamingRequest(service: string, method: string, data: Observable<Uint8Array>): Observable<Uint8Array>;
}

type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined;
Expand Down
2 changes: 1 addition & 1 deletion integration/simple-optionals/simple.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable */
import { util, configure, Reader, Writer } from 'protobufjs/minimal';
import { util, configure, Writer, Reader } from 'protobufjs/minimal';
import * as Long from 'long';
import { ImportedThing } from './import_dir/thing';
import { Timestamp } from './google/protobuf/timestamp';
Expand Down
2 changes: 1 addition & 1 deletion integration/simple-snake/simple.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable */
import { util, configure, Reader, Writer } from 'protobufjs/minimal';
import { util, configure, Writer, Reader } from 'protobufjs/minimal';
import * as Long from 'long';
import { Timestamp } from './google/protobuf/timestamp';
import { ImportedThing } from './import_dir/thing';
Expand Down
2 changes: 1 addition & 1 deletion integration/simple-unrecognized-enum/simple.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable */
import { util, configure, Reader, Writer } from 'protobufjs/minimal';
import { util, configure, Writer, Reader } from 'protobufjs/minimal';
import * as Long from 'long';
import { Timestamp } from './google/protobuf/timestamp';
import { ImportedThing } from './import_dir/thing';
Expand Down
2 changes: 1 addition & 1 deletion integration/simple/simple.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable */
import { util, configure, Reader, Writer } from 'protobufjs/minimal';
import { util, configure, Writer, Reader } from 'protobufjs/minimal';
import * as Long from 'long';
import { Timestamp } from './google/protobuf/timestamp';
import { ImportedThing } from './import_dir/thing';
Expand Down
7 changes: 2 additions & 5 deletions src/generate-grpc-web.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { MethodDescriptorProto, FileDescriptorProto, ServiceDescriptorProto } from 'ts-proto-descriptors';
import { requestType, responseObservable, responsePromise, responseType } from './types';
import { requestType, responsePromiseOrObservable, responseType } from './types';
import { Code, code, imp, joinCode } from 'ts-poet';
import { Context } from './context';
import { assertInstanceOf, FormattedMethodDescriptor, maybePrefixPackage } from './utils';
Expand Down Expand Up @@ -52,10 +52,7 @@ function generateRpcMethod(ctx: Context, serviceDesc: ServiceDescriptorProto, me
const { options, utils } = ctx;
const inputType = requestType(ctx, methodDesc);
const partialInputType = code`${utils.DeepPartial}<${inputType}>`;
const returns =
options.returnObservable || methodDesc.serverStreaming
? responseObservable(ctx, methodDesc)
: responsePromise(ctx, methodDesc);
const returns = responsePromiseOrObservable(ctx, methodDesc);
const method = methodDesc.serverStreaming ? 'invoke' : 'unary';
return code`
${methodDesc.formattedName}(
Expand Down
87 changes: 64 additions & 23 deletions src/generate-services.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import {
BatchMethod,
detectBatchMethod,
requestType,
responseObservable,
responsePromise,
rawRequestType,
responsePromiseOrObservable,
responseType,
} from './types';
import { assertInstanceOf, FormattedMethodDescriptor, maybeAddComment, maybePrefixPackage, singular } from './utils';
Expand Down Expand Up @@ -72,15 +72,12 @@ export function generateService(
params.push(code`...rest: any`);
}

// Return observable for interface only configuration, passing returnObservable=true and methodDesc.serverStreaming=true
let returnType: Code;
if (options.returnObservable || methodDesc.serverStreaming) {
returnType = responseObservable(ctx, methodDesc);
} else {
returnType = responsePromise(ctx, methodDesc);
}

chunks.push(code`${methodDesc.formattedName}(${joinCode(params, { on: ',' })}): ${returnType};`);
chunks.push(
code`${methodDesc.formattedName}(${joinCode(params, { on: ',' })}): ${responsePromiseOrObservable(
ctx,
methodDesc
)};`
);

// If this is a batch method, auto-generate the singular version of it
if (options.context) {
Expand Down Expand Up @@ -108,24 +105,51 @@ function generateRegularRpcMethod(
assertInstanceOf(methodDesc, FormattedMethodDescriptor);
const { options } = ctx;
const Reader = imp('Reader@protobufjs/minimal');
const rawInputType = rawRequestType(ctx, methodDesc);
const inputType = requestType(ctx, methodDesc);
const outputType = responseType(ctx, methodDesc);

const params = [...(options.context ? [code`ctx: Context`] : []), code`request: ${inputType}`];
const maybeCtx = options.context ? 'ctx,' : '';

let encode = code`${rawInputType}.encode(request).finish()`;
let decode = code`data => ${outputType}.decode(new ${Reader}(data))`;

if (methodDesc.clientStreaming) {
encode = code`request.pipe(${imp('map@rxjs/operators')}(request => ${encode}))`;
}
let returnVariable: string;
if (options.returnObservable || methodDesc.serverStreaming) {
returnVariable = 'result';
decode = code`result.pipe(${imp('map@rxjs/operators')}(${decode}))`;
} else {
returnVariable = 'promise';
decode = code`promise.then(${decode})`;
}

let rpcMethod: string;
if (methodDesc.clientStreaming && methodDesc.serverStreaming) {
rpcMethod = 'bidirectionalStreamingRequest';
} else if (methodDesc.serverStreaming) {
rpcMethod = 'serverStreamingRequest';
} else if (methodDesc.clientStreaming) {
rpcMethod = 'clientStreamingRequest';
} else {
rpcMethod = 'request';
}

return code`
${methodDesc.formattedName}(
${joinCode(params, { on: ',' })}
): ${responsePromise(ctx, methodDesc)} {
const data = ${inputType}.encode(request).finish();
const promise = this.rpc.request(
): ${responsePromiseOrObservable(ctx, methodDesc)} {
const data = ${encode};
const ${returnVariable} = this.rpc.${rpcMethod}(
${maybeCtx}
"${maybePrefixPackage(fileDesc, serviceDesc.name)}",
"${methodDesc.name}",
data
);
return promise.then(data => ${outputType}.decode(new ${Reader}(data)));
return ${decode};
}
`;
}
Expand Down Expand Up @@ -273,24 +297,41 @@ function generateCachingRpcMethod(
*
* This lets clients pass in their own request-promise-ish client.
*
* This also requires clientStreamingRequest, serverStreamingRequest and
* bidirectionalStreamingRequest methods if any of the RPCs is streaming.
*
* We don't export this because if a project uses multiple `*.proto` files,
* we don't want our the barrel imports in `index.ts` to have multiple `Rpc`
* types.
*/
export function generateRpcType(ctx: Context): Code {
export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Code {
const { options } = ctx;
const maybeContext = options.context ? '<Context>' : '';
const maybeContextParam = options.context ? 'ctx: Context,' : '';
return code`
interface Rpc${maybeContext} {
request(
const methods = [[code`request`, code`Uint8Array`, code`Promise<Uint8Array>`]];
if (hasStreamingMethods) {
const observable = imp('Observable@rxjs');
methods.push([code`clientStreamingRequest`, code`${observable}<Uint8Array>`, code`Promise<Uint8Array>`]);
methods.push([code`serverStreamingRequest`, code`Uint8Array`, code`${observable}<Uint8Array>`]);
methods.push([
code`bidirectionalStreamingRequest`,
code`${observable}<Uint8Array>`,
code`${observable}<Uint8Array>`,
]);
}
const chunks: Code[] = [];
chunks.push(code` interface Rpc${maybeContext} {`);
methods.forEach((method) => {
chunks.push(code`
${method[0]}(
${maybeContextParam}
service: string,
method: string,
data: Uint8Array
): Promise<Uint8Array>;
}
`;
data: ${method[1]}
): ${method[2]};`);
});
chunks.push(code` }`);
return joinCode(chunks, { on: '\n' });
}

export function generateDataLoadersType(): Code {
Expand Down
12 changes: 9 additions & 3 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ export function generateFile(ctx: Context, fileDesc: FileDescriptorProto): [stri
);
}

let hasServerStreamingMethods = false;
let hasStreamingMethods = false;

visitServices(fileDesc, sourceInfo, (serviceDesc, sInfo) => {
Expand Down Expand Up @@ -200,18 +201,23 @@ export function generateFile(ctx: Context, fileDesc: FileDescriptorProto): [stri
serviceDesc.method.forEach((method) => {
chunks.push(generateGrpcMethodDesc(ctx, serviceDesc, method));
if (method.serverStreaming) {
hasStreamingMethods = true;
hasServerStreamingMethods = true;
}
});
}
}
serviceDesc.method.forEach((methodDesc, index) => {
if (methodDesc.serverStreaming || methodDesc.clientStreaming) {
hasStreamingMethods = true;
}
});
});

if (options.outputServices === ServiceOption.DEFAULT && options.outputClientImpl && fileDesc.service.length > 0) {
if (options.outputClientImpl === true) {
chunks.push(generateRpcType(ctx));
chunks.push(generateRpcType(ctx, hasStreamingMethods));
} else if (options.outputClientImpl === 'grpc-web') {
chunks.push(addGrpcWebMisc(ctx, hasStreamingMethods));
chunks.push(addGrpcWebMisc(ctx, hasServerStreamingMethods));
}
}

Expand Down
14 changes: 13 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,12 @@ export function detectMapType(
return undefined;
}

export function rawRequestType(ctx: Context, methodDesc: MethodDescriptorProto): Code {
return messageToTypeName(ctx, methodDesc.inputType);
}

export function requestType(ctx: Context, methodDesc: MethodDescriptorProto): Code {
let typeName = messageToTypeName(ctx, methodDesc.inputType);
let typeName = rawRequestType(ctx, methodDesc);
if (methodDesc.clientStreaming) {
return code`${imp('Observable@rxjs')}<${typeName}>`;
}
Expand All @@ -557,6 +561,14 @@ export function responseObservable(ctx: Context, methodDesc: MethodDescriptorPro
return code`${imp('Observable@rxjs')}<${responseType(ctx, methodDesc)}>`;
}

export function responsePromiseOrObservable(ctx: Context, methodDesc: MethodDescriptorProto): Code {
const { options } = ctx;
if (options.returnObservable || methodDesc.serverStreaming) {
return responseObservable(ctx, methodDesc);
}
return responsePromise(ctx, methodDesc);
}

export interface BatchMethod {
methodDesc: MethodDescriptorProto;
// a ${package + service + method name} key to identify this method in caches
Expand Down

0 comments on commit 459b94f

Please sign in to comment.