In [1]:
import hashlib

STOP = 0x00
ADD = 0x01
MUL = 0x02
SUB = 0x03
DIV = 0x04
SDIV = 0x05
MOD = 0x06
SMOD = 0x07
ADDMOD = 0x08
MULMOD = 0x09
EXP = 0x0A
SIGNEXTEND = 0x0B
LT = 0x10
GT = 0x11
SLT = 0x12
SGT = 0x13
EQ = 0x14
ISZERO = 0x15
AND = 0x16
OR = 0x17
XOR = 0x18
NOT = 0x19
BYTE = 0x1A
SHL = 0x1B
SHR = 0x1C
SAR = 0x1D
SHA3 = 0x20
ADDRESS = 0x30
BALANCE = 0x31
ORIGIN = 0x32
CALLER = 0x33
CALLVALUE = 0x34
CALLDATALOAD = 0x35
CALLDATASIZE = 0x36
CALLDATACOPY = 0x37
CODESIZE = 0x38
CODECOPY = 0x39
GASPRICE = 0x3A
EXTCODESIZE = 0x3B
EXTCODECOPY = 0x3C
EXTCODEHASH = 0x3F
BLOCKHASH = 0x40
COINBASE = 0x41
TIMESTAMP = 0x42
NUMBER = 0x43
PREVRANDAO = 0x44
GASLIMIT = 0x45
CHAINID = 0x46
SELFBALANCE = 0x47
BASEFEE = 0x48
PUSH0 = 0x5F
PUSH1 = 0x60
PUSH32 = 0x7F
DUP1 = 0x80
DUP16 = 0x8F
SWAP1 = 0x90
SWAP16 = 0x9F
POP = 0x50
MLOAD = 0x51
MSTORE = 0x52
MSTORE8 = 0x53
SLOAD = 0x54
SSTORE = 0x55
JUMP = 0x56
JUMPI = 0x57
PC = 0x58
MSIZE = 0x59
JUMPDEST = 0x5B
LOG0 = 0xA0
LOG1 = 0xA1
LOG2 = 0xA2
LOG3 = 0xA3
LOG4 = 0xA4
RETURN = 0xF3
RETURNDATASIZE = 0x3D
RETURNDATACOPY = 0x3E

account_db = {
    '0x9bbfed6889322e016e0a02ee459d306fc19545d8': {
        'balance': 100, # wei
        'nonce': 1, 
        'storage': {},
        'code': b'\x60\x00\x60\x00'  # Sample bytecode (PUSH1 0x00 PUSH1 0x00)
    },
}

class Transaction:
    def __init__(self, to = '', value = 0, data = '', caller='0x00', origin='0x00', thisAddr='0x00', gasPrice=1, gasLimit=21000, nonce=0, v=0, r=0, s=0):
        self.nonce = nonce
        self.gasPrice = gasPrice
        self.gasLimit = gasLimit
        self.to = to
        self.value = value
        self.data = data
        self.caller = caller
        self.origin = origin
        self.thisAddr = thisAddr
        self.v = v
        self.r = r
        self.s = s

class StopException(Exception):
    pass

class Log:
    def __init__(self, address, data, topics=[]):
        self.address = address
        self.data = data
        self.topics = topics

    def __str__(self):
        return f'Log(address={self.address}, data={self.data}, topics={self.topics})'

class EVM:
    def __init__(self, code, txn = None):
        self.code = code
        self.pc = 0
        self.stack = []
        self.memory = bytearray()
        self.storage = {}
        self.txn = txn
        self.logs = []
        self.returnData = bytearray()
        self.current_block = {
            "blockhash": 0x7527123fc877fe753b3122dc592671b4902ebf2b325dd2c7224a43c0cbeee3ca,
            "coinbase": 0x388C818CA8B9251b393131C08a736A67ccB19297,
            "timestamp": 1625900000,
            "number": 17871709,
            "prevrandao": 0xce124dee50136f3f93f19667fb4198c6b94eecbacfa300469e5280012757be94,
            "gaslimit": 30,
            "chainid": 1,
            "selfbalance": 100,
            "basefee": 30,
        }

    def next_instruction(self):
        op = self.code[self.pc]
        self.pc += 1
        return op

    def push(self, size):
        data = self.code[self.pc:self.pc + size]
        value = int.from_bytes(data, 'big')
        self.stack.append(value)
        self.pc += size

    def pop(self):
        if len(self.stack) == 0:
            raise Exception('Stack underflow')
        return self.stack.pop()

    def add(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        res = (a + b) % (2**256)
        self.stack.append(res)
        
    def mul(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        res = (a * b) % (2**256)
        self.stack.append(res)

    def sub(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        res = (a - b) % (2**256)
        self.stack.append(res)

    def div(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        if a == 0:
            res = 0
        else:
            res =  (a // b) % (2**256)
        self.stack.append(res)

    def sdiv(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        res = a//b % (2**256) if a!=0 else 0
        self.stack.append(res)

    def mod(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        res = a % b if a != 0 else 0
        self.stack.append(res)

    def smod(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        res = a % b if a != 0 else 0
        self.stack.append(res)

    def addmod(self):
        if len(self.stack) < 3:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        n = self.stack.pop()
        res = (a + b) % n if n != 0 else 0
        self.stack.append(res)

    def mulmod(self):
        if len(self.stack) < 3:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        n = self.stack.pop()
        res = (a * b) % n if n != 0 else 0
        self.stack.append(res)

    def exp(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        res = pow(a, b) % (2**256)
        self.stack.append(res)
        
    def signextend(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        b = self.stack.pop()
        x = self.stack.pop()
        if b < 32:
            sign_bit = 1 << (8 * b - 1)
            x = x & ((1 << (8 * b)) - 1)
            if x & sign_bit:
                x = x | ~((1 << (8 * b)) - 1)
        self.stack.append(x)
        
    def lt(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append(int(b < a))

    def gt(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append(int(b > a))
    def slt(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append(int(b < a))

    def sgt(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append(int(b > a))

    def eq(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append(int(a == b))

    def iszero(self):
        if len(self.stack) < 1:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        self.stack.append(int(a == 0))

    def and_op(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append(a & b)

    def or_op(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append(a | b)

    def xor_op(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append(a ^ b)

    def not_op(self):
        if len(self.stack) < 1:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        self.stack.append(~a % (2**256))

    def byte_op(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        position = self.stack.pop()
        value = self.stack.pop()
        if position >= 32:
            res = 0
        else:
            res = (value // pow(256, 31 - position)) & 0xFF
        self.stack.append(res)

    def shl(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append((b << a) % (2**256))
    
    def shr(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append(b >> a)
        
    def sar(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.append(b >> a)

    def mstore(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        offset = self.stack.pop()
        value = self.stack.pop()
        while len(self.memory) < offset + 32:
            self.memory.append(0)
        self.memory[offset:offset+32] = value.to_bytes(32, 'big')

    def mstore8(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        offset = self.stack.pop()
        value = self.stack.pop()
        while len(self.memory) < offset + 32:
            self.memory.append(0)
        self.memory[offset] = value & 0xFF

    def mload(self):
        if len(self.stack) < 1:
            raise Exception('Stack underflow')
        offset = self.stack.pop()
        while len(self.memory) < offset + 32:
            self.memory.append(0)
        value = int.from_bytes(self.memory[offset:offset+32], 'big')
        self.stack.append(value)

    def sload(self):
        if len(self.stack) < 1:
            raise Exception('Stack underflow')
        key = self.stack.pop()
        value = self.storage.get(key, 0)
        self.stack.append(value)

    def sstore(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        key = self.stack.pop()
        value = self.stack.pop()
        self.storage[key] = value

    def jump(self):
        if len(self.stack) < 1:
            raise Exception('Stack underflow')
        destination = self.stack.pop()
        if self.code[destination] != JUMPDEST:
            raise Exception('Invalid jump destination')
        self.pc = destination

    def jumpi(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        destination = self.stack.pop()
        condition = self.stack.pop()
        if condition != 0:
            if self.code[destination] != JUMPDEST:
                raise Exception('Invalid jump destination')
            self.pc = destination

    def pc(self):
        self.stack.append(self.pc)

    def msize(self):
        self.stack.append(len(self.memory))

    def jumpdest(self):
        pass

    def blockhash(self):
        if len(self.stack) < 1:
            raise Exception('Stack underflow')
        number = self.stack.pop()
        if number == self.current_block["number"]:
            self.stack.append(self.current_block["blockhash"])
        else:
            self.stack.append(0)

    def coinbase(self):
        self.stack.append(self.current_block["coinbase"])

    def timestamp(self):
        self.stack.append(self.current_block["timestamp"])

    def number(self):
        self.stack.append(self.current_block["number"])
        
    def prevrandao(self):
        self.stack.append(self.current_block["prevrandao"])
        
    def gaslimit(self):
        self.stack.append(self.current_block["gaslimit"])

    def chainid(self):
        self.stack.append(self.current_block["chainid"])

    def selfbalance(self):
        self.stack.append(self.current_block["selfbalance"])

    def basefee(self):
        self.stack.append(self.current_block["basefee"])

    def dup(self, position):
        if len(self.stack) < position:
            raise Exception('Stack underflow')
        value = self.stack[-position]
        self.stack.append(value)

    def swap(self, position):
        if len(self.stack) < position + 1:
            raise Exception('Stack underflow')
        idx1, idx2 = -1, -position - 1
        self.stack[idx1], self.stack[idx2] = self.stack[idx2], self.stack[idx1]
        
    def sha3(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')

        offset = self.pop()
        size = self.pop()
        data = self.memory[offset:offset+size]
        hash_value = int.from_bytes(hashlib.sha3_256(data).digest(), 'big')
        self.stack.append(hash_value)

    def balance(self):
        if len(self.stack) < 1:
            raise Exception('Stack underflow')
        addr_int = self.stack.pop()
        addr_str = '0x' + addr_int.to_bytes(20, byteorder='big').hex()
        self.stack.append(account_db.get(addr_str, {}).get('balance', 0))

    def extcodesize(self):
        if len(self.stack) < 1:
            raise Exception('Stack underflow')
        addr_int = self.stack.pop()
        addr_str = '0x' + addr_int.to_bytes(20, byteorder='big').hex()
        self.stack.append(len(account_db.get(addr_str, {}).get('code', b'')))

    def extcodecopy(self):
        if len(self.stack) < 4:
            raise Exception('Stack underflow')
        addr_int = self.stack.pop()
        mem_offset = self.stack.pop()
        code_offset = self.stack.pop()
        length = self.stack.pop()
        addr_str = '0x' + addr_int.to_bytes(20, byteorder='big').hex()
        code = account_db.get(addr_str, {}).get('code', b'')[code_offset:code_offset+length]
        while len(self.memory) < mem_offset + length:
            self.memory.append(0)

        self.memory[mem_offset:mem_offset+length] = code

    def extcodehash(self):
        if len(self.stack) < 1:
            raise Exception('Stack underflow')
        addr_int = self.stack.pop()
        addr_str = '0x' + addr_int.to_bytes(20, byteorder='big').hex()
        code = account_db.get(addr_str, {}).get('code', b'')        
        code_hash = int.from_bytes(hashlib.sha3_256(code).digest(), 'big')
        self.stack.append(code_hash)

    def address(self):
        self.stack.append(self.txn.thisAddr)

    def origin(self):
        self.stack.append(self.txn.origin)

    def caller(self):
        self.stack.append(self.txn.caller)

    def callvalue(self):
        self.stack.append(self.txn.value)

    def calldataload(self):
        if len(self.stack) < 1:
            raise Exception('Stack underflow')
        offset = self.stack.pop()
        calldata_bytes = bytes.fromhex(self.txn.data[2:])
        data = bytearray(32)
        for i in range(32):
            if offset + i < len(calldata_bytes):
                data[i] = calldata_bytes[offset + i]
        self.stack.append(int.from_bytes(data, 'big'))

    def calldatasize(self):
        # Assuming calldata is a hex string with a '0x' prefix
        size = (len(self.txn.data) - 2) // 2
        self.stack.append(size)

    def calldatacopy(self):
        if len(self.stack) < 3:
            raise Exception('Stack underflow')
        mem_offset = self.stack.pop()
        calldata_offset = self.stack.pop()
        length = self.stack.pop()

        if len(self.memory) < mem_offset + length:
            self.memory.extend([0] * (mem_offset + length - len(self.memory)))

        calldata_bytes = bytes.fromhex(self.txn.data[2:])  # Assuming it's prefixed with '0x'

        for i in range(length):
            if calldata_offset + i < len(calldata_bytes):
                self.memory[mem_offset + i] = calldata_bytes[calldata_offset + i]

    def codesize(self):
        addr = self.txn.thisAddr
        self.stack.append(len(account_db.get(addr, {}).get('code', b'')))

    def codecopy(self):
        if len(self.stack) < 3:
            raise Exception('Stack underflow')

        mem_offset = self.stack.pop()
        code_offset = self.stack.pop()
        length = self.stack.pop()

        addr = self.txn.thisAddr
        code = account_db.get(addr, {}).get('code', b'')

        if len(self.memory) < mem_offset + length:
            self.memory.extend([0] * (mem_offset + length - len(self.memory)))

        for i in range(length):
            if code_offset + i < len(code):
                self.memory[mem_offset + i] = code[code_offset + i]
            
    def gasprice(self):
        self.stack.append(self.txn.gasPrice)

    def log(self, num_topics):
        if len(self.stack) < 2 + num_topics:
            raise Exception('Stack underflow')

        mem_offset = self.stack.pop()
        length = self.stack.pop()
        topics = [self.stack.pop() for _ in range(num_topics)]

        data = self.memory[mem_offset:mem_offset + length]
        log_entry = {
            "address": self.txn.thisAddr,
            "data": data.hex(),
            "topics": [f"0x{topic:064x}" for topic in topics]
        }
        self.logs.append(log_entry)

    def return_op(self):
        if len(self.stack) < 2:
            raise Exception('Stack underflow')
        
        mem_offset = self.stack.pop()
        length = self.stack.pop()

        if len(self.memory) < mem_offset + length:
            self.memory.extend([0] * (mem_offset + length - len(self.memory)))

        self.returnData = self.memory[mem_offset:mem_offset + length]

    def returndatasize(self):
        self.stack.append(len(self.returnData))

    def returndatacopy(self):
        if len(self.stack) < 3:
            raise Exception('Stack underflow')

        mem_offset = self.stack.pop()
        return_offset = self.stack.pop()
        length = self.stack.pop()

        if return_offset + length > len(self.returnData):
            raise Exception("Invalid returndata size")

        # Expand Memory
        if len(self.memory) < mem_offset + length:
            self.memory.extend([0] * (mem_offset + length - len(self.memory)))

        # Copying using slices
        self.memory[mem_offset:mem_offset + length] = self.returnData[return_offset:return_offset + length]

    def run(self):
        while self.pc < len(self.code):
            op = self.next_instruction()

            if PUSH1 <= op <= PUSH32:
                size = op - PUSH1 + 1
                self.push(size)
            elif op == PUSH0:
                self.stack.append(0)
            elif DUP1 <= op <= DUP16:
                position = op - DUP1 + 1
                self.dup(position)
            elif SWAP1 <= op <= SWAP16:
                position = op - SWAP1 + 1
                self.swap(position)
            elif op == POP:
                self.pop()
            elif op == ADD:
                self.add()
            elif op == MUL:
                self.mul()
            elif op == SUB:
                self.sub()
            elif op == DIV:
                self.div()
            elif op == SDIV:
                self.sdiv()
            elif op == MOD:
                self.mod()
            elif op == SMOD:
                self.smod()
            elif op == ADDMOD:
                self.addmod()
            elif op == MULMOD:
                self.mulmod()
            elif op == EXP:
                self.exp()
            elif op == SIGNEXTEND:
                self.signextend()
            elif op == LT:
                self.lt()
            elif op == GT:
                self.gt()
            elif op == SLT:
                self.slt()
            elif op == SGT:
                self.sgt()
            elif op == EQ:
                self.eq()
            elif op == ISZERO:
                self.iszero()
            elif op == AND:
                self.and_op()
            elif op == OR:
                self.or_op()
            elif op == XOR:
                self.xor_op()
            elif op == NOT:
                self.not_op()
            elif op == BYTE:
                self.byte_op()
            elif op == SHL:
                self.shl()
            elif op == SHR:
                self.shr()
            elif op == SAR:
                self.sar()
            elif op == MLOAD:
                self.mload()
            elif op == MSTORE:
                self.mstore()
            elif op == MSTORE8:
                self.mstore8()
            elif op == SLOAD: 
                self.sload()
            elif op == SSTORE:
                self.sstore()
            elif op == MSIZE:
                self.msize()
            elif op == JUMP: 
                self.jump()
            elif op == JUMPDEST: 
                self.jumpdest()
            elif op == JUMPI: 
                self.jumpi()
            elif op == STOP:
                print('Program has been stopped')
                break
            elif op == PC:
                self.pc()
            elif op == BLOCKHASH:
                self.blockhash()
            elif op == COINBASE:
                self.coinbase()
            elif op == TIMESTAMP:
                self.timestamp()
            elif op == NUMBER:
                self.number()
            elif op == PREVRANDAO:
                self.prevrandao()
            elif op == GASLIMIT:
                self.gaslimit()
            elif op == CHAINID:
                self.chainid()
            elif op == SELFBALANCE:
                self.selfbalance()
            elif op == BASEFEE:
                self.basefee()        
            elif op == SHA3:
                self.sha3()
            elif op == BALANCE: 
                self.balance()
            elif op == EXTCODESIZE: 
                self.extcodesize()
            elif op == EXTCODECOPY: 
                self.extcodecopy()
            elif op == EXTCODEHASH: 
                self.extcodehash()
            elif op == ADDRESS: 
                self.address()
            elif op == ORIGIN: 
                self.origin()
            elif op == CALLER: 
                self.caller()
            elif op == CALLVALUE: 
                self.callvalue()
            elif op == CALLDATALOAD: 
                self.calldataload()
            elif op == CALLDATASIZE: 
                self.calldatasize()
            elif op == CALLDATACOPY: 
                self.calldatacopy()
            elif op == CODESIZE: 
                self.codesize()
            elif op == CODECOPY: 
                self.codecopy()
            elif op == GASPRICE: 
                self.gasprice()
            elif op == EXTCODEHASH: 
                self.extcodehash()
            elif op == LOG0:
                self.log(0)
            elif op == LOG1:
                self.log(1)
            elif op == LOG2:
                self.log(2)
            elif op == LOG3:
                self.log(3)
            elif op == LOG4:
                self.log(4)
            elif op == RETURN:
                self.return_op()
            elif op == RETURNDATASIZE:
                self.returndatasize()
            elif op == RETURNDATACOPY:
                self.returndatacopy()

            else:
                raise Exception('Invalid opcode')

In [2]:
# RETURN
code = b"\x60\xa2\x60\x00\x52\x60\x01\x60\x1f\xf3"
evm = EVM(code)
evm.run()
print(evm.returnData.hex())
# output: a2

a2


In [3]:
# RETURNDATASIZE
code = b"\x3D"
evm = EVM(code)
evm.returnData = b"\xaa\xaa"
evm.run()
print(evm.stack)
# output: 2

[2]


In [4]:
# RETURNDATACOPY
code = b"\x60\x02\x5F\x5F\x3E"
evm = EVM(code)
evm.returnData = b"\xaa\xaa"
evm.run()
print(evm.memory.hex())
# output: aaaa

aaaa
