In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip uninstall -y psycopg2 psycopg2-binary
!pip install psycopg2-binary

In [None]:
import pandas as pd
import numpy as np
import psycopg2
import copy
import os
import sys
import simplejson as json
from typing import Dict
pd.options.plotting.backend = "plotly"

from ddx._rust.decimal import Decimal

module_path = os.path.abspath(os.path.join('../'))
if module_path not in sys.path:
    sys.path.append(module_path)

print(sys.path)

from ddx.auditor.auditor_driver import AuditorDriver
from ddx.common.epoch_params import EpochParams
from ddx.common.trade_mining_params import TradeMiningParams
from web3.auto import w3
from ddx._rust.common.state import Item, ItemKind

# Auditor replay

## Postgres

### Connection

In [None]:
# SERVER
db_name = os.environ['OPERATOR_DB_NAME']
db_host = os.environ['PGHOST']
db_password = os.environ['PG_JUPYTER_READONLY_PW']
db_port = os.environ['PGPORT']
db_username = os.environ['PG_JUPYTER_READONLY_ROLE']

In [None]:
# Connect to an existing database
try:
    connection = psycopg2.connect(user=db_username,
                                  password=db_password,
                                  host=db_host,
                                  port=db_port,
                                  database=db_name)

    # Create a cursor to perform database operations
    cursor = connection.cursor()
    # Print PostgreSQL details
    print("PostgreSQL server information")
    print(connection.get_dsn_parameters(), "\n")
    # Executing a SQL query
    cursor.execute("SELECT version();")
    # Fetch result
    record = cursor.fetchone()
    print("You are connected to - ", record, "\n")
finally:
    if connection:
        # cursor.close()
        # connection.close()
        print("PostgreSQL connection is closed (but not really)")

### SQL library

In [None]:
# Get table names
get_table_names_sql = """SELECT table_name FROM information_schema.tables"""

# Get tx log table
get_tx_log_table_sql = """SELECT * FROM tx_log"""

# Get tx log table during epoch id
def get_tx_log_table_during_epoch_id_sql(epoch_id):
    return f"""SELECT * FROM tx_log WHERE epoch_id = {epoch_id}"""

# Get tx log table after epoch id
def get_tx_log_table_after_epoch_id_sql(epoch_id):
    return f"""SELECT * FROM tx_log WHERE epoch_id >= {epoch_id}"""

# Get request log table
get_request_log_table_sql = """SELECT * FROM request.queue"""

# Get request log table
def get_request_log_table_before_request_index_sql(request_index):
    return f"""SELECT * FROM request.queue WHERE request_index <= {request_index}"""

# Get request log table
def get_request_log_table_after_request_index_sql(request_index):
    return f"""SELECT * FROM request.queue WHERE request_index >= {request_index}"""

# Get table
def get_table_before_max_tx_sql(table_name: str, epoch_id: int, tx_ordinal: int):
    return f"""SELECT * FROM {table_name} WHERE (epoch_id < {epoch_id}) OR (epoch_id = {epoch_id} AND tx_ordinal <= {tx_ordinal})"""

# Get state snapshot table for epoch id
def get_state_snapshot_table_for_epoch_id_sql(epoch_id):
    return f"""WITH epoch_ver AS (
      SELECT leaf_key, max(epoch_id) as epoch_id
        FROM verified_state.versions
      WHERE epoch_id <= {epoch_id}
      GROUP BY leaf_key
    )
    SELECT m.leaf_key, s.abi_schema, s.leaf
      FROM verified_state.versions m
    INNER JOIN epoch_ver e ON e.leaf_key = m.leaf_key AND e.epoch_id = m.epoch_id
    INNER JOIN verified_state.items s ON s.leaf_hash = m.leaf_hash;"""

In [None]:
epoch_id = 5

state_snapshot_df = pd.read_sql_query(get_state_snapshot_table_for_epoch_id_sql(epoch_id - 1), connection)

tx_log_df = pd.read_sql_query(get_tx_log_table_after_epoch_id_sql(epoch_id), connection)

def to_camel_case(snake_str):
    components = snake_str.split('_')
    # We capitalize the first letter of each component except the first one
    # with the 'title' method and join them together.
    return components[0] + ''.join(x.title() for x in components[1:])

tx_log_df['state_root_hash'] = tx_log_df['state_root_hash'].apply(lambda x: f'0x{x.hex()}')

tx_log_df.columns = [to_camel_case(col) for col in tx_log_df.columns]
# tx_log_df = tx_log_df[(tx_log_df['epochId'] == epoch_id) | (tx_log_df['epochId'] == epoch_id + 1)]
tx_log_df.sort_values(['requestIndex', 'txOrdinal'], ascending=[True, True], inplace=True)

tx_log_df.tail()

state_snapshot_dict = {f'0x{key.hex()}': f'0x{val.hex()}' for (key, _, val) in state_snapshot_df.to_records(index=False)}

tx_log = tx_log_df.to_dict(orient='records')

In [None]:
state_snapshot_key = '0x0200aa6676733e2e259d879127519d973ac71135f7972576ebd1000000000000'
state_snapshot_value = '0x00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000014000000000000000000000000000000000000000000000000000000000000000030000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000001000000000000000000000000d8ca39d276c05a370174ce872b7278fcf734afac000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000005af3107a40000000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'
state_snapshot_key = bytes.fromhex(state_snapshot_key[2:])
# Peel the item discriminant off (the first byte of the
# leaf key) to determine what kind of leaf it is
item_discriminant = ItemKind(w3.to_int(state_snapshot_key[:1]))

item = Item.abi_decode_value_into_item(
    item_discriminant, bytes.fromhex(state_snapshot_value[2:])
)
item_discriminant, item

In [None]:
auditor_driver = AuditorDriver(
    'http://af1.derivadex.io',
    None,
    EpochParams(
        300,
        30,
        3000,
        9000,
        3000,
        3000,
        100,
    ),
    TradeMiningParams(
        1000,
        (Decimal("35_000_000") / (10 * 365 * 3)).recorded_amount(),
        Decimal("0.2"),
    ),
    '[[1000000, 10000], [0, 10000000]]',
    'geth',
)

auditor_driver._reset()

auditor_driver.process_state_snapshot(epoch_id, chunks[51])

# # Initialize the current local state root hash to the SMT's root
# # hash after having loaded the state snapshot
# auditor_driver.current_state_root_hash = f'0x{auditor_driver.smt.root().as_bytes().hex()}'
# auditor_driver.current_batch_state_root_hash = auditor_driver.current_state_root_hash

# auditor_driver.latest_batch_id = tx_log[0]["batchId"]

# for tx_log_event in tx_log:
#     auditor_driver.process_tx_log_event(tx_log_event, True)

# final_root_hash = f'0x{auditor_driver.smt.root().as_bytes().hex()}'