Skip to content

Commit

Permalink
Add exported functions to preserve pkg/flag compatibility (#220)
Browse files Browse the repository at this point in the history
* Rename out() to Output()

This brings behavior inline with go's flag library, and allows for
printing output directly to whatever the current FlagSet is using for
output. This change will make it easier to correctly emit output to
stdout or stderr (e.g. a user has requested a help screen, which
should emit to stdout since it's the desired outcome).

* improve compat. with pkg/flag by adding Name()

pkg/flag has a public `Name()` function, which returns the name of the
flag set when called. This commit adds that function, as well as a
test for it.

* Streamline testing Name()

Testing `Name()` will move into its own explicit test, instead of
running inline during `TestAddFlagSet()`.

Co-authored-by: Chloe Kudryavtsev <toast@toast.cafe>

Co-authored-by: Chloe Kudryavtsev <toast@toast.cafe>
  • Loading branch information
mckern and Chloe Kudryavtsev committed May 4, 2020
1 parent 2e9d26c commit 81378bb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
29 changes: 18 additions & 11 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ type FlagSet struct {
args []string // arguments after flags
argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no --
errorHandling ErrorHandling
output io.Writer // nil means stderr; use out() accessor
output io.Writer // nil means stderr; use Output() accessor
interspersed bool // allow interspersed option/non-option args
normalizeNameFunc func(f *FlagSet, name string) NormalizedName

Expand Down Expand Up @@ -255,13 +255,20 @@ func (f *FlagSet) normalizeFlagName(name string) NormalizedName {
return n(f, name)
}

func (f *FlagSet) out() io.Writer {
// Output returns the destination for usage and error messages. os.Stderr is returned if
// output was not set or was set to nil.
func (f *FlagSet) Output() io.Writer {
if f.output == nil {
return os.Stderr
}
return f.output
}

// Name returns the name of the flag set.
func (f *FlagSet) Name() string {
return f.name
}

// SetOutput sets the destination for usage and error messages.
// If output is nil, os.Stderr is used.
func (f *FlagSet) SetOutput(output io.Writer) {
Expand Down Expand Up @@ -358,7 +365,7 @@ func (f *FlagSet) ShorthandLookup(name string) *Flag {
}
if len(name) > 1 {
msg := fmt.Sprintf("can not look up shorthand which is more than one ASCII character: %q", name)
fmt.Fprintf(f.out(), msg)
fmt.Fprintf(f.Output(), msg)
panic(msg)
}
c := name[0]
Expand Down Expand Up @@ -482,7 +489,7 @@ func (f *FlagSet) Set(name, value string) error {
}

if flag.Deprecated != "" {
fmt.Fprintf(f.out(), "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
fmt.Fprintf(f.Output(), "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
}
return nil
}
Expand Down Expand Up @@ -523,7 +530,7 @@ func Set(name, value string) error {
// otherwise, the default values of all defined flags in the set.
func (f *FlagSet) PrintDefaults() {
usages := f.FlagUsages()
fmt.Fprint(f.out(), usages)
fmt.Fprint(f.Output(), usages)
}

// defaultIsZeroValue returns true if the default value for this flag represents
Expand Down Expand Up @@ -758,7 +765,7 @@ func PrintDefaults() {

// defaultUsage is the default function to print a usage message.
func defaultUsage(f *FlagSet) {
fmt.Fprintf(f.out(), "Usage of %s:\n", f.name)
fmt.Fprintf(f.Output(), "Usage of %s:\n", f.name)
f.PrintDefaults()
}

Expand Down Expand Up @@ -844,7 +851,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {
_, alreadyThere := f.formal[normalizedFlagName]
if alreadyThere {
msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name)
fmt.Fprintln(f.out(), msg)
fmt.Fprintln(f.Output(), msg)
panic(msg) // Happens only if flags are declared with identical names
}
if f.formal == nil {
Expand All @@ -860,7 +867,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {
}
if len(flag.Shorthand) > 1 {
msg := fmt.Sprintf("%q shorthand is more than one ASCII character", flag.Shorthand)
fmt.Fprintf(f.out(), msg)
fmt.Fprintf(f.Output(), msg)
panic(msg)
}
if f.shorthands == nil {
Expand All @@ -870,7 +877,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {
used, alreadyThere := f.shorthands[c]
if alreadyThere {
msg := fmt.Sprintf("unable to redefine %q shorthand in %q flagset: it's already used for %q flag", c, f.name, used.Name)
fmt.Fprintf(f.out(), msg)
fmt.Fprintf(f.Output(), msg)
panic(msg)
}
f.shorthands[c] = flag
Expand Down Expand Up @@ -909,7 +916,7 @@ func VarP(value Value, name, shorthand, usage string) {
func (f *FlagSet) failf(format string, a ...interface{}) error {
err := fmt.Errorf(format, a...)
if f.errorHandling != ContinueOnError {
fmt.Fprintln(f.out(), err)
fmt.Fprintln(f.Output(), err)
f.usage()
}
return err
Expand Down Expand Up @@ -1060,7 +1067,7 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
}

if flag.ShorthandDeprecated != "" {
fmt.Fprintf(f.out(), "Flag shorthand -%s has been deprecated, %s\n", flag.Shorthand, flag.ShorthandDeprecated)
fmt.Fprintf(f.Output(), "Flag shorthand -%s has been deprecated, %s\n", flag.Shorthand, flag.ShorthandDeprecated)
}

err = fn(flag, value)
Expand Down
21 changes: 21 additions & 0 deletions flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ func TestAnnotation(t *testing.T) {
}
}

func TestName(t *testing.T) {
flagSetName := "bob"
f := NewFlagSet(flagSetName, ContinueOnError)

givenName := f.Name()
if givenName != flagSetName {
t.Errorf("Unexpected result when retrieving a FlagSet's name: expected %s, but found %s", flagSetName, givenName)
}
}

func testParse(f *FlagSet, t *testing.T) {
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
Expand Down Expand Up @@ -854,6 +864,17 @@ func TestSetOutput(t *testing.T) {
}
}

func TestOutput(t *testing.T) {
var flags FlagSet
var buf bytes.Buffer
expect := "an example string"
flags.SetOutput(&buf)
fmt.Fprint(flags.Output(), expect)
if out := buf.String(); !strings.Contains(out, expect) {
t.Errorf("expected output %q; got %q", expect, out)
}
}

// This tests that one can reset the flags. This still works but not well, and is
// superseded by FlagSet.
func TestChangingArgs(t *testing.T) {
Expand Down

0 comments on commit 81378bb

Please sign in to comment.