Skip to content

Commit

Permalink
Merge pull request #76 from planetscale/empty_bytes
Browse files Browse the repository at this point in the history
Fix handling of optional bytes in proto3
  • Loading branch information
vmg committed Jan 27, 2023
2 parents c793337 + d192986 commit 96ede25
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 52 deletions.
19 changes: 5 additions & 14 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,22 @@ jobs:
steps:
- uses: actions/setup-go@v2
with:
go-version: '^1.18'
go-version: '^1.19'

- uses: actions/checkout@v2

- name: Cache protobuf build
id: protocache
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: _vendor/protobuf-3.20.0
key: protobuf-3.20.0
path: _vendor/protobuf-21.12
key: protobuf-21.12

- name: Compile protobuf
if: steps.protocache.outputs.cache-hit != 'true'
run: |
sudo apt-get install -y autoconf automake libtool curl make g++ unzip
mkdir -p _vendor
curl -#fsSL https://github.com/protocolbuffers/protobuf/releases/download/v3.20.0/protobuf-all-3.20.0.tar.gz | tar -xzvf - -C _vendor
cd _vendor/protobuf-3.20.0
./configure
make
make -C conformance
ls -l src/
ls -l conformance/
./protobuf.sh
- run: make install && go mod tidy && go mod verify
- run: git --no-pager diff --exit-code
Expand Down
19 changes: 9 additions & 10 deletions features/marshal/marshalto.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ type marshal struct {
var _ generator.FeatureGenerator = (*marshal)(nil)

func (p *marshal) GenerateFile(file *protogen.File) bool {
proto3 := file.Desc.Syntax() == protoreflect.Proto3
for _, message := range file.Messages {
p.message(proto3, message)
p.message(message)
}
return p.once
}
Expand Down Expand Up @@ -130,9 +129,9 @@ func (p *marshal) mapField(kvField *protogen.Field, varName string) {
}
}

func (p *marshal) field(proto3, oneof bool, numGen *counter, field *protogen.Field) {
func (p *marshal) field(oneof bool, numGen *counter, field *protogen.Field) {
fieldname := field.GoName
nullable := field.Message != nil || (field.Oneof != nil && field.Oneof.Desc.IsSynthetic()) || (!proto3 && !oneof)
nullable := field.Message != nil || (!oneof && field.Desc.HasPresence())
repeated := field.Desc.Cardinality() == protoreflect.Repeated
if repeated {
p.P(`if len(m.`, fieldname, `) > 0 {`)
Expand Down Expand Up @@ -433,7 +432,7 @@ func (p *marshal) field(proto3, oneof bool, numGen *counter, field *protogen.Fie
p.encodeVarint(`len(`, val, `)`)
p.encodeKey(fieldNumber, wireType)
p.P(`}`)
} else if !oneof && proto3 {
} else if !oneof && !field.Desc.HasPresence() {
p.P(`if len(m.`, fieldname, `) > 0 {`)
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
Expand Down Expand Up @@ -569,9 +568,9 @@ func (p *marshal) methodMarshal() string {
}
}

func (p *marshal) message(proto3 bool, message *protogen.Message) {
func (p *marshal) message(message *protogen.Message) {
for _, nested := range message.Messages {
p.message(proto3, nested)
p.message(nested)
}

if message.Desc.IsMapEntry() {
Expand Down Expand Up @@ -631,7 +630,7 @@ func (p *marshal) message(proto3 bool, message *protogen.Message) {
field := message.Fields[i]
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
p.field(proto3, false, &numGen, field)
p.field(false, &numGen, field)
} else {
p.P(`if msg, ok := m.`, field.Oneof.GoName, `.(*`, field.GoIdent.GoName, `); ok {`)
marshalForwardOneOf("msg")
Expand Down Expand Up @@ -663,7 +662,7 @@ func (p *marshal) message(proto3 bool, message *protogen.Message) {
field := message.Fields[i]
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
p.field(proto3, false, &numGen, field)
p.field(false, &numGen, field)
}
}
}
Expand All @@ -685,7 +684,7 @@ func (p *marshal) message(proto3 bool, message *protogen.Message) {
p.P(``)
p.P(`func (m *`, ccTypeName, `) `, p.methodMarshalToSizedBuffer(), `(dAtA []byte) (int, error) {`)
p.P(`i := len(dAtA)`)
p.field(proto3, true, &numGen, field)
p.field(true, &numGen, field)
p.P(`return len(dAtA) - i, nil`)
p.P(`}`)
}
Expand Down
17 changes: 8 additions & 9 deletions features/size/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ func (p *size) Name() string {
}

func (p *size) GenerateFile(file *protogen.File) bool {
proto3 := file.Desc.Syntax() == protoreflect.Proto3
for _, message := range file.Messages {
p.message(proto3, message)
p.message(message)
}

return p.once
Expand Down Expand Up @@ -72,9 +71,9 @@ func (p *size) messageSize(varName, sizeName string, message *protogen.Message)
}
}

func (p *size) field(proto3, oneof bool, field *protogen.Field, sizeName string) {
func (p *size) field(oneof bool, field *protogen.Field, sizeName string) {
fieldname := field.GoName
nullable := field.Message != nil || (field.Oneof != nil && field.Oneof.Desc.IsSynthetic()) || (!proto3 && !oneof)
nullable := field.Message != nil || (!oneof && field.Desc.HasPresence())
repeated := field.Desc.Cardinality() == protoreflect.Repeated
if repeated {
p.P(`if len(m.`, fieldname, `) > 0 {`)
Expand Down Expand Up @@ -239,7 +238,7 @@ func (p *size) field(proto3, oneof bool, field *protogen.Field, sizeName string)
p.P(`l = len(b)`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov(uint64(l))`)
p.P(`}`)
} else if !oneof && proto3 {
} else if !oneof && !field.Desc.HasPresence() {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`if l > 0 {`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov(uint64(l))`)
Expand Down Expand Up @@ -276,9 +275,9 @@ func (p *size) field(proto3, oneof bool, field *protogen.Field, sizeName string)
}
}

func (p *size) message(proto3 bool, message *protogen.Message) {
func (p *size) message(message *protogen.Message) {
for _, nested := range message.Messages {
p.message(proto3, nested)
p.message(nested)
}

if message.Desc.IsMapEntry() {
Expand All @@ -300,7 +299,7 @@ func (p *size) message(proto3 bool, message *protogen.Message) {
for _, field := range message.Fields {
oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
if !oneof {
p.field(proto3, false, field, sizeName)
p.field(false, field, sizeName)
} else {
fieldname := field.Oneof.GoName
if _, ok := oneofs[fieldname]; !ok {
Expand All @@ -327,7 +326,7 @@ func (p *size) message(proto3 bool, message *protogen.Message) {
p.P(`}`)
p.P(`var l int`)
p.P(`_ = l`)
p.field(proto3, true, field, sizeName)
p.field(true, field, sizeName)
p.P(`return n`)
p.P(`}`)
}
Expand Down
7 changes: 5 additions & 2 deletions protobuf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ if [ -f "$PROTOBUF_PATH/protoc" ]; then
exit 0
fi

mkdir -p _vendor
curl -sS -L -o "$ROOT/_vendor/pb.tar.gz" http://github.com/protocolbuffers/protobuf/releases/download/v${PROTOBUF_VERSION}/protobuf-all-${PROTOBUF_VERSION}.tar.gz

cd "$ROOT/_vendor"
tar zxf pb.tar.gz

cd protobuf-${PROTOBUF_VERSION}
./configure --quiet
./configure
make
cd conformance/ && make
make -C conformance

echo "Dowloaded and compiled protobuf $PROTOBUF_VERSION to $PROTOBUF_PATH"
ls -l src/
ls -l conformance/
28 changes: 11 additions & 17 deletions testproto/proto3opt/opt_vtproto.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions testproto/proto3opt/proto3opt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package proto3opt

import (
"testing"

"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)

func TestEmptyBytesMarshalling(t *testing.T) {
a := &OptionalFieldInProto3{
OptionalBytes: nil,
}
b := &OptionalFieldInProto3{
OptionalBytes: []byte{},
}

type Message interface {
proto.Message
MarshalVT() ([]byte, error)
}

for _, msg := range []Message{a, b} {
vt, err := msg.MarshalVT()
require.NoError(t, err)
goog, err := proto.Marshal(msg)
require.NoError(t, err)
require.Equal(t, goog, vt)
}
}

0 comments on commit 96ede25

Please sign in to comment.