From 14366f7030b89bce4ccd084eecfd5f52b8a96c7c Mon Sep 17 00:00:00 2001 From: Wendell Sun Date: Tue, 15 Feb 2022 23:49:41 +0800 Subject: [PATCH] feat: flag action --- app.go | 26 +++++++++++++++++ app_test.go | 67 +++++++++++++++++++++++++++++++++++++++++++ command.go | 4 +++ flag.go | 1 + flag_bool.go | 9 ++++++ flag_duration.go | 9 ++++++ flag_float64.go | 9 ++++++ flag_float64_slice.go | 9 ++++++ flag_generic.go | 9 ++++++ flag_int.go | 9 ++++++ flag_int64.go | 9 ++++++ flag_int64_slice.go | 9 ++++++ flag_int_slice.go | 33 +++++++++++++++++++++ flag_path.go | 9 ++++++ flag_string.go | 9 ++++++ flag_string_slice.go | 9 ++++++ flag_timestamp.go | 9 ++++++ flag_uint.go | 9 ++++++ flag_uint64.go | 9 ++++++ 19 files changed, 257 insertions(+) diff --git a/app.go b/app.go index 2ffacd512c..2a291ae40a 100644 --- a/app.go +++ b/app.go @@ -342,6 +342,10 @@ func (a *App) RunContext(ctx context.Context, arguments []string) (err error) { } } + if err = runFlagActions(cCtx, a.Flags); err != nil { + return err + } + var c *Command args := cCtx.Args() if args.Present() { @@ -523,6 +527,10 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { } } + if err = runFlagActions(cCtx, a.Flags); err != nil { + return err + } + args := cCtx.Args() if args.Present() { name := args.First() @@ -646,6 +654,24 @@ func (a *App) argsWithDefaultCommand(oldArgs Args) Args { return oldArgs } +func runFlagActions(c *Context, fs []Flag) error { + for _, f := range fs { + isSet := false + for _, name := range f.Names() { + if c.IsSet(name) { + isSet = true + break + } + } + if isSet { + if err := f.RunAction(c); err != nil { + return err + } + } + } + return nil +} + // Author represents someone who has contributed to a cli project. type Author struct { Name string // The Authors name diff --git a/app_test.go b/app_test.go index 64316fc79e..9cd067eebc 100644 --- a/app_test.go +++ b/app_test.go @@ -2357,6 +2357,10 @@ func (c *customBoolFlag) Apply(set *flag.FlagSet) error { return nil } +func (c *customBoolFlag) RunAction(*Context) error { + return nil +} + func (c *customBoolFlag) IsSet() bool { return false } @@ -2576,3 +2580,66 @@ func TestSetupInitializesOnlyNilWriters(t *testing.T) { t.Errorf("expected a.Writer to be os.Stdout") } } + +func TestFlagAction(t *testing.T) { + r := []string{} + actionFunc := func(c *Context, s string) error { + r = append(r, s) + return nil + } + + app := &App{ + Name: "command", + Writer: io.Discard, + Flags: []Flag{&StringFlag{Name: "flag", Action: actionFunc}}, + Commands: []*Command{ + { + Name: "command1", + Flags: []Flag{&StringFlag{Name: "flag1", Aliases: []string{"f1"}, Action: actionFunc}}, + Subcommands: []*Command{ + { + Name: "command2", + Flags: []Flag{&StringFlag{Name: "flag2", Action: actionFunc}}, + }, + }, + }, + }, + } + + tests := []struct { + args []string + exp []string + }{ + { + args: []string{"command", "--flag=f"}, + exp: []string{"f"}, + }, + { + args: []string{"command", "command1", "-f1=f1", "command2"}, + exp: []string{"f1"}, + }, + { + args: []string{"command", "command1", "-f1=f1", "command2", "--flag2=f2"}, + exp: []string{"f1", "f2"}, + }, + { + args: []string{"command", "--flag=f", "command1", "-flag1=f1"}, + exp: []string{"f", "f1"}, + }, + { + args: []string{"command", "--flag=f", "command1", "-f1=f1"}, + exp: []string{"f", "f1"}, + }, + { + args: []string{"command", "--flag=f", "command1", "-f1=f1", "command2", "--flag2=f2"}, + exp: []string{"f", "f1", "f2"}, + }, + } + + for _, test := range tests { + r = []string{} + err := app.Run(test.args) + expect(t, err, nil) + expect(t, r, test.exp) + } +} diff --git a/command.go b/command.go index 13b79de46d..cd8ea91c70 100644 --- a/command.go +++ b/command.go @@ -165,6 +165,10 @@ func (c *Command) Run(ctx *Context) (err error) { } } + if err = runFlagActions(cCtx, c.Flags); err != nil { + return err + } + if c.Action == nil { c.Action = helpSubcommand.Action } diff --git a/flag.go b/flag.go index 050bb4b1d1..a1e37fc5a0 100644 --- a/flag.go +++ b/flag.go @@ -92,6 +92,7 @@ type Flag interface { Apply(*flag.FlagSet) error Names() []string IsSet() bool + RunAction(*Context) error } // RequiredFlag is an interface that allows us to mark flags as required diff --git a/flag_bool.go b/flag_bool.go index cb937ae65d..0be27e3aa6 100644 --- a/flag_bool.go +++ b/flag_bool.go @@ -92,6 +92,15 @@ func (f *BoolFlag) GetEnvVars() []string { return f.EnvVars } +// RunAction executes flag action if set +func (f *BoolFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Bool(f.Name)) + } + + return nil +} + // Apply populates the flag given the flag set and environment func (f *BoolFlag) Apply(set *flag.FlagSet) error { if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { diff --git a/flag_duration.go b/flag_duration.go index 5178c6ae12..31db4102e6 100644 --- a/flag_duration.go +++ b/flag_duration.go @@ -70,6 +70,15 @@ func (f *DurationFlag) Get(ctx *Context) time.Duration { return ctx.Duration(f.Name) } +// RunAction executes flag action if set +func (f *DurationFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Duration(f.Name)) + } + + return nil +} + // Duration looks up the value of a local DurationFlag, returns // 0 if not found func (cCtx *Context) Duration(name string) time.Duration { diff --git a/flag_float64.go b/flag_float64.go index 2d31739bc6..bce26c1958 100644 --- a/flag_float64.go +++ b/flag_float64.go @@ -70,6 +70,15 @@ func (f *Float64Flag) Get(ctx *Context) float64 { return ctx.Float64(f.Name) } +// RunAction executes flag action if set +func (f *Float64Flag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Float64(f.Name)) + } + + return nil +} + // Float64 looks up the value of a local Float64Flag, returns // 0 if not found func (cCtx *Context) Float64(name string) float64 { diff --git a/flag_float64_slice.go b/flag_float64_slice.go index e4aff73da0..2cb5e4adfa 100644 --- a/flag_float64_slice.go +++ b/flag_float64_slice.go @@ -181,6 +181,15 @@ func (f *Float64SliceFlag) stringify() string { return stringifySliceFlag(f.Usage, f.Names(), defaultVals) } +// RunAction executes flag action if set +func (f *Float64SliceFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Float64Slice(f.Name)) + } + + return nil +} + // Float64Slice looks up the value of a local Float64SliceFlag, returns // nil if not found func (cCtx *Context) Float64Slice(name string) []float64 { diff --git a/flag_generic.go b/flag_generic.go index 6a19aef36c..5034728c42 100644 --- a/flag_generic.go +++ b/flag_generic.go @@ -73,6 +73,15 @@ func (f *GenericFlag) Get(ctx *Context) interface{} { return ctx.Generic(f.Name) } +// RunAction executes flag action if set +func (f *GenericFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Generic(f.Name)) + } + + return nil +} + // Generic looks up the value of a local GenericFlag, returns // nil if not found func (cCtx *Context) Generic(name string) interface{} { diff --git a/flag_int.go b/flag_int.go index 0f5c403b3d..af98e936fb 100644 --- a/flag_int.go +++ b/flag_int.go @@ -71,6 +71,15 @@ func (f *IntFlag) Get(ctx *Context) int { return ctx.Int(f.Name) } +// RunAction executes flag action if set +func (f *IntFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Int(f.Name)) + } + + return nil +} + // Int looks up the value of a local IntFlag, returns // 0 if not found func (cCtx *Context) Int(name string) int { diff --git a/flag_int64.go b/flag_int64.go index a392275def..ebe46d21fc 100644 --- a/flag_int64.go +++ b/flag_int64.go @@ -70,6 +70,15 @@ func (f *Int64Flag) Get(ctx *Context) int64 { return ctx.Int64(f.Name) } +// RunAction executes flag action if set +func (f *Int64Flag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Int64(f.Name)) + } + + return nil +} + // Int64 looks up the value of a local Int64Flag, returns // 0 if not found func (cCtx *Context) Int64(name string) int64 { diff --git a/flag_int64_slice.go b/flag_int64_slice.go index ead4e77570..d4a11b6a81 100644 --- a/flag_int64_slice.go +++ b/flag_int64_slice.go @@ -179,6 +179,15 @@ func (f *Int64SliceFlag) stringify() string { return stringifySliceFlag(f.Usage, f.Names(), defaultVals) } +// RunAction executes flag action if set +func (f *Int64SliceFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Int64Slice(f.Name)) + } + + return nil +} + // Int64Slice looks up the value of a local Int64SliceFlag, returns // nil if not found func (cCtx *Context) Int64Slice(name string) []int64 { diff --git a/flag_int_slice.go b/flag_int_slice.go index b40e0d8d1a..68ce483650 100644 --- a/flag_int_slice.go +++ b/flag_int_slice.go @@ -92,6 +92,29 @@ func (i *IntSlice) Get() interface{} { return *i } +<<<<<<< HEAD +======= +// IntSliceFlag is a flag with type *IntSlice +type IntSliceFlag struct { + Name string + Aliases []string + Usage string + EnvVars []string + FilePath string + Required bool + Hidden bool + Value *IntSlice + DefaultText string + HasBeenSet bool + Action func(*Context, []int) error +} + +// IsSet returns whether or not the flag has been set through env or file +func (f *IntSliceFlag) IsSet() bool { + return f.HasBeenSet +} + +>>>>>>> e132f01 (feat: flag action) // String returns a readable representation of this value // (for usage defaults) func (f *IntSliceFlag) String() string { @@ -174,9 +197,19 @@ func (f *IntSliceFlag) Apply(set *flag.FlagSet) error { return nil } +<<<<<<< HEAD // Get returns the flag’s value in the given Context. func (f *IntSliceFlag) Get(ctx *Context) []int { return ctx.IntSlice(f.Name) +======= +// RunAction executes flag action if set +func (f *IntSliceFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.IntSlice(f.Name)) + } + + return nil +>>>>>>> e132f01 (feat: flag action) } func (f *IntSliceFlag) stringify() string { diff --git a/flag_path.go b/flag_path.go index 7c87a8900d..911819db94 100644 --- a/flag_path.go +++ b/flag_path.go @@ -67,6 +67,15 @@ func (f *PathFlag) Get(ctx *Context) string { return ctx.Path(f.Name) } +// RunAction executes flag action if set +func (f *PathFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Path(f.Name)) + } + + return nil +} + // Path looks up the value of a local PathFlag, returns // "" if not found func (cCtx *Context) Path(name string) string { diff --git a/flag_string.go b/flag_string.go index c8da38f92d..b7163ba6f6 100644 --- a/flag_string.go +++ b/flag_string.go @@ -65,6 +65,15 @@ func (f *StringFlag) Get(ctx *Context) string { return ctx.String(f.Name) } +// RunAction executes flag action if set +func (f *StringFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.String(f.Name)) + } + + return nil +} + // String looks up the value of a local StringFlag, returns // "" if not found func (cCtx *Context) String(name string) string { diff --git a/flag_string_slice.go b/flag_string_slice.go index 9d69342db1..7b46a24742 100644 --- a/flag_string_slice.go +++ b/flag_string_slice.go @@ -171,6 +171,15 @@ func (f *StringSliceFlag) stringify() string { return stringifySliceFlag(f.Usage, f.Names(), defaultVals) } +// RunAction executes flag action if set +func (f *StringSliceFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.StringSlice(f.Name)) + } + + return nil +} + // StringSlice looks up the value of a local StringSliceFlag, returns // nil if not found func (cCtx *Context) StringSlice(name string) []string { diff --git a/flag_timestamp.go b/flag_timestamp.go index 16f42dd011..17bc8d7571 100644 --- a/flag_timestamp.go +++ b/flag_timestamp.go @@ -148,6 +148,15 @@ func (f *TimestampFlag) Get(ctx *Context) *time.Time { return ctx.Timestamp(f.Name) } +// RunAction executes flag action if set +func (f *TimestampFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Timestamp(f.Name)) + } + + return nil +} + // Timestamp gets the timestamp from a flag name func (cCtx *Context) Timestamp(name string) *time.Time { if fs := cCtx.lookupFlagSet(name); fs != nil { diff --git a/flag_uint.go b/flag_uint.go index d25ff73ad7..f9acb6d06c 100644 --- a/flag_uint.go +++ b/flag_uint.go @@ -46,6 +46,15 @@ func (f *UintFlag) Apply(set *flag.FlagSet) error { return nil } +// RunAction executes flag action if set +func (f *UintFlag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Uint(f.Name)) + } + + return nil +} + // GetValue returns the flags value as string representation and an empty // string if the flag takes no value at all. func (f *UintFlag) GetValue() string { diff --git a/flag_uint64.go b/flag_uint64.go index 975c73393b..09590d7353 100644 --- a/flag_uint64.go +++ b/flag_uint64.go @@ -46,6 +46,15 @@ func (f *Uint64Flag) Apply(set *flag.FlagSet) error { return nil } +// RunAction executes flag action if set +func (f *Uint64Flag) RunAction(c *Context) error { + if f.Action != nil { + return f.Action(c, c.Uint64(f.Name)) + } + + return nil +} + // GetValue returns the flags value as string representation and an empty // string if the flag takes no value at all. func (f *Uint64Flag) GetValue() string {