diff --git a/.travis.yml b/.travis.yml index fabf1244..d28c5f41 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,4 @@ language: python python: - - "3.5" -script: cd tests && python test.py + - "3.6" +script: python -m tests diff --git a/README.md b/README.md index 419235cb..4cab4be1 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![Build Status](https://travis-ci.org/python-security/pyt.svg?branch=master)](https://travis-ci.org/python-security/pyt) + # PyT - Python Taint Static analysis of Python web applications based on theoretical foundations (Control flow graphs, fixed point, dataflow analysis) @@ -8,6 +10,15 @@ Features planned: - Detect XSS - Detect directory traversal +Using it like a user: +`python -m pyt -f example/vulnerable_code/XSS_call.py save -du` + +Running the tests: `python -m tests` + +Running an individual test file: `python -m unittest tests.import_test` + +Running an individual test: `python -m unittest tests.import_test.ImportTest.test_import` + Work in progress # Contributions @@ -15,4 +26,46 @@ Join our slack group: https://pyt-dev.slack.com/ [Guidelines](https://github.com/python-security/pyt/blob/master/CONTRIBUTIONS.md) -[![Build Status](https://travis-ci.org/python-security/pyt.svg?branch=master)](https://travis-ci.org/python-security/pyt) +## Virtual env setup guide + +Create a directory to hold the virtual env and project + +`mkdir ~/a_folder` + +`cd ~/a_folder` + +Clone the project into the directory + +`git clone https://github.com/python-security/pyt.git` + +Create the virtual environment + +`python3 -m venv ~/a_folder/` + +Check that you have the right versions + +`python --version` sample output `Python 3.6.0` + +`pip --version` sample output `pip 9.0.1 from /Users/kevinhock/a_folder/lib/python3.6/site-packages (python 3.6)` + +Change to project directory + +`cd pyt` + +Install dependencies + +`pip install -r requirements.txt` + +`pip list` sample output + +``` +gitdb (0.6.4) +GitPython (2.0.8) +graphviz (0.4.10) +pip (9.0.1) +requests (2.10.0) +setuptools (28.8.0) +smmap (0.9.0) +``` + +In the future, just type `source ~/pyt/bin/activate` to start developing. diff --git a/example/cfg_example.py b/example/cfg_example.py index 4b123d24..9b496d14 100644 --- a/example/cfg_example.py +++ b/example/cfg_example.py @@ -1,8 +1,4 @@ -import os -import sys - -sys.path.insert(0, os.path.abspath('../pyt')) -from cfg import CFG, print_CFG, generate_ast +from ..pyt.cfg import CFG, print_CFG, generate_ast ast = generate_ast('example_inputs/example.py') diff --git a/func_counter.py b/func_counter.py index 982398a3..c2f5f59c 100644 --- a/func_counter.py +++ b/func_counter.py @@ -1,13 +1,10 @@ """Module used for counting number of functions to get an estimate og how big the CFG should be""" - import ast -import sys -import os -sys.path.insert(0, os.path.abspath('pyt')) -from cfg import get_call_names_as_string, generate_ast -from project_handler import get_python_modules +from pyt.cfg import get_call_names_as_string, generate_ast +from pyt.project_handler import get_python_modules + function_calls = list() functions = dict() diff --git a/profiling/fine_timer.py b/profiling/fine_timer.py index ac3f1d48..468cf6d8 100644 --- a/profiling/fine_timer.py +++ b/profiling/fine_timer.py @@ -1,13 +1,15 @@ -import pstats import os -from subprocess import run as sub_run, PIPE +import pstats +from subprocess import PIPE, run as sub_run + +KERNPROF = 'kernprof-3.5' +LINE_PROFILER_FILE = 'pyt.py.lprof' PYTHON = 'python3' PYT_PATH = '../pyt/pyt.py' -STATS_FILENAME = 'stats.prof' SNAKEVIZ = 'snakeviz' -KERNPROF = 'kernprof-3.5' -LINE_PROFILER_FILE = 'pyt.py.lprof' +STATS_FILENAME = 'stats.prof' + def clean_up(): if os.path.isfile(STATS_FILENAME): diff --git a/profiling/profiler.py b/profiling/profiler.py index 528ac9c1..c38b247c 100644 --- a/profiling/profiler.py +++ b/profiling/profiler.py @@ -1,6 +1,7 @@ import argparse -import fine_timer +from . import fine_timer + parser = argparse.ArgumentParser() diff --git a/profiling/profiling_runner.py b/profiling/profiling_runner.py index 0f37ad33..4f126217 100644 --- a/profiling/profiling_runner.py +++ b/profiling/profiling_runner.py @@ -2,19 +2,21 @@ Saves the result for future reference. """ -from subprocess import Popen from datetime import datetime from shutil import which +from subprocess import Popen -TRAVIS_PYTHON = 'python' + +FIXED_POINT_FLAG = '-fp' +KERNPROF = 'kernprof-3.5' PROFILER = 'profiler.py' +PROFILING_DB = 'db.txt' TEST_PROJECT_1 = 'test_projects/flaskbb_lite_1/flaskbb/app.py' TEST_PROJECT_2 = 'test_projects/flaskbb_lite_2/flaskbb/app.py' TEST_PROJECT_3 = 'test_projects/flaskbb_lite_3/flaskbb/app.py' TEST_PROJECTS = [TEST_PROJECT_1, TEST_PROJECT_2, TEST_PROJECT_3] -FIXED_POINT_FLAG = '-fp' -PROFILING_DB = 'db.txt' -KERNPROF = 'kernprof-3.5' +TRAVIS_PYTHON = 'python' + if which(KERNPROF) is None: print('You need "kernprof" to run this script. Install: "pip3 install line_profiler".') diff --git a/pydocstyle.py b/pydocstyle.py index 895c0971..618efad2 100644 --- a/pydocstyle.py +++ b/pydocstyle.py @@ -1,5 +1,5 @@ -import subprocess import re +import subprocess import sys import os diff --git a/pyt/__init__.py b/pyt/__init__.py index 4f0a4e46..77876a65 100644 --- a/pyt/__init__.py +++ b/pyt/__init__.py @@ -1,4 +1,5 @@ import pyt + def main(): pyt.main() diff --git a/pyt/pyt.py b/pyt/__main__.py similarity index 89% rename from pyt/pyt.py rename to pyt/__main__.py index e270b8c9..d62eb861 100644 --- a/pyt/pyt.py +++ b/pyt/__main__.py @@ -3,24 +3,35 @@ import argparse import os from datetime import date +from pprint import pprint + +from .argument_helpers import valid_date +from .ast_helper import generate_ast +from .draw import draw_cfgs, draw_lattices +from .constraint_table import initialize_constraint_table, print_table +from .fixed_point import analyse +from .flask_adaptor import FlaskAdaptor +from .github_search import scan_github, set_github_api_token +from .interprocedural_cfg import interprocedural +from .intraprocedural_cfg import intraprocedural +from .lattice import print_lattice +from .liveness import LivenessAnalysis +from .project_handler import get_directory_modules, get_python_modules +from .reaching_definitions import ReachingDefinitionsAnalysis +from .reaching_definitions_taint import ReachingDefinitionsTaintAnalysis +from .repo_runner import get_repos +from .save import ( + cfg_to_file, + create_database, + def_use_chain_to_file, + lattice_to_file, + Output, + use_def_chain_to_file, + verbose_cfg_to_file, + vulnerabilities_to_file +) +from .vulnerabilities import find_vulnerabilities -from ast_helper import generate_ast -from interprocedural_cfg import interprocedural -from intraprocedural_cfg import intraprocedural -from draw import draw_cfgs, draw_lattices -from reaching_definitions_taint import ReachingDefinitionsTaintAnalysis -from liveness import LivenessAnalysis -from reaching_definitions import ReachingDefinitionsAnalysis -from fixed_point import analyse -from flask_adaptor import FlaskAdaptor -from vulnerabilities import find_vulnerabilities -from project_handler import get_python_modules, get_directory_modules -from save import create_database, def_use_chain_to_file,\ - use_def_chain_to_file, cfg_to_file, verbose_cfg_to_file,\ - lattice_to_file, vulnerabilities_to_file -from constraint_table import initialize_constraint_table -from github_search import scan_github, set_github_api_token -from argument_helpers import valid_date parser = argparse.ArgumentParser() parser.set_defaults(which='') @@ -118,7 +129,7 @@ help='Output everything to file.', action='store_true') -search_parser = subparsers.add_parser('github_search', +search_parser = subparsers.add_parser('github_search', help='Searches through github and runs PyT' ' on found repositories. This can take some time.') search_parser.set_defaults(which='search') @@ -159,7 +170,6 @@ def main(): cfg_list = list() if args.git_repos: - from repo_runner import get_repos repos = get_repos(args.git_repos) for repo in repos: repo.clone() @@ -171,7 +181,7 @@ def main(): if args.which == 'search': set_github_api_token() - if args.start_date: + if args.start_date: scan_github(args.search_string, args.start_date, analysis, analyse_repo, args.csv_path) else: @@ -220,10 +230,8 @@ def main(): else: draw_cfgs(cfg_list) if args.print: - from lattice import print_lattice l = print_lattice(cfg_list, analysis) - from constraint_table import print_table print_table(l) for i, e in enumerate(cfg_list): print('############## CFG number: ', i) @@ -234,7 +242,6 @@ def main(): print(repr(e)) if args.print_project_modules: - from pprint import pprint print('############## PROJECT MODULES ##############') pprint(project_modules) @@ -246,7 +253,6 @@ def main(): # Output to file if args.which == 'save': if args.filename_prefix: - from save import Output Output.filename_prefix = args.filename_prefix if args.save_all: def_use_chain_to_file(cfg_list) diff --git a/pyt/analysis_base.py b/pyt/analysis_base.py index b63df711..bcdb73d1 100644 --- a/pyt/analysis_base.py +++ b/pyt/analysis_base.py @@ -1,5 +1,5 @@ -"""Thos module contains a base class for the analysis component used in PyT.""" -from abc import abstractmethod, ABCMeta +"""This module contains a base class for the analysis component used in PyT.""" +from abc import ABCMeta, abstractmethod class AnalysisBase(metaclass=ABCMeta): diff --git a/pyt/argument_helpers.py b/pyt/argument_helpers.py index d8e00b79..c3685df7 100644 --- a/pyt/argument_helpers.py +++ b/pyt/argument_helpers.py @@ -1,5 +1,6 @@ -from datetime import datetime from argparse import ArgumentTypeError +from datetime import datetime + def valid_date(s): date_format = "%Y-%m-%d" diff --git a/pyt/ast_helper.py b/pyt/ast_helper.py index 2b9d0857..6c5e07cb 100644 --- a/pyt/ast_helper.py +++ b/pyt/ast_helper.py @@ -73,9 +73,9 @@ def get_call_names(node): return reversed(get_call_names_helper(node, result)) -class Arguments(object): +class Arguments(): """Represents arguments of a function.""" - + def __init__(self, args): """Create an Argument container class. diff --git a/pyt/base_cfg.py b/pyt/base_cfg.py index f3c57b5d..b1f52d6b 100644 --- a/pyt/base_cfg.py +++ b/pyt/base_cfg.py @@ -1,9 +1,9 @@ import ast - from collections import namedtuple -from right_hand_side_visitor import RHSVisitor -from label_visitor import LabelVisitor -from ast_helper import get_call_names_as_string, Arguments + +from .ast_helper import Arguments, get_call_names_as_string +from .label_visitor import LabelVisitor +from .right_hand_side_visitor import RHSVisitor ControlFlowNode = namedtuple('ControlFlowNode', @@ -16,11 +16,11 @@ CALL_IDENTIFIER = '¤' -class IgnoredNode(object): +class IgnoredNode(): """Ignored Node sent from a ast node that should not return anything.""" -class Node(object): +class Node(): """A Control Flow Graph node that contains a list of ingoing and outgoing nodes and a list of its variables.""" @@ -43,7 +43,7 @@ def connect(self, successor): """Connect this node to its successor node by setting its outgoing and the successors ingoing.""" if isinstance(self, ConnectToExitNode) and\ - not type(successor) is EntryExitNode: + not isinstance(successor, EntryOrExitNode): return self.outgoing.append(successor) successor.ingoing.append(self) @@ -92,34 +92,34 @@ def __init__(self, ast_node): This node is a dummy node representing a function definition """ - super(FunctionNode, self).__init__(self.__class__.__name__, ast_node) + super().__init__(self.__class__.__name__, ast_node) class RaiseNode(Node, ConnectToExitNode): """CFG Node that represents a Raise statement.""" - + def __init__(self, label, ast_node, *, line_number, path): """Create a Raise node.""" - super(RaiseNode, self).__init__(label, ast_node, line_number=line_number, path=path) + super().__init__(label, ast_node, line_number=line_number, path=path) class BreakNode(Node): """CFG Node that represents a Break node.""" - + def __init__(self, ast_node, *, line_number, path): - super(BreakNode, self).__init__(self.__class__.__name__, ast_node, line_number=line_number, path=path) + super().__init__(self.__class__.__name__, ast_node, line_number=line_number, path=path) + +class EntryOrExitNode(Node): + """CFG Node that represents an Exit or an Entry node.""" -class EntryExitNode(Node): - """CFG Node that represents a Exit or an Entry node.""" - def __init__(self, label): - super(EntryExitNode, self).__init__(label, None, line_number=None, path=None) + super().__init__(label, None, line_number=None, path=None) + - class AssignmentNode(Node): """CFG Node that represents an assignment.""" - + def __init__(self, label, left_hand_side, ast_node, right_hand_side_variables, *, line_number, path): """Create an Assignment node. @@ -129,12 +129,12 @@ def __init__(self, label, left_hand_side, ast_node, right_hand_side_variables, * right_hand_side_variables(list[str]): A list of variables on the right hand side. line_number(Optional[int]): The line of the expression the Node represents. """ - super(AssignmentNode, self).__init__(label, ast_node, line_number=line_number, path=path) + super().__init__(label, ast_node, line_number=line_number, path=path) self.left_hand_side = left_hand_side self.right_hand_side_variables = right_hand_side_variables def __repr__(self): - output_string = super(AssignmentNode, self).__repr__() + output_string = super().__repr__() output_string += '\n' return ''.join((output_string, 'left_hand_side:\t', str(self.left_hand_side), '\n', 'right_hand_side_variables:\t', str(self.right_hand_side_variables))) @@ -143,7 +143,7 @@ class RestoreNode(AssignmentNode): """Node used for handling restore nodes returning from function calls.""" def __init__(self, label, left_hand_side, right_hand_side_variables, *, line_number, path): - """Create an Restore node. + """Create a Restore node. Args: label (str): The label of the node, describing the expression it represents. @@ -151,14 +151,14 @@ def __init__(self, label, left_hand_side, right_hand_side_variables, *, line_num right_hand_side_variables(list[str]): A list of variables on the right hand side. line_number(Optional[int]): The line of the expression the Node represents. """ - super(RestoreNode, self).__init__(label, left_hand_side, None, right_hand_side_variables, line_number=line_number, path=path) - + super().__init__(label, left_hand_side, None, right_hand_side_variables, line_number=line_number, path=path) + class ReturnNode(AssignmentNode, ConnectToExitNode): """CFG node that represents a return from a call.""" - + def __init__(self, label, left_hand_side, right_hand_side_variables, ast_node, *, line_number, path): - """Create an CallReturn node. + """Create a CallReturn node. Args: label (str): The label of the node, describing the expression it represents. @@ -166,12 +166,12 @@ def __init__(self, label, left_hand_side, right_hand_side_variables, ast_node, * right_hand_side_variables(list[str]): A list of variables on the right hand side. line_number(Optional[int]): The line of the expression the Node represents. """ - super(ReturnNode, self).__init__(label, left_hand_side, ast_node, right_hand_side_variables, line_number=line_number, path=path) + super().__init__(label, left_hand_side, ast_node, right_hand_side_variables, line_number=line_number, path=path) - -class Function(object): + +class Function(): """Representation of a function definition in the program.""" - + def __init__(self, nodes, args, decorator_list): """Create a Function representation. @@ -226,8 +226,8 @@ def get_first_statement(self, node_or_tuple): """Find the first statement of the provided object. Returns: - The node if is is a node. The first element in the tuple if it is a tuple. + The node if it is a node. """ if isinstance(node_or_tuple, tuple): return node_or_tuple[0] @@ -236,18 +236,14 @@ def get_first_statement(self, node_or_tuple): def node_to_connect(self, node): """Determine if node should be in the final CFG.""" - if isinstance(node, IgnoredNode): - return False - elif isinstance(node, ControlFlowNode): - return True - elif type(node) is FunctionNode: + if isinstance(node, (FunctionNode, IgnoredNode)): return False else: return True def connect_control_flow_node(self, control_flow_node, next_node): """Connect a ControlFlowNode properly to the next_node.""" - for last in control_flow_node[1]: # listof last nodes in ifs and elifs + for last in control_flow_node[1]: # list of last nodes in ifs and elifs if isinstance(next_node, ControlFlowNode): last.connect(next_node.test) # connect to next if test case else: @@ -260,7 +256,7 @@ def connect_nodes(self, nodes): self.connect_control_flow_node(n, next_node) elif isinstance(next_node, ControlFlowNode): # case for if n.connect(next_node[0]) - elif type(next_node) is RestoreNode: + elif isinstance(next_node, RestoreNode): continue elif CALL_IDENTIFIER in next_node.label: continue @@ -268,7 +264,7 @@ def connect_nodes(self, nodes): n.connect(next_node) def get_last_statements(self, cfg_statements): - """Retrieve the last statements from a cfg_statments list.""" + """Retrieve the last statements from a cfg_statements list.""" if isinstance(cfg_statements[-1], ControlFlowNode): return cfg_statements[-1].last_nodes else: @@ -287,20 +283,21 @@ def stmt_star_handler(self, stmts): if isinstance(node, ControlFlowNode): break_nodes.extend(node.break_statements) - elif type(node) is BreakNode: + elif isinstance(node, BreakNode): break_nodes.append(node) if self.node_to_connect(node): cfg_statements.append(node) - + self.connect_nodes(cfg_statements) - if cfg_statements: # When body of module only contains ignored nodes + if cfg_statements: first_statement = self.get_first_statement(cfg_statements[0]) last_statements = self.get_last_statements(cfg_statements) return ConnectStatements(first_statement=first_statement, last_statements=last_statements, break_statements=break_nodes) - return IgnoredNode() - + else: # When body of module only contains ignored nodes + return IgnoredNode() + def visit_Module(self, node): return self.stmt_star_handler(node.body) @@ -316,7 +313,7 @@ def handle_or_else(self, orelse, test): """Handle the orelse part of an if node. Returns: - The last nodes of the orelse branch + The last nodes of the orelse branch. """ if isinstance(orelse[0], ast.If): control_flow_node = self.visit(orelse[0]) @@ -330,21 +327,21 @@ def handle_or_else(self, orelse, test): def remove_breaks(self, last_statements): """Remove all break statements in last_statements.""" - return [n for n in last_statements if type(n) is not BreakNode] + return [n for n in last_statements if not isinstance(n, BreakNode)] def visit_If(self, node): label_visitor = LabelVisitor() label_visitor.visit(node.test) - test = self.append_node(Node(label_visitor.result, node, line_number = node.lineno, path=self.filenames[-1])) - + test = self.append_node(Node(label_visitor.result, node, line_number=node.lineno, path=self.filenames[-1])) + self.add_if_label(test) body_connect_stmts = self.stmt_star_handler(node.body) if isinstance(body_connect_stmts, IgnoredNode): body_connect_stmts = ConnectStatements(first_statement=test, last_statements=[], break_statements=[]) test.connect(body_connect_stmts.first_statement) - + if node.orelse: orelse_last_nodes = self.handle_or_else(node.orelse, test) body_connect_stmts.last_statements.extend(orelse_last_nodes) @@ -373,7 +370,7 @@ def handle_stmt_star_ignore_node(self, body, fallback_cfg_node): except AttributeError: body = ConnectStatements([fallback_cfg_node], [fallback_cfg_node], list()) return body - + def visit_Try(self, node): try_node = self.append_node(Node('Try', node, line_number=node.lineno, path=self.filenames[-1])) @@ -391,7 +388,7 @@ def visit_Try(self, node): body_node.connect(handler_node) handler_body = self.stmt_star_handler(handler.body) handler_body = self.handle_stmt_star_ignore_node(handler_body, handler_node) - last_statements.extend(handler_body.last_statements) + last_statements.extend(handler_body.last_statements) if node.orelse: orelse_last_nodes = self.handle_or_else(node.orelse, body.last_statements[-1]) @@ -438,21 +435,21 @@ def assign_tuple_target(self, node, right_hand_side_variables): new_assignment_nodes = list() for i, target in enumerate(node.targets[0].elts): value = node.value.elts[i] - + label = LabelVisitor() label.visit(target) - + if isinstance(value, ast.Call): new_ast_node = ast.Assign(target, value) new_ast_node.lineno = node.lineno - + new_assignment_nodes.append( self.assignment_call_node(label.result, new_ast_node)) - + else: label.result += ' = ' label.visit(value) - - new_assignment_nodes.append(self.append_node(AssignmentNode(label.result, self.extract_left_hand_side(target), ast.Assign(target, value), right_hand_side_variables, line_number = node.lineno, path=self.filenames[-1]))) + + new_assignment_nodes.append(self.append_node(AssignmentNode(label.result, self.extract_left_hand_side(target), ast.Assign(target, value), right_hand_side_variables, line_number=node.lineno, path=self.filenames[-1]))) self.connect_nodes(new_assignment_nodes) @@ -460,19 +457,19 @@ def assign_tuple_target(self, node, right_hand_side_variables): def assign_multi_target(self, node, right_hand_side_variables): new_assignment_nodes = list() - + for target in node.targets: label = LabelVisitor() label.visit(target) left_hand_side = label.result label.result += ' = ' label.visit(node.value) - - new_assignment_nodes.append(self.append_node(AssignmentNode(label.result, left_hand_side, ast.Assign(target, node.value), right_hand_side_variables, line_number = node.lineno, path=self.filenames[-1]))) + + new_assignment_nodes.append(self.append_node(AssignmentNode(label.result, left_hand_side, ast.Assign(target, node.value), right_hand_side_variables, line_number=node.lineno, path=self.filenames[-1]))) self.connect_nodes(new_assignment_nodes) return ControlFlowNode(new_assignment_nodes[0], [new_assignment_nodes[-1]], []) # return the last added node - + def visit_Assign(self, node): rhs_visitor = RHSVisitor() rhs_visitor.visit(node.value) @@ -492,20 +489,20 @@ def visit_Assign(self, node): print('Assignment not properly handled.', 'Could result in not finding a vulnerability.', 'Assignment:', label.result) - return self.append_node(AssignmentNode(label.result, label.result, node, rhs_visitor.result, line_number = node.lineno, path=self.filenames[-1])) - + return self.append_node(AssignmentNode(label.result, label.result, node, rhs_visitor.result, line_number=node.lineno, path=self.filenames[-1])) + elif len(node.targets) > 1: # x = y = 3 return self.assign_multi_target(node, rhs_visitor.result) - else: + else: if isinstance(node.value, ast.Call): # x = call() - + label = LabelVisitor() label.visit(node.targets[0]) return self.assignment_call_node(label.result, node) else: # x = 4 label = LabelVisitor() label.visit(node) - return self.append_node(AssignmentNode(label.result, self.extract_left_hand_side(node.targets[0]), node, rhs_visitor.result, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(AssignmentNode(label.result, self.extract_left_hand_side(node.targets[0]), node, rhs_visitor.result, line_number=node.lineno, path=self.filenames[-1])) def assignment_call_node(self, left_hand_label, ast_node): """Handle assignments that contain a function call on its right side.""" @@ -515,7 +512,7 @@ def assignment_call_node(self, left_hand_label, ast_node): rhs_visitor.visit(ast_node.value) call = self.visit(ast_node.value) - + call_label = '' call_assignment = None if isinstance(call, AssignmentNode): # assignment after returned nonbuiltin @@ -529,30 +526,30 @@ def assignment_call_node(self, left_hand_label, ast_node): self.nodes.append(call_assignment) self.undecided = False - + return call_assignment - + def visit_AugAssign(self, node): label = LabelVisitor() label.visit(node) rhs_visitor = RHSVisitor() rhs_visitor.visit(node.value) - - return self.append_node(AssignmentNode(label.result, self.extract_left_hand_side(node.target), node, rhs_visitor.result, line_number = node.lineno, path=self.filenames[-1])) + + return self.append_node(AssignmentNode(label.result, self.extract_left_hand_side(node.target), node, rhs_visitor.result, line_number=node.lineno, path=self.filenames[-1])) def loop_node_skeleton(self, test, node): """Common handling of looped structures, while and for.""" body_connect_stmts = self.stmt_star_handler(node.body) - test.connect(body_connect_stmts.first_statement) + test.connect(body_connect_stmts.first_statement) test.connect_predecessors(body_connect_stmts.last_statements) # last_nodes is used for making connections to the next node in the parent node # this is handled in stmt_star_handler last_nodes = list() last_nodes.extend(body_connect_stmts.break_statements) - + if node.orelse: orelse_connect_stmts = self.stmt_star_handler(node.orelse) @@ -565,16 +562,16 @@ def loop_node_skeleton(self, test, node): def add_while_label(self, node): """Prepend 'while' and append ':' to the label of a node.""" - node.label = 'while ' + node.label + ':' - + node.label = 'while ' + node.label + ':' + def visit_While(self, node): label_visitor = LabelVisitor() label_visitor.visit(node.test) - test = self.append_node(Node(label_visitor.result, node, line_number = node.lineno, path=self.filenames[-1])) + test = self.append_node(Node(label_visitor.result, node, line_number=node.lineno, path=self.filenames[-1])) self.add_while_label(test) - + return self.loop_node_skeleton(test, node) def visit_For(self, node): @@ -588,15 +585,15 @@ def visit_For(self, node): target_label = LabelVisitor() target = target_label.visit(node.target) - for_node = self.append_node(Node("for " + target_label.result + " in " + iterator_label.result + ':', node, line_number = node.lineno, path=self.filenames[-1])) + for_node = self.append_node(Node("for " + target_label.result + " in " + iterator_label.result + ':', node, line_number=node.lineno, path=self.filenames[-1])) + + - - if isinstance(node.iter, ast.Call) and get_call_names_as_string(node.iter.func) in self.function_names: last_node = self.visit(node.iter) last_node.connect(for_node) - - + + return self.loop_node_skeleton(for_node, node) def visit_Expr(self, node): @@ -605,7 +602,7 @@ def visit_Expr(self, node): def add_builtin(self, node): label = LabelVisitor() label.visit(node) - builtin_call = Node(label.result, node, line_number = node.lineno, path=self.filenames[-1]) + builtin_call = Node(label.result, node, line_number=node.lineno, path=self.filenames[-1]) if not self.undecided: self.nodes.append(builtin_call) @@ -616,7 +613,7 @@ def visit_Name(self, node): label = LabelVisitor() label.visit(node) - return self.append_node(Node(label.result, node, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(Node(label.result, node, line_number=node.lineno, path=self.filenames[-1])) def visit_With(self, node): label_visitor = LabelVisitor() @@ -631,46 +628,46 @@ def visit_Str(self, node): return IgnoredNode() def visit_Break(self, node): - return self.append_node(BreakNode(node, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(BreakNode(node, line_number=node.lineno, path=self.filenames[-1])) def visit_Pass(self, node): - return self.append_node(Node('pass', node, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(Node('pass', node, line_number=node.lineno, path=self.filenames[-1])) def visit_Continue(self, node): - return self.append_node(Node('continue', node, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(Node('continue', node, line_number=node.lineno, path=self.filenames[-1])) def visit_Delete(self, node): labelVisitor = LabelVisitor() for expr in node.targets: labelVisitor.visit(expr) - return self.append_node(Node('del ' + labelVisitor.result, node, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(Node('del ' + labelVisitor.result, node, line_number=node.lineno, path=self.filenames[-1])) def visit_Assert(self, node): label_visitor = LabelVisitor() label_visitor.visit(node.test) - return self.append_node(Node(label_visitor.result, node, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(Node(label_visitor.result, node, line_number=node.lineno, path=self.filenames[-1])) def visit_Attribute(self, node): label_visitor = LabelVisitor() label_visitor.visit(node) - return self.append_node(Node(label_visitor.result, node, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(Node(label_visitor.result, node, line_number=node.lineno, path=self.filenames[-1])) def visit_Global(self, node): label_visitor = LabelVisitor() label_visitor.visit(node) - return self.append_node(Node(label_visitor.result, node, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(Node(label_visitor.result, node, line_number=node.lineno, path=self.filenames[-1])) def visit_Subscript(self, node): label_visitor = LabelVisitor() label_visitor.visit(node) - return self.append_node(Node(label_visitor.result, node, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(Node(label_visitor.result, node, line_number=node.lineno, path=self.filenames[-1])) def visit_Tuple(self, node): label_visitor = LabelVisitor() label_visitor.visit(node) - return self.append_node(Node(label_visitor.result, node, line_number = node.lineno, path=self.filenames[-1])) + return self.append_node(Node(label_visitor.result, node, line_number=node.lineno, path=self.filenames[-1])) diff --git a/pyt/cfg.py b/pyt/cfg.py index d7c6a891..88b21b7b 100644 --- a/pyt/cfg.py +++ b/pyt/cfg.py @@ -8,14 +8,16 @@ import ast from collections import namedtuple -import logging -from label_visitor import LabelVisitor -from right_hand_side_visitor import RHSVisitor -from module_definitions import ModuleDefinition, ModuleDefinitions,\ - LocalModuleDefinition -from project_handler import get_directory_modules -from ast_helper import generate_ast, get_call_names_as_string, Arguments +from .ast_helper import Arguments, generate_ast, get_call_names_as_string +from .label_visitor import LabelVisitor +from .module_definitions import ( + LocalModuleDefinition, + ModuleDefinition, + ModuleDefinitions +) +from .project_handler import get_directory_modules +from .right_hand_side_visitor import RHSVisitor CALL_IDENTIFIER = '¤' @@ -25,7 +27,7 @@ class Visitor(ast.NodeVisitor): """A Control Flow Graph containing a list of nodes.""" - + def __init__(self, node, project_modules, local_modules, filename, module_definitions=None, intraprocedural=False): """Create an empty CFG.""" self.nodes = list() @@ -47,15 +49,15 @@ def __init__(self, node, project_modules, local_modules, filename, module_defini self.init_cfg(node) def init_intra_function_cfg(self, node): - self.module_definitions_stack.append(ModuleDefinitions()) + self.module_definitions_stack.append(ModuleDefinitions()) self.function_names.append(node.name) self.function_return_stack.append(node.name) - - entry_node = self.append_node(EntryExitNode("Entry module")) + + entry_node = self.append_node(EntryOrExitNode("Entry module")) module_statements = self.stmt_star_handler(node.body) if isinstance(module_statements, IgnoredNode): - exit_node = self.append_node(EntryExitNode("Exit module")) + exit_node = self.append_node(EntryOrExitNode("Exit module")) entry_node.connect(exit_node) return @@ -63,8 +65,8 @@ def init_intra_function_cfg(self, node): if CALL_IDENTIFIER not in first_node.label: entry_node.connect(first_node) - exit_node = self.append_node(EntryExitNode("Exit module")) - + exit_node = self.append_node(EntryOrExitNode("Exit module")) + last_nodes = module_statements.last_statements exit_node.connect_predecessors(last_nodes) diff --git a/pyt/definition_chains.py b/pyt/definition_chains.py index 8d98fb10..6c56248e 100644 --- a/pyt/definition_chains.py +++ b/pyt/definition_chains.py @@ -1,10 +1,10 @@ import ast -from constraint_table import constraint_table -from lattice import Lattice -from reaching_definitions import ReachingDefinitionsAnalysis -from vars_visitor import VarsVisitor -from base_cfg import AssignmentNode +from .base_cfg import AssignmentNode +from .constraint_table import constraint_table +from .lattice import Lattice +from .reaching_definitions import ReachingDefinitionsAnalysis +from .vars_visitor import VarsVisitor def get_vars(node): diff --git a/pyt/draw.py b/pyt/draw.py index 28ab3b65..9231d9a7 100644 --- a/pyt/draw.py +++ b/pyt/draw.py @@ -1,9 +1,11 @@ """Draws CFG.""" +import argparse from graphviz import Digraph -from base_cfg import AssignmentNode from itertools import permutations from subprocess import run -import argparse + +from .base_cfg import AssignmentNode + IGNORED_LABEL_NAME_CHARACHTERS = ':' @@ -78,17 +80,17 @@ def apply_styles(graph, styles): def draw_cfg(cfg, output_filename = 'output'): """Draw CFG and output as pdf.""" graph = Digraph(format='pdf') - + for node in cfg.nodes: stripped_label = node.label.replace(IGNORED_LABEL_NAME_CHARACHTERS, '') - + if 'Exit' in stripped_label: graph.node(stripped_label, 'Exit', shape='none') elif 'Entry' in stripped_label: graph.node(stripped_label, 'Entry', shape='none') else: graph.node(stripped_label, stripped_label) - + for ingoing_node in node.ingoing: graph.edge(ingoing_node.label.replace(IGNORED_LABEL_NAME_CHARACHTERS, ''), stripped_label) @@ -196,10 +198,10 @@ def draw_lattice_from_labels(labels, output_filename): add_anchor(output_filename) run_dot(output_filename) - + def draw_lattices(cfg_list, output_prefix='output'): for i, cfg in enumerate(cfg_list): - draw_lattice(cfg, output_prefix + '_' + str(i)) + draw_lattice(cfg, output_prefix + '_' + str(i)) def draw_cfgs(cfg_list, output_prefix='output'): for i, cfg in enumerate(cfg_list): diff --git a/pyt/fixed_point.py b/pyt/fixed_point.py index c772be68..1a27b162 100644 --- a/pyt/fixed_point.py +++ b/pyt/fixed_point.py @@ -1,5 +1,5 @@ """This module implements the fixed point algorithm.""" -from constraint_table import constraint_table +from .constraint_table import constraint_table class FixedPointAnalysis(): diff --git a/pyt/flask_adaptor.py b/pyt/flask_adaptor.py index 2228f34e..fe05587d 100644 --- a/pyt/flask_adaptor.py +++ b/pyt/flask_adaptor.py @@ -1,11 +1,11 @@ """Adaptor for Flask web applications.""" import ast -from framework_adaptor import FrameworkAdaptor -from ast_helper import get_call_names, Arguments -from interprocedural_cfg import interprocedural -from module_definitions import project_definitions -from framework_adaptor import TaintedNode +from .ast_helper import Arguments, get_call_names +from .framework_adaptor import FrameworkAdaptor, TaintedNode +from .interprocedural_cfg import interprocedural +from .module_definitions import project_definitions + class FlaskAdaptor(FrameworkAdaptor): """The flask adaptor class manipulates the CFG to adapt to flask applications.""" @@ -36,9 +36,9 @@ def get_cfg(self, definition): if args: definition_lineno = definition.node.lineno - cfg.nodes[0].outgoing = [] + cfg.nodes[0].outgoing = [] cfg.nodes[1].ingoing = [] - + for i, argument in enumerate(args, 1): taint = TaintedNode(argument, argument, None, [], line_number=definition_lineno, path=definition.path) previous_node = cfg.nodes[0] @@ -47,7 +47,7 @@ def get_cfg(self, definition): last_inserted = cfg.nodes[i] after_last = cfg.nodes[i+1] - last_inserted.connect(after_last) + last_inserted.connect(after_last) return cfg diff --git a/pyt/framework_adaptor.py b/pyt/framework_adaptor.py index a9999ebb..80535016 100644 --- a/pyt/framework_adaptor.py +++ b/pyt/framework_adaptor.py @@ -1,7 +1,8 @@ """A framework adaptor is a adaptor used to adapt the source code to a specific framework.""" from abc import ABCMeta, abstractmethod -from base_cfg import AssignmentNode + +from .base_cfg import AssignmentNode class FrameworkAdaptor(metaclass=ABCMeta): diff --git a/pyt/github_search.py b/pyt/github_search.py index 6e5379a6..0e200cd5 100644 --- a/pyt/github_search.py +++ b/pyt/github_search.py @@ -1,20 +1,22 @@ -from abc import abstractmethod, ABCMeta import re +import requests import time -from datetime import date, timedelta, datetime +from abc import ABCMeta, abstractmethod +from datetime import date, datetime, timedelta + +from . import repo_runner +from .reaching_definitions_taint import ReachingDefinitionsTaintAnalysis +from .repo_runner import add_repo_to_csv, NoEntryPathError +from .save import save_repo_scan +from .vulnerabilities import SinkArgsError -import requests -import repo_runner -from save import save_repo_scan -from vulnerabilities import SinkArgsError -from repo_runner import NoEntryPathError, add_repo_to_csv +DEFAULT_TIMEOUT_IN_SECONDS = 60 GITHUB_API_URL = 'https://api.github.com' GITHUB_OAUTH_TOKEN = None -SEARCH_REPO_URL = GITHUB_API_URL + '/search/repositories' -SEARCH_CODE_URL = GITHUB_API_URL + '/search/code' NUMBER_OF_REQUESTS_ALLOWED_PER_MINUTE = 30 # Rate limit is 10 and 30 with auth -DEFAULT_TIMEOUT_IN_SECONDS = 60 +SEARCH_CODE_URL = GITHUB_API_URL + '/search/code' +SEARCH_REPO_URL = GITHUB_API_URL + '/search/repositories' def set_github_api_token(): @@ -241,7 +243,6 @@ def scan_github(search_string, start_date, analysis_type, analyse_repo_func, csv for x in get_dates(date(2010, 1, 1), interval=93): print(x) exit() - from reaching_definitions_taint import ReachingDefinitionsTaintAnalysis scan_github('flask', ReachingDefinitionsTaintAnalysis) exit() q = Query(SEARCH_REPO_URL, 'flask') diff --git a/pyt/interprocedural_cfg.py b/pyt/interprocedural_cfg.py index 9cc23d99..19b20911 100644 --- a/pyt/interprocedural_cfg.py +++ b/pyt/interprocedural_cfg.py @@ -1,15 +1,28 @@ import ast from collections import namedtuple -from label_visitor import LabelVisitor -from right_hand_side_visitor import RHSVisitor -from module_definitions import ModuleDefinition, ModuleDefinitions,\ - LocalModuleDefinition -from project_handler import get_directory_modules -from ast_helper import generate_ast, get_call_names_as_string, Arguments -from base_cfg import Visitor, EntryExitNode, Node, IgnoredNode,\ - ConnectToExitNode, ReturnNode, AssignmentNode, RestoreNode,\ - CFG, CALL_IDENTIFIER +from .ast_helper import Arguments, generate_ast, get_call_names_as_string +from .base_cfg import ( + AssignmentNode, + CALL_IDENTIFIER, + CFG, + ConnectToExitNode, + EntryOrExitNode, + IgnoredNode, + Node, + RestoreNode, + ReturnNode, + Visitor +) +from .label_visitor import LabelVisitor +from .module_definitions import ( + LocalModuleDefinition, + ModuleDefinition, + ModuleDefinitions +) +from .project_handler import get_directory_modules +from .right_hand_side_visitor import RHSVisitor + SavedVariable = namedtuple('SavedVariable', 'LHS RHS') @@ -18,15 +31,15 @@ class InterproceduralVisitor(Visitor): def __init__(self, node, project_modules, local_modules, filename, module_definitions=None): """Create an empty CFG.""" + self.project_modules = project_modules + self.local_modules = local_modules + self.filenames = [filename] self.nodes = list() self.function_index = 0 self.undecided = False - self.project_modules = project_modules - self.local_modules = local_modules self.function_names = list() self.function_return_stack = list() self.module_definitions_stack = list() - self.filenames = [filename] if module_definitions: self.init_function_cfg(node, module_definitions) @@ -36,7 +49,7 @@ def __init__(self, node, project_modules, local_modules, def init_cfg(self, node): self.module_definitions_stack.append(ModuleDefinitions()) - entry_node = self.append_node(EntryExitNode("Entry module")) + entry_node = self.append_node(EntryOrExitNode("Entry module")) module_statements = self.visit(node) @@ -50,12 +63,12 @@ def init_cfg(self, node): if CALL_IDENTIFIER not in first_node.label: entry_node.connect(first_node) - exit_node = self.append_node(EntryExitNode("Exit module")) + exit_node = self.append_node(EntryOrExitNode("Exit module")) last_nodes = module_statements.last_statements exit_node.connect_predecessors(last_nodes) else: - exit_node = self.append_node(EntryExitNode("Exit module")) + exit_node = self.append_node(EntryOrExitNode("Exit module")) entry_node.connect(exit_node) def init_function_cfg(self, node, module_definitions): @@ -64,7 +77,7 @@ def init_function_cfg(self, node, module_definitions): self.function_names.append(node.name) self.function_return_stack.append(node.name) - entry_node = self.append_node(EntryExitNode("Entry module")) + entry_node = self.append_node(EntryOrExitNode("Entry module")) module_statements = self.stmt_star_handler(node.body) @@ -73,7 +86,7 @@ def init_function_cfg(self, node, module_definitions): if CALL_IDENTIFIER not in first_node.label: entry_node.connect(first_node) - exit_node = self.append_node(EntryExitNode("Exit module")) + exit_node = self.append_node(EntryOrExitNode("Exit module")) last_nodes = module_statements.last_statements exit_node.connect_predecessors(last_nodes) @@ -313,14 +326,14 @@ def add_function(self, call_node, definition): def get_function_nodes(self, definition): length = len(self.nodes) previous_node = self.nodes[-1] - entry_node = self.append_node(EntryExitNode("Entry " + + entry_node = self.append_node(EntryOrExitNode("Entry " + definition.name)) previous_node.connect(entry_node) function_body_connect_statements = self.stmt_star_handler(definition.node.body) entry_node.connect(function_body_connect_statements.first_statement) - exit_node = self.append_node(EntryExitNode("Exit " + definition.name)) + exit_node = self.append_node(EntryOrExitNode("Exit " + definition.name)) exit_node.connect_predecessors(function_body_connect_statements.last_statements) self.return_connection_handler(self.nodes[length:], exit_node) @@ -351,7 +364,7 @@ def add_class(self, call_node, def_node): previous_node = self.nodes[-1] - entry_node = self.append_node(EntryExitNode("Entry " + def_node.name)) + entry_node = self.append_node(EntryOrExitNode("Entry " + def_node.name)) previous_node.connect(entry_node) @@ -359,7 +372,7 @@ def add_class(self, call_node, def_node): entry_node.connect(function_body_connect_statements.first_statement) - exit_node = self.append_node(EntryExitNode("Exit " + def_node.name)) + exit_node = self.append_node(EntryOrExitNode("Exit " + def_node.name)) exit_node.connect_predecessors(function_body_connect_statements.last_statements) return Node(label_visitor.result, call_node, @@ -378,9 +391,9 @@ def add_module(self, module, module_name, local_names): module_definitions = ModuleDefinitions(local_names, module_name) self.module_definitions_stack.append(module_definitions) - self.append_node(EntryExitNode('Entry ' + module[0])) + self.append_node(EntryOrExitNode('Entry ' + module[0])) self.visit(tree) - exit_node = self.append_node(EntryExitNode('Exit ' + module[0])) + exit_node = self.append_node(EntryOrExitNode('Exit ' + module[0])) self.module_definitions_stack.pop() self.filenames.pop() diff --git a/pyt/intraprocedural_cfg.py b/pyt/intraprocedural_cfg.py index cfefae13..b31cdf05 100644 --- a/pyt/intraprocedural_cfg.py +++ b/pyt/intraprocedural_cfg.py @@ -1,10 +1,17 @@ import ast -from base_cfg import Visitor, Node, CFG, EntryExitNode, IgnoredNode,\ - CALL_IDENTIFIER, ReturnNode -from right_hand_side_visitor import RHSVisitor -from label_visitor import LabelVisitor -from ast_helper import generate_ast, Arguments +from .ast_helper import Arguments, generate_ast +from .base_cfg import ( + CALL_IDENTIFIER, + CFG, + EntryOrExitNode, + IgnoredNode, + Node, + ReturnNode, + Visitor +) +from .label_visitor import LabelVisitor +from .right_hand_side_visitor import RHSVisitor class IntraproceduralVisitor(Visitor): @@ -25,7 +32,7 @@ def __init__(self, node, filename): self.init_module_cfg(node) def init_module_cfg(self, node): - entry_node = self.append_node(EntryExitNode("Entry module")) + entry_node = self.append_node(EntryOrExitNode("Entry module")) module_statements = self.visit(node) @@ -39,21 +46,21 @@ def init_module_cfg(self, node): if CALL_IDENTIFIER not in first_node.label: entry_node.connect(first_node) - exit_node = self.append_node(EntryExitNode("Exit module")) + exit_node = self.append_node(EntryOrExitNode("Exit module")) last_nodes = module_statements.last_statements exit_node.connect_predecessors(last_nodes) else: - exit_node = self.append_node(EntryExitNode("Exit module")) + exit_node = self.append_node(EntryOrExitNode("Exit module")) entry_node.connect(exit_node) def init_function_cfg(self, node): - entry_node = self.append_node(EntryExitNode("Entry module")) + entry_node = self.append_node(EntryOrExitNode("Entry module")) module_statements = self.stmt_star_handler(node.body) if isinstance(module_statements, IgnoredNode): - exit_node = self.append_node(EntryExitNode("Exit module")) + exit_node = self.append_node(EntryOrExitNode("Exit module")) entry_node.connect(exit_node) return @@ -61,7 +68,7 @@ def init_function_cfg(self, node): if CALL_IDENTIFIER not in first_node.label: entry_node.connect(first_node) - exit_node = self.append_node(EntryExitNode("Exit module")) + exit_node = self.append_node(EntryOrExitNode("Exit module")) last_nodes = module_statements.last_statements exit_node.connect_predecessors(last_nodes) diff --git a/pyt/label_visitor.py b/pyt/label_visitor.py index e8ab703e..f5c25550 100644 --- a/pyt/label_visitor.py +++ b/pyt/label_visitor.py @@ -1,18 +1,18 @@ -from ast import NodeVisitor import ast -class LabelVisitor(NodeVisitor): + +class LabelVisitor(ast.NodeVisitor): def __init__(self): self.result = '' - + def handle_comma_separated(self, comma_separated_list): if comma_separated_list: for element in range(len(comma_separated_list)-1): self.visit(comma_separated_list[element]) self.result += ', ' - + self.visit(comma_separated_list[-1]) - + def visit_Tuple(self, node): self.result += '(' @@ -24,7 +24,7 @@ def visit_List(self, node): self.result += '[' self.handle_comma_separated(node.elts) - + self.result += ']' def visit_Raise(self, node): @@ -42,17 +42,17 @@ def visit_withitem(self, node): if node.optional_vars: self.result += ' as ' self.visit(node.optional_vars) - + def visit_Return(self, node): if node.value: self.visit(node.value) - + def visit_Assign(self, node): for target in node.targets: self.visit(target) self.result = ' '.join((self.result,'=')) self.insert_space() - + self.visit(node.value) def visit_AugAssign(self, node): @@ -67,7 +67,7 @@ def visit_AugAssign(self, node): def visit_Compare(self,node): self.visit(node.left) self.insert_space() - + for op,com in zip(node.ops,node.comparators): self.visit(op) self.insert_space() @@ -82,7 +82,7 @@ def visit_BinOp(self, node): self.insert_space() self.visit(node.op) self.insert_space() - + self.visit(node.right) def visit_UnaryOp(self, node): @@ -96,7 +96,7 @@ def visit_BoolOp(self, node): else: self.visit(value) self.visit(node.op) - + def comprehensions(self, node): self.visit(node.elt) @@ -105,12 +105,12 @@ def comprehensions(self, node): self.visit(expression.target) self.result += ' in ' self.visit(expression.iter) - + def visit_GeneratorExp(self, node): self.result += '(' self.comprehensions(node) self.result += ')' - + def visit_ListComp(self, node): self.result += '[' self.comprehensions(node) @@ -124,23 +124,23 @@ def visit_SetComp(self, node): def visit_DictComp(self, node): self.result += '{' - + self.visit(node.key) self.result += ' : ' self.visit(node.value) - + for expression in node.generators: self.result += ' for ' self.visit(expression.target) self.result += ' in ' self.visit(expression.iter) - + self.result += '}' def visit_Attribute(self, node): self.visit(node.value) self.result += '.' self.result += node.attr - + def visit_Call(self, node): self.visit(node.func) self.result += '(' @@ -171,7 +171,7 @@ def visit_Subscript(self, node): self.result += '[' self.slicev(node.slice) - + self.result += ']' def slicev(self, node): @@ -188,7 +188,7 @@ def slicev(self, node): self.visit(d) else: self.visit(node.value) - + # operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift | RShift | BitOr | BitXor | BitAnd | FloorDiv def visit_Add(self, node): self.result += '+' diff --git a/pyt/lattice.py b/pyt/lattice.py index 1fcc60bb..c8ee8a67 100644 --- a/pyt/lattice.py +++ b/pyt/lattice.py @@ -1,8 +1,7 @@ -from constraint_table import constraint_table +from .constraint_table import constraint_table class Lattice: - def __init__(self, cfg_nodes, analysis_type): self.el2bv = dict() # Element to bitvector dictionary self.bv2el = list() # Bitvector to element list diff --git a/pyt/liveness.py b/pyt/liveness.py index efe69309..6b4c7d60 100644 --- a/pyt/liveness.py +++ b/pyt/liveness.py @@ -1,18 +1,18 @@ import ast -from base_cfg import AssignmentNode, EntryExitNode -from analysis_base import AnalysisBase -from lattice import Lattice -from constraint_table import constraint_table, constraint_join -from ast_helper import get_call_names_as_string -from vars_visitor import VarsVisitor +from .base_cfg import AssignmentNode, EntryOrExitNode +from .analysis_base import AnalysisBase +from .lattice import Lattice +from .constraint_table import constraint_table, constraint_join +from .ast_helper import get_call_names_as_string +from .vars_visitor import VarsVisitor class LivenessAnalysis(AnalysisBase): """Reaching definitions analysis rules implemented.""" def __init__(self, cfg): - super(LivenessAnalysis, self).__init__(cfg, None) + super().__init__(cfg, None) def join(self, cfg_node): """Joins all constraints of the ingoing nodes and returns them. @@ -84,7 +84,7 @@ def add_vars_conditional(self, JOIN, cfg_node): def fixpointmethod(self, cfg_node): - if isinstance(cfg_node, EntryExitNode) and 'Exit' in cfg_node.label: + if isinstance(cfg_node, EntryOrExitNode) and 'Exit' in cfg_node.label: constraint_table[cfg_node] = 0 elif isinstance(cfg_node, AssignmentNode): JOIN = self.join(cfg_node) diff --git a/pyt/project_handler.py b/pyt/project_handler.py index 03820155..7993ed38 100644 --- a/pyt/project_handler.py +++ b/pyt/project_handler.py @@ -1,11 +1,12 @@ """Generates a list of CFGs from a path. The module finds all python modules and generates an ast for them. -Then +Then """ import ast import os + def is_python_module(path): if os.path.splitext(path)[1] == '.py': return True diff --git a/pyt/reaching_definitions.py b/pyt/reaching_definitions.py index dabdcc0f..b8cdaa9d 100644 --- a/pyt/reaching_definitions.py +++ b/pyt/reaching_definitions.py @@ -1,14 +1,14 @@ -from base_cfg import AssignmentNode -from analysis_base import AnalysisBase -from lattice import Lattice -from constraint_table import constraint_table, constraint_join +from .analysis_base import AnalysisBase +from .base_cfg import AssignmentNode +from .constraint_table import constraint_join, constraint_table +from .lattice import Lattice class ReachingDefinitionsAnalysis(AnalysisBase): """Reaching definitions analysis rules implemented.""" def __init__(self, cfg): - super(ReachingDefinitionsAnalysis, self).__init__(cfg, None) + super().__init__(cfg, None) def join(self, cfg_node): """Joins all constraints of the ingoing nodes and returns them. diff --git a/pyt/reaching_definitions_taint.py b/pyt/reaching_definitions_taint.py index caa877ee..e8138788 100644 --- a/pyt/reaching_definitions_taint.py +++ b/pyt/reaching_definitions_taint.py @@ -1,14 +1,14 @@ -from base_cfg import AssignmentNode -from analysis_base import AnalysisBase -from constraint_table import constraint_table, constraint_join -from lattice import Lattice +from .analysis_base import AnalysisBase +from .base_cfg import AssignmentNode +from .constraint_table import constraint_join, constraint_table +from .lattice import Lattice class ReachingDefinitionsTaintAnalysis(AnalysisBase): """Reaching definitions analysis rules implemented.""" def __init__(self, cfg): - super(ReachingDefinitionsTaintAnalysis, self).__init__(cfg, None) + super().__init__(cfg, None) def join(self, cfg_node): """Joins all constraints of the ingoing nodes and returns them. diff --git a/pyt/repo_runner.py b/pyt/repo_runner.py index 025578a8..78ad414b 100644 --- a/pyt/repo_runner.py +++ b/pyt/repo_runner.py @@ -1,6 +1,6 @@ """This modules runs PyT on a CSV file of git repos.""" -import os import git +import os import shutil diff --git a/pyt/right_hand_side_visitor.py b/pyt/right_hand_side_visitor.py index 65f4ca8f..72e1ce97 100644 --- a/pyt/right_hand_side_visitor.py +++ b/pyt/right_hand_side_visitor.py @@ -2,9 +2,10 @@ Used to find all variables on a right hand side(RHS) of assignment. """ -from ast import NodeVisitor +import ast -class RHSVisitor(NodeVisitor): + +class RHSVisitor(ast.NodeVisitor): """Visitor collecting all names.""" def __init__(self): diff --git a/pyt/save.py b/pyt/save.py index 408f6f98..5ee6ebf2 100644 --- a/pyt/save.py +++ b/pyt/save.py @@ -1,7 +1,9 @@ import os from datetime import datetime -from base_cfg import Node +from .base_cfg import Node +from .definition_chains import build_def_use_chain, build_use_def_chain +from .lattice import Lattice database_file_name = 'db.sql' @@ -78,7 +80,6 @@ def __exit__(self, type, value, traceback): def def_use_chain_to_file(cfg_list): - from definition_chains import build_def_use_chain with Output('def-use_chain.pyt') as fd: for i, cfg in enumerate(cfg_list): fd.write('##### Def-use chain for CFG {} #####{}' @@ -92,7 +93,6 @@ def def_use_chain_to_file(cfg_list): def use_def_chain_to_file(cfg_list): - from definition_chains import build_use_def_chain with Output('use-def_chain.pyt') as fd: for i, cfg in enumerate(cfg_list): fd.write('##### Use-def chain for CFG {} #####{}' @@ -125,7 +125,6 @@ def lattice_to_file(cfg_list, analysis_type): with Output('lattice.pyt') as fd: for i, cfg in enumerate(cfg_list): fd.write('##### Lattice for CFG {} #####{}'.format(i, os.linesep)) - from lattice import Lattice l = Lattice(cfg.nodes, analysis_type) fd.write('# Elements to bitvector #{}'.format(os.linesep)) diff --git a/pyt/trigger_definitions_parser.py b/pyt/trigger_definitions_parser.py index 55521078..65e13f9f 100644 --- a/pyt/trigger_definitions_parser.py +++ b/pyt/trigger_definitions_parser.py @@ -1,7 +1,7 @@ import os - from collections import namedtuple + SANITISER_SEPARATOR = '->' SOURCES_KEYWORD = 'sources:' SINKS_KEYWORD = 'sinks:' diff --git a/pyt/vars_visitor.py b/pyt/vars_visitor.py index dfb462ef..e47f7f98 100644 --- a/pyt/vars_visitor.py +++ b/pyt/vars_visitor.py @@ -1,5 +1,6 @@ import ast -from ast_helper import get_call_names + +from .ast_helper import get_call_names class VarsVisitor(ast.NodeVisitor): @@ -48,7 +49,7 @@ def visit_ListComp(self, node): self.visit(node.elt) for gen in node.generators: self.comprehension(gen) - + def visit_SetComp(self, node): self.visit(node.elt) for gen in node.generators: diff --git a/pyt/vulnerabilities.py b/pyt/vulnerabilities.py index c61fd27a..5313d43b 100644 --- a/pyt/vulnerabilities.py +++ b/pyt/vulnerabilities.py @@ -1,15 +1,18 @@ """Module for finding vulnerabilities based on a definitions file.""" -from collections import namedtuple import ast +from collections import namedtuple -from base_cfg import Node, AssignmentNode, ReturnNode -from framework_adaptor import TaintedNode -from vulnerability_log import Vulnerability, VulnerabilityLog,\ +from .base_cfg import Node, AssignmentNode, ReturnNode +from .framework_adaptor import TaintedNode +from .vulnerability_log import ( + Vulnerability, + VulnerabilityLog, SanitisedVulnerability -from lattice import Lattice -from vars_visitor import VarsVisitor -from trigger_definitions_parser import parse, default_trigger_word_file +) +from .lattice import Lattice +from .vars_visitor import VarsVisitor +from .trigger_definitions_parser import default_trigger_word_file, parse Triggers = namedtuple('Triggers', 'sources sinks sanitiser_dict') @@ -231,7 +234,7 @@ def get_vulnerability(source, sink, triggers, lattice): if lattice.in_constraint(secondary, sink.cfg_node)] trigger_node_in_sink = source_in_sink or secondary_in_sink - + sink_args = get_sink_args(sink.cfg_node) source_lhs_in_sink_args = source.cfg_node.left_hand_side in sink_args\ if sink_args else None diff --git a/pyt/vulnerability_log.py b/pyt/vulnerability_log.py index ac43c2b9..e1d843f9 100644 --- a/pyt/vulnerability_log.py +++ b/pyt/vulnerability_log.py @@ -4,8 +4,7 @@ The log is printed to the standard output. """ - -class VulnerabilityLog(object): +class VulnerabilityLog(): """Log that consists of vulnerabilities.""" def __init__(self): @@ -23,10 +22,10 @@ def print_report(self): print('%s vulnerability found:' % number_of_vulnerabilities) else: print('%s vulnerabilities found:' % number_of_vulnerabilities) - + for i, vulnerability in enumerate(self.vulnerabilities, start=1): print('Vulnerability {}:\n{}\n'.format(i, vulnerability)) - + class Reassigned(): def __init__(self, secondary_nodes): self.secondary_nodes = secondary_nodes @@ -39,7 +38,7 @@ def __str__(self): return secondary -class Vulnerability(object): +class Vulnerability(): """Vulnerability containing the source and the sources trigger word, the sink and the sinks trigger word.""" def __init__(self, source, source_trigger_word, sink, sink_trigger_word, secondary_nodes): @@ -49,7 +48,7 @@ def __init__(self, source, source_trigger_word, sink, sink_trigger_word, seconda self.sink = sink self.sink_trigger_word = sink_trigger_word self.secondary_nodes = secondary_nodes - + self.__remove_sink_from_secondary_nodes() def __remove_sink_from_secondary_nodes(self): @@ -58,7 +57,7 @@ def __remove_sink_from_secondary_nodes(self): self.secondary_nodes.remove(self.sink) except ValueError: pass - + def __str__(self): """Pretty printing of a vulnerability.""" reassigned_str = Reassigned(self.secondary_nodes) @@ -69,11 +68,11 @@ class SanitisedVulnerability(Vulnerability): def __init__(self, source, source_trigger_word, sink, sink_trigger_word, sanitiser, secondary_nodes): """Set source, sink and sanitiser information.""" - super(SanitisedVulnerability, self).__init__(source, source_trigger_word, sink, sink_trigger_word, secondary_nodes) + super().__init__(source, source_trigger_word, sink, sink_trigger_word, secondary_nodes) self.sanitiser = sanitiser def __str__(self): """Pretty printing of a vulnerability.""" - super_str = super(SanitisedVulnerability, self).__str__() + super_str = super().__str__() return super_str + '\nThis vulnerability is potentially sanitised by: {}'.format(self.sanitiser) - + diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test.py b/tests/__main__.py similarity index 70% rename from tests/test.py rename to tests/__main__.py index a51445db..df8298a7 100644 --- a/tests/test.py +++ b/tests/__main__.py @@ -1,7 +1,6 @@ -from unittest import TextTestRunner -from unittest import TestSuite -from unittest import TestLoader -from run import check_files +from unittest import TestLoader, TestSuite, TextTestRunner + +from .run import check_files test_suite = TestSuite() diff --git a/tests/analysis_base_test_case.py b/tests/analysis_base_test_case.py index 8182cfc1..1db614a6 100644 --- a/tests/analysis_base_test_case.py +++ b/tests/analysis_base_test_case.py @@ -1,13 +1,11 @@ import unittest -import sys -import os from collections import namedtuple -from base_test_case import BaseTestCase -sys.path.insert(0, os.path.abspath('../pyt')) -from constraint_table import initialize_constraint_table -from fixed_point import FixedPointAnalysis -from lattice import Lattice +from .base_test_case import BaseTestCase +from pyt.constraint_table import initialize_constraint_table +from pyt.fixed_point import FixedPointAnalysis +from pyt.lattice import Lattice + class AnalysisBaseTestCase(BaseTestCase): connection = namedtuple('connection', 'constraintset element') @@ -22,7 +20,7 @@ def assertInCfg(self, connections, lattice): for connection in connections: self.assertEqual(lattice.in_constraint(self.cfg.nodes[connection[0]], self.cfg.nodes[connection[1]]), True, str(connection) + " expected to be connected") nodes = len(self.cfg.nodes) - + for element in range(nodes): for sets in range(nodes): if (element, sets) not in connections: diff --git a/tests/base_test_case.py b/tests/base_test_case.py index 5e7975a1..8b79bb11 100644 --- a/tests/base_test_case.py +++ b/tests/base_test_case.py @@ -1,12 +1,9 @@ """A module that contains a base class that has helper methods for testing PyT.""" import unittest -import sys -import os -sys.path.insert(0, os.path.abspath('../pyt')) -from interprocedural_cfg import interprocedural -from ast_helper import generate_ast -from module_definitions import project_definitions +from pyt.ast_helper import generate_ast +from pyt.interprocedural_cfg import interprocedural +from pyt.module_definitions import project_definitions class BaseTestCase(unittest.TestCase): @@ -36,7 +33,7 @@ def assertConnected(self, node, successor): self.assertIn(successor, node.outgoing, '\n%s was NOT found in the outgoing list of %s containing: ' % (successor.label, node.label) + '[' + ', '.join([x.label for x in node.outgoing]) + ']') - + self.assertIn(node, successor.ingoing, '\n%s was NOT found in the ingoing list of %s containing: ' % (node.label, successor.label) + '[' + ', '.join([x.label for x in successor.ingoing]) + ']') @@ -47,7 +44,7 @@ def assertNotConnected(self, node, successor): self.assertNotIn(successor, node.outgoing, '\n%s was mistakenly found in the outgoing list of %s containing: ' % (successor.label, node.label) + '[' + ', '.join([x.label for x in node.outgoing]) + ']') - + self.assertNotIn(node, successor.ingoing, '\n%s was mistakenly found in the ingoing list of %s containing: ' % (node.label, successor.label) + '[' + ', '.join([x.label for x in successor.ingoing]) + ']') diff --git a/tests/cfg_test.py b/tests/cfg_test.py index 4b70dcdd..d5370ef5 100644 --- a/tests/cfg_test.py +++ b/tests/cfg_test.py @@ -1,60 +1,55 @@ -import os -import sys -from ast import parse - -from base_test_case import BaseTestCase -sys.path.insert(0, os.path.abspath('../pyt')) -from base_cfg import Node, EntryExitNode -from project_handler import get_python_modules +from .base_test_case import BaseTestCase +from pyt.base_cfg import EntryOrExitNode, Node +from pyt.project_handler import get_python_modules class CFGGeneralTest(BaseTestCase): def test_repr_cfg(self): - self.cfg_create_from_file('../example/example_inputs/for_complete.py') - + self.cfg_create_from_file('example/example_inputs/for_complete.py') + self.nodes = self.cfg_list_to_dict(self.cfg.nodes) - + #print(repr(self.cfg)) def test_str_cfg(self): - self.cfg_create_from_file('../example/example_inputs/for_complete.py') - + self.cfg_create_from_file('example/example_inputs/for_complete.py') + self.nodes = self.cfg_list_to_dict(self.cfg.nodes) - + #print(self.cfg) def test_no_tuples(self): - self.cfg_create_from_file('../example/example_inputs/for_complete.py') - + self.cfg_create_from_file('example/example_inputs/for_complete.py') + self.nodes = self.cfg_list_to_dict(self.cfg.nodes) - + for node in self.cfg.nodes: for edge in node.outgoing + node.ingoing: self.assertIsInstance(edge, Node) def test_start_and_exit_nodes(self): - self.cfg_create_from_file('../example/example_inputs/simple.py') + self.cfg_create_from_file('example/example_inputs/simple.py') self.assert_length(self.cfg.nodes, expected_length=3) - + start_node = 0 node = 1 exit_node = 2 self.assertInCfg([(1,0),(2,1)]) - - self.assertEqual(type(self.cfg.nodes[start_node]), EntryExitNode) - self.assertEqual(type(self.cfg.nodes[exit_node]), EntryExitNode) + + self.assertEqual(type(self.cfg.nodes[start_node]), EntryOrExitNode) + self.assertEqual(type(self.cfg.nodes[exit_node]), EntryOrExitNode) def test_start_and_exit_nodes_line_numbers(self): - self.cfg_create_from_file('../example/example_inputs/simple.py') + self.cfg_create_from_file('example/example_inputs/simple.py') self.assertLineNumber(self.cfg.nodes[0], None) self.assertLineNumber(self.cfg.nodes[1], 1) self.assertLineNumber(self.cfg.nodes[2], None) def test_str_ignored(self): - self.cfg_create_from_file('../example/example_inputs/str_ignored.py') + self.cfg_create_from_file('example/example_inputs/str_ignored.py') self.assert_length(self.cfg.nodes, expected_length=3) @@ -62,11 +57,11 @@ def test_str_ignored(self): actual_label = self.cfg.nodes[1].label self.assertEqual(expected_label, actual_label) - + class CFGForTest(BaseTestCase): def test_for_complete(self): - self.cfg_create_from_file('../example/example_inputs/for_complete.py') - + self.cfg_create_from_file('example/example_inputs/for_complete.py') + self.assert_length(self.cfg.nodes, expected_length=8) entry = 0 @@ -88,8 +83,8 @@ def test_for_complete(self): self.assertInCfg([(for_node, entry), (body_1, for_node), (else_body_1, for_node), (body_2, body_1), (for_node, body_2), (else_body_2, else_body_1), (next_node, else_body_2), (exit_node, next_node)]) def test_for_no_orelse(self): - self.cfg_create_from_file('../example/example_inputs/for_no_orelse.py') - + self.cfg_create_from_file('example/example_inputs/for_no_orelse.py') + self.nodes = self.cfg_list_to_dict(self.cfg.nodes) self.assert_length(self.cfg.nodes, expected_length=6) @@ -102,9 +97,9 @@ def test_for_no_orelse(self): exit_node = 5 self.assertInCfg([(for_node, entry), (body_1, for_node), (body_2, body_1), (for_node, body_2), (next_node, for_node), (exit_node, next_node)]) - + def test_for_tuple_target(self): - self.cfg_create_from_file('../example/example_inputs/for_tuple_target.py') + self.cfg_create_from_file('example/example_inputs/for_tuple_target.py') self.assert_length(self.cfg.nodes, expected_length = 4) @@ -112,15 +107,15 @@ def test_for_tuple_target(self): for_node = 1 print_node = 2 exit_node = 3 - + self.assertInCfg([(for_node,entry_node),(print_node,for_node),(for_node,print_node),(exit_node,for_node)]) self.assertEqual(self.cfg.nodes[for_node].label, "for (x, y) in [(1, 2), (3, 4)]:") def test_for_line_numbers(self): - self.cfg_create_from_file('../example/example_inputs/for_complete.py') + self.cfg_create_from_file('example/example_inputs/for_complete.py') self.assert_length(self.cfg.nodes, expected_length=8) - + self.nodes = self.cfg_list_to_dict(self.cfg.nodes) for_node = self.nodes['for x in range(3):'] body_1 = self.nodes['print(x)'] @@ -128,7 +123,7 @@ def test_for_line_numbers(self): else_body_1 = self.nodes["print('Final: %s' % x)"] else_body_2 = self.nodes['print(y)'] next_node = self.nodes['x = 3'] - + self.assertLineNumber(for_node, 1) self.assertLineNumber(body_1, 2) self.assertLineNumber(body_2, 3) @@ -137,7 +132,7 @@ def test_for_line_numbers(self): self.assertLineNumber(next_node, 7) def test_for_func_iterator(self): - self.cfg_create_from_file('../example/example_inputs/for_func_iterator.py') + self.cfg_create_from_file('example/example_inputs/for_func_iterator.py') self.assert_length(self.cfg.nodes, expected_length=8) @@ -149,7 +144,7 @@ def test_for_func_iterator(self): call_foo = 5 _print = 6 _exit = 7 - + self.assertInCfg([(_for, entry), (_for, call_foo), (_for, _print), (entry_foo, _for), (ret_foo, entry_foo), (exit_foo, ret_foo), (call_foo, exit_foo), (_print, _for), (_exit, _for)]) class CFGTryTest(BaseTestCase): @@ -157,7 +152,7 @@ def connected(self, node, successor): return (successor, node) def test_simple_try(self): - self.cfg_create_from_file('../example/example_inputs/try.py') + self.cfg_create_from_file('example/example_inputs/try.py') self.nodes = self.cfg_list_to_dict(self.cfg.nodes) @@ -178,7 +173,7 @@ def test_simple_try(self): self.connected(try_body, _exit)]) def test_orelse(self): - self.cfg_create_from_file('../example/example_inputs/try_orelse.py') + self.cfg_create_from_file('example/example_inputs/try_orelse.py') self.nodes = self.cfg_list_to_dict(self.cfg.nodes) @@ -189,7 +184,7 @@ def test_orelse(self): try_body = 2 except_im = 3 except_im_body_1 = 4 - print_else = 5 + print_else = 5 _exit = 6 self.assertInCfg([self.connected(entry, try_), @@ -202,7 +197,7 @@ def test_orelse(self): self.connected(print_else, _exit)]) def test_final(self): - self.cfg_create_from_file('../example/example_inputs/try_final.py') + self.cfg_create_from_file('example/example_inputs/try_final.py') self.nodes = self.cfg_list_to_dict(self.cfg.nodes) @@ -213,7 +208,7 @@ def test_final(self): try_body = 2 except_im = 3 except_im_body_1 = 4 - print_final = 5 + print_final = 5 _exit = 6 self.assertInCfg([self.connected(entry, try_), @@ -223,14 +218,14 @@ def test_final(self): self.connected(try_body, _exit), self.connected(except_im, except_im_body_1), self.connected(except_im_body_1, _exit), - self.connected(except_im_body_1, print_final), + self.connected(except_im_body_1, print_final), self.connected(print_final, _exit)]) -class CFGIfTest(BaseTestCase): +class CFGIfTest(BaseTestCase): def test_if_complete(self): - self.cfg_create_from_file('../example/example_inputs/if_complete.py') - + self.cfg_create_from_file('example/example_inputs/if_complete.py') + self.nodes = self.cfg_list_to_dict(self.cfg.nodes) self.assert_length(self.cfg.nodes, expected_length=9) @@ -254,13 +249,13 @@ def test_if_complete(self): self.assertEqual(self.cfg.nodes[next_node].label, 'x += 5') - self.assertInCfg([(test, entry), (eliftest, test), (body_1, test), (body_2, body_1), (next_node, body_2), (else_body, eliftest), (elif_body, eliftest), (next_node, elif_body), (next_node, else_body), (exit_node, next_node)]) - + self.assertInCfg([(test, entry), (eliftest, test), (body_1, test), (body_2, body_1), (next_node, body_2), (else_body, eliftest), (elif_body, eliftest), (next_node, elif_body), (next_node, else_body), (exit_node, next_node)]) + def test_single_if(self): - self.cfg_create_from_file('../example/example_inputs/if.py') - + self.cfg_create_from_file('example/example_inputs/if.py') + self.assert_length(self.cfg.nodes, expected_length=4) - + start_node = 0 test_node = 1 body_node = 2 @@ -268,7 +263,7 @@ def test_single_if(self): self.assertInCfg([(test_node,start_node), (body_node,test_node), (exit_node,test_node), (exit_node,body_node)]) def test_single_if_else(self): - self.cfg_create_from_file('../example/example_inputs/if_else.py') + self.cfg_create_from_file('example/example_inputs/if_else.py') self.assert_length(self.cfg.nodes, expected_length=5) @@ -280,7 +275,7 @@ def test_single_if_else(self): self.assertInCfg([(test_node,start_node), (body_node,test_node), (else_body,test_node), (exit_node,else_body), (exit_node,body_node)]) def test_multiple_if_else(self): - self.cfg_create_from_file('../example/example_inputs/multiple_if_else.py') + self.cfg_create_from_file('example/example_inputs/multiple_if_else.py') self.assert_length(self.cfg.nodes, expected_length=9) @@ -308,7 +303,7 @@ def test_multiple_if_else(self): ]) def test_if_else_elif(self): - self.cfg_create_from_file('../example/example_inputs/if_else_elif.py') + self.cfg_create_from_file('example/example_inputs/if_else_elif.py') self.assert_length(self.cfg.nodes, expected_length=7) @@ -331,7 +326,7 @@ def test_if_else_elif(self): ]) def test_nested_if_else_elif(self): - self.cfg_create_from_file('../example/example_inputs/nested_if_else_elif.py') + self.cfg_create_from_file('example/example_inputs/nested_if_else_elif.py') self.assert_length(self.cfg.nodes, expected_length=12) @@ -365,13 +360,13 @@ def test_nested_if_else_elif(self): (_exit, elif_body) ]) - + def test_if_line_numbers(self): - self.cfg_create_from_file('../example/example_inputs/if_complete.py') - + self.cfg_create_from_file('example/example_inputs/if_complete.py') + self.nodes = self.cfg_list_to_dict(self.cfg.nodes) self.assert_length(self.cfg.nodes, expected_length=9) - + test = self.nodes['if x > 0:'] body_1 = self.nodes['x += 1'] body_2 = self.nodes['x += 2'] @@ -390,7 +385,7 @@ def test_if_line_numbers(self): self.assertLineNumber(next_stmt, 8) def test_if_not(self): - self.cfg_create_from_file('../example/example_inputs/if_not.py') + self.cfg_create_from_file('example/example_inputs/if_not.py') self.assert_length(self.cfg.nodes, expected_length=4) @@ -398,15 +393,15 @@ def test_if_not(self): _if = 1 body = 2 _exit = 3 - + self.assertInCfg([(1, 0), (2, 1), (3, 2), (3, 1)]) class CFGWhileTest(BaseTestCase): def test_while_complete(self): - self.cfg_create_from_file('../example/example_inputs/while_complete.py') - + self.cfg_create_from_file('example/example_inputs/while_complete.py') + self.assert_length(self.cfg.nodes, expected_length=8) entry = 0 @@ -421,10 +416,10 @@ def test_while_complete(self): self.assertEqual(self.cfg.nodes[test].label, 'while x > 0:') self.assertInCfg([(test, entry), (body_1, test), (else_body_1, test), ( body_2, body_1), (test, body_2), (else_body_2, else_body_1), (next_node, else_body_2), (exit_node, next_node)]) - + def test_while_no_orelse(self): - self.cfg_create_from_file('../example/example_inputs/while_no_orelse.py') - + self.cfg_create_from_file('example/example_inputs/while_no_orelse.py') + self.assert_length(self.cfg.nodes, expected_length=6) entry = 0 @@ -433,12 +428,12 @@ def test_while_no_orelse(self): body_2 = 3 next_node = 4 exit_node = 5 - + self.assertInCfg([(test, entry), (body_1, test), ( next_node, test), (body_2, body_1), (test, body_2), (exit_node, next_node)]) - + def test_while_line_numbers(self): - self.cfg_create_from_file('../example/example_inputs/while_complete.py') - + self.cfg_create_from_file('example/example_inputs/while_complete.py') + self.nodes = self.cfg_list_to_dict(self.cfg.nodes) self.assert_length(self.cfg.nodes, expected_length=8) @@ -459,8 +454,8 @@ def test_while_line_numbers(self): class CFGAssignmentMultiTest(BaseTestCase): def test_assignment_multi_target(self): - self.cfg_create_from_file('../example/example_inputs/assignment_two_targets.py') - + self.cfg_create_from_file('example/example_inputs/assignment_two_targets.py') + self.assert_length(self.cfg.nodes, expected_length=4) start_node = 0 node = 1 @@ -473,8 +468,8 @@ def test_assignment_multi_target(self): self.assertEqual(self.cfg.nodes[node_2].label, 'y = 2') def test_assignment_multi_target_call(self): - self.cfg_create_from_file('../example/example_inputs/assignment_multiple_assign_call.py') - + self.cfg_create_from_file('example/example_inputs/assignment_multiple_assign_call.py') + self.assert_length(self.cfg.nodes, expected_length=4) start_node = self.cfg.nodes[0] node = self.cfg.nodes[1] @@ -482,13 +477,13 @@ def test_assignment_multi_target_call(self): exit_node = self.cfg.nodes[-1] self.assertInCfg([(1,0),(2,1),(3,2)]) - + self.assertEqual(node.label, 'x = int(5)') self.assertEqual(node_2.label, 'y = int(4)') def test_assignment_multi_target_line_numbers(self): - self.cfg_create_from_file('../example/example_inputs/assignment_two_targets.py') - + self.cfg_create_from_file('example/example_inputs/assignment_two_targets.py') + node = self.cfg.nodes[1] node_2 = self.cfg.nodes[2] @@ -496,20 +491,20 @@ def test_assignment_multi_target_line_numbers(self): self.assertLineNumber(node_2, 1) def test_assignment_and_builtin(self): - self.cfg_create_from_file('../example/example_inputs/assignmentandbuiltin.py') - + self.cfg_create_from_file('example/example_inputs/assignmentandbuiltin.py') + self.assert_length(self.cfg.nodes, expected_length=4) - + entry = 0 assign = 1 builtin = 2 exit_node = 3 self.assertInCfg([(assign, entry), (builtin, assign), (exit_node, builtin)]) - + def test_assignment_and_builtin_line_numbers(self): - self.cfg_create_from_file('../example/example_inputs/assignmentandbuiltin.py') - + self.cfg_create_from_file('example/example_inputs/assignmentandbuiltin.py') + assign = self.cfg.nodes[1] builtin = self.cfg.nodes[2] @@ -517,10 +512,10 @@ def test_assignment_and_builtin_line_numbers(self): self.assertLineNumber(builtin, 2) def test_multiple_assignment(self): - self.cfg_create_from_file('../example/example_inputs/assignment_multiple_assign.py') + self.cfg_create_from_file('example/example_inputs/assignment_multiple_assign.py') self.assert_length(self.cfg.nodes, expected_length=4) - + start_node = self.cfg.nodes[0] assign_y = self.cfg.nodes[1] assign_x = self.cfg.nodes[2] @@ -530,8 +525,8 @@ def test_multiple_assignment(self): self.assertEqual(assign_y.label, 'y = 5') def test_assign_list_comprehension(self): - self.cfg_create_from_file('../example/example_inputs/generator_expression_assign.py') - + self.cfg_create_from_file('example/example_inputs/generator_expression_assign.py') + length = 3 self.assert_length(self.cfg.nodes, expected_length = length) @@ -543,8 +538,8 @@ def test_assign_list_comprehension(self): self.assertInCfg(list(l)) def test_assignment_tuple_value(self): - self.cfg_create_from_file('../example/example_inputs/assignment_tuple_value.py') - + self.cfg_create_from_file('example/example_inputs/assignment_tuple_value.py') + self.assert_length(self.cfg.nodes, expected_length=3) start_node = 0 node = 1 @@ -558,62 +553,62 @@ def test_assignment_tuple_value(self): class CFGComprehensionTest(BaseTestCase): def test_nodes(self): - self.cfg_create_from_file('../example/example_inputs/comprehensions.py') + self.cfg_create_from_file('example/example_inputs/comprehensions.py') self.assert_length(self.cfg.nodes, expected_length=8) def test_list_comprehension(self): - self.cfg_create_from_file('../example/example_inputs/comprehensions.py') + self.cfg_create_from_file('example/example_inputs/comprehensions.py') listcomp = self.cfg.nodes[1] self.assertEqual(listcomp.label, 'l = [x for x in [1, 2, 3]]') def test_list_comprehension_multi(self): - self.cfg_create_from_file('../example/example_inputs/comprehensions.py') - + self.cfg_create_from_file('example/example_inputs/comprehensions.py') + listcomp = self.cfg.nodes[2] self.assertEqual(listcomp.label, 'll = [(x, y) for x in [1, 2, 3] for y in [4, 5, 6]]') - + def test_dict_comprehension(self): - self.cfg_create_from_file('../example/example_inputs/comprehensions.py') + self.cfg_create_from_file('example/example_inputs/comprehensions.py') dictcomp = self.cfg.nodes[3] self.assertEqual(dictcomp.label, 'd = {i : x for (i, x) in enumerate([1, 2, 3])}') def test_set_comprehension(self): - self.cfg_create_from_file('../example/example_inputs/comprehensions.py') - + self.cfg_create_from_file('example/example_inputs/comprehensions.py') + setcomp = self.cfg.nodes[4] self.assertEqual(setcomp.label, 's = {x for x in [1, 2, 3, 2, 2, 1, 2]}') def test_generator_expression(self): - self.cfg_create_from_file('../example/example_inputs/comprehensions.py') + self.cfg_create_from_file('example/example_inputs/comprehensions.py') listcomp = self.cfg.nodes[5] self.assertEqual(listcomp.label, 'g = (x for x in [1, 2, 3])') def test_dict_comprehension_multi(self): - self.cfg_create_from_file('../example/example_inputs/comprehensions.py') + self.cfg_create_from_file('example/example_inputs/comprehensions.py') listcomp = self.cfg.nodes[6] self.assertEqual(listcomp.label, 'dd = {x + y : y for x in [1, 2, 3] for y in [4, 5, 6]}') - + class CFGFunctionNodeTest(BaseTestCase): def connected(self, node, successor): return (successor, node) def test_simple_function(self): - path = '../example/example_inputs/simple_function.py' + path = 'example/example_inputs/simple_function.py' self.cfg_create_from_file(path) - - + + self.assert_length(self.cfg.nodes, expected_length=8) - + entry = 0 y_assignment = 1 save_y = 2 @@ -632,9 +627,9 @@ def test_simple_function(self): self.connected(y_load, exit_)]) def test_function_line_numbers(self): - path = '../example/example_inputs/simple_function.py' + path = 'example/example_inputs/simple_function.py' self.cfg_create_from_file(path) - + y_assignment = self.cfg.nodes[1] save_y = self.cfg.nodes[2] entry_foo = self.cfg.nodes[3] @@ -648,11 +643,11 @@ def test_function_line_numbers(self): self.assertLineNumber(body_foo, 2) def test_function_parameters(self): - path = '../example/example_inputs/parameters_function.py' + path = 'example/example_inputs/parameters_function.py' self.cfg_create_from_file(path) - + self.assert_length(self.cfg.nodes, expected_length=12) - + entry = 0 y_assignment = 1 save_y = 2 @@ -672,18 +667,18 @@ def test_function_parameters(self): self.connected(bar_y_assignment, bar_print_y), self.connected(bar_print_y, bar_print_x), self.connected(bar_print_x, exit_bar), self.connected(exit_bar, restore_actual_y), self.connected(restore_actual_y, exit_)]) - + def test_function_with_return(self): - path = '../example/example_inputs/simple_function_with_return.py' + path = 'example/example_inputs/simple_function_with_return.py' self.cfg_create_from_file(path) - + self.assert_length(self.cfg.nodes, expected_length=18) l = zip(range(1, len(self.cfg.nodes)), range(len(self.cfg.nodes))) self.assertInCfg(list(l)) def test_function_multiple_return(self): - path = '../example/example_inputs/function_with_multiple_return.py' + path = 'example/example_inputs/function_with_multiple_return.py' self.cfg_create_from_file(path) self.assert_length(self.cfg.nodes, expected_length=9) @@ -707,24 +702,24 @@ def test_function_multiple_return(self): (exit_foo, ret), (call_foo, exit_foo), (_exit, call_foo)]) - + def test_function_line_numbers_2(self): - path = '../example/example_inputs/simple_function_with_return.py' + path = 'example/example_inputs/simple_function_with_return.py' self.cfg_create_from_file(path) # self.cfg = CFG(get_python_modules(path)) # tree = generate_ast(path) # self.cfg.create(tree) - + assignment_with_function = self.cfg.nodes[1] self.assertLineNumber(assignment_with_function, 9) def test_multiple_parameters(self): - path = '../example/example_inputs/multiple_parameters_function.py' + path = 'example/example_inputs/multiple_parameters_function.py' self.cfg_create_from_file(path) - + length = len(self.cfg.nodes) self.assertEqual(length, 21) @@ -733,15 +728,15 @@ def test_multiple_parameters(self): self.assertInCfg(list(l)) def test_call_on_call(self): - path = '../example/example_inputs/call_on_call.py' + path = 'example/example_inputs/call_on_call.py' self.cfg_create_from_file(path) - - + + class CFGCallWithAttributeTest(BaseTestCase): def setUp(self): - self.cfg_create_from_file('../example/example_inputs/call_with_attribute.py') + self.cfg_create_from_file('example/example_inputs/call_with_attribute.py') def test_call_with_attribute(self): length = 14 @@ -761,7 +756,7 @@ def test_call_with_attribute_line_numbers(self): class CFGBreak(BaseTestCase): """Break in while and for and other places""" def test_break(self): - self.cfg_create_from_file('../example/example_inputs/while_break.py') + self.cfg_create_from_file('example/example_inputs/while_break.py') self.assert_length(self.cfg.nodes, expected_length=8) @@ -775,15 +770,15 @@ def test_break(self): _exit = 7 self.assertInCfg([(_while, entry), (_while, print_hest), (_if, _while), (print_x, _if), (_break, print_x), (print_hest, _if), (print_next, _while), (print_next, _break), (_exit, print_next)]) - - + + class CFGNameConstant(BaseTestCase): def setUp(self): - self.cfg_create_from_file('../example/example_inputs/name_constant.py') + self.cfg_create_from_file('example/example_inputs/name_constant.py') def test_name_constant_in_assign(self): self.assert_length(self.cfg.nodes, expected_length=6) - + expected_label = 'x = True' actual_label = self.cfg.nodes[1].label self.assertEqual(expected_label, actual_label) @@ -793,20 +788,20 @@ def test_name_constant_if(self): expected_label = 'if True:' actual_label = self.cfg.nodes[2].label self.assertEqual(expected_label, actual_label) - + class CFGName(BaseTestCase): """Test is Name nodes are properly handled in different contexts""" - + def test_name_if(self): - self.cfg_create_from_file('../example/example_inputs/name_if.py') - + self.cfg_create_from_file('example/example_inputs/name_if.py') + self.assert_length(self.cfg.nodes, expected_length=5) self.assertEqual(self.cfg.nodes[2].label, 'if x:') def test_name_for(self): - self.cfg_create_from_file('../example/example_inputs/name_for.py') - + self.cfg_create_from_file('example/example_inputs/name_for.py') + self.assert_length(self.cfg.nodes, expected_length=4) self.assertEqual(self.cfg.nodes[1].label, 'for x in l:') diff --git a/tests/flask_adaptor_test.py b/tests/flask_adaptor_test.py index 472df017..bdc35517 100644 --- a/tests/flask_adaptor_test.py +++ b/tests/flask_adaptor_test.py @@ -1,20 +1,16 @@ -import os -import sys - -sys.path.insert(0, os.path.abspath('../pyt')) -from flask_adaptor import FlaskAdaptor -from base_test_case import BaseTestCase +from .base_test_case import BaseTestCase +from pyt.flask_adaptor import FlaskAdaptor class FlaskEngineTest(BaseTestCase): def test_find_flask_functions(self): - self.cfg_create_from_file('../example/example_inputs/flask_function_and_normal_function.py') - + self.cfg_create_from_file('example/example_inputs/flask_function_and_normal_function.py') + cfg_list = [self.cfg] flask = FlaskAdaptor(cfg_list, list(), list()) - + #self.assertEqual(len(flask_functions), 1) #self.assertEqual(flask_functions[0], 'flask_function') diff --git a/tests/github_search_test.py b/tests/github_search_test.py index 804eaf4e..f5797f39 100644 --- a/tests/github_search_test.py +++ b/tests/github_search_test.py @@ -1,13 +1,11 @@ import unittest -import sys -import os from datetime import date -sys.path.insert(0, os.path.abspath('../pyt')) -from github_search import get_dates +from pyt.github_search import get_dates + class GetDatesTest(unittest.TestCase): def test_range_shorter_than_interval(self): - date_range = get_dates(date(2016,12,12), date(2016,12,13), 7) - + date_range = get_dates(date(2016,12,12), date(2016,12,13), 7) + diff --git a/tests/import_test.py b/tests/import_test.py index 6976f025..7312e196 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -1,29 +1,26 @@ +import ast import os -import sys -from ast import parse, Name -from base_test_case import BaseTestCase - -sys.path.insert(0, os.path.abspath('../pyt')) -from base_cfg import get_call_names_as_string -from project_handler import get_python_modules, get_directory_modules +from .base_test_case import BaseTestCase +from pyt.base_cfg import get_call_names_as_string +from pyt.project_handler import get_directory_modules, get_python_modules class ImportTest(BaseTestCase): def test_import(self): - path = os.path.normpath('../example/import_test_project/main.py') - + path = os.path.normpath('example/import_test_project/main.py') + project_modules = get_python_modules(os.path.dirname(path)) local_modules = get_directory_modules(os.path.dirname(path)) - + self.cfg_create_from_file(path, project_modules, local_modules) - + cfg_list = [self.cfg] - + #adaptor_type = FlaskAdaptor(cfg_list) def test_get_call_names_single(self): - m = parse('hi(a)') + m = ast.parse('hi(a)') call = m.body[0].value result = get_call_names_as_string(call.func) @@ -31,7 +28,7 @@ def test_get_call_names_single(self): self.assertEqual(result, 'hi') def test_get_call_names_uselesscase(self): - m = parse('defg.hi(a)') + m = ast.parse('defg.hi(a)') call = m.body[0].value result = get_call_names_as_string(call.func) @@ -40,7 +37,7 @@ def test_get_call_names_uselesscase(self): def test_get_call_names_multi(self): - m = parse('abc.defg.hi(a)') + m = ast.parse('abc.defg.hi(a)') call = m.body[0].value result = get_call_names_as_string(call.func) diff --git a/tests/label_visitor_test.py b/tests/label_visitor_test.py index 76b923b6..b4a6e837 100644 --- a/tests/label_visitor_test.py +++ b/tests/label_visitor_test.py @@ -1,17 +1,14 @@ -import os -import sys +import ast import unittest -from ast import parse -sys.path.insert(0, os.path.abspath('../pyt')) -from label_visitor import LabelVisitor +from pyt.label_visitor import LabelVisitor class LabelVisitorTestCase(unittest.TestCase): '''Baseclass for LabelVisitor tests''' def perform_labeling_on_expression(self, expr): - obj = parse(expr) + obj = ast.parse(expr) label = LabelVisitor() label.visit(obj) @@ -33,7 +30,7 @@ def test_compare_simple(self): def test_compare_multi(self): label = self.perform_labeling_on_expression('a > b > c') self.assertEqual(label.result,'a > b > c') - + def test_binop(self): label = self.perform_labeling_on_expression('a / b') self.assertEqual(label.result,'a / b') @@ -42,7 +39,7 @@ def test_call_no_arg(self): label = self.perform_labeling_on_expression('range()') self.assertEqual(label.result,'range()') - + def test_call_single_arg(self): label = self.perform_labeling_on_expression('range(5)') self.assertEqual(label.result,'range(5)') @@ -77,4 +74,4 @@ def test_list_two_elements(self): - + diff --git a/tests/lattice_test.py b/tests/lattice_test.py index c87b8853..5164c8ac 100644 --- a/tests/lattice_test.py +++ b/tests/lattice_test.py @@ -1,11 +1,8 @@ -import sys -import os +from .base_test_case import BaseTestCase +from pyt.constraint_table import constraint_table +from pyt.lattice import Lattice +from pyt.reaching_definitions_taint import ReachingDefinitionsTaintAnalysis -from base_test_case import BaseTestCase -sys.path.insert(0, os.path.abspath('../pyt')) -from lattice import Lattice -from reaching_definitions_taint import ReachingDefinitionsTaintAnalysis -from constraint_table import constraint_table class LatticeTest(BaseTestCase): @@ -24,7 +21,7 @@ def __init__(self, value, lattice_element): self.lattice_element = lattice_element def __str__(self): return str(self.value) - + def test_generate_integer_elements(self): one = self.Node(1, True) two = self.Node(2, True) @@ -51,7 +48,7 @@ def test_join(self): b = self.Node('print(x)', False) c = self.Node('x = 3', True) d = self.Node('y = x', True) - + lattice = Lattice([a, c, d], self.AnalysisType) # Constraint results after fixpoint: @@ -83,7 +80,7 @@ def test_meet(self): b = self.Node('print(x)', False) c = self.Node('x = 3', True) d = self.Node('y = x', True) - + lattice = Lattice([a, c, d], self.AnalysisType) # Constraint results after fixpoint: @@ -110,7 +107,7 @@ def test_in_constraint(self): b = self.Node('print(x)', False) c = self.Node('x = 3', True) d = self.Node('y = x', True) - + lattice = Lattice([a, c, d], self.AnalysisType) constraint_table[a] = 0b001 @@ -133,7 +130,7 @@ def test_get_elements(self): b = self.Node('print(x)', False) c = self.Node('x = 3', True) d = self.Node('y = x', True) - + lattice = Lattice([a, c, d], self.AnalysisType) self.assertEqual(set(lattice.get_elements(0b111)), {a,c,d}) diff --git a/tests/liveness_test.py b/tests/liveness_test.py index 8b0de787..2647b666 100644 --- a/tests/liveness_test.py +++ b/tests/liveness_test.py @@ -1,15 +1,11 @@ -import sys -import os - -from analysis_base_test_case import AnalysisBaseTestCase -sys.path.insert(0, os.path.abspath('../pyt')) -from liveness import LivenessAnalysis -from constraint_table import constraint_table +from .analysis_base_test_case import AnalysisBaseTestCase +from pyt.constraint_table import constraint_table +from pyt.liveness import LivenessAnalysis class LivenessTest(AnalysisBaseTestCase): def test_example(self): - lattice = self.run_analysis('../example/example_inputs/example.py', LivenessAnalysis) + lattice = self.run_analysis('example/example_inputs/example.py', LivenessAnalysis) x = 0b1 y = 0b10 diff --git a/tests/project_handler_test.py b/tests/project_handler_test.py index 96969b65..dfe1b67b 100644 --- a/tests/project_handler_test.py +++ b/tests/project_handler_test.py @@ -1,29 +1,27 @@ import os -import sys import unittest -from ast import parse - -sys.path.insert(0, os.path.abspath('../pyt')) -from project_handler import is_python_module, get_python_modules from pprint import pprint +from pyt.project_handler import get_python_modules, is_python_module + + class ProjectHandlerTest(unittest.TestCase): """Tests for the project handler.""" def test_is_python_module(self): python_module = './project_handler_test.py' not_python_module = '../.travis.yml' - + self.assertEqual(is_python_module(python_module), True) self.assertEqual(is_python_module(not_python_module), False) def test_get_python_modules(self): - project_folder = os.path.normpath(os.path.join('..', 'example', 'test_project')) + project_folder = os.path.normpath(os.path.join('example', 'test_project')) project_namespace = 'test_project' folder = 'folder' directory = 'directory' - + modules = get_python_modules(project_folder) app_path = os.path.join(project_folder, 'app.py') @@ -31,18 +29,18 @@ def test_get_python_modules(self): exceptions_path = os.path.join(project_folder, 'exceptions.py') some_path = os.path.join(project_folder, folder, 'some.py') indhold_path = os.path.join(project_folder, folder, 'indhold.py') - + app_name = project_namespace + '.' + 'app' utils_name = project_namespace + '.' + 'utils' exceptions_name = project_namespace + '.' + 'exceptions' some_name = project_namespace + '.' + folder + '.some' indhold_name = project_namespace + '.' + folder + '.' + directory + '.indhold' - + app_tuple = (app_name, app_path) utils_tuple = (utils_name, utils_path) exceptions_tuple = (exceptions_name, exceptions_path) some_tuple = (some_name, some_path) - + self.assertIn(app_tuple, modules) self.assertIn(utils_tuple, modules) self.assertIn(exceptions_tuple, modules) diff --git a/tests/reaching_definitions_taint_test.py b/tests/reaching_definitions_taint_test.py index e4563493..a1867e2b 100644 --- a/tests/reaching_definitions_taint_test.py +++ b/tests/reaching_definitions_taint_test.py @@ -1,15 +1,12 @@ -import sys -import os from collections import namedtuple, OrderedDict -from analysis_base_test_case import AnalysisBaseTestCase -sys.path.insert(0, os.path.abspath('../pyt')) -from reaching_definitions_taint import ReachingDefinitionsTaintAnalysis +from .analysis_base_test_case import AnalysisBaseTestCase +from pyt.reaching_definitions_taint import ReachingDefinitionsTaintAnalysis class ReachingDefinitionsTaintTest(AnalysisBaseTestCase): def test_linear_program(self): - lattice = self.run_analysis('../example/example_inputs/linear.py', ReachingDefinitionsTaintAnalysis) + lattice = self.run_analysis('example/example_inputs/linear.py', ReachingDefinitionsTaintAnalysis) self.assertInCfg([(1,1), (1,2), (2,2), @@ -18,7 +15,7 @@ def test_linear_program(self): def test_if_program(self): - lattice = self.run_analysis('../example/example_inputs/if_program.py', ReachingDefinitionsTaintAnalysis) + lattice = self.run_analysis('example/example_inputs/if_program.py', ReachingDefinitionsTaintAnalysis) self.assertInCfg([(1,1), (1,2), @@ -27,7 +24,7 @@ def test_if_program(self): (1,5), (3,5)], lattice) def test_example(self): - lattice = self.run_analysis('../example/example_inputs/example.py', ReachingDefinitionsTaintAnalysis) + lattice = self.run_analysis('example/example_inputs/example.py', ReachingDefinitionsTaintAnalysis) self.assertInCfg([(1,1), (1,2), (2,2), @@ -43,7 +40,7 @@ def test_example(self): *self.constraints([1,2,4,6,7,9,10], 12)], lattice) def test_func_with_params(self): - lattice = self.run_analysis('../example/example_inputs/function_with_params.py', ReachingDefinitionsTaintAnalysis) + lattice = self.run_analysis('example/example_inputs/function_with_params.py', ReachingDefinitionsTaintAnalysis) self.assertInCfg([(1,1), (1,2), (2,2), @@ -57,7 +54,7 @@ def test_func_with_params(self): *self.constraints([2,3,4,6,9], 10)], lattice) def test_while(self): - lattice = self.run_analysis('../example/example_inputs/while.py', ReachingDefinitionsTaintAnalysis) + lattice = self.run_analysis('example/example_inputs/while.py', ReachingDefinitionsTaintAnalysis) self.assertInCfg([(1,1), (1,2), (3,2), diff --git a/tests/reaching_definitions_test.py b/tests/reaching_definitions_test.py index fb2f920a..377a6537 100644 --- a/tests/reaching_definitions_test.py +++ b/tests/reaching_definitions_test.py @@ -1,14 +1,10 @@ -import sys -import os - -from analysis_base_test_case import AnalysisBaseTestCase -sys.path.insert(0, os.path.abspath('../pyt')) -from reaching_definitions import ReachingDefinitionsAnalysis +from .analysis_base_test_case import AnalysisBaseTestCase +from pyt.reaching_definitions import ReachingDefinitionsAnalysis class ReachingDefinitionsTest(AnalysisBaseTestCase): def test_linear_program(self): - lattice = self.run_analysis('../example/example_inputs/linear.py', ReachingDefinitionsAnalysis) + lattice = self.run_analysis('example/example_inputs/linear.py', ReachingDefinitionsAnalysis) self.assertInCfg([(1,1), (1,2), (2,2), @@ -16,7 +12,7 @@ def test_linear_program(self): (1,4), (2,4)], lattice) def test_example(self): - lattice = self.run_analysis('../example/example_inputs/example.py', ReachingDefinitionsAnalysis) + lattice = self.run_analysis('example/example_inputs/example.py', ReachingDefinitionsAnalysis) self.assertInCfg([(1,1), (2,2), @@ -31,4 +27,4 @@ def test_example(self): *self.constraints([2,4,6,9,10], 11), *self.constraints([2,4,6,9,10], 12)], lattice) - + diff --git a/tests/results b/tests/results index 5cd041b1..7a029880 100644 --- a/tests/results +++ b/tests/results @@ -1 +1 @@ -b'1 vulnerability found:\nVulnerability 1:\nFile: ../example/vulnerable_code/XSS.py\n > User input at line 6, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nReassigned in: \n\tFile: ../example/vulnerable_code/XSS.py\n\t > Line 10: ret_make_response = resp\nFile: ../example/vulnerable_code/XSS.py\n > reaches line 9, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', param))\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: ../example/vulnerable_code/command_injection.py\n > User input at line 15, trigger word "form[": \n\tparam = request.form[\'suggestion\']\nReassigned in: \n\tFile: ../example/vulnerable_code/command_injection.py\n\t > Line 16: command = \'echo \' + param + \' >> \' + \'menu.txt\'\nFile: ../example/vulnerable_code/command_injection.py\n > reaches line 18, trigger word "subprocess.call(": \n\tsubprocess.call(command,shell=True)\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: ../example/vulnerable_code/path_traversal.py\n > User input at line 8, trigger word "get(": \n\timage_name = request.args.get(\'image_name\')\nReassigned in: \n\tFile: ../example/vulnerable_code/path_traversal.py\n\t > Line 10: ret_request.args.get = 404\nFile: ../example/vulnerable_code/path_traversal.py\n > reaches line 11, trigger word "send_file(": \n\tret_request.args.get = send_file(os.path.join(os.getcwd(), image_name))\n\n'#¤%&/()=?b'2 vulnerabilities found:\nVulnerability 1:\nFile: ../example/vulnerable_code/path_traversal_sanitised.py\n > User input at line 8, trigger word "get(": \n\timage_name = request.args.get(\'image_name\')\nFile: ../example/vulnerable_code/path_traversal_sanitised.py\n > reaches line 10, trigger word "replace(": \n\timage_name = image_name.replace(\'..\', \'\')\n\nVulnerability 2:\nFile: ../example/vulnerable_code/path_traversal_sanitised.py\n > User input at line 8, trigger word "get(": \n\timage_name = request.args.get(\'image_name\')\nFile: ../example/vulnerable_code/path_traversal_sanitised.py\n > reaches line 12, trigger word "send_file(": \n\tret_image_name.replace = send_file(os.path.join(os.getcwd(), image_name))\nThis vulnerability is potentially sanitised by: ["\'..\'", "\'..\' in"]\n\n'#¤%&/()=?b'2 vulnerabilities found:\nVulnerability 1:\nFile: ../example/vulnerable_code/sql/sqli.py\n > User input at line 33, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nFile: ../example/vulnerable_code/sql/sqli.py\n > reaches line 36, trigger word "filter(": \n\tresult = session.query(User).filter(\'username={}\'.format(param))\n\nVulnerability 2:\nFile: ../example/vulnerable_code/sql/sqli.py\n > User input at line 26, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nFile: ../example/vulnerable_code/sql/sqli.py\n > reaches line 27, trigger word "execute(": \n\tresult = db.engine.execute(param)\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: ../example/vulnerable_code/XSS_form.py\n > User input at line 14, trigger word "form[": \n\tdata = request.form[\'my_text\']\nReassigned in: \n\tFile: ../example/vulnerable_code/XSS_form.py\n\t > Line 17: ret_resp.set_cookie = resp\nFile: ../example/vulnerable_code/XSS_form.py\n > reaches line 15, trigger word "replace(": \n\tresp = make_response(html1.replace(\'{{ data }}\', data))\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: ../example/vulnerable_code/XSS_url.py\n > User input at line 4, trigger word "Flask function URL parameter": \n\turl\nReassigned in: \n\tFile: ../example/vulnerable_code/XSS_url.py\n\t > Line 6: param = url\n\tFile: ../example/vulnerable_code/XSS_url.py\n\t > Line 10: ret_make_response = resp\nFile: ../example/vulnerable_code/XSS_url.py\n > reaches line 9, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', param))\n\n'#¤%&/()=?b'0 vulnerabilities found:\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: ../example/vulnerable_code/XSS_reassign.py\n > User input at line 6, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nReassigned in: \n\tFile: ../example/vulnerable_code/XSS_reassign.py\n\t > Line 8: param = param + \'\'\n\tFile: ../example/vulnerable_code/XSS_reassign.py\n\t > Line 12: ret_make_response = resp\nFile: ../example/vulnerable_code/XSS_reassign.py\n > reaches line 11, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', param))\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: ../example/vulnerable_code/XSS_sanitised.py\n > User input at line 7, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nReassigned in: \n\tFile: ../example/vulnerable_code/XSS_sanitised.py\n\t > Line 9: param = Markup.escape(param)\n\tFile: ../example/vulnerable_code/XSS_sanitised.py\n\t > Line 13: ret_make_response = resp\nFile: ../example/vulnerable_code/XSS_sanitised.py\n > reaches line 12, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', param))\nThis vulnerability is potentially sanitised by: [\'escape\']\n\n'#¤%&/()=?b'0 vulnerabilities found:\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: ../example/vulnerable_code/XSS_variable_assign.py\n > User input at line 6, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nReassigned in: \n\tFile: ../example/vulnerable_code/XSS_variable_assign.py\n\t > Line 8: other_var = param + \'\'\n\tFile: ../example/vulnerable_code/XSS_variable_assign.py\n\t > Line 12: ret_make_response = resp\nFile: ../example/vulnerable_code/XSS_variable_assign.py\n > reaches line 11, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', other_var))\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: ../example/vulnerable_code/XSS_variable_multiple_assign.py\n > User input at line 6, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nReassigned in: \n\tFile: ../example/vulnerable_code/XSS_variable_multiple_assign.py\n\t > Line 8: other_var = param + \'\'\n\tFile: ../example/vulnerable_code/XSS_variable_multiple_assign.py\n\t > Line 10: not_the_same_var = \'\' + other_var\n\tFile: ../example/vulnerable_code/XSS_variable_multiple_assign.py\n\t > Line 12: another_one = not_the_same_var + \'\'\n\tFile: ../example/vulnerable_code/XSS_variable_multiple_assign.py\n\t > Line 17: ret_make_response = resp\nFile: ../example/vulnerable_code/XSS_variable_multiple_assign.py\n > reaches line 15, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', another_one))\n\n'#¤%&/()=? \ No newline at end of file +b'1 vulnerability found:\nVulnerability 1:\nFile: example/vulnerable_code/XSS.py\n > User input at line 6, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nReassigned in: \n\tFile: example/vulnerable_code/XSS.py\n\t > Line 10: ret_make_response = resp\nFile: example/vulnerable_code/XSS.py\n > reaches line 9, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', param))\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: example/vulnerable_code/command_injection.py\n > User input at line 15, trigger word "form[": \n\tparam = request.form[\'suggestion\']\nReassigned in: \n\tFile: example/vulnerable_code/command_injection.py\n\t > Line 16: command = \'echo \' + param + \' >> \' + \'menu.txt\'\nFile: example/vulnerable_code/command_injection.py\n > reaches line 18, trigger word "subprocess.call(": \n\tsubprocess.call(command,shell=True)\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: example/vulnerable_code/path_traversal.py\n > User input at line 8, trigger word "get(": \n\timage_name = request.args.get(\'image_name\')\nReassigned in: \n\tFile: example/vulnerable_code/path_traversal.py\n\t > Line 10: ret_request.args.get = 404\nFile: example/vulnerable_code/path_traversal.py\n > reaches line 11, trigger word "send_file(": \n\tret_request.args.get = send_file(os.path.join(os.getcwd(), image_name))\n\n'#¤%&/()=?b'2 vulnerabilities found:\nVulnerability 1:\nFile: example/vulnerable_code/path_traversal_sanitised.py\n > User input at line 8, trigger word "get(": \n\timage_name = request.args.get(\'image_name\')\nFile: example/vulnerable_code/path_traversal_sanitised.py\n > reaches line 10, trigger word "replace(": \n\timage_name = image_name.replace(\'..\', \'\')\n\nVulnerability 2:\nFile: example/vulnerable_code/path_traversal_sanitised.py\n > User input at line 8, trigger word "get(": \n\timage_name = request.args.get(\'image_name\')\nFile: example/vulnerable_code/path_traversal_sanitised.py\n > reaches line 12, trigger word "send_file(": \n\tret_image_name.replace = send_file(os.path.join(os.getcwd(), image_name))\nThis vulnerability is potentially sanitised by: ["\'..\'", "\'..\' in"]\n\n'#¤%&/()=?b'2 vulnerabilities found:\nVulnerability 1:\nFile: example/vulnerable_code/sql/sqli.py\n > User input at line 26, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nFile: example/vulnerable_code/sql/sqli.py\n > reaches line 27, trigger word "execute(": \n\tresult = db.engine.execute(param)\n\nVulnerability 2:\nFile: example/vulnerable_code/sql/sqli.py\n > User input at line 33, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nFile: example/vulnerable_code/sql/sqli.py\n > reaches line 36, trigger word "filter(": \n\tresult = session.query(User).filter(\'username={}\'.format(param))\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: example/vulnerable_code/XSS_form.py\n > User input at line 14, trigger word "form[": \n\tdata = request.form[\'my_text\']\nReassigned in: \n\tFile: example/vulnerable_code/XSS_form.py\n\t > Line 17: ret_resp.set_cookie = resp\nFile: example/vulnerable_code/XSS_form.py\n > reaches line 15, trigger word "replace(": \n\tresp = make_response(html1.replace(\'{{ data }}\', data))\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: example/vulnerable_code/XSS_url.py\n > User input at line 4, trigger word "Flask function URL parameter": \n\turl\nReassigned in: \n\tFile: example/vulnerable_code/XSS_url.py\n\t > Line 6: param = url\n\tFile: example/vulnerable_code/XSS_url.py\n\t > Line 10: ret_make_response = resp\nFile: example/vulnerable_code/XSS_url.py\n > reaches line 9, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', param))\n\n'#¤%&/()=?b'0 vulnerabilities found:\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: example/vulnerable_code/XSS_reassign.py\n > User input at line 6, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nReassigned in: \n\tFile: example/vulnerable_code/XSS_reassign.py\n\t > Line 8: param = param + \'\'\n\tFile: example/vulnerable_code/XSS_reassign.py\n\t > Line 12: ret_make_response = resp\nFile: example/vulnerable_code/XSS_reassign.py\n > reaches line 11, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', param))\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: example/vulnerable_code/XSS_sanitised.py\n > User input at line 7, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nReassigned in: \n\tFile: example/vulnerable_code/XSS_sanitised.py\n\t > Line 9: param = Markup.escape(param)\n\tFile: example/vulnerable_code/XSS_sanitised.py\n\t > Line 13: ret_make_response = resp\nFile: example/vulnerable_code/XSS_sanitised.py\n > reaches line 12, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', param))\nThis vulnerability is potentially sanitised by: [\'escape\']\n\n'#¤%&/()=?b'0 vulnerabilities found:\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: example/vulnerable_code/XSS_variable_assign.py\n > User input at line 6, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nReassigned in: \n\tFile: example/vulnerable_code/XSS_variable_assign.py\n\t > Line 8: other_var = param + \'\'\n\tFile: example/vulnerable_code/XSS_variable_assign.py\n\t > Line 12: ret_make_response = resp\nFile: example/vulnerable_code/XSS_variable_assign.py\n > reaches line 11, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', other_var))\n\n'#¤%&/()=?b'1 vulnerability found:\nVulnerability 1:\nFile: example/vulnerable_code/XSS_variable_multiple_assign.py\n > User input at line 6, trigger word "get(": \n\tparam = request.args.get(\'param\', \'not set\')\nReassigned in: \n\tFile: example/vulnerable_code/XSS_variable_multiple_assign.py\n\t > Line 8: other_var = param + \'\'\n\tFile: example/vulnerable_code/XSS_variable_multiple_assign.py\n\t > Line 10: not_the_same_var = \'\' + other_var\n\tFile: example/vulnerable_code/XSS_variable_multiple_assign.py\n\t > Line 12: another_one = not_the_same_var + \'\'\n\tFile: example/vulnerable_code/XSS_variable_multiple_assign.py\n\t > Line 17: ret_make_response = resp\nFile: example/vulnerable_code/XSS_variable_multiple_assign.py\n > reaches line 15, trigger word "replace(": \n\tresp = make_response(html.replace(\'{{ param }}\', another_one))\n\n'#¤%&/()=? diff --git a/tests/run.py b/tests/run.py index 17dad640..32709d6a 100644 --- a/tests/run.py +++ b/tests/run.py @@ -1,11 +1,11 @@ -from subprocess import run, PIPE import argparse +from subprocess import PIPE, run + delimiter = '#¤%&/()=?' -results_file = 'results' -pyt_path = '../pyt/pyt.py' -example_file_path = '../example/vulnerable_code/' -python_name = open('python_name.txt', 'r').read().rstrip() +results_file = 'tests/results' +example_file_path = 'example/vulnerable_code/' +python_name = open('tests/python_name.txt', 'r').read().rstrip() encoding = 'utf-8' parser = argparse.ArgumentParser() @@ -25,7 +25,7 @@ def run_pyt(file_input, stdout=PIPE): - return run([python_name, pyt_path, '-f', file_input], stdout=stdout) + return run([python_name, '-m', 'pyt', '-f', file_input], stdout=stdout) def check_files(): diff --git a/tests/vars_visitor_test.py b/tests/vars_visitor_test.py index 0cbf50ad..4df37acf 100644 --- a/tests/vars_visitor_test.py +++ b/tests/vars_visitor_test.py @@ -1,17 +1,14 @@ -import os -import sys +import ast import unittest -from ast import parse -sys.path.insert(0, os.path.abspath('../pyt')) -from vars_visitor import VarsVisitor +from pyt.vars_visitor import VarsVisitor class VarsVisitorTestCase(unittest.TestCase): '''Baseclass for VarsVisitor tests''' def perform_vars_on_expression(self, expr): - obj = parse(expr) + obj = ast.parse(expr) vars = VarsVisitor() vars.visit(obj) diff --git a/tests/vulnerabilities_test.py b/tests/vulnerabilities_test.py index 23640f14..0b80576e 100644 --- a/tests/vulnerabilities_test.py +++ b/tests/vulnerabilities_test.py @@ -1,16 +1,13 @@ import os -import sys -sys.path.insert(1, os.path.abspath('../pyt')) -import vulnerabilities -import trigger_definitions_parser -from base_test_case import BaseTestCase -from base_cfg import Node -from fixed_point import analyse -from reaching_definitions_taint import ReachingDefinitionsTaintAnalysis -from flask_adaptor import FlaskAdaptor -from lattice import Lattice -from constraint_table import constraint_table, initialize_constraint_table +from .base_test_case import BaseTestCase +from pyt import trigger_definitions_parser, vulnerabilities +from pyt.base_cfg import Node +from pyt.constraint_table import constraint_table, initialize_constraint_table +from pyt.fixed_point import analyse +from pyt.flask_adaptor import FlaskAdaptor +from pyt.lattice import Lattice +from pyt.reaching_definitions_taint import ReachingDefinitionsTaintAnalysis class EngineTest(BaseTestCase): @@ -22,7 +19,7 @@ def get_lattice_elements(self, cfg_nodes): return cfg_nodes def test_parse(self): - definitions = vulnerabilities.parse(trigger_word_file=os.path.join(os.getcwd().replace('tests','pyt'), 'trigger_definitions', 'test_triggers.pyt')) + definitions = vulnerabilities.parse(trigger_word_file=os.path.join(os.getcwd(), 'pyt', 'trigger_definitions', 'test_triggers.pyt')) self.assert_length(definitions.sources, expected_length=1) self.assert_length(definitions.sinks, expected_length=3) @@ -59,25 +56,25 @@ def test_label_contains(self): self.assertEqual(trigger_node_1.cfg_node, cfg_node) self.assertEqual(trigger_node_2.trigger_word, 'request') self.assertEqual(trigger_node_2.cfg_node, cfg_node) - + cfg_node = Node('request.get("stefan")', None, line_number=None, path=None) trigger_words = [('get', []), ('get', [])] l = list(vulnerabilities.label_contains(cfg_node, trigger_words)) self.assert_length(l, expected_length=2) def test_find_triggers(self): - self.cfg_create_from_file('../example/vulnerable_code/XSS.py') + self.cfg_create_from_file('example/vulnerable_code/XSS.py') cfg_list = [self.cfg] - FlaskAdaptor(cfg_list, [], []) + FlaskAdaptor(cfg_list, [], []) XSS1 = cfg_list[1] trigger_words = [('get', [])] l = vulnerabilities.find_triggers(XSS1.nodes, trigger_words) self.assert_length(l, expected_length=1) - + def test_find_sanitiser_nodes(self): cfg_node = Node(None, None, line_number=None, path=None) @@ -87,16 +84,16 @@ def test_find_sanitiser_nodes(self): result = list(vulnerabilities.find_sanitiser_nodes(sanitiser, [sanitiser_tuple])) self.assert_length(result, expected_length=1) self.assertEqual(result[0], cfg_node) - - + + def test_build_sanitiser_node_dict(self): - self.cfg_create_from_file('../example/vulnerable_code/XSS_sanitised.py') + self.cfg_create_from_file('example/vulnerable_code/XSS_sanitised.py') cfg_list = [self.cfg] - FlaskAdaptor(cfg_list, [], []) + FlaskAdaptor(cfg_list, [], []) cfg = cfg_list[1] - + cfg_node = Node(None, None, line_number=None, path=None) sinks_in_file = [vulnerabilities.TriggerNode('replace', ['escape'], cfg_node)] @@ -119,7 +116,7 @@ def test_is_sanitized_false(self): result = vulnerabilities.is_sanitized(sinks_in_file[0], sanitiser_dict, lattice) self.assertEqual(result, False) - + def test_is_sanitized_true(self): cfg_node_1 = Node('Awesome sanitiser', None, line_number=None, path=None) cfg_node_2 = Node('something.replace("this", "with this")', None, line_number=None, path=None) @@ -132,26 +129,26 @@ def test_is_sanitized_true(self): result = vulnerabilities.is_sanitized(sinks_in_file[0], sanitiser_dict, lattice) self.assertEqual(result, True) - + def test_find_vulnerabilities_no_vuln(self): - vulnerability_log = self.run_analysis('../example/vulnerable_code/XSS_no_vuln.py') + vulnerability_log = self.run_analysis('example/vulnerable_code/XSS_no_vuln.py') self.assert_length(vulnerability_log.vulnerabilities, expected_length=0) def test_find_vulnerabilities_sanitised(self): - vulnerability_log = self.run_analysis('../example/vulnerable_code/XSS_sanitised.py') + vulnerability_log = self.run_analysis('example/vulnerable_code/XSS_sanitised.py') self.assert_length(vulnerability_log.vulnerabilities, expected_length=1) def test_find_vulnerabilities_vulnerable(self): - vulnerability_log = self.run_analysis('../example/vulnerable_code/XSS.py') + vulnerability_log = self.run_analysis('example/vulnerable_code/XSS.py') self.assert_length(vulnerability_log.vulnerabilities, expected_length=1) def test_find_vulnerabilities_reassign(self): - vulnerability_log = self.run_analysis('../example/vulnerable_code/XSS_reassign.py') + vulnerability_log = self.run_analysis('example/vulnerable_code/XSS_reassign.py') self.assert_length(vulnerability_log.vulnerabilities, expected_length=1) def test_find_vulnerabilities_variable_assign(self): - vulnerability_log = self.run_analysis('../example/vulnerable_code/XSS_variable_assign.py') + vulnerability_log = self.run_analysis('example/vulnerable_code/XSS_variable_assign.py') self.assert_length(vulnerability_log.vulnerabilities, expected_length=1) def run_analysis(self, path): @@ -161,24 +158,24 @@ def run_analysis(self, path): FlaskAdaptor(cfg_list, [], []) - initialize_constraint_table(cfg_list) + initialize_constraint_table(cfg_list) analyse(cfg_list, analysis_type=ReachingDefinitionsTaintAnalysis) return vulnerabilities.find_vulnerabilities(cfg_list, ReachingDefinitionsTaintAnalysis) def test_find_vulnerabilities_assign_other_var(self): - vulnerability_log = self.run_analysis('../example/vulnerable_code/XSS_assign_to_other_var.py') + vulnerability_log = self.run_analysis('example/vulnerable_code/XSS_assign_to_other_var.py') self.assert_length(vulnerability_log.vulnerabilities, expected_length=1) def test_find_vulnerabilities_variable_multiple_assign(self): - vulnerability_log = self.run_analysis('../example/vulnerable_code/XSS_variable_multiple_assign.py') + vulnerability_log = self.run_analysis('example/vulnerable_code/XSS_variable_multiple_assign.py') self.assert_length(vulnerability_log.vulnerabilities, expected_length=1) def test_find_vulnerabilities_variable_assign_no_vuln(self): - vulnerability_log = self.run_analysis('../example/vulnerable_code/XSS_variable_assign_no_vuln.py') + vulnerability_log = self.run_analysis('example/vulnerable_code/XSS_variable_assign_no_vuln.py') self.assert_length(vulnerability_log.vulnerabilities, expected_length=0) def test_find_vulnerabilities_command_injection(self): - vulnerability_log = self.run_analysis('../example/vulnerable_code/command_injection.py') + vulnerability_log = self.run_analysis('example/vulnerable_code/command_injection.py') self.assert_length(vulnerability_log.vulnerabilities, expected_length=1)