Skip to content
This repository has been archived by the owner on Oct 2, 2020. It is now read-only.

Commit

Permalink
Merge 7ef1471 into 003e700
Browse files Browse the repository at this point in the history
  • Loading branch information
willhug committed May 19, 2017
2 parents 003e700 + 7ef1471 commit c756d00
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 90 deletions.
214 changes: 124 additions & 90 deletions hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ func strconvHook(from, to reflect.Type, data reflect.Value) (reflect.Value, erro
// fieldHook applies the user-specified FieldHookFunc to all struct fields.
func fieldHook(opts *options) DecodeHookFunc {
hook := composeFieldHooks(opts.FieldHooks)
return func(from, to reflect.Type, data reflect.Value) (reflect.Value, error) {
if to.Kind() != reflect.Struct || from.Kind() != reflect.Map {
return data, nil
return func(_, destType reflect.Type, srcData reflect.Value) (reflect.Value, error) {
if destType.Kind() != reflect.Struct || srcData.Type().Kind() != reflect.Map {
return srcData, nil
}

// We can only decode map[string]* and map[interface{}]* into structs.
if k := from.Key().Kind(); k != reflect.String && k != reflect.Interface {
return data, nil
if k := srcData.Type().Key().Kind(); k != reflect.String && k != reflect.Interface {
return srcData, nil
}

// This map tracks type-changing updates to items in the map.
Expand All @@ -227,108 +227,142 @@ func fieldHook(opts *options) DecodeHookFunc {
// values in-place if a hook changed the type of a value. So we will
// make a copy of the source map with a more liberal type and inject
// these updates into the copy.
updates := make(map[interface{}]interface{})

var errors []error
for i := 0; i < to.NumField(); i++ {
structField := to.Field(i)
if structField.PkgPath != "" && !structField.Anonymous {
// This field is not exported so we won't be able to decode
// into it.
continue
}
updates, err := getMapUpdates(destType, srcData, hook, opts.TagName)
if err != nil {
return srcData, err
}

// This field resolution logic is adapted from mapstructure's own
// logic.
//
// See https://github.com/mitchellh/mapstructure/blob/53818660ed4955e899c0bcafa97299a388bd7c8e/mapstructure.go#L741
// No more changes to make.
if len(updates) == 0 {
return srcData, nil
}

fieldName := structField.Name
return applyUpdates(updates, srcData), nil
}
}

// Field name override was specified.
tagParts := strings.Split(structField.Tag.Get(opts.TagName), ",")
if tagParts[0] != "" {
fieldName = tagParts[0]
}
func getMapUpdates(destType reflect.Type, srcData reflect.Value, hook FieldHookFunc, tagName string) (map[interface{}]interface{}, error) {
updates := make(map[interface{}]interface{})
var errors []error

// Get the value for this field from the source map, if any.
key := reflect.ValueOf(fieldName)
value := data.MapIndex(key)
if !value.IsValid() {
// Case-insensitive linear search if the name doesn't match
// as-is.
for _, kV := range data.MapKeys() {
// Kind() == Interface if map[interface{}]* so we use
// Interface().(string) to handle interface{} and string
// keys.
k, ok := kV.Interface().(string)
if !ok {
continue
}

if strings.EqualFold(k, fieldName) {
key = kV
value = data.MapIndex(kV)
break
}
}
}
decodableFields := getDecodableStructFields(destType, tagName)
for _, structField := range decodableFields {
// This field resolution logic is adapted from mapstructure's own
// logic.
//
// See https://github.com/mitchellh/mapstructure/blob/53818660ed4955e899c0bcafa97299a388bd7c8e/mapstructure.go#L741

if !value.IsValid() {
// No value specified for this field in source map.
continue
}
fieldName := structField.Name

newValue, err := hook(value.Type(), structField, value)
if err != nil {
errors = append(errors, fmt.Errorf(
"error reading into field %q: %v", fieldName, err))
continue
}
// Field name override was specified.
tagParts := strings.Split(structField.Tag.Get(tagName), ",")
if tagParts[0] != "" {
fieldName = tagParts[0]
}

if newValue == value {
continue
}
// Get the value for this field from the source map, if any.
key := reflect.ValueOf(fieldName)
value := srcData.MapIndex(key)
if !value.IsValid() {
// Case-insensitive linear search if the name doesn't match
// as-is.
for _, kV := range srcData.MapKeys() {
// Kind() == Interface if map[interface{}]* so we use
// Interface().(string) to handle interface{} and string
// keys.
k, ok := kV.Interface().(string)
if !ok {
continue
}

// If we can, assign in-place.
if newValue.Type().AssignableTo(value.Type()) {
// XXX(abg): Is it okay to make updates to the source map?
data.SetMapIndex(key, newValue)
} else {
updates[key.Interface()] = newValue.Interface()
if strings.EqualFold(k, fieldName) {
key = kV
value = srcData.MapIndex(kV)
break
}
}
}

if len(errors) > 0 {
return data, multierr.Combine(errors...)
if !value.IsValid() {
// No value specified for this field in source map.
continue
}

// No more changes to make.
if len(updates) == 0 {
return data, nil
newValue, err := hook(value.Type(), structField, value)
if err != nil {
errors = append(errors, fmt.Errorf(
"error reading into field %q: %v", fieldName, err))
continue
}

// Equivalent to,
//
// newData := make(map[$key]interface{})
// for k, v := range data {
// if newV, ok := updates[k]; ok {
// newData[k] = newV
// } else {
// newData[k] = v
// }
// }
newData := reflect.MakeMap(reflect.MapOf(from.Key(), _typeOfEmptyInterface))
for _, key := range data.MapKeys() {
var value reflect.Value
if v, ok := updates[key.Interface()]; ok {
value = reflect.ValueOf(v)
} else {
value = data.MapIndex(key)
if newValue == value {
continue
}

updates[key.Interface()] = newValue.Interface()
}

if len(errors) > 0 {
return nil, multierr.Combine(errors...)
}

return updates, nil
}

func getDecodableStructFields(structType reflect.Type, tagName string) []reflect.StructField {
fields := make([]reflect.StructField, 0)
for i := 0; i < structType.NumField(); i++ {
structField := structType.Field(i)
if structField.PkgPath != "" && !structField.Anonymous {
// This field is not exported so we won't be able to decode
// into it.
continue
}

// TODO pull in the squash info from the decoder options
squash := structField.Anonymous
tagParts := strings.Split(structField.Tag.Get(tagName), ",")
for _, tag := range tagParts[1:] {
switch tag {
case "squash":
squash = true
case "nosquash": // explicit opt-out
squash = false
default:
continue
}
newData.SetMapIndex(key, value)
break
}

if squash && structField.Type.Kind() == reflect.Struct {
fields = append(fields, getDecodableStructFields(structField.Type, tagName)...)
continue
}
fields = append(fields, structField)
}
return fields
}

return newData, nil
func applyUpdates(updates map[interface{}]interface{}, srcData reflect.Value) reflect.Value {
// Equivalent to,
//
// newData := make(map[$key]interface{})
// for k, v := range data {
// if newV, ok := updates[k]; ok {
// newData[k] = newV
// } else {
// newData[k] = v
// }
// }
newData := reflect.MakeMap(reflect.MapOf(srcData.Type().Key(), _typeOfEmptyInterface))
for _, key := range srcData.MapKeys() {
var value reflect.Value
if v, ok := updates[key.Interface()]; ok {
value = reflect.ValueOf(v)
} else {
value = srcData.MapIndex(key)
}
newData.SetMapIndex(key, value)
}
return newData
}
39 changes: 39 additions & 0 deletions hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,45 @@ func TestMultipleFieldHooks(t *testing.T) {
assert.Equal(t, 42, dest.Int)
}

func TestEmbeddedFieldHooks(t *testing.T) {
type subDest struct {
Int int
}
var dest struct {
subDest
}

mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

hook1 := newMockFieldHook(mockCtrl)
hook2 := newMockFieldHook(mockCtrl)

typeOfInt := reflect.TypeOf(42)

hook1.
Expect(_typeOfEmptyInterface, structField{
Name: "Int",
Type: typeOfInt,
}, reflectEq{"FOO"}).
Return(valueOf("BAR"), nil)

hook2.
Expect(reflect.TypeOf(""), structField{
Name: "Int",
Type: typeOfInt,
}, reflectEq{"BAR"}).
Return(valueOf(42), nil)

err := Decode(&dest, map[string]interface{}{"int": "FOO"},
FieldHook(hook1.Hook()),
FieldHook(hook2.Hook()),
)
require.NoError(t, err)

assert.Equal(t, 42, dest.Int)
}

func TestMultipleDecodeHooks(t *testing.T) {
type myStruct struct{ String string }

Expand Down

0 comments on commit c756d00

Please sign in to comment.