Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

133: Visit functions in while test #186

Merged
merged 5 commits into from
Nov 21, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/example_inputs/while_func_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def foo():
return True

while foo():
print(x)
x += 1
6 changes: 6 additions & 0 deletions examples/example_inputs/while_func_comparator_lhs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def foo():
return 6

while foo() > x:
print(x)
x += 1
6 changes: 6 additions & 0 deletions examples/example_inputs/while_func_comparator_rhs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def foo():
return 6

while x < foo():
print(x)
x += 1
33 changes: 27 additions & 6 deletions pyt/cfg/stmt_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,23 +555,44 @@ def visit_For(self, node):
path=self.filenames[-1]
))

if isinstance(node.iter, ast.Call) and get_call_names_as_string(node.iter.func) in self.function_names:
last_node = self.visit(node.iter)
last_node.connect(for_node)
self.process_loop_funcs(node.iter, for_node)

return self.loop_node_skeleton(for_node, node)

def process_loop_funcs(self, comp_n, loop_node):
"""
If the loop test node contains function calls, it connects the loop node to the nodes of
those function calls.

:param comp_n: The test node of a loop that may contain functions.
:param loop_node: The loop node itself to connect to the new function nodes if any
:return: None
"""
if isinstance(comp_n, ast.Call) and get_call_names_as_string(comp_n.func) in self.function_names:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this was already here, but I don't understand why get_call_names_as_string(comp_n.func) in self.function_names is there. It means that range is ignored for instance. I'd be tempted to remove it, but maybe in a later PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't pay much mind to this detail but if removed, 4 of the test cases or for fail due to functions in the iter node being ignored (e.g. range).

I think the rationale here is to only link the while / for node to the function node if we've visited the function before, but I'm not sure why that would matter since we're immediately visiting the call node.

That said, we don't have a visit_Call method, so the default visitor will be invoked, which will in turn call visit_Name (which we don't have, so default again), then I think a call to visit_Load which will be default again, etc. I wonder if we'd need a visit_Call to handle this case before we can safely remove the in self.function_names.

last_node = self.visit(comp_n)
last_node.connect(loop_node)

def visit_While(self, node):
label_visitor = LabelVisitor()
label_visitor.visit(node.test)
test = node.test # the test condition of the while loop
label_visitor.visit(test)

test = self.append_node(Node(
while_node = self.append_node(Node(
'while ' + label_visitor.result + ':',
node,
path=self.filenames[-1]
))

return self.loop_node_skeleton(test, node)
if isinstance(test, ast.Compare):
comparators = test.comparators
comparators.append(test.left) # quirk. See https://greentreesnakes.readthedocs.io/en/latest/nodes.html#Compare
bcaller marked this conversation as resolved.
Show resolved Hide resolved

for comp in comparators:
self.process_loop_funcs(comp, while_node)
else: # while foo():
self.process_loop_funcs(test, while_node)

return self.loop_node_skeleton(while_node, node)

def add_blackbox_or_builtin_call(self, node, blackbox): # noqa: C901
"""Processes a blackbox or builtin function when it is called.
Expand Down
90 changes: 90 additions & 0 deletions tests/cfg/cfg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,96 @@ def test_while_line_numbers(self):
self.assertLineNumber(else_body_2, 6)
self.assertLineNumber(next_stmt, 7)

def test_while_func_comparator(self):
self.cfg_create_from_file('examples/example_inputs/while_func_comparator.py')

self.assert_length(self.cfg.nodes, expected_length=9)

entry = 0
test = 1
entry_foo = 2
ret_foo = 3
exit_foo = 4
call_foo = 5
_print = 6
body_1 = 7
_exit = 8

self.assertEqual(self.cfg.nodes[test].label, 'while foo():')

self.assertInCfg([
(test, entry),
(entry_foo, test),
(_print, test),
(_exit, test),
(body_1, _print),
(test, body_1),
(test, call_foo),
(ret_foo, entry_foo),
(exit_foo, ret_foo),
(call_foo, exit_foo)
])

def test_while_func_comparator_rhs(self):
self.cfg_create_from_file('examples/example_inputs/while_func_comparator_rhs.py')

self.assert_length(self.cfg.nodes, expected_length=9)

entry = 0
test = 1
entry_foo = 2
ret_foo = 3
exit_foo = 4
call_foo = 5
_print = 6
body_1 = 7
_exit = 8

self.assertEqual(self.cfg.nodes[test].label, 'while x < foo():')

self.assertInCfg([
(test, entry),
(entry_foo, test),
(_print, test),
(_exit, test),
(body_1, _print),
(test, body_1),
(test, call_foo),
(ret_foo, entry_foo),
(exit_foo, ret_foo),
(call_foo, exit_foo)
])

def test_while_func_comparator_lhs(self):
self.cfg_create_from_file('examples/example_inputs/while_func_comparator_lhs.py')

self.assert_length(self.cfg.nodes, expected_length=9)

entry = 0
test = 1
entry_foo = 2
ret_foo = 3
exit_foo = 4
call_foo = 5
_print = 6
body_1 = 7
_exit = 8

self.assertEqual(self.cfg.nodes[test].label, 'while foo() > x:')

self.assertInCfg([
(test, entry),
(entry_foo, test),
(_print, test),
(_exit, test),
(body_1, _print),
(test, body_1),
(test, call_foo),
(ret_foo, entry_foo),
(exit_foo, ret_foo),
(call_foo, exit_foo)
])


class CFGAssignmentMultiTest(CFGBaseTestCase):
def test_assignment_multi_target(self):
Expand Down