diff --git a/tests/test_enum.py b/tests/test_enum.py index b8a3b71c..dcc549dd 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -105,6 +105,22 @@ def goodbye(self): assert s.is_ONE() assert s.message == 'Goodbye' + def test_str_enum(self): + class States(str, enum.Enum): + ONE = "one" + TWO = "two" + + class Stuff(object): + def __init__(self, machine_cls): + self.state = None + self.machine = machine_cls(states=States, initial=States.ONE, model=self) + self.machine.add_transition("advance", States.ONE, States.TWO) + + s = Stuff(self.machine_cls) + assert s.is_ONE() + s.advance() + assert s.is_TWO() + @skipIf(enum is None, "enum is not available") class TestNestedStateEnums(TestEnumsAsStates): diff --git a/tests/test_markup.py b/tests/test_markup.py index 80d50dee..aa03c7e6 100644 --- a/tests/test_markup.py +++ b/tests/test_markup.py @@ -3,13 +3,18 @@ except ImportError: pass +try: + import enum +except ImportError: + enum = None + from transitions.extensions.markup import MarkupMachine, rep from transitions.extensions.factory import HierarchicalMarkupMachine from .utils import Stuff from functools import partial -from unittest import TestCase +from unittest import TestCase, skipIf try: from unittest.mock import MagicMock @@ -158,3 +163,23 @@ def setUp(self): self.machine_cls = HierarchicalMarkupMachine self.num_trans = len(self.transitions) self.num_auto = self.num_trans + 9**2 + + +@skipIf(enum is None, "enum is not available") +class TestMarkupMachineEnum(TestMarkupMachine): + class States(enum.Enum): + A = 1 + B = 2 + C = 3 + D = 4 + + def setUp(self): + self.machine_cls = MarkupMachine + self.states = TestMarkupMachineEnum.States + self.transitions = [ + {'trigger': 'walk', 'source': self.states.A, 'dest': self.states.B}, + {'trigger': 'run', 'source': self.states.B, 'dest': self.states.C}, + {'trigger': 'sprint', 'source': self.states.C, 'dest': self.states.D} + ] + self.num_trans = len(self.transitions) + self.num_auto = self.num_trans + len(self.states)**2 diff --git a/transitions/core.py b/transitions/core.py index 0a44f764..15dad6f0 100644 --- a/transitions/core.py +++ b/transitions/core.py @@ -858,7 +858,7 @@ def add_transition(self, trigger, source, dest, conditions=None, for model in self.models: self._add_trigger_to_model(trigger, model) - if isinstance(source, string_types): + if isinstance(source, string_types) and not isinstance(source, Enum): source = list(self.states.keys()) if source == self.wildcard_all else [source] else: source = [s.name if self._has_state(s) or isinstance(s, Enum) else s for s in listify(source)] diff --git a/transitions/extensions/diagrams.py b/transitions/extensions/diagrams.py index aad0fbc7..ab4cdd56 100644 --- a/transitions/extensions/diagrams.py +++ b/transitions/extensions/diagrams.py @@ -3,6 +3,7 @@ import warnings import logging +from enum import Enum from functools import partial _LOGGER = logging.getLogger(__name__) @@ -164,7 +165,10 @@ def _get_graph(self, model, title=None, force_new=False, show_roi=False): grph = self.graph_cls(self, title=title if title is not None else self.title) self.model_graphs[model] = grph try: - self.model_graphs[model].set_node_style(model.state, 'active') + if isinstance(model.state, Enum): + self.model_graphs[model].set_node_style(model.state.name, 'active') + else: + self.model_graphs[model].set_node_style(model.state, 'active') except AttributeError: _LOGGER.info("Could not set active state of diagram") try: diff --git a/transitions/extensions/markup.py b/transitions/extensions/markup.py index de1e0c88..e3d113d1 100644 --- a/transitions/extensions/markup.py +++ b/transitions/extensions/markup.py @@ -2,6 +2,7 @@ from functools import partial import itertools import importlib +from enum import Enum from collections import defaultdict from ..core import Machine @@ -67,7 +68,10 @@ def _convert_states(self, states): markup_states = [] for state in states: s_def = _convert(state, self.state_attributes, self.skip_references) - s_def['name'] = getattr(state, '_name', state.name) + if isinstance(state, Enum): + s_def['name'] = state.name + else: + s_def['name'] = getattr(state, '_name', state.name) if getattr(state, 'children', False): s_def['children'] = self._convert_states(state.children) markup_states.append(s_def) @@ -153,7 +157,9 @@ def _convert(obj, attributes, skip): val = getattr(obj, key, False) if not val: continue - if isinstance(val, string_types): + if isinstance(val, Enum): + s[key] = val.name + elif isinstance(val, string_types): s[key] = val else: try: