This notebook compares quartet inference on simulated data by a softmax regression machine learning model to that by SVDquartets.

## Imports

In [36]:
import numpy as np
import h5py
import re
import random
from itertools import compress
import itertools
import math
from operator import itemgetter
import sys
from Bio import Phylo
import tensorflow as tf

## Function
Reads in sequence data generated on a phylogeny, along with the phylogeny. 

Returns a sequence matrix split on t1, t2 | t3, t4, as well as a three-element array of what the real split is in the tree. 

To define the real split on the tree:

[1, 0, 0] = t1, t2 | t3, t4

[0, 1, 0] = t1, t3 | t2, t4

[0, 0, 1] = t1, t4 | t2, t3

In [180]:
def compare_quint_pred_actual(sequencedata, phylogeny):
    # read in data

    fname = sequencedata
    with open(fname) as f:
        sequences = f.readlines()

    # remove whitespace characters like `\n` at the end of each line

    sequences = [x.strip() for x in sequences] 
    sequences.pop(0)

    # get sequences  and identify quintet taxa
    names = [sequences[i][0:10].strip(" ") for i in range(len(sequences))]
    iso_sequences = [sequences[i][10:].strip(" ") for i in range(len(sequences))]
    
    # so we're only testing one possible quartet per tree... Easy to expand this to test every quintet per tree
    interestednames = ["t1","t2","t3","t4"]
    taxa_ids = list(itertools.chain.from_iterable([list(compress(range(10),i)) for i in [[q == i for i in names] for q in interestednames]]))
    
    #taxa_ids = [3,2,8,9]
    #fourtaxa = [names[i] for i in taxa_ids]

    tempobj = [iso_sequences[i] for i in taxa_ids]

    # eliminate non-snps

    ind_samples = []
    for i in range(len(tempobj[0])):
        currentbase = ([tempobj[q][i] for q in range(len(tempobj))])
        if (len(set(currentbase)) > 1):
            ind_samples.append(currentbase)
    ind_samples_reset = ind_samples

    # separate sequences by fifth taxon

    ind_samples = np.array(ind_samples_reset)
    ind_samples = np.where(ind_samples=='A',0,ind_samples)
    ind_samples = np.where(ind_samples=='C',1,ind_samples)
    ind_samples = np.where(ind_samples=='G',2,ind_samples)
    ind_samples = np.where(ind_samples=='T',3,ind_samples)
    ind_samples = ind_samples.astype(int)

    # get the matrices
    indexmat = np.array(range(16))
    indexmat.shape=(4,4)
    # order across matrix is 00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33
    fullmat0123 = np.zeros(shape=(16,16))
    arr0123 = ind_samples
    for i in range(len(arr0123)):
                # get row number 
        rownum = int(indexmat[arr0123[i][0],arr0123[i][1]])
                # get col number
        colnum = int(indexmat[arr0123[i][2],arr0123[i][3]])
        fullmat0123[rownum,colnum] = fullmat0123[rownum,colnum] + 1
    #allmats.append(fullmat0123)

    # predict the true quintet

    # compare with actual quintet

    tree = Phylo.read(phylogeny, 'newick')

    tipnames = [names[i] for i in taxa_ids]
    indexing = np.array([[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]])

    alldists = [tree.distance(tipnames[0],tipnames[1]),
                tree.distance(tipnames[0],tipnames[2]),
                tree.distance(tipnames[0],tipnames[3]),
                tree.distance(tipnames[1],tipnames[2]),
                tree.distance(tipnames[1],tipnames[3]),
                tree.distance(tipnames[2],tipnames[3])]

    min_tree_pairs1, min_pair_val1 = min(enumerate(alldists), key=itemgetter(1))
    
    paired_taxa =  [tipnames[i] for i in list(indexing[min_tree_pairs1])] + [tipnames[i] for i in list(set([0,1,2,3]) ^ set(list(indexing[min_tree_pairs1])))]
    quartet_numbers = list(itertools.chain.from_iterable([list(compress(range(10),i)) for i in [[q == i for i in names] for q in paired_taxa]]))
    
    # is this a 0123, 0213, or 0312?
    correct_config = np.array([(set([taxa_ids[i] for i in [0,1,2,3]][2:4]) == set(quartet_numbers[2:4]) or 
                                    set([taxa_ids[i] for i in [0,1,2,3]][2:4]) == set(quartet_numbers[0:2])),
                                (set([taxa_ids[i] for i in [0,2,1,3]][2:4]) == set(quartet_numbers[2:4]) or 
                                    set([taxa_ids[i] for i in [0,2,1,3]][2:4]) == set(quartet_numbers[0:2])),
                                (set([taxa_ids[i] for i in [0,3,1,2]][2:4]) == set(quartet_numbers[2:4]) or 
                                    set([taxa_ids[i] for i in [0,3,1,2]][2:4]) == set(quartet_numbers[0:2]))]).astype(int)
    
    return(taxa_ids,quartet_numbers,paired_taxa,correct_config,fullmat0123)
    
    
    
    

Now apply the function to all of our tree/sequence combinations, saving the sequence matrices as `images` and the true splits as `labels`:

In [272]:
images = []
labels = []

for i in range(2001)[1:2001]:
    test = compare_quint_pred_actual(sequencedata="tree_seqs/test" + str(i) + ".dat",phylogeny="random_trees/samp" + str(i) + ".phy")
    images.append(test[4].flatten()/max(test[4].flatten()))
    labels.append(test[3])

Now run a very simple (as in, from the tensorflow tutorial) softmax regression model.

In [279]:
x = tf.placeholder(tf.float32, [None, 256])
W = tf.Variable(tf.zeros([256, 3]))
b = tf.Variable(tf.zeros([3]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 3])

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

for _ in range(1000):
  batch = np.random.choice(1000, 50)
  batch_xs, batch_ys = np.array([images[i] for i in batch]),np.array([labels[i] for i in batch])
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: images[1001:2000], y_: labels[1001:2000]}))

0.985986


This shows 98.6% successful prediction of quartet arrangements by our simple machine learning model.

This high rate is easy to accomplish because the seq-gen settings are really basic, and we have lots of loci to work with. With real data, we'd want more sophisticated models and would still probably end up with lower rates of success. Model training relies on simulated sequence data, so making the jump to empirical data might be hard. We'd need a way to test robustness of the model to variation in data. 

Many of the branch lengths on the simulated trees used here ended up being very short, so the high rate of success is still a good sign. 

This is also promising because it could be easily extended beyond four taxa.

## SVDquartets inference on same data

The loop below makes a bunch of quartet decisions on the same set of sequences.

In [316]:
chosenindexlist = []
for w in range(1001,2001):
    sequencedata = "tree_seqs/test" + str(w) + ".dat"
    # read in data

    fname = sequencedata
    with open(fname) as f:
        sequences = f.readlines()

    # remove whitespace characters like `\n` at the end of each line

    sequences = [x.strip() for x in sequences] 
    sequences.pop(0)

    # get sequences  and identify quintet taxa
    names = [sequences[i][0:10].strip(" ") for i in range(len(sequences))]
    iso_sequences = [sequences[i][10:].strip(" ") for i in range(len(sequences))]
    
    # so we're only testing one possible quartet per tree... Easy to expand this to test every quintet per tree
    interestednames = ["t1","t2","t3","t4"]
    taxa_ids = list(itertools.chain.from_iterable([list(compress(range(10),i)) for i in [[q == i for i in names] for q in interestednames]]))
    
    #taxa_ids = [3,2,8,9]
    #fourtaxa = [names[i] for i in taxa_ids]

    tempobj = [iso_sequences[i] for i in taxa_ids]

    # eliminate non-snps

    ind_samples = []
    for i in range(len(tempobj[0])):
        currentbase = ([tempobj[q][i] for q in range(len(tempobj))])
        if (len(set(currentbase)) > 1):
            ind_samples.append(currentbase)
    ind_samples_reset = ind_samples

    # separate sequences by fifth taxon

    ind_samples = np.array(ind_samples_reset)
    ind_samples = np.where(ind_samples=='A',0,ind_samples)
    ind_samples = np.where(ind_samples=='C',1,ind_samples)
    ind_samples = np.where(ind_samples=='G',2,ind_samples)
    ind_samples = np.where(ind_samples=='T',3,ind_samples)
    ind_samples = ind_samples.astype(int)
    
    possible_configs = [[0,1,2,3],[0,2,1,3],[0,3,1,2]]
    # get the matrices
    indexmat = np.array(range(16))
    indexmat.shape=(4,4)
    # order across matrix is 00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33
    fullmat0123 = np.zeros(shape=(16,16))
    arr0123 = ind_samples
    for i in range(len(arr0123)):
                # get row number 
        rownum = int(indexmat[arr0123[i][0],arr0123[i][1]])
                # get col number
        colnum = int(indexmat[arr0123[i][2],arr0123[i][3]])
        fullmat0123[rownum,colnum] = fullmat0123[rownum,colnum] + 1
 

    fullmat0213 = np.zeros(shape=(16,16))
    arr0213 = ind_samples[:,possible_configs[1]]
    for i in range(len(arr0213)):
        # get row number 
        rownum = int(indexmat[arr0213[i][0:2][0],arr0213[i][0:2][1]])
        # get col number
        colnum = int(indexmat[arr0213[i][2:4][0],arr0213[i][2:4][1]])
        fullmat0213[rownum,colnum] = fullmat0213[rownum,colnum] + 1

    fullmat0312 = np.zeros(shape=(16,16))
    arr0312 = ind_samples[:,possible_configs[2]]
    for i in range(len(arr0312)):
        # get row number 
        rownum = int(indexmat[arr0312[i][0:2][0],arr0312[i][0:2][1]])
        # get col number
        colnum = int(indexmat[arr0312[i][2:4][0],arr0312[i][2:4][1]])
        fullmat0312[rownum,colnum] = fullmat0312[rownum,colnum] + 1
    #score the matrices here
    scores = [math.sqrt(np.sum(np.square(np.linalg.svd(fullmat0123)[1][10:15]))),math.sqrt(np.sum(np.square(np.linalg.svd(fullmat0213)[1][10:15]))),math.sqrt(np.sum(np.square(np.linalg.svd(fullmat0312)[1][10:15])))]
    #choose best scoring matrix
    min_index, min_value = min(enumerate(scores), key=itemgetter(1))
    chosenindex = np.array([0,0,0])
    chosenindex[min_index] = 1
    chosenindexlist.append(chosenindex)

Now we tally up the correctly inferred quartets...

In [323]:
truequarts = [labels[i] for i in range(1000,2000)]
tally = 0
for w in range(len(truequarts)):
    if (sum(truequarts[w] == chosenindexlist[w]) == 3):
        tally = tally + 1

And, finally, get the percent correct under SVDquartets:

In [326]:
tally / 1000.

0.935

So 93.5 percent of quartets inferred by SVDquartets are correct, although I won't rule out the possibility that I'm doing the scoring incorrectly.

The softmax model had the benefit of being tailored specifically to the simulated data and making predictions on data generated under the same model. SVDquartets quartet selection performs pretty well regardless, which we can't yet say for softmax or a similar machine learning model.