Skip to content

Commit

Permalink
Fix graph crash with partial conditions
Browse files Browse the repository at this point in the history
Graph.rep() assumes that every callable has a `__name__` attribute.
However, callables like `functools.partial` and other classes with a
`__call__()` method have no `__name__`.  This patch modifies
`Graph.rep()` to first check whether the condition is a string and
return it, then try to return `f.__name__` (that is the current
behavior) and if that fails, then it calls `str()` on the argument.

`str()` should never fail and it is now the responsibility of the user
to return something meaningful from `__str__()`.
  • Loading branch information
Synss committed Jul 23, 2017
1 parent 9f7cbe3 commit 1ca302e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
25 changes: 25 additions & 0 deletions tests/test_graphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,31 @@ def test_diagram(self):
graph = m.get_graph(force_new=True, title=False)
self.assertEqual("", graph.graph_attr['label'])

def test_anonymous_callable_conditions(self):
import functools
def check(result):
return result

class Check:
def __init__(self, result):
self.result = result
def __call__(self):
return self.result

m = self.machine_cls(states=self.states,
transitions=self.transitions,
initial='A',
auto_transitions=False,
show_conditions=True,
title='a test')
m.add_state({'name': 'E'})
m.add_transition(trigger='fly', source='D', dest='E',
conditions=Check(True))
m.add_transition(trigger='fly', source='D', dest='E',
unless=functools.partial(check, False))
graph = m.get_graph()
self.assertIsNotNone(graph)

def test_add_custom_state(self):
m = self.machine_cls(states=self.states, transitions=self.transitions, initial='A', auto_transitions=False, title='a test')
m.add_state('X')
Expand Down
8 changes: 7 additions & 1 deletion transitions/extensions/diagrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import logging
from functools import partial
from six import string_types
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

Expand Down Expand Up @@ -117,7 +118,12 @@ def _is_auto_transition(self, event, label):
return False

def rep(self, f):
return f.__name__ if callable(f) else f
if isinstance(f, string_types):
return f
try:
return f.__name__
except AttributeError:
return str(f)

def _transition_label(self, edge_label, tran):
if self.machine.show_conditions and tran.conditions:
Expand Down

0 comments on commit 1ca302e

Please sign in to comment.