Skip to content

Commit

Permalink
fix: encode (u)int(16|8)s as varints
Browse files Browse the repository at this point in the history
This commit changes encoding of smaller ints to varint scheme. It should make our enum types compatible with proto enum types.

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
  • Loading branch information
DmitriyMV committed Sep 8, 2022
1 parent d8ddbd5 commit 82f0774
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 42 deletions.
12 changes: 6 additions & 6 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ func (m *marshaller) encodeValue(num protowire.Number, val reflect.Value) {
putBool(m, val.Bool())

case reflect.Int8, reflect.Int16:
putTag(m, num, protowire.Fixed32Type)
putInt32(m, int32(val.Int()))
putTag(m, num, protowire.VarintType)
putUVarint(m, val.Int())

case reflect.Uint8, reflect.Uint16:
putTag(m, num, protowire.Fixed32Type)
putInt32(m, int32(val.Uint()))
putTag(m, num, protowire.VarintType)
putUVarint(m, val.Uint())

case reflect.Int, reflect.Int32, reflect.Int64:
putTag(m, num, protowire.VarintType)
Expand Down Expand Up @@ -399,12 +399,12 @@ func (m *marshaller) sliceReflect(key protowire.Number, val reflect.Value) {
switch elem.Kind() { //nolint:exhaustive
case reflect.Int8, reflect.Int16:
for i := 0; i < sliceLen; i++ {
putInt32(&result, int32(val.Index(i).Int()))
putUVarint(&result, val.Index(i).Int())
}

case reflect.Uint8, reflect.Uint16:
for i := 0; i < sliceLen; i++ {
putInt32(&result, uint32(val.Index(i).Uint()))
putUVarint(&result, val.Index(i).Uint())
}

case reflect.Bool:
Expand Down
12 changes: 11 additions & 1 deletion messages/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,22 @@ type msg[T any] interface {
proto.Message
}

func runTestPipe[R any, RP msg[R], T any](t *testing.T, original T) {
encoded1 := must(protoenc.Marshal(&original))(t)
decoded := protoUnmarshal[R, RP](t, encoded1)
encoded2 := must(proto.Marshal(decoded))(t)
result := ourUnmarshal[T](t, encoded2)

shouldBeEqual(t, original, result)
}

func protoUnmarshal[T any, V msg[T]](t *testing.T, data []byte) V {
t.Helper()

var msg T

err := proto.Unmarshal(data, V(&msg))
err := proto.UnmarshalOptions{DiscardUnknown: true}.Unmarshal(data, V(&msg))

require.NoError(t, err)

return &msg
Expand Down
42 changes: 19 additions & 23 deletions messages/messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
package messages_test

import (
"encoding/hex"
"testing"

"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/siderolabs/protoenc"
Expand All @@ -18,15 +16,6 @@ import (

// TODO: ensure that binary output is also the same

func runTestPipe[R any, RP msg[R], T any](t *testing.T, original T) {
encoded1 := must(protoenc.Marshal(&original))(t)
decoded := protoUnmarshal[R, RP](t, encoded1)
encoded2 := must(proto.Marshal(decoded))(t)
result := ourUnmarshal[T](t, encoded2)

shouldBeEqual(t, original, result)
}

//nolint:govet
type BasicMessage struct {
Int64 int64 `protobuf:"1"`
Expand Down Expand Up @@ -320,29 +309,36 @@ func TestEmptyMessage(t *testing.T) {
})
}

func TestEnumMessage(t *testing.T) {
// This test ensures that we can decode a message with an enum field.
// Even tho we use fixed 32-bit values for encoding enums (unlike protobuf) decoding into int8-16s should still work.
func TestEnumMessage_CompatibleOldScheme(t *testing.T) {
// This test ensures that we can decode a message with an enum field encoded by previus version of our encoder.
t.Parallel()

encoded := []byte{0x0d, 0x01, 0x00, 0x00, 0x00}

type Enum int8

type EnumMessage struct {
EnumField Enum `protobuf:"1"`
}

original := messages.EnumMessage{
EnumField: messages.Enum_ENUM2,
}
dest := EnumMessage{}

encoded, err := proto.Marshal(&original)
err := protoenc.Unmarshal(encoded, &dest)
require.NoError(t, err)

t.Log("\n", hex.Dump(encoded))
require.EqualValues(t, dest.EnumField, 1)
}

decoded := EnumMessage{}
err = protoenc.Unmarshal(encoded, &decoded)
require.NoError(t, err)
func TestEnumMessage(t *testing.T) {
t.Parallel()

type Enum int8

type EnumMessage struct {
EnumField Enum `protobuf:"1"`
}

require.EqualValues(t, original.EnumField, decoded.EnumField)
runTestPipe[messages.EnumMessage](t, EnumMessage{
EnumField: 1,
})
}
2 changes: 1 addition & 1 deletion scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ func (s *dataScanner) Wiretype() protowire.Type {
func getDataScannerFor(eltype reflect.Type, buf []byte) (dataScanner, bool, error) {
switch eltype.Kind() { //nolint:exhaustive
case reflect.Uint8, reflect.Uint16, reflect.Int8, reflect.Int16:
return makeDataScanner(protowire.Fixed32Type, buf), true, nil
return makeDataScanner(protowire.VarintType, buf), true, nil

case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Int,
reflect.Uint32, reflect.Uint64, reflect.Uint:
Expand Down
23 changes: 12 additions & 11 deletions slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ func testSliceEncodingResult[T any](slc []T, expected []byte) func(t *testing.T)
func TestSmallIntegers(t *testing.T) {
t.Parallel()

encodedBytes := hexToBytes(t, "0a 03 01 FF 03")
encodedFixed := hexToBytes(t, "0a 0c [01 00 00 00] [ff 00 00 00] [03 00 00 00]")
encodedFixedNegative := hexToBytes(t, "0a 0c [01 00 00 00] [ff ff ff ff] [03 00 00 00]")
encodedUint16s := hexToBytes(t, "0a 0c [01 00 00 00] [ff ff 00 00] [03 00 00 00]")
encodedBytes := hexToBytes(t, "0A 03 01 FF 03")
encodedFixed := hexToBytes(t, "0A 04 01 FF 01 03")
encodedFixedNegative := hexToBytes(t, "0A 0C [01] [FF FF FF FF FF FF FF FF FF 01] [03]")
encodedUint16s := hexToBytes(t, "0A 05 [01] [FF FF 03] [03]")

type customByte byte

Expand All @@ -215,7 +215,7 @@ func TestSmallIntegers(t *testing.T) {
CustomByte customByte `protobuf:"5"`
}

encodedCustomType := hexToBytes(t, "0a 19 [0d [ff ff ff ff]] [1d [ff ff 00 00]] [15 [ff ff ff ff]] [25 [ff 00 00 00]] [2d [ff 00 00 00]]")
encodedCustomType := hexToBytes(t, "0a 20 [08 [FF FF FF FF FF FF FF FF FF 01] 18 [FF FF 03] 10 [FF FF FF FF FF FF FF FF FF 01] 20 [FF 01] 28 [FF 01]]")

tests := []struct { //nolint:govet
name string
Expand All @@ -226,31 +226,31 @@ func TestSmallIntegers(t *testing.T) {
testEncodeDecodeWrapped([...]byte{1, 0xFF, 3}, encodedBytes),
},
{
"array of custom byte types should be encoded in 'fixed32' form",
"array of custom byte type should be encoded in 'varint' form",
testEncodeDecodeWrapped([...]customByte{1, 0xFF, 3}, encodedFixed),
},
{
"slice of custom byte type should be encoded in 'fixed32' form",
"slice of custom byte type should be encoded in 'varint' form",
testEncodeDecodeWrapped([]customByte{1, 0xFF, 3}, encodedFixed),
},
{
"slice of int8 should be encoded in 'fixed32' form",
"slice of int8 should be encoded in 'varint' form",
testEncodeDecodeWrapped([]int8{1, -1, 3}, encodedFixedNegative),
},
{
"slice of int16 type should be encoded in 'fixed32' form",
"slice of int16 type should be encoded in 'varint' form",
testEncodeDecodeWrapped([]int16{1, -1, 3}, encodedFixedNegative),
},
{
"slice of uint16 type should be encoded in 'fixed32' form",
"slice of uint16 type should be encoded in 'varint' form",
testEncodeDecodeWrapped([]uint16{1, 0xFFFF, 3}, encodedUint16s),
},
{
"customSlice should be encoded in 'bytes' form",
testEncodeDecodeWrapped(customSlice{1, 0xFF, 3}, encodedBytes),
},
{
"customType should be encoded in 'fixed32' form",
"customType should be encoded in 'varint' form",
testEncodeDecodeWrapped(customType{
Int16: -1,
Uint16: 0xFFFF,
Expand All @@ -269,6 +269,7 @@ func TestSmallIntegers(t *testing.T) {

func testEncodeDecodeWrapped[T any](slc T, expected []byte) func(t *testing.T) {
return func(t *testing.T) {
t.Helper()
t.Parallel()

original := Value[T]{V: slc}
Expand Down
1 change: 1 addition & 0 deletions unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ func unmarshalByteSeqeunce(dst reflect.Value, val complexValue) error {
}

func slice(dst reflect.Value, val complexValue) error {
// TODO: this code doesn't support the case when slice is encoded in several chunks across the message
elemType := dst.Type().Elem()

// we only decode bytes as []byte or [n]byte field
Expand Down

0 comments on commit 82f0774

Please sign in to comment.