Skip to content

Commit

Permalink
feat: allow nested koanf file merging (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik committed Dec 5, 2020
1 parent d200415 commit c1adb0d
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 29 deletions.
64 changes: 58 additions & 6 deletions configx/koanf_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ import (
"context"
"io/ioutil"
"path/filepath"
"strings"

"github.com/knadh/koanf"
"github.com/knadh/koanf/parsers/json"
"github.com/knadh/koanf/parsers/toml"
"github.com/knadh/koanf/parsers/yaml"

"github.com/ory/x/stringslice"

"github.com/pkg/errors"

Expand All @@ -12,23 +20,67 @@ import (

// KoanfFile implements a KoanfFile provider.
type KoanfFile struct {
path string
ctx context.Context
subKey string
path string
ctx context.Context
parser koanf.Parser
}

// Provider returns a file provider.
func NewKoanfFile(ctx context.Context, path string) *KoanfFile {
return &KoanfFile{path: filepath.Clean(path), ctx: ctx}
func NewKoanfFile(ctx context.Context, path string) (*KoanfFile, error) {
return NewKoanfFileSubKey(ctx, path, "")
}

func NewKoanfFileSubKey(ctx context.Context, path, subKey string) (*KoanfFile, error) {
kf := &KoanfFile{
path: filepath.Clean(path),
ctx: ctx,
subKey: subKey,
}

switch e := filepath.Ext(path); e {
case ".toml":
kf.parser = toml.Parser()
case ".json":
kf.parser = json.Parser()
case ".yaml", ".yml":
kf.parser = yaml.Parser()
default:
return nil, errors.Errorf("unknown config file extension: %s", e)
}

return kf, nil
}

// ReadBytes reads the contents of a file on disk and returns the bytes.
func (f *KoanfFile) ReadBytes() ([]byte, error) {
return ioutil.ReadFile(f.path)
return nil, errors.New("file provider does not support this method")
}

// Read is not supported by the file provider.
func (f *KoanfFile) Read() (map[string]interface{}, error) {
return nil, errors.New("file provider does not support this method")
fc, err := ioutil.ReadFile(f.path)
if err != nil {
return nil, errors.WithStack(err)
}

v, err := f.parser.Unmarshal(fc)
if err != nil {
return nil, errors.WithStack(err)
}

if f.subKey == "" {
return v, nil
}

path := strings.Split(f.subKey, Delimiter)
for _, k := range stringslice.Reverse(path) {
v = map[string]interface{}{
k: v,
}
}

return v, nil
}

// WatchChannel watches the file and triggers a callback when it changes. It is a
Expand Down
88 changes: 88 additions & 0 deletions configx/koanf_file_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package configx

import (
"context"
"encoding/json"
"io/ioutil"
"testing"

"github.com/ghodss/yaml"
"github.com/pelletier/go-toml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestKoanfFile(t *testing.T) {
setupFile := func(t *testing.T, fn, fc, subKey string) *KoanfFile {
f, err := ioutil.TempFile("", fn)
require.NoError(t, err)
_, err = f.Write([]byte(fc))
require.NoError(t, err)

kf, err := NewKoanfFileSubKey(context.Background(), f.Name(), subKey)
require.NoError(t, err)
return kf
}

t.Run("case=reads json root file", func(t *testing.T) {
v := map[string]interface{}{
"foo": "bar",
}
encV, err := json.Marshal(v)
require.NoError(t, err)

kf := setupFile(t, "config*.json", string(encV), "")

actual, err := kf.Read()
require.NoError(t, err)
assert.Equal(t, v, actual)
})

t.Run("case=reads yaml root file", func(t *testing.T) {
v := map[string]interface{}{
"foo": "yaml string",
}
encV, err := yaml.Marshal(v)
require.NoError(t, err)

kf := setupFile(t, "config*.yml", string(encV), "")

actual, err := kf.Read()
require.NoError(t, err)
assert.Equal(t, v, actual)
})

t.Run("case=reads toml root file", func(t *testing.T) {
v := map[string]interface{}{
"foo": "toml string",
}
encV, err := toml.Marshal(v)
require.NoError(t, err)

kf := setupFile(t, "config*.toml", string(encV), "")

actual, err := kf.Read()
require.NoError(t, err)
assert.Equal(t, v, actual)
})

t.Run("case=reads json file as subkey", func(t *testing.T) {
v := map[string]interface{}{
"bar": "asdf",
}
encV, err := json.Marshal(v)
require.NoError(t, err)

kf := setupFile(t, "config*.json", string(encV), "parent.of.config")

actual, err := kf.Read()
require.NoError(t, err)
assert.Equal(t, map[string]interface{}{
"parent": map[string]interface{}{
"of": map[string]interface{}{
"config": v,
},
},
}, actual)
})
}
4 changes: 2 additions & 2 deletions configx/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ func WithContext(ctx context.Context) OptionModifier {
}
}

func WithImmutables(immutables []string) OptionModifier {
func WithImmutables(immutables ...string) OptionModifier {
return func(p *Provider) {
p.immutables = immutables
}
}

func OmitKeysFromTracing(keys []string) OptionModifier {
func OmitKeysFromTracing(keys ...string) OptionModifier {
return func(p *Provider) {
p.excludeFieldsFromTracing = keys
}
Expand Down
37 changes: 17 additions & 20 deletions configx/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io"
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"time"
Expand All @@ -28,8 +27,6 @@ import (

"github.com/knadh/koanf"
"github.com/knadh/koanf/parsers/json"
"github.com/knadh/koanf/parsers/toml"
"github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/confmap"
"github.com/pkg/errors"
"github.com/rs/cors"
Expand All @@ -48,6 +45,15 @@ type Provider struct {
tracer *tracing.Tracer
}

const (
FlagConfig = "config"
Delimiter = "."
)

func RegisterConfigFlag(flags *pflag.FlagSet, fallback []string) {
flags.StringSliceP(FlagConfig, "c", fallback, "Config files to load, overwriting in the order specified.")
}

// New creates a new provider instance or errors.
// Configuration values are loaded in the following order:
//
Expand Down Expand Up @@ -101,7 +107,7 @@ func (p *Provider) validate(k *koanf.Koanf) error {
}

func (p *Provider) newKoanf(ctx context.Context) (*koanf.Koanf, error) {
k := koanf.New(".")
k := koanf.New(Delimiter)

dp, err := NewKoanfSchemaDefaults(p.schema)
if err != nil {
Expand All @@ -118,7 +124,7 @@ func (p *Provider) newKoanf(ctx context.Context) (*koanf.Koanf, error) {
return nil, err
}

paths, err := p.flags.GetStringSlice("config")
paths, _ := p.flags.GetStringSlice(FlagConfig)
for _, configFile := range paths {
if err := p.addConfigFile(ctx, configFile, k); err != nil {
return nil, err
Expand Down Expand Up @@ -186,21 +192,12 @@ func (p *Provider) runOnChanges(e watcherx.Event, err error) {
}

func (p *Provider) addConfigFile(ctx context.Context, path string, k *koanf.Koanf) error {
var parser koanf.Parser

switch e := filepath.Ext(path); e {
case ".toml":
parser = toml.Parser()
case ".json":
parser = json.Parser()
case ".yaml", ".yml":
parser = yaml.Parser()
default:
return errors.Errorf("unknown config file extension: %s", e)
}

ctx, cancel := context.WithCancel(p.ctx)
fp := NewKoanfFile(ctx, path)
fp, err := NewKoanfFile(ctx, path)
if err != nil {
cancel()
return err
}

c := make(watcherx.EventChannel)
go func(c watcherx.EventChannel) {
Expand Down Expand Up @@ -250,7 +247,7 @@ func (p *Provider) addConfigFile(ctx context.Context, path string, k *koanf.Koan
return err
}

return k.Load(fp, parser)
return k.Load(fp, nil)
}

func (p *Provider) Set(key string, value interface{}) {
Expand Down
2 changes: 1 addition & 1 deletion configx/provider_watch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func TestReload(t *testing.T) {
hook := test.NewLocal(l.Entry.Logger)
wg := new(sync.WaitGroup)
p := setup(t, configFile, wg,
WithImmutables([]string{"dsn"}))
WithImmutables("dsn"))

assert.Equal(t, []*logrus.Entry{}, hook.AllEntries())
assert.Equal(t, "memory", p.String("dsn"))
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/dgraph-io/ristretto v0.0.2
github.com/fatih/structs v1.1.0
github.com/fsnotify/fsnotify v1.4.9
github.com/ghodss/yaml v1.0.0
github.com/go-bindata/go-bindata v3.1.1+incompatible
github.com/go-sql-driver/mysql v1.5.0
github.com/gobuffalo/httptest v1.0.2
Expand Down Expand Up @@ -34,6 +35,7 @@ require (
github.com/ory/herodot v0.8.3
github.com/ory/jsonschema/v3 v3.0.1
github.com/pborman/uuid v1.2.0
github.com/pelletier/go-toml v1.7.0
github.com/philhofer/fwd v1.0.0 // indirect
github.com/pkg/errors v0.9.1
github.com/pkg/profile v1.2.1
Expand Down
11 changes: 11 additions & 0 deletions stringslice/reverse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package stringslice

func Reverse(s []string) []string {
r := make([]string, len(s))

for i, j := 0, len(r)-1; i <= j; i, j = i+1, j-1 {
r[i], r[j] = s[j], s[i]
}

return r
}
35 changes: 35 additions & 0 deletions stringslice/reverse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package stringslice

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestReverse(t *testing.T) {
for i, tc := range []struct {
i, e []string
}{
{
i: []string{"a", "b", "c"},
e: []string{"c", "b", "a"},
},
{
i: []string{"foo"},
e: []string{"foo"},
},
{
i: []string{"foo", "bar"},
e: []string{"bar", "foo"},
},
{
i: []string{},
e: []string{},
},
} {
t.Run(fmt.Sprintf("case=%d/input:%v expected:%v", i, tc.i, tc.e), func(t *testing.T) {
assert.Equal(t, tc.e, Reverse(tc.i))
})
}
}

0 comments on commit c1adb0d

Please sign in to comment.