Skip to content
125 changes: 103 additions & 22 deletions internal/protogen/exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func newBookExporter(protoPackage string, protoFileOptions map[string]string, ou
}

func (x *bookExporter) GetProtoFilePath() string {
return genProtoFilePath(x.wb.Name, x.FilenameSuffix)
return genProtoFilePath(x.wb.GetName(), x.FilenameSuffix)
}

func (x *bookExporter) export(checkProtoFileConflicts bool) error {
Expand Down Expand Up @@ -203,11 +203,17 @@ func (x *sheetExporter) exportStruct() error {
opts := &tableaupb.StructOptions{Name: x.ws.GetOptions().GetName(), Note: x.ws.Note}
x.p.P(" option (tableau.struct) = {", marshalToText(opts), "};")
x.p.P("")

oldMD := x.findMDFromGeneratedProtos(x.ws.Name)
x.assignFieldNumbers(x.ws.Fields, oldMD)
// generate the fields
depth := 1
for i, field := range x.ws.Fields {
tagid := i + 1
if err := x.exportField(depth, tagid, field, x.ws.Name); err != nil {
for _, field := range x.ws.Fields {
var fd protoreflect.FieldDescriptor
if oldMD != nil {
fd = oldMD.Fields().ByNumber(protoreflect.FieldNumber(field.GetNumber()))
}
if err := x.exportField(depth, field, x.ws.Name, fd); err != nil {
return err
}
}
Expand Down Expand Up @@ -273,16 +279,16 @@ func (x *sheetExporter) exportUnion() error {
x.p.P(" message ", typ, " {")
// generate the fields
depth := 2
tagid := 1
fieldNumber := int32(1)
for _, field := range msgField.Fields {
if err := x.exportField(depth, tagid, field, msgField.Name); err != nil {
field.Number = fieldNumber
cross := max(field.GetOptions().GetProp().GetCross(), 1)
fieldNumber += cross
}
for _, field := range msgField.Fields {
if err := x.exportField(depth, field, msgField.Name, nil); err != nil {
return err
}
cross := int(field.GetOptions().GetProp().GetCross())
if cross < 1 {
cross = 1
}
tagid += cross
}
x.p.P(" }")
}
Expand All @@ -294,6 +300,58 @@ func (x *sheetExporter) exportUnion() error {
return nil
}

// findMDFromGeneratedProtos finds the MessageDescriptor in the generated proto
// files by message name. It returns nil if not found.
//
// NOTE: Even if the message is moved to another proto file, we still can find it
// in the generated proto files.
func (x *sheetExporter) findMDFromGeneratedProtos(name string) protoreflect.MessageDescriptor {
if !x.be.gen.OutputOpt.PreserveFieldNumbers {
return nil
}
fullName := protoreflect.FullName(x.be.ProtoPackage).Append(protoreflect.Name(name))
descriptor, err := x.be.gen.ProtoRegistryFiles.FindDescriptorByName(fullName)
if err != nil {
return nil
}
return descriptor.(protoreflect.MessageDescriptor)
}

// assignFieldNumbers assigns the field numbers to the fields. It uses the old
// MD to preserve field numbers if provided, otherwise it assigns field numbers
// in sequence starting from 1.
func (*sheetExporter) assignFieldNumbers(fields []*internalpb.Field, oldMD protoreflect.MessageDescriptor) {
if oldMD == nil {
fieldNumber := int32(1)
for _, field := range fields {
field.Number = fieldNumber
fieldNumber++
}
return
}
fieldNameNumberMap := make(map[string]int32)
var maxFieldNumber int32
for i := 0; i < oldMD.Fields().Len(); i++ {
fd := oldMD.Fields().Get(i)
for _, field := range fields {
if string(fd.Name()) == field.Name {
fieldNameNumberMap[field.Name] = int32(fd.Number())
}
}
maxFieldNumber = max(maxFieldNumber, int32(fd.Number()))
}
for _, field := range fields {
if number, ok := fieldNameNumberMap[field.Name]; ok {
// for existing field, use the old field number.
field.Number = number
} else {
// for new field, assign the max field number plus 1 in the same level.
maxFieldNumber++
field.Number = maxFieldNumber
}
}
}

func (x *sheetExporter) exportMessager() error {
// log.Debugf("workbook: %s", x.ws.String())
if x.be.messagerPatternRegexp != nil && !x.be.messagerPatternRegexp.MatchString(x.ws.Name) {
Expand All @@ -302,11 +360,17 @@ func (x *sheetExporter) exportMessager() error {
x.p.P("message ", x.ws.Name, " {")
x.p.P(" option (tableau.worksheet) = {", marshalToText(x.ws.Options), "};")
x.p.P("")

md := x.findMDFromGeneratedProtos(x.ws.Name)
x.assignFieldNumbers(x.ws.Fields, md)
// generate the fields
depth := 1
for i, field := range x.ws.Fields {
tagid := i + 1
if err := x.exportField(depth, tagid, field, x.ws.Name); err != nil {
for _, field := range x.ws.Fields {
var fd protoreflect.FieldDescriptor
if md != nil {
fd = md.Fields().ByNumber(protoreflect.FieldNumber(field.GetNumber()))
}
if err := x.exportField(depth, field, x.ws.Name, fd); err != nil {
return err
}
}
Expand All @@ -317,7 +381,7 @@ func (x *sheetExporter) exportMessager() error {
return nil
}

func (x *sheetExporter) exportField(depth int, tagid int, field *internalpb.Field, prefix string) error {
func (x *sheetExporter) exportField(depth int, field *internalpb.Field, prefix string, oldFD protoreflect.FieldDescriptor) error {
label := ""
if x.ws.GetOptions().GetFieldPresence() &&
types.IsScalarType(field.FullType) &&
Expand All @@ -328,17 +392,29 @@ func (x *sheetExporter) exportField(depth int, tagid int, field *internalpb.Fiel
if field.Note != "" {
note = " // " + field.Note
}
x.p.P(printer.Indent(depth), label, field.FullType, " ", field.Name, " = ", tagid, " ", genFieldOptionsString(field.Options), ";", note)
x.p.P(printer.Indent(depth), label, field.FullType, " ", field.Name, " = ", field.Number, " ", genFieldOptionsString(field.Options), ";", note)

var oldMD protoreflect.MessageDescriptor
typeName := field.Type
fullTypeName := field.FullType
if field.ListEntry != nil {
typeName = field.ListEntry.ElemType
fullTypeName = field.ListEntry.ElemFullType
}
if field.MapEntry != nil {
if oldFD != nil {
oldMD = oldFD.Message()
}
} else if field.MapEntry != nil {
typeName = field.MapEntry.ValueType
fullTypeName = field.MapEntry.ValueFullType
if oldFD != nil {
if v := oldFD.MapValue(); v != nil {
oldMD = v.Message()
}
}
} else {
if oldFD != nil {
oldMD = oldFD.Message()
}
}

if types.IsWellKnownMessage(fullTypeName) {
Expand Down Expand Up @@ -372,14 +448,19 @@ func (x *sheetExporter) exportField(depth int, tagid int, field *internalpb.Fiel
default:
return nil
}
// bookkeeping this nested msessage, so we can check if we can reuse it later.
// bookkeeping this nested message, so we can check if we can reuse it later.
x.nestedMessages[nestedMsgName] = field

// x.g.P("")
x.p.P(printer.Indent(depth), "message ", typeName, " {")
for i, f := range field.Fields {
tagid := i + 1
if err := x.exportField(depth+1, tagid, f, nestedMsgName); err != nil {

x.assignFieldNumbers(field.Fields, oldMD)
for _, subField := range field.Fields {
var nestedOldFD protoreflect.FieldDescriptor
if oldMD != nil {
nestedOldFD = oldMD.Fields().ByNumber(protoreflect.FieldNumber(subField.GetNumber()))
}
if err := x.exportField(depth+1, subField, nestedMsgName, nestedOldFD); err != nil {
return err
}
}
Expand Down
Loading
Loading