In [1]:
import hashlib

def hash_leaf(leaf: str) -> str:
    # return f'h({leaf})'
    return hashlib.sha256(leaf.encode()).hexdigest()

def hash_pair(left: str, right: str) -> str:
    # return f'h({left}, {right})'
    return hashlib.sha256((left + right).encode()).hexdigest()

def cut(s: str) -> str:
    return s[:5] + "..."

def calc_root(leaves: list[str]) -> list[list[str]]:
    tree = [[hash_leaf(l) for l in leaves]]
    tree[0].sort()

    n = len(tree[0])

    while n > 1:
        tree.append([])
        for i in range(0, n, 2):
            left = tree[-2][i]
            right = tree[-2][min(i + 1, n - 1)]
            if left > right:
                left, right = right, left
            tree[-1].append(hash_pair(left, right))
        n = (n + (n % 2)) // 2

    tree.reverse()

    return tree

def get_proof(leaves: list[str], index: int) -> list[str]:
    proof = []

    hashes = [hash_leaf(l) for l in leaves]
    hashes.sort()

    n = len(hashes)
    k = index

    while n > 1:
        j = k - 1 if k & 1 else min(k + 1, n - 1)
        h = hashes[j]
        proof.append(h)
        k >>= 1
        
        for i in range(0, n, 2):
            left = hashes[i]
            right = hashes[min(i + 1, n - 1)]
            if left > right:
                left, right = right, left
            hashes[i >> 1] = hash_pair(left, right)
        n = (n + (n & 1)) >> 1

    return proof

def verify(proof: list[str], root: str, leaf: str) -> bool:
    h = hash_leaf(leaf)

    for p in proof:
        left = h
        right = p
        if left > right:
            left, right = right, left
        h = hash_pair(left, right)

    return h == root
    

leaves = ["A", "B", "C", "D", "E", "F", "G"]
hash_leaves = [hash_leaf(l) for l in leaves]
hash_leaves.sort()

indexes = [hash_leaves.index(hash_leaf(l)) for l in leaves]
print(indexes)

tree = calc_root(leaves)
root = tree[0][0]
proof = get_proof(leaves, indexes[2])
verify(proof, root, leaves[2])

[2, 5, 3, 1, 4, 6, 0]


True