Skip to content

Commit 80113af

Browse files
bcallerBen Caller
authored andcommitted
Handle Starred assignments where possible
Try to match the targets with the values so we reduce the number of false positives. Before, all right hand side variables were tainting all of the left hand side variables. a, *b = _, _, TAINT a clean, b tainted a, *b, c = _, _, TAINT, TAINT, _ a clean, b tainted, c clean a, *b, c = _, *_, *TAINT, *_ a clean, b tainted, c tainted
1 parent d2566d2 commit 80113af

File tree

4 files changed

+84
-9
lines changed

4 files changed

+84
-9
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
a, *b, c, d, e = f, *g, *h, f + i, j

pyt/cfg/stmt_visitor.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -327,28 +327,59 @@ def visit_Try(self, node):
327327
return ControlFlowNode(try_node, last_statements, break_statements=body.break_statements)
328328

329329
def assign_tuple_target(self, node, right_hand_side_variables):
330-
new_assignment_nodes = list()
331-
for i, target in enumerate(node.targets[0].elts):
332-
value = node.value.elts[i]
330+
new_assignment_nodes = []
331+
remaining_variables = list(right_hand_side_variables)
332+
remaining_targets = list(node.targets[0].elts)
333+
remaining_values = list(node.value.elts) # May contain duplicates
333334

335+
def visit(target, value):
334336
label = LabelVisitor()
335337
label.visit(target)
336-
338+
rhs_visitor = RHSVisitor()
339+
rhs_visitor.visit(value)
337340
if isinstance(value, ast.Call):
338341
new_ast_node = ast.Assign(target, value)
339-
new_ast_node.lineno = node.lineno
340-
342+
ast.copy_location(new_ast_node, node)
341343
new_assignment_nodes.append(self.assignment_call_node(label.result, new_ast_node))
342-
343344
else:
344345
label.result += ' = '
345346
label.visit(value)
346-
347347
new_assignment_nodes.append(self.append_node(AssignmentNode(
348348
label.result,
349349
extract_left_hand_side(target),
350350
ast.Assign(target, value),
351-
right_hand_side_variables,
351+
rhs_visitor.result,
352+
line_number=node.lineno,
353+
path=self.filenames[-1]
354+
)))
355+
remaining_targets.remove(target)
356+
remaining_values.remove(value)
357+
for var in rhs_visitor.result:
358+
remaining_variables.remove(var)
359+
360+
# Pair targets and values until a Starred node is reached
361+
for target, value in zip(node.targets[0].elts, node.value.elts):
362+
if isinstance(target, ast.Starred) or isinstance(value, ast.Starred):
363+
break
364+
visit(target, value)
365+
366+
# If there was a Starred node, pair remaining targets and values from the end
367+
for target, value in zip(reversed(list(remaining_targets)), reversed(list(remaining_values))):
368+
if isinstance(target, ast.Starred) or isinstance(value, ast.Starred):
369+
break
370+
visit(target, value)
371+
372+
if remaining_targets:
373+
label = LabelVisitor()
374+
label.handle_comma_separated(remaining_targets)
375+
label.result += ' = '
376+
label.handle_comma_separated(remaining_values)
377+
for target in remaining_targets:
378+
new_assignment_nodes.append(self.append_node(AssignmentNode(
379+
label.result,
380+
extract_left_hand_side(target),
381+
ast.Assign(target, remaining_values[0]),
382+
remaining_variables,
352383
line_number=node.lineno,
353384
path=self.filenames[-1]
354385
)))

pyt/cfg/stmt_visitor_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def _get_names(node, result):
7979
return node.id + result
8080
elif isinstance(node, ast.Subscript):
8181
return result
82+
elif isinstance(node, ast.Starred):
83+
return _get_names(node.value, result)
8284
else:
8385
return _get_names(node.value, result + '.' + node.attr)
8486

tests/cfg/cfg_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import ast
2+
13
from .cfg_base_test_case import CFGBaseTestCase
24

35
from pyt.core.node_types import (
@@ -779,6 +781,45 @@ def test_assignment_tuple_value(self):
779781

780782
self.assertEqual(self.cfg.nodes[node].label, 'a = (x, y)')
781783

784+
def test_assignment_starred(self):
785+
self.cfg_create_from_file('examples/example_inputs/assignment_starred.py')
786+
787+
middle_nodes = self.cfg.nodes[1:-1]
788+
self.assert_length(middle_nodes, expected_length=5)
789+
790+
visited = [self.cfg.nodes[0]]
791+
while True:
792+
current_node = visited[-1]
793+
if len(current_node.outgoing) != 1:
794+
break
795+
visited.append(current_node.outgoing[0])
796+
self.assertCountEqual(self.cfg.nodes, visited, msg="Did not complete a path from Entry to Exit")
797+
798+
self.assertEqual(middle_nodes[0].label, 'a = f')
799+
self.assertCountEqual( # We don't assert a specific order for the assignment nodes
800+
[n.label for n in middle_nodes],
801+
['a = f', 'd = f + i', 'e = j'] + ['*b, c = *g, *h'] * 2,
802+
)
803+
self.assertCountEqual(
804+
[(n.left_hand_side, n.right_hand_side_variables) for n in middle_nodes],
805+
[('a', ['f']), ('b', ['g', 'h']), ('c', ['g', 'h']), ('d', ['f', 'i']), ('e', ['j'])],
806+
)
807+
808+
def test_assignment_starred_list(self):
809+
self.cfg_create_from_ast(ast.parse('[a, b, c] = *d, e'))
810+
811+
middle_nodes = self.cfg.nodes[1:-1]
812+
self.assert_length(middle_nodes, expected_length=3)
813+
814+
self.assertCountEqual(
815+
[n.label for n in middle_nodes],
816+
['a, b = *d', 'a, b = *d', 'c = e'],
817+
)
818+
self.assertCountEqual(
819+
[(n.left_hand_side, n.right_hand_side_variables) for n in middle_nodes],
820+
[('a', ['d']), ('b', ['d']), ('c', ['e'])],
821+
)
822+
782823

783824
class CFGComprehensionTest(CFGBaseTestCase):
784825
def test_nodes(self):

0 commit comments

Comments
 (0)