-
Notifications
You must be signed in to change notification settings - Fork 529
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(experimental): Add
Machine.model_override
and `experimental.ut…
…ils.generate_base_model` When `model_override` is set to True, Machine will assign only methods to a model that are already defined. This eases static type checking (#658) and enables tailored helper function assigment. Default value is False which prevents override of already defined model functions.
- Loading branch information
Showing
15 changed files
with
322 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,127 @@ | ||
from typing import TYPE_CHECKING | ||
from unittest import TestCase | ||
from types import ModuleType | ||
from unittest.mock import MagicMock | ||
|
||
from transitions import Machine | ||
from transitions.experimental.utils import generate_base_model | ||
from transitions.extensions import HierarchicalMachine | ||
|
||
from .utils import Stuff | ||
|
||
if TYPE_CHECKING: | ||
from transitions.core import MachineConfig | ||
from typing import Type | ||
|
||
|
||
def import_code(code: str, name: str) -> ModuleType: | ||
module = ModuleType(name) | ||
exec(code, module.__dict__) | ||
return module | ||
|
||
|
||
class TestExperimental(TestCase): | ||
|
||
def setUp(self) -> None: | ||
self.machine_cls = Machine # type: Type[Machine] | ||
|
||
def test_model_override(self): | ||
|
||
class Model: | ||
|
||
def trigger(self, name: str) -> bool: | ||
raise RuntimeError("Should be overridden") | ||
|
||
def is_A(self) -> bool: | ||
raise RuntimeError("Should be overridden") | ||
|
||
def is_C(self) -> bool: | ||
raise RuntimeError("Should be overridden") | ||
|
||
model = Model() | ||
machine = self.machine_cls(model, states=["A", "B"], initial="A", model_override=True) | ||
self.assertTrue(model.is_A()) | ||
with self.assertRaises(AttributeError): | ||
model.to_B() # type: ignore # Should not be assigned to model since its not declared | ||
self.assertTrue(model.trigger("to_B")) | ||
self.assertFalse(model.is_A()) | ||
with self.assertRaises(RuntimeError): | ||
model.is_C() # not overridden yet | ||
machine.add_state("C") | ||
self.assertFalse(model.is_C()) # now it is! | ||
self.assertTrue(model.trigger("to_C")) | ||
self.assertTrue(model.is_C()) | ||
|
||
def test_generate_base_model(self): | ||
simple_config = { | ||
"states": ["A", "B"], | ||
"transitions": [ | ||
["go", "A", "B"], | ||
["back", "*", "A"] | ||
], | ||
"initial": "A", | ||
"model_override": True | ||
} # type: MachineConfig | ||
|
||
mod = import_code(generate_base_model(simple_config), "base_module") | ||
model = mod.BaseModel() | ||
machine = self.machine_cls(model, **simple_config) | ||
self.assertTrue(model.is_A()) | ||
self.assertTrue(model.go()) | ||
self.assertTrue(model.is_B()) | ||
self.assertTrue(model.back()) | ||
self.assertTrue(model.state == "A") | ||
with self.assertRaises(AttributeError): | ||
model.is_C() | ||
|
||
def test_generate_base_model_callbacks(self): | ||
simple_config = { | ||
"states": ["A", "B"], | ||
"transitions": [ | ||
["go", "A", "B"], | ||
], | ||
"initial": "A", | ||
"model_override": True, | ||
"before_state_change": "call_this" | ||
} # type: MachineConfig | ||
|
||
mod = import_code(generate_base_model(simple_config), "base_module") | ||
mock = MagicMock() | ||
|
||
class Model(mod.BaseModel): # type: ignore | ||
|
||
@staticmethod | ||
def call_this() -> None: | ||
mock() | ||
|
||
model = Model() | ||
machine = self.machine_cls(model, **simple_config) | ||
self.assertTrue(model.is_A()) | ||
self.assertTrue(model.go()) | ||
self.assertTrue(mock.called) | ||
|
||
def test_generate_model_no_auto(self): | ||
simple_config: MachineConfig = { | ||
"states": ["A", "B"], | ||
"auto_transitions": False, | ||
"model_override": True, | ||
"transitions": [ | ||
["go", "A", "B"], | ||
["back", "*", "A"] | ||
], | ||
"initial": "A" | ||
} | ||
mod = import_code(generate_base_model(simple_config), "base_module") | ||
model = mod.BaseModel() | ||
machine = self.machine_cls(model, **simple_config) | ||
self.assertTrue(model.is_A()) | ||
self.assertTrue(model.go()) | ||
with self.assertRaises(AttributeError): | ||
model.to_B() | ||
|
||
|
||
class TestHSMExperimental(TestExperimental): | ||
|
||
def setUp(self): | ||
self.machine_cls = HierarchicalMachine # type: Type[HierarchicalMachine] | ||
self.create_trigger_class() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from collections import deque, defaultdict | ||
|
||
from transitions.core import listify | ||
from transitions.extensions.markup import HierarchicalMarkupMachine | ||
|
||
|
||
_placeholder_body = "raise RuntimeError('This should be overridden')" | ||
|
||
|
||
def generate_base_model(config): | ||
m = HierarchicalMarkupMachine(**config) | ||
triggers = set() | ||
markup = m.markup | ||
model_attribute = markup.get("model_attribute", "state") | ||
trigger_block = "" | ||
state_block = "" | ||
callback_block = "" | ||
|
||
callbacks = set( | ||
[cb for cb in markup["prepare_event"]] | ||
+ [cb for cb in markup["before_state_change"]] | ||
+ [cb for cb in markup["after_state_change"]] | ||
+ [cb for cb in markup["on_exception"]] | ||
+ [cb for cb in markup["on_final"]] | ||
+ [cb for cb in markup["finalize_event"]] | ||
) | ||
|
||
for trans in markup["transitions"]: | ||
triggers.add(trans["trigger"]) | ||
|
||
stack = [(markup["states"], markup["transitions"], "")] | ||
has_nested_states = any("children" in state for state in markup["states"]) | ||
while stack: | ||
states, transitions, prefix = stack.pop() | ||
for state in states: | ||
state_name = state["name"] | ||
|
||
state_block += ( | ||
f" def is_{prefix}{state_name}(self{', allow_substates=False' if has_nested_states else ''})" | ||
f" -> bool: {_placeholder_body}\n" | ||
) | ||
if m.auto_transitions: | ||
state_block += ( | ||
f" def to_{prefix}{state_name}(self) -> bool: {_placeholder_body}\n" | ||
f" def may_to_{prefix}{state_name}(self) -> bool: {_placeholder_body}\n" | ||
) | ||
|
||
state_block += "\n" | ||
for tran in transitions: | ||
triggers.add(tran["trigger"]) | ||
new_set = set( | ||
[cb for cb in tran.get("prepare", [])] | ||
+ [cb for cb in tran.get("conditions", [])] | ||
+ [cb for cb in tran.get("unless", [])] | ||
+ [cb for cb in tran.get("before", [])] | ||
+ [cb for cb in tran.get("after", [])] | ||
) | ||
callbacks.update(new_set) | ||
|
||
if "children" in state: | ||
stack.append((state["children"], state.get("transitions", []), prefix + state_name + "_")) | ||
|
||
for trigger_name in triggers: | ||
trigger_block += ( | ||
f" def {trigger_name}(self) -> bool: {_placeholder_body}\n" | ||
f" def may_{trigger_name}(self) -> bool: {_placeholder_body}\n" | ||
) | ||
|
||
extra_params = "event_data: EventData" if m.send_event else "*args: List[Any], **kwargs: Dict[str, Any]" | ||
for callback_name in callbacks: | ||
if isinstance(callback_name, str): | ||
callback_block += (f" @abstractmethod\n" | ||
f" def {callback_name}(self, {extra_params}) -> Optional[bool]: ...\n") | ||
|
||
template = f"""# autogenerated by transitions | ||
from abc import ABCMeta, abstractmethod | ||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
from transitions.core import CallbacksArg, StateIdentifier, EventData | ||
class BaseModel(metaclass=ABCMeta): | ||
{model_attribute}: "StateIdentifier" = "" | ||
def trigger(self, name: str) -> bool: {_placeholder_body} | ||
{trigger_block} | ||
{state_block}\ | ||
{callback_block}""" | ||
|
||
return template | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from typing import Union, Callable, List, Optional, Iterable, Type, ClassVar, Tuple, Dict, Any, DefaultDict, Deque | ||
from transitions.core import StateIdentifier, CallbacksArg, CallbackFunc, Machine, TransitionConfig, MachineConfig | ||
from transitions.extensions.markup import MarkupConfig | ||
|
||
_placeholder_body: str | ||
|
||
def generate_base_model(config: Union[MachineConfig, MarkupConfig]) -> str: ... |
Oops, something went wrong.