In [10]:
import networkx as nx
import dwave_networkx as dnx
import random
import matplotlib.pyplot as plt
import pulp
import numpy as np
from copy import deepcopy
from tree import *

from lpSolver import solve
from uccgGenerator import tree_insertion, Graph
from plotNetwork import plotGraph, plotCoupling

## Binary tree

The number of binary tree with $n$ nodes is a Catalan number:
$$C_{n}=\frac{1}{n+1}\left(\begin{array}{c}{2 n} \\ {n}\end{array}\right)=\frac{(2 n) !}{(n+1) ! n !}=\prod_{k=2}^{n} \frac{n+k}{k}$$

An $n$-binary tree is that it has $n$ internal nodes (two children) and $n-1$ other nodes (no children)

In [11]:
def rotate_left_keep_left(u_node, parent):
    # u is node, not idx
    u_right = u_node.right
    u_node.right = parent   
    parent.left = u_right
    u_right.parent = parent

In [12]:
def rotate_right_keep_right(u_node, parent):
    u_left = u_node.left
    u_node.left = parent
    parent.right = u_left
    u_left.parent = parent

In [13]:
def rotate_left_keep_right(u_node, parent):
    u_left = u_node.left
    u_node.left = parent
    parent.left = u_left
    u_left.parent = parent

In [14]:
def rotate_right_keep_left(u_node, parent):
    u_right = u_node.right
    u_node.right = parent
    parent.right = u_right
    u_right.parent = parent

In [15]:
left, right = 'left', 'right'
rotate_func = {(prop, orient): eval(f"rotate_{prop}_keep_{orient}") for prop in [left, right] for orient in [left, right]}
def rotate(t, u, orient):
    u_node = t.get_node(u)
    parent = u_node.parent
    u_prop = left if u_node == parent.left else right
    p_parent = parent.parent
    rotate_func[(u_prop, orient)](u_node, parent)
    parent.parent = u_node
    u_node.parent = p_parent
    if p_parent:
        if p_parent.left == parent:
            p_parent.left = u_node
        else:
            p_parent.right = u_node
    else:
        t.root = u_node

In [16]:
def get_children(t):
    # inner node should not be a leaf or the root
    inner_nodes = t.get_inner_nodes()
    n = len(inner_nodes)
    prob = 1/(4*n)
    children = []
    total = 0
    for u in inner_nodes:
        for orient in [left, right]:
            t1 = deepcopy(t)
            rotate(t1, u, orient)
            children.append(t1)
            t1.prob = prob
            total += prob
    t_stay = deepcopy(t)
    t_stay.prob = 1 - total
    return children + [t_stay]

In [17]:
def markov_chain_dist(t1, t2):
    visited = set()
    queue = [(t1, 0)]
    while queue:
        head, dist = queue.pop(0)
        if head == t2:
            return dist
        children = get_children(head)
        for child in children:
            if child not in visited:
                visited.add(child)
                queue.append((child, dist+1))
    return -1

In [18]:
def generate_pair(h=2):
    t1 = Tree(h)
    for _ in range(50):
        t1 = random.choice(get_children(t1))
    t2 = random.choice(get_children(t1))
    while t2 == t1:
        t2 = random.choice(get_children(t1))
    return t1, t2

In [19]:
t1, t2 = generate_pair(3)
t1_children = get_children(t1)
t2_children = get_children(t2)
l1, l2 = len(t1_children), len(t2_children)
# for i in range(l1):
#     for j in range(l2):
#         print(len(t1_children[i].get_inner_nodes())== len(t2_children[j].get_inner_nodes()))

In [20]:
dists = np.zeros((l1, l2))
for i in range(l1):
    for j in range(l2):
         dists[i][j] = markov_chain_dist(t1_children[i], t2_children[j])
    print(i)
status, dist, joint_prob = solve(t1_children, t2_children, dists)
print(dist)

0
1
2
3
4
5
6
7
8
9
10
11
12
0.9166666669999999
