## Can a machine learning algorithm do a good job of predicting correct quartet configurations?

### imports:

In [7]:
import numpy as np
import toytree
import itertools
import copy
from itertools import compress
import toyplot
import tensorflow as tf
import contextlib
import sys
import io
import ete3
import os
import ipyrad as ip
import subprocess
from numba import jit

### functions:

In [17]:
def decompose_to_quartets(tre):

    ## set to compared ordered sets, and set to store quartet
    qrts = set()
    stored = set()

    ## get all tips in tree as a set
    n_all = set(tre.tree.get_leaf_names())

    ## traverse tree
    for node in tre.tree.traverse():
        ## skip root or tip nodes
        if not (node.is_root() or node.is_leaf()):

            ## get all tips below this node
            below = set(node.get_leaf_names())
            above = n_all - below

            ## get all combinations of 2 above and 2 below this node
            for qrt in get_all_combs(tre, above, below):

                ## add to quartet set
                sqrt = tuple(sorted(qrt))
                if sqrt not in qrts:
                    stored.add(qrt)
                    qrts.add(sqrt)

    ## store qrts  
    return stored
def get_all_combs(self, set1, set2, as_list=False):
    #qiter = (sorted(i) + sorted(j) for (i, j) in itertools.product(
    #            itertools.combinations(set1, 2), 
    #            itertools.combinations(set2, 2),
    #        ))
    qiter = (tuple(i+j) for (i, j) in itertools.product(
        itertools.combinations(set1, 2), 
        itertools.combinations(set2, 2),
    ))

    ## option to return as list 
    if as_list:
        return list(qiter)
    ## but returning as generator is more efficient
    else:
        return qiter

## Show that we should be able to recognize the true quartet visually

Here, we import a random tree, a sequence simulated on that tree, and a random quartet from the tree. The tree has random branch lengths, and the sequences paired with the trees evolve at varying rates and are of varying lengths.

In [50]:
# pick a random tree from file
treenum = np.random.choice(range(2000))+1
print "Tree number: " + str(treenum)
thetree = toytree.tree('random_trees/samp'+str(treenum+1)+'.tre')
# get all TRUE splits on the tree
treeqrts = list(decompose_to_quartets(thetree))

fname = ('tree_seqs/test'+str(treenum+1)+'.dat')
with open(fname) as f:
    sequences = f.readlines()
# remove whitespace characters like `\n` at the end of each line
sequences = [x.strip() for x in sequences] 
sequences.pop(0)

# separate all sequences and tip names
names = [sequences[i][0:10].strip(" ") for i in range(len(sequences))]
iso_sequences = [sequences[i][10:].strip(" ") for i in range(len(sequences))]

# pick a random quartet
num = int((np.random.choice(range(len(treeqrts)),1)))
print "Quartet number = " + str(num)

# visualize this quartet on the tree
colors = [thetree.colors[0] if i==True else thetree.colors[1] \
          for i in [thetree.get_node_labels()[q] in treeqrts[num] for q in range(len(thetree.get_node_labels()))]]
thetree.draw(
    width=300,
    height = 490,
    node_labels=False, 
    node_color=colors,
    node_size=15,
);


print "the TRUE split is: " + str(treeqrts[num])

Tree number: 1344
Quartet number = 60579
the TRUE split is: ('T33', 'T30', 'T7', 'T1')


Let's pretend we don't know the true quartet and are just taking a random set of four tips and shuffling them three ways:

In [51]:
qrtnum = num
true_qrt = np.array(treeqrts[qrtnum])
tipnames = copy.deepcopy(true_qrt)
np.random.shuffle(tipnames)
# is correct config of these tips [0123],[0213], or [0312]
correct_config = [int( ((set([tipnames[i] for i in q[0:2]]) == set(true_qrt[0:2])) or (set([tipnames[i] for i in q[0:2]]) == set(true_qrt[2:4]))) ) for q in [[0,1,2,3],[0,2,1,3],[0,3,1,2]]]
print "The true quartet is: " + str(true_qrt)
print "The random configuration of this (i.e. when true quartet unknown) is: " + str(tipnames)
print "The correct configuration of the random configuration is: " + str([[0,1,2,3],[0,2,1,3], [0,3,1,2]][correct_config.index(1)])


The true quartet is: ['T33' 'T30' 'T7' 'T1']
The random configuration of this (i.e. when true quartet unknown) is: ['T33' 'T1' 'T30' 'T7']
The correct configuration of the random configuration is: [0, 2, 1, 3]


### Now, we can shuffle our random arrangement of the four tips into the three possible splits: [0,1,2,3],[0,2,1,3], and [0,3,1,2]. 

In [52]:
interestednames = tipnames # this is a list of four tip names... e.g. ["t1","t2","t3","t4]
taxa_ids = list(itertools.chain.from_iterable([list(compress(range(len(names)),i)) for i in [[q == i for i in names] for q in interestednames]]))

tempobj = [iso_sequences[i] for i in taxa_ids]

# eliminate non-snps

ind_samples = []
for i in range(len(tempobj[0])):
    currentbase = ([tempobj[q][i] for q in range(len(tempobj))])
    if (len(set(currentbase)) > 1):
        ind_samples.append(currentbase)
ind_samples_reset = ind_samples

# separate sequences by fifth taxon

ind_samples = np.array(ind_samples_reset)
ind_samples = np.where(ind_samples=='A',0,ind_samples)
ind_samples = np.where(ind_samples=='C',1,ind_samples)
ind_samples = np.where(ind_samples=='G',2,ind_samples)
ind_samples = np.where(ind_samples=='T',3,ind_samples)
ind_samples = ind_samples.astype(int)


print(correct_config)
for q in [[0,1,2,3],[0,2,1,3],[0,3,1,2]]:
    # get the matrices
    indexmat = np.array(range(16))
    indexmat.shape=(4,4)
    # order across matrix is 00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33
    fullmat0123 = np.zeros(shape=(16,16))
    arr0123 = ind_samples[:,q]
    for i in range(len(arr0123)):
                # get row number 
        rownum = int(indexmat[arr0123[i][0],arr0123[i][1]])
                # get col number
        colnum = int(indexmat[arr0123[i][2],arr0123[i][3]])
        fullmat0123[rownum,colnum] = fullmat0123[rownum,colnum] + 1
    #images.append(fullmat0123.flatten()/max(fullmat0123.flatten()))
    #labels.append(correct_config)
    toyplot.matrix((fullmat0123.flatten()/max(fullmat0123.flatten())).reshape(16,16))

[0, 1, 0]


# Training a machine learning model

We can just adapt the above code to train a machine learning model by showing it many different incorrect vs. correct matrices.

In [55]:
# functions to generate a random correctly / incorrectly configured quartet matrices
def random_correct_matrix():
    # pick a random tree from file
    treenum = np.random.choice(range(2000))+1
    #print "Tree number: " + str(treenum)
    thetree = toytree.tree('random_trees/samp'+str(treenum)+'.tre')
    # get all TRUE splits on the tree
    treeqrts = list(decompose_to_quartets(thetree))

    fname = ('tree_seqs/test'+str(treenum)+'.dat')
    with open(fname) as f:
        sequences = f.readlines()
    # remove whitespace characters like `\n` at the end of each line
    sequences = [x.strip() for x in sequences] 
    sequences.pop(0)

    # separate all sequences and tip names
    names = [sequences[i][0:10].strip(" ") for i in range(len(sequences))]
    iso_sequences = [sequences[i][10:].strip(" ") for i in range(len(sequences))]

    # pick a random quartet
    num = int((np.random.choice(range(len(treeqrts)),1)))
    #print "Quartet number = " + str(num)

    qrtnum = num
    true_qrt = np.array(treeqrts[qrtnum])
    correct_config = [1,0,0]
    #print "The true quartet is: " + str(true_qrt)
    #print "The random configuration of this (i.e. when true quartet unknown) is: " + str(tipnames)
    #print "The correct configuration of the random configuration is: " + str([[0,1,2,3],[0,2,1,3], [0,3,1,2]][correct_config.index(1)])
    
    interestednames = copy.deepcopy(true_qrt) # this is a list of four tip names... e.g. ["t1","t2","t3","t4]
    taxa_ids = list(itertools.chain.from_iterable([list(compress(range(len(names)),i)) for i in [[q == i for i in names] for q in interestednames]]))

    tempobj = [iso_sequences[i] for i in taxa_ids]

    # eliminate non-snps

    ind_samples = []
    for i in range(len(tempobj[0])):
        currentbase = ([tempobj[q][i] for q in range(len(tempobj))])
        if (len(set(currentbase)) > 1):
            ind_samples.append(currentbase)
    ind_samples_reset = ind_samples

    # separate sequences by fifth taxon

    ind_samples = np.array(ind_samples_reset)
    ind_samples = np.where(ind_samples=='A',0,ind_samples)
    ind_samples = np.where(ind_samples=='C',1,ind_samples)
    ind_samples = np.where(ind_samples=='G',2,ind_samples)
    ind_samples = np.where(ind_samples=='T',3,ind_samples)
    ind_samples = ind_samples.astype(int)

    #print(correct_config)
    # get the matrices
    indexmat = np.array(range(16))
    indexmat.shape=(4,4)
    # order across matrix is 00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33
    fullmat0123 = np.zeros(shape=(16,16))
    arr0123 = ind_samples # using the correct config
    for i in range(len(arr0123)):
                # get row number 
        rownum = int(indexmat[arr0123[i][0],arr0123[i][1]])
                # get col number
        colnum = int(indexmat[arr0123[i][2],arr0123[i][3]])
        fullmat0123[rownum,colnum] = fullmat0123[rownum,colnum] + 1
    #images.append(fullmat0123.flatten()/max(fullmat0123.flatten()))
    #labels.append(correct_config)
    return (fullmat0123.flatten()/max(fullmat0123.flatten()))


def random_wrong_matrix():
    #pick a random tree from file
    treenum = np.random.choice(range(2000))+1
    #print "Tree number: " + str(treenum)
    thetree = toytree.tree('random_trees/samp'+str(treenum)+'.tre')
    # get all TRUE splits on the tree
    treeqrts = list(decompose_to_quartets(thetree))

    fname = ('tree_seqs/test'+str(treenum)+'.dat')
    with open(fname) as f:
        sequences = f.readlines()
    # remove whitespace characters like `\n` at the end of each line
    sequences = [x.strip() for x in sequences] 
    sequences.pop(0)

    # separate all sequences and tip names
    names = [sequences[i][0:10].strip(" ") for i in range(len(sequences))]
    iso_sequences = [sequences[i][10:].strip(" ") for i in range(len(sequences))]

    # pick a random quartet
    num = int((np.random.choice(range(len(treeqrts)),1)))
    #print "Quartet number = " + str(num)

    qrtnum = num
    true_qrt = np.array(treeqrts[qrtnum])
    correct_config = [1,0,0]
    #print "The true quartet is: " + str(true_qrt)
    #print "The random configuration of this (i.e. when true quartet unknown) is: " + str(tipnames)
    #print "The correct configuration of the random configuration is: " + str([[0,1,2,3],[0,2,1,3], [0,3,1,2]][correct_config.index(1)])
    
    interestednames = copy.deepcopy(true_qrt) # this is a list of four tip names... e.g. ["t1","t2","t3","t4]
    taxa_ids = list(itertools.chain.from_iterable([list(compress(range(len(names)),i)) for i in [[q == i for i in names] for q in interestednames]]))

    tempobj = [iso_sequences[i] for i in taxa_ids]

    # eliminate non-snps

    ind_samples = []
    for i in range(len(tempobj[0])):
        currentbase = ([tempobj[q][i] for q in range(len(tempobj))])
        if (len(set(currentbase)) > 1):
            ind_samples.append(currentbase)
    ind_samples_reset = ind_samples

    ind_samples = np.array(ind_samples_reset)
    ind_samples = np.where(ind_samples=='A',0,ind_samples)
    ind_samples = np.where(ind_samples=='C',1,ind_samples)
    ind_samples = np.where(ind_samples=='G',2,ind_samples)
    ind_samples = np.where(ind_samples=='T',3,ind_samples)
    ind_samples = ind_samples.astype(int)


    #print(correct_config)
    # get the matrices
    indexmat = np.array(range(16))
    indexmat.shape=(4,4)
    # order across matrix is 00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33
    fullmat0123 = np.zeros(shape=(16,16))
    arr0123 = ind_samples[:,[[0,2,1,3],[0,3,1,2]][np.random.binomial(1,.5)]] # pick one of the wrong configs
    for i in range(len(arr0123)):
                # get row number 
        rownum = int(indexmat[arr0123[i][0],arr0123[i][1]])
                # get col number
        colnum = int(indexmat[arr0123[i][2],arr0123[i][3]])
        fullmat0123[rownum,colnum] = fullmat0123[rownum,colnum] + 1
    #images.append(fullmat0123.flatten()/max(fullmat0123.flatten()))
    #labels.append(correct_config)
    return (fullmat0123.flatten()/max(fullmat0123.flatten()))


## Now we can build a training set of vectorized matrices and accompanying "correct" or "incorrect" labels

In [60]:
images = []
labels = []
for loopnum in range(1000):
    if np.random.binomial(1,.5):
        images.append(random_correct_matrix())
        labels.append([1,0])
    else:
        images.append(random_wrong_matrix())
        labels.append([0,1])



KeyboardInterrupt: 

In [61]:
len(images)

987

In [62]:
len(labels)

987

## Now we train and test our model

In [64]:
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 256])
W = tf.Variable(tf.zeros([256, 2]))
b = tf.Variable(tf.zeros([2]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 2])

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy)

saver = tf.train.Saver()

init = tf.global_variables_initializer()
# Launch the graph
with tf.Session() as sess:
    sess.run(init)


    #sess = tf.InteractiveSession()
    #tf.global_variables_initializer().run()

    for _ in range(1000):
      batch = np.random.choice(800, 50)
      batch_xs, batch_ys = np.array([images[i] for i in batch]),np.array([labels[i] for i in batch])
      sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    #print(sess.run(accuracy, feed_dict={x: images[800:987], y_: labels[800:987]}))
    print(sess.run(accuracy, feed_dict={x: images[800:987], y_: labels[800:987]}))

0.962567


## Do we see the same structures in real data? Example with mammals.

In [None]:
# functions
def get_patterns(goodbases):
    patterns = np.empty(shape = (0,4))
    for base in range(len(goodbases)):
        d = {ni: indi for indi, ni in enumerate(set(goodbases[base]))}
        patterns = np.vstack([patterns,(np.vectorize(d.__getitem__)(goodbases[base]))])
    return(patterns.astype(int))

def most_freq_pattern(the_patterns):
    unique_patterns, freqs  = np.unique(the_patterns,axis = 0,return_counts=True)
    return unique_patterns[np.argmax(freqs)]

def f(genes_alltaxa,geneidx,fourtaxa):
        currentgene = [[genes_alltaxa[geneidx][base][taxon] for taxon in fourtaxa] for base in range(len(genes_alltaxa[geneidx]))]
        return np.array(currentgene)

def exclude(fullgene):
    return np.array([sum(fullgene[q])<= 12 and len(set(fullgene[q])) > 1 for q in xrange(len(fullgene))])

# Making matrices

totalseqs = np.genfromtxt("download_simseqs/concat_mammal_genes.gz",dtype='str')
snpmap = np.loadtxt("download_simseqs/concat_mammal_map.gz").astype(int)

totalseqs = totalseqs.view(np.uint8)
totalseqs = np.where(totalseqs==65,0,totalseqs)
totalseqs = np.where(totalseqs==67,1,totalseqs)
totalseqs = np.where(totalseqs==71,2,totalseqs)
totalseqs = np.where(totalseqs==84,3,totalseqs)
genes_alltaxa = [totalseqs[snpmap[0][i]:snpmap[1][i]] for i in range(len(snpmap[0]))]
alltipcombns=np.array(list(itertools.combinations(range(len(totalseqs[0])), 4)))
alltipcombns = alltipcombns.astype(int)

combocounter = 0
orig_file = np.empty(shape = (0,4))
#    np.savetxt(output_path,orig_file)

targetlen = len(alltipcombns)



allpredictedquarts = np.empty(shape = (0,4))
savecounter = 0 # this will be reset

# set your current combination of four taxa
fourtaxa= alltipcombns[np.random.choice(range(targetlen))]


reducedgene = np.empty(shape = (0,4))
for geneidx in range(len(genes_alltaxa)):
    fullgene = f(genes_alltaxa,geneidx,fourtaxa)
    goodbases = fullgene[exclude(fullgene)]
    if len(goodbases) > 0:
        the_patterns = get_patterns(goodbases)
        indices = np.where((the_patterns == most_freq_pattern(the_patterns)).all(axis=1))[0]
        # pick one snp from the most common pattern
        reducedgene = np.vstack([reducedgene,goodbases[int(np.random.choice(indices,1))]])
        # pick some random snps
        reducedgene = np.vstack([reducedgene,goodbases[np.random.choice(range(len(goodbases)),10)]])
#    print(geneidx)
# make index matrix for each pair of bases. This assigns row / col number for full 16x16 matrix
indexmat = np.array(range(16))
indexmat.shape=(4,4)

        # make 16x16 matrix of zeroes
        # order across matrix is 00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33
        # not good use of space
three_possible = []
possible_configs = [[0,1,2,3],[0,2,1,3],[0,3,1,2]]
arr0123 = copy.deepcopy(reducedgene)
arr0123 = arr0123.astype(int)
for q in possible_configs:
    temp_rearrangement = arr0123[:,q]
    fullmat0123 = np.zeros(shape=(16,16))
    for i in range(len(temp_rearrangement)):
                # get row number 
        rownum = int(indexmat[temp_rearrangement[i][0:2][0],temp_rearrangement[i][0:2][1]])
                # get col number
        colnum = int(indexmat[temp_rearrangement[i][2:4][0],temp_rearrangement[i][2:4][1]])
        fullmat0123[rownum,colnum] = fullmat0123[rownum,colnum] + 1
    three_possible.append((fullmat0123.flatten()/max(fullmat0123.flatten())))
print "Four taxa: " + str(fourtaxa)
toyplot.matrix(three_possible[0].reshape(16,16))
toyplot.matrix(three_possible[1].reshape(16,16))
toyplot.matrix(three_possible[2].reshape(16,16))
