In [1]:
import numpy as np
import os
import sys
sys.path.append('../')
import mlx as ml 
import warnings
import operator
import pickle
import glob
import pylab as plt
%matplotlib inline
import pandas as pd
from tqdm import tqdm
DEBUG=False

In [75]:
def dictprod(dict_,a=1.0):
    '''
        given a dict of probability distributions 
        represented as such: {'key1': val1, ... ,'keyn':valn}
        multiply all values with the second argument `a`
    '''
    return {key:value*a for (key,value) in dict_.iteritems()}
 
def normalizedict(dict_):
    '''
        given a dict represented as such: {'key1': val1, ... ,'keyn':valn}
        scale all values such that they sum to 1.0    
    '''
    s=0.0
    for key in dict_.keys():
        s=s+dict_[key]
    return {key:(value/s) for (key,value) in dict_.iteritems()}

def mergedistributions(dist_):
    '''
        given a dict of dicts, each represented as such: {'key1': val1, ... ,'keyn':valn}
        we retun a combined dict, where values corresponding  to key1 is the average over 
        all component dicts
    '''
    num=len(dist_.keys())
    key_list=[]
    for key in dist_.keys():
        key_list=np.append(key_list,dist_[key].keys())
        
    D={}
    for key in key_list:
        D[key]=0.0
        
    for count in dist_.keys():
        for key_ in dist_[count].keys():
            if key_ in dist_[count]:
                D[key_]=D[key_]+dist_[count][key_]
    return {key:value/(num+0.0) for (key,value) in D.iteritems() }

def getMergedDistribution(tree,cond={}):
    '''
        get distribution over keys given particular
        constriants (cond) on the decision tree
        
        Arguments:
        
        tree: decision tree returned by mlx.py
        cond: conditions that specify constraints
              on the decision tree
        
    '''
    node_id_map={feature_name:np.array([], dtype=int) for (i,feature_name) in tree.feature.iteritems()}
    for (i,feature_name) in tree.feature.iteritems():
        node_id_map[feature_name]=np.append(node_id_map[feature_name],int(i))
    
    if DEBUG:
        print(node_id_map)
    #propagate to find current nodes
    children={i:set() for i in cond.keys()}
    for feature_name in cond.keys():
        for node_id in tree.feature:
            if tree.feature[node_id] == feature_name:
                children[feature_name]=children[feature_name].union(tree.children[node_id])
    if DEBUG:
        print(children)

    current_active_nodes=np.array([],int)
    for feature_name in cond.keys():
        for child in children[feature_name]:
            for parent in node_id_map[feature_name]:
                if (parent,child) in tree.edge_cond_:
                    for edge_var in cond[feature_name]:
                        if edge_var in tree.edge_cond_[(parent,child)]:
                            if DEBUG:
                                print(parent,child,"::",tree.edge_cond_[(parent,child)])
                            current_active_nodes=np.append(current_active_nodes,child)
    
    S=0.0
    if current_active_nodes.size == 0:
        current_active_nodes=np.array([1],int)
    for i in current_active_nodes:
        S=S+tree.num_pass_[i]
        
    indexed_dist={i:dictprod(tree.class_pred_[i],tree.num_pass_[i]/S) for i in current_active_nodes}
    dist_=normalizedict(mergedistributions(indexed_dist))
        
    if DEBUG:
        print(children)
        print(current_active_nodes)
        print("ID",indexed_dist)
        print("MD",mergedistributions(indexed_dist))
        print("ND",normalizedict(mergedistributions(indexed_dist)))
        
    return dist_  
    
def sampleTree(tree,cond={},sample='mle',DIST=False,NUMSAMPLE=10):
    '''
        draw sample from decision tree
        specified in the format that 
        mlx.py returns
        
        Arguments:
        
        1. cond: dict of the format {'name': value, 'name1': value1,...}
                 specifies the constraints on the decision tree.
                 example: {'RBM34':'C','SOX2': 'A'}
        
        Note--> we can use arbitrary cond argument, irrespective of if the
        names are in the decision tree at all or not. Also, we can use 
        an empty cond dict, which corresponds to the unconstrained tree.
        In all these cases, it makes sense to ask what is the distribution on the 
        keys that the decision tree outputs, and we attempt to compute that.
        
        2. sample: 'mle'|'random' 
                   if 'mle' then return the value with maximum probability.
                   if 'random' then makes random choice NUMSAMPLE times 
                   and returns the result.
        
        3. DIST: TRUE|FALSE
                 if TRUE returns the distribution from the tree 
                 after applying the constraints
    '''
    dist_=getMergedDistribution(tree,cond=cond)
    if sample is 'mle':
        sample=max(dist_.iteritems(), key=operator.itemgetter(1))[0]
    if sample is 'random':
        probs = dist_.values()
        keys =  dist_.keys()

        sample = np.random.choice(keys,NUMSAMPLE, replace=True, p=probs)
    if DIST:
        return sample,dist_
    return sample

def getFmap(PATH_TO_TREES):
    F={}
    TREE={}
    TREES=glob.glob(PATH_TO_TREES)
    for filename in TREES:
        with open(filename,'rb') as f:
            TR = pickle.load(f)
        f.close()
        index=os.path.splitext(os.path.basename(filename))[0].split('_')[-1]
        #print index
        F[index]=[]
        TREE[index]=TR
        for key,value in TR.feature.iteritems():
            if not TR.TREE_LEAF[key]:
                F[index]=np.append(F[index],value)
    return F,TREE

def getPerturbation(seq,PATH_TO_TREES):
    F,TREES=getFmap(PATH_TO_TREES)
    P={}
    for KEY in F.keys():
        I=[int(x.replace('P','')) for x in F[KEY]]
        DICT_={'P'+str(i):seq[i] for i in I}
        D=sampleTree(TREES[KEY],DICT_,sample='random',DIST=True)[1]
        
        P[KEY]=[D[x] for x in sorted(D.keys()) ]
    return P

def klscore(p1,p2):
    
    if any(np.array(p2)<=0):
        return np.nan
    
    return np.array([p1[i]*np.log2(p1[i]/p2[i]) for i in range(len(p1))]).sum()

def jsdiv(p1,p2,smooth=True):
    
    
    p1=np.array(p1)
    p2=np.array(p2)
    
    p1=(p1+0.001)/1.001
    p2=(p2+0.001)/1.001
    
    p=0.5*(p1+p2)
    return 0.5*(klscore(p1,p)+klscore(p2,p))

In [47]:
PATH_TO_TREES='../../cchf/cchfl_trees/*pkl'
gn=pd.read_csv('../../cchf/cchfl_test.csv')
seq=gn.loc[0]

In [48]:
F,TREES=getFmap(PATH_TO_TREES)

In [8]:
KEY='P93'
I=[int(x.replace('P','')) for x in F[KEY]]
v=[seq[i] for i in I]
DICT_={'P'+str(i):seq[i] for i in I}
DICT_

{'P141': 'A', 'P4651': 'A', 'P60': 'C'}

In [9]:
sampleTree(TREES[KEY],DICT_,sample='random',DIST=True)[1]

{'A': 0.0, 'C': 0.07647202540923528, 'G': 0.0, 'T': 0.9235279745907646}

In [51]:
P=getPerturbation(seq,PATH_TO_TREES)

In [52]:
P['P1008']

[0.0, 0.43321917808219174, 0.0, 0.5667808219178082]

In [35]:
Q[3001]

0.16216216216216217

In [53]:
P['P3000']

[0.00693069306930693,
 0.02079207920792079,
 0.00693069306930693,
 0.9653465346534653]

In [None]:
seq0=gn.loc[1]
seq1=gn.loc[20]

In [54]:
P0=getPerturbation(seq0,PATH_TO_TREES)
P1=getPerturbation(seq1,PATH_TO_TREES)

In [72]:
print P0['P3000']
print P1['P3000']

print klscore(P1['P3000'],P0['P3000'])
print jsdiv(P1['P3000'],P0['P3000'])
print jsdiv(P0['P3000'],P1['P3000'])


[0.731958762886598, 0.18253486961795026, 0.02850212249848393, 0.05700424499696786]
[0.00693069306930693, 0.02079207920792079, 0.00693069306930693, 0.9653465346534653]
3.814558727504776
0.751990812212272
0.751990812212272


In [73]:
def getDistance(seq0,seq1,PATH_TO_TREES):
    P0=getPerturbation(seq0,PATH_TO_TREES)
    P1=getPerturbation(seq1,PATH_TO_TREES)
    S=0.0
    nCount=0
    for key0 in P0.keys():
        if key0 in P1.keys():
            S=S+jsdiv(P0[key0],P1[key0])
            nCount=nCount+1
    if nCount == 0:
        nCount=1
    return S/(nCount+0.0)

In [76]:
getDistance(seq1,gn.loc[69],PATH_TO_TREES)

0.03171281029655304

In [None]:
getDistance(seq1,gn.loc[19],PATH_TO_TREES)

In [81]:
getDistance(gn.loc[7],gn.loc[33],PATH_TO_TREES)

0.0136664180544655

In [78]:
D={}
for i in tqdm(range(20)):
    for j in range(20):
        if i< j:
            D[(i,j)] = getDistance(gn.loc[i],gn.loc[j],PATH_TO_TREES)


  0%|          | 0/20 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
H=np.zeros([71,71])
for key,value in D.iteritems():
    i=key[0]
    j=key[1]
    H[i][j]=value
    H[j][i]=value