diff --git a/hooks.go b/hooks.go index 334043a..687ff00 100644 --- a/hooks.go +++ b/hooks.go @@ -245,7 +245,7 @@ func getMapUpdates(destType reflect.Type, srcData reflect.Value, hook FieldHookF updates := make(map[interface{}]interface{}) var errors []error - decodableFields := getDecodableStructFields(destType) + decodableFields := getDecodableStructFields(destType, tagName) for _, structField := range decodableFields { // This field resolution logic is adapted from mapstructure's own // logic. @@ -309,7 +309,7 @@ func getMapUpdates(destType reflect.Type, srcData reflect.Value, hook FieldHookF return updates, nil } -func getDecodableStructFields(structType reflect.Type) []reflect.StructField { +func getDecodableStructFields(structType reflect.Type, tagName string) []reflect.StructField { fields := make([]reflect.StructField, 0) for i := 0; i < structType.NumField(); i++ { structField := structType.Field(i) @@ -318,7 +318,26 @@ func getDecodableStructFields(structType reflect.Type) []reflect.StructField { // into it. continue } - // TODO account for embedded struct fields + + // TODO pull in the squash info from the decoder options + squash := structField.Anonymous + tagParts := strings.Split(structField.Tag.Get(tagName), ",") + for _, tag := range tagParts[1:] { + switch tag { + case "squash": + squash = true + case "nosquash": // explicit opt-out + squash = false + default: + continue + } + break + } + + if squash && structField.Type.Kind() == reflect.Struct { + fields = append(fields, getDecodableStructFields(structField.Type, tagName)...) + continue + } fields = append(fields, structField) } return fields diff --git a/hooks_test.go b/hooks_test.go index 5676e99..b03e23f 100644 --- a/hooks_test.go +++ b/hooks_test.go @@ -65,6 +65,45 @@ func TestMultipleFieldHooks(t *testing.T) { assert.Equal(t, 42, dest.Int) } +func TestEmbeddedFieldHooks(t *testing.T) { + type subDest struct { + Int int + } + var dest struct { + subDest + } + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + hook1 := newMockFieldHook(mockCtrl) + hook2 := newMockFieldHook(mockCtrl) + + typeOfInt := reflect.TypeOf(42) + + hook1. + Expect(_typeOfEmptyInterface, structField{ + Name: "Int", + Type: typeOfInt, + }, reflectEq{"FOO"}). + Return(valueOf("BAR"), nil) + + hook2. + Expect(reflect.TypeOf(""), structField{ + Name: "Int", + Type: typeOfInt, + }, reflectEq{"BAR"}). + Return(valueOf(42), nil) + + err := Decode(&dest, map[string]interface{}{"int": "FOO"}, + FieldHook(hook1.Hook()), + FieldHook(hook2.Hook()), + ) + require.NoError(t, err) + + assert.Equal(t, 42, dest.Int) +} + func TestMultipleDecodeHooks(t *testing.T) { type myStruct struct{ String string }