In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')

In [3]:
import torch
from dataManipulation import *
from utils import tree_summary, summary, summary_raw, get_support_info
from vbpi import VBPI

In [4]:
# load the sequence data and estimate the subsplit support
data, taxa = loadData('../data/DENV4/DENV4.nexus', 'nexus')
mcmc_support_trees_dict, mcmc_support_trees_wts = summary('../data/DENV4/DENV4_constant_support_short_run', 'nexus', burnin=2501)
rootsplit_supp_dict, subsplit_supp_dict = get_support_info(taxa, mcmc_support_trees_dict)
del mcmc_support_trees_dict, mcmc_support_trees_wts

In [5]:
# load the ground truth
mcmc_sampled_trees_dict, mcmc_sampled_trees_wts, _ = tree_summary('../data/DENV4/DENV4_constant_golden_run.trees', 'nexus', burnin=25001)
emp_tree_freq = {mcmc_sampled_trees_dict[tree_id]: tree_wts for tree_id, tree_wts in sorted(mcmc_sampled_trees_wts.items(), key=lambda x:x[1], reverse=True)}
sample_info = [1994.0 - float('19'+taxon[-2:]) for taxon in taxa]
del mcmc_sampled_trees_dict, mcmc_sampled_trees_wts

In [6]:
# set up the model
model = VBPI(taxa, rootsplit_supp_dict, subsplit_supp_dict, data, pden=np.ones(4)/4., subModel=('JC', 1.0),
             emp_tree_freq=emp_tree_freq, root_height_offset=5.0, clock_rate=1e-3, psp=True,
             sample_info=sample_info, coalescent_type='constant', clock_type='strict')

In [7]:
# variational parameters
for param in model.parameters():
    print(param.dtype, param.size())

torch.float32 torch.Size([425])
torch.float32 torch.Size([982, 2])
torch.float32 torch.Size([2, 1])
torch.float32 torch.Size([2, 1])


In [10]:
# run vbpi.
# the trained model will be saved to 'results/DENV4/mcmc_vimco_10_psp_constant_strict.pt'.
test_lb, test_kl_div = model.learn(0.001, maxiter=200000, n_particles=10, warm_start_interval=50000, method='vimco',
                                   save_to_path='../results/DENV4/mcmc_vimco_10_psp_constant_strict.pt')

Iter 1000:(43.3s) Lower Bound: -4391.7688 | Logll: -6822.3584 | Root Age: 292.0729 | KL: 7.046626
Iter 2000:(44.0s) Lower Bound: -4281.3014 | Logll: -4147.5742 | Root Age: 255.1476 | KL: 6.952720
Iter 3000:(43.5s) Lower Bound: -4242.2542 | Logll: -4284.5283 | Root Age: 223.1223 | KL: 6.932994
Iter 4000:(43.9s) Lower Bound: -4218.5626 | Logll: -4156.1045 | Root Age: 192.2396 | KL: 6.914821
Iter 5000:(43.1s) Lower Bound: -4202.6595 | Logll: -4186.6660 | Root Age: 174.3486 | KL: 6.893417
>>> Iter 5000:(2.3s) Test Lower Bound: -4468.6187
Iter 6000:(43.3s) Lower Bound: -4191.2046 | Logll: -4219.0430 | Root Age: 159.8778 | KL: 6.784171
Iter 7000:(44.6s) Lower Bound: -4182.1438 | Logll: -4221.7549 | Root Age: 141.8777 | KL: 6.724048
Iter 8000:(44.8s) Lower Bound: -4174.2072 | Logll: -4166.3833 | Root Age: 127.9925 | KL: 6.628641
Iter 9000:(43.4s) Lower Bound: -4167.1923 | Logll: -4118.9316 | Root Age: 114.2433 | KL: 6.497280
Iter 10000:(44.5s) Lower Bound: -4161.3970 | Logll: -4106.7075 | Roo

Iter 77000:(43.1s) Lower Bound: -4129.4395 | Logll: -4071.2017 | Root Age: 67.5586 | KL: 0.090789
Iter 78000:(43.3s) Lower Bound: -4129.4663 | Logll: -4075.5200 | Root Age: 67.6179 | KL: 0.089983
Iter 79000:(42.1s) Lower Bound: -4129.4532 | Logll: -4077.9500 | Root Age: 67.4607 | KL: 0.089456
Iter 80000:(43.4s) Lower Bound: -4129.4501 | Logll: -4091.4912 | Root Age: 67.4684 | KL: 0.088091
>>> Iter 80000:(2.3s) Test Lower Bound: -4132.5474
Iter 81000:(43.2s) Lower Bound: -4129.4579 | Logll: -4083.4104 | Root Age: 67.6153 | KL: 0.087137
Iter 82000:(42.1s) Lower Bound: -4129.4265 | Logll: -4075.9065 | Root Age: 67.7705 | KL: 0.086991
Iter 83000:(42.7s) Lower Bound: -4129.3986 | Logll: -4070.4470 | Root Age: 67.3839 | KL: 0.087181
Iter 84000:(43.1s) Lower Bound: -4129.4170 | Logll: -4076.7336 | Root Age: 67.3592 | KL: 0.087120
Iter 85000:(42.1s) Lower Bound: -4129.4187 | Logll: -4075.3955 | Root Age: 67.1883 | KL: 0.087114
>>> Iter 85000:(2.3s) Test Lower Bound: -4132.4302
Iter 86000:(43.3

Iter 153000:(42.7s) Lower Bound: -4129.4222 | Logll: -4078.1643 | Root Age: 67.4306 | KL: 0.081617
Iter 154000:(43.2s) Lower Bound: -4129.4409 | Logll: -4077.5940 | Root Age: 67.4009 | KL: 0.081686
Iter 155000:(42.7s) Lower Bound: -4129.4177 | Logll: -4078.7253 | Root Age: 67.3326 | KL: 0.081636
>>> Iter 155000:(2.3s) Test Lower Bound: -4132.2695
Iter 156000:(43.6s) Lower Bound: -4129.4318 | Logll: -4073.6677 | Root Age: 67.2783 | KL: 0.081587
Iter 157000:(42.4s) Lower Bound: -4129.4602 | Logll: -4072.6575 | Root Age: 67.0638 | KL: 0.081504
Iter 158000:(42.3s) Lower Bound: -4129.4488 | Logll: -4076.3843 | Root Age: 67.1319 | KL: 0.081463
Iter 159000:(43.3s) Lower Bound: -4129.4403 | Logll: -4082.3059 | Root Age: 67.3616 | KL: 0.081462
Iter 160000:(42.0s) Lower Bound: -4129.4291 | Logll: -4077.0435 | Root Age: 67.3667 | KL: 0.081508
>>> Iter 160000:(2.3s) Test Lower Bound: -4131.9297
Iter 161000:(42.8s) Lower Bound: -4129.4557 | Logll: -4095.9087 | Root Age: 67.4459 | KL: 0.081458
Iter 

In [8]:
# load a trained model/checkpoint for test
model.load_from('../results/DENV4/mcmc_vimco_10_psp_constant_strict.pt')

In [9]:
# sample effective population size
with torch.no_grad():
    vbpi_pop_size, _ = model.tree_prior_model.sample_pop_size(n_particles=75000)
    vbpi_pop_size = vbpi_pop_size.exp().squeeze()

In [10]:
vbpi_pop_size

tensor([31.9729, 26.5911, 21.6218,  ..., 26.8124, 22.4801, 41.4166])