Skip to content

Commit

Permalink
Merge c4dae18 into 02076d3
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjbremner committed Oct 31, 2019
2 parents 02076d3 + c4dae18 commit 1828ba4
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 5 deletions.
16 changes: 16 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,22 @@ def goodbye(self):
assert s.is_ONE()
assert s.message == 'Goodbye'

def test_str_enum(self):
class States(str, enum.Enum):
ONE = "one"
TWO = "two"

class Stuff(object):
def __init__(self, machine_cls):
self.state = None
self.machine = machine_cls(states=States, initial=States.ONE, model=self)
self.machine.add_transition("advance", States.ONE, States.TWO)

s = Stuff(self.machine_cls)
assert s.is_ONE()
s.advance()
assert s.is_TWO()


@skipIf(enum is None, "enum is not available")
class TestNestedStateEnums(TestEnumsAsStates):
Expand Down
27 changes: 26 additions & 1 deletion tests/test_markup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
except ImportError:
pass

try:
import enum
except ImportError:
enum = None

from transitions.extensions.markup import MarkupMachine, rep
from transitions.extensions.factory import HierarchicalMarkupMachine
from .utils import Stuff
from functools import partial


from unittest import TestCase
from unittest import TestCase, skipIf

try:
from unittest.mock import MagicMock
Expand Down Expand Up @@ -158,3 +163,23 @@ def setUp(self):
self.machine_cls = HierarchicalMarkupMachine
self.num_trans = len(self.transitions)
self.num_auto = self.num_trans + 9**2


@skipIf(enum is None, "enum is not available")
class TestMarkupMachineEnum(TestMarkupMachine):
class States(enum.Enum):
A = 1
B = 2
C = 3
D = 4

def setUp(self):
self.machine_cls = MarkupMachine
self.states = TestMarkupMachineEnum.States
self.transitions = [
{'trigger': 'walk', 'source': self.states.A, 'dest': self.states.B},
{'trigger': 'run', 'source': self.states.B, 'dest': self.states.C},
{'trigger': 'sprint', 'source': self.states.C, 'dest': self.states.D}
]
self.num_trans = len(self.transitions)
self.num_auto = self.num_trans + len(self.states)**2
2 changes: 1 addition & 1 deletion transitions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def add_transition(self, trigger, source, dest, conditions=None,
for model in self.models:
self._add_trigger_to_model(trigger, model)

if isinstance(source, string_types):
if isinstance(source, string_types) and not isinstance(source, Enum):
source = list(self.states.keys()) if source == self.wildcard_all else [source]
else:
source = [s.name if self._has_state(s) or isinstance(s, Enum) else s for s in listify(source)]
Expand Down
6 changes: 5 additions & 1 deletion transitions/extensions/diagrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import warnings
import logging
from enum import Enum
from functools import partial

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -164,7 +165,10 @@ def _get_graph(self, model, title=None, force_new=False, show_roi=False):
grph = self.graph_cls(self, title=title if title is not None else self.title)
self.model_graphs[model] = grph
try:
self.model_graphs[model].set_node_style(model.state, 'active')
if isinstance(model.state, Enum):
self.model_graphs[model].set_node_style(model.state.name, 'active')
else:
self.model_graphs[model].set_node_style(model.state, 'active')
except AttributeError:
_LOGGER.info("Could not set active state of diagram")
try:
Expand Down
10 changes: 8 additions & 2 deletions transitions/extensions/markup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial
import itertools
import importlib
from enum import Enum
from collections import defaultdict

from ..core import Machine
Expand Down Expand Up @@ -67,7 +68,10 @@ def _convert_states(self, states):
markup_states = []
for state in states:
s_def = _convert(state, self.state_attributes, self.skip_references)
s_def['name'] = getattr(state, '_name', state.name)
if isinstance(state, Enum):
s_def['name'] = state.name
else:
s_def['name'] = getattr(state, '_name', state.name)
if getattr(state, 'children', False):
s_def['children'] = self._convert_states(state.children)
markup_states.append(s_def)
Expand Down Expand Up @@ -153,7 +157,9 @@ def _convert(obj, attributes, skip):
val = getattr(obj, key, False)
if not val:
continue
if isinstance(val, string_types):
if isinstance(val, Enum):
s[key] = val.name
elif isinstance(val, string_types):
s[key] = val
else:
try:
Expand Down

0 comments on commit 1828ba4

Please sign in to comment.