Skip to content

Commit

Permalink
Fix unmarshalling when overriding from multiple sources
Browse files Browse the repository at this point in the history
  • Loading branch information
rsafonseca committed Aug 2, 2023
1 parent 389df17 commit 91e94ea
Showing 1 changed file with 72 additions and 5 deletions.
77 changes: 72 additions & 5 deletions config/viper_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
package config

import (
"fmt"
"github.com/mitchellh/mapstructure"
"reflect"
"strings"
"time"

Expand Down Expand Up @@ -189,12 +192,76 @@ func (c *Config) GetStringMapString(s string) map[string]string {
return c.config.GetStringMapString(s)
}

// UnmarshalKey unmarshals key into v
func (c *Config) UnmarshalKey(s string, v interface{}) error {
return c.config.UnmarshalKey(s, v)
}

// Unmarshal unmarshals config into v
func (c *Config) Unmarshal(v interface{}) error {
return c.config.Unmarshal(v)
}

// UnmarshalKey unmarshals key into v
func (c *Config) UnmarshalKey(key string, rawVal interface{}, opts ...viper.DecoderConfigOption) error {
key = strings.ToLower(key)
delimiter := "."
prefix := key + delimiter

i := c.config.Get(key)
if isStringMapInterface(i) {
val := i.(map[string]interface{})
keys := c.config.AllKeys()
for _, k := range keys {
if !strings.HasPrefix(k, prefix) {
continue
}
fmt.Printf("prefix: %v\n", prefix)
mk := strings.TrimPrefix(k, prefix)
fmt.Printf("got key1: %v\n", mk)
mk = strings.Split(mk, delimiter)[0]
fmt.Printf("got key2: %v\n", mk)
if _, exists := val[mk]; exists {
continue
}
mv := c.Get(key + delimiter + mk)
fmt.Printf("got key5: %v\n", mv)
if mv == nil {
continue
}
val[mk] = mv
}
i = val
}
return decode(i, defaultDecoderConfig(rawVal, opts...))
}

func isStringMapInterface(val interface{}) bool {
vt := reflect.TypeOf(val)
return vt.Kind() == reflect.Map &&
vt.Key().Kind() == reflect.String &&
vt.Elem().Kind() == reflect.Interface
}

// A wrapper around mapstructure.Decode that mimics the WeakDecode functionality
func decode(input interface{}, config *mapstructure.DecoderConfig) error {
decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return err
}
return decoder.Decode(input)
}

// defaultDecoderConfig returns default mapstructure.DecoderConfig with support
// of time.Duration values & string slices
func defaultDecoderConfig(output interface{}, opts ...viper.DecoderConfigOption) *mapstructure.DecoderConfig {
c := &mapstructure.DecoderConfig{
Metadata: nil,
Result: output,
WeaklyTypedInput: true,
DecodeHook: mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
),
}
for _, opt := range opts {
opt(c)
}
return c

}

0 comments on commit 91e94ea

Please sign in to comment.