diff --git a/plumbum/cli/switches.py b/plumbum/cli/switches.py index e6db1bb9..3321d778 100644 --- a/plumbum/cli/switches.py +++ b/plumbum/cli/switches.py @@ -464,10 +464,16 @@ class MyApp(Application): comparison or not. The default is ``False`` :param csv: splits the input as a comma-separated-value before validating and returning a list. Accepts ``True``, ``False``, or a string for the separator - """ + :param all_markers: when a user inputs any value from this set, it is considered that + he requested iteration over all values. By default `all_markers` + are {"*", "all"}.""" def __init__(self, *values, **kwargs): self.case_sensitive = kwargs.pop("case_sensitive", False) + all_markers = kwargs.pop("all_markers", None) + if all_markers is None: + all_markers = {"*", "all"} - set(values) + self.all_markers = frozenset(all_markers) self.csv = kwargs.pop("csv", False) if self.csv is True: self.csv = "," @@ -475,7 +481,30 @@ def __init__(self, *values, **kwargs): raise TypeError( _("got unexpected keyword argument(s): {0}").format(kwargs.keys()) ) - self.values = values + + str_values = [] + numeric_values = [] + non_primitive_values = [] + for opt in values: + if isinstance(opt, str): + if not self.case_sensitive: + opt = opt.lower() + str_values.append(opt) + elif isinstance(opt, (int, float)): + numeric_values.append(opt) + else: + non_primitive_values.append(opt) + self.str_values = frozenset(str_values) + self.numeric_values = frozenset(numeric_values) + self.non_primitive_values = tuple(non_primitive_values) + + @property + def primitive_values(self): + return self.str_values | self.numeric_values + + @property + def values(self): + return tuple(self.primitive_values) + self.non_primitive_values def __repr__(self): return "{{{0}}}".format( @@ -484,16 +513,29 @@ def __repr__(self): def __call__(self, value, check_csv=True): if self.csv and check_csv: - return [self(v.strip(), check_csv=False) for v in value.split(",")] + vals = value.split(",") + res = [] + for v in vals: + if v in self.all_markers: + res.extend(self.primitive_values) + else: + res.append(self(v.strip(), check_csv=False)) + + return res + if not self.case_sensitive: value = value.lower() - for opt in self.values: - if isinstance(opt, str): - if not self.case_sensitive: - opt = opt.lower() - if opt == value: - return opt # always return original value - continue + + try: + if value in self.str_values: + return value + + if value in self.numeric_values: + return value + except TypeError: + pass + + for opt in self.non_primitive_values: try: return opt(value) except ValueError: @@ -503,11 +545,16 @@ def __call__(self, value, check_csv=True): ) def choices(self, partial=""): - choices = { - opt if isinstance(opt, str) else "({})".format(opt) for opt in self.values - } + choices = ( + self.all_markers + | self.str_values + | {"({})".format(opt) for opt in self.numeric_values} + | {"({})".format(opt) for opt in self.primitive_values} + ) + if partial: choices = {opt for opt in choices if opt.lower().startswith(partial)} + return choices diff --git a/tests/test_cli.py b/tests/test_cli.py index bb32c5e4..e3193d50 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -197,6 +197,9 @@ def test_okay(self): _, rc = SimpleApp.run(["foo", "--bacon=81", "--csv=MAX,MIN,100"], exit=False) assert rc == 0 + _, rc = SimpleApp.run(["foo", "--bacon=81", "--csv=all,100"], exit=False) + assert rc == 0 + _, rc = SimpleApp.run(["foo", "--bacon=81", "--num=100"], exit=False) assert rc == 0