From 8b7bb8efbf78c0f2b129b66a39e65246f333881e Mon Sep 17 00:00:00 2001 From: Will Hughes Date: Thu, 18 May 2017 16:47:39 -0700 Subject: [PATCH] Refactor fieldHook function Summary: I was having a hard time reasoning about the fieldHook function so I did some refactoring to make it easier to add the embedded struct fields. - Remove the in-place data value updates (this appears to have been an optimization, but, it's not necessary and make the code hard to reason about) - Extract getting the updates struct into it's own function. - Extract getting decodable Struct fields into it's own function. - Extract applying updates to the srcData into it's own function. - Rename variables to be more explicit - Don't use the `from` field, which should be equivalent to srcData.Type() Test Plan: tests pass --- hooks.go | 197 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 106 insertions(+), 91 deletions(-) diff --git a/hooks.go b/hooks.go index 6086394..334043a 100644 --- a/hooks.go +++ b/hooks.go @@ -210,14 +210,14 @@ func strconvHook(from, to reflect.Type, data reflect.Value) (reflect.Value, erro // fieldHook applies the user-specified FieldHookFunc to all struct fields. func fieldHook(opts *options) DecodeHookFunc { hook := composeFieldHooks(opts.FieldHooks) - return func(from, to reflect.Type, data reflect.Value) (reflect.Value, error) { - if to.Kind() != reflect.Struct || from.Kind() != reflect.Map { - return data, nil + return func(_, destType reflect.Type, srcData reflect.Value) (reflect.Value, error) { + if destType.Kind() != reflect.Struct || srcData.Type().Kind() != reflect.Map { + return srcData, nil } // We can only decode map[string]* and map[interface{}]* into structs. - if k := from.Key().Kind(); k != reflect.String && k != reflect.Interface { - return data, nil + if k := srcData.Type().Key().Kind(); k != reflect.String && k != reflect.Interface { + return srcData, nil } // This map tracks type-changing updates to items in the map. @@ -227,108 +227,123 @@ func fieldHook(opts *options) DecodeHookFunc { // values in-place if a hook changed the type of a value. So we will // make a copy of the source map with a more liberal type and inject // these updates into the copy. - updates := make(map[interface{}]interface{}) - - var errors []error - for i := 0; i < to.NumField(); i++ { - structField := to.Field(i) - if structField.PkgPath != "" && !structField.Anonymous { - // This field is not exported so we won't be able to decode - // into it. - continue - } + updates, err := getMapUpdates(destType, srcData, hook, opts.TagName) + if err != nil { + return srcData, err + } - // This field resolution logic is adapted from mapstructure's own - // logic. - // - // See https://github.com/mitchellh/mapstructure/blob/53818660ed4955e899c0bcafa97299a388bd7c8e/mapstructure.go#L741 + // No more changes to make. + if len(updates) == 0 { + return srcData, nil + } - fieldName := structField.Name + return applyUpdates(updates, srcData), nil + } +} - // Field name override was specified. - tagParts := strings.Split(structField.Tag.Get(opts.TagName), ",") - if tagParts[0] != "" { - fieldName = tagParts[0] - } +func getMapUpdates(destType reflect.Type, srcData reflect.Value, hook FieldHookFunc, tagName string) (map[interface{}]interface{}, error) { + updates := make(map[interface{}]interface{}) + var errors []error - // Get the value for this field from the source map, if any. - key := reflect.ValueOf(fieldName) - value := data.MapIndex(key) - if !value.IsValid() { - // Case-insensitive linear search if the name doesn't match - // as-is. - for _, kV := range data.MapKeys() { - // Kind() == Interface if map[interface{}]* so we use - // Interface().(string) to handle interface{} and string - // keys. - k, ok := kV.Interface().(string) - if !ok { - continue - } - - if strings.EqualFold(k, fieldName) { - key = kV - value = data.MapIndex(kV) - break - } - } - } + decodableFields := getDecodableStructFields(destType) + for _, structField := range decodableFields { + // This field resolution logic is adapted from mapstructure's own + // logic. + // + // See https://github.com/mitchellh/mapstructure/blob/53818660ed4955e899c0bcafa97299a388bd7c8e/mapstructure.go#L741 - if !value.IsValid() { - // No value specified for this field in source map. - continue - } + fieldName := structField.Name - newValue, err := hook(value.Type(), structField, value) - if err != nil { - errors = append(errors, fmt.Errorf( - "error reading into field %q: %v", fieldName, err)) - continue - } + // Field name override was specified. + tagParts := strings.Split(structField.Tag.Get(tagName), ",") + if tagParts[0] != "" { + fieldName = tagParts[0] + } - if newValue == value { - continue - } + // Get the value for this field from the source map, if any. + key := reflect.ValueOf(fieldName) + value := srcData.MapIndex(key) + if !value.IsValid() { + // Case-insensitive linear search if the name doesn't match + // as-is. + for _, kV := range srcData.MapKeys() { + // Kind() == Interface if map[interface{}]* so we use + // Interface().(string) to handle interface{} and string + // keys. + k, ok := kV.Interface().(string) + if !ok { + continue + } - // If we can, assign in-place. - if newValue.Type().AssignableTo(value.Type()) { - // XXX(abg): Is it okay to make updates to the source map? - data.SetMapIndex(key, newValue) - } else { - updates[key.Interface()] = newValue.Interface() + if strings.EqualFold(k, fieldName) { + key = kV + value = srcData.MapIndex(kV) + break + } } } - if len(errors) > 0 { - return data, multierr.Combine(errors...) + if !value.IsValid() { + // No value specified for this field in source map. + continue } - // No more changes to make. - if len(updates) == 0 { - return data, nil + newValue, err := hook(value.Type(), structField, value) + if err != nil { + errors = append(errors, fmt.Errorf( + "error reading into field %q: %v", fieldName, err)) + continue } - // Equivalent to, - // - // newData := make(map[$key]interface{}) - // for k, v := range data { - // if newV, ok := updates[k]; ok { - // newData[k] = newV - // } else { - // newData[k] = v - // } - // } - newData := reflect.MakeMap(reflect.MapOf(from.Key(), _typeOfEmptyInterface)) - for _, key := range data.MapKeys() { - var value reflect.Value - if v, ok := updates[key.Interface()]; ok { - value = reflect.ValueOf(v) - } else { - value = data.MapIndex(key) - } - newData.SetMapIndex(key, value) + if newValue == value { + continue } - return newData, nil + updates[key.Interface()] = newValue.Interface() + } + + if len(errors) > 0 { + return nil, multierr.Combine(errors...) + } + + return updates, nil +} + +func getDecodableStructFields(structType reflect.Type) []reflect.StructField { + fields := make([]reflect.StructField, 0) + for i := 0; i < structType.NumField(); i++ { + structField := structType.Field(i) + if structField.PkgPath != "" && !structField.Anonymous { + // This field is not exported so we won't be able to decode + // into it. + continue + } + // TODO account for embedded struct fields + fields = append(fields, structField) + } + return fields +} + +func applyUpdates(updates map[interface{}]interface{}, srcData reflect.Value) reflect.Value { + // Equivalent to, + // + // newData := make(map[$key]interface{}) + // for k, v := range data { + // if newV, ok := updates[k]; ok { + // newData[k] = newV + // } else { + // newData[k] = v + // } + // } + newData := reflect.MakeMap(reflect.MapOf(srcData.Type().Key(), _typeOfEmptyInterface)) + for _, key := range srcData.MapKeys() { + var value reflect.Value + if v, ok := updates[key.Interface()]; ok { + value = reflect.ValueOf(v) + } else { + value = srcData.MapIndex(key) + } + newData.SetMapIndex(key, value) } + return newData }