Skip to content

Commit

Permalink
asyncmachine rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
aleneum committed Aug 27, 2019
1 parent fdd742a commit d9bca39
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 26 deletions.
44 changes: 37 additions & 7 deletions tests/test_async.py
@@ -1,10 +1,16 @@
import asyncio
from unittest.mock import MagicMock
from transitions.extensions.asyncio import AsyncMachine
try:
from transitions.extensions.asyncio import AsyncMachine
import asyncio
except (ImportError, SyntaxError):
asyncio = None


from unittest.mock import MagicMock
from unittest import skipIf
from .test_core import TestTransitions


@skipIf(asyncio is None, "AsyncMachine requires asyncio and contextvars suppport")
class TestAsync(TestTransitions):

@staticmethod
Expand All @@ -17,6 +23,11 @@ async def await_true():
await asyncio.sleep(0.1)
return True

@staticmethod
async def await_never_return():
await asyncio.sleep(100)
return None

@staticmethod
def synced_true():
return True
Expand All @@ -34,14 +45,14 @@ async def async_process():

m = self.machine
m.after_state_change = async_process
m.go()
asyncio.run(m.go())
self.assertEqual(m.state, 'B')
self.assertTrue(mock.called)

def test_async_condition(self):
m = self.machine
m.add_transition('proceed', 'A', 'C', conditions=self.await_true, unless=self.await_false)
m.proceed()
asyncio.run(m.proceed())
self.assertEqual(m.state, 'C')

def test_async_enter_exit(self):
Expand All @@ -59,7 +70,7 @@ async def async_exit():
m = self.machine
m.on_exit_A(async_exit)
m.on_enter_B(async_enter)
m.go()
asyncio.run(m.go())
self.assertTrue(exit_mock.called)
self.assertTrue(enter_mock.called)

Expand All @@ -71,6 +82,25 @@ def sync_process():

m = self.machine
m.add_transition('proceed', 'A', 'C', conditions=self.synced_true, after=sync_process)
m.proceed()
asyncio.run(m.proceed())
self.assertEqual(m.state, 'C')
self.assertTrue(mock.called)

def test_multiple_models(self):
async def fix():
if not m2.is_B():
await m2.fix()

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

m1 = AsyncMachine(states=['A', 'B', 'C'], initial='A')
m2 = AsyncMachine(states=['A', 'B', 'C'], initial='A')
m2.add_transition(trigger='go', source='A', dest='B', conditions=self.await_never_return)
m2.add_transition(trigger='fix', source='A', dest='C', conditions=self.await_true)
m1.add_transition(trigger='go', source='A', dest='B', conditions=self.await_true, after='go')
m1.add_transition(trigger='go', source='B', dest='C', after=fix)
loop.run_until_complete(asyncio.gather(m1.go()))

assert m1.is_C()
assert m2.is_C()
263 changes: 244 additions & 19 deletions transitions/extensions/asyncio.py
@@ -1,15 +1,61 @@
from ..core import Condition, Machine, Transition

import itertools
import logging
import asyncio
import contextvars

from functools import partial

from ..core import State, Condition, Transition, EventData
from ..core import Event, MachineError, Machine

_LOGGER = logging.getLogger(__name__)
_LOGGER.addHandler(logging.NullHandler())

is_subtask = contextvars.ContextVar('is_subtask', default=False)


class AsyncState(State):
"""A persistent representation of a state managed by a ``Machine``. Callback execution is done asynchronously.
Attributes:
name (str): State name which is also assigned to the model(s).
on_enter (list): Callbacks awaited when a state is entered.
on_exit (list): Callbacks awaited when a state is entered.
ignore_invalid_triggers (bool): Indicates if unhandled/invalid triggers should raise an exception.
"""

async def enter(self, event_data):
""" Triggered when a state is entered. """
_LOGGER.debug("%sEntering state %s. Processing callbacks...", event_data.machine.name, self.name)
await event_data.machine.callbacks(self.on_enter, event_data)
_LOGGER.info("%sEntered state %s", event_data.machine.name, self.name)

async def exit(self, event_data):
""" Triggered when a state is exited. """
_LOGGER.debug("%sExiting state %s. Processing callbacks...", event_data.machine.name, self.name)
await event_data.machine.callbacks(self.on_exit, event_data)
_LOGGER.info("%sExited state %s", event_data.machine.name, self.name)


class AsyncCondition(Condition):
""" A helper class to await condition checks in the intended way.
Attributes:
func (callable): The function to call for the condition check
target (bool): Indicates the target state--i.e., when True,
the condition-checking callback should return True to pass,
and when False, the callback should return False to pass.
"""

async def check(self, event_data):
""" Check whether the condition passes.
Args:
event_data (EventData): An EventData instance to pass to the
condition (if event sending is enabled) or to extract arguments
from (if event sending is disabled). Also contains the data
model attached to the current machine which is used to invoke
the condition.
"""
predicate = event_data.machine.resolve_callable(self.func, event_data)
if asyncio.iscoroutinefunction(predicate):
if event_data.machine.send_event:
Expand All @@ -21,44 +67,223 @@ async def check(self, event_data):


class AsyncTransition(Transition):
""" Representation of an asynchronous transition managed by a ``AsyncMachine`` instance.
Attributes:
source (str): Source state of the transition.
dest (str): Destination state of the transition.
prepare (list): Callbacks executed before conditions checks.
conditions (list): Callbacks evaluated to determine if
the transition should be executed.
before (list): Callbacks executed before the transition is executed
but only if condition checks have been successful.
after (list): Callbacks executed after the transition is executed
but only if condition checks have been successful.
"""

condition_cls = AsyncCondition

def _eval_conditions(self, event_data):
loop = asyncio.get_event_loop()
res = loop.run_until_complete(
asyncio.gather(*[
cond.check(event_data) for cond in self.conditions
])
)
async def _eval_conditions(self, event_data):
res = await asyncio.gather(*[cond.check(event_data) for cond in self.conditions])
if not all(res):
_LOGGER.debug("%sTransition condition failed: Transition halted.", event_data.machine.name)
return False
return True

async def execute(self, event_data):
""" Executes the transition.
Args:
event_data: An instance of class EventData.
Returns: boolean indicating whether or not the transition was
successfully executed (True if successful, False if not).
"""
_LOGGER.debug("%sInitiating transition from state %s to state %s...",
event_data.machine.name, self.source, self.dest)

await event_data.machine.callbacks(self.prepare, event_data)
_LOGGER.debug("%sExecuted callbacks before conditions.", event_data.machine.name)

if not await self._eval_conditions(event_data):
return False

# cancel running tasks since the transition will happen
machine = event_data.machine
model = event_data.model
if model in machine._tasks and not machine._tasks[model].done():
if not is_subtask.get():
machine._tasks[model].cancel()
else:
is_subtask.set(True)
machine._tasks[model] = asyncio.current_task()

await event_data.machine.callbacks(itertools.chain(event_data.machine.before_state_change, self.before), event_data)
_LOGGER.debug("%sExecuted callback before transition.", event_data.machine.name)

if self.dest: # if self.dest is None this is an internal transition with no actual state change
await self._change_state(event_data)

await event_data.machine.callbacks(itertools.chain(self.after, event_data.machine.after_state_change), event_data)
_LOGGER.debug("%sExecuted callback after transition.", event_data.machine.name)
return True

async def _change_state(self, event_data):
await event_data.machine.get_state(self.source).exit(event_data)
event_data.machine.set_state(self.dest, event_data.model)
event_data.update(event_data.model.state)
await event_data.machine.get_state(self.dest).enter(event_data)


class AsyncEvent(Event):
""" A collection of transitions assigned to the same trigger """

async def trigger(self, model, *args, **kwargs):
""" Serially execute all transitions that match the current state,
halting as soon as one successfully completes. Note that `AsyncEvent` triggers must be awaited.
Args:
args and kwargs: Optional positional or named arguments that will
be passed onto the EventData object, enabling arbitrary state
information to be passed on to downstream triggered functions.
Returns: boolean indicating whether or not a transition was
successfully executed (True if successful, False if not).
"""
func = partial(self._trigger, model, *args, **kwargs)
t = asyncio.create_task(self.machine._process(func))
try:
return await t
except asyncio.CancelledError:
return False

async def _trigger(self, model, *args, **kwargs):
state = self.machine.get_state(model.state)
if state.name not in self.transitions:
msg = "%sCan't trigger event %s from state %s!" % (self.machine.name, self.name,
state.name)
ignore = state.ignore_invalid_triggers if state.ignore_invalid_triggers is not None \
else self.machine.ignore_invalid_triggers
if ignore:
_LOGGER.warning(msg)
return False
else:
raise MachineError(msg)
event_data = EventData(state, self, self.machine, model, args=args, kwargs=kwargs)
return await self._process(event_data)

async def _process(self, event_data):
await self.machine.callbacks(self.machine.prepare_event, event_data)
_LOGGER.debug("%sExecuted machine preparation callbacks before conditions.", self.machine.name)

try:
for trans in self.transitions[event_data.state.name]:
event_data.transition = trans
if await trans.execute(event_data):
event_data.result = True
break
except Exception as err:
event_data.error = err
raise
finally:
await self.machine.callbacks(self.machine.finalize_event, event_data)
_LOGGER.debug("%sExecuted machine finalize callbacks", self.machine.name)
return event_data.result


class AsyncMachine(Machine):
""" Machine manages states, transitions and models. In case it is initialized without a specific model
(or specifically no model), it will also act as a model itself. Machine takes also care of decorating
models with conveniences functions related to added transitions and states during runtime.
Attributes:
states (OrderedDict): Collection of all registered states.
events (dict): Collection of transitions ordered by trigger/event.
models (list): List of models attached to the machine.
initial (str): Name of the initial state for new models.
prepare_event (list): Callbacks executed when an event is triggered.
before_state_change (list): Callbacks executed after condition checks but before transition is conducted.
Callbacks will be executed BEFORE the custom callbacks assigned to the transition.
after_state_change (list): Callbacks executed after the transition has been conducted.
Callbacks will be executed AFTER the custom callbacks assigned to the transition.
finalize_event (list): Callbacks will be executed after all transitions callbacks have been executed.
Callbacks mentioned here will also be called if a transition or condition check raised an error.
queued (bool): Whether transitions in callbacks should be executed immediately (False) or sequentially.
send_event (bool): When True, any arguments passed to trigger methods will be wrapped in an EventData
object, allowing indirect and encapsulated access to data. When False, all positional and keyword
arguments will be passed directly to all callback methods.
auto_transitions (bool): When True (default), every state will automatically have an associated
to_{state}() convenience trigger in the base model.
ignore_invalid_triggers (bool): When True, any calls to trigger methods that are not valid for the
present state (e.g., calling an a_to_b() trigger when the current state is c) will be silently
ignored rather than raising an invalid transition exception.
name (str): Name of the ``Machine`` instance mainly used for easier log message distinction.
"""

state_cls = AsyncState
transition_cls = AsyncTransition
event_cls = AsyncEvent
_tasks = {}

async def dispatch(self, trigger, *args, **kwargs): # ToDo: not tested
""" Trigger an event on all models assigned to the machine.
Args:
trigger (str): Event name
*args (list): List of arguments passed to the event trigger
**kwargs (dict): Dictionary of keyword arguments passed to the event trigger
Returns:
bool The truth value of all triggers combined with AND
"""
return asyncio.gather(*[getattr(model, trigger)(*args, **kwargs) for model in self.models])

def callbacks(self, funcs, event_data):
loop = asyncio.get_event_loop()
loop.run_until_complete(
asyncio.gather(*[
event_data.machine.callback(func, event_data) for func in funcs
])
)
async def callbacks(self, funcs, event_data):
""" Triggers a list of callbacks """
await asyncio.gather(*[event_data.machine.callback(func, event_data) for func in funcs])

async def callback(self, func, event_data):
""" Trigger a callback function with passed event_data parameters. In case func is a string,
the callable will be resolved from the passed model in event_data. This function is not intended to
be called directly but through state and transition callback definitions.
Args:
func (string, callable): The callback function.
1. First, if the func is callable, just call it
2. Second, we try to import string assuming it is a path to a func
3. Fallback to a model attribute
event_data (EventData): An EventData instance to pass to the
callback (if event sending is enabled) or to extract arguments
from (if event sending is disabled).
"""
func = self.resolve_callable(func, event_data)

if self.send_event:
if asyncio.iscoroutinefunction(func):
if asyncio.iscoroutinefunction(func) or asyncio.iscoroutinefunction(getattr(func, 'func', None)):
await func(event_data)
else:
func(event_data)
else:
if asyncio.iscoroutinefunction(func):
if asyncio.iscoroutinefunction(func) or asyncio.iscoroutinefunction(getattr(func, 'func', None)):
await func(*event_data.args, **event_data.kwargs)
else:
func(*event_data.args, **event_data.kwargs)

async def _process(self, trigger):
# default processing
if not self.has_queue:
if not self._transition_queue:
# if trigger raises an Error, it has to be handled by the Machine.process caller
return await trigger()
else:
raise MachineError("Attempt to process events synchronously while transition queue is not empty!")

# ToDo: test for has_queue
# process queued events
self._transition_queue.append(trigger)
# another entry in the queue implies a running transition; skip immediate execution
if len(self._transition_queue) > 1:
return True

# execute as long as transition queue is not empty ToDo: not tested!
while self._transition_queue:
try:
await self._transition_queue[0]()
self._transition_queue.popleft()
except Exception:
# if a transition raises an exception, clear queue and delegate exception handling
self._transition_queue.clear()
raise
return True

0 comments on commit d9bca39

Please sign in to comment.