Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Add ability to mark flags as required or exclusive as a group
This change adds two features for dealing with flags:
 - requiring flags be provided as a group (or not at all)
 - requiring flags be mutually exclusive of each other

By utilizing the flag annotations we can mark which flag groups
a flag is a part of and during the parsing process we track which
ones we have seen or not.

A flag may be a part of multiple groups. The list of flags and the
type of group (required together or exclusive) make it a unique group.

Signed-off-by: John Schnake <jschnake@vmware.com>
  • Loading branch information
johnSchnake committed Apr 3, 2022
1 parent d622355 commit 90f9a16
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 7 deletions.
4 changes: 4 additions & 0 deletions command.go
Expand Up @@ -863,6 +863,10 @@ func (c *Command) execute(a []string) (err error) {
if err := c.validateRequiredFlags(); err != nil {
return err
}
if err := c.validateFlagGroups(); err != nil {
return err
}

if c.RunE != nil {
if err := c.RunE(c, argWoFlags); err != nil {
return err
Expand Down
135 changes: 135 additions & 0 deletions flag_groups.go
@@ -0,0 +1,135 @@
// Copyright © 2022 Steve Francia <spf@spf13.com>.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cobra

import (
"fmt"
"sort"
"strings"

flag "github.com/spf13/pflag"
)

const (
RequiredAsGroup = "cobra_annotation_required_if_others_set"
MutuallyExclusive = "cobra_annotation_mutually_exclusive"
)

func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f.Annotations == nil {
f.Annotations = map[string][]string{}
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
f.Annotations[RequiredAsGroup] = append(f.Annotations[RequiredAsGroup], strings.Join(flagNames, " "))
}
}

func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f.Annotations == nil {
f.Annotations = map[string][]string{}
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
f.Annotations[MutuallyExclusive] = append(f.Annotations[MutuallyExclusive], strings.Join(flagNames, " "))
}
}

// validateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic.
func (c *Command) validateFlagGroups() error {
if c.DisableFlagParsing {
return nil
}

flags := c.Flags()

// groupStatus format is the list of flags as a unique ID,
// then a map of each flag name and whether it is set or not.
groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
groupInfo, found := pflag.Annotations[RequiredAsGroup]
if found {
// If not tracking a group, start.
for _, group := range groupInfo {
if groupStatus[group] == nil {
groupStatus[group] = map[string]bool{}
// Track each flag by name.
flagnames := strings.Split(group, " ")
for _, name := range flagnames {
groupStatus[group][name] = false
}
}

// Record we've seen this flag for each group its in.
groupStatus[group][pflag.Name] = pflag.Changed
}
}

groupInfo, found = pflag.Annotations[MutuallyExclusive]
if found {
// If not tracking a group, start.
for _, group := range groupInfo {
if mutuallyExclusiveGroupStatus[group] == nil {
mutuallyExclusiveGroupStatus[group] = map[string]bool{}
// Track each flag by name.
flagnames := strings.Split(group, " ")
for _, name := range flagnames {
mutuallyExclusiveGroupStatus[group][name] = false
}
}

// Record we've seen this flag for each group its in.
mutuallyExclusiveGroupStatus[group][pflag.Name] = pflag.Changed
}
}
})

// Now review the groups and form errors as needed.
errMsgs := []string{}
for flagList, flagnameAndStatus := range groupStatus {
unset := []string{}
for flagname, isSet := range flagnameAndStatus {
if !isSet {
unset = append(unset, flagname)
}
}
if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
continue
}
sort.Strings(unset)
errMsgs = append(errMsgs, fmt.Sprintf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset))
}

for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
set := []string{}
for flagname, isSet := range flagnameAndStatus {
if isSet {
set = append(set, flagname)
}
}
if len(set) == 0 || len(set) == 1 {
continue
}
sort.Strings(set)
errMsgs = append(errMsgs, fmt.Sprintf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set))
}

if len(errMsgs) > 0 {
return fmt.Errorf(strings.Join(errMsgs, `, `))
}
return nil
}
91 changes: 91 additions & 0 deletions flag_groups_test.go
@@ -0,0 +1,91 @@
// Copyright © 2022 Steve Francia <spf@spf13.com>.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cobra

import (
"strings"
"testing"
)

func TestValidateFlagGroups(t *testing.T) {
getCmd := func() *Command {
c := &Command{
Use: "testcmd",
Run: func(cmd *Command, args []string) {
}}
// Define lots of flags to utilize for testing.
for _, v := range []string{"a", "b", "c", "d", "e", "f", "g"} {
c.Flags().String(v, "", "")
}
return c
}

// Each test case uses a unique command from the function above.
testcases := []struct {
desc string
flagGroupsRequired []string
flagGroupsExclusive []string
args []string
expectErr string
}{
{
desc: "No flags no problem",
},{
desc: "No flags no problem even with conflicting groups",
flagGroupsRequired: []string{"a b"},
flagGroupsExclusive: []string{"a b"},
},{
desc: "Required flag group not satisfied",
flagGroupsRequired: []string{"a b c"},
args: []string{"--a=foo"},
expectErr: "if any flags in the group [a b c] are set they must all be set; missing [b c]",
},{
desc: "Exclusive flag group not satisfied",
flagGroupsExclusive: []string{"a b c"},
args: []string{"--a=foo","--b=foo"},
expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set",
},{
desc: "Multiple required flag group not satisfied",
flagGroupsRequired: []string{"a b c", "a d"},
args: []string{"--c=foo","--d=foo"},
expectErr: "if any flags in the group [a b c] are set they must all be set; missing [a b], if any flags in the group [a d] are set they must all be set; missing [a]",
},{
desc: "Multiple exclusive flag group not satisfied",
flagGroupsExclusive: []string{"a b c", "a d"},
args: []string{"testcmd","--a=foo","--c=foo","--d=foo"},
expectErr: "if any flags in the group [a b c] are set none of the others can be; [a c] were all set, if any flags in the group [a d] are set none of the others can be; [a d] were all set",
},
}
for _, tc := range testcases {
t.Run(tc.desc, func(t *testing.T) {
c := getCmd()
for _, flagGroup := range tc.flagGroupsRequired {
c.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.flagGroupsExclusive {
c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
}
c.SetArgs(tc.args)
err := c.Execute()
switch {
case err == nil && len(tc.expectErr) > 0:
t.Errorf("Expected error %q but got nil", tc.expectErr)
case err == nil && len(tc.expectErr) == 0:
case err != nil && err.Error() == tc.expectErr:
case err != nil && err.Error() != tc.expectErr:
t.Errorf("Expected error %q but got %v", tc.expectErr, err)
}
})
}
}
8 changes: 1 addition & 7 deletions go.sum
Expand Up @@ -2,17 +2,11 @@ github.com/cpuguy83/go-md2man/v2 v2.0.1 h1:r/myEWzV9lfsM1tFLgDyu0atFtJ1fXn261LKY
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=

0 comments on commit 90f9a16

Please sign in to comment.