Skip to content

Commit

Permalink
perf: generate switch statement for oneof union encode (#767)
Browse files Browse the repository at this point in the history
* Extract writeSnippet create out of generateEncode

* Perf: oneof union encode to switch statement

Previously each oneof union case generated its own `if` statement in the
encode method which becomes very inefficient with more cases as each of
the conditions needed to be evaluated but only one would ever be true.
They were not even `else if (...)` so once a match was found the rest of
the `if` statements would still continue to be evaluated.
  • Loading branch information
antsluts committed Jan 31, 2023
1 parent e8c6d8b commit c3fd1e3
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 127 deletions.
36 changes: 19 additions & 17 deletions integration/oneof-unions-snake/google/protobuf/struct.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,23 +246,25 @@ function createBaseValue(): Value {

export const Value = {
encode(message: Value, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
if (message.kind?.$case === "null_value") {
writer.uint32(8).int32(message.kind.null_value);
}
if (message.kind?.$case === "number_value") {
writer.uint32(17).double(message.kind.number_value);
}
if (message.kind?.$case === "string_value") {
writer.uint32(26).string(message.kind.string_value);
}
if (message.kind?.$case === "bool_value") {
writer.uint32(32).bool(message.kind.bool_value);
}
if (message.kind?.$case === "struct_value") {
Struct.encode(Struct.wrap(message.kind.struct_value), writer.uint32(42).fork()).ldelim();
}
if (message.kind?.$case === "list_value") {
ListValue.encode(ListValue.wrap(message.kind.list_value), writer.uint32(50).fork()).ldelim();
switch (message.kind?.$case) {
case "null_value":
writer.uint32(8).int32(message.kind.null_value);
break;
case "number_value":
writer.uint32(17).double(message.kind.number_value);
break;
case "string_value":
writer.uint32(26).string(message.kind.string_value);
break;
case "bool_value":
writer.uint32(32).bool(message.kind.bool_value);
break;
case "struct_value":
Struct.encode(Struct.wrap(message.kind.struct_value), writer.uint32(42).fork()).ldelim();
break;
case "list_value":
ListValue.encode(ListValue.wrap(message.kind.list_value), writer.uint32(50).fork()).ldelim();
break;
}
return writer;
},
Expand Down
36 changes: 19 additions & 17 deletions integration/oneof-unions/google/protobuf/struct.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,23 +246,25 @@ function createBaseValue(): Value {

export const Value = {
encode(message: Value, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
if (message.kind?.$case === "nullValue") {
writer.uint32(8).int32(message.kind.nullValue);
}
if (message.kind?.$case === "numberValue") {
writer.uint32(17).double(message.kind.numberValue);
}
if (message.kind?.$case === "stringValue") {
writer.uint32(26).string(message.kind.stringValue);
}
if (message.kind?.$case === "boolValue") {
writer.uint32(32).bool(message.kind.boolValue);
}
if (message.kind?.$case === "structValue") {
Struct.encode(Struct.wrap(message.kind.structValue), writer.uint32(42).fork()).ldelim();
}
if (message.kind?.$case === "listValue") {
ListValue.encode(ListValue.wrap(message.kind.listValue), writer.uint32(50).fork()).ldelim();
switch (message.kind?.$case) {
case "nullValue":
writer.uint32(8).int32(message.kind.nullValue);
break;
case "numberValue":
writer.uint32(17).double(message.kind.numberValue);
break;
case "stringValue":
writer.uint32(26).string(message.kind.stringValue);
break;
case "boolValue":
writer.uint32(32).bool(message.kind.boolValue);
break;
case "structValue":
Struct.encode(Struct.wrap(message.kind.structValue), writer.uint32(42).fork()).ldelim();
break;
case "listValue":
ListValue.encode(ListValue.wrap(message.kind.listValue), writer.uint32(50).fork()).ldelim();
break;
}
return writer;
},
Expand Down
54 changes: 29 additions & 25 deletions integration/oneof-unions/oneof.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,35 +80,39 @@ export const PleaseChoose = {
if (message.name !== "") {
writer.uint32(10).string(message.name);
}
if (message.choice?.$case === "aNumber") {
writer.uint32(17).double(message.choice.aNumber);
}
if (message.choice?.$case === "aString") {
writer.uint32(26).string(message.choice.aString);
}
if (message.choice?.$case === "aMessage") {
PleaseChoose_Submessage.encode(message.choice.aMessage, writer.uint32(34).fork()).ldelim();
}
if (message.choice?.$case === "aBool") {
writer.uint32(48).bool(message.choice.aBool);
}
if (message.choice?.$case === "bunchaBytes") {
writer.uint32(82).bytes(message.choice.bunchaBytes);
}
if (message.choice?.$case === "anEnum") {
writer.uint32(88).int32(message.choice.anEnum);
switch (message.choice?.$case) {
case "aNumber":
writer.uint32(17).double(message.choice.aNumber);
break;
case "aString":
writer.uint32(26).string(message.choice.aString);
break;
case "aMessage":
PleaseChoose_Submessage.encode(message.choice.aMessage, writer.uint32(34).fork()).ldelim();
break;
case "aBool":
writer.uint32(48).bool(message.choice.aBool);
break;
case "bunchaBytes":
writer.uint32(82).bytes(message.choice.bunchaBytes);
break;
case "anEnum":
writer.uint32(88).int32(message.choice.anEnum);
break;
}
if (message.age !== 0) {
writer.uint32(40).uint32(message.age);
}
if (message.eitherOr?.$case === "either") {
writer.uint32(58).string(message.eitherOr.either);
}
if (message.eitherOr?.$case === "or") {
writer.uint32(66).string(message.eitherOr.or);
}
if (message.eitherOr?.$case === "thirdOption") {
writer.uint32(74).string(message.eitherOr.thirdOption);
switch (message.eitherOr?.$case) {
case "either":
writer.uint32(58).string(message.eitherOr.either);
break;
case "or":
writer.uint32(66).string(message.eitherOr.or);
break;
case "thirdOption":
writer.uint32(74).string(message.eitherOr.thirdOption);
break;
}
if (message.signature.length !== 0) {
writer.uint32(98).bytes(message.signature);
Expand Down
36 changes: 19 additions & 17 deletions integration/use-readonly-types/google/protobuf/struct.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,23 +246,25 @@ function createBaseValue(): Value {

export const Value = {
encode(message: Value, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
if (message.kind?.$case === "nullValue") {
writer.uint32(8).int32(message.kind.nullValue);
}
if (message.kind?.$case === "numberValue") {
writer.uint32(17).double(message.kind.numberValue);
}
if (message.kind?.$case === "stringValue") {
writer.uint32(26).string(message.kind.stringValue);
}
if (message.kind?.$case === "boolValue") {
writer.uint32(32).bool(message.kind.boolValue);
}
if (message.kind?.$case === "structValue") {
Struct.encode(Struct.wrap(message.kind.structValue), writer.uint32(42).fork()).ldelim();
}
if (message.kind?.$case === "listValue") {
ListValue.encode(ListValue.wrap(message.kind.listValue), writer.uint32(50).fork()).ldelim();
switch (message.kind?.$case) {
case "nullValue":
writer.uint32(8).int32(message.kind.nullValue);
break;
case "numberValue":
writer.uint32(17).double(message.kind.numberValue);
break;
case "stringValue":
writer.uint32(26).string(message.kind.stringValue);
break;
case "boolValue":
writer.uint32(32).bool(message.kind.boolValue);
break;
case "structValue":
Struct.encode(Struct.wrap(message.kind.structValue), writer.uint32(42).fork()).ldelim();
break;
case "listValue":
ListValue.encode(ListValue.wrap(message.kind.listValue), writer.uint32(50).fork()).ldelim();
break;
}
return writer;
},
Expand Down
12 changes: 7 additions & 5 deletions integration/use-readonly-types/use-readonly-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ export const Entity = {
if (message.structValue !== undefined) {
Struct.encode(Struct.wrap(message.structValue), writer.uint32(82).fork()).ldelim();
}
if (message.oneOfValue?.$case === "theStringValue") {
writer.uint32(90).string(message.oneOfValue.theStringValue);
}
if (message.oneOfValue?.$case === "theIntValue") {
writer.uint32(96).int32(message.oneOfValue.theIntValue);
switch (message.oneOfValue?.$case) {
case "theStringValue":
writer.uint32(90).string(message.oneOfValue.theStringValue);
break;
case "theIntValue":
writer.uint32(96).int32(message.oneOfValue.theIntValue);
break;
}
return writer;
},
Expand Down
111 changes: 65 additions & 46 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,49 @@ function generateDecode(ctx: Context, fullName: string, messageDesc: DescriptorP
return joinCode(chunks, { on: "\n" });
}

/** Returns a generic writer.doSomething based on the basic type */
function getEncodeWriteSnippet(ctx: Context, field: FieldDescriptorProto): (place: string) => Code {
const { options, utils } = ctx;
if (isEnum(field) && options.stringEnums) {
const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0;
const toNumber = getEnumMethod(ctx, field.typeName, "ToNumber");
return (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${toNumber}(${place}))`;
} else if (isLong(field) && options.forceLong === LongOption.BIGINT) {
const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0;
return (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${place}.toString())`;
} else if (isScalar(field) || isEnum(field)) {
const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0;
return (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${place})`;
} else if (isObjectId(field) && options.useMongoObjectId) {
const tag = ((field.number << 3) | 2) >>> 0;
const type = basicTypeName(ctx, field, { keepValueType: true });
return (place) => code`${type}.encode(${utils.toProtoObjectId}(${place}), writer.uint32(${tag}).fork()).ldelim()`;
} else if (isTimestamp(field) && (options.useDate === DateOption.DATE || options.useDate === DateOption.STRING)) {
const tag = ((field.number << 3) | 2) >>> 0;
const type = basicTypeName(ctx, field, { keepValueType: true });
return (place) => code`${type}.encode(${utils.toTimestamp}(${place}), writer.uint32(${tag}).fork()).ldelim()`;
} else if (isValueType(ctx, field)) {
const maybeTypeField = options.outputTypeRegistry ? `$type: '${field.typeName.slice(1)}',` : "";

const type = basicTypeName(ctx, field, { keepValueType: true });
const wrappedValue = (place: string): Code => {
if (isAnyValueType(field) || isListValueType(field) || isStructType(field) || isFieldMaskType(field)) {
return code`${type}.wrap(${place})`;
}
return code`{${maybeTypeField} value: ${place}!}`;
};

const tag = ((field.number << 3) | 2) >>> 0;
return (place) => code`${type}.encode(${wrappedValue(place)}, writer.uint32(${tag}).fork()).ldelim()`;
} else if (isMessage(field)) {
const tag = ((field.number << 3) | 2) >>> 0;
const type = basicTypeName(ctx, field);
return (place) => code`${type}.encode(${place}, writer.uint32(${tag}).fork()).ldelim()`;
} else {
throw new Error(`Unhandled field ${field}`);
}
}

/** Creates a function to encode a message by loop overing the tags. */
function generateEncode(ctx: Context, fullName: string, messageDesc: DescriptorProto): Code {
const { options, utils, typeMap } = ctx;
Expand All @@ -1074,52 +1117,20 @@ function generateEncode(ctx: Context, fullName: string, messageDesc: DescriptorP
): ${Writer} {
`);

const processedOneofs = new Set<number>();
const oneOfFieldsDict = messageDesc.field
.filter((field) => isWithinOneOfThatShouldBeUnion(options, field))
.reduce<{ [key: number]: FieldDescriptorProto[] }>(
(result, field) => ((result[field.oneofIndex] || (result[field.oneofIndex] = [])).push(field), result),
{}
);

// then add a case for each field
messageDesc.field.forEach((field) => {
const fieldName = maybeSnakeToCamel(field.name, options);

// get a generic writer.doSomething based on the basic type
let writeSnippet: (place: string) => Code;
if (isEnum(field) && options.stringEnums) {
const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0;
const toNumber = getEnumMethod(ctx, field.typeName, "ToNumber");
writeSnippet = (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${toNumber}(${place}))`;
} else if (isLong(field) && options.forceLong === LongOption.BIGINT) {
const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0;
writeSnippet = (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${place}.toString())`;
} else if (isScalar(field) || isEnum(field)) {
const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0;
writeSnippet = (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${place})`;
} else if (isObjectId(field) && options.useMongoObjectId) {
const tag = ((field.number << 3) | 2) >>> 0;
const type = basicTypeName(ctx, field, { keepValueType: true });
writeSnippet = (place) =>
code`${type}.encode(${utils.toProtoObjectId}(${place}), writer.uint32(${tag}).fork()).ldelim()`;
} else if (isTimestamp(field) && (options.useDate === DateOption.DATE || options.useDate === DateOption.STRING)) {
const tag = ((field.number << 3) | 2) >>> 0;
const type = basicTypeName(ctx, field, { keepValueType: true });
writeSnippet = (place) =>
code`${type}.encode(${utils.toTimestamp}(${place}), writer.uint32(${tag}).fork()).ldelim()`;
} else if (isValueType(ctx, field)) {
const maybeTypeField = options.outputTypeRegistry ? `$type: '${field.typeName.slice(1)}',` : "";

const type = basicTypeName(ctx, field, { keepValueType: true });
const wrappedValue = (place: string): Code => {
if (isAnyValueType(field) || isListValueType(field) || isStructType(field) || isFieldMaskType(field)) {
return code`${type}.wrap(${place})`;
}
return code`{${maybeTypeField} value: ${place}!}`;
};

const tag = ((field.number << 3) | 2) >>> 0;
writeSnippet = (place) => code`${type}.encode(${wrappedValue(place)}, writer.uint32(${tag}).fork()).ldelim()`;
} else if (isMessage(field)) {
const tag = ((field.number << 3) | 2) >>> 0;
const type = basicTypeName(ctx, field);
writeSnippet = (place) => code`${type}.encode(${place}, writer.uint32(${tag}).fork()).ldelim()`;
} else {
throw new Error(`Unhandled field ${field}`);
}
const writeSnippet = getEncodeWriteSnippet(ctx, field);

const isOptional = isOptionalProperty(field, messageDesc.options, options);
if (isRepeated(field)) {
Expand Down Expand Up @@ -1208,12 +1219,20 @@ function generateEncode(ctx: Context, fullName: string, messageDesc: DescriptorP
}
}
} else if (isWithinOneOfThatShouldBeUnion(options, field)) {
let oneofName = maybeSnakeToCamel(messageDesc.oneofDecl[field.oneofIndex].name, options);
chunks.push(code`
if (message.${oneofName}?.$case === '${fieldName}') {
${writeSnippet(`message.${oneofName}.${fieldName}`)};
if (!processedOneofs.has(field.oneofIndex)) {
processedOneofs.add(field.oneofIndex);

const oneofName = maybeSnakeToCamel(messageDesc.oneofDecl[field.oneofIndex].name, options);
chunks.push(code`switch (message.${oneofName}?.$case) {`);
for (const oneOfField of oneOfFieldsDict[field.oneofIndex]) {
const writeSnippet = getEncodeWriteSnippet(ctx, oneOfField);
const oneOfFieldName = maybeSnakeToCamel(oneOfField.name, ctx.options);
chunks.push(code`case "${oneOfFieldName}":
${writeSnippet(`message.${oneofName}.${oneOfFieldName}`)};
break;`);
}
`);
chunks.push(code`}`);
}
} else if (isWithinOneOf(field)) {
// Oneofs don't have a default value check b/c they need to denote which-oneof presence
chunks.push(code`
Expand Down

0 comments on commit c3fd1e3

Please sign in to comment.