In [7]:
import pippenger
import blst
import hashlib
from random import randint, shuffle
from poly_utils import PrimeField
from time import time
from kzg_utils import KzgUtils
from fft import fft
import sys


In [768]:
import blst
import hashlib
from poly_utils import PrimeField
from kzg_utils import KzgUtils
from fft import fft
from time import time
from random import randint, shuffle


# General functions

def int_to_bytes(x: int) -> bytes:
    return x.to_bytes(32, "little")


def int_from_bytes(x: bytes) -> int:
    return int.from_bytes(x, "little")


def hash(x):
    if isinstance(x, bytes):
        return hashlib.sha256(x).digest()
    elif isinstance(x, blst.P1):
        return hash(x.compress())
    b = b""
    for a in x:
        if isinstance(a, bytes):
            b += a
        elif isinstance(a, int):
            b += a.to_bytes(32, "little")
        elif isinstance(a, blst.P1):
            b += hash(a.compress())
    return hash(b)


def hash_to_int(x):
    return int.from_bytes(hash(x), "little")


class KzgIntegration:
    def __init__(self, modulus: int, width: int, primitive_root: int):
        self.modulus = modulus
        self.width = width
        assert pow(primitive_root, (modulus - 1) // width, modulus) != 1
        assert pow(primitive_root, modulus - 1, modulus) == 1
        self.root_of_unity = pow(primitive_root, (modulus - 1) // width, modulus)

    def generate_setup(self, size, secret):
        """
        Generates a setup in the G1 group and G2 group, as well as the Lagrange polynomials in G1 (via FFT)
        """
        g1_setup = [blst.G1().mult(pow(secret, i, self.modulus))
                    for i in range(size)]
        g2_setup = [blst.G2().mult(pow(secret, i, self.modulus))
                    for i in range(size)]
        g1_lagrange = fft(g1_setup, self.modulus, self.root_of_unity, inv=True)
        return {"g1": g1_setup, "g2": g2_setup, "g1_lagrange": g1_lagrange}

    def kzg_utils(self, setup: dict):
        primefield = PrimeField(self.modulus, self.width)
        domain = [pow(self.root_of_unity, i, self.modulus) for i in range(self.width)]
        return KzgUtils(self.modulus, self.width, domain, setup, primefield)


class VerkleBSTNode(object):
    def __init__(self, key: bytes, value: bytes):
        self.value = value
        self.key = key
        self.left = None
        self.right = None
        self.hash = None
        self.commitment = None

    def node_hash(self):
        if self.is_leaf():
            self.hash = hash([self.key, self.value])
        else:
            self.hash = hash([self.commitment.compress(), self.key, self.value])

    def is_leaf(self) -> bool:
        return self.left is None and self.right is None


class VerkleBST:
    def __init__(self, setup: dict, kzg: KzgUtils, root: VerkleBSTNode, modulus: int, width: int):
        self.setup = setup
        self.kzg = kzg
        self.root = root
        self.modulus = modulus
        self.width = width

    def _insert(self, node: VerkleBSTNode, key: bytes, value: bytes, update: bool = False):
        """
        Insert command for the tree
        """

        if node is None:
            return VerkleBSTNode(key, value)

        if key == node.key:
            if update:
                node.value = value
        elif key < node.key:
            node.left = self._insert(node.left, key, value)
        elif key > node.key:
            node.right = self._insert(node.right, key, value)
        return node

    def insert_node(self, key: bytes, value: bytes):
        """
        Insert a node into the tree
        """
        self.root = self._insert(self.root, key, value)

    def upsert_verkle_node(self, key: bytes, value: bytes):
        """
        Insert or update a node in the tree and update the hashes/commitments
        """

        root = self.root

        path = self.find_path_to_node(root, key)
        last_node = path[-1][1]

        # Insert
        if last_node is None:
            path.pop()
            self._insert(path[-1][1], key, value)
            new_node = self.find_node(path[-1][1], key)
            new_node.node_hash()
            path.append((None, new_node))
            value_change = int_from_bytes(new_node.hash) % self.modulus

        # Update
        elif last_node.key == key:
            old_hash = last_node.hash
            last_node.value = value
            last_node.node_hash()
            new_hash = last_node.hash
            value_change = (int_from_bytes(new_hash) - int_from_bytes(old_hash) + self.modulus) % self.modulus

        for edge, node in reversed(path):
            if edge is None:
                continue

            old_hash = node.hash
            if node.commitment is None:
                self.add_node_hash(node)
            else:
                node.commitment.add(self.setup["g1_lagrange"][edge].dup().mult(value_change))
                node.node_hash()
            new_hash = node.hash
            value_change = (int_from_bytes(new_hash) - int_from_bytes(old_hash) + self.modulus) % self.modulus

    def delete_verkle_node(self, key: bytes):
        """
        Delete a node in the tree and update the hashes/commitments
        """
        root = self.root

        node = self.find_node(root, key)
        if node is None:
            return

        children = sum(1 for child in [node.left, node.right] if child is not None)

        # Leaf node
        if children == 0:
            path = self.find_path_to_node(root, key)
            node_to_delete = path[-1][1]
            path.pop()
            node_to_update = path[-1][1]
            if path[-1][0] == 0:
                node_to_update.left = None
            elif path[-1][0] == 1:
                node_to_update.right = None
            value_change = (- int_from_bytes(node_to_delete.hash) + self.modulus) % self.modulus
            del node_to_delete

        # Parent with only child
        elif children == 1:
            path = self.find_path_to_node(root, key)
            node_to_delete = path[-1][1]
            node_to_pullup = next(child for child in [node_to_delete.left, node_to_delete.right] if child is not None)
            path.pop()
            node_to_update = path[-1][1]
            if path[-1][0] == 0:
                node_to_update.left = node_to_pullup
            elif path[-1][0] == 1:
                node_to_update.right = node_to_pullup
            value_change = (int_from_bytes(node_to_pullup.hash) - int_from_bytes(node_to_delete.hash) + self.modulus) % self.modulus
            del node_to_delete

        # Parent with two children
        elif children == 2:
            inorder_succ = self.find_min(node.right)
            path = self.find_path_to_node(root, inorder_succ.key)
            node.key = inorder_succ.key
            node.value = inorder_succ.value
            node_to_delete = inorder_succ
            path.pop()
            node_to_update = path[-1][1]
            if path[-1][0] == 0:  # Same as node != node_to_update
                node_to_update.left = node_to_delete.right
            elif path[-1][0] == 1:  # Same as node == node_to_update
                node_to_update.right = node_to_delete.right

            if node_to_delete.is_leaf():
                value_change = (- int_from_bytes(node_to_delete.hash) + self.modulus) % self.modulus
            else:
                value_change = (int_from_bytes(node_to_delete.right.hash) - int_from_bytes(node_to_delete.hash)
                                + self.modulus) % self.modulus
            del node_to_delete

        for edge, node in reversed(path):
            old_hash = node.hash
            if node.commitment is None:
                self.add_node_hash(node)
            else:
                node.commitment.add(self.setup["g1_lagrange"][edge].dup().mult(value_change))
                node.node_hash()
            new_hash = node.hash
            value_change = (int_from_bytes(new_hash) - int_from_bytes(old_hash) + self.modulus) % self.modulus

    def find_min(self, node: VerkleBSTNode):
        """
        Find the minimum node in the tree
        """
        while node.left is not None:
            node = node.left
        return node

    def find_node(self, node: VerkleBSTNode, key: bytes):
        """
        Search for a node in the tree
        """
        while node is not None:
            if key == node.key:
                return node
            elif key < node.key:
                node = node.left
            elif key > node.key:
                node = node.right
        return None

    def find_path_to_node(self, node: VerkleBSTNode, key: bytes):
        """
        Returns the path from node to a node with key with the last element being none if the node does not exist
        """
        path = []
        while node is not None:
            if key == node.key:
                path.append((None, node))
                break
            elif key < node.key:
                edge = 0  # edge 0 for left
            elif key > node.key:
                edge = 1  # edge 1 for right
            path.append((edge, node))
            node = node.left if edge == 0 else node.right

        if node is None:
            path.append((None, None))

        return path

    def add_node_hash(self, node: VerkleBSTNode):
        """
        Add the hash of a node to the node itself
        """
        if node.is_leaf():
            node.node_hash()
        else:
            values = {}
            nodes = [node.left, node.right]
            for i in range(len(nodes)):
                if nodes[i] is None:
                    continue

                if nodes[i].hash is None:
                    self.add_node_hash(nodes[i])
                values[i] = int_from_bytes(nodes[i].hash)
            commitment = self.kzg.compute_commitment_lagrange(values)
            node.commitment = commitment
            node.node_hash()

    def check_valid_tree(self, node: VerkleBSTNode):
        """
        Check if the tree is valid
        """

        if node.is_leaf():
            assert node.hash == hash([node.key, node.value])
        else:
            values = {}
            nodes = [node.left, node.right]
            for i in range(len(nodes)):
                if nodes[i] is None:
                    continue

                if nodes[i].hash is None:
                    self.add_node_hash(nodes[i])
                values[i] = int_from_bytes(nodes[i].hash)
                self.check_valid_tree(nodes[i])
            commitment = self.kzg.compute_commitment_lagrange(values)

            assert node.commitment.is_equal(commitment)
            assert node.hash == hash([node.commitment.compress(), node.key, node.value])

    def inorder_traversal(self, node: VerkleBSTNode, order: list = []):
        """
        Inorder traversal of the tree
        """
        if node is not None:
            self.inorder_traversal(node.left)
            order.append(int_from_bytes(node.key))
            self.inorder_traversal(node.right)

        return order

    def inorder_tree_structure(self, node, level: int = 0, prefix: str = "Root", structure: list = []):
        """
        Print the tree in order
        """

        if node is not None:
            self.inorder_tree_structure(node.left, level + 1, "L")
            info = {"position": prefix + str(level),
                    "key": int_from_bytes(node.key),
                    "value": int_from_bytes(node.value)}
            structure.append(info)
            self.inorder_tree_structure(node.right, level + 1, "R")

        return structure


In [769]:
# BLS12_381 curve modulus\
MODULUS = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001

# Verkle trie parameters
KEY_LENGTH = 256  # bits
WIDTH = 2
PRIMITIVE_ROOT = 7


# Number of key-value pairs to insert
NUMBER_INITIAL_KEYS = 2**15

# Number of keys to insert after computing initial tree
NUMBER_ADDED_KEYS = 512

# Number of keys to delete
NUMBER_DELETED_KEYS = 512

# Number of key/values pair in proof
NUMBER_KEYS_PROOF = 5000

In [770]:
kzg_integration = KzgIntegration(MODULUS, WIDTH, PRIMITIVE_ROOT)
kzg_setup = kzg_integration.generate_setup(WIDTH, 8927347823478352432985)
kzg_utils = kzg_integration.kzg_utils(kzg_setup)

In [771]:
root = VerkleBSTNode(int_to_bytes(10), int_to_bytes(10))
verkle = VerkleBST(kzg_setup, kzg_utils, root, MODULUS, WIDTH)

verkle.insert_node(int_to_bytes(5), int_to_bytes(5))
verkle.insert_node(int_to_bytes(15), int_to_bytes(15))
verkle.insert_node(int_to_bytes(2), int_to_bytes(2))
verkle.insert_node(int_to_bytes(7), int_to_bytes(7))

In [772]:
verkle.inorder_tree_structure(verkle.root)

[{'position': 'L2', 'key': 2, 'value': 2},
 {'position': 'L1', 'key': 5, 'value': 5},
 {'position': 'R2', 'key': 7, 'value': 7},
 {'position': 'Root0', 'key': 10, 'value': 10},
 {'position': 'R1', 'key': 15, 'value': 15}]

In [773]:
verkle.inorder_tranversal(verkle.root)

AttributeError: 'VerkleBST' object has no attribute 'inorder_tranversal'

In [774]:
verkle.find_path_to_node(verkle.root, int_to_bytes(3))

[(0, <__main__.VerkleBSTNode at 0x7f04611af040>),
 (0, <__main__.VerkleBSTNode at 0x7f04611af6d0>),
 (1, <__main__.VerkleBSTNode at 0x7f04611af310>),
 (None, None)]

In [775]:
verkle.add_node_hash(verkle.root)

In [776]:
verkle.check_valid_tree(verkle.root)

In [777]:
verkle.upsert_verkle_node(int_to_bytes(5), int_to_bytes(10))

In [778]:
verkle.check_valid_tree(verkle.root)

In [779]:
verkle.upsert_verkle_node(int_to_bytes(3), int_to_bytes(3))

In [780]:
verkle.check_valid_tree(verkle.root)

In [781]:
verkle.delete_verkle_node(int_to_bytes(3))

In [782]:
verkle.check_valid_tree(verkle.root)

In [783]:
verkle.upsert_verkle_node(int_to_bytes(3), int_to_bytes(3))

In [784]:
verkle.delete_verkle_node(int_to_bytes(2))

In [785]:
verkle.check_valid_tree(verkle.root)

In [786]:
verkle.upsert_verkle_node(int_to_bytes(2), int_to_bytes(2))
verkle.upsert_verkle_node(int_to_bytes(4), int_to_bytes(4))
verkle.upsert_verkle_node(int_to_bytes(6), int_to_bytes(6))
verkle.upsert_verkle_node(int_to_bytes(8), int_to_bytes(8))
verkle.upsert_verkle_node(int_to_bytes(9), int_to_bytes(9))
verkle.upsert_verkle_node(int_to_bytes(1), int_to_bytes(1))

In [787]:
verkle.check_valid_tree(verkle.root)

In [788]:
verkle.delete_verkle_node(int_to_bytes(5))

In [789]:
verkle.check_valid_tree(verkle.root)

In [790]:
verkle.upsert_verkle_node(int_to_bytes(5), int_to_bytes(5))

In [791]:
verkle.check_valid_tree(verkle.root)

In [792]:
verkle.delete_verkle_node(int_to_bytes(3))

In [793]:
verkle.check_valid_tree(verkle.root)

In [804]:
MODULUS = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
WIDTH = 2
PRIMITIVE_ROOT = 7

class TestVerkleBST:
  kzg_integration = KzgIntegration(MODULUS, WIDTH, PRIMITIVE_ROOT)
  kzg_setup = kzg_integration.generate_setup(WIDTH, 8927347823478352432985)
  kzg_utils = kzg_integration.kzg_utils(kzg_setup)

  def setup_method(self):
    root = VerkleBSTNode(int_to_bytes(100), int_to_bytes(100))
    verkle_bst = VerkleBST(kzg_setup, kzg_utils, root, MODULUS, WIDTH)
    self.verkle_bst = verkle_bst
    
  def test_insert(self):
    verkle_bst = self.verkle_bst
    verkle_bst.insert_node(int_to_bytes(50), int_to_bytes(50))
    verkle_bst.insert_node(int_to_bytes(150), int_to_bytes(150))
    verkle_bst.insert_node(int_to_bytes(25), int_to_bytes(25))
    verkle_bst.insert_node(int_to_bytes(75), int_to_bytes(75))
    print(verkle_bst.inorder_traversal(verkle_bst.root))
    verkle_bst.add_node_hash(verkle_bst.root)
    print(verkle_bst.inorder_traversal(verkle_bst.root))

    tree_structure = verkle_bst.inorder_tree_structure(verkle_bst.root)
    print(tree_structure)
    assert tree_structure == [{'position': 'L2', 'key': 25, 'value': 25},
                              {'position': 'L1', 'key': 50, 'value': 50},
                              {'position': 'R2', 'key': 100, 'value': 100},
                              {'position': 'Root0', 'key': 100, 'value': 100}, 
                              {'position': 'R1', 'key': 150, 'value': 150}]

    verkle_bst.check_valid_tree(verkle_bst.root)
  
  def teardown_method(self):
    self.verkle_bst = None
    del self.verkle_bst

[31mF[0m[31m                                                                                            [100%][0m
[31m[1m____________________________________ TestVerkleBST.test_insert _____________________________________[0m

self = <__main__.TestVerkleBST object at 0x7f0460db6a60>

    [94mdef[39;49;00m [92mtest_insert[39;49;00m([96mself[39;49;00m):[90m[39;49;00m
      verkle_bst = [96mself[39;49;00m.verkle_bst[90m[39;49;00m
      verkle_bst.insert_node(int_to_bytes([94m50[39;49;00m), int_to_bytes([94m50[39;49;00m))[90m[39;49;00m
      verkle_bst.insert_node(int_to_bytes([94m150[39;49;00m), int_to_bytes([94m150[39;49;00m))[90m[39;49;00m
      verkle_bst.insert_node(int_to_bytes([94m25[39;49;00m), int_to_bytes([94m25[39;49;00m))[90m[39;49;00m
      verkle_bst.insert_node(int_to_bytes([94m75[39;49;00m), int_to_bytes([94m75[39;49;00m))[90m[39;49;00m
      [96mprint[39;49;00m(verkle_bst.inorder_traversal(verkle_bst.root))[90m[39;49;00m
      

In [805]:
%%run_pytest[clean] -qq

UsageError: %%run_pytest[clean] is a cell magic, but the cell body is empty.
