/
flag.go
368 lines (332 loc) · 8.24 KB
/
flag.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
package zli
import (
"errors"
"fmt"
"path/filepath"
"strconv"
"strings"
)
type (
// ErrFlagUnknown is used when the flag parsing encounters unknown flags.
ErrFlagUnknown struct{ flag string }
// ErrFlagDouble is used when a flag is given more than once.
ErrFlagDouble struct{ flag string }
// ErrFlagInvalid is used when a flag has an invalid syntax (e.g. "no" for
// an int flag).
ErrFlagInvalid struct {
flag string
err error
kind string
}
)
func (e ErrFlagUnknown) Error() string { return fmt.Sprintf("unknown flag: %q", e.flag) }
func (e ErrFlagDouble) Error() string { return fmt.Sprintf("flag given more than once: %q", e.flag) }
func (e ErrFlagInvalid) Error() string {
return fmt.Sprintf("%s: %s (must be a %s)", e.flag, e.err, e.kind)
}
func (e ErrFlagInvalid) Unwrap() error { return e.err }
type Flags struct {
Program string // Program name.
Args []string // List of arguments, after parsing this will be reduces to non-flags.
flags []flagValue
}
type flagValue struct {
names []string
value interface{}
}
type setter interface{ Set() bool }
// NewFlags creates a new Flags from os.Args.
func NewFlags(args []string) Flags {
f := Flags{}
if len(args) > 0 {
f.Program = filepath.Base(args[0])
}
if len(args) > 1 {
f.Args = args[1:]
}
return f
}
// Shift a value from the argument list.
func (f *Flags) Shift() string {
if len(f.Args) == 0 {
return ""
}
a := f.Args[0]
f.Args = f.Args[1:]
return a
}
// Sentinel return values for ShiftCommand()
const (
CommandNoneGiven = "\x00"
CommandAmbiguous = "\x01"
CommandUnknown = "\x02"
)
// ShiftCommand shifts a value from the argument list, and matches it with the
// list of commands.
//
// Commands can be matched as an abbreviation as long as it's unambiguous; if
// you have "search" and "identify" then "i", "id", etc. will all return
// "identify".
//
// If you have the commands "search" and "see", then "s" or "se" are ambiguous,
// and it will return the special CommandAmbiguous sentinel value.
//
// Commands can also contain aliases as "alias=cmd"; for example "ci=commit".
//
// It will return CommandNoneGiven if there is no command, and CommandUnknown if
// the command is not found.
func (f *Flags) ShiftCommand(cmds ...string) string {
cmd := f.Shift()
if cmd == "" {
return CommandNoneGiven
}
cmd = strings.ToLower(cmd)
var found string
for _, c := range cmds {
if c == cmd {
return cmd
}
if strings.HasPrefix(c, cmd) {
if found != "" {
return CommandAmbiguous
}
if i := strings.IndexRune(c, '='); i > -1 {
c = c[i+1:]
}
found = c
}
}
if found == "" {
return CommandUnknown
}
return found
}
func (f *Flags) Parse() error {
// Modify f.Args to split out grouped boolean values: "prog -ab" becomes
// "prog -a -b"
args := make([]string, 0, len(f.Args))
for _, arg := range f.Args {
if !strings.HasPrefix(arg, "-") {
args = append(args, arg)
continue
}
if len(strings.TrimLeft(arg, "-")) <= 1 {
args = append(args, arg)
continue
}
_, ok := f.match(arg)
if ok {
args = append(args, arg)
continue
}
split := strings.Split(arg[1:], "")
found := true
for _, s := range split {
_, ok := f.match(s)
if !ok {
found = false
break
}
}
if !found {
return &ErrFlagUnknown{arg}
}
for _, s := range split {
args = append(args, "-"+s)
}
}
f.Args = args
var (
p []string
skip bool
)
for i, a := range f.Args {
if skip {
skip = false
continue
}
if a == "" || a[0] != '-' {
p = append(p, a)
continue
}
if a == "--" {
p = append(p, f.Args[i+1:]...)
break
}
flag, ok := f.match(a)
if !ok {
return &ErrFlagUnknown{a}
}
var err error
next := func() (string, bool) {
if j := strings.IndexByte(f.Args[i], '='); j > -1 {
return f.Args[i][j+1:], true
}
if i >= len(f.Args)-1 {
err = fmt.Errorf("needs an argument")
return "", false
}
skip = true
return f.Args[i+1], true
}
// TODO: it might make more sense to have two interfaces: singleSetter
// and multiSetter.
if set := flag.value.(setter); set.Set() {
switch flag.value.(type) {
case flagIntCounter, flagStringList, flagBool: // Not an error.
default:
return &ErrFlagDouble{a}
}
}
var val string
switch v := flag.value.(type) {
case flagBool:
*v.s = true
*v.v = true
case flagString:
*v.v, *v.s = next()
case flagInt:
val, *v.s = next()
x, err := strconv.ParseInt(val, 0, 64)
if err != nil {
if nErr := errors.Unwrap(err); nErr != nil {
err = nErr
}
return ErrFlagInvalid{a, err, "number"}
}
*v.v = int(x)
case flagInt64:
val, *v.s = next()
x, err := strconv.ParseInt(val, 0, 64)
if err != nil {
if nErr := errors.Unwrap(err); nErr != nil {
err = nErr
}
return ErrFlagInvalid{a, err, "number"}
}
*v.v = x
case flagFloat64:
val, *v.s = next()
x, err := strconv.ParseFloat(val, 64)
if err != nil {
if nErr := errors.Unwrap(err); nErr != nil {
err = nErr
}
return ErrFlagInvalid{a, err, "number"}
}
*v.v = x
case flagIntCounter:
*v.s = true
*v.v++
case flagStringList:
n, s := next()
*v.s = s
*v.v = append(*v.v, n)
}
if err != nil {
return fmt.Errorf("%s: %s", a, err)
}
}
f.Args = p
return nil
}
func (f *Flags) match(arg string) (flagValue, bool) {
arg = strings.TrimLeft(arg, "-")
for _, flag := range f.flags {
for _, name := range flag.names {
if name == arg || strings.HasPrefix(arg, name+"=") {
return flag, true
}
}
}
return flagValue{}, false
}
type (
flagBool struct {
v *bool
s *bool
}
flagString struct {
v *string
s *bool
}
flagInt struct {
v *int
s *bool
}
flagInt64 struct {
v *int64
s *bool
}
flagFloat64 struct {
v *float64
s *bool
}
flagIntCounter struct {
v *int
s *bool
}
flagStringList struct {
v *[]string
s *bool
}
)
func (f flagBool) Pointer() *bool { return f.v }
func (f flagString) Pointer() *string { return f.v }
func (f flagInt) Pointer() *int { return f.v }
func (f flagInt64) Pointer() *int64 { return f.v }
func (f flagFloat64) Pointer() *float64 { return f.v }
func (f flagIntCounter) Pointer() *int { return f.v }
func (f flagStringList) Pointer() *[]string { return f.v }
func (f flagBool) Bool() bool { return *f.v }
func (f flagString) String() string { return *f.v }
func (f flagInt) Int() int { return *f.v }
func (f flagInt64) Int64() int64 { return *f.v }
func (f flagFloat64) Float64() float64 { return *f.v }
func (f flagIntCounter) Int() int { return *f.v }
func (f flagStringList) Strings() []string { return *f.v }
func (f flagBool) Set() bool { return *f.s }
func (f flagString) Set() bool { return *f.s }
func (f flagInt) Set() bool { return *f.s }
func (f flagInt64) Set() bool { return *f.s }
func (f flagFloat64) Set() bool { return *f.s }
func (f flagIntCounter) Set() bool { return *f.s }
func (f flagStringList) Set() bool { return *f.s }
func (f *Flags) append(v interface{}, n string, a ...string) {
f.flags = append(f.flags, flagValue{value: v, names: append([]string{n}, a...)})
}
func (f *Flags) Bool(def bool, name string, aliases ...string) flagBool {
v := flagBool{v: &def, s: new(bool)}
f.append(v, name, aliases...)
return v
}
func (f *Flags) String(def, name string, aliases ...string) flagString {
v := flagString{v: &def, s: new(bool)}
f.append(v, name, aliases...)
return v
}
func (f *Flags) Int(def int, name string, aliases ...string) flagInt {
v := flagInt{v: &def, s: new(bool)}
f.append(v, name, aliases...)
return v
}
func (f *Flags) Int64(def int64, name string, aliases ...string) flagInt64 {
v := flagInt64{v: &def, s: new(bool)}
f.append(v, name, aliases...)
return v
}
func (f *Flags) Float64(def float64, name string, aliases ...string) flagFloat64 {
v := flagFloat64{v: &def, s: new(bool)}
f.append(v, name, aliases...)
return v
}
func (f *Flags) IntCounter(def int, name string, aliases ...string) flagIntCounter {
v := flagIntCounter{v: &def, s: new(bool)}
f.append(v, name, aliases...)
return v
}
func (f *Flags) StringList(def []string, name string, aliases ...string) flagStringList {
v := flagStringList{v: &def, s: new(bool)}
f.append(v, name, aliases...)
return v
}