Skip to content

Commit

Permalink
Support customising mapstructure.DecoderConfig for Unmarshal
Browse files Browse the repository at this point in the history
* Added a new `DecoderConfigOption` type allowing the user to write custom
  functions that can override the default mapstructure.DecoderConfig
  settings

* Added a new `DecodeHook` function which returns
  a `DecoderConfigOption`. This allows the user to easily set their own
  Decode hooks when Unmarshaling

* Updated Unmarshal, UnmarshalKey and defaultDecoderConfig to support variadic
  trailing `DecoderConfigOption` functions to allow for customisation of
  the default  mapstructure.DecoderConfig

* Added a test case with example usage
  • Loading branch information
krak3n authored and bep committed Aug 3, 2018
1 parent d493c32 commit 907c19d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 8 deletions.
41 changes: 33 additions & 8 deletions viper.go
Expand Up @@ -113,6 +113,23 @@ func (fnfe ConfigFileNotFoundError) Error() string {
return fmt.Sprintf("Config File %q Not Found in %q", fnfe.name, fnfe.locations)
}

// A DecoderConfigOption can be passed to viper.Unmarshal to configure
// mapstructure.DecoderConfig options
type DecoderConfigOption func(*mapstructure.DecoderConfig)

// DecodeHook returns a DecoderConfigOption which overrides the default
// DecoderConfig.DecodeHook value, the default is:
//
// mapstructure.ComposeDecodeHookFunc(
// mapstructure.StringToTimeDurationHookFunc(),
// mapstructure.StringToSliceHookFunc(","),
// )
func DecodeHook(hook mapstructure.DecodeHookFunc) DecoderConfigOption {
return func(c *mapstructure.DecoderConfig) {
c.DecodeHook = hook
}
}

// Viper is a prioritized configuration registry. It
// maintains a set of configuration sources, fetches
// values to populate those, and provides them according
Expand Down Expand Up @@ -745,9 +762,11 @@ func (v *Viper) GetSizeInBytes(key string) uint {
}

// UnmarshalKey takes a single key and unmarshals it into a Struct.
func UnmarshalKey(key string, rawVal interface{}) error { return v.UnmarshalKey(key, rawVal) }
func (v *Viper) UnmarshalKey(key string, rawVal interface{}) error {
err := decode(v.Get(key), defaultDecoderConfig(rawVal))
func UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
return v.UnmarshalKey(key, rawVal, opts...)
}
func (v *Viper) UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
err := decode(v.Get(key), defaultDecoderConfig(rawVal, opts...))

if err != nil {
return err
Expand All @@ -760,9 +779,11 @@ func (v *Viper) UnmarshalKey(key string, rawVal interface{}) error {

// Unmarshal unmarshals the config into a Struct. Make sure that the tags
// on the fields of the structure are properly set.
func Unmarshal(rawVal interface{}) error { return v.Unmarshal(rawVal) }
func (v *Viper) Unmarshal(rawVal interface{}) error {
err := decode(v.AllSettings(), defaultDecoderConfig(rawVal))
func Unmarshal(rawVal interface{}, opts ...DecoderConfigOption) error {
return v.Unmarshal(rawVal, opts...)
}
func (v *Viper) Unmarshal(rawVal interface{}, opts ...DecoderConfigOption) error {
err := decode(v.AllSettings(), defaultDecoderConfig(rawVal, opts...))

if err != nil {
return err
Expand All @@ -775,8 +796,8 @@ func (v *Viper) Unmarshal(rawVal interface{}) error {

// defaultDecoderConfig returns default mapsstructure.DecoderConfig with suppot
// of time.Duration values & string slices
func defaultDecoderConfig(output interface{}) *mapstructure.DecoderConfig {
return &mapstructure.DecoderConfig{
func defaultDecoderConfig(output interface{}, opts ...DecoderConfigOption) *mapstructure.DecoderConfig {
c := &mapstructure.DecoderConfig{
Metadata: nil,
Result: output,
WeaklyTypedInput: true,
Expand All @@ -785,6 +806,10 @@ func defaultDecoderConfig(output interface{}) *mapstructure.DecoderConfig {
mapstructure.StringToSliceHookFunc(","),
),
}
for _, opt := range opts {
opt(c)
}
return c
}

// A wrapper around mapstructure.Decode that mimics the WeakDecode functionality
Expand Down
38 changes: 38 additions & 0 deletions viper_test.go
Expand Up @@ -7,6 +7,7 @@ package viper

import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
Expand All @@ -18,6 +19,7 @@ import (
"testing"
"time"

"github.com/mitchellh/mapstructure"
"github.com/spf13/afero"
"github.com/spf13/cast"

Expand Down Expand Up @@ -503,6 +505,42 @@ func TestUnmarshal(t *testing.T) {
assert.Equal(t, &config{Name: "Steve", Port: 1234, Duration: time.Second + time.Millisecond}, &C)
}

func TestUnmarshalWithDecoderOptions(t *testing.T) {
Set("credentials", "{\"foo\":\"bar\"}")

opt := DecodeHook(mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
// Custom Decode Hook Function
func(rf reflect.Kind, rt reflect.Kind, data interface{}) (interface{}, error) {
if rf != reflect.String || rt != reflect.Map {
return data, nil
}
m := map[string]string{}
raw := data.(string)
if raw == "" {
return m, nil
}
return m, json.Unmarshal([]byte(raw), &m)
},
))

type config struct {
Credentials map[string]string
}

var C config

err := Unmarshal(&C, opt)
if err != nil {
t.Fatalf("unable to decode into struct, %v", err)
}

assert.Equal(t, &config{
Credentials: map[string]string{"foo": "bar"},
}, &C)
}

func TestBindPFlags(t *testing.T) {
v := New() // create independent Viper object
flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)
Expand Down

0 comments on commit 907c19d

Please sign in to comment.