diff --git a/string_to_string.go b/string_to_string.go index 64892db0..890a01af 100644 --- a/string_to_string.go +++ b/string_to_string.go @@ -22,11 +22,22 @@ func newStringToStringValue(val map[string]string, p *map[string]string) *string // Format: a=1,b=2 func (s *stringToStringValue) Set(val string) error { - r := csv.NewReader(strings.NewReader(val)) - ss, err := r.Read() - if err != nil { - return err + var ss []string + n := strings.Count(val, "=") + switch n { + case 0: + return fmt.Errorf("%s must be formatted as key=value", val) + case 1: + ss = append(ss, strings.Trim(val, `"`)) + default: + r := csv.NewReader(strings.NewReader(val)) + var err error + ss, err = r.Read() + if err != nil { + return err + } } + out := make(map[string]string, len(ss)) for _, pair := range ss { kv := strings.SplitN(pair, "=", 2) diff --git a/string_to_string_test.go b/string_to_string_test.go index f1aae042..0777f03f 100644 --- a/string_to_string_test.go +++ b/string_to_string_test.go @@ -140,16 +140,20 @@ func TestS2SCalledTwice(t *testing.T) { var s2s map[string]string f := setUpS2SFlagSet(&s2s) - in := []string{"a=1,b=2", "b=3", `"e=5,6"`, `f="7,8"`} + in := []string{"a=1,b=2", "b=3", `"e=5,6"`, `f=7,8`} expected := map[string]string{"a": "1", "b": "3", "e": "5,6", "f": "7,8"} argfmt := "--s2s=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - arg3 := fmt.Sprintf(argfmt, in[2]) - err := f.Parse([]string{arg1, arg2, arg3}) + arg0 := fmt.Sprintf(argfmt, in[0]) + arg1 := fmt.Sprintf(argfmt, in[1]) + arg2 := fmt.Sprintf(argfmt, in[2]) + arg3 := fmt.Sprintf(argfmt, in[3]) + err := f.Parse([]string{arg0, arg1, arg2, arg3}) if err != nil { t.Fatal("expected no error; got", err) } + if len(s2s) != len(expected) { + t.Fatalf("expected %d flags; got %d flags", len(expected), len(s2s)) + } for i, v := range s2s { if expected[i] != v { t.Fatalf("expected s2s[%s] to be %s but got: %s", i, expected[i], v)