Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor flag handling logic #691

Merged
merged 3 commits into from Dec 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
134 changes: 52 additions & 82 deletions command.go
@@ -1,6 +1,7 @@
package cli

import (
"flag"
"fmt"
"io/ioutil"
"sort"
Expand Down Expand Up @@ -110,43 +111,7 @@ func (c Command) Run(ctx *Context) (err error) {
)
}

set, err := flagSet(c.Name, c.Flags)
if err != nil {
return err
}
set.SetOutput(ioutil.Discard)
firstFlagIndex, terminatorIndex := getIndexes(ctx)
flagArgs, regularArgs := getAllArgs(ctx.Args(), firstFlagIndex, terminatorIndex)
if c.UseShortOptionHandling {
flagArgs = translateShortOptions(flagArgs)
}
if c.SkipFlagParsing {
err = set.Parse(append([]string{"--"}, ctx.Args().Tail()...))
} else if !c.SkipArgReorder {
if firstFlagIndex > -1 {
err = set.Parse(append(flagArgs, regularArgs...))
} else {
err = set.Parse(ctx.Args().Tail())
}
} else if c.UseShortOptionHandling {
if terminatorIndex == -1 && firstFlagIndex > -1 {
// Handle shortname AND no options
err = set.Parse(append(regularArgs, flagArgs...))
} else {
// Handle shortname and options
err = set.Parse(flagArgs)
}
} else {
err = set.Parse(append(regularArgs, flagArgs...))
}

nerr := normalizeFlags(c.Flags, set)
if nerr != nil {
fmt.Fprintln(ctx.App.Writer, nerr)
fmt.Fprintln(ctx.App.Writer)
ShowCommandHelp(ctx, c.Name)
return nerr
}
set, err := c.parseFlags(ctx.Args().Tail())

context := NewContext(ctx.App, set, ctx)
context.Command = c
Expand Down Expand Up @@ -205,60 +170,65 @@ func (c Command) Run(ctx *Context) (err error) {
return err
}

func getIndexes(ctx *Context) (int, int) {
firstFlagIndex := -1
terminatorIndex := -1
for index, arg := range ctx.Args() {
if arg == "--" {
terminatorIndex = index
break
} else if arg == "-" {
// Do nothing. A dash alone is not really a flag.
continue
} else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 {
firstFlagIndex = index
}
func (c *Command) parseFlags(args Args) (*flag.FlagSet, error) {
set, err := flagSet(c.Name, c.Flags)
if err != nil {
return nil, err
}
if len(ctx.Args()) > 0 && !strings.HasPrefix(ctx.Args()[0], "-") && firstFlagIndex == -1 {
return -1, -1
set.SetOutput(ioutil.Discard)

if c.SkipFlagParsing {
return set, set.Parse(append([]string{c.Name, "--"}, args...))
}

return firstFlagIndex, terminatorIndex
if c.UseShortOptionHandling {
args = translateShortOptions(args)
}

}
if !c.SkipArgReorder {
args = reorderArgs(args)
}

// copyStringslice takes a string slice and copies it
func copyStringSlice(slice []string, start, end int) []string {
newSlice := make([]string, end-start)
copy(newSlice, slice[start:end])
return newSlice
}
err = set.Parse(args)
if err != nil {
return nil, err
}

// getAllArgs extracts and returns two string slices representing
// regularArgs and flagArgs
func getAllArgs(args []string, firstFlagIndex, terminatorIndex int) ([]string, []string) {
var regularArgs []string
// if there are no options, the we set the index to 1 manually
if firstFlagIndex == -1 {
firstFlagIndex = 1
regularArgs = copyStringSlice(args, 0, len(args))
} else {
regularArgs = copyStringSlice(args, 1, firstFlagIndex)
err = normalizeFlags(c.Flags, set)
if err != nil {
return nil, err
}
var flagArgs []string
// a flag terminatorIndex was found in the input. we need to collect
// flagArgs based on it.
if terminatorIndex > -1 {
flagArgs = copyStringSlice(args, firstFlagIndex, terminatorIndex)
additionalRegularArgs := copyStringSlice(args, terminatorIndex, len(args))
regularArgs = append(regularArgs, additionalRegularArgs...)
for _, i := range additionalRegularArgs {
regularArgs = append(regularArgs, i)

return set, nil
}

// reorderArgs moves all flags before arguments as this is what flag expects
func reorderArgs(args []string) []string {
var nonflags, flags []string

readFlagValue := false
for i, arg := range args {
if arg == "--" {
nonflags = append(nonflags, args[i:]...)
break
}

if readFlagValue {
readFlagValue = false
flags = append(flags, arg)
continue
}

if arg != "-" && strings.HasPrefix(arg, "-") {
flags = append(flags, arg)

readFlagValue = !strings.Contains(arg, "=")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this handle commands like

command --foo=bar argument ...

Won't argument end up being marked as a flag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, but you made me realize there wasn't an explicit test for that logic. Added in 0671b16

} else {
nonflags = append(nonflags, arg)
}
} else {
flagArgs = args[firstFlagIndex:]
}
return flagArgs, regularArgs

return append(flags, nonflags...)
}

func translateShortOptions(flagArgs Args) []string {
Expand Down
38 changes: 38 additions & 0 deletions command_test.go
Expand Up @@ -243,3 +243,41 @@ func TestCommand_Run_SubcommandsCanUseErrWriter(t *testing.T) {
t.Fatal(err)
}
}

func TestCommandFlagReordering(t *testing.T) {
cases := []struct {
testArgs []string
expectedValue string
expectedArgs []string
expectedErr error
}{
{[]string{"some-exec", "some-command", "some-arg", "--flag", "foo"}, "foo", []string{"some-arg"}, nil},
{[]string{"some-exec", "some-command", "some-arg", "--flag=foo"}, "foo", []string{"some-arg"}, nil},
{[]string{"some-exec", "some-command", "--flag=foo", "some-arg"}, "foo", []string{"some-arg"}, nil},
}

for _, c := range cases {
value := ""
args := []string{}
app := &App{
Commands: []Command{
{
Name: "some-command",
Flags: []Flag{
StringFlag{Name: "flag"},
},
Action: func(c *Context) {
fmt.Printf("%+v\n", c.String("flag"))
value = c.String("flag")
args = c.Args()
},
},
},
}

err := app.Run(c.testArgs)
expect(t, err, c.expectedErr)
expect(t, value, c.expectedValue)
expect(t, args, c.expectedArgs)
}
}