In [34]:
%run ./BaseUPCDetector.ipynb
%run ./Proxy.ipynb

from abc import ABC, abstractmethod
from overrides import overrides
import copy
from web3 import Web3

In [35]:
class DUPDetector(BaseUPCDetector):

    def __init__(self, impact_var_slots, proxy_detector, bytecode_decompiler):
        super().__init__()
        self.impact_var_slots = set(impact_var_slots)
        self.proxy_detector = proxy_detector
        self.bytecode_decompiler = bytecode_decompiler
        self.trace = dict()
        self.w3 = Web3(Web3.HTTPProvider('https://mainnet.infura.io/v3/bfc43c4acd6b4b15af7d607977d54c8c'))

    def prepare_env(self):
        self.fallback_func_found_sw = 0     # sw when a fallback function is found
        self.proxy_found_sw = 0             # sw when a relevant delegate is found
        self.storage_found_sw = 0           # sw that shows if the storage is found in the decompiled bytecode
        self.curr_func_loc = "global"       # current function loc
        self.state_vars = []                # all parsed state variables
        self.delegatecalls = []             # all relevant delegatecalls
        self.all_assign_exps = dict()       # all assignment expressions
        self.all_func_sigs = dict()         # all function signature <loc, sig>
        self.all_state_vars_code = dict()   # state variables pure code
        self.func_to_assign_exps = dict()   # mapping from function to their assignment expressions
        self.all_func_start_end_line = dict() # all function start line and end line <sig, (start line, end line)>
        self.curr_func_assign_exps = set()  # current function assignment expresssions
        self.piv = set()                    # primary impact variables
        self.siv = set()                    # secondary impact variables
        self.tiv = set()                    # teritiary impact variables
        self.qiv = []                       # quaternary impact variables
        self.implementation_storage = set() #
        self.upgrade_funcs1  = set()        # all upgrade function found for the current relevant delegatecall
        self.upgrade_funcs2  = set()        # all upgrade funciton found for the current relevant delegatecall
        self.all_impact_vars = list()       # all impact variables by name
        self.is_piv_of_type_mapping = False
    
    def locate_impact_variable_in_implementation_storage(self):
        # for each state variable it checks its slot against the impact variable slots we found in the proxy contract.
        for state_var in self.state_vars:
            state_var_slot = self.get_variable_slot(state_var)
            for impact_var_slot in self.impact_var_slots:
                # if the impact variable is not numeric and the current state variable slot is not numeric as well
                if len(state_var_slot) > 5 and len(impact_var_slot) > 5:
                    if state_var_slot.find(impact_var_slot) >=0:
                        self.piv.add(state_var[0].strip())
                        self.implementation_storage.add(state_var_slot.strip())
                # but if one of the is numeric then they should be identical
                else:
                    if state_var_slot == impact_var_slot:
                        self.piv.add(state_var[0].strip())
                        self.implementation_storage.add(state_var_slot.strip())
    
    @overrides
    def parse_decompiled_bytecode(self, decompiled_bytecode_lines):
        self.prepare_env()
        _sw = 0
        for curr_loc, curr_line in enumerate(decompiled_bytecode_lines):
            self.parse_assignment_exp(curr_loc, curr_line)
            self.parse_function_exp(curr_loc, curr_line)
            self.map_assignments_to_function(curr_loc, curr_line)
            self.parse_contract_storage(curr_line)

            # After parsing the storage, we perform an early check to see if at least one of the proxy's
            # impact variable slots is present in the implementation storage.
            # If no match is found, we break out of the loop to avoid unnecessary processing.
            # If a match is found, we set a flag to avoid rechecking in subsequent iterations.
            if _sw == 0 and self.storage_found_sw == 2:
                self.locate_impact_variable_in_implementation_storage()
                if len(self.piv) == 0:
                    _sw = 1
                    break
                else:
                    _sw = 1

    
    def is_diamond_upc(self, address):
        try:
            contract_address = self.w3.to_checksum_address(address)
            # ABI of the DiamondLoupe interface (replace with actual ABI if different)
            abi = """
            [
                {
                    "constant": true,
                    "inputs": [{"name":"_functionSelector","type":"bytes4"}],
                    "name": "facetAddress",
                    "outputs": [{"name":"facetAddress_","type":"address"}],
                    "payable": false,
                    "stateMutability": "view",
                    "type": "function"
                }
            ]
            """
            # Create contract instance
            contract =self.w3.eth.contract(address=contract_address, abi=abi)
            # Function signature
            function_signature = "diamondCut((address,uint8,bytes4[])[],address,bytes)"
            # Calculate function selector
            function_selector = self.w3.keccak(text=function_signature)[:4]
            # Call the facetAddress function
            result = contract.functions.facetAddress(function_selector).call()
            if result.find("0x00000000000000") >=0:
                return False
            else:
                return result.lower()
        except Exception as e:
            return False
    
    @overrides
    def is_upc(self, address, decompiled_bytecode, bytecode):
        if address not in self.trace:
            self.trace[address] = [address]
        
        self.parse_decompiled_bytecode(decompiled_bytecode)
        # if we cannot find any of the proxy's impact variable slots in the implementation.
        
        if len(self.piv) > 0:
            self.detect_secondary_impact_variables()
            self.detect_teritiary_impact_variables()
            self.all_impact_variables = list(set(self.piv.union(self.siv).union(self.tiv).union(self.qiv)))
            if self.DEBUG_MODE:
                print('primary:', self.piv) 
                print('secondary:',self.siv)
                print('tertiary:', self.tiv)
                print('quaternary:', self.qiv)
                print('storage:',self.impact_var_slots)
                print('all imps:',self.all_impact_variables)
                print()
            self.detect_upgrade_functions()

            # if we detect an upgrade function within the implementation contract then we found a dup style.
            if len(self.upgrade_funcs1) > 0 and len(self.upgrade_funcs2) > 0:
                upgrade_func_to_line = [(func, self.all_func_start_end_line[func]) for func in list(self.upgrade_funcs1) + list(self.upgrade_funcs2)]
                return (self.trace[address], 'UPC:DUP:1', upgrade_func_to_line, list(self.impact_var_slots))
            elif len(self.upgrade_funcs1) > 0:
                upgrade_func_to_line = [(func, self.all_func_start_end_line[func]) for func in list(self.upgrade_funcs1)]
                return (self.trace[address], 'UPC:DUP:2', upgrade_func_to_line, list(self.impact_var_slots))
            elif len(self.upgrade_funcs2) > 0:
                upgrade_func_to_line = [(func, self.all_func_start_end_line[func]) for func in list(self.upgrade_funcs2)]
                return (self.trace[address], "UPC:DUP:3", upgrade_func_to_line, list(self.impact_var_slots))
            # elif self.is_piv_of_type_mapping:
            #     dimaond_check_res = self.is_diamond_upc(address)
            #     if dimaond_check_res != False:
            #         return (self.trace[address] + [dimaond_check_res], "UPC:DUP:4", ["diamondCut((address,uint8,bytes4[])[],address,bytes)"], self.impact_var_slots)
        
        # upgrade function not found in the implementation contract. yet, there is a possbility that if the implementation contract itself is a proxy then
        # the implementation contract's implementation contract could implement the upgradeability function.
        elif len(self.piv) == 0 or (len(self.upgrade_funcs1) == 0 and len(self.upgrade_funcs2) == 0):
            # recursivelly call this function for implementation's implementation contracts. 
            if self.proxy_detector.is_proxy(address)[0]:
                # print(address, 'is proxy', self.proxy_detector.is_proxy(address)[1])
                for imp in self.proxy_detector.is_proxy(address)[1]:
                    if imp == address:
                        continue
                    if address in self.trace:
                        _trace = copy.deepcopy(self.trace[address])
                        _trace.append(imp)
                        self.trace[imp] = _trace
                    try:
                        decompiled_bytecode_path = bytecode_decompiler.decompile_contract(imp)
                        if decompiled_bytecode_path.find("Failure") < 0:
                            print("\trecursively calling is_upc function for proxy's {}->implementation contract {}".format(address, imp))
                            res = self.is_upc(imp, open(decompiled_bytecode_path, 'r'), bytecode_decompiler.distinct_bytecodes_hash[bytecode_decompiler.contracts_bytecodes_hash[imp]])  
                            if res and res is not None:
                                # print('\t->', res)
                                return res
                    except Exception as e:
                        print('DUP Detector Error', e)
                        continue
            # else:
            # return ("", "NON-UPC", self.upgrade_funcs2, self.implementation_storage)