Skip to content
This repository has been archived by the owner on Oct 2, 2020. It is now read-only.

Commit

Permalink
Refactor fieldHook function
Browse files Browse the repository at this point in the history
Summary: I was having a hard time reasoning about the fieldHook function
so I did some refactoring to make it easier to add the embedded struct
fields.

- Remove the in-place data value updates (this appears to have been an
optimization, but, it's not necessary and make the code hard to reason
about)
- Extract getting the updates struct into it's own function.
- Extract getting decodable Struct fields into it's own function.
- Extract applying updates to the srcData into it's own function.
- Rename variables to be more explicit
- Don't use the `from` field, which should be equivalent to
srcData.Type()

Test Plan: tests pass
  • Loading branch information
Will Hughes committed May 19, 2017
1 parent 97bfdfc commit 8b7bb8e
Showing 1 changed file with 106 additions and 91 deletions.
197 changes: 106 additions & 91 deletions hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ func strconvHook(from, to reflect.Type, data reflect.Value) (reflect.Value, erro
// fieldHook applies the user-specified FieldHookFunc to all struct fields.
func fieldHook(opts *options) DecodeHookFunc {
hook := composeFieldHooks(opts.FieldHooks)
return func(from, to reflect.Type, data reflect.Value) (reflect.Value, error) {
if to.Kind() != reflect.Struct || from.Kind() != reflect.Map {
return data, nil
return func(_, destType reflect.Type, srcData reflect.Value) (reflect.Value, error) {
if destType.Kind() != reflect.Struct || srcData.Type().Kind() != reflect.Map {
return srcData, nil
}

// We can only decode map[string]* and map[interface{}]* into structs.
if k := from.Key().Kind(); k != reflect.String && k != reflect.Interface {
return data, nil
if k := srcData.Type().Key().Kind(); k != reflect.String && k != reflect.Interface {
return srcData, nil
}

// This map tracks type-changing updates to items in the map.
Expand All @@ -227,108 +227,123 @@ func fieldHook(opts *options) DecodeHookFunc {
// values in-place if a hook changed the type of a value. So we will
// make a copy of the source map with a more liberal type and inject
// these updates into the copy.
updates := make(map[interface{}]interface{})

var errors []error
for i := 0; i < to.NumField(); i++ {
structField := to.Field(i)
if structField.PkgPath != "" && !structField.Anonymous {
// This field is not exported so we won't be able to decode
// into it.
continue
}
updates, err := getMapUpdates(destType, srcData, hook, opts.TagName)
if err != nil {
return srcData, err
}

// This field resolution logic is adapted from mapstructure's own
// logic.
//
// See https://github.com/mitchellh/mapstructure/blob/53818660ed4955e899c0bcafa97299a388bd7c8e/mapstructure.go#L741
// No more changes to make.
if len(updates) == 0 {
return srcData, nil
}

fieldName := structField.Name
return applyUpdates(updates, srcData), nil
}
}

// Field name override was specified.
tagParts := strings.Split(structField.Tag.Get(opts.TagName), ",")
if tagParts[0] != "" {
fieldName = tagParts[0]
}
func getMapUpdates(destType reflect.Type, srcData reflect.Value, hook FieldHookFunc, tagName string) (map[interface{}]interface{}, error) {
updates := make(map[interface{}]interface{})
var errors []error

// Get the value for this field from the source map, if any.
key := reflect.ValueOf(fieldName)
value := data.MapIndex(key)
if !value.IsValid() {
// Case-insensitive linear search if the name doesn't match
// as-is.
for _, kV := range data.MapKeys() {
// Kind() == Interface if map[interface{}]* so we use
// Interface().(string) to handle interface{} and string
// keys.
k, ok := kV.Interface().(string)
if !ok {
continue
}

if strings.EqualFold(k, fieldName) {
key = kV
value = data.MapIndex(kV)
break
}
}
}
decodableFields := getDecodableStructFields(destType)
for _, structField := range decodableFields {
// This field resolution logic is adapted from mapstructure's own
// logic.
//
// See https://github.com/mitchellh/mapstructure/blob/53818660ed4955e899c0bcafa97299a388bd7c8e/mapstructure.go#L741

if !value.IsValid() {
// No value specified for this field in source map.
continue
}
fieldName := structField.Name

newValue, err := hook(value.Type(), structField, value)
if err != nil {
errors = append(errors, fmt.Errorf(
"error reading into field %q: %v", fieldName, err))
continue
}
// Field name override was specified.
tagParts := strings.Split(structField.Tag.Get(tagName), ",")
if tagParts[0] != "" {
fieldName = tagParts[0]
}

if newValue == value {
continue
}
// Get the value for this field from the source map, if any.
key := reflect.ValueOf(fieldName)
value := srcData.MapIndex(key)
if !value.IsValid() {
// Case-insensitive linear search if the name doesn't match
// as-is.
for _, kV := range srcData.MapKeys() {
// Kind() == Interface if map[interface{}]* so we use
// Interface().(string) to handle interface{} and string
// keys.
k, ok := kV.Interface().(string)
if !ok {
continue
}

// If we can, assign in-place.
if newValue.Type().AssignableTo(value.Type()) {
// XXX(abg): Is it okay to make updates to the source map?
data.SetMapIndex(key, newValue)
} else {
updates[key.Interface()] = newValue.Interface()
if strings.EqualFold(k, fieldName) {
key = kV
value = srcData.MapIndex(kV)
break
}
}
}

if len(errors) > 0 {
return data, multierr.Combine(errors...)
if !value.IsValid() {
// No value specified for this field in source map.
continue
}

// No more changes to make.
if len(updates) == 0 {
return data, nil
newValue, err := hook(value.Type(), structField, value)
if err != nil {
errors = append(errors, fmt.Errorf(
"error reading into field %q: %v", fieldName, err))
continue
}

// Equivalent to,
//
// newData := make(map[$key]interface{})
// for k, v := range data {
// if newV, ok := updates[k]; ok {
// newData[k] = newV
// } else {
// newData[k] = v
// }
// }
newData := reflect.MakeMap(reflect.MapOf(from.Key(), _typeOfEmptyInterface))
for _, key := range data.MapKeys() {
var value reflect.Value
if v, ok := updates[key.Interface()]; ok {
value = reflect.ValueOf(v)
} else {
value = data.MapIndex(key)
}
newData.SetMapIndex(key, value)
if newValue == value {
continue
}

return newData, nil
updates[key.Interface()] = newValue.Interface()
}

if len(errors) > 0 {
return nil, multierr.Combine(errors...)
}

return updates, nil
}

func getDecodableStructFields(structType reflect.Type) []reflect.StructField {
fields := make([]reflect.StructField, 0)
for i := 0; i < structType.NumField(); i++ {
structField := structType.Field(i)
if structField.PkgPath != "" && !structField.Anonymous {
// This field is not exported so we won't be able to decode
// into it.
continue
}
// TODO account for embedded struct fields
fields = append(fields, structField)
}
return fields
}

func applyUpdates(updates map[interface{}]interface{}, srcData reflect.Value) reflect.Value {
// Equivalent to,
//
// newData := make(map[$key]interface{})
// for k, v := range data {
// if newV, ok := updates[k]; ok {
// newData[k] = newV
// } else {
// newData[k] = v
// }
// }
newData := reflect.MakeMap(reflect.MapOf(srcData.Type().Key(), _typeOfEmptyInterface))
for _, key := range srcData.MapKeys() {
var value reflect.Value
if v, ok := updates[key.Interface()]; ok {
value = reflect.ValueOf(v)
} else {
value = srcData.MapIndex(key)
}
newData.SetMapIndex(key, value)
}
return newData
}

0 comments on commit 8b7bb8e

Please sign in to comment.