In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

In [2]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from collections import defaultdict

In [8]:
from models import SBN
from utils import summary, mcmc_treeprob

In [9]:
# we test the first 4 datasets here since other datasets may take large memory to compute the kl divergence.
datasets = ['DS' + str(i) for i in range(1, 5)]

In [None]:
n_rep = 10
kl_div_mat = np.empty((n_rep, 4, 5))
n_trees_mat = np.empty((n_rep, 4))
for ds_num, dataset in enumerate(datasets):
    print 'dataset {}, golden run loading ...'.format(dataset)
    tree_dict_total, tree_names_total, tree_wts_total = summary(dataset, '../data/raw_data_DS1-11/')
    taxa = tree_dict_total[tree_names_total[0]].get_leaf_names()
    emp_tree_freq = {tree_dict_total[tree_name]:tree_wts_total[i] for i, tree_name in enumerate(tree_names_total)}

    for rep in range(n_rep):
        tree_dict, tree_names, tree_wts = mcmc_treeprob('../data/short_run_data_DS1-11/' + dataset + '/rep_{}/'.format(rep+1) + dataset + '.trprobs', 'nexus')
        tree_wts = np.array(tree_wts)/sum(tree_wts)
        print dataset + ', rep {}: {} unique trees'.format(rep+1, len(tree_wts))

        model = SBN(taxa, emp_tree_freq)
        model.bn_train_prob(tree_dict, tree_names, tree_wts)

        kl_div = model.kl_div()
        kl_div_mat[rep, ds_num, 0], kl_div_mat[rep, ds_num, 1] = kl_div['ccd'], kl_div['bn']
        kl_div_mat[rep, ds_num, 4] = kl_div['freq']

        print 'running EM >>>>>>'
        logp = model.bn_em_prob(tree_dict, tree_names, tree_wts, maxiter=1000, abstol=1e-05, monitor=True, MAP=False)
        kl_div_mat[rep, ds_num, 2] = model.kl_div(method='bn')['bn']

        print 'running EM-alpha >>>>>>'
        model = SBN(taxa, emp_tree_freq, alpha=0.0001)
        logp = model.bn_em_prob(tree_dict, tree_names, tree_wts, maxiter=1000, abstol=1e-05, monitor=True, MAP=True)
        kl_div_mat[rep, ds_num, 3] = model.kl_div(method='bn', MAP=True)['bn']

        n_trees_mat[rep, ds_num] = len(tree_wts)

        print dataset + ', rep {}'.format(rep+1)
        print 'kl_div: {}'.format(kl_div_mat[rep, ds_num, :]) 

dataset DS1, golden run loading ...
DS1, rep 1: 1278 unique trees
running EM >>>>>>
Iter 1: current per tree log-likelihood -2.960113
Iter 2: current per tree log-likelihood -2.955872
Iter 3: current per tree log-likelihood -2.954958
Iter 4: current per tree log-likelihood -2.954789
Iter 5: current per tree log-likelihood -2.954687
Iter 6: current per tree log-likelihood -2.954600
Iter 7: current per tree log-likelihood -2.954514
Iter 8: current per tree log-likelihood -2.954425
Iter 9: current per tree log-likelihood -2.954337
Iter 10: current per tree log-likelihood -2.954265
Iter 11: current per tree log-likelihood -2.954215
Iter 12: current per tree log-likelihood -2.954179
Iter 13: current per tree log-likelihood -2.954148
Iter 14: current per tree log-likelihood -2.954115
Iter 15: current per tree log-likelihood -2.954074
Iter 16: current per tree log-likelihood -2.954017
Iter 17: current per tree log-likelihood -2.953934
Iter 18: current per tree log-likelihood -2.953806
Iter 19

Iter 26: current per tree log-likelihood -2.937244
Iter 27: current per tree log-likelihood -2.932259
Iter 28: current per tree log-likelihood -2.927065
Iter 29: current per tree log-likelihood -2.922438
Iter 30: current per tree log-likelihood -2.918847
Iter 31: current per tree log-likelihood -2.916253
Iter 32: current per tree log-likelihood -2.914314
Iter 33: current per tree log-likelihood -2.912691
Iter 34: current per tree log-likelihood -2.911178
Iter 35: current per tree log-likelihood -2.909700
Iter 36: current per tree log-likelihood -2.908254
Iter 37: current per tree log-likelihood -2.906876
Iter 38: current per tree log-likelihood -2.905609
Iter 39: current per tree log-likelihood -2.904487
Iter 40: current per tree log-likelihood -2.903532
Iter 41: current per tree log-likelihood -2.902747
Iter 42: current per tree log-likelihood -2.902122
Iter 43: current per tree log-likelihood -2.901637
Iter 44: current per tree log-likelihood -2.901270
Iter 45: current per tree log-l

In [None]:
kl_div_mean = np.mean(kl_div_mat, axis=0)