### Sandbox document for getting Hils results via Tetrad

## Imports

In [44]:
import ipyrad as ip
import ipyrad.analysis as ipa
import toytree 
import h5py
import ipyparallel as ipp
import numpy as np
import math
## ipcluster start -n20

In [None]:
## conda install ipyrad -c ipyrad
## conda install toytree -c eaton-lab

In [2]:
## up-to-date versions 
print 'ip', ipa.__version__
print 'toytree', toytree.__version__

ip 0.7.14
toytree 0.1.4


In [7]:
data = ip.Assembly("pedicularis")
## set parameters
data.set_params("project_dir", "analysis-ipyrad")
data.set_params("sorted_fastq_path", "fastqs-Ped/*.fastq.gz")
data.set_params("clust_threshold", "0.90")
data.set_params("filter_adapters", "2")
data.set_params("max_Hs_consens", (5, 5))
data.set_params("trim_loci", (0, 5, 0, 0))
data.set_params("output_formats", "psvnkua")

## see/print all parameters
data.get_params()
## run steps 1 & 2 of the assembly
data.run("12")
## run steps 3-6 of the assembly
data.run("3456")


New Assembly: pedicularis
0   assembly_name               pedicularis                                  
1   project_dir                 ./analysis-ipyrad                            
2   raw_fastq_path                                                           
3   barcodes_path                                                            
4   sorted_fastq_path           ./fastqs-Ped/*.fastq.gz                      
5   assembly_method             denovo                                       
6   reference_sequence                                                       
7   datatype                    rad                                          
8   restriction_overhang        ('TGCAG', '')                                
9   max_low_qual_bases          5                                            
10  phred_Qscore_offset         33                                           
11  mindepth_statistical        6                                            
12  mindepth_majrule            6     

In [None]:
pops = data.branch("min11-pops")
pops.populations = {
    "ingroup": (11, [i for i in pops.samples if "prz" not in i]),
    "outgroup" : (0, [i for i in pops.samples if "prz" in i]),
    }
pops.run("7")

## create a branch with no missing data and with outgroups removed
nouts = data.branch("nouts_min11", subsamples=[i for i in pops.samples if "prz" not in i])
nouts.set_params("min_samples_locus", 11)
nouts.run("7")

In [8]:
data = ip.load_json("/Users/pmckenz1/Desktop/projects/quartet_proj/analysis-ipyrad/min4.json")

loading Assembly: min4
from saved path: ~/Desktop/projects/intro_python/analysis-ipyrad/min4.json


In [9]:
## init a tetrad analysis object
tet = ipa.tetrad(
    name=data.name,
    data=data.outfiles.snpsphy,
    mapfile=data.outfiles.snpsmap,
    nboots=10,
    save_invariants=True   ## <- new option to save the arrays
    )

loading seq array [13 taxa x 173131 bp]
max unlinked SNPs per quartet (nloci): 39634


In [10]:
ipyclient = ipp.Client()
ipyclient.ids

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

In [11]:
tet.run(ipyclient)

inferring 715 quartet tree sets
host compute node: [20 cores] on Patricks-MacBook-Pro.local
[####################] 100% generating q-sets | 0:00:00 |  
[####################] 100% initial tree      | 0:00:01 |  
[####################] 100% bootstrap trees   | 0:00:11 |  
[####################] 100% calculating stats | 0:00:01 |  


In [16]:
## a 16x16 matrix for one quartet
with h5py.File(tet.database.output) as db:
    idx = 0
    qrt = db['quartets'][idx]
    arr = db['invariants/boot0']
    print 'inferred quartet:', qrt
    print 'matrix for ordered set:\n', arr[idx, :, :]
    
with h5py.File(tet.database.output) as db:

inferred quartet: [0 2 1 3]
matrix for ordered set:
[[  0 145 557 251  16   3   0   1  60   1  12   0  25   0   0   0]
 [ 42  26   1   2   3  42   0   1   0   0   1   0   0   1   0   0]
 [155   1  75   2   0   0   0   0  12   0 106   1   0   0   1   0]
 [ 65   0   0  15   0   0   0   1   0   0   0   0   1   0   1  30]
 [ 16   0   0   1   8  68   1   3   0   0   0   0   0   0   0   0]
 [  2  35   1   2 157   0 112 481   0  21   3   2   2 111   1   8]
 [  0   0   0   0   0  46  13   2   0   1  11   0   0   0   0   0]
 [  0   1   0   0   4 201   7  70   0   0   0   0   0   9   1  68]
 [ 49   0  10   2   1   0   0   0  55   3 189   2   0   0   0   0]
 [  0   0   0   0   0  24   0   1   2  11  32   0   0   0   0   0]
 [  9   1  99   1   0   1  15   0 496 118   0 129   0   1  39   0]
 [  0   0   2   0   0   0   0   0   3   3  71  11   0   1   2  17]
 [ 42   0   0   3   0   3   0   0   0   0   0   0  26   1   3  51]
 [  0   0   0   0   0 127   0  10   0   0   0   0   1  56   2 113]
 [  0   0 

In [19]:
f = h5py.File(tet.database.output, 'r')

In [158]:
arr = f['invariants']['boot0'][0]

mats = np.zeros((3, 16, 16), dtype=np.uint32)
mats[0] = arr
x = np.uint8(0)
for y in np.array([0, 4, 8, 12], dtype=np.uint8):
    for z in np.array([0, 4, 8, 12], dtype=np.uint8):
        mats[1, y:y+np.uint8(4), z:z+np.uint8(4)] = mats[0, x].reshape(4, 4)
        #mats[2, y:y+np.uint8(4), z:z+np.uint8(4)] = mats[0, x].reshape(4, 4).T
        x += np.uint8(1)
x = np.uint8(0)
for z in np.array([0,1,2,3]):
    for y in np.array([0,4,8,12]):
        mats[2,:,x] = mats[0,:,(y+z)]
        x += np.uint8(1)

[calcHils(mats[0]),calcHils(mats[1]),calcHils(mats[2])]

['Parental taxa are more closely related than hybrid. Discard this.',
 '0.401917492623',
 '0.397271047622']

In [None]:
class Hils(object):
    """
    A Class to calculate the Hils statistic given a matrix of invariants.
    """
    def __init__(self, database, boot=0):
        ## open file handles for accessing database
        self._open = True
        self._boot = boot
        self.hdf5 = h5py.File(database)
        self.matrix = self.hdf5["invariants"]
        self.quartets = self.hdf5["quartets"]
        self.nquartets = self.quartets.shape[0]
    
    
    def close_db(self):
        """close the database file"""
        self.hdf5.close()
    
    
    def get_counts_by_idx(self, idx):
        """return site counts for a given index (quartet)"""
        ## get matrix
        mat = self.matrix["boot{}".format(self._boot)][idx, :, :]
        qrt = self.quartets[idx]
        
        ## arrange matrix
        if qrt[1] > qrt[3]:
            mat = alt_mats(mat, 2)
        elif qrt[1] > qrt[2]:
            mat = alt_mats(mat, 1)
            
        ## get counts and format
        df = pd.DataFrame(
            data=count_snps(mat), 
            index=["aabb", "abba", "baba", "aaab"], 
            columns=[idx]).T
        return df
    
    
    def get_h_by_idx(self, idx):
        """
        calculate Hils. This could be numba-fied, but you'd have to work
        with arrays instead of dataframes. This is fine for now.
        """
        ## get site frequencies
        df = self.get_counts_by_idx(idx)
        nsites = df.sum(axis=1).values[0]
        pdf = df/nsites
        pdf.columns = ["p"+i for i in df.columns]
        data = pd.concat([df, pdf], axis=1)
        
        ## choose invariant pattern
        f1 = data.paabb - data.pbaba
        f2 = data.pabba - data.pbaba
        ratio = f1 / f2
        
        ## calculate var, covar
        var_f1 = (1. / nsites) * (
                    data.paabb * (1. - data.paabb) \
                  + data.pbaba * (1. - data.pbaba) \
                  + 2. * data.paabb * data.pbaba)

        var_f2 = (1. / nsites) * (
                    data.pabba * (1. - data.pabba) \
                  + data.pbaba * (1. - data.pbaba) \
                  + 2. * data.pabba * data.pbaba)

        cov_f1_f2 = (1. / nsites) * (
                   -data.paabb * data.pabba \
                  + data.paabb * data.pbaba \
                  + data.pabba * data.pbaba \
                  + data.pbaba * (1. - data.pbaba))

        ## calculate hils
        num = abs(f2 * ratio)
        denom = np.sqrt(var_f2 * (ratio**2) - (2 * cov_f1_f2 * ratio + var_f1))
        H = pd.DataFrame({"Hils":num/denom, "gamma":(f1/f1+f2)}, index=[idx])

        data = pd.concat([df, pdf, H], axis=1)
        return data
    
    
    def run(self):
        """calculate Hils and return table for all idxs in database"""
        stats = pd.concat([self.get_h_by_idx(idx) for idx in xrange(self.nquartets)])
        qrts = ["{},{}|{},{}".format(*i) for i in self.quartets[:]]
        qrts = pd.DataFrame(np.array(qrts), columns=["qrts"])
        return pd.concat([stats, qrts], axis=1)
    
    
@numba.jit(nopython=True)   
def alt_mats(mat, idx):
    """ return alternate rearrangement of matrix"""
    mats = np.zeros((3, 16, 16), dtype=np.uint32)
    mats[0] = arr
    x = np.uint8(0)
    for y in np.array([0, 4, 8, 12], dtype=np.uint8):
        for z in np.array([0, 4, 8, 12], dtype=np.uint8):
            mats[1, y:y+np.uint8(4), z:z+np.uint8(4)] = mats[0, x].reshape(4, 4)
            mats[2, y:y+np.uint8(4), z:z+np.uint8(4)] = mats[0, x].reshape(4, 4).T
            x += np.uint8(1)
    return mats[idx]
        
        
@numba.jit(nopython=True)
def count_snps(mat):
    """JIT func to return counts quickly"""
    ## array to store results
    snps = np.zeros(4, dtype=np.uint16)

    ## get concordant (aabb) pis sites
    snps[0] = np.uint16(\
           mat[0, 5] + mat[0, 10] + mat[0, 15] + \
           mat[5, 0] + mat[5, 10] + mat[5, 15] + \
           mat[10, 0] + mat[10, 5] + mat[10, 15] + \
           mat[15, 0] + mat[15, 5] + mat[15, 10])

    ## get discordant (baba) sites
    for i in range(16):
        if i % 5:
            snps[1] += mat[i, i]

    ## get discordant (abba) sites
    snps[2] = mat[1, 4] + mat[2, 8] + mat[3, 12] +\
              mat[4, 1] + mat[6, 9] + mat[7, 13] +\
              mat[8, 2] + mat[9, 6] + mat[11, 14] +\
              mat[12, 3] + mat[13, 7] + mat[14, 11]

    ## get autapomorphy sites
    snps[3] = (mat.sum() - np.diag(mat).sum()) - snps[2]
    return snps


In [155]:
def calcHils(invmat, Nreq = 10, returnf = False, returnp = False, returnall = False,returnnum = False):
    invmat = invmat.astype(float)
    comb_dict = dict(zip([00,01,02,03,10,11,12,13,20,21,22,23,30,31,32,33], [0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15]))
    num_iijj = (invmat[comb_dict[00],comb_dict[11]] + invmat[comb_dict[00],comb_dict[22]] +
        invmat[comb_dict[00],comb_dict[33]] + invmat[comb_dict[11],comb_dict[00]] + invmat[comb_dict[11],comb_dict[22]] +
        invmat[comb_dict[11],comb_dict[33]] + invmat[comb_dict[22],comb_dict[00]] + invmat[comb_dict[22],comb_dict[11]] +
        invmat[comb_dict[22],comb_dict[33]] + invmat[comb_dict[33],comb_dict[00]] + invmat[comb_dict[33],comb_dict[11]] +
        invmat[comb_dict[33],comb_dict[22]])

    num_ijji = (invmat[comb_dict[01],comb_dict[10]] + invmat[comb_dict[02],comb_dict[20]] +
        invmat[comb_dict[03],comb_dict[30]] + invmat[comb_dict[10],comb_dict[01]] + invmat[comb_dict[12],comb_dict[21]] +
        invmat[comb_dict[13],comb_dict[31]] + invmat[comb_dict[20],comb_dict[02]] + invmat[comb_dict[21],comb_dict[12]] +
        invmat[comb_dict[23],comb_dict[32]] + invmat[comb_dict[30],comb_dict[03]] + invmat[comb_dict[31],comb_dict[13]] +
        invmat[comb_dict[32],comb_dict[23]])

    num_ijij = (invmat[comb_dict[01],comb_dict[01]] + invmat[comb_dict[02],comb_dict[02]] +
        invmat[comb_dict[03],comb_dict[03]] + invmat[comb_dict[10],comb_dict[10]] + invmat[comb_dict[12],comb_dict[12]] +
        invmat[comb_dict[13],comb_dict[13]] + invmat[comb_dict[20],comb_dict[20]] + invmat[comb_dict[21],comb_dict[21]] +
        invmat[comb_dict[23],comb_dict[23]] + invmat[comb_dict[30],comb_dict[30]] + invmat[comb_dict[31],comb_dict[31]] +
        invmat[comb_dict[32],comb_dict[32]])
    [num_iijj,num_ijji,num_ijij]
    if (num_ijij == 0 and num_ijji == 0):
        return("No ijij or ijji are present in data (not enough data)")
    N = sum(map(sum, invmat))
    if (N <= Nreq):
        return("Not enough snps.")
    # calculate probability, add .05 to counts in case some of them are 0
    p_iijj = (num_iijj + .05)/N
    p_ijji = (num_ijji + .05)/N
    p_ijij = (num_ijij + .05)/N
    
    if (p_ijij > max([p_iijj,p_ijji])):
        return("Parental taxa are more closely related than hybrid. Discard this.")
    
    f1 = p_iijj - p_ijij
    f2 = p_ijji - p_ijij
    if not(f2):
        p_ijji = (num_ijji + 1. + .05)/N
        f2 = p_ijji - p_ijij
    rat_f1_f2 = f1/f2

    var_f1 = (1./N) * ( p_iijj*(1.-p_iijj) + p_ijij*(1.-p_ijij) + 2.*p_iijj*p_ijij )
    var_f2 = (1./N) * ( p_ijji*(1.-p_ijji) + p_ijij*(1.-p_ijij) + 2.*p_ijji*p_ijij )

    cov_f1_f2 = (1./N) * ( -p_iijj*p_ijji + p_iijj*p_ijij + p_ijji*p_ijij + p_ijij*(1.-p_ijij))

    H = abs(f2 * rat_f1_f2) / math.sqrt( var_f2*(rat_f1_f2**2.) - 2.*cov_f1_f2*rat_f1_f2 + var_f1 )
    if returnf:
        return [H, f1, f2];
    if returnp:
        return [H, p_iijj,p_ijji,p_ijij];
    if returnall:
        return [H, f1, f2, p_iijj,p_ijji,p_ijij];
    if returnnum:
        return [num_iijj,num_ijji,num_ijij];
    if(num_ijij-num_ijji == 0):
        return('*'+str(H))
    else:
        return str(H);
def calcp(z):
    p = st.norm.sf(abs(z))*2
    return p;
def isfloat(value):
  try:
    float(value)
    return True
  except ValueError:
    return False

In [156]:
arr = f['invariants']['boot0'][0]


mats = np.zeros((3, 16, 16), dtype=np.uint32)
mats[0] = arr
x = np.uint8(0)
for y in np.array([0, 4, 8, 12], dtype=np.uint8):
    for z in np.array([0, 4, 8, 12], dtype=np.uint8):
        mats[1, y:y+np.uint8(4), z:z+np.uint8(4)] = mats[0, x].reshape(4, 4)
        #mats[2, y:y+np.uint8(4), z:z+np.uint8(4)] = mats[0, x].reshape(4, 4).T
        x += np.uint8(1)
x = np.uint8(0)
for z in np.array([0,1,2,3]):
    for y in np.array([0,4,8,12]):
        mats[2,:,x] = mats[0,:,(y+z)]
        x += np.uint8(1)

[calcHils(mats[0]),calcHils(mats[1]),calcHils(mats[2])]

['Parental taxa are more closely related than hybrid. Discard this.',
 '0.401917492623',
 '0.397271047622']

In [157]:
sum(sum(mats[0]))

6108

In [146]:
mats[0,:,0]

array([  0,  79, 235,  96,  55,   2,   0,   1, 213,   1,   5,   0, 106,
         1,   0,   1], dtype=uint32)

In [104]:
mats = np.zeros((3, 16, 16), dtype=np.uint32)
mats[0] = arr
x = np.uint8(0)

In [115]:
mats[0,0].reshape(4,4).T

array([[  0,  55, 211,  82],
       [ 65,   3,   1,   0],
       [219,   0,  12,   1],
       [ 90,   0,   0,   2]], dtype=uint32)