In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('../')

In [3]:
import numpy as np
from dataManipulation import *
from utils import summary, summary_raw, get_support_from_mcmc
from vbpi import VBPI

In [4]:
# load the sequence data and estimate the subsplit support
data, taxa = loadData('../data/hohna_datasets_fasta/DS1.fasta', 'fasta')
tree_dict_ufboot, tree_names_ufboot = summary_raw('DS1', '../data/ufboot_data_DS1-11/')
rootsplit_supp_dict, subsplit_supp_dict = get_support_from_mcmc(taxa, tree_dict_ufboot, tree_names_ufboot)
del tree_dict_ufboot, tree_names_ufboot

In [5]:
# load the ground truth
tree_dict_total, tree_names_total, tree_wts_total = summary('DS1', '../data/raw_data_DS1-11/')
emp_tree_freq = {tree_dict_total[tree_name]:tree_wts_total[i] for i, tree_name in enumerate(tree_names_total)}
del tree_dict_total, tree_names_total, tree_wts_total

In [6]:
# set up the model
model = VBPI(taxa, rootsplit_supp_dict, subsplit_supp_dict, data, pden=np.ones(4)/4, subModel=('JC', 1.0),
             emp_tree_freq=emp_tree_freq, psp=True)

In [7]:
# variational parameters
for param in model.parameters():
    print(param.dtype, param.size())

torch.float32 torch.Size([8235])
torch.float32 torch.Size([2771, 2])


In [8]:
# run vbpi.
# the trained model will be saved to 'results/DS1/ufboot_vimco_10_psp.pt'.
test_lb, test_kl_div = model.learn(0.001, maxiter=200000, n_particles=10, warm_start_interval=100000, method='vimco',
                                   save_to_path='../results/DS1/ufboot_vimco_10_psp.pt')

Iter 1000:(75.1s) Lower Bound: -14251.3226 | Loglikelihood: -10710.3184 | KL: 25.867969
Iter 2000:(73.2s) Lower Bound: -10251.6487 | Loglikelihood: -8553.8311 | KL: 25.476149
Iter 3000:(74.8s) Lower Bound: -8767.9784 | Loglikelihood: -7781.5278 | KL: 24.843501
Iter 4000:(75.8s) Lower Bound: -8134.7211 | Loglikelihood: -7522.0947 | KL: 24.317473
Iter 5000:(76.0s) Lower Bound: -7821.5830 | Loglikelihood: -7337.3716 | KL: 23.771647
>>> Iter 5000:(3.3s) Test Lower Bound: -8237.4082
Iter 6000:(77.0s) Lower Bound: -7639.5721 | Loglikelihood: -7227.2485 | KL: 23.118502
Iter 7000:(76.9s) Lower Bound: -7529.7799 | Loglikelihood: -7170.9951 | KL: 22.624282
Iter 8000:(76.4s) Lower Bound: -7450.8802 | Loglikelihood: -7088.6592 | KL: 21.982051
Iter 9000:(76.9s) Lower Bound: -7400.3036 | Loglikelihood: -7114.2925 | KL: 21.217974
Iter 10000:(76.6s) Lower Bound: -7360.2487 | Loglikelihood: -7086.0581 | KL: 20.417465
>>> Iter 10000:(3.2s) Test Lower Bound: -7739.1187
Iter 11000:(76.6s) Lower Bound: -73

Iter 86000:(70.8s) Lower Bound: -7109.0137 | Loglikelihood: -6898.3477 | KL: 0.194484
Iter 87000:(71.3s) Lower Bound: -7108.9195 | Loglikelihood: -6898.2754 | KL: 0.182749
Iter 88000:(70.7s) Lower Bound: -7108.8470 | Loglikelihood: -6898.0054 | KL: 0.173490
Iter 89000:(71.2s) Lower Bound: -7108.9318 | Loglikelihood: -6895.8335 | KL: 0.164472
Iter 90000:(70.3s) Lower Bound: -7108.8392 | Loglikelihood: -6898.3750 | KL: 0.156428
>>> Iter 90000:(3.2s) Test Lower Bound: -7113.7700
Iter 91000:(70.4s) Lower Bound: -7108.8848 | Loglikelihood: -6895.7695 | KL: 0.147041
Iter 92000:(71.3s) Lower Bound: -7108.8084 | Loglikelihood: -6898.1514 | KL: 0.141001
Iter 93000:(70.2s) Lower Bound: -7108.8330 | Loglikelihood: -6898.5000 | KL: 0.134461
Iter 94000:(71.2s) Lower Bound: -7108.8161 | Loglikelihood: -6896.3804 | KL: 0.128637
Iter 95000:(70.6s) Lower Bound: -7108.8253 | Loglikelihood: -6896.9551 | KL: 0.124088
>>> Iter 95000:(3.1s) Test Lower Bound: -7113.9683
Iter 96000:(70.6s) Lower Bound: -7108.

>>> Iter 170000:(3.1s) Test Lower Bound: -7110.7456
Iter 171000:(69.5s) Lower Bound: -7108.7327 | Loglikelihood: -6896.2480 | KL: 0.075119
Iter 172000:(69.9s) Lower Bound: -7108.7497 | Loglikelihood: -6892.7017 | KL: 0.075090
Iter 173000:(69.8s) Lower Bound: -7108.7567 | Loglikelihood: -6895.7524 | KL: 0.075043
Iter 174000:(69.3s) Lower Bound: -7108.7237 | Loglikelihood: -6897.6465 | KL: 0.074982
Iter 175000:(69.8s) Lower Bound: -7108.7211 | Loglikelihood: -6897.1147 | KL: 0.074886
>>> Iter 175000:(3.1s) Test Lower Bound: -7110.6768
Iter 176000:(69.2s) Lower Bound: -7108.7335 | Loglikelihood: -6895.1670 | KL: 0.074827
Iter 177000:(70.4s) Lower Bound: -7108.7253 | Loglikelihood: -6896.8071 | KL: 0.074842
Iter 178000:(69.9s) Lower Bound: -7108.7071 | Loglikelihood: -6895.8433 | KL: 0.074761
Iter 179000:(69.9s) Lower Bound: -7108.7234 | Loglikelihood: -6896.5327 | KL: 0.074765
Iter 180000:(69.9s) Lower Bound: -7108.7299 | Loglikelihood: -6895.4243 | KL: 0.074754
>>> Iter 180000:(3.1s) Tes

In [9]:
# load a trained model/checkpoint for test
model.load_from('../results/DS1/ufboot_vimco_10_psp.pt')

In [10]:
# compute the KL divergence
model.kl_div()

0.07391851477720657

In [15]:
# compute the evidence lower bound
lower_bound_1_sample = np.array([model.lower_bound(n_particles=1, n_runs=1000) for i in range(100)])

In [16]:
np.mean(lower_bound_1_sample), np.std(lower_bound_1_sample)

(-7111.295024414063, 0.977415126557402)

In [17]:
# compute the marginal likelihood
marginal_likelihood_est = np.array([model.lower_bound(n_particles=1000, n_runs=1) for i in range(100)])

In [18]:
np.mean(marginal_likelihood_est), np.std(marginal_likelihood_est)

(-7108.396044921875, 0.1608000751888991)