diff --git a/tests/test_core.py b/tests/test_core.py index 5d17f73b..913e65a7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1212,6 +1212,62 @@ def on_exception(event_data): m.to_B() self.assertTrue(mock.called) + def test_may_transition(self): + states = ['A', 'B', 'C'] + d = DummyModel() + m = Machine(model=d, states=states, initial='A', auto_transitions=False) + m.add_transition('walk', 'A', 'B') + m.add_transition('stop', 'B', 'C') + assert d.may_walk() + assert not d.may_stop() + + d.walk() + assert not d.may_walk() + assert d.may_stop() + + def test_may_transition_with_conditions(self): + states = ['A', 'B', 'C'] + d = DummyModel() + m = Machine(model=d, states=states, initial='A', auto_transitions=False) + m.add_transition('walk', 'A', 'B', conditions=[lambda: False]) + m.add_transition('stop', 'B', 'C') + m.add_transition('run', 'A', 'C') + assert not d.may_walk() + assert not d.may_stop() + assert d.may_run() + d.run() + assert not d.may_run() + + def test_may_transition_with_auto_transitions(self): + states = ['A', 'B', 'C'] + d = DummyModel() + Machine(model=d, states=states, initial='A') + assert d.may_to_A() + assert d.may_to_B() + assert d.may_to_C() + + def test_machine_may_transitions(self): + states = ['A', 'B', 'C'] + m = Machine(states=states, initial='A', auto_transitions=False) + m.add_transition('walk', 'A', 'B', conditions=[lambda: False]) + m.add_transition('stop', 'B', 'C') + m.add_transition('run', 'A', 'C') + + assert not m.may_walk() + assert not m.may_stop() + assert m.may_run() + m.run() + assert not m.may_run() + assert not m.may_stop() + assert not m.may_walk() + + def test_may_transition_with_invalid_state(self): + states = ['A', 'B', 'C'] + d = DummyModel() + m = Machine(model=d, states=states, initial='A', auto_transitions=False) + m.add_transition('walk', 'A', 'UNKNOWN') + assert not d.may_walk() + class TestWarnings(TestCase): def test_multiple_machines_per_model(self): diff --git a/tests/test_enum.py b/tests/test_enum.py index 553b936b..374ebf42 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -154,6 +154,17 @@ def test_get_triggers(self): trigger_enum = m.get_triggers(m.state) self.assertEqual(trigger_enum, trigger_name) + def test_may_transition(self): + class TrafficLight(object): + pass + + t = TrafficLight() + m = MachineFactory.get_predefined()(states=self.States, model=t, initial=self.States.RED, auto_transitions=False) + m.add_transition('go', self.States.RED, self.States.GREEN) + m.add_transition('stop', self.States.YELLOW, self.States.RED) + assert t.may_go() + assert not t.may_stop() + @skipIf(enum is None, "enum is not available") class TestNestedStateEnums(TestEnumsAsStates): diff --git a/tests/test_nesting.py b/tests/test_nesting.py index e7b9faea..6d41f4a7 100644 --- a/tests/test_nesting.py +++ b/tests/test_nesting.py @@ -804,6 +804,32 @@ def test_example_two(self): machine.reset() # exit C, enter A self.assertEqual('A', machine.state) + def test_machine_may_transitions(self): + states = ['A', 'B', {'name': 'C', 'children': ['1', '2', '3']}, 'D'] + transitions = [ + {'trigger': 'walk', 'source': 'A', 'dest': 'B'}, + {'trigger': 'run', 'source': 'B', 'dest': 'C'}, + {'trigger': 'run_fast', 'source': 'C', 'dest': 'C{0}1'.format(self.separator)}, + {'trigger': 'sprint', 'source': 'C', 'dest': 'D'} + ] + m = self.stuff.machine_cls( + states=states, transitions=transitions, initial='A', auto_transitions=False + ) + assert m.may_walk() + assert not m.may_run() + assert not m.may_run_fast() + assert not m.may_sprint() + + m.walk() + assert not m.may_walk() + assert m.may_run() + assert not m.may_run_fast() + + m.run() + assert m.may_run_fast() + assert m.may_sprint() + m.run_fast() + class TestSeparatorsSlash(TestSeparatorsBase): separator = '/' diff --git a/transitions/core.py b/transitions/core.py index 58e5bc5f..46ea2b7e 100644 --- a/transitions/core.py +++ b/transitions/core.py @@ -874,8 +874,26 @@ def _checked_assignment(self, model, name, func): else: setattr(model, name, func) + def _can_trigger(self, model, trigger, *args, **kwargs): + e = EventData(None, None, self, model, args, kwargs) + state = self.get_model_state(model).name + + for trigger_name in self.get_triggers(state): + if trigger_name != trigger: + continue + for transition in self.events[trigger_name].transitions[state]: + try: + self.get_state(transition.dest) + except ValueError: + continue + + if all(c.check(e) for c in transition.conditions): + return True + return False + def _add_trigger_to_model(self, trigger, model): self._checked_assignment(model, trigger, partial(self.events[trigger].trigger, model)) + self._checked_assignment(model, "may_%s" % trigger, partial(self._can_trigger, model, trigger)) def _get_trigger(self, model, trigger_name, *args, **kwargs): """Convenience function added to the model to trigger events by name. diff --git a/transitions/extensions/nesting.py b/transitions/extensions/nesting.py index 3b53f2c5..740eae01 100644 --- a/transitions/extensions/nesting.py +++ b/transitions/extensions/nesting.py @@ -885,6 +885,7 @@ def _add_trigger_to_model(self, trigger, model): self._checked_assignment(model, 'to_' + path[0], FunctionWrapper(trig_func, path[1:])) else: self._checked_assignment(model, trigger, trig_func) + self._checked_assignment(model, "may_%s" % trigger, partial(self._can_trigger, model, trigger)) # converts a list of current states into a hierarchical state tree def _build_state_tree(self, model_states, separator, tree=None):