In [None]:
import itertools
from collections import defaultdict
import numpy as np

WE = defaultdict(lambda: None)
WE[0] = 0 
WE[1] = 1

#outputs the Wedderburn number for n
def W(n):
    if n == 0:
        return 0
    if n == 1:
        return 1
    if WE[n] == None: 
        if n % 2 == 1:
            s = sum([W(i) * W(n-i) for i in range(1, (n+1)//2)])
            WE[n] = s
        else:
            s = sum([W(i) * W(n-i) for i in range(1, n//2)])
            s += (W(n//2) * (W(n/2) + 1)) / 2  #W(n/2) choose 2 + W(n/2)
            WE[n] = s
    return WE[n]

#in: L list, unordered with duplicates
#out: ordered list L without duplicates
def Oset(L):
    return sorted(list(dict.fromkeys(L)))

#inputs two list of pairs of integers lst1 and lst2 
#outputs a list of ordered pairs ((x,y),(z,w)) 
#such that (x,y) <= (z,w), and (x,y) in lst1 and (z,w) in lst2 or vice versa
def unordered_pairs(lst1, lst2):
    if len(lst1) == 0 or len(lst1) == 0:
        return []
    else: 
        L = [(x, y) if x <= y else (y, x) for x, y in set(itertools.product(lst1, lst2))]
        return Oset(L)

#inputs tree_0 = (n_0,s_0) and tree_1 = (n_1,s_1)
#outputs tree = (n,s) such that tree is tree_0 \oplus tree_1
def binary_tree(tree_0,tree_1):
    ((n_0,s_0),(n_1,s_1)) = sorted((tree_0,tree_1))
    if n_0 == 0: 
        return (n_1,s_1)
    if n_1 == 0: 
        return (n_0,s_0)
    n = n_0 + n_1
    if n_0 == n_1:
        y = W(n_0)
        s = sum([ W(n - i) * W(i) for i in range(1,n_0)]) + sum([ y-i for i in range(s_0)]) + s_1 - s_0
    else: 
        s = sum([ W(n - i) * W(i) for i in range(1,n_0)]) +  W(n_0) * s_1 + s_0
    return (n,s)

# input: pairs of integers (n,t)
# output: the root split of s'th binary tree (in lex order) on n leaves in the form ((n_0,s_0),(n_1,s_1))
def root_split_finder(pair):
    n=pair[0]
    t = pair[1]
    if t >= W(n):
        raise ValueError("there is no tree for t = ", t)
    if n == 1:
        return ((1,0),(0,0))
    elif n % 2 == 1:
        i = 1
        while W(n-i) * W(i) <= t: 
            t-=W(n-i) * W(i)
            i += 1
        (s_1,s_2) = divmod(int(t),int(W(i)))
        return ((n-i,s_1),(i,s_2))
    else:
        i = 1
        while i <= n/2 and W(n-i) * W(i) <= t: 
            t-=W(n-i) * W(i)
            i+=1      
        (s_1,s_2) = divmod(int(t),int(W(i)))
        if i != n/2:
            return ((n-i,s_1),(i,s_2))
        else:
            y = W(n//2)
            j = 0
            r = 0
            while y-j-1 < t:
                t -= (y-j)
                j += 1
                r += 1
            return ((n//2,t+r),(n//2,j))     

m=defaultdict(lambda: defaultdict(lambda: None ))


#in: a binary tree in the form pair = (n,s)  
#out: draws it in ascii_art  (using the BinaryTree class of Sagemath)
def draw(pair):
    def convert_back(pair):
        T_1 = BinaryTree([])
        (n,s) = (pair[0],pair[1])
        if n == 1:
            return T_1
        if n > 1:
            pair = root_split_finder((n,s))
            T_left = convert_back(pair[0])
            T_right = convert_back(pair[1])
            return BinaryTree([T_left,T_right])
    print(ascii_art(convert_back(pair)))    
    
#in: a tree = (n,s) 
#out: the tree (n+1,t) = (n,s) \oplus (1,0)
def add_one(T):
    return binary_tree(T,(1,0))

#in: a tree = (n,s) 
#out: the tree (n+2,t) = (n,s) \oplus (2,0)
def add_two(T):
    return binary_tree(T,(1,0))

#in: a tree = (n,s) 
#out: the tree (n+3,t) = (n,s) \oplus (3,0)
def add_three(T): 
    return binary_tree(T,(1,0))

#in: permutation p and a button as a pair
#out: 
def mid(button):
    (a,b) = (button[0],button[1])
    if a == b:
        return a-1
    else:
        return (a + b -1)//2

#in: button
#out: length of button
def length_of(button):
    return (button[1] - button[0] + 1)/2

#in: a pattern pat on the elements 1,...,i-1,i+1,...,n
#out: the permutation p on the elements 1,...,n-1
def perm_canonicalizer(pat,i):
    return  Permutation([x+1 if x<i else x for x in pat]  )

def sorted_matrix(M):
    return sorted([sorted(row) for row in M ] ) 

left_distance_matrix_dict = defaultdict(lambda:None)
right_distance_matrix_dict = defaultdict(lambda:None)

#in: a tanglegram T=(L,R,sigma)
#out: distance matrix of T as an m \times m matrix
def distance_dict_pair(T):
    ltree = T[0]
    rtree = T[1]
    sigma = T[2]
    n = ltree[0]
    if left_distance_matrix_dict[ltree] == None or right_distance_matrix_dict[rtree] == None:
        left_buttonizers = set([(i,i) for i in range(1,n+1)])
        right_buttonizers = set([(i,i) for i in range(1,n+1)])
        
        left_root_number = 1
        right_root_number = 1
        
        left_distance_matrix = matrix(ZZ, 2*n-1, {  }, sparse=True)
        right_distance_matrix = matrix(ZZ, 2*n-1, {  }, sparse=True)
    
        left_internal_vertex = n
        right_internal_vertex = n
    
        left_leaf_number = 0
        right_leaf_number = 0
        
        # input: pair of integers (n,s) denoting the tree and pair of integers (a,b) denoting the leaf number interval
        # output: the root split of s'th binary tree (in lex order) on n leaves in the form ((n_0,s_0),(n_1,s_1))
        def root_split_finder_with_left_buttons( pair, leaf_interval, left_root_number ):
            nonlocal left_internal_vertex 
            nonlocal left_leaf_number
            
            left_internal_vertex += 1
            
            (ltree, rtree) = root_split_finder(pair)
            
            for i in list(range(n+1,left_internal_vertex)) + list(range(1,left_leaf_number+1)):
                left_distance_matrix[i-1,left_internal_vertex-1] = left_distance_matrix[i-1,left_root_number-1] + 1
                left_distance_matrix[left_internal_vertex-1,i-1] = left_distance_matrix[i-1,left_internal_vertex-1]
         
            left_root_number = left_internal_vertex
            
            if ltree == rtree:
                left_buttonizers.add(leaf_interval)
            
            if ltree[0] > 1:
    
                left_bound = leaf_interval[0]
                right_bound = leaf_interval[0] + ltree[0] - 1
                
    
                root_split_finder_with_left_buttons(ltree,(left_bound, right_bound), left_root_number)
    
            if ltree[0] == 1:
    
                left_leaf_number += 1
    
                for i in list( range(n+1,left_internal_vertex+1) ) + list( range(1,leaf_interval[0]) ):
                   
                    left_distance_matrix[i-1,leaf_interval[0]-1] = left_distance_matrix[i-1,left_root_number-1] + 1
                    left_distance_matrix[leaf_interval[0]-1,i-1] = left_distance_matrix[i-1,leaf_interval[0]-1]
           
                   
            if rtree[0] > 1:
    
                left_bound = leaf_interval[0] + ltree[0]
                right_bound = leaf_interval[1]
    
    
                root_split_finder_with_left_buttons(rtree,(left_bound, right_bound), left_root_number)
    
            if rtree[0] == 1:
    
                left_leaf_number += 1
    
                for i in list( range(n+1,left_internal_vertex+1) ) + list( range(1,leaf_interval[1]) ):
    
                    left_distance_matrix[i-1,leaf_interval[1]-1] = left_distance_matrix[i-1,left_root_number-1] + 1
                    left_distance_matrix[leaf_interval[1]-1,i-1] = left_distance_matrix[i-1,leaf_interval[1]-1]
    
        # input: pair of integers (n,s) denoting the tree and pair of integers (a,b) denoting the leaf number interval
        # output: the root split of s'th binary tree (in lex order) on n leaves in the form ((n_0,s_0),(n_1,s_1))
        def root_split_finder_with_right_buttons(pair,leaf_interval,right_root_number):
    
            nonlocal right_internal_vertex 
            nonlocal right_leaf_number
            
            right_internal_vertex += 1
    
            (ltree, rtree) = root_split_finder(pair)
            
            for i in list(range(n+1,right_internal_vertex)) + list(range(1,right_leaf_number+1)):
                right_distance_matrix[i-1,right_internal_vertex-1] = right_distance_matrix[i-1,right_root_number-1] + 1
                right_distance_matrix[right_internal_vertex-1,i-1] = right_distance_matrix[i-1,right_internal_vertex-1]
    
            right_root_number = right_internal_vertex
            
            if ltree == rtree:
                right_buttonizers.add( leaf_interval )
                
            if ltree[0] > 1:
                
                left_bound = leaf_interval[0]
                right_bound = leaf_interval[0] + ltree[0] - 1
                
                root_split_finder_with_right_buttons(ltree,(left_bound, right_bound), right_root_number)
    
            if ltree[0] == 1:
    
                right_leaf_number += 1
    
                for i in list( range(n+1,left_internal_vertex+1) ) + list( range(1,leaf_interval[0]) ):
                    right_distance_matrix[i-1,leaf_interval[0]-1] = right_distance_matrix[i-1,right_root_number-1] + 1 
                    right_distance_matrix[leaf_interval[0]-1,i-1] = right_distance_matrix[i-1,leaf_interval[0]-1]    
                    
            if rtree[0] > 1:
                
                left_bound = leaf_interval[0] + ltree[0]
                right_bound = leaf_interval[1]
                
                root_split_finder_with_right_buttons(rtree,(left_bound, right_bound), right_root_number)
    
            if rtree[0] == 1:
    
                right_leaf_number += 1
    
                for i in list( range(n+1,left_internal_vertex+1) ) + list( range(1,leaf_interval[1]) ):
    
                    right_distance_matrix[i-1,leaf_interval[1]-1] = right_distance_matrix[i-1,right_root_number-1] + 1
                    right_distance_matrix[leaf_interval[1]-1,i-1] = right_distance_matrix[i-1,leaf_interval[1]-1]    
    
        root_split_finder_with_left_buttons(ltree,(1,n),left_root_number)
        root_split_finder_with_right_buttons(rtree,(1,n), right_root_number)
        left_distance_matrix_dict[ltree] = left_distance_matrix
        right_distance_matrix_dict[rtree] = right_distance_matrix
    else:
        left_distance_matrix = left_distance_matrix_dict[ltree]
        right_distance_matrix = right_distance_matrix_dict[rtree]
    
    distance_matrix_dict = defaultdict( lambda:0 )
    
    numpy_4n_1 = np.zeros(4*n, dtype=int)
    array_4n_1 = defaultdict( lambda:0 )
    for j in range(2*n-1):
        numpy_4n_1[ j ] = left_distance_matrix[n,j] + 1
        array_4n_1[left_distance_matrix[n,j] + 1] += 1
    for j in range( 2*n-1, 4*n-2 ):
        numpy_4n_1[ j ] = min([ left_distance_matrix[n,k-1] + right_distance_matrix[ sigma[k-1]-1, j-(2*n-1) ] for k in range(1,n+1) ]) + 1 + 1
        array_4n_1[numpy_4n_1[ j ]] += 1

    array_4n_1_list = [array_4n_1[l] for l in range(1,max(array_4n_1)+1)]
    distance_matrix_dict[tuple(array_4n_1_list) ] += 1
    
    for i in range(2*n-1):
        array_i = defaultdict( lambda:0 )
        array_i[numpy_4n_1[i]] += 1
        for j in range(2*n-1):
            array_i[ left_distance_matrix[i,j] ] += 1
        for j in range( 2*n-1,4*n-2 ):
            sum_list = [ left_distance_matrix[i,k-1] + right_distance_matrix[ sigma[k-1]-1, j-(2*n-1) ] for k in range(1,n+1) ]
            array_i[ min(sum_list) + 1 ] +=1 
        
        array_i_list = [array_i[l] for l in range(1,max(array_i)+1)]
        distance_matrix_dict[ tuple(array_i_list) ] +=1    

    for i in range(2*n-1,4*n-2):    
        array_i = defaultdict( lambda:0 )
        array_i[numpy_4n_1[i]] += 1
        for j in range(2*n-1):
            # sum_list = [ left_distance_matrix[j,k-1] + right_distance_matrix[ sigma[k-1]-1, i-(2*n-1) ] for k in range(1,n+1) ]
            # dist_ij_2n_1 = min([ left_distance_matrix[j,k-1] + right_distance_matrix[ sigma[k-1]-1, i-(2*n-1) ] for k in range(1,n+1) ]) + 1
            array_i[ min([ left_distance_matrix[j,k-1] + right_distance_matrix[ sigma[k-1]-1, i-(2*n-1) ] for k in range(1,n+1) ]) + 1 ] +=1 
        for j in range(2*n-1,4*n-2):
            # distlower_ij = right_distance_matrix[i-(2*n-1),j-(2*n-1)]
            array_i[ right_distance_matrix[i-(2*n-1),j-(2*n-1)]  ] += 1
        
        # array_i_list = [array_i[l] for l in range(1,max(array_i)+1)]
        distance_matrix_dict[tuple([array_i[l] for l in range(1,max(array_i)+1)])] +=1

    distance_matrix_edges_dict = defaultdict( lambda:0 )
    for i in range( n ):
        array_i = defaultdict( lambda:0 )
        for j in range( n ):
            array_i[min( [left_distance_matrix[i,j], right_distance_matrix[ sigma[i]-1, sigma[j]-1 ] ] )] +=1
        # array_i_list = [array_i[l] for l in range(1,max(array_i)+1)]
        distance_matrix_edges_dict[tuple([array_i[l] for l in range(1,max(array_i)+1)])] +=1
    # print((distance_matrix_dict,distance_matrix_edges_dict ))
    return (distance_matrix_dict,distance_matrix_edges_dict )

def check_isom_edge_distances(T_1,T_2):
    if T_1[0] != T_2[0] or T_1[1] != T_2[1]:
        return False
    else:
        (DV_1, DE_1) = distance_dict_pair(T_1)
        (DV_2, DE_2) = distance_dict_pair(T_2)
        
        if DV_1 == DV_2 and DE_1 == DE_2:
            return True
        else:
            return False


for n in range(3,6):
    print(n)
    k = W(n)
    list_tan = defaultdict(lambda:[])
    for bl in range(k):
        for br in range(k):
            # print(bl,br)
            for perm in Permutations(n):
                candidate_tan = ((n,bl),(n,br),perm)
                new_tan = True
                for tan_1 in list_tan[(n,bl),(n,br)]:
                    if check_isom_edge_distances(candidate_tan,tan_1) == True:
                        new_tan = False
                if new_tan:
                    list_tan[(n,bl),(n,br)].append(candidate_tan)
                    if sum([len(S) for S in list_tan.values()]) %20 ==0:
                        print("non-isomorphic found so far:",sum([len(S) for S in list_tan.values()]))
    
    print("number of tanglegrams on", n ,"leaves =",sum([len(S) for S in list_tan.values()]))

3
number of tanglegrams on 3 leaves = 2
4
number of tanglegrams on 4 leaves = 13
5
