## Merkle Tree (Hash Tree)

- The idea behind a Merkle Tree is quite simple, and best illustrated with a problem:
    - Let's suppose I have a huge array of values (e.g. in a database)
    - Because of the distributed nature of my database, to maintain consistency, I need to sync data across different nodes periodically
    - At some point, I notice that one of my nodes has gone down! So some new data transactions have occurred in my DB, and the node that was dead did not receive these new data
    - Once the node restarts, its data is inconsistent with the others. 
    - How would I know what is different?

- A "brute force" solution to something like this would be to completely rewrite the entire database from one of the updated replicas
    - This is $O(N)$ in the amount of data you have, which is not ideal
    - If the number of dropped transactions is small, you are doing a lot of redundant replication, because 99.99% of data could be the same! This is wasteful
    - Can we do better? Yes we can!

- Enter Merkel Trees
    - A Merkel Tree is simply a binary tree
    - Each node in the tree:
        - Contains a value (like a regular binary tree)
        - Contains an hash value of the hashes of its left and right children

- How does this help us solve the problem?
    - Imagine each database replica has a Merkel tree constructed and maintained, which takes $O(N)$ space
        - Let's call the 2 replicas `A` and `B`
    - Imagine replica `A` goes down and it misses some data transactions
        - As a result, the hash value at the root node of `A` will be different from the hash value at the root node of `B`
            - Why? Because somewhere down the tree, the value at a leaf node is different, and as a result this difference will bubble upwards 
        - So the root is different. We now compare 1 level down;
            - In this case, the left child contains nodes in the first half of the array, and the right contains nodes in the second
        - If the difference is only in the first half of the array, the right hashes will still match!
            - Therefore, we know that we only need to copy the leaves of the second half of the array!

- Nuances
    - A Merkle Tree must always be binary. 
        - What happens when it isn't?
            - Imagine I have leaf nodes [1,2,3,4,5]. From the root:
                - Level 1: root.left = [1,2], root.right = [3,4,5]
                - Level 2: root.left.left = [1], and root.left.right = [2]
                - Level 2: root.right.left = [3], and root.right.right = [4,5]
                - Level 3: root.right.right.left = [4], root.right.right.right = [5]
            
            - Suddenly, this tree is pretty imbalanced! There are 3 levels on the right, but 2 on the left
            
            - Suppose we want to compare this with [1,2,3,4,5,6]
                - Level 1: root.left = [1,2,3], root.right = [4,5,6]
                - Level 2: root.left.left = [1], and root.left.right = [2,3]
                - Level 3: root.left.right.left = [2], and root.left.right.right = [3]
                - Level 2: root.right.left = [4], and root.right.right = [5,6]
                - Level 3: root.right.right.left = [5], root.right.right.right = [6]
            
            - Traversing both trees and comparing hashes:
                - Root nodes will not match
                - Root.left does not match
                - Root.right does not match
                - root.left.left matches, root.left.right does not match
                - etc.
            
            - Because there are so many mismatches, you need to copy many values
        
        - What if I force it to be binary at each level?
            - Imagine I have leaf nodes [1,2,3,4,5]. From the root:
                - Length is odd, append last value [1,2,3,4,5,5]
                
                - Level 1: root.left = [1,2,3], root.right = [4,5,5]
                    - Lengths are odd, append last value; root.left = [1,2,3,3], root.right = [4,5,5,5]
                - Level 2: root.left.left = [1,2], root.left.right = [3,3]
                - Level 3: root.left.left.left = [1], root.left.left.right = [2], root.left.right.left = [3], root.left.right.right = [3]
                - Level 2: root.right.left = [4,5], root.right.right = [5,5]
                - Level 3: root.right.left.left = [4], root.right.left.right = [5], root.right.right.left = [5], root.right.right.right = [5]
            
            - Imagine I have leaf nodes [1,2,3,4,5,6]. From the root:
                - Level 1: root.left = [1,2,3], root.right = [4,5,6]
                    - Lengths are odd, append last value; root.left = [1,2,3,3], root.right = [4,5,6,6]
                - Level 2: root.left.left = [1,2], root.left.right = [3,3]
                - Level 3: root.left.left.left = [1], root.left.left.right = [2], root.left.right.left = [3], root.left.right.right = [3]
                - Level 2: root.right.left = [4,5], root.right.right = [6,6]
                - Level 3: root.right.left.left = [4], root.right.left.right = [5], root.right.right.left = [6], root.right.right.right = [6]

            - Traversing both trees and comparing hashes:
                - Root nodes will not match
                - Root.left **MATCHES**
                - Root.right does not match
                - So as a result of consistently maintaining the length of the array, the search space is halved!        

### Implementation

In [37]:
import hashlib
from typing import Any
from copy import copy
from collections import deque

class Node:
    def __init__(self, left, right, hashval: str, content: str, is_copy: bool = False):
        self.left: Node = left
        self.right: Node = right
        self.hashval: str = hashval
        self.content: str = content

        ## Since we are constantly duplicating nodes to ensure our array is even length, keep track of which nodes are copies and which are originals
        self.is_copy: bool = is_copy
    
    @staticmethod
    def hash(content: str):
        '''
        Create a method to do hashing
        '''
        return hashlib.md5(string=content.encode('utf-8')).hexdigest()

    def __copy__(self):
        return Node(**self.__dict__)
    

class MerkleTree:
    def __init__(self, array: list[Any]):
        self.array: list[Any] = array
        self.root = self.make_merkel_tree()

    def make_merkel_tree(self) -> Node:
        '''
        Recursively construct a Merkle Tree from the input array
        '''
        
        ## Create Nodes for every array value
        leaf_nodes = [Node(None, None, Node.hash(str(content)), str(content)) for content in self.array]

        ## Return root node of the built tree
        root_node = self._recursive_make_tree_helper(leaf_nodes)
        return root_node
        

    def _make_even_length_array(self, array):
        '''
        Merkle tree array must always be even length to ensure binary structure. See section on `Nuances` for discussion why this must be true
        '''
        if len(array) % 2 == 1:
            last_element_copy = copy(array[-1])
            last_element_copy.is_copy = True
            array.append(last_element_copy)

        return array

    def _recursive_make_tree_helper(self, leaf_nodes: list[Node]) -> Node:
        
        ## Check if leaf
        if len(leaf_nodes) == 1:
            return leaf_nodes[-1]
        
        ## Merkle tree array must always be even length to ensure binary structure. See section on `Nuances` for discussion why this must be true
        leaf_nodes = self._make_even_length_array(leaf_nodes)

        ## If we reach this point, there are at least 2 nodes in the array. Find the midpoint
        midpoint = len(leaf_nodes)//2
        
        ## Recursively make left and right trees
        left_child = self._recursive_make_tree_helper(leaf_nodes[:midpoint])
        right_child = self._recursive_make_tree_helper(leaf_nodes[midpoint:])

        return Node(left=left_child, right=right_child, hashval=Node.hash(left_child.hashval + right_child.hashval), content=str(left_child.content)+', '+str(right_child.content))

def find_discrepant_nodes(t1: MerkleTree, t2: MerkleTree):
    '''
    When comparing 2 subtrees, we want to know which part of the subtree doesn't match, and which values we should copy.
    '''
    queue = deque([(t1.root, t2.root)])
    discrepancies = []
    
    while queue:        
        curr1, curr2 = queue.popleft()  # Use popleft for FIFO behavior

        # If hashes match, continue to next pair
        if curr1.hashval == curr2.hashval:
            continue
        
        # If both are leaf nodes and they don't match, record the discrepancy
        if curr1.left is None and curr1.right is None and curr2.left is None and curr2.right is None:
            if (not curr1.is_copy) and (not curr2.is_copy):
                discrepancies.append((curr1, curr2))
        
        # If one node has children and the other doesn't, add that as a discrepancy (i.e. either curr1 has not kids, or curr2 has no kids, but not both, which is handled above)
        if (curr1.left is None and curr1.right is None) or (curr2.left is None and curr2.right is None):
            if (not curr1.is_copy) or (not curr2.is_copy):
                discrepancies.append((curr1, curr2))

        # Continue to traverse the tree
        if curr1.left and curr2.left:
            queue.append((curr1.left, curr2.left))
        if curr1.right and curr2.right:
            queue.append((curr1.right, curr2.right))
        
        # Handle cases where one has children and the other does not
        if curr1.left and (curr2.left is None):
            discrepancies.append((curr1, None))  # curr1 has a child that curr2 does not
        if curr2.left and (curr1.left is None):
            discrepancies.append((None, curr2))  # curr2 has a child that curr1 does not

    return discrepancies

t1 = MerkleTree([1,2,3,4,5])
t2 = MerkleTree([1,2,3,4,5,6])

[(x1.content, x2.content) for (x1,x2) in find_discrepant_nodes(t1, t2)]

[('5', '6')]