diff --git a/tests/test_enum.py b/tests/test_enum.py index b8a3b71c..dcc549dd 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -105,6 +105,22 @@ def goodbye(self): assert s.is_ONE() assert s.message == 'Goodbye' + def test_str_enum(self): + class States(str, enum.Enum): + ONE = "one" + TWO = "two" + + class Stuff(object): + def __init__(self, machine_cls): + self.state = None + self.machine = machine_cls(states=States, initial=States.ONE, model=self) + self.machine.add_transition("advance", States.ONE, States.TWO) + + s = Stuff(self.machine_cls) + assert s.is_ONE() + s.advance() + assert s.is_TWO() + @skipIf(enum is None, "enum is not available") class TestNestedStateEnums(TestEnumsAsStates): diff --git a/transitions/core.py b/transitions/core.py index 0a44f764..15dad6f0 100644 --- a/transitions/core.py +++ b/transitions/core.py @@ -858,7 +858,7 @@ def add_transition(self, trigger, source, dest, conditions=None, for model in self.models: self._add_trigger_to_model(trigger, model) - if isinstance(source, string_types): + if isinstance(source, string_types) and not isinstance(source, Enum): source = list(self.states.keys()) if source == self.wildcard_all else [source] else: source = [s.name if self._has_state(s) or isinstance(s, Enum) else s for s in listify(source)]