Skip to content

Commit

Permalink
Merge pull request #239 from Synss/graphing-nameless-callable
Browse files Browse the repository at this point in the history
Fix graph crash with partial conditions
  • Loading branch information
aleneum committed Jul 24, 2017
2 parents 9f7cbe3 + 235bfa0 commit 33c15d0
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 5 deletions.
19 changes: 19 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .utils import InheritedStuff
from .utils import Stuff
from functools import partial
import sys
from transitions import Machine, MachineError, State, EventData
from transitions.core import listify, prep_ordered_arg
Expand Down Expand Up @@ -130,6 +131,24 @@ def test_conditions(self):
s.advance()
self.assertEqual(s.state, 'C')

def test_conditions_with_partial(self):
def check(result):
return result

s = self.stuff
s.machine.add_transition('advance', 'A', 'B',
conditions=partial(check, True))
s.machine.add_transition('advance', 'B', 'C',
unless=[partial(check, False)])
s.machine.add_transition('advance', 'C', 'D',
unless=[partial(check, False), partial(check, True)])
s.advance()
self.assertEqual(s.state, 'B')
s.advance()
self.assertEqual(s.state, 'C')
s.advance()
self.assertEqual(s.state, 'C')

def test_multiple_add_transitions_from_state(self):
s = self.stuff
s.machine.add_transition(
Expand Down
58 changes: 57 additions & 1 deletion tests/test_graphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from .utils import Stuff

from transitions.extensions import MachineFactory
from transitions.extensions.diagrams import Diagram
from transitions.extensions.diagrams import Diagram, rep
from transitions.extensions.nesting import NestedState
from unittest import TestCase, skipIf
from functools import partial
import tempfile
import os

Expand All @@ -23,6 +24,61 @@ def edge_label_from_transition_label(label):
return label.split(' | ')[0].split(' [')[0] # if no condition, label is returned; returns first event only


class TestRep(TestCase):

def test_rep_string(self):
self.assertEqual(rep("string"), "string")

def test_rep_function(self):
def check():
return True
self.assertTrue(check())
self.assertEqual(rep(check), "check")

def rest_rep_partial_no_args_no_kwargs(self):
def check():
return True
pcheck = partial(check)
self.assertTrue(pcheck())
self.assertEqual(rep(pcheck), "check()")

def test_rep_partial_with_args(self):
def check(result):
return result
pcheck = partial(check, True)
self.assertTrue(pcheck())
self.assertEqual(rep(pcheck), "check(True)")

def test_rep_partial_with_kwargs(self):
def check(result=True):
return result
pcheck = partial(check, result=True)
self.assertTrue(pcheck())
self.assertEqual(rep(pcheck), "check(result=True)")

def test_rep_partial_with_args_and_kwargs(self):
def check(result, doublecheck=True):
return result == doublecheck
pcheck = partial(check, True, doublecheck=True)
self.assertTrue(pcheck())
self.assertEqual(rep(pcheck), "check(True, doublecheck=True)")

def test_rep_callable_class(self):
class Check(object):
def __init__(self, result):
self.result = result

def __call__(self):
return self.result

def __repr__(self):
return "%s(%r)" % (type(self).__name__, self.result)

ccheck = Check(True)
self.assertTrue(ccheck())
self.assertEqual(rep(ccheck), "Check(True)")


@skipIf(pgv is None, 'Graph diagram requires pygraphviz')
class TestDiagrams(TestCase):

Expand Down
25 changes: 21 additions & 4 deletions transitions/extensions/diagrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,30 @@

import logging
from functools import partial
import itertools
from six import string_types, iteritems
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def rep(f):
"""Return a string representation for `f`."""
if isinstance(f, string_types):
return f
try:
return f.__name__
except AttributeError:
pass
if isinstance(f, partial):
return "%s(%s)" % (
f.func.__name__,
", ".join(itertools.chain(
(str(_) for _ in f.args),
("%s=%s" % (key, value)
for key, value in iteritems(f.keywords if f.keywords else {})))))
return str(f)


class Diagram(object):

def __init__(self, machine):
Expand Down Expand Up @@ -116,15 +136,12 @@ def _is_auto_transition(self, event, label):
return True
return False

def rep(self, f):
return f.__name__ if callable(f) else f

def _transition_label(self, edge_label, tran):
if self.machine.show_conditions and tran.conditions:
return '{edge_label} [{conditions}]'.format(
edge_label=edge_label,
conditions=' & '.join(
self.rep(c.func) if c.target else '!' + self.rep(c.func)
rep(c.func) if c.target else '!' + rep(c.func)
for c in tran.conditions
),
)
Expand Down

0 comments on commit 33c15d0

Please sign in to comment.