-
Notifications
You must be signed in to change notification settings - Fork 0
/
appcfg.go
149 lines (136 loc) · 4.43 KB
/
appcfg.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package appcfg
import (
"errors"
"flag"
"fmt"
"reflect"
"strings"
"github.com/spf13/pflag"
)
// ProvideStruct updates cfg using values from given providers. Given cfg
// must be a ref to struct with all exported fields having Value type and
// struct tag with tags for given providers. Current values in cfg, if
// any, will be used as defaults.
//
// Providers will be called for each exported field in cfg, in order, with
// next provider will be called only if previous providers won't provide a
// value for a current field.
//
// It is recommended to add cfg fields to FlagSet after all other
// providers will be applied - this way usage message on -h flag will be
// able to show values set by other providers as flag defaults.
//
// Returns error if any provider will try to set invalid value.
func ProvideStruct(cfg interface{}, providers ...Provider) error {
var lastErr error
forStruct(cfg, func(value Value, name string, tags Tags) {
for _, provider := range providers {
ok, err := provider.Provide(value, name, tags)
if err != nil {
lastErr = fmt.Errorf("%s: %w", field(name, tags), err)
break
}
if ok {
_ = value.Get() // Mark previous value as completed (in case it's a Slice).
break
}
}
})
return lastErr
}
// RequiredError is returned from Value(&err) methods if value wasn't set.
type RequiredError struct{ Value }
// Error implements error interface.
func (*RequiredError) Error() string { return "value required" }
// WrapErr adds more details about err.Value (if err is a RequiredError)
// by looking for related flag name and field name/tags in given fs and
// cfgs, otherwise returns err as is.
func WrapErr(err error, fs *flag.FlagSet, cfgs ...interface{}) error {
if reqErr := new(RequiredError); errors.As(err, &reqErr) {
var flagName string
if fs != nil {
fs.VisitAll(func(f *flag.Flag) {
if f.Value == reqErr.Value {
flagName = "-" + f.Name
}
})
}
return doWrapErr(reqErr, flagName, cfgs...)
}
return err
}
// WrapPErr adds more details about err.Value (if err is a RequiredError)
// by looking for related flag name and field name/tags in given fs and
// cfgs, otherwise returns err as is.
func WrapPErr(err error, fs *pflag.FlagSet, cfgs ...interface{}) error {
if reqErr := new(RequiredError); errors.As(err, &reqErr) {
var flagName string
if fs != nil {
fs.VisitAll(func(f *pflag.Flag) {
if f.Value == reqErr.Value {
flagName = "--" + f.Name
}
})
}
return doWrapErr(reqErr, flagName, cfgs...)
}
return err
}
func doWrapErr(reqErr *RequiredError, flagName string, cfgs ...interface{}) error {
var lastErr error
for _, cfg := range cfgs {
forStruct(cfg, func(value Value, name string, tags Tags) {
if value == reqErr.Value {
lastErr = fmt.Errorf("%s: %w", field(name, flagName, tags), reqErr)
}
})
}
if lastErr == nil {
panic("required value not found in cfgs")
}
return lastErr
}
func forStruct(cfg interface{}, handle func(Value, string, Tags)) {
val := reflect.ValueOf(cfg)
typ := val.Type()
if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
panic("cfg: must be a ptr to struct")
}
typ = typ.Elem()
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
if f.PkgPath != "" {
continue
}
if !implementsValue(f.Type) {
panic(fmt.Sprintf("cfg.%s: must implements Value", f.Name))
}
f.Tag = reflect.StructTag(strings.ReplaceAll(string(f.Tag), "\n", " "))
f.Tag = reflect.StructTag(strings.ReplaceAll(string(f.Tag), "\t", " "))
value := val.Elem().FieldByName(f.Name).Addr().Interface().(Value) //nolint:forcetypeassert // Want panic.
handle(value, f.Name, f.Tag)
}
}
func field(name string, sources ...interface{}) string {
s := strings.TrimSpace(strings.Join(strings.Fields(fmt.Sprintln(sources...)), " "))
if s != "" {
return fmt.Sprintf("%s (%s)", name, s)
}
return name
}
// AddFlag defines a flag with the specified name and usage string.
// Calling it again with same fs, value and name will have no effect.
func AddFlag(fs *flag.FlagSet, value flag.Value, name string, usage string) {
if f := fs.Lookup(name); f != nil && f.Value == value {
return
}
fs.Var(value, name, usage)
}
// AddPFlag defines a flag with the specified name and usage string.
// Calling it again with same fs, value and name will have no effect.
func AddPFlag(fs *pflag.FlagSet, value pflag.Value, name string, usage string) {
if f := fs.Lookup(name); f != nil && f.Value == value {
return
}
fs.Var(value, name, usage)
}