In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('../')

In [2]:
from dataManipulation import *
from utils import tree_summary, summary, summary_raw, get_support_info
from vbpi import VBPI

In [3]:
# load the sequence data and estimate the subsplit support
data, taxa = loadData('../data/HCV/HCV.nexus', 'nexus')
mcmc_support_trees_dict, mcmc_support_trees_ids = summary('../data/HCV/HCV_skyride_support_short_run', 'nexus', burnin=251)
rootsplit_supp_dict, subsplit_supp_dict = get_support_info(taxa, mcmc_support_trees_dict)
del mcmc_support_trees_dict, mcmc_support_trees_ids

In [4]:
# set up the model
model = VBPI(taxa, rootsplit_supp_dict, subsplit_supp_dict, data, pden=np.ones(4)/4., subModel=('JC', 1.0),
             root_height_offset=5.0, clock_rate=7.9e-4, psp=True, coalescent_type='skyride')

In [5]:
# variational parameters
for param in model.parameters():
    print(param.dtype, param.size())

torch.float32 torch.Size([3673])
torch.float32 torch.Size([9690, 2])
torch.float32 torch.Size([62, 2])


In [6]:
# run vbpi.
# the trained model will be saved to 'results/HCV/mcmc_vimco_10_psp_skyride_fixed_rate.pt'.
test_lb, _ = model.learn(0.001, maxiter=200000, n_particles=10, warm_start_interval=50000, method='vimco',
                         save_to_path='../results/HCV/mcmc_vimco_10_psp_skyride_fixed_rate.pt')

Iter 1000:(178.0s) Lower Bound: -8303.1516 | Logprior: -709.4425 | Logll: -7886.3276 | Root Age: 276.3767
Iter 2000:(170.7s) Lower Bound: -8229.5442 | Logprior: -683.7360 | Logll: -8194.3271 | Root Age: 323.8568
Iter 3000:(169.0s) Lower Bound: -8154.3307 | Logprior: -643.6879 | Logll: -8077.0947 | Root Age: 331.0407
Iter 4000:(168.6s) Lower Bound: -8082.2354 | Logprior: -629.9238 | Logll: -7835.6763 | Root Age: 333.3037
Iter 5000:(168.4s) Lower Bound: -8005.8367 | Logprior: -654.0965 | Logll: -8902.6484 | Root Age: 321.8905
>>> Iter 5000:(6.5s) Test Lower Bound: -8581.5615
Iter 6000:(167.9s) Lower Bound: -7947.3722 | Logprior: -602.7484 | Logll: -7665.4434 | Root Age: 323.5475
Iter 7000:(167.9s) Lower Bound: -7899.1694 | Logprior: -570.8593 | Logll: -7841.6689 | Root Age: 320.6855
Iter 8000:(168.9s) Lower Bound: -7861.2569 | Logprior: -576.4910 | Logll: -7574.8428 | Root Age: 313.6620
Iter 9000:(171.5s) Lower Bound: -7829.8090 | Logprior: -589.8684 | Logll: -7835.7935 | Root Age: 308.0

Iter 71000:(161.9s) Lower Bound: -7629.9038 | Logprior: -493.2608 | Logll: -7359.3008 | Root Age: 231.2656
Iter 72000:(162.2s) Lower Bound: -7630.2013 | Logprior: -505.1701 | Logll: -7344.5562 | Root Age: 231.1418
Iter 73000:(162.3s) Lower Bound: -7630.0860 | Logprior: -487.2344 | Logll: -7342.3594 | Root Age: 230.8121
Iter 74000:(162.7s) Lower Bound: -7629.8805 | Logprior: -510.3177 | Logll: -7361.2998 | Root Age: 231.6403
Iter 75000:(162.3s) Lower Bound: -7629.4953 | Logprior: -496.5406 | Logll: -7357.2671 | Root Age: 231.7545
>>> Iter 75000:(6.5s) Test Lower Bound: -7644.0215
Iter 76000:(162.6s) Lower Bound: -7629.6981 | Logprior: -498.8295 | Logll: -7349.1025 | Root Age: 230.5377
Iter 77000:(162.6s) Lower Bound: -7629.4963 | Logprior: -498.1212 | Logll: -7392.4292 | Root Age: 231.4242
Iter 78000:(161.4s) Lower Bound: -7629.5855 | Logprior: -505.1074 | Logll: -7372.7119 | Root Age: 231.7245
Iter 79000:(161.7s) Lower Bound: -7629.5403 | Logprior: -494.3390 | Logll: -7367.4458 | Root 

>>> Iter 140000:(6.5s) Test Lower Bound: -7637.9097
Iter 141000:(160.4s) Lower Bound: -7627.5459 | Logprior: -500.2725 | Logll: -7366.3091 | Root Age: 231.1342
Iter 142000:(160.8s) Lower Bound: -7627.4465 | Logprior: -495.6278 | Logll: -7344.1426 | Root Age: 231.4919
Iter 143000:(160.9s) Lower Bound: -7627.5472 | Logprior: -497.7082 | Logll: -7347.4722 | Root Age: 230.9996
Iter 144000:(160.5s) Lower Bound: -7627.3580 | Logprior: -524.3526 | Logll: -7353.8760 | Root Age: 229.9201
Iter 145000:(161.4s) Lower Bound: -7627.4496 | Logprior: -506.0348 | Logll: -7355.9546 | Root Age: 230.5397
>>> Iter 145000:(6.4s) Test Lower Bound: -7637.2168
Iter 146000:(160.5s) Lower Bound: -7627.4511 | Logprior: -503.9289 | Logll: -7346.7021 | Root Age: 230.6755
Iter 147000:(160.5s) Lower Bound: -7627.3757 | Logprior: -499.1236 | Logll: -7364.0601 | Root Age: 230.2330
Iter 148000:(160.5s) Lower Bound: -7627.3714 | Logprior: -483.7567 | Logll: -7367.0562 | Root Age: 230.9007
Iter 149000:(160.5s) Lower Bound

In [7]:
# sample (log) effective population trajectory 
vbpi_skyride_pop_traj = model.sample_pop_traj(cut_off=250, n_traj=75000)

In [14]:
vbpi_skyride_pop_traj.exp()

tensor([[ 138.8319,  138.8319,  138.8319,  ..., 3861.5222, 3861.5222,
         4277.7002],
        [ 113.7947,  113.7947,  113.7947,  ..., 4056.6475, 4056.6475,
         9845.8291],
        [  87.8090,   87.8090,   87.8090,  ..., 2786.1008, 7535.8096,
         4744.2803],
        ...,
        [ 145.4557,  145.4557,  145.4557,  ..., 4610.2266, 4610.2266,
         5466.1299],
        [  39.2343,   39.2343,   39.2343,  ..., 5445.3076, 4209.6978,
         4070.5056],
        [ 134.4723,  134.4723,  134.4723,  ..., 6724.9614, 2584.2993,
         2584.2993]])

In [8]:
# sample tree height
vbpi_tree_height = model.sample_tree_height(75000)

In [12]:
vbpi_tree_height

tensor([232.4026, 231.2287, 231.0143,  ..., 228.3037, 231.1365, 241.8365])

In [9]:
# sample tree loglikelihood
vbpi_tree_logll = model.sample_tree_loglikelihood(75000)

In [13]:
vbpi_tree_logll

tensor([-7353.8950, -7355.3555, -7345.7427,  ..., -7341.6108, -7349.6147,
        -7344.8442])