In [82]:
import os
import panoramix
import sys
from panoramix.decompiler import decompile_address, decompile_bytecode
import numpy as np
import re
from multiprocessing import Pool
from web3 import Web3
import pandas as pd
from tqdm import tqdm
pd.options.mode.chained_assignment = None
import glob
from abc import ABC, abstractmethod
from overrides import overrides
from google.cloud import storage
from google.cloud import bigquery
from google.cloud.storage import Client, transfer_manager
import os
import shutil
import gzip


class BaseUPCDetector (ABC):
    def __init__(self):    
        self.PAYABLE_FUNC_REGEX = re.compile(r"def\s*(\w*)\s*(\(.*?\))?[^\n]+default function[^\n]*",re.IGNORECASE) # single-slash, foll'd by word: \HOSTNAME  
        self.FUNC_REGEX = re.compile(r"def\s+[^\n]+",re.IGNORECASE)
        self.STORAGE_FUNC_REGEX = re.compile(r"def storage\s*[^\n]+",re.IGNORECASE)
        self.ASSIGNMENT_REGEX = re.compile(r"([^\n]+) (=) ([^\n]+)",re.IGNORECASE)
        self.FUNC_SIG_REGEX = re.compile(r'def (\w*)\s*(\(.*?\))?\s*[^\n]+',re.IGNORECASE)
        self.FUNC_SELECTOR_REGEX = re.compile(r"def (\w*\s*(\(.*?\))?)\s*[^\n]+",re.IGNORECASE)
        self.HASH_BASED_STORAGE_SELECTOR_REGEX = re.compile(r"[^\n]+(stor[^\n]*\[[^\n]+\])[^\n]+",re.IGNORECASE)
        
        self.DEBUG_MODE = False             # sw to enable debug mode with printing 
        self.fallback_func_found_sw = 0     # sw when a fallback function is found
        self.proxy_found_sw = 0             # sw when a relevant delegate is found
        self.storage_found_sw = 0           # sw that shows if the storage is found in the decompiled bytecode
        self.curr_func_loc = "global"       # current function loc
        self.is_piv_of_type_mapping = False # is primary impact variable of type mapping
        self.state_vars = []                # all parsed state variables
        self.delegatecalls = []             # all relevant delegatecalls
        self.all_assign_exps = dict()       # all assignment expressions
        self.all_func_sigs = dict()         # all function signature <loc, sig>
        self.all_func_start_end_line = dict() # all function start line and end line <sig, (start line, end line)>
        self.all_state_vars_code = dict()   # state variables pure code
        self.func_to_assign_exps = dict()   # mapping from function to their assignment expressions
        self.curr_func_assign_exps = set()  # current function assignment expresssions
        self.piv = set()                    # primary impact variables
        self.siv = set()                    # secondary impact variables
        self.tiv = set()                    # teritiary impact variables
        self.qiv = []                       # quaternary impact variables
        self.impact_var_slots = set()       # all the impact variable storage slots
        self.implementation_storage = set() #
        self.upgrade_funcs1  = set()        # all upgrade function found for the current relevant delegatecall
        self.upgrade_funcs2  = set()        # all upgrade funciton found for the current relevant delegatecall
        self.all_impact_vars = list()       # all impact variables by name
    
    def prepare_env(self):
        self.fallback_func_found_sw = 0     # sw when a fallback function is found
        self.proxy_found_sw = 0             # sw when a relevant delegate is found
        self.storage_found_sw = 0           # sw that shows if the storage is found in the decompiled bytecode
        self.curr_func_loc = "global"       # current function loc
        self.is_piv_of_type_mapping = False # is primary impact variable of type mapping
        self.state_vars = []                # all parsed state variables
        self.delegatecalls = []             # all relevant delegatecalls
        self.all_assign_exps = dict()       # all assignment expressions
        self.all_func_sigs = dict()         # all function signature <loc, sig>
        self.all_func_start_end_line = dict() # all function start line and end line <sig, (start line, end line)>
        self.all_state_vars_code = dict()   # state variables pure code
        self.func_to_assign_exps = dict()   # mapping from function to their assignment expressions
        self.curr_func_assign_exps = set()  # current function assignment expresssions
        self.piv = set()                    # primary impact variables
        self.siv = set()                    # secondary impact variables
        self.tiv = set()                    # teritiary impact variables
        self.qiv = []                       # quaternary impact variables
        self.impact_var_slots = set()       # all the impact variable storage slots
        self.implementation_storage = set() #
        self.upgrade_funcs1  = set()        # all upgrade function found for the current relevant delegatecall
        self.upgrade_funcs2  = set()        # all upgrade funciton found for the current relevant delegatecall
        self.all_impact_vars = list()       # all impact variables by name

    def parse_assignment_exp(self, curr_loc, curr_line):
        # check if the current line is an assingment, then store it
        ar = re.search(self.ASSIGNMENT_REGEX,curr_line)
        if ar:
            self.all_assign_exps[curr_loc] = ar
            self.curr_func_assign_exps.add(ar) # add current assingment to the list of assignments for the current function

    def parse_function_exp(self, curr_loc, curr_line):
        # check if the current line is a function, then store it
        afs = re.search(self.FUNC_SIG_REGEX, curr_line)
        if afs:
            self.all_func_sigs[curr_loc] = afs # store the function with its line
            self.all_func_sigs = dict(sorted(self.all_func_sigs.items()))  # sort function definitions in ascending fashion in terms of line number
            self.curr_func_loc = curr_loc # update the current function 
            self.curr_func_assign_exps = set() # reset the list of function's assignment when observe a new function
            self.all_func_start_end_line[afs.group(0)] = (curr_loc,0)
    def map_assignments_to_function(self, curr_loc, curr_line):
        if curr_line == "\n" and self.curr_func_loc != "global": # then it means the previous function is finished
            self.func_to_assign_exps[self.curr_func_loc] = self.curr_func_assign_exps # now save the list of assignment for the function
            curr_func_start_end_line = self.all_func_start_end_line[self.all_func_sigs[self.curr_func_loc].group(0)] # get the current function start and end line
            self.all_func_start_end_line[self.all_func_sigs[self.curr_func_loc].group(0)] = (curr_func_start_end_line[0], curr_loc) # update the end line number
    def fix_storage_variable_name(self, state_var_chunks):
        var_name = state_var_chunks[0] 
        # fix long storage variable issues
        if len(var_name) >= 40:
            var_slot = self.get_variable_slot(state_var_chunks)
            if var_slot.find("0x") >= 0:
                var_name = "stor" + var_slot[2:6].upper()
            state_var_chunks[0] = var_name
        return state_var_chunks

    def get_variable_slot(self, state_var_chunks):
        if state_var_chunks[-2].find("offset") < 0:
            return state_var_chunks[-1].strip()
        else:
            return state_var_chunks[-3].strip()  

    def parse_contract_storage(self, curr_line):
        # Note that a storage function always locates at the begining of the file
        # if a "storage function" is not found yet
        if self.storage_found_sw == 0:
            # check if the current line is the storage function
            sfr = re.search(self.STORAGE_FUNC_REGEX, curr_line)
            # if yes, turn on storage_sw_found
            if sfr:
                self.storage_found_sw = 1
        # if storage function is found and the current line is neither a function or an empty line (i.e., signs that indicate end of the storage func)
        elif self.storage_found_sw == 1 and not re.search(self.FUNC_REGEX, curr_line) and curr_line != "\n":
            # fetch the storage variable and prune some irrelevant keywords
            state_var = re.sub("\s\s+", " ",curr_line.replace(' is ', ' ').replace(' at ', ' ').replace(' storage ', ' ').replace("\n", ""))
            # split it into chunks
            state_var_chunks = state_var.split() 
            # fix the var name if it is long. decompiler issue
            state_var_chunks = self.fix_storage_variable_name(state_var_chunks)
            # store the variable in the list of all storage varialbes
            self.state_vars.append(state_var_chunks)    
            # get the variable name
            state_var_name = state_var_chunks[0]
            # if there are several variables with similar name, keep the one that is an address, otherwise the most recent one.
            if state_var_name in self.all_state_vars_code:
                if self.all_state_vars_code[state_var_name].find(" addr ")>=0:
                    pass
                else:
                    self.all_state_vars_code[state_var_name] = curr_line.rstrip().lstrip().replace(state_var.split()[0], state_var_name)
            else:
                self.all_state_vars_code[state_var_name] = curr_line.rstrip().lstrip().replace(state_var.split()[0], state_var_name)   
        # else a new function is started, or the storage function is ended by the new line, 
        # therefore we need to stop looking for a storage variables
        else:
            self.storage_found_sw = 2

    def find_delegatecall_function(self, delegatecall_loc):
        for func_loc in reversed(sorted(self.all_func_sigs.keys())):
            if func_loc < delegatecall_loc:
                return func_loc, self.all_func_sigs[func_loc]
        raise("DELEGATECALL FUNCTION NOT FOUND!")

    def parse_proxy_function(self, curr_loc, curr_line):
        if curr_line.find('delegate ') >= 0 and curr_line.split()[1].strip() != "0x0":
            self.proxy_found_sw = 1
            self.delegatecalls.append((curr_loc, curr_line.replace('\n', '').strip(), self.curr_func_loc, self.all_func_sigs[self.curr_func_loc].group(0)))
        elif self.proxy_found_sw == 1:
            # get function where the last delegatecall is defined in
            _, delegtecall_func = self.find_delegatecall_function(self.delegatecalls[-1][0]) 
            # extract the function signature
            func_sig = delegtecall_func.group(1)
            func_sig = func_sig[len("unknown"):] if func_sig.find("unknown") >=0 else func_sig
            # the last OR supports cases where the target function is hardcoded yet having a similar signitature to delegatecall's function one
            if curr_line.find("call.data[0 len 4]") >= 0 or curr_line.find("call.data[return_data.size len 4]")>= 0 or self.delegatecalls[-1][1].find(".0x" + func_sig)>=0:   
                self.proxy_found_sw = 0      
            # if similar interface checks fails then this is not a relevant delegatecall and we removeit from our list of relevant delegatecalls.
            else:
                self.delegatecalls.pop()
                self.proxy_found_sw = 0

    def exclude_forwarder_delegatecalls(self, bytecode):
        # =====================================
        # check for any possible forwarders. it can be the case that there are several delegatecalls
        # within a fallback. among those we need to detect the forwarders and exclude them.
        # for each delegatecall we check if the address to which delegatecall is sent is appeard
        # the from_address bytecode
        # =======================================================================================================
        forwarders_index = []
        for idx, delegatecall in enumerate(self.delegatecalls):
            if bytecode.find(delegatecall[1].split()[1][2:]) >= 0:
                forwarders_index.append(idx) 
        self.delegatecalls = [delegatecall for idx, delegatecall in enumerate(self.delegatecalls) if idx not in forwarders_index]

    def exclude_duplicate_delegatecalls(self):
        # remove duplicate delegatecalls
        seen = set()
        temp_delegatecalls = []
        for delegatecall in self.delegatecalls:
            if str(delegatecall[1]) in seen:
                continue
            else:
                seen.add(str(delegatecall[1]))
                temp_delegatecalls.append(delegatecall)       
        self.delegatecalls = temp_delegatecalls
        
    def set_detector_type(self, curr_delegatecall):
        # if implementation address are obtained via an external call, there is a sign that this contract
        # is implementing an UUPS pattern.        
        if curr_delegatecall[1].find('ext_call')>=0:
            return "ESUP"
        elif curr_delegatecall[1].find('mem[')>=0:
            # check to see if mem = ext_call occurs in the delegatecall function
            sw = False
            for assig in self.func_to_assign_exps[self.find_delegatecall_function(curr_delegatecall[0])[0]]:
                if assig.group(1).find("mem[") >=0 and assig.group(3).find("ext_call") >=0:
                    sw = True
                    break
            if sw:
                return "ESUP"
            else:
                return "Keep Checking"

    def get_variable_slot(self, var_chunks):
        if var_chunks[-2].find("offset") < 0:
            return var_chunks[-1].strip()
        else:
            return var_chunks[-3].strip()  

    def detect_primary_impact_variables(self, curr_delegatecall):
        # if the impl variable is coming from a memory and the memory is initalized by a state variable in prior lines, then state variable is primary impact variable
        # because if one can change the state variable then the memory would change and subsequently the imp address
        if curr_delegatecall[1].find('delegate mem[')>=0:
            for assig in self.func_to_assign_exps[self.find_delegatecall_function(curr_delegatecall[0])[0]]:
                if assig.group(1).find("mem[") >=0:
                    for var in self.state_vars:
                        var_name = var[0].strip()
                        if assig.group(3).find(var_name) >=0:
                            self.piv.add(var_name)
                            self.impact_var_slots.add(self.get_variable_slot(var))
                            self.implementation_storage.add(self.all_state_vars_code[var_name])
                        if self.all_state_vars_code[var_name].find(" mapping ")>=0:
                            self.is_piv_of_type_mapping = True
        else:
            for var in self.state_vars:
                # the following two condition ensures that the implementation variable fetched out of delegatecall matches storage variables correctly
                var_name = var[0].strip()
                if curr_delegatecall[1].find(var_name)>=0 and \
                    curr_delegatecall[1][curr_delegatecall[1].find(var_name) + len(var_name)] in [" ", ".", '[', '(', ']', ')']:
                    if self.DEBUG_MODE:
                        print("\"{}\"".format(curr_delegatecall[1][curr_delegatecall[1].find(var_name) + len(var_name)]))
                    self.piv.add(var_name)
                    self.impact_var_slots.add(self.get_variable_slot(var))
                    self.implementation_storage.add(self.all_state_vars_code[var_name])
                    if self.all_state_vars_code[var_name].find(" mapping ")>=0:
                        self.is_piv_of_type_mapping = True
    
    def detect_quaternary_impact_variables(self, curr_delegatecall):
        # =====================================
        # if the delegatecall address is not in the storage variable at all. there are two possibilities
        # either i) the delegate call address is comming from external entity, ii) is stored via hash function 
        # (i.e., sha3, kaack)
        # in anycase, we check if the delegatecall address is upgradeable within the proxy contract. if not, the pattern
        # could be beacon or uups
        # =======================================================================================================        
        if len(self.piv) == 0:
            hbssr = re.search(self.HASH_BASED_STORAGE_SELECTOR_REGEX, curr_delegatecall[1].split()[1])
            if hbssr:
                self.qiv.append(hbssr.group(1))
            # elif curr_delegatecall[1].split()[1] != "0x0":
            #     self.qiv.append(curr_delegatecall[1].split()[1])
            
                self.implementation_storage.add(curr_delegatecall[1].split()[1])

    def detect_secondary_impact_variables(self):
        # =====================================
        # all storage variables with different name but same storage slot as the primary impact variable
        # to the delegatecall address must be examined cause they affect the implementation slot.
        # here we do not care if they have similar types to the primary one.
        # =======================================================================================================            
        for var in self.state_vars:
            var_slot = self.get_variable_slot(var)
            if var[0].strip() not in self.piv and var_slot in self.impact_var_slots:
                self.siv.add(var[0].strip())
    
    def detect_teritiary_impact_variables(self):
        # =====================================
        # for primary + secondary and quaternary variables we examine to all assignments to get the assigned 
        # value for those if any exist
        # =======================================================================================================        
        for impact_var in list(set(self.piv.union(self.siv).union(self.qiv))):
            for loc in self.all_assign_exps.keys():
                # check if the impact variable locates on the left side of the assignment operand
                if re.sub("\s\s+", " ", self.all_assign_exps[loc].group(1)).strip().find(impact_var) >= 0:
                    # get the value of the assihgment
                    right_operand = re.sub("\s\s+", " ", self.all_assign_exps[loc].group(3)).strip()
                    # if value is not numeric
                    if not right_operand.isnumeric():
                        # then we check if assigned value is stored in the storage. 
                        # cause otherwise, it means the assinged value is either and external call or is one of function parameters.
                        # and we do not need to keep them cause later on we evaulate all assingments again
                        for state_var in self.state_vars:
                            if right_operand.find(state_var[0].strip())>=0: # the assigned value to the primary imp must be itself in the storage to be included in tertiary imps
                                self.tiv.add(state_var[0].strip())
                                tiv_slot = self.get_variable_slot(state_var)
                                self.impact_var_slots.add(tiv_slot)
                                # lets double check one more time if there is any other variable with similar slot but different name 
                                # to the just found tertiary one and add it to the tertiary list
                                for var in self.state_vars:
                                    var_slot = self.get_variable_slot(var)
                                    if var_slot == tiv_slot:
                                        self.tiv.add(var[0].strip())
                                break

    def is_right_operand_parameterized(self, right_operand_loc, right_operand_value):
        # if right operand value is numeric return with false as it cannot be parameterized.
        if right_operand_value.isnumeric():
            return False,""
        # get functions loc
        functions_loc = list(self.all_func_sigs.keys())
        # sort functions in ascending order of the loc
        functions_loc.sort()
    
        for idx, func_loc in enumerate(functions_loc):
            # if the current_func_loc is larger than the right_operand_loc we continue with the next function
            if func_loc > right_operand_loc:
                continue
            # otherwise, if right_operand_loc is after the current_function_loc
            else:
                # if there is one more next function in the list
                if idx+1 < len(functions_loc):
                    # and if this next_function_loc do not starts the right_operand_loc then we need to continue with the next
                    # function as this current function is not upgrade function yet
                    if functions_loc[idx+1] < right_operand_loc:
                        continue
                # Otherwise, the current function is the upgrade/setter function
                # get its signature  
                func_sig = self.all_func_sigs[func_loc]
                # parse its parameters
                func_params = func_sig.group(2)
                # if there is any param
                if func_params != None and func_params != '()':
                    # remove ()
                    func_params = func_params.replace('(', '').replace(')', '').strip()
                    # split by ","
                    func_params = func_params.split(',')
                    # for each param
                    for param in func_params:
                        # get its identifier
                        param_identifier = param.strip().split(' ')[-1]
                        # check if the right operand is similar to the param
                        if right_operand_value.find(param_identifier.strip()) >= 0 or right_operand_value.find("create.") >=0:
                            if self.DEBUG_MODE:
                                print(right_operand_value, right_operand_loc,"==>", func_loc, func_sig.group(0), param_identifier)
                            return True, func_sig.group(0)
                        elif right_operand_value.find("ext_call") >=0:
                            return False, func_sig.group(0)
                # if param is empty or none
                else:
                    # function with parameters of type calldata do not have parameter in the decompiled bytecode, and instead cd[], or call. stands for call data, represents that one of the function parameters are being used
                    if right_operand_value.find("cd[") >=0 or right_operand_value.find("call.func_hash") >=0: 
                        return True, func_sig.group(0)
                    else:
                        return False, func_sig.group(0)
        return False, ""

    def detect_upgrade_functions(self):
        # =====================================
        # for all impact variables, we again recheck the assingments to get any other variable that is assigned to 
        # input variables. we then check if the assigned value is coming from the method parameters or external call.
        # in the former case, we know that this is a smup. 
        # =======================================================================================================  
        for impact_var in self.all_impact_variables:
            for loc in self.all_assign_exps.keys():
                if re.sub("\s\s+", " ", self.all_assign_exps[loc].group(1)).strip().find(impact_var) >= 0:
                    right_operand_value = re.sub("\s\s+", " ", self.all_assign_exps[loc].group(3)).strip()
                    res, func_sig = self.is_right_operand_parameterized(loc, right_operand_value)
                    if res:
                        self.upgrade_funcs1.add(func_sig)
                    # Example: proxy:0x633c4861a4e9522353eda0bb652878b079fb75fd.sol
                    # imp: 0x9b3be0cc5dd26fd0254088d03d8206792715588b
                    # this means if the assingment to the implementation variable is via external call 
                    # then it means that the external entity overrides/update proxies storage;
                    # yet still the upgrade function is implemented inside the proxy; thus, smup. This is very rare though
                    elif not res and right_operand_value.find("ext_call") >=0:
                        self.upgrade_funcs2.add(func_sig)

    @abstractmethod
    def parse_decompiled_bytecode(self, decompiled_bytecode_lines):
        pass 

    @abstractmethod
    def is_upc(self, address, decompiled_bytecode, bytecode):
        pass