Skip to content

Commit

Permalink
fix thrift encoding of structs with pointer fields (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
Achille authored Nov 20, 2021
1 parent 37e513f commit 000439e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 5 deletions.
6 changes: 6 additions & 0 deletions thrift/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,12 @@ func (dec *structDecoder) decode(r Reader, v reflect.Value, flags flags) error {
lastField = x

if coalesceBoolFields && (f.Type == TRUE || f.Type == FALSE) {
for x.Kind() == reflect.Ptr {
if x.IsNil() {
x.Set(reflect.New(x.Type().Elem()))
}
x = x.Elem()
}
x.SetBool(f.Type == TRUE)
return nil
}
Expand Down
19 changes: 18 additions & 1 deletion thrift/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,21 @@ type structEncoder struct {
union bool
}

func dereference(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Ptr {
if v.IsNil() {
return v
}
v = v.Elem()
}
return v
}

func isTrue(v reflect.Value) bool {
v = dereference(v)
return v.IsValid() && v.Kind() == reflect.Bool && v.Bool()
}

func (enc *structEncoder) encode(w Writer, v reflect.Value, flags flags) error {
useDeltaEncoding := flags.have(useDeltaEncoding)
coalesceBoolFields := flags.have(coalesceBoolFields)
Expand Down Expand Up @@ -280,7 +295,7 @@ encodeFields:
}

skipValue := coalesceBoolFields && field.Type == BOOL
if skipValue && x.Bool() == true {
if skipValue && isTrue(x) == true {
field.Type = TRUE
}

Expand Down Expand Up @@ -376,6 +391,8 @@ func encodeFuncPtrOf(t reflect.Type, seen encodeFuncCache) encodeFunc {
return func(w Writer, v reflect.Value, f flags) error {
if v.IsNil() {
v = zero
} else {
v = v.Elem()
}
return enc(w, v, f)
}
Expand Down
14 changes: 11 additions & 3 deletions thrift/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,17 @@ func forEachStructField(t reflect.Type, index []int, do func(structField)) {
fieldIndex := append(index, i)
fieldIndex = fieldIndex[:len(fieldIndex):len(fieldIndex)]

if f.Anonymous && f.Type.Kind() == reflect.Struct {
forEachStructField(f.Type, fieldIndex, do)
continue
if f.Anonymous {
fieldType := f.Type

for fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}

if fieldType.Kind() == reflect.Struct {
forEachStructField(fieldType, fieldIndex, do)
continue
}
}

tag := f.Tag.Get("thrift")
Expand Down
36 changes: 35 additions & 1 deletion thrift/thrift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ var marshalTestValues = [...]struct {
RecursiveStruct{},
RecursiveStruct{Value: "hello"},
RecursiveStruct{Value: "hello", Next: &RecursiveStruct{}},
RecursiveStruct{Value: "hello", Next: &RecursiveStruct{Value: "world"}},
RecursiveStruct{Value: "hello", Next: &RecursiveStruct{Value: "world", Test: newBool(true)}},
},
},

Expand All @@ -161,6 +161,26 @@ var marshalTestValues = [...]struct {
},
},

{
scenario: "StructWithPointToPointerToBool",
values: []interface{}{
StructWithPointerToPointerToBool{
Test: newBoolPtr(true),
},
},
},

{
scenario: "StructWithEmbeddedStrutPointerWithPointerToPointer",
values: []interface{}{
StructWithEmbeddedStrutPointerWithPointerToPointer{
StructWithPointerToPointerToBool: &StructWithPointerToPointerToBool{
Test: newBoolPtr(true),
},
},
},
},

{
scenario: "Union",
values: []interface{}{
Expand All @@ -180,12 +200,21 @@ type Point2D struct {
type RecursiveStruct struct {
Value string `thrift:"1"`
Next *RecursiveStruct `thrift:"2"`
Test *bool `thrift:"3"`
}

type StructWithEnum struct {
Enum int8 `thrift:"1,enum"`
}

type StructWithPointerToPointerToBool struct {
Test **bool `thrift:"1"`
}

type StructWithEmbeddedStrutPointerWithPointerToPointer struct {
*StructWithPointerToPointerToBool
}

type Union struct {
A bool `thrift:"1"`
B int `thrift:"2"`
Expand All @@ -197,6 +226,11 @@ func newBool(b bool) *bool { return &b }
func newInt(i int) *int { return &i }
func newString(s string) *string { return &s }

func newBoolPtr(b bool) **bool {
p := newBool(b)
return &p
}

func TestMarshalUnmarshal(t *testing.T) {
for _, p := range protocols {
t.Run(p.name, func(t *testing.T) { testMarshalUnmarshal(t, p.proto) })
Expand Down

0 comments on commit 000439e

Please sign in to comment.