diff --git a/altsrc/flag.go b/altsrc/flag.go index c1fb669393..858e54eaa2 100644 --- a/altsrc/flag.go +++ b/altsrc/flag.go @@ -2,6 +2,7 @@ package altsrc import ( "fmt" + "path/filepath" "strconv" "syscall" @@ -160,6 +161,34 @@ func (f *StringFlag) ApplyInputSourceValue(context *cli.Context, isc InputSource return nil } +// ApplyInputSourceValue applies a Path value to the flagSet if required +func (f *PathFlag) ApplyInputSourceValue(context *cli.Context, isc InputSourceContext) error { + if f.set != nil { + if !(context.IsSet(f.Name) || isEnvVarSet(f.EnvVars)) { + value, err := isc.String(f.PathFlag.Name) + if err != nil { + return err + } + if value != "" { + for _, name := range f.Names() { + + if !filepath.IsAbs(value) && isc.Source() != "" { + basePathAbs, err := filepath.Abs(isc.Source()) + if err != nil { + return err + } + + value = filepath.Join(filepath.Dir(basePathAbs), value) + } + + f.set.Set(name, value) + } + } + } + } + return nil +} + // ApplyInputSourceValue applies a int value to the flagSet if required func (f *IntFlag) ApplyInputSourceValue(context *cli.Context, isc InputSourceContext) error { if f.set != nil { diff --git a/altsrc/flag_generated.go b/altsrc/flag_generated.go index ee2231cd8b..87c7c5d2d4 100644 --- a/altsrc/flag_generated.go +++ b/altsrc/flag_generated.go @@ -268,6 +268,32 @@ func (f *StringFlag) ApplyWithError(set *flag.FlagSet) error { return f.StringFlag.ApplyWithError(set) } +// PathFlag is the flag type that wraps cli.PathFlag to allow +// for other values to be specified +type PathFlag struct { + *cli.PathFlag + set *flag.FlagSet +} + +// NewPathFlag creates a new PathFlag +func NewPathFlag(fl *cli.PathFlag) *PathFlag { + return &PathFlag{PathFlag: fl, set: nil} +} + +// Apply saves the flagSet for later usage calls, then calls the +// wrapped PathFlag.Apply +func (f *PathFlag) Apply(set *flag.FlagSet) { + f.set = set + f.PathFlag.Apply(set) +} + +// ApplyWithError saves the flagSet for later usage calls, then calls the +// wrapped PathFlag.ApplyWithError +func (f *PathFlag) ApplyWithError(set *flag.FlagSet) error { + f.set = set + return f.PathFlag.ApplyWithError(set) +} + // StringSliceFlag is the flag type that wraps cli.StringSliceFlag to allow // for other values to be specified type StringSliceFlag struct { diff --git a/altsrc/flag_test.go b/altsrc/flag_test.go index a4d1d59a26..087e607ee2 100644 --- a/altsrc/flag_test.go +++ b/altsrc/flag_test.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "os" + "runtime" "strings" "testing" "time" @@ -20,6 +21,7 @@ type testApplyInputSource struct { ContextValue flag.Value EnvVarValue string EnvVarName string + SourcePath string MapValue interface{} } @@ -178,6 +180,43 @@ func TestStringApplyInputSourceMethodEnvVarSet(t *testing.T) { }) expect(t, "goodbye", c.String("test")) } +func TestPathApplyInputSourceMethodSet(t *testing.T) { + c := runTest(t, testApplyInputSource{ + Flag: NewPathFlag(&cli.PathFlag{Name: "test"}), + FlagName: "test", + MapValue: "hello", + SourcePath: "/path/to/source/file", + }) + + expected := "/path/to/source/hello" + if runtime.GOOS == "windows" { + expected = `C:\path\to\source\hello` + } + expect(t, expected, c.String("test")) +} + +func TestPathApplyInputSourceMethodContextSet(t *testing.T) { + c := runTest(t, testApplyInputSource{ + Flag: NewPathFlag(&cli.PathFlag{Name: "test"}), + FlagName: "test", + MapValue: "hello", + ContextValueString: "goodbye", + SourcePath: "/path/to/source/file", + }) + expect(t, "goodbye", c.String("test")) +} + +func TestPathApplyInputSourceMethodEnvVarSet(t *testing.T) { + c := runTest(t, testApplyInputSource{ + Flag: NewPathFlag(&cli.PathFlag{Name: "test", EnvVars: []string{"TEST"}}), + FlagName: "test", + MapValue: "hello", + EnvVarName: "TEST", + EnvVarValue: "goodbye", + SourcePath: "/path/to/source/file", + }) + expect(t, "goodbye", c.String("test")) +} func TestIntApplyInputSourceMethodSet(t *testing.T) { c := runTest(t, testApplyInputSource{ @@ -270,7 +309,10 @@ func TestFloat64ApplyInputSourceMethodEnvVarSet(t *testing.T) { } func runTest(t *testing.T, test testApplyInputSource) *cli.Context { - inputSource := &MapInputSource{valueMap: map[interface{}]interface{}{test.FlagName: test.MapValue}} + inputSource := &MapInputSource{ + file: test.SourcePath, + valueMap: map[interface{}]interface{}{test.FlagName: test.MapValue}, + } set := flag.NewFlagSet(test.FlagSetName, flag.ContinueOnError) c := cli.NewContext(nil, set, nil) if test.EnvVarName != "" && test.EnvVarValue != "" { diff --git a/altsrc/input_source_context.go b/altsrc/input_source_context.go index c45ba5ca7d..bb0afdb4db 100644 --- a/altsrc/input_source_context.go +++ b/altsrc/input_source_context.go @@ -8,7 +8,12 @@ import ( // InputSourceContext is an interface used to allow // other input sources to be implemented as needed. +// +// Source returns an identifier for the input source. In case of file source +// it should return path to the file. type InputSourceContext interface { + Source() string + Int(name string) (int, error) Duration(name string) (time.Duration, error) Float64(name string) (float64, error) diff --git a/altsrc/json_source_context.go b/altsrc/json_source_context.go index a197d87e3a..34c7eb9544 100644 --- a/altsrc/json_source_context.go +++ b/altsrc/json_source_context.go @@ -29,7 +29,13 @@ func NewJSONSourceFromFile(f string) (InputSourceContext, error) { if err != nil { return nil, err } - return NewJSONSource(data) + s, err := newJSONSource(data) + if err != nil { + return nil, err + } + + s.file = f + return s, nil } // NewJSONSourceFromReader returns an InputSourceContext suitable for @@ -45,6 +51,10 @@ func NewJSONSourceFromReader(r io.Reader) (InputSourceContext, error) { // NewJSONSource returns an InputSourceContext suitable for retrieving // config variables from raw JSON data. func NewJSONSource(data []byte) (InputSourceContext, error) { + return newJSONSource(data) +} + +func newJSONSource(data []byte) (*jsonSource, error) { var deserialized map[string]interface{} if err := json.Unmarshal(data, &deserialized); err != nil { return nil, err @@ -52,6 +62,10 @@ func NewJSONSource(data []byte) (InputSourceContext, error) { return &jsonSource{deserialized: deserialized}, nil } +func (x *jsonSource) Source() string { + return x.file +} + func (x *jsonSource) Int(name string) (int, error) { i, err := x.getValue(name) if err != nil { @@ -198,5 +212,6 @@ func jsonGetValue(key string, m map[string]interface{}) (interface{}, error) { } type jsonSource struct { + file string deserialized map[string]interface{} } diff --git a/altsrc/map_input_source.go b/altsrc/map_input_source.go index 66a29b6251..37709c77ed 100644 --- a/altsrc/map_input_source.go +++ b/altsrc/map_input_source.go @@ -12,6 +12,7 @@ import ( // MapInputSource implements InputSourceContext to return // data from the map that is loaded. type MapInputSource struct { + file string valueMap map[interface{}]interface{} } @@ -39,6 +40,11 @@ func nestedVal(name string, tree map[interface{}]interface{}) (interface{}, bool return nil, false } +// Source returns the path of the source file +func (fsm *MapInputSource) Source() string { + return fsm.file +} + // Int returns an int from the map if it exists otherwise returns 0 func (fsm *MapInputSource) Int(name string) (int, error) { otherGenericValue, exists := fsm.valueMap[name] diff --git a/altsrc/toml_file_loader.go b/altsrc/toml_file_loader.go index 423e5ff808..1cb2d7b9af 100644 --- a/altsrc/toml_file_loader.go +++ b/altsrc/toml_file_loader.go @@ -86,7 +86,7 @@ func NewTomlSourceFromFile(file string) (InputSourceContext, error) { if err := readCommandToml(tsc.FilePath, &results); err != nil { return nil, fmt.Errorf("Unable to load TOML file '%s': inner error: \n'%v'", tsc.FilePath, err.Error()) } - return &MapInputSource{valueMap: results.Map}, nil + return &MapInputSource{file: file, valueMap: results.Map}, nil } // NewTomlSourceFromFlagFunc creates a new TOML InputSourceContext from a provided flag name and source context. diff --git a/altsrc/yaml_file_loader.go b/altsrc/yaml_file_loader.go index 4c0060b279..37c8d9c7ca 100644 --- a/altsrc/yaml_file_loader.go +++ b/altsrc/yaml_file_loader.go @@ -32,7 +32,7 @@ func NewYamlSourceFromFile(file string) (InputSourceContext, error) { return nil, fmt.Errorf("Unable to load Yaml file '%s': inner error: \n'%v'", ysc.FilePath, err.Error()) } - return &MapInputSource{valueMap: results}, nil + return &MapInputSource{file: file, valueMap: results}, nil } // NewYamlSourceFromFlagFunc creates a new Yaml InputSourceContext from a provided flag name and source context. diff --git a/context_test.go b/context_test.go index edfbaee11d..0509488e20 100644 --- a/context_test.go +++ b/context_test.go @@ -109,6 +109,17 @@ func TestContext_String(t *testing.T) { expect(t, c.String("top-flag"), "hai veld") } +func TestContext_Path(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.String("path", "path/to/file", "path to file") + parentSet := flag.NewFlagSet("test", 0) + parentSet.String("top-path", "path/to/top/file", "doc") + parentCtx := NewContext(nil, parentSet, nil) + c := NewContext(nil, set, parentCtx) + expect(t, c.Path("path"), "path/to/file") + expect(t, c.Path("top-path"), "path/to/top/file") +} + func TestContext_Bool(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", false, "doc") diff --git a/flag-types.json b/flag-types.json index f609e07b80..bd5ec3f70b 100644 --- a/flag-types.json +++ b/flag-types.json @@ -68,6 +68,12 @@ "context_default": "\"\"", "parser": "f.Value.String(), error(nil)" }, + { + "name": "Path", + "type": "string", + "context_default": "\"\"", + "parser": "f.Value.String(), error(nil)" + }, { "name": "StringSlice", "type": "*StringSlice", diff --git a/flag.go b/flag.go index d5ab42e981..364362ccf2 100644 --- a/flag.go +++ b/flag.go @@ -488,6 +488,33 @@ func (f *StringFlag) ApplyWithError(set *flag.FlagSet) error { return nil } +// Apply populates the flag given the flag set and environment +// Ignores errors +func (f *PathFlag) Apply(set *flag.FlagSet) { + f.ApplyWithError(set) +} + +// ApplyWithError populates the flag given the flag set and environment +func (f *PathFlag) ApplyWithError(set *flag.FlagSet) error { + if f.EnvVars != nil { + for _, envVar := range f.EnvVars { + if envVal, ok := syscall.Getenv(envVar); ok { + f.Value = envVal + break + } + } + } + + for _, name := range f.Names() { + if f.Destination != nil { + set.StringVar(f.Destination, name, f.Value, f.Usage) + continue + } + set.String(name, f.Value, f.Usage) + } + return nil +} + // Apply populates the flag given the flag set and environment // Ignores errors func (f *IntFlag) Apply(set *flag.FlagSet) { diff --git a/flag_generated.go b/flag_generated.go index 187a6cae47..a95815012e 100644 --- a/flag_generated.go +++ b/flag_generated.go @@ -450,6 +450,51 @@ func lookupString(name string, set *flag.FlagSet) string { return "" } +// PathFlag is a flag with type string +type PathFlag struct { + Name string + Aliases []string + Usage string + EnvVars []string + Hidden bool + Value string + DefaultText string + + Destination *string +} + +// String returns a readable representation of this value +// (for usage defaults) +func (f *PathFlag) String() string { + return FlagStringer(f) +} + +// Names returns the names of the flag +func (f *PathFlag) Names() []string { + return flagNames(f) +} + +// Path looks up the value of a local PathFlag, returns +// "" if not found +func (c *Context) Path(name string) string { + if fs := lookupFlagSet(name, c); fs != nil { + return lookupPath(name, fs) + } + return "" +} + +func lookupPath(name string, set *flag.FlagSet) string { + f := set.Lookup(name) + if f != nil { + parsed, err := f.Value.String(), error(nil) + if err != nil { + return "" + } + return parsed + } + return "" +} + // StringSliceFlag is a flag with type *StringSlice type StringSliceFlag struct { Name string diff --git a/flag_test.go b/flag_test.go index 2c42176430..6ed3a76ed4 100644 --- a/flag_test.go +++ b/flag_test.go @@ -96,6 +96,7 @@ func TestFlagsFromEnv(t *testing.T) { {"foobar", newSetInt64Slice(), &Int64SliceFlag{Name: "seconds", EnvVars: []string{"SECONDS"}}, `could not parse "foobar" as int64 slice value for flag seconds: .*`}, {"foo", "foo", &StringFlag{Name: "name", EnvVars: []string{"NAME"}}, ""}, + {"path", "path", &PathFlag{Name: "path", EnvVars: []string{"PATH"}}, ""}, {"foo,bar", newSetStringSlice("foo", "bar"), &StringSliceFlag{Name: "names", EnvVars: []string{"NAMES"}}, ""}, @@ -206,6 +207,56 @@ func TestStringFlagApply_SetsAllNames(t *testing.T) { expect(t, v, "YUUUU") } +var pathFlagTests = []struct { + name string + aliases []string + usage string + value string + expected string +}{ + {"f", nil, "", "", "-f value\t"}, + {"f", nil, "Path is the `path` of file", "/path/to/file", "-f path\tPath is the path of file (default: \"/path/to/file\")"}, +} + +func TestPathFlagHelpOutput(t *testing.T) { + for _, test := range pathFlagTests { + flag := &PathFlag{Name: test.name, Aliases: test.aliases, Usage: test.usage, Value: test.value} + output := flag.String() + + if output != test.expected { + t.Errorf("%q does not match %q", output, test.expected) + } + } +} + +func TestPathFlagWithEnvVarHelpOutput(t *testing.T) { + clearenv() + os.Setenv("APP_PATH", "/path/to/file") + for _, test := range pathFlagTests { + flag := &PathFlag{Name: test.name, Aliases: test.aliases, Value: test.value, EnvVars: []string{"APP_PATH"}} + output := flag.String() + + expectedSuffix := " [$APP_PATH]" + if runtime.GOOS == "windows" { + expectedSuffix = " [%APP_PATH%]" + } + if !strings.HasSuffix(output, expectedSuffix) { + t.Errorf("%s does not end with"+expectedSuffix, output) + } + } +} + +func TestPathFlagApply_SetsAllNames(t *testing.T) { + v := "mmm" + fl := PathFlag{Name: "path", Aliases: []string{"p", "PATH"}, Destination: &v} + set := flag.NewFlagSet("test", 0) + fl.Apply(set) + + err := set.Parse([]string{"--path", "/path/to/file/path", "-p", "/path/to/file/p", "--PATH", "/path/to/file/PATH"}) + expect(t, err, nil) + expect(t, v, "/path/to/file/PATH") +} + var stringSliceFlagTests = []struct { name string aliases []string