This notebook compares quartet inference on simulated data by a softmax regression machine learning model to that by SVDquartets.

## Imports

In [2]:
import numpy as np
import h5py
import re
import random
from itertools import compress
import itertools
import math
from operator import itemgetter
import sys
from Bio import Phylo
import tensorflow as tf

## Function
Reads in sequence data generated on a phylogeny, along with the phylogeny. 

Returns a sequence matrix split on t1, t2 | t3, t4, as well as a three-element array of what the real split is in the tree. 

To define the real split on the tree:

[1, 0, 0] = t1, t2 | t3, t4

[0, 1, 0] = t1, t3 | t2, t4

[0, 0, 1] = t1, t4 | t2, t3

In [61]:
def compare_quint_pred_actual(sequencedata, phylogeny,tipnames):
    # read in data

    fname = sequencedata
    with open(fname) as f:
        sequences = f.readlines()

    # remove whitespace characters like `\n` at the end of each line

    sequences = [x.strip() for x in sequences] 
    sequences.pop(0)

    # get sequences  and identify quintet taxa
    names = [sequences[i][0:10].strip(" ") for i in range(len(sequences))]
    iso_sequences = [sequences[i][10:].strip(" ") for i in range(len(sequences))]
    
    # so we're only testing one possible quartet per tree... Easy to expand this to test every quintet per tree
    interestednames = tipnames # this should be a list of four tip names... e.g. ["t1","t2","t3","t4]
    taxa_ids = list(itertools.chain.from_iterable([list(compress(range(10),i)) for i in [[q == i for i in names] for q in interestednames]]))
    
    #taxa_ids = [3,2,8,9]
    #fourtaxa = [names[i] for i in taxa_ids]

    tempobj = [iso_sequences[i] for i in taxa_ids]

    # eliminate non-snps

    ind_samples = []
    for i in range(len(tempobj[0])):
        currentbase = ([tempobj[q][i] for q in range(len(tempobj))])
        if (len(set(currentbase)) > 1):
            ind_samples.append(currentbase)
    ind_samples_reset = ind_samples

    # separate sequences by fifth taxon

    ind_samples = np.array(ind_samples_reset)
    ind_samples = np.where(ind_samples=='A',0,ind_samples)
    ind_samples = np.where(ind_samples=='C',1,ind_samples)
    ind_samples = np.where(ind_samples=='G',2,ind_samples)
    ind_samples = np.where(ind_samples=='T',3,ind_samples)
    ind_samples = ind_samples.astype(int)

    # get the matrices
    indexmat = np.array(range(16))
    indexmat.shape=(4,4)
    # order across matrix is 00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33
    fullmat0123 = np.zeros(shape=(16,16))
    arr0123 = ind_samples
    for i in range(len(arr0123)):
                # get row number 
        rownum = int(indexmat[arr0123[i][0],arr0123[i][1]])
                # get col number
        colnum = int(indexmat[arr0123[i][2],arr0123[i][3]])
        fullmat0123[rownum,colnum] = fullmat0123[rownum,colnum] + 1
    #allmats.append(fullmat0123)

    # predict the true quintet

    # compare with actual quintet

    tree = Phylo.read(phylogeny, 'newick')

    tipnames = [names[i] for i in taxa_ids]
    indexing = np.array([[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]])

    alldists = [tree.distance(tipnames[0],tipnames[1]),
                tree.distance(tipnames[0],tipnames[2]),
                tree.distance(tipnames[0],tipnames[3]),
                tree.distance(tipnames[1],tipnames[2]),
                tree.distance(tipnames[1],tipnames[3]),
                tree.distance(tipnames[2],tipnames[3])]

    min_tree_pairs1, min_pair_val1 = min(enumerate(alldists), key=itemgetter(1))
    
    paired_taxa =  [tipnames[i] for i in list(indexing[min_tree_pairs1])] + [tipnames[i] for i in list(set([0,1,2,3]) ^ set(list(indexing[min_tree_pairs1])))]
    quartet_numbers = list(itertools.chain.from_iterable([list(compress(range(10),i)) for i in [[q == i for i in names] for q in paired_taxa]]))
    
    # is this a 0123, 0213, or 0312?
    correct_config = np.array([(set([taxa_ids[i] for i in [0,1,2,3]][2:4]) == set(quartet_numbers[2:4]) or 
                                    set([taxa_ids[i] for i in [0,1,2,3]][2:4]) == set(quartet_numbers[0:2])),
                                (set([taxa_ids[i] for i in [0,2,1,3]][2:4]) == set(quartet_numbers[2:4]) or 
                                    set([taxa_ids[i] for i in [0,2,1,3]][2:4]) == set(quartet_numbers[0:2])),
                                (set([taxa_ids[i] for i in [0,3,1,2]][2:4]) == set(quartet_numbers[2:4]) or 
                                    set([taxa_ids[i] for i in [0,3,1,2]][2:4]) == set(quartet_numbers[0:2]))]).astype(int)
    
    return(taxa_ids,quartet_numbers,paired_taxa,correct_config,fullmat0123)
    
    
    
    

Now apply the function to all of our tree/sequence combinations, saving the sequence matrices as `images` and the true splits as `labels`:

In [62]:
images = []
labels = []

for i in range(2001)[1:2001]:
    test = compare_quint_pred_actual(sequencedata="tree_seqs/test" + str(i) + ".dat",phylogeny="random_trees/samp" + str(i) + ".phy",tipnames=["t1","t2","t3","t4"])
    images.append(test[4].flatten()/max(test[4].flatten()))
    labels.append(test[3])

Now run a very simple (as in, from the tensorflow tutorial) softmax regression model.

In [63]:
x = tf.placeholder(tf.float32, [None, 256])
W = tf.Variable(tf.zeros([256, 3]))
b = tf.Variable(tf.zeros([3]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 3])

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

for _ in range(1000):
  batch = np.random.choice(1000, 50)
  batch_xs, batch_ys = np.array([images[i] for i in batch]),np.array([labels[i] for i in batch])
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: images[1001:2000], y_: labels[1001:2000]}))

0.985986


In [136]:
predictions = sess.run(y, feed_dict={x: [images[7]]})
int(tf.argmax(predictions,1).eval())

0

This shows 98.6% successful prediction of quartet arrangements by our simple machine learning model.

This high rate is easy to accomplish because the seq-gen settings are really basic, and we have lots of loci to work with. With real data, we'd want more sophisticated models and would still probably end up with lower rates of success. Model training relies on simulated sequence data, so making the jump to empirical data might be hard. We'd need a way to test robustness of the model to variation in data. 

Many of the branch lengths on the simulated trees used here ended up being very short, so the high rate of success is still a good sign. 

This is also promising because it could be easily extended beyond four taxa.

## SVDquartets inference on same data

The loop below makes a bunch of quartet decisions on the same set of sequences.

In [316]:
chosenindexlist = []
for w in range(1001,2001):
    sequencedata = "tree_seqs/test" + str(w) + ".dat"
    # read in data

    fname = sequencedata
    with open(fname) as f:
        sequences = f.readlines()

    # remove whitespace characters like `\n` at the end of each line

    sequences = [x.strip() for x in sequences] 
    sequences.pop(0)

    # get sequences  and identify quintet taxa
    names = [sequences[i][0:10].strip(" ") for i in range(len(sequences))]
    iso_sequences = [sequences[i][10:].strip(" ") for i in range(len(sequences))]
    
    # so we're only testing one possible quartet per tree... Easy to expand this to test every quintet per tree
    interestednames = ["t1","t2","t3","t4"]
    taxa_ids = list(itertools.chain.from_iterable([list(compress(range(10),i)) for i in [[q == i for i in names] for q in interestednames]]))
    
    #taxa_ids = [3,2,8,9]
    #fourtaxa = [names[i] for i in taxa_ids]

    tempobj = [iso_sequences[i] for i in taxa_ids]

    # eliminate non-snps

    ind_samples = []
    for i in range(len(tempobj[0])):
        currentbase = ([tempobj[q][i] for q in range(len(tempobj))])
        if (len(set(currentbase)) > 1):
            ind_samples.append(currentbase)
    ind_samples_reset = ind_samples

    # separate sequences by fifth taxon

    ind_samples = np.array(ind_samples_reset)
    ind_samples = np.where(ind_samples=='A',0,ind_samples)
    ind_samples = np.where(ind_samples=='C',1,ind_samples)
    ind_samples = np.where(ind_samples=='G',2,ind_samples)
    ind_samples = np.where(ind_samples=='T',3,ind_samples)
    ind_samples = ind_samples.astype(int)
    
    possible_configs = [[0,1,2,3],[0,2,1,3],[0,3,1,2]]
    # get the matrices
    indexmat = np.array(range(16))
    indexmat.shape=(4,4)
    # order across matrix is 00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33
    fullmat0123 = np.zeros(shape=(16,16))
    arr0123 = ind_samples
    for i in range(len(arr0123)):
                # get row number 
        rownum = int(indexmat[arr0123[i][0],arr0123[i][1]])
                # get col number
        colnum = int(indexmat[arr0123[i][2],arr0123[i][3]])
        fullmat0123[rownum,colnum] = fullmat0123[rownum,colnum] + 1
 

    fullmat0213 = np.zeros(shape=(16,16))
    arr0213 = ind_samples[:,possible_configs[1]]
    for i in range(len(arr0213)):
        # get row number 
        rownum = int(indexmat[arr0213[i][0:2][0],arr0213[i][0:2][1]])
        # get col number
        colnum = int(indexmat[arr0213[i][2:4][0],arr0213[i][2:4][1]])
        fullmat0213[rownum,colnum] = fullmat0213[rownum,colnum] + 1

    fullmat0312 = np.zeros(shape=(16,16))
    arr0312 = ind_samples[:,possible_configs[2]]
    for i in range(len(arr0312)):
        # get row number 
        rownum = int(indexmat[arr0312[i][0:2][0],arr0312[i][0:2][1]])
        # get col number
        colnum = int(indexmat[arr0312[i][2:4][0],arr0312[i][2:4][1]])
        fullmat0312[rownum,colnum] = fullmat0312[rownum,colnum] + 1
    #score the matrices here
    scores = [math.sqrt(np.sum(np.square(np.linalg.svd(fullmat0123)[1][10:15]))),math.sqrt(np.sum(np.square(np.linalg.svd(fullmat0213)[1][10:15]))),math.sqrt(np.sum(np.square(np.linalg.svd(fullmat0312)[1][10:15])))]
    #choose best scoring matrix
    min_index, min_value = min(enumerate(scores), key=itemgetter(1))
    chosenindex = np.array([0,0,0])
    chosenindex[min_index] = 1
    chosenindexlist.append(chosenindex)

Now we tally up the correctly inferred quartets...

In [323]:
truequarts = [labels[i] for i in range(1000,2000)]
tally = 0
for w in range(len(truequarts)):
    if (sum(truequarts[w] == chosenindexlist[w]) == 3):
        tally = tally + 1

And, finally, get the percent correct under SVDquartets:

In [326]:
tally / 1000.

0.935

So 93.5 percent of quartets inferred by SVDquartets are correct, although I won't rule out the possibility that I'm doing the scoring incorrectly.

The softmax model had the benefit of being tailored specifically to the simulated data and making predictions on data generated under the same model. SVDquartets quartet selection performs pretty well regardless, which we can't yet say for softmax or a similar machine learning model.

## Next steps:

*  Mix together models of sequence evolution for training set, perform inference on mixed simulations.
*  Compare success of different types of trainings on empirical inference (even showing consistency would go a long way).
*  Improve machine learning model past single layer.

In [6]:
import itertools
import os
import ipyrad as ip
import subprocess
from ipyrad.assemble.util import IPyradWarningExit, progressbar, Params
from toytree import ete3mini as ete3
import re

In [104]:
alltips = ["t1","t2","t3","t4","t5","t6","t7","t8","t9","t10"]
alltipcombns=list(itertools.combinations(alltips, 4))

In [105]:
i=100
all_mats = [compare_quint_pred_actual(sequencedata="tree_seqs/test" + str(i) + ".dat",phylogeny="random_trees/samp" + str(i) + ".phy",tipnames=q)[4] for q in alltipcombns]



In [117]:
chosenquarts = []
for i in all_mats:
    predictions = sess.run(y, feed_dict={x: [i.flatten()/max(i.flatten())]})
    chosenquarts.append(int(tf.argmax(predictions,1).eval()))


In [220]:
correctquarts = [[alltipcombns[q][i] for i in [[0,1,2,3],[0,2,1,3],[0,3,1,2]][chosenquarts[q]]] for q in range(len(chosenquarts))]
names = ["t1","t2","t3","t4","t5","t6","t7","t8","t9","t10"]
ids = range(len(names))


In [215]:
correctquarts= np.array(correctquarts)

In [233]:
for q in range(len(names)):
    correctquarts = [[re.sub(r'\b'+ names[q] +r'\b', str(ids[q]), correctquarts[w][i]) for i in range(4)] for w in range(len(correctquarts))]


In [23]:
def dump_qmc(quartets,tempfiledir):
    """
    Writes the inferred quartet sets from the database to a text 
    file to be used as input for QMC. Quartets that had no information
    available (i.e., no SNPs) were written to the database as 0,0,0,0
    and are excluded here from the output.
    """

    ## open the h5 database
    #with h5py.File(self.database.output, 'r') as io5:

        ## create an output file for writing
    tempfile = os.path.join(tempfiledir,"quartets.txt")
    with open(tempfile, 'w') as qdump:

        ## pull from db
        #for idx in xrange(0, self.params.nquartets, self._chunksize):
            #quarts = quartets

            ## shuffle and format for qmc
            #np.random.shuffle(quarts)
            chunk = ["{},{}|{},{}".format(*i) for i in quartets]
            qdump.write("\n".join(chunk)+"\n")


def _run_qmc(tempfiledir, tempfilename,treename,tipnames):
    """
    Runs quartet max-cut QMC on the quartets qdump file.
    """

    ## build command
    thetmptree = os.path.join(tempfiledir, "tmptre.phy")
    cmd = [ip.bins.qmc, "qrtt="+tempfilename, "otre="+thetmptree]

    ## run it
    proc = subprocess.Popen(cmd, stderr=subprocess.STDOUT, stdout=subprocess.PIPE)
    res = proc.communicate()
    #if proc.returncode:
    #    print(proc.returncode)
    #    raise IPyradWarningExit(res[1])

    ## parse tmp file written by qmc into a tree and rename it
    with open(thetmptree, 'r') as intree:
        tre = ete3.Tree(intree.read().strip())
        names = tre.get_leaves()
        for name in names:
            name.name = tipnames[int(name.name)]
        tmptre = tre.write(format=9)

    ## save the tree to file
    #if boot:
    #    self.trees.boots = os.path.join(self.dirs, self.name+".boots")
    #    with open(self.trees.boots, 'a') as outboot:
    #        outboot.write(tmptre+"\n")
    #else:
    treepath  = os.path.join(tempfiledir, treename+".tree")
    with open(treepath, 'w') as outtree:
        outtree.write(tmptre)

    ## save the file
    #treepath._save()

In [236]:
dump_qmc(correctquarts,"")


In [243]:
_run_qmc(tempfiledir="",tempfilename = "quartets.txt",treename="mytree",tipnames=names)

## Mammals dataset

In [7]:
import numpy as np
from pathlib2 import Path

### Name your current quartet

In [52]:
fourtaxa = [0,1,2,3]

### Get all independent snps for quartet

In [99]:
genesnps = np.array([]).reshape(0,4)
for gene in range(1,447):
    fname = "download_simseqs/song-mammalian-bio_completely_processed/424genes/relabeled_data/"+ str(gene) +".fasta_relabeled.phy"
    if Path(fname).is_file():
        raw = open(fname, 'r')
        snps = file.read(raw) 
        snps = snps.split('\n')
        # remove whitespace characters like `\n` at the end of each line
        snps = [x.strip() for x in snps] 
        snps.pop(0)
        [snps.pop(i) for i in range(len(snps)) if not len(snps[i])]

        snps = [snps[i].split(" ") for i in range(len(snps))]
        snps = [filter(None, snps[i]) for i in range(len(snps))]

        ids = [snps[i][0] for i in range(len(snps))]
        sequences = [snps[i][1] for i in range(len(snps))]
        fourfullseqs = [sequences[i] for i in fourtaxa]
        snpseqs = np.array([]).reshape(0,4)
        for q in range(len(fourfullseqs[0])):
            current4bases=[fourfullseqs[i][q] for i in range(4)]
            if ((len(set(current4bases).union(set(['A','G','C','T']))) == 4) and (len(set(current4bases)) > 1)):
                snpseqs = np.vstack([snpseqs, current4bases])
        if len(snpseqs):
            genesnps = np.vstack([genesnps,snpseqs[np.random.choice(len(snpseqs))]])

### Make the quartet matrix

In [104]:
snps = np.array(genesnps)
possible_configs = [0,1,2,3]
snps = np.where(snps=='A',0,snps)
snps = np.where(snps=='C',1,snps)
snps = np.where(snps=='G',2,snps)
snps = np.where(snps=='T',3,snps)
snps = snps.astype(int)
finalsnps = snps

# make index matrix for each pair of bases. This assigns row / col number for full 16x16 matrix
indexmat = np.array(range(16))
indexmat.shape=(4,4)

        # make 16x16 matrix of zeroes
        # order across matrix is 00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33
        # not good use of space
fullmat0123 = np.zeros(shape=(16,16))
arr0123 = finalsnps[:,possible_configs]
for i in range(len(arr0123)):
            # get row number 
    rownum = int(indexmat[arr0123[i][0:2][0],arr0123[i][0:2][1]])
            # get col number
    colnum = int(indexmat[arr0123[i][2:4][0],arr0123[i][2:4][1]])
    fullmat0123[rownum,colnum] = fullmat0123[rownum,colnum] + 1

### Predict the correct quartet configuration

In [116]:
prediction = sess.run(y, feed_dict={x: [fullmat0123.flatten()/max(fullmat0123.flatten())]})

In [120]:
[fourtaxa[i] for i in [[0,1,2,3],[0,2,1,3],[0,3,1,2]][int(tf.argmax(prediction,1).eval())]]

[0, 2, 1, 3]

### Now make a prediction for all quartets!

In [235]:
alltipcombns=list(itertools.combinations(range(37), 4))

In [236]:
len(alltipcombns) # this is a lot. But we're tough.

66045

In [59]:
alltipcombns = alltipcombns.astype(int)

In [237]:
random.shuffle(alltipcombns) # maybe just in case we don't get all the way through.

In [83]:
allpredictedquarts = np.array([]).reshape(0,4)
for currentcombn in [alltipcombns[allcom] for allcom in range(18100,66045)]:
    fourtaxa = np.array(currentcombn).astype(int)
    genesnps = np.array([]).reshape(0,4)
    for gene in range(1,447):
        fname = "download_simseqs/song-mammalian-bio_completely_processed/424genes/relabeled_data/"+ str(gene) +".fasta_relabeled.phy"
        if Path(fname).is_file():
            raw = open(fname, 'r')
            snps = file.read(raw) 
            snps = snps.split('\n')
            # remove whitespace characters like `\n` at the end of each line
            snps = [xs.strip() for xs in snps] 
            snps.pop(0)
            [snps.pop(i) for i in range(len(snps)) if not len(snps[i])]

            snps = [snps[i].split(" ") for i in range(len(snps))]
            snps = [filter(None, snps[i]) for i in range(len(snps))]

            ids = [snps[i][0] for i in range(len(snps))]
            sequences = [snps[i][1] for i in range(len(snps))]
            fourfullseqs = [sequences[i] for i in fourtaxa]
            snpseqs = np.array([]).reshape(0,4)
            for q in range(len(fourfullseqs[0])):
                current4bases=[fourfullseqs[i][q] for i in range(4)]
                if ((len(set(current4bases).union(set(['A','G','C','T']))) == 4) and (len(set(current4bases)) > 1)):
                    snpseqs = np.vstack([snpseqs, current4bases])
            if len(snpseqs):
                genesnps = np.vstack([genesnps,snpseqs[np.random.choice(len(snpseqs))]])

    snps = np.array(genesnps)
    possible_configs = [0,1,2,3]
    snps = np.where(snps=='A',0,snps)
    snps = np.where(snps=='C',1,snps)
    snps = np.where(snps=='G',2,snps)
    snps = np.where(snps=='T',3,snps)
    snps = snps.astype(int)
    finalsnps = snps

    # make index matrix for each pair of bases. This assigns row / col number for full 16x16 matrix
    indexmat = np.array(range(16))
    indexmat.shape=(4,4)

            # make 16x16 matrix of zeroes
            # order across matrix is 00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33
            # not good use of space
    fullmat0123 = np.zeros(shape=(16,16))
    arr0123 = finalsnps[:,possible_configs]
    for i in range(len(arr0123)):
                # get row number 
        rownum = int(indexmat[arr0123[i][0:2][0],arr0123[i][0:2][1]])
                # get col number
        colnum = int(indexmat[arr0123[i][2:4][0],arr0123[i][2:4][1]])
        fullmat0123[rownum,colnum] = fullmat0123[rownum,colnum] + 1
    prediction = sess.run(y, feed_dict={x: [(fullmat0123.flatten()/max(fullmat0123.flatten()))]})
    allpredictedquarts = np.vstack([allpredictedquarts,[fourtaxa[i] for i in [[0,1,2,3],[0,2,1,3],[0,3,1,2]][int(tf.argmax(prediction,1).eval())]]])
    print(len(allpredictedquarts))
    

TypeError: Cannot interpret feed_dict key as Tensor: The name 'Orn\t37' looks like an (invalid) Operation name, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".

In [73]:
len(allpredictedquarts)

3870

In [12]:
alltipcombns[4830]

array([ 14.,  29.,  33.,  34.])

In [322]:
range(4830,66045)

66044

In [18]:
len(allpredictedquarts)

2200

In [77]:
test = np.loadtxt("download_simseqs/mammal_quarts.gz")

In [78]:
len(test)

18100

In [55]:
alltipcombns = np.loadtxt("download_simseqs/combn_order.gz")

In [76]:
#np.savetxt("download_simseqs/mammal_quarts.gz",np.vstack([test,allpredictedquarts]))
#np.savetxt("download_simseqs/combn_order.gz",alltipcombns)

In [79]:
with open("download_simseqs/song-mammalian-bio_completely_processed/taxa_dict.txt") as f:
    test = f.readlines()
test = [x.strip() for x in test]
nameskey = [test[i].split("\t") for i in range(len(test))]
[i[0] for i in nameskey] # this gives just the names

['Mac',
 'New',
 'Sor',
 'Gor',
 'Oto',
 'Spe',
 'Ory',
 'Tup',
 'Dip',
 'Tur',
 'Mic',
 'Eri',
 'Och',
 'Lox',
 'Fel',
 'Tar',
 'Pro',
 'Ech',
 'Das',
 'Myo',
 'Mus',
 'Rat',
 'Cav',
 'Cho',
 'Bos',
 'Cal',
 'Pon',
 'Hom',
 'Pan',
 'Sus',
 'Vic',
 'Can',
 'Pte',
 'Equ',
 'Gal',
 'Mon',
 'Orn']

In [80]:
allpredictedquarts = np.loadtxt("download_simseqs/mammal_quarts.gz")
allpredictedquarts = allpredictedquarts.astype(int)

In [81]:
dump_qmc(quartets = allpredictedquarts,tempfiledir= "download_simseqs/")

In [82]:
_run_qmc(tempfiledir = "download_simseqs/", 
         tempfilename="download_simseqs/quartets.txt",
         treename="tree4830.phy",
         tipnames=[i[0] for i in nameskey])