## Imports

In [183]:
import h5py
import numpy as np
import pandas as pd
import numba

### Data

In [217]:
## a tetrad result file
tree = "./analysis-tetrad/cli.nhx"
invariants = "./analysis-tetrad/cli.output.h5"
invariants = "/home/deren/local/src/ipyrad/tests/analysis-tetrad/tttt.output.h5"
inp = "/home/deren/local/src/ipyrad/tests/analysis-tetrad/tttt.input.h5"

In [221]:
with h5py.File(inp) as io5:
    print io5['quartets'][idx]

[ 5  8 10 11]


In [288]:
## peek at the invariants matrix database.
## The first matrix is stored by tetrad for each quartet.
## meaning it needs to be rearranged if the ordered taxa ids is not the best
idx = 478
with h5py.File(invariants) as io5:
    print 'quartet:', io5["quartets"][idx]
    print '\nmatrix:'
    print io5["invariants"]["boot0"][idx]
    arr = io5["invariants"]["boot0"][idx]

quartet: [ 5 11  8 10]

matrix:
[[ 0 26 30 31 16  0  0  0 22  0  1  0 17  0  0  2]
 [25  0  0  0 10 53  0  0  0  0  0  0  0  0  0  0]
 [20  1  1  0  0  0  0  0  6  0 63  0  0  0  0  0]
 [20  0  0  1  0  0  0  0  0  0  0  0  6  1  0 57]
 [51  5  1  0  2 15  0  0  0  0  0  0  1  0  0  0]
 [ 1 26  0  0 22  0 34 24  0 22  1  0  0 17  0  1]
 [ 0  0  0  0  0 17  2  0  1  5 50  0  0  0  0  0]
 [ 0  0  0  0  1 21  0  5  0  0  0  0  1  3  0 47]
 [55  0  7  0  1  0  0  0  1  0 19  0  0  0  0  0]
 [ 0  0  0  0  0 56  6  0  0  2 13  0  0  0  0  1]
 [ 4  0 27  0  0  2 24  0 29 26  0 22  0  0 23  0]
 [ 1  0  0  0  0  0  0  0  0  0 22  3  0  1 10 51]
 [49  0  0  3  0  0  0  0  0  0  0  0  2  0  0 15]
 [ 0  0  0  0  0 45  0  9  0  0  0  0  0  2  0 28]
 [ 0  0  0  0  0  0  0  0  1  0 51  9  0  0  1 11]
 [ 3  0  0 14  0  0  0 20  0  0  0 25 24 31 30  0]]


In [290]:
## fill the alternates
mats = np.zeros((3, 16, 16), dtype=np.uint32)
mats[0] = arr
x = np.uint8(0)
for y in np.array([0, 4, 8, 12], dtype=np.uint8):
    for z in np.array([0, 4, 8, 12], dtype=np.uint8):
        mats[1, y:y+np.uint8(4), z:z+np.uint8(4)] = mats[0, x].reshape(4, 4)
        mats[2, y:y+np.uint8(4), z:z+np.uint8(4)] = mats[0, x].reshape(4, 4).T
        x += np.uint8(1)

### plot grid

In [286]:
import toyplot
canvas, table = toyplot.matrix(mats[0], width=300, height=300, margin=10)
table.body.gaps.columns[...] = 1
table.body.gaps.rows[...] = 1

canvas, table = toyplot.matrix(mats[1], width=300, height=300, margin=10)
table.body.gaps.columns[...] = 1
table.body.gaps.rows[...] = 1

canvas, table = toyplot.matrix(mats[2], width=300, height=300, margin=10)
table.body.gaps.columns[...] = 1
table.body.gaps.rows[...] = 1

### Functions

In [323]:
class Hils(object):
    """
    A Class to calculate the Hils statistic given a matrix of invariants.
    """
    def __init__(self, database, boot=0):
        ## open file handles for accessing database
        self._open = True
        self._boot = boot
        self.hdf5 = h5py.File(database)
        self.matrix = self.hdf5["invariants"]
        self.quartets = self.hdf5["quartets"]
        self.nquartets = self.quartets.shape[0]
    
    
    def close_db(self):
        """close the database file"""
        self.hdf5.close()
    
    
    def get_counts_by_idx(self, idx):
        """return site counts for a given index (quartet)"""
        ## get matrix
        mat = self.matrix["boot{}".format(self._boot)][idx, :, :]
        qrt = self.quartets[idx]
        
        ## arrange matrix
        if qrt[1] > qrt[3]:
            mat = alt_mats(mat, 2)
        elif qrt[1] > qrt[2]:
            mat = alt_mats(mat, 1)
            
        ## get counts and format
        df = pd.DataFrame(
            data=count_snps(mat), 
            index=["aabb", "abba", "baba", "aaab"], 
            columns=[idx]).T
        return df
    
    
    def get_h_by_idx(self, idx):
        """
        calculate Hils. This could be numba-fied, but you'd have to work
        with arrays instead of dataframes. This is fine for now.
        """
        ## get site frequencies
        df = self.get_counts_by_idx(idx)
        nsites = df.sum(axis=1).values[0]
        pdf = df/nsites
        pdf.columns = ["p"+i for i in df.columns]
        data = pd.concat([df, pdf], axis=1)
        
        ## choose invariant pattern
        f1 = data.paabb - data.pbaba
        f2 = data.pabba - data.pbaba
        ratio = f1 / f2
        
        ## calculate var, covar
        var_f1 = (1. / nsites) * (
                    data.paabb * (1. - data.paabb) \
                  + data.pbaba * (1. - data.pbaba) \
                  + 2. * data.paabb * data.pbaba)

        var_f2 = (1. / nsites) * (
                    data.pabba * (1. - data.pabba) \
                  + data.pbaba * (1. - data.pbaba) \
                  + 2. * data.pabba * data.pbaba)

        cov_f1_f2 = (1. / nsites) * (
                   -data.paabb * data.pabba \
                  + data.paabb * data.pbaba \
                  + data.pabba * data.pbaba \
                  + data.pbaba * (1. - data.pbaba))

        ## calculate hils
        num = abs(f2 * ratio)
        denom = np.sqrt(var_f2 * (ratio**2) - (2 * cov_f1_f2 * ratio + var_f1))
        H = pd.DataFrame({"Hils":num/denom, "gamma":(f1/f1+f2)}, index=[idx])

        data = pd.concat([df, pdf, H], axis=1)
        return data
    
    
    def run(self):
        """calculate Hils and return table for all idxs in database"""
        stats = pd.concat([self.get_h_by_idx(idx) for idx in xrange(self.nquartets)])
        qrts = ["{},{}|{},{}".format(*i) for i in self.quartets[:]]
        qrts = pd.DataFrame(np.array(qrts), columns=["qrts"])
        return pd.concat([stats, qrts], axis=1)
    
    
@numba.jit(nopython=True)   
def alt_mats(mat, idx):
    """ return alternate rearrangement of matrix"""
    mats = np.zeros((3, 16, 16), dtype=np.uint32)
    mats[0] = arr
    x = np.uint8(0)
    for y in np.array([0, 4, 8, 12], dtype=np.uint8):
        for z in np.array([0, 4, 8, 12], dtype=np.uint8):
            mats[1, y:y+np.uint8(4), z:z+np.uint8(4)] = mats[0, x].reshape(4, 4)
            mats[2, y:y+np.uint8(4), z:z+np.uint8(4)] = mats[0, x].reshape(4, 4).T
            x += np.uint8(1)
    return mats[idx]
        
        
@numba.jit(nopython=True)
def count_snps(mat):
    """JIT func to return counts quickly"""
    ## array to store results
    snps = np.zeros(4, dtype=np.uint16)

    ## get concordant (aabb) pis sites
    snps[0] = np.uint16(\
           mat[0, 5] + mat[0, 10] + mat[0, 15] + \
           mat[5, 0] + mat[5, 10] + mat[5, 15] + \
           mat[10, 0] + mat[10, 5] + mat[10, 15] + \
           mat[15, 0] + mat[15, 5] + mat[15, 10])

    ## get discordant (baba) sites
    for i in range(16):
        if i % 5:
            snps[1] += mat[i, i]

    ## get discordant (abba) sites
    snps[2] = mat[1, 4] + mat[2, 8] + mat[3, 12] +\
              mat[4, 1] + mat[6, 9] + mat[7, 13] +\
              mat[8, 2] + mat[9, 6] + mat[11, 14] +\
              mat[12, 3] + mat[13, 7] + mat[14, 11]

    ## get autapomorphy sites
    snps[3] = (mat.sum() - np.diag(mat).sum()) - snps[2]
    return snps


### Calculating Hils

In [325]:
## initialize the database object
hils = Hils(invariants)

In [326]:
## calculate for some index of the database
hils.get_h_by_idx(10)

Unnamed: 0,aabb,abba,baba,aaab,paabb,pabba,pbaba,paaab,Hils,gamma
10,171,2,6,1292,0.116247,0.00136,0.004079,0.878314,1.396125,0.997281


In [328]:
## calculate for all idxs
dd = hils.run()

## print first 10 results
dd.head(10)

Unnamed: 0,aabb,abba,baba,aaab,paabb,pabba,pbaba,paaab,Hils,gamma,qrts
0,103,15,10,1020,0.089721,0.013066,0.008711,0.888502,1.028327,1.004355,"0,1|2,3"
1,109,13,11,1200,0.08177,0.009752,0.008252,0.900225,0.412507,1.0015,"0,1|2,4"
2,111,12,11,1194,0.083584,0.009036,0.008283,0.899096,0.209568,1.000753,"0,1|2,5"
3,109,13,11,1194,0.08214,0.009797,0.008289,0.899774,0.412506,1.001507,"0,1|2,6"
4,106,15,11,1196,0.079819,0.011295,0.008283,0.900602,0.80176,1.003012,"0,1|2,7"
5,106,15,10,1338,0.072158,0.010211,0.006807,0.910824,1.027544,1.003404,"0,1|2,8"
6,106,15,10,1350,0.071573,0.010128,0.006752,0.911546,1.02755,1.003376,"0,1|2,9"
7,106,11,11,1360,0.071237,0.007392,0.007392,0.913978,,1.0,"0,1|2,10"
8,106,14,10,1382,0.070106,0.009259,0.006614,0.914021,0.834304,1.002646,"0,1|2,11"
9,172,3,6,1296,0.116452,0.002031,0.004062,0.877454,0.990638,0.997969,"0,1|3,4"


In [321]:
## print just the ten most significant results
dd.sort_values(by="Hils", ascending=False).head(10)

Unnamed: 0,aabb,abba,baba,aaab,paabb,pabba,pbaba,paaab,Hils,gamma,qrts
181,245,0,8,1372,0.150769,0.0,0.004923,0.844308,2.771739,0.995077,"1,2|5,7"
292,249,7,0,1331,0.1569,0.004411,0.0,0.838689,2.671646,1.004411,"2,3|5,6"
494,97,9,26,1055,0.081719,0.007582,0.021904,0.888795,2.65865,0.985678,"8,9|10,11"
214,170,1,10,1441,0.104809,0.000617,0.006165,0.888409,2.63621,0.994451,"1,3|6,7"
94,166,0,7,1443,0.102723,0.0,0.004332,0.892946,2.583451,0.995668,"0,3|6,7"
183,152,1,9,1614,0.085586,0.000563,0.005068,0.908784,2.460063,0.995495,"1,2|5,9"
182,151,1,9,1604,0.085552,0.000567,0.005099,0.908782,2.459645,0.995467,"1,2|5,8"
175,244,0,6,1380,0.149693,0.0,0.003681,0.846626,2.412031,0.996319,"1,2|4,7"
203,177,0,6,1450,0.108389,0.0,0.003674,0.887936,2.401952,0.996326,"1,3|4,7"
166,83,19,8,1339,0.057281,0.013112,0.005521,0.924086,2.300529,1.007591,"1,2|3,5"


### Plot results

In [331]:
import toyplot

## distribution of Hils across all quartet edges in dataset
c = toyplot.Canvas(width=350)
a = c.cartesian(xlabel="Hils statistics", ylabel="Frequency")
m = a.bars(
    np.histogram(
        #dd.gamma[~dd.gamma.isnull()],
        dd.Hils[~dd.Hils.isnull()],
        density=True,
        ),
    )

## style axes
a.x.ticks.show = True
a.y.ticks.show = True

In [336]:
import toytree
tre = toytree.tree(tree)
tre.root(wildcard="1")

## convert names back into indexes
snames = sorted(tre.get_tip_labels())
for node in tre.tree.traverse():
    if node.is_leaf():
        node.name = snames.index(node.name)
    
## plot tree
c, a = tre.draw(width=300)