Symbolic execution
1. Parse the code into a DAG
2. Collect all the conditions in every path
3. Solve it using Z3

In [2]:
import z3 as z3
import z3.z3
import dis
from z3 import z3
from collections import deque
import copy
def deepCopy(s: z3.Solver,operandStack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref):
    solverCopy = z3.Solver()
    solverCopy.add(s.assertions())
    operandStackCopy = copy.deepcopy(operandStack)
    mapVariableToLatestVersionIdCopy = copy.deepcopy(mapVariableToLatestVersionId)
    mapVariableToLatestZ3RefCopy = copy.deepcopy(mapVariableToLatestZ3Ref)
    return (solverCopy,operandStackCopy,mapVariableToLatestVersionIdCopy,mapVariableToLatestZ3RefCopy)

def createOperationResult(variableType_: z3.ExprRef,
                                isInt: bool,
                                variableName: str,
                                mapVariableToLatestVersionId: map
                                ):
    if variableName in mapVariableToLatestVersionId:
        #Create a new version
        mapVariableToLatestVersionId[variableName] += 1
    else:
        #First version
        mapVariableToLatestVersionId[variableName] = 1
    match variableType_:
        case z3.ArithRef:
            if isInt:
                res = z3.Int(variableName+"@"+str(mapVariableToLatestVersionId[variableName]))
            else:
                res = z3.Real(variableName+"@"+str(mapVariableToLatestVersionId[variableName]))
        case z3.BoolRef:
            res = z3.Bool(variableName+"@"+str(mapVariableToLatestVersionId[variableName]))
    return res


def handleBinaryInstruction(operationName: str,
                            variableName: str,
                            operandsStack: deque,
                            mapVariableToLatestVersionId: map,
                            mapVariableNameToLatestZ3Ref: map,
                            solver_: z3.Solver):
    operand1 = operandsStack.pop()
    operand2 = operandsStack.pop()
    if variableName is None:#Variable name may not be available
        variableName = "Temp"
    #z3 supports Integers, Reals and Bools
    match operationName:
        case "BINARY_ADD":
            temp = operand1 + operand2
            isInt = False
            if type(temp) == z3.ArithRef:
                temp = operand1 + operand2
                isInt = temp.is_int()
            res = createOperationResult(type(temp),isInt,variableName,mapVariableToLatestVersionId)
            solver_.add(res == operand1 + operand2)
        case "BINARY_SUBTRACT":
            temp = operand1 - operand2
            isInt = False
            if type(temp) == z3.ArithRef:
                isInt = temp.is_int()
            res = createOperationResult(type(temp),isInt,variableName,mapVariableToLatestVersionId)
            solver_.add(res == operand1 - operand2)
        case "BINARY_MULTIPLY":
            temp = operand1 * operand2
            isInt = False
            if type(temp) == z3.ArithRef:
                isInt = temp.is_int()
            res = createOperationResult(type(temp),isInt,variableName,mapVariableToLatestVersionId)
            solver_.add(res == operand1 * operand2)
        case "==":
            res = createOperationResult(z3.BoolRef,False,variableName,mapVariableToLatestVersionId)
            solver_.add(res == (operand1 == operand2))
        case "<":
            res = createOperationResult(z3.BoolRef,False,variableName,mapVariableToLatestVersionId)
            solver_.add(res == (operand1 < operand2))
        case ">":
            res = createOperationResult(z3.BoolRef,False,variableName,mapVariableToLatestVersionId)
            solver_.add(res == (operand1 > operand2))
        case _:
            raise Exception("Operation " + str(operationName) + " is not handled")
        # Need to add DIVISION, COMPARE operations
    mapVariableNameToLatestZ3Ref[variableName] = res
    operandsStack.append(res)
    return



from inspect import signature
from z3 import CheckSatResult
from collections import deque

def handleAssignment(       variableName: str,
                            operandsStack: deque,
                            mapVariableToLatestVersionId: map,
                            mapVariableNameToLatestZ3Ref: map,
                            solver_: z3.Solver):
    if len(operandsStack)==0:
        raise Exception("Operand stack is empty, did you push all the loads ?")
    tempVariable = operandsStack.pop()
    if issubclass(type(tempVariable),z3.ExprRef):
        isInt = False
        if type(tempVariable) == z3.ArithRef and tempVariable.is_int():
            isInt = True
        res = createOperationResult(type(tempVariable),isInt,variableName,mapVariableToLatestVersionId)
    else:
        if type(tempVariable) == int:
            res = createOperationResult(z3.ArithRef,True,variableName,mapVariableToLatestVersionId)
        elif type(tempVariable) == float:
            res = createOperationResult(z3.ArithRef,False,variableName,mapVariableToLatestVersionId)
        elif type(tempVariable) == bool:
            res = createOperationResult(z3.BoolRef,False,variableName,mapVariableToLatestVersionId)
        else:
            raise Exception("Unhandled type:"+ type(tempVariable))
    # operandsStack.append(res)
    mapVariableNameToLatestZ3Ref[variableName] = res
    solver_.add(res == tempVariable)


def constructArguments(fn):
    mapVariableToLatestVersionId = {}
    mapVariableToLatestZ3Ref = {}
    params = signature(fn).parameters
    for var in params.keys():
        variableName = params[var].name
        variableType = params[var].annotation.__name__

        if str(variableType) == "int" or str(variableType) == "float":
            type_ = z3.ArithRef
            isInt_ = str(variableType)=="int"
        elif str(variableType)=="bool":
            type_ = z3.BoolRef
            isInt_ = False
        else:
            raise Exception("Unhandled type" + str(param.annotation) + " did you forget to mention the type of the arguments in the function signature ?  Ex: def func1(a:int,b:int):")
        res = createOperationResult(type_,isInt_,str(variableName),mapVariableToLatestVersionId)
        mapVariableToLatestZ3Ref[str(variableName)]=res
    return mapVariableToLatestVersionId, mapVariableToLatestZ3Ref


from dis import Instruction
from typing import List
import math
class Node:
    def __init__(self, nodeId_,instructions_: List[Instruction],children_: List['Node']=None):
        if children_ is None:
            children_ = []
        self.children = children_
        self.nodeId = nodeId_
        self.instructions = instructions_
def DFS(currentNode:Node,operandsStack:deque,
        mapVariableToLatestVersionId,
        mapVariableToLatestZ3Ref,
        s_: z3.Solver,models):

    # print("Node:"+str(currentNode.nodeId))
    for instr_ in currentNode.instructions:
        match str(instr_.opname):
            case "LOAD_FAST":
                if str(instr_.argval) in mapVariableToLatestVersionId:
                    operandsStack.append(mapVariableToLatestZ3Ref[str(instr_.argval)])
                else:
                    raise Exception(instr_.argval + " not seen earlier")
            case "LOAD_CONST":
                operandsStack.append(instr_.argval)
            case "STORE_FAST":
                variableName = str(instr_.argval)
                handleAssignment(variableName,operandsStack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref,s_)
            case "BINARY_MULTIPLY" | "BINARY_SUBTRACT" | "BINARY_ADD" :
                handleBinaryInstruction(str(instr_.opname),None,operandsStack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref,s_)
            case "RETURN_VALUE":
                #Generate test case
                if isinstance(s_.check(), CheckSatResult):
                    print(s_.assertions())
                    model = s_.model()
                    models.append(model)
            case "COMPARE_OP":
                handleBinaryInstruction(str(instr_.argval),None,operandsStack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref,s_)
            case "POP_JUMP_IF_FALSE" | "POP_JUMP_IF_TRUE":
                #assert there must be two children
                # Pop the top element of the operand stack
                # deepcopy the solver,operand stack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref once for each child
                # Case 1: add an entry that the top element is False into the solver, DFS to the child which has the jump offset
                # Case 2: add an entry to the solver that the top element is True, DFS to the child
                assert len(currentNode.children) == 2, f"Current node {currentNode} doesn't have 2 children"
                boolElement = operandsStack.pop()
                jumpIfFalse = "FALSE" in str(instr_.opname)
                for child in currentNode.children:
                    s_Copy,operandsStackCopy,mapVariableToLatestVersionIdCopy,mapVariableToLatestZ3RefCopy = deepCopy(s_,operandsStack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref)
                    # print("CHILD!!!!")
                    if str(child.nodeId) == str(instr_.argval): # JUMP
                        if jumpIfFalse:
                            s_Copy.add(boolElement == False)
                        else:
                            s_Copy.add(boolElement == True)
                    else:
                        if jumpIfFalse:
                            s_Copy.add(boolElement == True)
                        else:
                            s_Copy.add(boolElement == False)
                    DFS(child,operandsStackCopy,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref,s_Copy,models)
                    # print(child.nodeId)
                    # print("CHILD!!!!")
            case _:
                raise Exception("Unhandled instruction:"+str(instr_))
        # print("$$$$$$$$$$$$$$")
        # print(instr_)
        # print("------")
        # print(mapVariableToLatestVersionId)
        # print("------")
        # print(mapVariableToLatestZ3Ref)
        # print("------")
        # print(operandsStack)
        # print("$$$$$$$$$$$$$$")


def generateTestCases(fn):
    s =  z3.Solver()
    (mapVariableToLatestVersionId,mapVariableToLatestZ3Ref) = constructArguments(fn)

    # Dividing the bytecode into blocks of instructions
    # Each block is uniquely identified by the line offset of its starting instructions in byte code
    # Each block can have more than one parent and maximum two children
    edges = []
    nodes = []
    prevInstructionIsJump = False
    byteCode = dis.Bytecode(fn)
    print(byteCode)
    for instr in byteCode:
        # Adding edges
        if prevInstructionIsJump:
            edges.append((nodes[-1],instr.offset))
        if "JUMP" in instr.opname:
            edges.append((nodes[-1],int(instr.argval)))
        # Adding nodes
        if instr.is_jump_target:
            nodes.append(instr.offset)
        if prevInstructionIsJump:
            nodes.append(instr.offset)
        if instr.offset == 0:
            nodes.append(instr.offset)
        # print(instr.offset,instr.opname,instr.arg,instr.argval,instr.is_jump_target)
        if "JUMP" in instr.opname:
            prevInstructionIsJump = True
        else:
            prevInstructionIsJump = False

    nodes = list(set(nodes))
    nodes.sort()
    edges = list(set(edges))

    # Constructing a graph where each node/block will include details like
    # 1. What are the instructions being used
    # 2. What are its child nodes

    nodeIdToGraphObjectMapping = {}
    for i in range(len(nodes)):
        start = nodes[i]
        # If i is the last node
        if i+1==len(nodes):
            end = math.inf
        else:
            end = nodes[i+1]
        instructions = []
        nodeId = start
        for instr in byteCode:
            if start <= instr.offset < end:
                instructions.append(instr)
        nodeIdToGraphObjectMapping[nodeId] = Node(nodeId,instructions)

    for (nodeId1,nodeId2) in edges:
        nodeIdToGraphObjectMapping[nodeId1].children.append(nodeIdToGraphObjectMapping[nodeId2])

    root = nodeIdToGraphObjectMapping[0]

    from graphviz import Digraph
    gra = Digraph()
    for (nodeId,node) in nodeIdToGraphObjectMapping.items():
        gra.node(str(nodeId),str(",".join([str(instr.offset) for instr in node.instructions])))
    for (nodeId,node) in nodeIdToGraphObjectMapping.items():
        for child in node.children:
            gra.edge(str(nodeId),str(child.nodeId))
    operandsStack = deque([])
    models = []
    # print(mapVariableToLatestVersionId)
    # print("-----")
    # print(mapVariableToLatestZ3Ref)
    DFS(root,operandsStack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref,s,models)
    print(models)
    return gra,byteCode,models

In [3]:
def func1(a:int,b:int):
    a = a * 1
    b = b + 4
    a = a + 4
    c = a * a
    return a

def func2(a:int,b:int):
    if a == 1 and b==2:
        b = 2
        return 233
    return 333

def func3(a:int,b:int):
    if a == 1 and b==2:
        b = 2
        return 233
    elif b==3:
        b = 4
        return
    else:
        a = 1
    return 333

print(dis.dis(func3))
# gra,byteCode,models= generateTestCases(func2)

# gra
# models

 15           0 LOAD_FAST                0 (a)
              2 LOAD_CONST               1 (1)
              4 COMPARE_OP               2 (==)
              6 POP_JUMP_IF_FALSE       12 (to 24)
              8 LOAD_FAST                1 (b)
             10 LOAD_CONST               2 (2)
             12 COMPARE_OP               2 (==)
             14 POP_JUMP_IF_FALSE       12 (to 24)

 16          16 LOAD_CONST               2 (2)
             18 STORE_FAST               1 (b)

 17          20 LOAD_CONST               3 (233)
             22 RETURN_VALUE

 18     >>   24 LOAD_FAST                1 (b)
             26 LOAD_CONST               4 (3)
             28 COMPARE_OP               2 (==)
             30 POP_JUMP_IF_FALSE       20 (to 40)

 19          32 LOAD_CONST               5 (4)
             34 STORE_FAST               1 (b)

 20          36 LOAD_CONST               0 (None)
             38 RETURN_VALUE

 22     >>   40 LOAD_CONST               1 (1)
             42 STORE_F

In [37]:
# from z3 import z3
# a = z3.Int('a')
# s = z3.Solver()
# c = z3.Bool('c')
# s.add(c == True)
# # s2 = z3.Solver()
# # s2.add(s.assertions())
# # s2.add(c == 2)
#
#
# s.check()
# z = s.model()
# z

# s2.check()
# z = s2.model()
# z


#
# mp = {'a': a,'b': c}
# mp['a']
#
# import copy
#
# mp1 = copy.deepcopy(mp)

In [13]:
z