diff --git a/integration/oneof-unions-snake/google/protobuf/struct.ts b/integration/oneof-unions-snake/google/protobuf/struct.ts index 1cf3d7cc5..14a1efa4f 100644 --- a/integration/oneof-unions-snake/google/protobuf/struct.ts +++ b/integration/oneof-unions-snake/google/protobuf/struct.ts @@ -246,23 +246,25 @@ function createBaseValue(): Value { export const Value = { encode(message: Value, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { - if (message.kind?.$case === "null_value") { - writer.uint32(8).int32(message.kind.null_value); - } - if (message.kind?.$case === "number_value") { - writer.uint32(17).double(message.kind.number_value); - } - if (message.kind?.$case === "string_value") { - writer.uint32(26).string(message.kind.string_value); - } - if (message.kind?.$case === "bool_value") { - writer.uint32(32).bool(message.kind.bool_value); - } - if (message.kind?.$case === "struct_value") { - Struct.encode(Struct.wrap(message.kind.struct_value), writer.uint32(42).fork()).ldelim(); - } - if (message.kind?.$case === "list_value") { - ListValue.encode(ListValue.wrap(message.kind.list_value), writer.uint32(50).fork()).ldelim(); + switch (message.kind?.$case) { + case "null_value": + writer.uint32(8).int32(message.kind.null_value); + break; + case "number_value": + writer.uint32(17).double(message.kind.number_value); + break; + case "string_value": + writer.uint32(26).string(message.kind.string_value); + break; + case "bool_value": + writer.uint32(32).bool(message.kind.bool_value); + break; + case "struct_value": + Struct.encode(Struct.wrap(message.kind.struct_value), writer.uint32(42).fork()).ldelim(); + break; + case "list_value": + ListValue.encode(ListValue.wrap(message.kind.list_value), writer.uint32(50).fork()).ldelim(); + break; } return writer; }, diff --git a/integration/oneof-unions/google/protobuf/struct.ts b/integration/oneof-unions/google/protobuf/struct.ts index c135f74d2..e86599eee 100644 --- a/integration/oneof-unions/google/protobuf/struct.ts +++ b/integration/oneof-unions/google/protobuf/struct.ts @@ -246,23 +246,25 @@ function createBaseValue(): Value { export const Value = { encode(message: Value, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { - if (message.kind?.$case === "nullValue") { - writer.uint32(8).int32(message.kind.nullValue); - } - if (message.kind?.$case === "numberValue") { - writer.uint32(17).double(message.kind.numberValue); - } - if (message.kind?.$case === "stringValue") { - writer.uint32(26).string(message.kind.stringValue); - } - if (message.kind?.$case === "boolValue") { - writer.uint32(32).bool(message.kind.boolValue); - } - if (message.kind?.$case === "structValue") { - Struct.encode(Struct.wrap(message.kind.structValue), writer.uint32(42).fork()).ldelim(); - } - if (message.kind?.$case === "listValue") { - ListValue.encode(ListValue.wrap(message.kind.listValue), writer.uint32(50).fork()).ldelim(); + switch (message.kind?.$case) { + case "nullValue": + writer.uint32(8).int32(message.kind.nullValue); + break; + case "numberValue": + writer.uint32(17).double(message.kind.numberValue); + break; + case "stringValue": + writer.uint32(26).string(message.kind.stringValue); + break; + case "boolValue": + writer.uint32(32).bool(message.kind.boolValue); + break; + case "structValue": + Struct.encode(Struct.wrap(message.kind.structValue), writer.uint32(42).fork()).ldelim(); + break; + case "listValue": + ListValue.encode(ListValue.wrap(message.kind.listValue), writer.uint32(50).fork()).ldelim(); + break; } return writer; }, diff --git a/integration/oneof-unions/oneof.ts b/integration/oneof-unions/oneof.ts index 92b6a5357..c48473ad7 100644 --- a/integration/oneof-unions/oneof.ts +++ b/integration/oneof-unions/oneof.ts @@ -80,35 +80,39 @@ export const PleaseChoose = { if (message.name !== "") { writer.uint32(10).string(message.name); } - if (message.choice?.$case === "aNumber") { - writer.uint32(17).double(message.choice.aNumber); - } - if (message.choice?.$case === "aString") { - writer.uint32(26).string(message.choice.aString); - } - if (message.choice?.$case === "aMessage") { - PleaseChoose_Submessage.encode(message.choice.aMessage, writer.uint32(34).fork()).ldelim(); - } - if (message.choice?.$case === "aBool") { - writer.uint32(48).bool(message.choice.aBool); - } - if (message.choice?.$case === "bunchaBytes") { - writer.uint32(82).bytes(message.choice.bunchaBytes); - } - if (message.choice?.$case === "anEnum") { - writer.uint32(88).int32(message.choice.anEnum); + switch (message.choice?.$case) { + case "aNumber": + writer.uint32(17).double(message.choice.aNumber); + break; + case "aString": + writer.uint32(26).string(message.choice.aString); + break; + case "aMessage": + PleaseChoose_Submessage.encode(message.choice.aMessage, writer.uint32(34).fork()).ldelim(); + break; + case "aBool": + writer.uint32(48).bool(message.choice.aBool); + break; + case "bunchaBytes": + writer.uint32(82).bytes(message.choice.bunchaBytes); + break; + case "anEnum": + writer.uint32(88).int32(message.choice.anEnum); + break; } if (message.age !== 0) { writer.uint32(40).uint32(message.age); } - if (message.eitherOr?.$case === "either") { - writer.uint32(58).string(message.eitherOr.either); - } - if (message.eitherOr?.$case === "or") { - writer.uint32(66).string(message.eitherOr.or); - } - if (message.eitherOr?.$case === "thirdOption") { - writer.uint32(74).string(message.eitherOr.thirdOption); + switch (message.eitherOr?.$case) { + case "either": + writer.uint32(58).string(message.eitherOr.either); + break; + case "or": + writer.uint32(66).string(message.eitherOr.or); + break; + case "thirdOption": + writer.uint32(74).string(message.eitherOr.thirdOption); + break; } if (message.signature.length !== 0) { writer.uint32(98).bytes(message.signature); diff --git a/integration/use-readonly-types/google/protobuf/struct.ts b/integration/use-readonly-types/google/protobuf/struct.ts index 42c77fe9f..dab354cd9 100644 --- a/integration/use-readonly-types/google/protobuf/struct.ts +++ b/integration/use-readonly-types/google/protobuf/struct.ts @@ -246,23 +246,25 @@ function createBaseValue(): Value { export const Value = { encode(message: Value, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { - if (message.kind?.$case === "nullValue") { - writer.uint32(8).int32(message.kind.nullValue); - } - if (message.kind?.$case === "numberValue") { - writer.uint32(17).double(message.kind.numberValue); - } - if (message.kind?.$case === "stringValue") { - writer.uint32(26).string(message.kind.stringValue); - } - if (message.kind?.$case === "boolValue") { - writer.uint32(32).bool(message.kind.boolValue); - } - if (message.kind?.$case === "structValue") { - Struct.encode(Struct.wrap(message.kind.structValue), writer.uint32(42).fork()).ldelim(); - } - if (message.kind?.$case === "listValue") { - ListValue.encode(ListValue.wrap(message.kind.listValue), writer.uint32(50).fork()).ldelim(); + switch (message.kind?.$case) { + case "nullValue": + writer.uint32(8).int32(message.kind.nullValue); + break; + case "numberValue": + writer.uint32(17).double(message.kind.numberValue); + break; + case "stringValue": + writer.uint32(26).string(message.kind.stringValue); + break; + case "boolValue": + writer.uint32(32).bool(message.kind.boolValue); + break; + case "structValue": + Struct.encode(Struct.wrap(message.kind.structValue), writer.uint32(42).fork()).ldelim(); + break; + case "listValue": + ListValue.encode(ListValue.wrap(message.kind.listValue), writer.uint32(50).fork()).ldelim(); + break; } return writer; }, diff --git a/integration/use-readonly-types/use-readonly-types.ts b/integration/use-readonly-types/use-readonly-types.ts index 89389f902..7f1f6f055 100644 --- a/integration/use-readonly-types/use-readonly-types.ts +++ b/integration/use-readonly-types/use-readonly-types.ts @@ -76,11 +76,13 @@ export const Entity = { if (message.structValue !== undefined) { Struct.encode(Struct.wrap(message.structValue), writer.uint32(82).fork()).ldelim(); } - if (message.oneOfValue?.$case === "theStringValue") { - writer.uint32(90).string(message.oneOfValue.theStringValue); - } - if (message.oneOfValue?.$case === "theIntValue") { - writer.uint32(96).int32(message.oneOfValue.theIntValue); + switch (message.oneOfValue?.$case) { + case "theStringValue": + writer.uint32(90).string(message.oneOfValue.theStringValue); + break; + case "theIntValue": + writer.uint32(96).int32(message.oneOfValue.theIntValue); + break; } return writer; }, diff --git a/src/main.ts b/src/main.ts index 2bcabc23b..235044a37 100644 --- a/src/main.ts +++ b/src/main.ts @@ -1059,6 +1059,49 @@ function generateDecode(ctx: Context, fullName: string, messageDesc: DescriptorP return joinCode(chunks, { on: "\n" }); } +/** Returns a generic writer.doSomething based on the basic type */ +function getEncodeWriteSnippet(ctx: Context, field: FieldDescriptorProto): (place: string) => Code { + const { options, utils } = ctx; + if (isEnum(field) && options.stringEnums) { + const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0; + const toNumber = getEnumMethod(ctx, field.typeName, "ToNumber"); + return (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${toNumber}(${place}))`; + } else if (isLong(field) && options.forceLong === LongOption.BIGINT) { + const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0; + return (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${place}.toString())`; + } else if (isScalar(field) || isEnum(field)) { + const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0; + return (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${place})`; + } else if (isObjectId(field) && options.useMongoObjectId) { + const tag = ((field.number << 3) | 2) >>> 0; + const type = basicTypeName(ctx, field, { keepValueType: true }); + return (place) => code`${type}.encode(${utils.toProtoObjectId}(${place}), writer.uint32(${tag}).fork()).ldelim()`; + } else if (isTimestamp(field) && (options.useDate === DateOption.DATE || options.useDate === DateOption.STRING)) { + const tag = ((field.number << 3) | 2) >>> 0; + const type = basicTypeName(ctx, field, { keepValueType: true }); + return (place) => code`${type}.encode(${utils.toTimestamp}(${place}), writer.uint32(${tag}).fork()).ldelim()`; + } else if (isValueType(ctx, field)) { + const maybeTypeField = options.outputTypeRegistry ? `$type: '${field.typeName.slice(1)}',` : ""; + + const type = basicTypeName(ctx, field, { keepValueType: true }); + const wrappedValue = (place: string): Code => { + if (isAnyValueType(field) || isListValueType(field) || isStructType(field) || isFieldMaskType(field)) { + return code`${type}.wrap(${place})`; + } + return code`{${maybeTypeField} value: ${place}!}`; + }; + + const tag = ((field.number << 3) | 2) >>> 0; + return (place) => code`${type}.encode(${wrappedValue(place)}, writer.uint32(${tag}).fork()).ldelim()`; + } else if (isMessage(field)) { + const tag = ((field.number << 3) | 2) >>> 0; + const type = basicTypeName(ctx, field); + return (place) => code`${type}.encode(${place}, writer.uint32(${tag}).fork()).ldelim()`; + } else { + throw new Error(`Unhandled field ${field}`); + } +} + /** Creates a function to encode a message by loop overing the tags. */ function generateEncode(ctx: Context, fullName: string, messageDesc: DescriptorProto): Code { const { options, utils, typeMap } = ctx; @@ -1074,52 +1117,20 @@ function generateEncode(ctx: Context, fullName: string, messageDesc: DescriptorP ): ${Writer} { `); + const processedOneofs = new Set(); + const oneOfFieldsDict = messageDesc.field + .filter((field) => isWithinOneOfThatShouldBeUnion(options, field)) + .reduce<{ [key: number]: FieldDescriptorProto[] }>( + (result, field) => ((result[field.oneofIndex] || (result[field.oneofIndex] = [])).push(field), result), + {} + ); + // then add a case for each field messageDesc.field.forEach((field) => { const fieldName = maybeSnakeToCamel(field.name, options); // get a generic writer.doSomething based on the basic type - let writeSnippet: (place: string) => Code; - if (isEnum(field) && options.stringEnums) { - const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0; - const toNumber = getEnumMethod(ctx, field.typeName, "ToNumber"); - writeSnippet = (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${toNumber}(${place}))`; - } else if (isLong(field) && options.forceLong === LongOption.BIGINT) { - const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0; - writeSnippet = (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${place}.toString())`; - } else if (isScalar(field) || isEnum(field)) { - const tag = ((field.number << 3) | basicWireType(field.type)) >>> 0; - writeSnippet = (place) => code`writer.uint32(${tag}).${toReaderCall(field)}(${place})`; - } else if (isObjectId(field) && options.useMongoObjectId) { - const tag = ((field.number << 3) | 2) >>> 0; - const type = basicTypeName(ctx, field, { keepValueType: true }); - writeSnippet = (place) => - code`${type}.encode(${utils.toProtoObjectId}(${place}), writer.uint32(${tag}).fork()).ldelim()`; - } else if (isTimestamp(field) && (options.useDate === DateOption.DATE || options.useDate === DateOption.STRING)) { - const tag = ((field.number << 3) | 2) >>> 0; - const type = basicTypeName(ctx, field, { keepValueType: true }); - writeSnippet = (place) => - code`${type}.encode(${utils.toTimestamp}(${place}), writer.uint32(${tag}).fork()).ldelim()`; - } else if (isValueType(ctx, field)) { - const maybeTypeField = options.outputTypeRegistry ? `$type: '${field.typeName.slice(1)}',` : ""; - - const type = basicTypeName(ctx, field, { keepValueType: true }); - const wrappedValue = (place: string): Code => { - if (isAnyValueType(field) || isListValueType(field) || isStructType(field) || isFieldMaskType(field)) { - return code`${type}.wrap(${place})`; - } - return code`{${maybeTypeField} value: ${place}!}`; - }; - - const tag = ((field.number << 3) | 2) >>> 0; - writeSnippet = (place) => code`${type}.encode(${wrappedValue(place)}, writer.uint32(${tag}).fork()).ldelim()`; - } else if (isMessage(field)) { - const tag = ((field.number << 3) | 2) >>> 0; - const type = basicTypeName(ctx, field); - writeSnippet = (place) => code`${type}.encode(${place}, writer.uint32(${tag}).fork()).ldelim()`; - } else { - throw new Error(`Unhandled field ${field}`); - } + const writeSnippet = getEncodeWriteSnippet(ctx, field); const isOptional = isOptionalProperty(field, messageDesc.options, options); if (isRepeated(field)) { @@ -1208,12 +1219,20 @@ function generateEncode(ctx: Context, fullName: string, messageDesc: DescriptorP } } } else if (isWithinOneOfThatShouldBeUnion(options, field)) { - let oneofName = maybeSnakeToCamel(messageDesc.oneofDecl[field.oneofIndex].name, options); - chunks.push(code` - if (message.${oneofName}?.$case === '${fieldName}') { - ${writeSnippet(`message.${oneofName}.${fieldName}`)}; + if (!processedOneofs.has(field.oneofIndex)) { + processedOneofs.add(field.oneofIndex); + + const oneofName = maybeSnakeToCamel(messageDesc.oneofDecl[field.oneofIndex].name, options); + chunks.push(code`switch (message.${oneofName}?.$case) {`); + for (const oneOfField of oneOfFieldsDict[field.oneofIndex]) { + const writeSnippet = getEncodeWriteSnippet(ctx, oneOfField); + const oneOfFieldName = maybeSnakeToCamel(oneOfField.name, ctx.options); + chunks.push(code`case "${oneOfFieldName}": + ${writeSnippet(`message.${oneofName}.${oneOfFieldName}`)}; + break;`); } - `); + chunks.push(code`}`); + } } else if (isWithinOneOf(field)) { // Oneofs don't have a default value check b/c they need to denote which-oneof presence chunks.push(code`