diff --git a/pyt/__main__.py b/pyt/__main__.py index c5ca05c5..95a48c17 100644 --- a/pyt/__main__.py +++ b/pyt/__main__.py @@ -20,7 +20,6 @@ ) from .github_search import scan_github, set_github_api_token from .interprocedural_cfg import interprocedural -from .intraprocedural_cfg import intraprocedural from .lattice import print_lattice from .liveness import LivenessAnalysis from .project_handler import get_directory_modules, get_modules @@ -106,8 +105,6 @@ def parse_args(args): ' reaching definitions tainted version.', action='store_true') - parser.add_argument('-intra', '--intraprocedural-analysis', - help='Run intraprocedural analysis.', action='store_true') parser.add_argument('-ppm', '--print-project-modules', help='Print project modules.', action='store_true') @@ -159,8 +156,18 @@ def parse_args(args): def analyse_repo(github_repo, analysis_type): cfg_list = list() - project_modules = get_modules(os.path.dirname(github_repo.path)) - intraprocedural(project_modules, cfg_list) + directory = os.path.dirname(github_repo.path) + project_modules = get_modules(directory) + local_modules = get_directory_modules(directory) + tree = generate_ast(github_repo.path, python_2=args.python_2) + interprocedural_cfg = interprocedural( + tree, + project_modules, + local_modules, + github_repo.path + ) + cfg_list.append(interprocedural_cfg) + initialize_constraint_table(cfg_list) analyse(cfg_list, analysis_type=analysis_type) vulnerability_log = find_vulnerabilities(cfg_list, analysis_type) @@ -214,25 +221,23 @@ def main(command_line_args=sys.argv[1:]): tree = generate_ast(path, python_2=args.python_2) cfg_list = list() - - if args.intraprocedural_analysis: - intraprocedural(project_modules, cfg_list) - else: - interprocedural_cfg = interprocedural(tree, - project_modules, - local_modules, - path) - cfg_list.append(interprocedural_cfg) - framework_route_criteria = is_flask_route_function - if args.adaptor: - if args.adaptor.lower().startswith('e'): - framework_route_criteria = is_function - elif args.adaptor.lower().startswith('p'): - framework_route_criteria = is_function_without_leading_ - elif args.adaptor.lower().startswith('d'): - framework_route_criteria = is_django_view_function - # Add all the route functions to the cfg_list - FrameworkAdaptor(cfg_list, project_modules, local_modules, framework_route_criteria) + interprocedural_cfg = interprocedural( + tree, + project_modules, + local_modules, + path + ) + cfg_list.append(interprocedural_cfg) + framework_route_criteria = is_flask_route_function + if args.adaptor: + if args.adaptor.lower().startswith('e'): + framework_route_criteria = is_function + elif args.adaptor.lower().startswith('p'): + framework_route_criteria = is_function_without_leading_ + elif args.adaptor.lower().startswith('d'): + framework_route_criteria = is_django_view_function + # Add all the route functions to the cfg_list + FrameworkAdaptor(cfg_list, project_modules, local_modules, framework_route_criteria) initialize_constraint_table(cfg_list) diff --git a/pyt/alias_helper.py b/pyt/alias_helper.py index 648665fc..be98cfa0 100644 --- a/pyt/alias_helper.py +++ b/pyt/alias_helper.py @@ -10,6 +10,7 @@ def as_alias_handler(alias_list): list_.append(alias.name) return list_ + def handle_aliases_in_calls(name, import_alias_mapping): """Returns either None or the handled alias. Used in add_module. @@ -26,6 +27,7 @@ def handle_aliases_in_calls(name, import_alias_mapping): return name.replace(key, val) return None + def handle_aliases_in_init_files(name, import_alias_mapping): """Returns either None or the handled alias. Used in add_module. @@ -42,6 +44,7 @@ def handle_aliases_in_init_files(name, import_alias_mapping): return name.replace(val, key) return None + def handle_fdid_aliases(module_or_package_name, import_alias_mapping): """Returns either None or the handled alias. Used in add_module. @@ -52,6 +55,7 @@ def handle_fdid_aliases(module_or_package_name, import_alias_mapping): return key return None + def not_as_alias_handler(names_list): """Returns a list of names ignoring any aliases.""" list_ = list() @@ -59,10 +63,11 @@ def not_as_alias_handler(names_list): list_.append(alias.name) return list_ + def retrieve_import_alias_mapping(names_list): """Creates a dictionary mapping aliases to their respective name. import_alias_names is used in module_definitions.py and visit_Call""" - import_alias_names = {} + import_alias_names = dict() for alias in names_list: if alias.asname: diff --git a/pyt/base_cfg.py b/pyt/base_cfg.py index 8fb5ab7d..dfe5a30a 100644 --- a/pyt/base_cfg.py +++ b/pyt/base_cfg.py @@ -1,363 +1,44 @@ import ast import itertools -from collections import namedtuple -from .ast_helper import Arguments, get_call_names_as_string +from .ast_helper import ( + get_call_names_as_string +) +from .base_cfg_helper import ( + CALL_IDENTIFIER, + ConnectStatements, + connect_nodes, + extract_left_hand_side, + get_first_node, + get_first_statement, + get_last_statements, + remove_breaks +) from .label_visitor import LabelVisitor +from .node_types import ( + AssignmentNode, + AssignmentCallNode, + BBorBInode, + BreakNode, + ControlFlowNode, + IgnoredNode, + Node, + RestoreNode +) from .right_hand_side_visitor import RHSVisitor from .vars_visitor import VarsVisitor -ControlFlowNode = namedtuple('ControlFlowNode', - 'test last_nodes break_statements') - -ConnectStatements = namedtuple('ConnectStatements', - 'first_statement' + - ' last_statements' + - ' break_statements') -CALL_IDENTIFIER = '¤' - - -class IgnoredNode(): - """Ignored Node sent from an ast node that should not return anything.""" - - -class Node(): - """A Control Flow Graph node that contains a list of - ingoing and outgoing nodes and a list of its variables.""" - - def __init__(self, label, ast_node, *, line_number, path): - """Create a Node that can be used in a CFG. - - Args: - label(str): The label of the node, describing its expression. - line_number(Optional[int]): The line of the expression of the Node. - """ - self.label = label - self.ast_node = ast_node - self.line_number = line_number - self.path = path - self.ingoing = list() - self.outgoing = list() - - def connect(self, successor): - """Connect this node to its successor node by - setting its outgoing and the successors ingoing.""" - if isinstance(self, ConnectToExitNode) and\ - not isinstance(successor, EntryOrExitNode): - return - - self.outgoing.append(successor) - successor.ingoing.append(self) - - def connect_predecessors(self, predecessors): - """Connect all nodes in predecessors to this node.""" - for n in predecessors: - self.ingoing.append(n) - n.outgoing.append(self) - - def __str__(self): - """Print the label of the node.""" - return ''.join((' Label: ', self.label)) - - - def __repr__(self): - """Print a representation of the node.""" - label = ' '.join(('Label: ', self.label)) - line_number = 'Line number: ' + str(self.line_number) - outgoing = '' - ingoing = '' - if self.ingoing: - ingoing = ' '.join(('ingoing:\t', str([x.label for x in self.ingoing]))) - else: - ingoing = ' '.join(('ingoing:\t', '[]')) - - if self.outgoing: - outgoing = ' '.join(('outgoing:\t', str([x.label for x in self.outgoing]))) - else: - outgoing = ' '.join(('outgoing:\t', '[]')) - - return '\n' + '\n'.join((label, line_number, ingoing, outgoing)) - - -class ConnectToExitNode(): - pass - - -class FunctionNode(Node): - """CFG Node that represents a function definition. - - Used as a dummy for creating a list of function definitions. - """ - - def __init__(self, ast_node): - """Create a function node. - - This node is a dummy node representing a function definition. - """ - super().__init__(self.__class__.__name__, ast_node) - - -class RaiseNode(Node, ConnectToExitNode): - """CFG Node that represents a Raise statement.""" - - def __init__(self, label, ast_node, *, line_number, path): - """Create a Raise node.""" - super().__init__(label, ast_node, line_number=line_number, path=path) - - -class BreakNode(Node): - """CFG Node that represents a Break node.""" - - def __init__(self, ast_node, *, line_number, path): - super().__init__(self.__class__.__name__, ast_node, line_number=line_number, path=path) - - -class EntryOrExitNode(Node): - """CFG Node that represents an Exit or an Entry node.""" - - def __init__(self, label): - super().__init__(label, None, line_number=None, path=None) - - -class AssignmentNode(Node): - """CFG Node that represents an assignment.""" - - def __init__(self, label, left_hand_side, ast_node, right_hand_side_variables, *, line_number, path): - """Create an Assignment node. - - Args: - label(str): The label of the node, describing the expression it represents. - left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis. - ast_node(_ast.Assign, _ast.AugAssign, _ast.Return or None) - right_hand_side_variables(list[str]): A list of variables on the right hand side. - line_number(Optional[int]): The line of the expression the Node represents. - path(string): Current filename. - """ - super().__init__(label, ast_node, line_number=line_number, path=path) - self.left_hand_side = left_hand_side - self.right_hand_side_variables = right_hand_side_variables - - def __repr__(self): - output_string = super().__repr__() - output_string += '\n' - return ''.join((output_string, - 'left_hand_side:\t', str(self.left_hand_side), '\n', - 'right_hand_side_variables:\t', str(self.right_hand_side_variables))) - - -class TaintedNode(AssignmentNode): - pass - - -class RestoreNode(AssignmentNode): - """Node used for handling restore nodes returning from function calls.""" - - def __init__(self, label, left_hand_side, right_hand_side_variables, *, line_number, path): - """Create a Restore node. - - Args: - label(str): The label of the node, describing the expression it represents. - left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis. - right_hand_side_variables(list[str]): A list of variables on the right hand side. - line_number(Optional[int]): The line of the expression the Node represents. - path(string): Current filename. - """ - super().__init__(label, left_hand_side, None, right_hand_side_variables, line_number=line_number, path=path) - - -class BBorBInode(AssignmentNode): - """Node used for handling restore nodes returning from blackbox or builtin function calls.""" - - def __init__(self, label, left_hand_side, right_hand_side_variables, *, line_number, path): - """Create a Restore node. - - Args: - label(str): The label of the node, describing the expression it represents. - left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis. - right_hand_side_variables(list[str]): A list of variables on the right hand side. - line_number(Optional[int]): The line of the expression the Node represents. - path(string): Current filename. - """ - super().__init__(label, left_hand_side, None, right_hand_side_variables, line_number=line_number, path=path) - self.args = list() - self.inner_most_call = self - - -class AssignmentCallNode(AssignmentNode): - """Node used for X.""" - - def __init__(self, - label, - left_hand_side, - ast_node, - right_hand_side_variables, - vv_result, - *, - line_number, - path, - call_node): - """Create a X. - - Args: - label(str): The label of the node, describing the expression it represents. - left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis. - right_hand_side_variables(list[str]): A list of variables on the right hand side. - vv_result(list[str]): Necessary to know `image_name = image_name.replace('..', '')` is a reassignment. - line_number(Optional[int]): The line of the expression the Node represents. - path(string): Current filename. - call_node(BBorBInode or RestoreNode): Used in connect_control_flow_node. - """ - super().__init__(label, left_hand_side, ast_node, right_hand_side_variables, line_number=line_number, path=path) - self.vv_result = vv_result - self.call_node = call_node - self.blackbox = False - - -class ReturnNode(AssignmentNode, ConnectToExitNode): - """CFG node that represents a return from a call.""" - - def __init__(self, label, left_hand_side, ast_node, right_hand_side_variables, *, line_number, path): - """Create a CallReturn node. - - Args: - label(str): The label of the node, describing the expression it represents. - restore_nodes(list[Node]): List of nodes that were restored in the function call. - right_hand_side_variables(list[str]): A list of variables on the right hand side. - line_number(Optional[int]): The line of the expression the Node represents. - path(string): Current filename. - """ - super().__init__(label, left_hand_side, ast_node, right_hand_side_variables, line_number=line_number, path=path) - - -class Function(): - """Representation of a function definition in the program.""" - - def __init__(self, nodes, args, decorator_list): - """Create a Function representation. - - Args: - nodes(list[Node]): The CFG of the Function. - args(ast.args): The arguments from a function AST node. - decorator_list(list[ast.decorator]): The list of decorators - from a function AST node. - """ - self.nodes = nodes - self.arguments = Arguments(args) - self.decorator_list = decorator_list - - def __repr__(self): - output = '' - for x, n in enumerate(self.nodes): - output = ''.join((output, 'Node: ' + str(x) + ' ' + repr(n), '\n\n')) - return output - - def __str__(self): - output = '' - for x, n in enumerate(self.nodes): - output = ''.join((output, 'Node: ' + str(x) + ' ' + str(n), '\n\n')) - return output - - -class CFG(): - def __init__(self, nodes, blackbox_assignments): - self.nodes = nodes - self.blackbox_assignments = blackbox_assignments - - def __repr__(self): - output = '' - for x, n in enumerate(self.nodes): - output = ''.join((output, 'Node: ' + str(x) + ' ' + repr(n), '\n\n')) - return output - - def __str__(self): - output = '' - for x, n in enumerate(self.nodes): - output = ''.join((output, 'Node: ' + str(x) + ' ' + str(n), '\n\n')) - return output - - class Visitor(ast.NodeVisitor): - def append_node(self, Node): - """Append a node to the CFG and return it.""" - self.nodes.append(Node) - return Node - - def get_first_statement(self, node_or_tuple): - """Find the first statement of the provided object. - - Returns: - The first element in the tuple if it is a tuple. - The node if it is a node. - """ - if isinstance(node_or_tuple, tuple): - return node_or_tuple[0] - else: - return node_or_tuple - - def should_connect_node(self, node): - """Determine if node should be in the final CFG.""" - if isinstance(node, (FunctionNode, IgnoredNode)): - return False - else: - return True - - def get_inner_most_function_call(self, call_node): - # Loop to inner most function call - # e.g. return scrypt.inner in `foo = scrypt.outer(scrypt.inner(image_name))` - old_call_node = None - while call_node != old_call_node: - old_call_node = call_node - if isinstance(call_node, BBorBInode): - call_node = call_node.inner_most_call - else: - try: - call_node = call_node.first_node.inner_most_call - except AttributeError: - try: - call_node = call_node.first_node - except AttributeError: - # No inner calls - # Possible improvement: Make new node for RestoreNode's made in process_function - # and make `self.inner_most_call = self` - pass - return call_node - - def connect_control_flow_node(self, control_flow_node, next_node): - """Connect a ControlFlowNode properly to the next_node.""" - for last in control_flow_node[1]: # list of last nodes in ifs and elifs - if isinstance(next_node, ControlFlowNode): - last.connect(next_node.test) # connect to next if test case - elif isinstance(next_node, AssignmentCallNode): - call_node = next_node.call_node - inner_most_call_node = self.get_inner_most_function_call(call_node) - last.connect(inner_most_call_node) - else: - last.connect(next_node) - - def connect_nodes(self, nodes): - """Connect the nodes in a list linearly.""" - for n, next_node in zip(nodes, nodes[1:]): - if isinstance(n, ControlFlowNode): # case for if - self.connect_control_flow_node(n, next_node) - elif isinstance(next_node, ControlFlowNode): # case for if - n.connect(next_node[0]) - elif isinstance(next_node, RestoreNode): - continue - elif CALL_IDENTIFIER in next_node.label: - continue - else: - n.connect(next_node) - - def get_last_statements(self, cfg_statements): - """Retrieve the last statements from a cfg_statements list.""" - if isinstance(cfg_statements[-1], ControlFlowNode): - return cfg_statements[-1].last_nodes - else: - return [cfg_statements[-1]] + def visit_Module(self, node): + return self.stmt_star_handler(node.body) - def stmt_star_handler(self, stmts, prev_node_to_avoid=None): + def stmt_star_handler( + self, + stmts, + prev_node_to_avoid=None + ): """Handle stmt* expressions in an AST node. Links all statements together in a list of statements, accounting for statements with multiple last nodes. @@ -373,7 +54,7 @@ def stmt_star_handler(self, stmts, prev_node_to_avoid=None): for stmt in stmts: node = self.visit(stmt) - if isinstance(stmt, ast.While) or isinstance(stmt, ast.For): + if isinstance(stmt, (ast.For, ast.While)): self.last_was_loop_stack.append(True) else: self.last_was_loop_stack.append(False) @@ -383,62 +64,38 @@ def stmt_star_handler(self, stmts, prev_node_to_avoid=None): elif isinstance(node, BreakNode): break_nodes.append(node) - if node and not first_node: # (Make sure first_node isn't already set.) - # first_node is always a "node_to_connect", because it won't have ingoing otherwise - # If we have e.g. - # import os # An ignored node - # value = None - # first_node will be `value = None` - if hasattr(node, 'ingoing'): - ingoing = None - current_node = node - while current_node.ingoing: - # e.g. We don't want to step past the Except of an Except BB - if current_node.ingoing[0] == node_not_to_step_past: - break - ingoing = current_node.ingoing - current_node = current_node.ingoing[0] - if ingoing: - # Only set it once - first_node = ingoing[0] - - if node and self.should_connect_node(node): + if not isinstance(node, IgnoredNode): + cfg_statements.append(node) if not first_node: if isinstance(node, ControlFlowNode): first_node = node.test else: - first_node = node - cfg_statements.append(node) + first_node = get_first_node( + node, + node_not_to_step_past + ) if prev_node_to_avoid: self.prev_nodes_to_avoid.pop() self.last_was_loop_stack.pop() - self.connect_nodes(cfg_statements) + connect_nodes(cfg_statements) if cfg_statements: if first_node: first_statement = first_node else: - first_statement = self.get_first_statement(cfg_statements[0]) + first_statement = get_first_statement(cfg_statements[0]) - last_statements = self.get_last_statements(cfg_statements) + last_statements = get_last_statements(cfg_statements) - return ConnectStatements(first_statement=first_statement, - last_statements=last_statements, - break_statements=break_nodes) + return ConnectStatements( + first_statement=first_statement, + last_statements=last_statements, + break_statements=break_nodes + ) else: # When body of module only contains ignored nodes return IgnoredNode() - def visit_Module(self, node): - return self.stmt_star_handler(node.body) - - def add_if_label(self, CFG_node): - """Prepend 'if ' and append ':' to the label of a Node.""" - CFG_node.label = 'if ' + CFG_node.label + ':' - - def add_elif_label(self, CFG_node): - """Add the el to an already add_if_label'ed Node.""" - CFG_node.label = 'el' + CFG_node.label def handle_or_else(self, orelse, test): """Handle the orelse part of an if or try node. @@ -448,7 +105,8 @@ def handle_or_else(self, orelse, test): """ if isinstance(orelse[0], ast.If): control_flow_node = self.visit(orelse[0]) - self.add_elif_label(control_flow_node.test) + control_flow_node.test.label = 'el' + control_flow_node.test.label + test.connect(control_flow_node.test) return control_flow_node.last_nodes else: @@ -456,21 +114,24 @@ def handle_or_else(self, orelse, test): test.connect(else_connect_statements.first_statement) return else_connect_statements.last_statements - def remove_breaks(self, last_statements): - """Remove all break statements in last_statements.""" - return [n for n in last_statements if not isinstance(n, BreakNode)] - def visit_If(self, node): label_visitor = LabelVisitor() label_visitor.visit(node.test) - test = self.append_node(Node(label_visitor.result, node, line_number=node.lineno, path=self.filenames[-1])) - - self.add_if_label(test) + test = self.append_node(Node( + 'if ' + label_visitor.result + ':', + node, + line_number=node.lineno, + path=self.filenames[-1] + )) body_connect_stmts = self.stmt_star_handler(node.body) if isinstance(body_connect_stmts, IgnoredNode): - body_connect_stmts = ConnectStatements(first_statement=test, last_statements=[], break_statements=[]) + body_connect_stmts = ConnectStatements( + first_statement=test, + last_statements=[], + break_statements=[] + ) test.connect(body_connect_stmts.first_statement) if node.orelse: @@ -479,30 +140,39 @@ def visit_If(self, node): else: body_connect_stmts.last_statements.append(test) # if there is no orelse, test needs an edge to the next_node - last_statements = self.remove_breaks(body_connect_stmts.last_statements) + last_statements = remove_breaks(body_connect_stmts.last_statements) return ControlFlowNode(test, last_statements, break_statements=body_connect_stmts.break_statements) - def visit_NameConstant(self, node): - label_visitor = LabelVisitor() - label_visitor.visit(node) - return self.append_node(Node(label_visitor.result, node, line_number=node.lineno, path=self.filenames[-1])) - def visit_Raise(self, node): label = LabelVisitor() label.visit(node) - return self.append_node(RaiseNode(label.result, node, line_number=node.lineno, path=self.filenames[-1])) + return self.append_node(RaiseNode( + label.result, + node, + line_number=node.lineno, + path=self.filenames[-1] + )) def handle_stmt_star_ignore_node(self, body, fallback_cfg_node): try: fallback_cfg_node.connect(body.first_statement) except AttributeError: - body = ConnectStatements([fallback_cfg_node], [fallback_cfg_node], list()) + body = ConnectStatements( + first_statement=[fallback_cfg_node], + last_statements=[fallback_cfg_node], + break_statements=[] + ) return body def visit_Try(self, node): - try_node = self.append_node(Node('Try', node, line_number=node.lineno, path=self.filenames[-1])) + try_node = self.append_node(Node( + 'Try', + node, + line_number=node.lineno, + path=self.filenames[-1] + )) body = self.stmt_star_handler(node.body) body = self.handle_stmt_star_ignore_node(body, try_node) @@ -538,33 +208,10 @@ def visit_Try(self, node): body.last_statements.extend(finalbody.last_statements) - last_statements.extend(self.remove_breaks(body.last_statements)) + last_statements.extend(remove_breaks(body.last_statements)) return ControlFlowNode(try_node, last_statements, break_statements=body.break_statements) - def get_names(self, node, result): - """Recursively finds all names.""" - if isinstance(node, ast.Name): - return node.id + result - elif isinstance(node, ast.Subscript): - return result - else: - return self.get_names(node.value, result + '.' + node.attr) - - def extract_left_hand_side(self, target): - """Extract the left hand side variable from a target. - - Removes list indexes, stars and other left hand side elements. - """ - left_hand_side = self.get_names(target, '') - - left_hand_side.replace('*', '') - if '[' in left_hand_side: - index = left_hand_side.index('[') - left_hand_side = target[0:index] - - return left_hand_side - def assign_tuple_target(self, node, right_hand_side_variables): new_assignment_nodes = list() for i, target in enumerate(node.targets[0].elts): @@ -585,14 +232,14 @@ def assign_tuple_target(self, node, right_hand_side_variables): new_assignment_nodes.append(self.append_node(AssignmentNode( label.result, - self.extract_left_hand_side(target), + extract_left_hand_side(target), ast.Assign(target, value), right_hand_side_variables, line_number=node.lineno, path=self.filenames[-1] ))) - self.connect_nodes(new_assignment_nodes) + connect_nodes(new_assignment_nodes) return ControlFlowNode(new_assignment_nodes[0], [new_assignment_nodes[-1]], []) # return the last added node def assign_multi_target(self, node, right_hand_side_variables): @@ -612,7 +259,7 @@ def assign_multi_target(self, node, right_hand_side_variables): line_number=node.lineno, path=self.filenames[-1] ))) - self.connect_nodes(new_assignment_nodes) + connect_nodes(new_assignment_nodes) return ControlFlowNode(new_assignment_nodes[0], [new_assignment_nodes[-1]], []) # return the last added node def visit_Assign(self, node): @@ -655,7 +302,7 @@ def visit_Assign(self, node): label.visit(node) return self.append_node(AssignmentNode( label.result, - self.extract_left_hand_side(node.targets[0]), + extract_left_hand_side(node.targets[0]), node, rhs_visitor.result, line_number=node.lineno, @@ -714,54 +361,13 @@ def visit_AugAssign(self, node): return self.append_node(AssignmentNode( label.result, - self.extract_left_hand_side(node.target), + extract_left_hand_side(node.target), node, rhs_visitor.result, line_number=node.lineno, path=self.filenames[-1] )) - def loop_node_skeleton(self, test, node): - """Common handling of looped structures, while and for.""" - body_connect_stmts = self.stmt_star_handler(node.body, prev_node_to_avoid=self.nodes[-1]) - - test.connect(body_connect_stmts.first_statement) - test.connect_predecessors(body_connect_stmts.last_statements) - - # last_nodes is used for making connections to the next node in the parent node - # this is handled in stmt_star_handler - last_nodes = list() - last_nodes.extend(body_connect_stmts.break_statements) - - if node.orelse: - orelse_connect_stmts = self.stmt_star_handler(node.orelse, prev_node_to_avoid=self.nodes[-1]) - - test.connect(orelse_connect_stmts.first_statement) - last_nodes.extend(orelse_connect_stmts.last_statements) - else: - last_nodes.append(test) # if there is no orelse, test needs an edge to the next_node - - return ControlFlowNode(test, last_nodes, list()) - - def add_while_label(self, node): - """Prepend 'while' and append ':' to the label of a node.""" - node.label = 'while ' + node.label + ':' - - def visit_While(self, node): - label_visitor = LabelVisitor() - label_visitor.visit(node.test) - - test = self.append_node(Node( - label_visitor.result, - node, - line_number=node.lineno, - path=self.filenames[-1] - )) - - self.add_while_label(test) - - return self.loop_node_skeleton(test, node) - def visit_For(self, node): self.undecided = True # Used for handling functions in for loops @@ -785,8 +391,40 @@ def visit_For(self, node): return self.loop_node_skeleton(for_node, node) - def visit_Expr(self, node): - return self.visit(node.value) + def visit_While(self, node): + label_visitor = LabelVisitor() + label_visitor.visit(node.test) + + test = self.append_node(Node( + 'while ' + label_visitor.result + ':', + node, + line_number=node.lineno, + path=self.filenames[-1] + )) + + return self.loop_node_skeleton(test, node) + + def loop_node_skeleton(self, test, node): + """Common handling of looped structures, while and for.""" + body_connect_stmts = self.stmt_star_handler(node.body, prev_node_to_avoid=self.nodes[-1]) + + test.connect(body_connect_stmts.first_statement) + test.connect_predecessors(body_connect_stmts.last_statements) + + # last_nodes is used for making connections to the next node in the parent node + # this is handled in stmt_star_handler + last_nodes = list() + last_nodes.extend(body_connect_stmts.break_statements) + + if node.orelse: + orelse_connect_stmts = self.stmt_star_handler(node.orelse, prev_node_to_avoid=self.nodes[-1]) + + test.connect(orelse_connect_stmts.first_statement) + last_nodes.extend(orelse_connect_stmts.last_statements) + else: + last_nodes.append(test) # if there is no orelse, test needs an edge to the next_node + + return ControlFlowNode(test, last_nodes, list()) def add_blackbox_or_builtin_call(self, node, blackbox): """Processes a blackbox or builtin function when it is called. @@ -892,12 +530,6 @@ def add_blackbox_or_builtin_call(self, node, blackbox): return call_node - def visit_Name(self, node): - label = LabelVisitor() - label.visit(node) - - return self.append_node(Node(label.result, node, line_number=node.lineno, path=self.filenames[-1])) - def visit_With(self, node): label_visitor = LabelVisitor() label_visitor.visit(node.items[0]) @@ -910,25 +542,14 @@ def visit_With(self, node): )) connect_statements = self.stmt_star_handler(node.body) with_node.connect(connect_statements.first_statement) - return ControlFlowNode(with_node, connect_statements.last_statements, connect_statements.break_statements) - - def visit_Str(self, node): - return IgnoredNode() + return ControlFlowNode( + with_node, + connect_statements.last_statements, + connect_statements.break_statements + ) def visit_Break(self, node): - return self.append_node(BreakNode(node, line_number=node.lineno, path=self.filenames[-1])) - - def visit_Pass(self, node): - return self.append_node(Node( - 'pass', - node, - line_number=node.lineno, - path=self.filenames[-1] - )) - - def visit_Continue(self, node): - return self.append_node(Node( - 'continue', + return self.append_node(BreakNode( node, line_number=node.lineno, path=self.filenames[-1] @@ -957,45 +578,73 @@ def visit_Assert(self, node): )) def visit_Attribute(self, node): - label_visitor = LabelVisitor() - label_visitor.visit(node) + return self.visit_miscelleaneous_node( + node + ) - return self.append_node(Node( - label_visitor.result, + def visit_Continue(self, node): + return self.visit_miscelleaneous_node( node, - line_number=node.lineno, - path=self.filenames[-1] - )) + custom_label='continue' + ) def visit_Global(self, node): - label_visitor = LabelVisitor() - label_visitor.visit(node) + return self.visit_miscelleaneous_node( + node + ) - return self.append_node(Node( - label_visitor.result, - node, - line_number=node.lineno, - path=self.filenames[-1] - )) + def visit_Name(self, node): + return self.visit_miscelleaneous_node( + node + ) - def visit_Subscript(self, node): - label_visitor = LabelVisitor() - label_visitor.visit(node) + def visit_NameConstant(self, node): + return self.visit_miscelleaneous_node( + node + ) - return self.append_node(Node( - label_visitor.result, + def visit_Pass(self, node): + return self.visit_miscelleaneous_node( node, - line_number=node.lineno, - path=self.filenames[-1] - )) + custom_label='pass' + ) + + def visit_Subscript(self, node): + return self.visit_miscelleaneous_node( + node + ) def visit_Tuple(self, node): - label_visitor = LabelVisitor() - label_visitor.visit(node) + return self.visit_miscelleaneous_node( + node + ) + + def visit_miscelleaneous_node( + self, + node, + custom_label=None + ): + if custom_label: + label = custom_label + else: + label_visitor = LabelVisitor() + label_visitor.visit(node) + label = label_visitor.result return self.append_node(Node( - label_visitor.result, + label, node, line_number=node.lineno, path=self.filenames[-1] )) + + def visit_Str(self, node): + return IgnoredNode() + + def visit_Expr(self, node): + return self.visit(node.value) + + def append_node(self, node): + """Append a node to the CFG and return it.""" + self.nodes.append(node) + return node diff --git a/pyt/base_cfg_helper.py b/pyt/base_cfg_helper.py new file mode 100644 index 00000000..fd1bc43f --- /dev/null +++ b/pyt/base_cfg_helper.py @@ -0,0 +1,137 @@ +import ast +from collections import namedtuple + +from .node_types import ( + AssignmentCallNode, + BBorBInode, + BreakNode, + ControlFlowNode, + RestoreNode +) + + +CALL_IDENTIFIER = '¤' + + +ConnectStatements = namedtuple('ConnectStatements', + 'first_statement' + + ' last_statements' + + ' break_statements') + + +def _get_inner_most_function_call(call_node): + # Loop to inner most function call + # e.g. return scrypt.inner in `foo = scrypt.outer(scrypt.inner(image_name))` + old_call_node = None + while call_node != old_call_node: + old_call_node = call_node + if isinstance(call_node, BBorBInode): + call_node = call_node.inner_most_call + else: + try: + call_node = call_node.first_node.inner_most_call + except AttributeError: + try: + call_node = call_node.first_node + except AttributeError: + # No inner calls + # Possible improvement: Make new node for RestoreNode's made in process_function + # and make `self.inner_most_call = self` + pass + return call_node + + +def _connect_control_flow_node(control_flow_node, next_node): + """Connect a ControlFlowNode properly to the next_node.""" + for last in control_flow_node[1]: # list of last nodes in ifs and elifs + if isinstance(next_node, ControlFlowNode): + last.connect(next_node.test) # connect to next if test case + elif isinstance(next_node, AssignmentCallNode): + call_node = next_node.call_node + inner_most_call_node = _get_inner_most_function_call(call_node) + last.connect(inner_most_call_node) + else: + last.connect(next_node) + + +def connect_nodes(nodes): + """Connect the nodes in a list linearly.""" + for n, next_node in zip(nodes, nodes[1:]): + if isinstance(n, ControlFlowNode): # case for if + _connect_control_flow_node(n, next_node) + elif isinstance(next_node, ControlFlowNode): # case for if + n.connect(next_node[0]) + elif isinstance(next_node, RestoreNode): + continue + elif CALL_IDENTIFIER in next_node.label: + continue + else: + n.connect(next_node) + + +def _get_names(node, result): + """Recursively finds all names.""" + if isinstance(node, ast.Name): + return node.id + result + elif isinstance(node, ast.Subscript): + return result + else: + return _get_names(node.value, result + '.' + node.attr) + + +def extract_left_hand_side(target): + """Extract the left hand side variable from a target. + + Removes list indexes, stars and other left hand side elements. + """ + left_hand_side = _get_names(target, '') + + left_hand_side.replace('*', '') + if '[' in left_hand_side: + index = left_hand_side.index('[') + left_hand_side = target[0:index] + + return left_hand_side + + +def get_first_node( + node, + node_not_to_step_past +): + ingoing = None + current_node = node + while current_node.ingoing: + # e.g. We don't want to step past the Except of an Except basic block + if current_node.ingoing[0] == node_not_to_step_past: + break + ingoing = current_node.ingoing + current_node = current_node.ingoing[0] + if ingoing: + return ingoing[0] + return current_node + + +def get_first_statement(node_or_tuple): + """Find the first statement of the provided object. + + Returns: + The first element in the tuple if it is a tuple. + The node if it is a node. + """ + if isinstance(node_or_tuple, tuple): + return node_or_tuple[0] + else: + return node_or_tuple + + +def get_last_statements(cfg_statements): + """Retrieve the last statements from a cfg_statements list.""" + if isinstance(cfg_statements[-1], ControlFlowNode): + return cfg_statements[-1].last_nodes + else: + return [cfg_statements[-1]] + + +def remove_breaks(last_statements): + """Remove all break statements in last_statements.""" + return [n for n in last_statements if not isinstance(n, BreakNode)] diff --git a/pyt/definition_chains.py b/pyt/definition_chains.py index 2ac0e077..c4435488 100644 --- a/pyt/definition_chains.py +++ b/pyt/definition_chains.py @@ -1,8 +1,8 @@ import ast -from .base_cfg import AssignmentNode from .constraint_table import constraint_table from .lattice import Lattice +from .node_types import AssignmentNode from .reaching_definitions import ReachingDefinitionsAnalysis from .vars_visitor import VarsVisitor diff --git a/pyt/draw.py b/pyt/draw.py index 9231d9a7..0e372557 100644 --- a/pyt/draw.py +++ b/pyt/draw.py @@ -4,7 +4,7 @@ from itertools import permutations from subprocess import run -from .base_cfg import AssignmentNode +from .node_types import AssignmentNode IGNORED_LABEL_NAME_CHARACHTERS = ':' diff --git a/pyt/framework_adaptor.py b/pyt/framework_adaptor.py index c56dfe50..78f88ec9 100644 --- a/pyt/framework_adaptor.py +++ b/pyt/framework_adaptor.py @@ -2,12 +2,12 @@ import ast from .ast_helper import Arguments -from .base_cfg import ( +from .interprocedural_cfg import interprocedural +from .module_definitions import project_definitions +from .node_types import ( AssignmentNode, TaintedNode ) -from .interprocedural_cfg import interprocedural -from .module_definitions import project_definitions class FrameworkAdaptor(): diff --git a/pyt/interprocedural_cfg.py b/pyt/interprocedural_cfg.py index a500fa54..062885ae 100644 --- a/pyt/interprocedural_cfg.py +++ b/pyt/interprocedural_cfg.py @@ -1,6 +1,5 @@ import ast import os.path -from collections import namedtuple from .alias_helper import ( as_alias_handler, @@ -10,55 +9,44 @@ not_as_alias_handler, retrieve_import_alias_mapping ) -from .ast_helper import Arguments, generate_ast, get_call_names_as_string +from .ast_helper import ( + Arguments, + generate_ast, + get_call_names_as_string +) from .base_cfg import ( + Visitor +) +from .base_cfg_helper import ( + CALL_IDENTIFIER +) +from .interprocedural_cfg_helper import ( + BUILTINS, + CFG, + return_connection_handler, + SavedVariable +) +from .label_visitor import LabelVisitor +from .module_definitions import ( + LocalModuleDefinition, + ModuleDefinition, + ModuleDefinitions +) +from .node_types import ( AssignmentCallNode, AssignmentNode, BBorBInode, - CALL_IDENTIFIER, - CFG, ConnectToExitNode, EntryOrExitNode, IgnoredNode, Node, RestoreNode, - ReturnNode, - Visitor -) -from .label_visitor import LabelVisitor -from .module_definitions import ( - LocalModuleDefinition, - ModuleDefinition, - ModuleDefinitions + ReturnNode ) from .project_handler import get_directory_modules from .right_hand_side_visitor import RHSVisitor -SavedVariable = namedtuple('SavedVariable', 'LHS RHS') -BUILTINS = ( - 'get', - 'Flask', - 'run', - 'replace', - 'read', - 'set_cookie', - 'make_response', - 'SQLAlchemy', - 'Column', - 'execute', - 'sessionmaker', - 'Session', - 'filter', - 'call', - 'render_template', - 'redirect', - 'url_for', - 'flash', - 'jsonify' -) - - class InterproceduralVisitor(Visitor): def __init__(self, node, project_modules, local_modules, filename, module_definitions=None): @@ -93,19 +81,19 @@ def init_cfg(self, node): raise Exception('Empty module. It seems that your file is empty,' + 'there is nothing to analyse.') - if not isinstance(module_statements, IgnoredNode): - first_node = module_statements.first_statement + exit_node = self.append_node(EntryOrExitNode("Exit module")) - if CALL_IDENTIFIER not in first_node.label: - entry_node.connect(first_node) + if isinstance(module_statements, IgnoredNode): + entry_node.connect(exit_node) + return - exit_node = self.append_node(EntryOrExitNode("Exit module")) + first_node = module_statements.first_statement - last_nodes = module_statements.last_statements - exit_node.connect_predecessors(last_nodes) - else: - exit_node = self.append_node(EntryOrExitNode("Exit module")) - entry_node.connect(exit_node) + if CALL_IDENTIFIER not in first_node.label: + entry_node.connect(first_node) + + last_nodes = module_statements.last_statements + exit_node.connect_predecessors(last_nodes) def init_function_cfg(self, node, module_definitions): self.module_definitions_stack.append(module_definitions) @@ -116,46 +104,26 @@ def init_function_cfg(self, node, module_definitions): entry_node = self.append_node(EntryOrExitNode("Entry function")) module_statements = self.stmt_star_handler(node.body) + exit_node = self.append_node(EntryOrExitNode("Exit function")) + + if isinstance(module_statements, IgnoredNode): + entry_node.connect(exit_node) + return first_node = module_statements.first_statement if CALL_IDENTIFIER not in first_node.label: entry_node.connect(first_node) - exit_node = self.append_node(EntryOrExitNode("Exit function")) - last_nodes = module_statements.last_statements exit_node.connect_predecessors(last_nodes) - def visit_ClassDef(self, node): - self.add_to_definitions(node) - - local_definitions = self.module_definitions_stack[-1] - local_definitions.classes.append(node.name) - - parent_definitions = self.get_parent_definitions() - if parent_definitions: - parent_definitions.classes.append(node.name) - - self.stmt_star_handler(node.body) - - local_definitions.classes.pop() - if parent_definitions: - parent_definitions.classes.pop() - - return IgnoredNode() - def get_parent_definitions(self): parent_definitions = None if len(self.module_definitions_stack) > 1: parent_definitions = self.module_definitions_stack[-2] return parent_definitions - def visit_FunctionDef(self, node): - self.add_to_definitions(node) - - return IgnoredNode() - def add_to_definitions(self, node): local_definitions = self.module_definitions_stack[-1] parent_definitions = self.get_parent_definitions() @@ -187,12 +155,28 @@ def add_to_definitions(self, node): self.function_names.append(node.name) - def return_connection_handler(self, nodes, exit_node): - """Connect all return statements to the Exit node.""" - for function_body_node in nodes: - if isinstance(function_body_node, ConnectToExitNode): - if exit_node not in function_body_node.outgoing: - function_body_node.connect(exit_node) + def visit_ClassDef(self, node): + self.add_to_definitions(node) + + local_definitions = self.module_definitions_stack[-1] + local_definitions.classes.append(node.name) + + parent_definitions = self.get_parent_definitions() + if parent_definitions: + parent_definitions.classes.append(node.name) + + self.stmt_star_handler(node.body) + + local_definitions.classes.pop() + if parent_definitions: + parent_definitions.classes.pop() + + return IgnoredNode() + + def visit_FunctionDef(self, node): + self.add_to_definitions(node) + + return IgnoredNode() def visit_Return(self, node): label = LabelVisitor() @@ -322,12 +306,14 @@ def connect_if_allowed(self, previous_node, node_to_connect_to): if not isinstance(previous_node, ReturnNode): previous_node.connect(node_to_connect_to) - def save_def_args_in_temp(self, - call_args, - def_args, - line_number, - saved_function_call_index, - first_node): + def save_def_args_in_temp( + self, + call_args, + def_args, + line_number, + saved_function_call_index, + first_node + ): """Save the arguments of the definition being called. Visit the arguments if they're calls. Args: @@ -346,6 +332,7 @@ def save_def_args_in_temp(self, # Create e.g. temp_N_def_arg1 = call_arg1_label_visitor.result for each argument for i, call_arg in enumerate(call_args): + # If this results in an IndexError it is invalid Python def_arg_temp_name = 'temp_' + str(saved_function_call_index) + '_' + def_args[i] return_value_of_nested_call = None @@ -408,11 +395,13 @@ def save_def_args_in_temp(self, return (args_mapping, first_node) - def create_local_scope_from_def_args(self, - call_args, - def_args, - line_number, - saved_function_call_index): + def create_local_scope_from_def_args( + self, + call_args, + def_args, + line_number, + saved_function_call_index + ): """Create the local scope before entering the body of a function call. Args: @@ -439,10 +428,42 @@ def create_local_scope_from_def_args(self, self.nodes[-1].connect(local_scope_node) self.nodes.append(local_scope_node) - def restore_saved_local_scope(self, - saved_variables, - args_mapping, - line_number): + def visit_and_get_function_nodes(self, definition, first_node): + """Visits the nodes of a user defined function. + + Args: + definition(LocalModuleDefinition): Definition of the function being added. + first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function. + + Returns: + the_new_nodes(list[Node]): The nodes added while visiting the function. + first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function. + """ + len_before_visiting_func = len(self.nodes) + previous_node = self.nodes[-1] + entry_node = self.append_node(EntryOrExitNode("Function Entry " + + definition.name)) + if not first_node: + first_node = entry_node + self.connect_if_allowed(previous_node, entry_node) + + function_body_connect_statements = self.stmt_star_handler(definition.node.body) + entry_node.connect(function_body_connect_statements.first_statement) + + exit_node = self.append_node(EntryOrExitNode("Exit " + definition.name)) + exit_node.connect_predecessors(function_body_connect_statements.last_statements) + + the_new_nodes = self.nodes[len_before_visiting_func:] + return_connection_handler(the_new_nodes, exit_node) + + return (the_new_nodes, first_node) + + def restore_saved_local_scope( + self, + saved_variables, + args_mapping, + line_number + ): """Restore the previously saved variables to their original values. Args: @@ -584,36 +605,6 @@ def process_function(self, call_node, definition): return self.nodes[-1] - def visit_and_get_function_nodes(self, definition, first_node): - """Visits the nodes of a user defined function. - - Args: - definition(LocalModuleDefinition): Definition of the function being added. - first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function. - - Returns: - the_new_nodes(list[Node]): The nodes added while visiting the function. - first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function. - """ - len_before_visiting_func = len(self.nodes) - previous_node = self.nodes[-1] - entry_node = self.append_node(EntryOrExitNode("Function Entry " + - definition.name)) - if not first_node: - first_node = entry_node - self.connect_if_allowed(previous_node, entry_node) - - function_body_connect_statements = self.stmt_star_handler(definition.node.body) - entry_node.connect(function_body_connect_statements.first_statement) - - exit_node = self.append_node(EntryOrExitNode("Exit " + definition.name)) - exit_node.connect_predecessors(function_body_connect_statements.last_statements) - - the_new_nodes = self.nodes[len_before_visiting_func:] - self.return_connection_handler(the_new_nodes, exit_node) - - return (the_new_nodes, first_node) - def visit_Call(self, node): _id = get_call_names_as_string(node.func) local_definitions = self.module_definitions_stack[-1] @@ -741,7 +732,14 @@ def add_module(self, module, module_or_package_name, local_names, import_alias_m return exit_node - def from_directory_import(self, module, real_names, local_names, import_alias_mapping, skip_init=False): + def from_directory_import( + self, + module, + real_names, + local_names, + import_alias_mapping, + skip_init=False + ): """ Directories don't need to be packages. """ @@ -785,6 +783,7 @@ def from_directory_import(self, module, real_names, local_names, import_alias_ma import_alias_mapping, from_from=True ) + return IgnoredNode() def import_package(self, module, module_name, local_name, import_alias_mapping): module_path = module[1] @@ -801,40 +800,6 @@ def import_package(self, module, module_name, local_name, import_alias_mapping): else: raise Exception("import directory needs an __init__.py file") - def visit_Import(self, node): - for name in node.names: - for module in self.local_modules: - if name.name == module[0]: - if os.path.isdir(module[1]): - return self.import_package( - module, - name, - name.asname, - retrieve_import_alias_mapping(node.names) - ) - return self.add_module( - module, - name.name, - name.asname, - retrieve_import_alias_mapping(node.names) - ) - for module in self.project_modules: - if name.name == module[0]: - if os.path.isdir(module[1]): - return self.import_package( - module, - name, - name.asname, - retrieve_import_alias_mapping(node.names) - ) - return self.add_module( - module, - name.name, - name.asname, - retrieve_import_alias_mapping(node.names) - ) - return IgnoredNode() - def handle_relative_import(self, node): """ from A means node.level == 0 @@ -886,6 +851,40 @@ def handle_relative_import(self, node): skip_init=skip_init ) + def visit_Import(self, node): + for name in node.names: + for module in self.local_modules: + if name.name == module[0]: + if os.path.isdir(module[1]): + return self.import_package( + module, + name, + name.asname, + retrieve_import_alias_mapping(node.names) + ) + return self.add_module( + module, + name.name, + name.asname, + retrieve_import_alias_mapping(node.names) + ) + for module in self.project_modules: + if name.name == module[0]: + if os.path.isdir(module[1]): + return self.import_package( + module, + name, + name.asname, + retrieve_import_alias_mapping(node.names) + ) + return self.add_module( + module, + name.name, + name.asname, + retrieve_import_alias_mapping(node.names) + ) + return IgnoredNode() + def visit_ImportFrom(self, node): # Is it relative? if node.level > 0: @@ -926,11 +925,20 @@ def visit_ImportFrom(self, node): return IgnoredNode() -def interprocedural(node, project_modules, local_modules, filename, - module_definitions=None): - - visitor = InterproceduralVisitor(node, - project_modules, - local_modules, filename, - module_definitions) - return CFG(visitor.nodes, visitor.blackbox_assignments) +def interprocedural( + node, + project_modules, + local_modules, + filename, + module_definitions=None +): + visitor = InterproceduralVisitor( + node, + project_modules, + local_modules, filename, + module_definitions + ) + return CFG( + visitor.nodes, + visitor.blackbox_assignments + ) diff --git a/pyt/interprocedural_cfg_helper.py b/pyt/interprocedural_cfg_helper.py new file mode 100644 index 00000000..d8ee6c7f --- /dev/null +++ b/pyt/interprocedural_cfg_helper.py @@ -0,0 +1,54 @@ +from collections import namedtuple + +from .node_types import ( + ConnectToExitNode +) + +SavedVariable = namedtuple('SavedVariable', 'LHS RHS') +BUILTINS = ( + 'get', + 'Flask', + 'run', + 'replace', + 'read', + 'set_cookie', + 'make_response', + 'SQLAlchemy', + 'Column', + 'execute', + 'sessionmaker', + 'Session', + 'filter', + 'call', + 'render_template', + 'redirect', + 'url_for', + 'flash', + 'jsonify' +) + + +class CFG(): + def __init__(self, nodes, blackbox_assignments): + self.nodes = nodes + self.blackbox_assignments = blackbox_assignments + + def __repr__(self): + output = '' + for x, n in enumerate(self.nodes): + output = ''.join((output, 'Node: ' + str(x) + ' ' + repr(n), '\n\n')) + return output + + def __str__(self): + output = '' + for x, n in enumerate(self.nodes): + output = ''.join((output, 'Node: ' + str(x) + ' ' + str(n), '\n\n')) + return output + + +def return_connection_handler(nodes, exit_node): + """Connect all return statements to the Exit node.""" + for function_body_node in nodes: + if isinstance(function_body_node, ConnectToExitNode): + if exit_node not in function_body_node.outgoing: + function_body_node.connect(exit_node) diff --git a/pyt/intraprocedural_cfg.py b/pyt/intraprocedural_cfg.py deleted file mode 100644 index 5c990f23..00000000 --- a/pyt/intraprocedural_cfg.py +++ /dev/null @@ -1,179 +0,0 @@ -import ast - -from .ast_helper import Arguments, generate_ast -from .base_cfg import ( - CALL_IDENTIFIER, - CFG, - EntryOrExitNode, - IgnoredNode, - Node, - ReturnNode, - Visitor -) -from .label_visitor import LabelVisitor -from .right_hand_side_visitor import RHSVisitor - - -class IntraproceduralVisitor(Visitor): - - def __init__(self, node, filename): - """Create an empty CFG.""" - self.nodes = list() - self.undecided = False # Check if needed in intraprocedural - - self.function_names = list() - self.filenames = [filename] - - try: - # FunctionDef ast node - self.init_function_cfg(node) - except: # Error?! - # Module ast node - self.init_module_cfg(node) - - def init_module_cfg(self, node): - entry_node = self.append_node(EntryOrExitNode("Entry module")) - - module_statements = self.visit(node) - - if not module_statements: - raise Exception('Empty module. It seems that your file is empty,' + - ' there is nothing to analyse.') - - if not isinstance(module_statements, IgnoredNode): - first_node = module_statements.first_statement - - if CALL_IDENTIFIER not in first_node.label: - entry_node.connect(first_node) - - exit_node = self.append_node(EntryOrExitNode("Exit module")) - - last_nodes = module_statements.last_statements - exit_node.connect_predecessors(last_nodes) - else: - exit_node = self.append_node(EntryOrExitNode("Exit module")) - entry_node.connect(exit_node) - - def init_function_cfg(self, node): - - entry_node = self.append_node(EntryOrExitNode("Entry module")) - - module_statements = self.stmt_star_handler(node.body) - if isinstance(module_statements, IgnoredNode): - exit_node = self.append_node(EntryOrExitNode("Exit module")) - entry_node.connect(exit_node) - return - - first_node = module_statements.first_statement - if CALL_IDENTIFIER not in first_node.label: - entry_node.connect(first_node) - - exit_node = self.append_node(EntryOrExitNode("Exit module")) - - last_nodes = module_statements.last_statements - exit_node.connect_predecessors(last_nodes) - - def visit_ClassDef(self, node): - return self.append_node(Node('class ' + node.name, node, - line_number=node.lineno, - path=self.filenames[-1])) - - def visit_FunctionDef(self, node): - arguments = Arguments(node.args) - return self.append_node(Node('def ' + node.name + '(' + - ','.join(arguments) + '):', - node, - line_number=node.lineno, - path=self.filenames[-1])) - - def visit_Return(self, node): - label = LabelVisitor() - label.visit(node) - - try: - rhs_visitor = RHSVisitor() - rhs_visitor.visit(node.value) - except AttributeError: - rhs_visitor.result = 'EmptyReturn' - - LHS = 'ret_' + 'MAYBE_FUNCTION_NAME' - return self.append_node(ReturnNode(LHS + ' = ' + label.result, - LHS, - node, - rhs_visitor.result, - line_number=node.lineno, - path=self.filenames[-1])) - - def visit_Yield(self, node): - label = LabelVisitor() - label.visit(node) - - try: - rhs_visitor = RHSVisitor() - rhs_visitor.visit(node.value) - except AttributeError: - rhs_visitor.result = 'EmptyYield' - - LHS = 'yield_' + 'MAYBE_FUNCTION_NAME' - return self.append_node(ReturnNode(LHS + ' = ' + label.result, - LHS, - node, - rhs_visitor.result, - line_number=node.lineno, - path=self.filenames[-1])) - - def visit_Call(self, node): - return self.add_builtin(node) - - def visit_Import(self, node): - names = [n.name for n in node.names] - return self.append_node(Node('Import ' + ', '.join(names), node, - line_number=node.lineno, - path=self.filenames[-1])) - - def visit_ImportFrom(self, node): - names = [a.name for a in node.names] - try: - from_import = 'from ' + node.module + ' ' - except TypeError: - from_import = '' - return self.append_node(Node(from_import + 'import ' + - ', '.join(names), - node, - line_number=node.lineno, - path=self.filenames[-1])) - - -class FunctionDefVisitor(ast.NodeVisitor): - def __init__(self): - self.result = list() - - def visit_FunctionDef(self, node): - self.result.append(node) - #def visit_ClassDef(self, node): - # self.result.append(node) - - -def intraprocedural(project_modules, cfg_list): - functions = list() - dup = list() - for module in project_modules: - t = generate_ast(module[1]) - iv = IntraproceduralVisitor(t, filename=module[1]) - cfg_list.append(CFG(iv.nodes)) - dup.append(t) - fdv = FunctionDefVisitor() - fdv.visit(t) - dup.extend(fdv.result) - functions.extend([(f, module[1]) for f in fdv.result]) - - for f in functions: - iv = IntraproceduralVisitor(f[0], filename=f[1]) - cfg_list.append(CFG(iv.nodes)) - - s = set() - for d in dup: - if d in s: - raise Exception('Duplicates in the functions definitions list.') - else: - s.add(d) diff --git a/pyt/liveness.py b/pyt/liveness.py index d9aadde8..8d3fe4f3 100644 --- a/pyt/liveness.py +++ b/pyt/liveness.py @@ -2,16 +2,16 @@ from .analysis_base import AnalysisBase from .ast_helper import get_call_names_as_string -from .base_cfg import ( - AssignmentNode, - BBorBInode, - EntryOrExitNode -) from .constraint_table import ( constraint_join, constraint_table ) from .lattice import Lattice +from .node_types import ( + AssignmentNode, + BBorBInode, + EntryOrExitNode +) from .vars_visitor import VarsVisitor diff --git a/pyt/module_definitions.py b/pyt/module_definitions.py index 0a465d4c..0f7b72a2 100644 --- a/pyt/module_definitions.py +++ b/pyt/module_definitions.py @@ -62,7 +62,7 @@ def __init__(self, import_names=None, module_name=None, is_init=False, filename= self.filename = filename self.definitions = list() self.classes = list() - self.import_alias_mapping = {} + self.import_alias_mapping = dict() def append_if_local_or_in_imports(self, definition): """Add definition to list. diff --git a/pyt/node_types.py b/pyt/node_types.py new file mode 100644 index 00000000..495ac33b --- /dev/null +++ b/pyt/node_types.py @@ -0,0 +1,228 @@ +"""This module contains all of the CFG nodes types.""" +from collections import namedtuple + + +ControlFlowNode = namedtuple('ControlFlowNode', + 'test last_nodes break_statements') + +class IgnoredNode(): + """Ignored Node sent from an ast node that should not return anything.""" + pass + +class ConnectToExitNode(): + pass + + +class Node(): + """A Control Flow Graph node that contains a list of + ingoing and outgoing nodes and a list of its variables.""" + + def __init__(self, label, ast_node, *, line_number, path): + """Create a Node that can be used in a CFG. + + Args: + label(str): The label of the node, describing its expression. + line_number(Optional[int]): The line of the expression of the Node. + """ + self.label = label + self.ast_node = ast_node + self.line_number = line_number + self.path = path + self.ingoing = list() + self.outgoing = list() + + def connect(self, successor): + """Connect this node to its successor node by + setting its outgoing and the successors ingoing.""" + if isinstance(self, ConnectToExitNode) and not isinstance(successor, EntryOrExitNode): + return + + self.outgoing.append(successor) + successor.ingoing.append(self) + + def connect_predecessors(self, predecessors): + """Connect all nodes in predecessors to this node.""" + for n in predecessors: + self.ingoing.append(n) + n.outgoing.append(self) + + def __str__(self): + """Print the label of the node.""" + return ''.join((' Label: ', self.label)) + + + def __repr__(self): + """Print a representation of the node.""" + label = ' '.join(('Label: ', self.label)) + line_number = 'Line number: ' + str(self.line_number) + outgoing = '' + ingoing = '' + if self.ingoing: + ingoing = ' '.join(('ingoing:\t', str([x.label for x in self.ingoing]))) + else: + ingoing = ' '.join(('ingoing:\t', '[]')) + + if self.outgoing: + outgoing = ' '.join(('outgoing:\t', str([x.label for x in self.outgoing]))) + else: + outgoing = ' '.join(('outgoing:\t', '[]')) + + return '\n' + '\n'.join((label, line_number, ingoing, outgoing)) + + +class RaiseNode(Node, ConnectToExitNode): + """CFG Node that represents a Raise statement.""" + + def __init__(self, label, ast_node, *, line_number, path): + """Create a Raise node.""" + super().__init__(label, ast_node, line_number=line_number, path=path) + + +class BreakNode(Node): + """CFG Node that represents a Break node.""" + + def __init__(self, ast_node, *, line_number, path): + super().__init__(self.__class__.__name__, ast_node, line_number=line_number, path=path) + + +class EntryOrExitNode(Node): + """CFG Node that represents an Exit or an Entry node.""" + + def __init__(self, label): + super().__init__(label, None, line_number=None, path=None) + + +class AssignmentNode(Node): + """CFG Node that represents an assignment.""" + + def __init__(self, label, left_hand_side, ast_node, right_hand_side_variables, *, line_number, path): + """Create an Assignment node. + + Args: + label(str): The label of the node, describing the expression it represents. + left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis. + ast_node(_ast.Assign, _ast.AugAssign, _ast.Return or None) + right_hand_side_variables(list[str]): A list of variables on the right hand side. + line_number(Optional[int]): The line of the expression the Node represents. + path(string): Current filename. + """ + super().__init__(label, ast_node, line_number=line_number, path=path) + self.left_hand_side = left_hand_side + self.right_hand_side_variables = right_hand_side_variables + + def __repr__(self): + output_string = super().__repr__() + output_string += '\n' + return ''.join((output_string, + 'left_hand_side:\t', str(self.left_hand_side), '\n', + 'right_hand_side_variables:\t', str(self.right_hand_side_variables))) + + +class TaintedNode(AssignmentNode): + pass + + +class RestoreNode(AssignmentNode): + """Node used for handling restore nodes returning from function calls.""" + + def __init__(self, label, left_hand_side, right_hand_side_variables, *, line_number, path): + """Create a Restore node. + + Args: + label(str): The label of the node, describing the expression it represents. + left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis. + right_hand_side_variables(list[str]): A list of variables on the right hand side. + line_number(Optional[int]): The line of the expression the Node represents. + path(string): Current filename. + """ + super().__init__(label, left_hand_side, None, right_hand_side_variables, line_number=line_number, path=path) + + +class BBorBInode(AssignmentNode): + """Node used for handling restore nodes returning from blackbox or builtin function calls.""" + + def __init__(self, label, left_hand_side, right_hand_side_variables, *, line_number, path): + """Create a Restore node. + + Args: + label(str): The label of the node, describing the expression it represents. + left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis. + right_hand_side_variables(list[str]): A list of variables on the right hand side. + line_number(Optional[int]): The line of the expression the Node represents. + path(string): Current filename. + """ + super().__init__(label, left_hand_side, None, right_hand_side_variables, line_number=line_number, path=path) + self.args = list() + self.inner_most_call = self + + +class AssignmentCallNode(AssignmentNode): + """Node used for X.""" + + def __init__( + self, + label, + left_hand_side, + ast_node, + right_hand_side_variables, + vv_result, + *, + line_number, + path, + call_node + ): + """Create a X. + + Args: + label(str): The label of the node, describing the expression it represents. + left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis. + right_hand_side_variables(list[str]): A list of variables on the right hand side. + vv_result(list[str]): Necessary to know `image_name = image_name.replace('..', '')` is a reassignment. + line_number(Optional[int]): The line of the expression the Node represents. + path(string): Current filename. + call_node(BBorBInode or RestoreNode): Used in connect_control_flow_node. + """ + super().__init__( + label, + left_hand_side, + ast_node, + right_hand_side_variables, + line_number=line_number, + path=path + ) + self.vv_result = vv_result + self.call_node = call_node + self.blackbox = False + + +class ReturnNode(AssignmentNode, ConnectToExitNode): + """CFG node that represents a return from a call.""" + + def __init__( + self, + label, + left_hand_side, + ast_node, + right_hand_side_variables, + *, + line_number, + path + ): + """Create a return from a call node. + + Args: + label(str): The label of the node, describing the expression it represents. + left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis. + ast_node + right_hand_side_variables(list[str]): A list of variables on the right hand side. + line_number(Optional[int]): The line of the expression the Node represents. + path(string): Current filename. + """ + super().__init__( + label, + left_hand_side, + ast_node, + right_hand_side_variables, + line_number=line_number, + path=path + ) diff --git a/pyt/reaching_definitions.py b/pyt/reaching_definitions.py index 0c8317b3..3bf5d075 100644 --- a/pyt/reaching_definitions.py +++ b/pyt/reaching_definitions.py @@ -1,5 +1,5 @@ -from .base_cfg import AssignmentNode from .constraint_table import constraint_table +from .node_types import AssignmentNode from .reaching_definitions_base import ReachingDefinitionsAnalysisBase diff --git a/pyt/reaching_definitions_base.py b/pyt/reaching_definitions_base.py index 48ed533d..9031c12d 100644 --- a/pyt/reaching_definitions_base.py +++ b/pyt/reaching_definitions_base.py @@ -1,7 +1,7 @@ from .analysis_base import AnalysisBase -from .base_cfg import AssignmentNode from .constraint_table import constraint_join from .lattice import Lattice +from .node_types import AssignmentNode class ReachingDefinitionsAnalysisBase(AnalysisBase): diff --git a/pyt/reaching_definitions_taint.py b/pyt/reaching_definitions_taint.py index a93ec544..4e0512b2 100644 --- a/pyt/reaching_definitions_taint.py +++ b/pyt/reaching_definitions_taint.py @@ -1,8 +1,8 @@ -from .base_cfg import ( +from .constraint_table import constraint_table +from .node_types import ( AssignmentCallNode, AssignmentNode ) -from .constraint_table import constraint_table from .reaching_definitions_base import ReachingDefinitionsAnalysisBase diff --git a/pyt/save.py b/pyt/save.py index 5ee6ebf2..a3a61059 100644 --- a/pyt/save.py +++ b/pyt/save.py @@ -1,9 +1,9 @@ import os from datetime import datetime -from .base_cfg import Node from .definition_chains import build_def_use_chain, build_use_def_chain from .lattice import Lattice +from .node_types import Node database_file_name = 'db.sql' diff --git a/pyt/vulnerabilities.py b/pyt/vulnerabilities.py index 968992b7..ab1d130d 100644 --- a/pyt/vulnerabilities.py +++ b/pyt/vulnerabilities.py @@ -3,14 +3,14 @@ import ast from collections import namedtuple -from .base_cfg import ( +from .lattice import Lattice +from .node_types import ( AssignmentCallNode, AssignmentNode, BBorBInode, RestoreNode, TaintedNode ) -from .lattice import Lattice from .right_hand_side_visitor import RHSVisitor from .trigger_definitions_parser import default_trigger_word_file, parse from .vars_visitor import VarsVisitor diff --git a/tests/cfg_test.py b/tests/cfg_test.py index 08edb251..da07706e 100644 --- a/tests/cfg_test.py +++ b/tests/cfg_test.py @@ -1,6 +1,5 @@ from .base_test_case import BaseTestCase -from pyt.base_cfg import EntryOrExitNode, Node -# from pyt.project_handler import get_modules +from pyt.node_types import EntryOrExitNode, Node class CFGGeneralTest(BaseTestCase): @@ -1086,9 +1085,6 @@ def test_multiple_user_defined_calls_in_blackbox_call_after_if(self): def test_function_line_numbers_2(self): path = 'example/example_inputs/simple_function_with_return.py' self.cfg_create_from_file(path) - # self.cfg = CFG(get_modules(path)) - # tree = generate_ast(path) - # self.cfg.create(tree) assignment_with_function = self.cfg.nodes[1] diff --git a/tests/command_line_test.py b/tests/command_line_test.py index 6749677c..3f90594e 100644 --- a/tests/command_line_test.py +++ b/tests/command_line_test.py @@ -27,7 +27,7 @@ def test_no_args(self): [-p | -vp | -trim] [-t TRIGGER_WORD_FILE] [-py2] [-l LOG_LEVEL] [-a ADAPTOR] [-db] [-dl DRAW_LATTICE [DRAW_LATTICE ...]] [-li | -re | -rt] - [-intra] [-ppm] + [-ppm] {save,github_search} ...\n""" + \ "python -m pyt: error: one of the arguments " + \ "-f/--filepath -gr/--git-repos is required\n" diff --git a/tests/vulnerabilities_across_files_test.py b/tests/vulnerabilities_across_files_test.py index 0c9a6153..af57cdd9 100644 --- a/tests/vulnerabilities_across_files_test.py +++ b/tests/vulnerabilities_across_files_test.py @@ -4,12 +4,12 @@ from .base_test_case import BaseTestCase from pyt import trigger_definitions_parser, vulnerabilities from pyt.ast_helper import get_call_names_as_string -from pyt.base_cfg import Node from pyt.constraint_table import constraint_table, initialize_constraint_table from pyt.fixed_point import analyse from pyt.framework_adaptor import FrameworkAdaptor from pyt.framework_helper import is_flask_route_function from pyt.lattice import Lattice +from pyt.node_types import Node from pyt.project_handler import get_directory_modules, get_modules from pyt.reaching_definitions_taint import ReachingDefinitionsTaintAnalysis diff --git a/tests/vulnerabilities_test.py b/tests/vulnerabilities_test.py index aeb39478..9fd0dac7 100644 --- a/tests/vulnerabilities_test.py +++ b/tests/vulnerabilities_test.py @@ -2,12 +2,12 @@ from .base_test_case import BaseTestCase from pyt import trigger_definitions_parser, vulnerabilities -from pyt.base_cfg import Node from pyt.constraint_table import constraint_table, initialize_constraint_table from pyt.fixed_point import analyse from pyt.framework_adaptor import FrameworkAdaptor from pyt.framework_helper import is_django_view_function, is_flask_route_function from pyt.lattice import Lattice +from pyt.node_types import Node from pyt.reaching_definitions_taint import ReachingDefinitionsTaintAnalysis