diff --git a/thrift/decode.go b/thrift/decode.go index 14f6528..9343ed2 100644 --- a/thrift/decode.go +++ b/thrift/decode.go @@ -393,6 +393,12 @@ func (dec *structDecoder) decode(r Reader, v reflect.Value, flags flags) error { lastField = x if coalesceBoolFields && (f.Type == TRUE || f.Type == FALSE) { + for x.Kind() == reflect.Ptr { + if x.IsNil() { + x.Set(reflect.New(x.Type().Elem())) + } + x = x.Elem() + } x.SetBool(f.Type == TRUE) return nil } diff --git a/thrift/encode.go b/thrift/encode.go index 6faa334..bd2c3a9 100644 --- a/thrift/encode.go +++ b/thrift/encode.go @@ -243,6 +243,21 @@ type structEncoder struct { union bool } +func dereference(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Ptr { + if v.IsNil() { + return v + } + v = v.Elem() + } + return v +} + +func isTrue(v reflect.Value) bool { + v = dereference(v) + return v.IsValid() && v.Kind() == reflect.Bool && v.Bool() +} + func (enc *structEncoder) encode(w Writer, v reflect.Value, flags flags) error { useDeltaEncoding := flags.have(useDeltaEncoding) coalesceBoolFields := flags.have(coalesceBoolFields) @@ -280,7 +295,7 @@ encodeFields: } skipValue := coalesceBoolFields && field.Type == BOOL - if skipValue && x.Bool() == true { + if skipValue && isTrue(x) == true { field.Type = TRUE } @@ -376,6 +391,8 @@ func encodeFuncPtrOf(t reflect.Type, seen encodeFuncCache) encodeFunc { return func(w Writer, v reflect.Value, f flags) error { if v.IsNil() { v = zero + } else { + v = v.Elem() } return enc(w, v, f) } diff --git a/thrift/struct.go b/thrift/struct.go index 4bc61eb..aa556f3 100644 --- a/thrift/struct.go +++ b/thrift/struct.go @@ -60,9 +60,17 @@ func forEachStructField(t reflect.Type, index []int, do func(structField)) { fieldIndex := append(index, i) fieldIndex = fieldIndex[:len(fieldIndex):len(fieldIndex)] - if f.Anonymous && f.Type.Kind() == reflect.Struct { - forEachStructField(f.Type, fieldIndex, do) - continue + if f.Anonymous { + fieldType := f.Type + + for fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + + if fieldType.Kind() == reflect.Struct { + forEachStructField(fieldType, fieldIndex, do) + continue + } } tag := f.Tag.Get("thrift") diff --git a/thrift/thrift_test.go b/thrift/thrift_test.go index e8f46e8..04c8efd 100644 --- a/thrift/thrift_test.go +++ b/thrift/thrift_test.go @@ -148,7 +148,7 @@ var marshalTestValues = [...]struct { RecursiveStruct{}, RecursiveStruct{Value: "hello"}, RecursiveStruct{Value: "hello", Next: &RecursiveStruct{}}, - RecursiveStruct{Value: "hello", Next: &RecursiveStruct{Value: "world"}}, + RecursiveStruct{Value: "hello", Next: &RecursiveStruct{Value: "world", Test: newBool(true)}}, }, }, @@ -161,6 +161,26 @@ var marshalTestValues = [...]struct { }, }, + { + scenario: "StructWithPointToPointerToBool", + values: []interface{}{ + StructWithPointerToPointerToBool{ + Test: newBoolPtr(true), + }, + }, + }, + + { + scenario: "StructWithEmbeddedStrutPointerWithPointerToPointer", + values: []interface{}{ + StructWithEmbeddedStrutPointerWithPointerToPointer{ + StructWithPointerToPointerToBool: &StructWithPointerToPointerToBool{ + Test: newBoolPtr(true), + }, + }, + }, + }, + { scenario: "Union", values: []interface{}{ @@ -180,12 +200,21 @@ type Point2D struct { type RecursiveStruct struct { Value string `thrift:"1"` Next *RecursiveStruct `thrift:"2"` + Test *bool `thrift:"3"` } type StructWithEnum struct { Enum int8 `thrift:"1,enum"` } +type StructWithPointerToPointerToBool struct { + Test **bool `thrift:"1"` +} + +type StructWithEmbeddedStrutPointerWithPointerToPointer struct { + *StructWithPointerToPointerToBool +} + type Union struct { A bool `thrift:"1"` B int `thrift:"2"` @@ -197,6 +226,11 @@ func newBool(b bool) *bool { return &b } func newInt(i int) *int { return &i } func newString(s string) *string { return &s } +func newBoolPtr(b bool) **bool { + p := newBool(b) + return &p +} + func TestMarshalUnmarshal(t *testing.T) { for _, p := range protocols { t.Run(p.name, func(t *testing.T) { testMarshalUnmarshal(t, p.proto) })