Symbolic execution
1. Parse the code into a DAG
2. Collect all the conditions in every path
3. Solve it using Z3

In [283]:
import z3 as z3
import z3.z3



import dis
z = dis.Bytecode(getMaxValue)
from z3 import z3
from collections import deque
def createOperationResult(variableType_: z3.ExprRef,
                                isInt: bool,
                                variableName: str,
                                mapVariableToLatestVersionId: map
                                ):
    if variableName in mapVariableToLatestVersionId:
        #Create a new version
        mapVariableToLatestVersionId[variableName] += 1
    else:
        #First version
        mapVariableToLatestVersionId[variableName] = 1
    match variableType_:
        case z3.ArithRef:
            if isInt:
                res = z3.Int(variableName+"@"+str(mapVariableToLatestVersionId[variableName]))
            else:
                res = z3.Real(variableName+"@"+str(mapVariableToLatestVersionId[variableName]))
        case z3.BoolRef:
            res = z3.Bool(variableName+"@"+str(mapVariableToLatestVersionId[variableName]))
    return res


def handleBinaryInstruction(operationName: str,
                            variableName: str,
                            operandsStack: deque,
                            mapVariableToLatestVersionId: map,
                            mapVariableNameToLatestZ3Ref: map,
                            solver_: z3.Solver):
    operand1 = operandsStack.pop()
    operand2 = operandsStack.pop()
    if variableName is None:#Variable name may not be available
        variableName = "Temp"
    #z3 supports Integers, Reals and Bools
    match operationName:
        case "BINARY_ADD":
            temp = operand1 + operand2
            isInt = False
            if type(temp) == z3.ArithRef:
                temp = operand1 + operand2
                isInt = temp.is_int()
            res = createOperationResult(type(temp),isInt,variableName,mapVariableToLatestVersionId)
            solver_.add(res == operand1 + operand2)
        case "BINARY_SUBTRACT":
            temp = operand1 - operand2
            isInt = False
            if type(temp) == z3.ArithRef:
                isInt = temp.is_int()
            res = createOperationResult(type(temp),isInt,variableName,mapVariableToLatestVersionId)
            solver_.add(res == operand1 - operand2)
        case "BINARY_MULTIPLY":
            temp = operand1 * operand2
            isInt = False
            if type(temp) == z3.ArithRef:
                isInt = temp.is_int()
            res = createOperationResult(type(temp),isInt,variableName,mapVariableToLatestVersionId)
            print("Operand1:")
            print(operand1)
            print("Operand2:")
            print(operand2)
            solver_.add(res == operand1 * operand2)
        # Need to add DIVISION, COMPARE operations
    mapVariableNameToLatestZ3Ref[variableName] = res
    operandsStack.append(res)
    return

# # operation = instr_.opname
# # variableName = str(instr.argval)
# operand1 = z3.Int('a@1')
# operand2 = 1
#
# operandsStack = deque()
# operandsStack.append(operand1)
# operandsStack.append(operand2)
#
# operationName = "BINARY_MULTIPLY"
# mapVariableToLatestVersionId = {"a":1}
# mapVariableToLatestZ3Ref = {"a":operand1}
#
# variableName = "c"
# s = z3.Solver()
# handleBinaryInstruction(operationName,variableName,operandsStack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref,s)
# z3.solve(s.assertions())

# s.check()
# Always do check before finding the model
# print(s.check())
# zz = s.model()
# print(zz)
# zz[operand1]




# dis.dis(getMaxValue)



from inspect import signature
from z3 import CheckSatResult
from collections import deque

def handleAssignment(       variableName: str,
                            operandsStack: deque,
                            mapVariableToLatestVersionId: map,
                            mapVariableNameToLatestZ3Ref: map,
                            solver_: z3.Solver):
    if len(operandsStack)==0:
        raise Exception("Operand stack is empty, did you push all the loads ?")
    tempVariable = operandsStack.pop()
    if issubclass(type(tempVariable),z3.ExprRef):
        isInt = False
        if type(tempVariable) == z3.ArithRef and tempVariable.is_int():
            isInt = True
        res = createOperationResult(type(tempVariable),isInt,variableName,mapVariableToLatestVersionId)
    else:
        if type(tempVariable) == int:
            res = createOperationResult(z3.ArithRef,True,variableName,mapVariableToLatestVersionId)
        elif type(tempVariable) == float:
            res = createOperationResult(z3.ArithRef,False,variableName,mapVariableToLatestVersionId)
        elif type(tempVariable) == bool:
            res = createOperationResult(z3.BoolRef,False,variableName,mapVariableToLatestVersionId)
        else:
            raise Exception("Unhandled type:"+ type(tempVariable))
    # operandsStack.append(res)
    mapVariableNameToLatestZ3Ref[variableName] = res
    solver_.add(res == tempVariable)


def constructArguments(fn):
    mapVariableToLatestVersionId = {}
    mapVariableToLatestZ3Ref = {}
    params = signature(fn).parameters
    for var in params.keys():
        variableName = params[var].name
        variableType = params[var].annotation.__name__

        if str(variableType) == "int" or str(variableType) == "float":
            type_ = z3.ArithRef
            isInt_ = str(variableType)=="int"
        elif str(variableType)=="bool":
            type_ = z3.BoolRef
            isInt_ = False
        else:
            raise Exception("Unhandled type" + str(param.annotation) + " did you forget to mention the type of the arguments in the function signature ?  Ex: def func1(a:int,b:int):")
        res = createOperationResult(type_,isInt_,str(variableName),mapVariableToLatestVersionId)
        mapVariableToLatestZ3Ref[str(variableName)]=res
    return mapVariableToLatestVersionId, mapVariableToLatestZ3Ref



def DFS(currentNode:Node,operandsStack:deque,
        mapVariableToLatestVersionId,
        mapVariableToLatestZ3Ref,
        s_: z3.Solver,models):

    for instr_ in currentNode.instructions:
        match str(instr_.opname):
            case "LOAD_FAST":
                if str(instr_.argval) in mapVariableToLatestVersionId:
                    operandsStack.append(mapVariableToLatestZ3Ref[str(instr_.argval)])
                else:
                    raise Exception(instr_.argval + " not seen earlier")
            case "LOAD_CONST":
                operandsStack.append(instr_.argval)
            case "STORE_FAST":
                variableName = str(instr_.argval)
                handleAssignment(variableName,operandsStack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref,s_)
            case "BINARY_MULTIPLY" | "BINARY_SUBTRACT" | "BINARY_ADD":
                handleBinaryInstruction(str(instr_.opname),None,operandsStack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref,s_)
            case "RETURN_VALUE":
                #Generate test case
                if isinstance(s_.check(), CheckSatResult):
                    model = s.model()
                    models.append(model)
        print("$$$$$$$$$$$$$$")
        print(instr_)
        print("------")
        print(mapVariableToLatestVersionId)
        print("------")
        print(mapVariableToLatestZ3Ref)
        print("------")
        print(operandsStack)
        print("$$$$$$$$$$$$$$")
def func1(a:int,b:int):
    a = a * 1
    b = b + 4
    a = a + 4
    c = a * a
    return a
s =  z3.Solver()
(mapVariableToLatestVersionId,mapVariableToLatestZ3Ref) = constructArguments(func1)

# Dividing the bytecode into blocks of instructions
# Each block is uniquely identified by the line offset of its starting instructions in byte code
# Each block can have more than one parent and maximum two children
edges = []
nodes = []
prevInstructionIsJump = False
byteCode = dis.Bytecode(func1)
print(byteCode)
for instr in byteCode:
    # Adding edges
    if prevInstructionIsJump:
        edges.append((nodes[-1],instr.offset))
    if "JUMP" in instr.opname:
        edges.append((nodes[-1],int(instr.argval)))
    # Adding nodes
    if instr.is_jump_target:
        nodes.append(instr.offset)
    if prevInstructionIsJump:
        nodes.append(instr.offset)
    if instr.offset == 0:
        nodes.append(instr.offset)
    print(instr.offset,instr.opname,instr.arg,instr.argval,instr.is_jump_target)
    if "JUMP" in instr.opname:
        prevInstructionIsJump = True
    else:
        prevInstructionIsJump = False

nodes = list(set(nodes))
nodes.sort()
edges = list(set(edges))

# Constructing a graph where each node/block will include details like
# 1. What are the instructions being used
# 2. What are its child nodes
from dis import Instruction
from typing import List
import math
class Node:
    def __init__(self, nodeId_,instructions_: List[Instruction],children_: List['Node']=None):
        if children_ is None:
            children_ = []
        self.children = children_
        self.nodeId = nodeId_
        self.instructions = instructions_

nodeIdToGraphObjectMapping = {}
for i in range(len(nodes)):
    start = nodes[i]
    # If i is the last node
    if i+1==len(nodes):
        end = math.inf
    else:
        end = nodes[i+1]
    instructions = []
    nodeId = start
    for instr in byteCode:
        if start <= instr.offset < end:
            instructions.append(instr)
    nodeIdToGraphObjectMapping[nodeId] = Node(nodeId,instructions)

for (nodeId1,nodeId2) in edges:
    nodeIdToGraphObjectMapping[nodeId1].children.append(nodeIdToGraphObjectMapping[nodeId2])

root = nodeIdToGraphObjectMapping[0]

from graphviz import Digraph
gra = Digraph()
for (nodeId,node) in nodeIdToGraphObjectMapping.items():
    gra.node(str(nodeId),str(",".join([str(instr.offset) for instr in node.instructions])))
for (nodeId,node) in nodeIdToGraphObjectMapping.items():
    for child in node.children:
        gra.edge(str(nodeId),str(child.nodeId))
gra
operandsStack = deque([])
models = []
print(mapVariableToLatestVersionId)
print("-----")
print(mapVariableToLatestZ3Ref)
DFS(root,operandsStack,mapVariableToLatestVersionId,mapVariableToLatestZ3Ref,s,models)
print(models)

Bytecode(<function func1 at 0x123f54e50>)
0 LOAD_FAST 0 a False
2 LOAD_CONST 1 1 False
4 BINARY_MULTIPLY None None False
6 STORE_FAST 0 a False
8 LOAD_FAST 1 b False
10 LOAD_CONST 2 4 False
12 BINARY_ADD None None False
14 STORE_FAST 1 b False
16 LOAD_FAST 0 a False
18 LOAD_CONST 2 4 False
20 BINARY_ADD None None False
22 STORE_FAST 0 a False
24 LOAD_FAST 0 a False
26 LOAD_FAST 0 a False
28 BINARY_MULTIPLY None None False
30 STORE_FAST 2 c False
32 LOAD_FAST 0 a False
34 RETURN_VALUE None None False
{'a': 1, 'b': 1}
-----
{'a': a@1, 'b': b@1}
$$$$$$$$$$$$$$
Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='a', argrepr='a', offset=0, starts_line=194, is_jump_target=False)
------
{'a': 1, 'b': 1}
------
{'a': a@1, 'b': b@1}
------
deque([a@1])
$$$$$$$$$$$$$$
$$$$$$$$$$$$$$
Instruction(opname='LOAD_CONST', opcode=100, arg=1, argval=1, argrepr='1', offset=2, starts_line=None, is_jump_target=False)
------
{'a': 1, 'b': 1}
------
{'a': a@1, 'b': b@1}
------
deque([a@1, 1])
$$$$$$$$$