In [67]:
import numpy as np
from scipy.spatial.distance import pdist, squareform
from skbio.tree import TreeNode
import copy
import time 

def mod_nj(feature_matrix, prob_features, names=None, result_constructor=None):
    fm = copy.deepcopy(feature_matrix)

    if names is None:
        names = np.arange(fm.shape[0])
#     # Determine the lcas for each pair of sites
#     lcas = np.zeros((fm.shape[0], fm.shape[0], fm.shape[1]))
#     for i in range(fm.shape[0]):
#         for j in range(i+1, fm.shape[0]):
#             lcas[i,j] = fm[i]*fm[j]
#             lcas[j,i] = lcas[i,j]
    
    names = np.array(names)
    perm = np.random.permutation(len(names))
    
    feature_matrix = feature_matrix[perm]
    names = names[perm]
    
    print(names, perm)
#     log_prob_features = np.log(prob_features)
    log_prob_features = (prob_features)
    log_prob_features[-log_prob_features == np.inf] = -10000 #hacky
    
    
    # Compute the distance matrix
    D = squareform(pdist(fm, lambda u,v: (- (u + v - 2*u*v) * log_prob_features).sum()))
    
    tree_nodes = {}
    
    print('Starting with {0} nodes'.format(len(D)))
    new_name = len(D)
    while len(D) > 2:
        
        print('D', D)
       
        s = time.time()
        # Compute the Q matrix -> Q(ij)=d(ij) - [r(i) + r(j)]/(N-2)
        R  = np.tile(D.sum(0), (fm.shape[0], 1))
        
        R = (R + R.T)
        
        Q = D*(fm.shape[0]-2) - R
        
        print('R', R)
        print('Q', Q)
        
        # Convert Q martix to lower triangular form without the diagonal to avoid merging the same site
        Q[np.tril_indices(Q.shape[0], 0)]  = np.inf
        
        print('min: ', np.argmin(Q, axis=None))
        # Now find the argmin (i,j) of Q. These are the sites the be merged
        min_i, min_j = np.unravel_index(np.argmin(Q, axis=None), Q.shape)
        s = time.time() 
        
        # Now we merge i,j. We need to replace i,j in the feature matrix with lca(i,j).
        lca = fm[min_i]*fm[min_j]
        fm  = np.delete(fm, [min_i,min_j], axis=0)
        fm  = np.vstack([fm, lca])
        
        
        
        # Create a new TreeNode from the merged children
        
#         new_name = 'lca({0},{1})'.format(names[min_i], names[min_j])
        new_name += 1
        child_i = tree_nodes.get(names[min_i], TreeNode(name=str(names[min_i])))
        child_j = tree_nodes.get(names[min_j], TreeNode(name=str(names[min_j])))
        new_node = TreeNode(name=str(new_name), length=None, parent=None, children=[child_i, child_j])
        new_node.lca = lca 
        
        print(names[min_i], names[min_j], 'joined')
        child_i.parent = new_node
        child_j.parent = new_node
        
        tree_nodes[new_name] = new_node
        
        
        names = np.delete(names, [min_i,min_j], axis=0)
        names = np.hstack([names, new_name])
        

        
#         # Remove the entries for i,j in the lca matrix.
#         lcas = np.delete(np.delete(lcas, [min_i,min_j], axis=0), [min_i,min_j], axis=1)

#         # Add a new lca entry for the merged node lca(i,j) -> k for every other k in the feature matrix 
#         new_lcas = np.zeros((fm.shape[0], fm.shape[0], fm.shape[1]))
#         new_lcas[:-1,:-1,:] = lcas
#         j = -1 
#         for i in range(fm.shape[0]):
#             new_lcas[i,j] = fm[i]*fm[j]
#             new_lcas[j,i] = new_lcas[i,j]
#         lcas = new_lcas

        # We also need to replace the distance of each site k to i or j with the distance to lca(i,j)

        D = np.delete(np.delete(D, [min_i,min_j], axis=0), [min_i,min_j], axis=1)

        new_D = np.zeros((fm.shape[0], fm.shape[0]))
        new_D[:-1, :-1] = D

        new_D_row = - ((fm + fm[-1] - 2* fm * fm[-1])*log_prob_features).sum(1)


        new_D[-1, :] = new_D_row
        new_D[:, -1] = new_D_row
        D = new_D
        
    new_name += 1
    # Merge the last two remaining sites to complete the tree
    child1, child2 = tree_nodes[names[0]], tree_nodes[names[1]]
    root = TreeNode(name = str(new_name), children=[child1, child2])
    child1.parent = root
    child2.parent = root
    
    return root



In [68]:
feature_matrix = np.array([[1,1,0,0,0], 
                           [1,1,0,0,0],
                           [1,0,1,0,0], 
                           [1,0,1,0,1],
                           [0,0,0,1,0],
                           [0,0,0,1,0]]
                          )

In [69]:
root = mod_nj(feature_matrix, prob_features, [1,2,3,4,5,6])

[5 3 6 4 2 1] [4 2 5 3 1 0]
Starting with 6 nodes
D [[0. 0. 2. 3. 3. 3.]
 [0. 0. 2. 3. 3. 3.]
 [2. 2. 0. 1. 3. 3.]
 [3. 3. 1. 0. 4. 4.]
 [3. 3. 3. 4. 0. 0.]
 [3. 3. 3. 4. 0. 0.]]
R [[22. 22. 22. 26. 24. 24.]
 [22. 22. 22. 26. 24. 24.]
 [22. 22. 22. 26. 24. 24.]
 [26. 26. 26. 30. 28. 28.]
 [24. 24. 24. 28. 26. 26.]
 [24. 24. 24. 28. 26. 26.]]
Q [[-22. -22. -14. -14. -12. -12.]
 [-22. -22. -14. -14. -12. -12.]
 [-14. -14. -22. -22. -12. -12.]
 [-14. -14. -22. -30. -12. -12.]
 [-12. -12. -12. -12. -26. -26.]
 [-12. -12. -12. -12. -26. -26.]]
min:  34
1 2 joined
D [[0. 0. 2. 3. 3.]
 [0. 0. 2. 3. 3.]
 [2. 2. 0. 1. 3.]
 [3. 3. 1. 0. 4.]
 [3. 3. 3. 4. 0.]]
R [[16. 16. 16. 19. 21.]
 [16. 16. 16. 19. 21.]
 [16. 16. 16. 19. 21.]
 [19. 19. 19. 22. 24.]
 [21. 21. 21. 24. 26.]]
Q [[-16. -16. -10. -10. -12.]
 [-16. -16. -10. -10. -12.]
 [-10. -10. -16. -16. -12.]
 [-10. -10. -16. -22. -12.]
 [-12. -12. -12. -12. -26.]]
min:  5
3 5 joined
D [[0. 1. 3. 2.]
 [1. 0. 4. 3.]
 [3. 4. 0. 3.]
 [2. 3. 3. 0.]]

In [70]:
print(root.ascii_art())

                    /-4
          /9-------|
         |          \-6
         |
-11------|                    /-3
         |          /8-------|
         |         |          \-5
          \10------|
                   |          /-1
                    \7-------|
                              \-2


In [36]:
p = np.random.permutation(6)
p

array([3, 0, 4, 5, 1, 2])

In [37]:
feature_matrix

array([[1, 1, 0, 0, 0],
       [1, 1, 0, 0, 0],
       [1, 0, 1, 0, 0],
       [1, 0, 1, 0, 1],
       [0, 0, 0, 1, 0],
       [0, 0, 0, 1, 0]])

In [38]:
feature_matrix[p]

array([[1, 0, 1, 0, 1],
       [1, 1, 0, 0, 0],
       [0, 0, 0, 1, 0],
       [0, 0, 0, 1, 0],
       [1, 1, 0, 0, 0],
       [1, 0, 1, 0, 0]])

In [65]:
names = np.array([1,2,3,4,5,6])

In [66]:
names[p]

array([4, 1, 5, 6, 2, 3])

In [469]:
Q

array([[ 0.        ,  0.        ,  0.        ],
       [-2.40794561,  0.        ,  0.        ],
       [-2.40794561, -2.40794561,  0.        ]])