In [103]:
%run ./BaseUPCDetector.ipynb
%run ./Proxy.ipynb

from abc import ABC, abstractmethod
from overrides import overrides
import copy

In [105]:
class ESUPDetector(BaseUPCDetector):

    def __init__(self, imp_get_selector, proxy_detector, bytecode_decompiler):
        super().__init__()
        self.imp_get_selector = imp_get_selector
        self.proxy_detector = proxy_detector
        self.bytecode_decompiler = bytecode_decompiler
        self.trace = dict()
        self.func_to_return_exps = dict()
        self.curr_fnuc_return_exps = set()

    def prepare_env(self):
        super().prepare_env()
        self.func_to_return_exps = dict()
        self.curr_fnuc_return_exps = set()
    
    def parse_return_exp(self, curr_line):
        # check if the current line is an assingment, then store it
        if curr_line.find("return ") >=0:
            self.curr_fnuc_return_exps.add(curr_line.strip()) # add current assingment to the list of assignments for the current function  
    
    def map_returns_to_function(self, curr_loc, curr_line):
        if curr_line == "\n" and self.curr_func_loc != "global": # then it means the previous function is finished
            self.func_to_return_exps[self.all_func_sigs[self.curr_func_loc].group(0)] = self.curr_fnuc_return_exps # now save the list of assignment for the function
    
    def parse_function_exp(self, curr_loc, curr_line):
        # check if the current line is a function, then store it
        afs = re.search(self.FUNC_SIG_REGEX, curr_line)
        if afs:
            self.all_func_sigs[curr_loc] = afs # store the function with its line
            self.all_func_sigs = dict(sorted(self.all_func_sigs.items()))  # sort function definitions in ascending fashion in terms of line number
            self.curr_func_loc = curr_loc # update the current function 
            self.curr_func_assign_exps = set() # reset the list of function's assignment when observe a new function 
            self.curr_fnuc_return_exps = set()
            self.all_func_start_end_line[afs.group(0)] = (curr_loc,0)
    @overrides
    def parse_decompiled_bytecode(self, decompiled_bytecode_lines):
        self.prepare_env()
        for curr_loc, curr_line in enumerate(decompiled_bytecode_lines):
            self.parse_assignment_exp(curr_loc, curr_line)
            self.parse_function_exp(curr_loc, curr_line)
            self.parse_return_exp(curr_line)
            self.map_assignments_to_function(curr_loc, curr_line)
            self.map_returns_to_function(curr_loc, curr_line)
            self.parse_contract_storage(curr_line)
    def compute_function_signature(self, func):
        if func.find("(") < 0:
            return func
        func = func.replace("def ", "").strip()
        func = func.split(":")[0]
        first_chunk = func.split("(")[0]
        first_chunk = first_chunk.strip()
        sec_chunk = func.split("(")[1].replace("(", "").replace(")", "")
        param_chunk = []
        if len(sec_chunk) > 0:
            if sec_chunk.find(",") >= 0:
                for ch in sec_chunk.split(","):
                    ch2 = ch.rstrip().lstrip()
                    param_chunk.append(ch2.split(" ")[0])
            else:
                param_chunk.append(sec_chunk.split(" ")[0])
            return first_chunk + "(" + ",".join(param_chunk) + ")"
        else:
            return first_chunk + "()"

    def compute_function_selector(self, func_sig):
        if func_sig.find("unknown") >= 0:
            return "0x" + func_sig[len("unknown"): len("unknown") + 8].strip()
        else:
            hex_bytes = Web3.keccak(text=func_sig)
            return str(hex_bytes[0:4].hex())
        
    def find_imp_get_function(self):
        for func in self.func_to_return_exps:
            func_selector = self.compute_function_selector(self.compute_function_signature(func))
            # print(func, self.compute_function_signature(func), func_selector)
            if func_selector == self.imp_get_selector:
                return func, self.func_to_return_exps[func]
        return None, None

    def detect_primary_impact_variables(self, curr_delegatecall):
        for var in self.state_vars:
            # the following two condition ensures that the implementation variable fetched out of delegatecall matches storage variables correctly
            var_name = var[0].strip()
            if curr_delegatecall.find(var_name)>=0:
                self.piv.add(var_name)
                self.impact_var_slots.add(self.get_variable_slot(var))
                self.implementation_storage.add(self.all_state_vars_code[var_name])
                if self.all_state_vars_code[var_name].find(" mapping ")>=0:
                    self.is_piv_of_type_mapping = True
    
    @overrides
    def is_upc(self, address, decompiled_bytecode, bytecode):
        if address not in self.trace:
            self.trace[address] = [address]
        self.parse_decompiled_bytecode(decompiled_bytecode)        
        imp_get_func = self.find_imp_get_function()
        # if there is at least one get function in the beacon
        if imp_get_func[0] is not None:
            if len(imp_get_func[1]) > 0:
                for return_exp in imp_get_func[1]:
                    self.detect_primary_impact_variables(return_exp)
                if len(self.piv) > 0:
                    self.detect_secondary_impact_variables()
                    self.detect_teritiary_impact_variables()
                    self.all_impact_variables = list(set(self.piv.union(self.siv).union(self.tiv).union(self.qiv)))
                    self.detect_upgrade_functions()

                    if len(self.upgrade_funcs1) > 0 and len(self.upgrade_funcs2) > 0:
                        upgrade_func_to_line = [(func, self.all_func_start_end_line[func]) for func in list(self.upgrade_funcs1) + list(self.upgrade_funcs2)]
                        return (self.trace[address], 'UPC:ESUP:1', upgrade_func_to_line, list(self.impact_var_slots), list(self.state_vars), list(self.implementation_storage), list(self.qiv))
                    elif len(self.upgrade_funcs1) > 0:
                        upgrade_func_to_line = [(func, self.all_func_start_end_line[func]) for func in list(self.upgrade_funcs1)]
                        return (self.trace[address], 'UPC:ESUP:2', upgrade_func_to_line, list(self.impact_var_slots), list(self.state_vars), list(self.implementation_storage), list(self.qiv))
                    elif len(self.upgrade_funcs2) > 0:
                        upgrade_func_to_line = [(func, self.all_func_start_end_line[func]) for func in list(self.upgrade_funcs2)]
                        return (self.trace[address], "UPC:ESUP:3", upgrade_func_to_line, list(self.impact_var_slots), list(self.state_vars), list(self.implementation_storage), list(self.qiv))
                    else: 
                        print('ESUP DETECTOR: imp get function found but no upgrade functions in the external contract')
                else:
                    print('ESUP DETECTOR: imp get function found but the corresponding no impact variables are found')
            # if there is no return statement in the imp: this cannot happend but just for our internal check.
            else:
                print("ESUP DETECTOR: imp get function found but no return statements are found")
                # no return statement in the get func. this should not be happending in reality
                pass
        
        # else implementation is not found in the proxy lets check if the beacon is a proxy itself then call detector recursively for beacon's implemetation contracts.
        elif self.proxy_detector.is_proxy(address)[0]:
            for imp in self.proxy_detector.is_proxy(address)[1]:
                if imp == address:
                    continue
                if address in self.trace:
                    _trace = copy.deepcopy(self.trace[address])
                    _trace.append(imp)
                    self.trace[imp] = _trace
                try:
                    decompiled_bytecode_path = bytecode_decompiler.decompile_contract(imp)
                    print(decompiled_bytecode_path)
                    if decompiled_bytecode_path.find("Failure") < 0:
                        print("ESUP detector recursive call!")
                        res = self.is_upc(imp, open(decompiled_bytecode_path, 'r'), bytecode_decompiler.distinct_bytecodes_hash[bytecode_decompiler.contracts_bytecodes_hash[imp]])  
                        if res and res is not None:
                            return res
                    
                except Exception as e:
                    print('ESUP Detector Error', e)
                    continue
        
        # if beacon is not a proxy then its fallback function is the get function
        elif not self.proxy_detector.is_proxy(address)[0]:
            fallback_return_exp = []
            for func in self.func_to_return_exps.keys():
                if func.find('_fallback()') >=0:
                    fallback_return_exp = self.func_to_return_exps[func]
                    break

            if len(fallback_return_exp) > 0:
                for return_exp in fallback_return_exp:
                    self.detect_primary_impact_variables(return_exp)
                if len(self.piv) > 0:
                    self.detect_secondary_impact_variables()
                    self.detect_teritiary_impact_variables()   
                    self.all_impact_variables = list(set(self.piv.union(self.siv).union(self.tiv).union(self.qiv)))
                    self.detect_upgrade_functions()
                    if len(self.upgrade_funcs1) > 0 and len(self.upgrade_funcs2) > 0:
                        upgrade_func_to_line = [(func, self.all_func_start_end_line[func]) for func in list(self.upgrade_funcs1) + list(self.upgrade_funcs2)]
                        return (self.trace[address], 'UPC:ESUP:1', upgrade_func_to_line, list(self.impact_var_slots), list(self.state_vars), list(self.implementation_storage), list(self.qiv))
                    elif len(self.upgrade_funcs1) > 0:
                        upgrade_func_to_line = [(func, self.all_func_start_end_line[func]) for func in list(self.upgrade_funcs1)]
                        return (self.trace[address], 'UPC:ESUP:2', upgrade_func_to_line, list(self.impact_var_slots), list(self.state_vars), list(self.implementation_storage), list(self.qiv))
                    elif len(self.upgrade_funcs2) > 0:
                        upgrade_func_to_line = [(func, self.all_func_start_end_line[func]) for func in list(self.upgrade_funcs2)]
                        return (self.trace[address], "UPC:ESUP:3", upgrade_func_to_line, list(self.impact_var_slots), list(self.state_vars), list(self.implementation_storage), list(self.qiv))