In [1]:
# This is a variational implementation of the hierarchical_2pl model from the Stan examples.

import VariationalBayes as vb
from VariationalBayes.SparseObjectives import Objective
import VariationalBayes.ExponentialFamilies as ef
import VariationalBayes.Modeling as modeling
import hierarchical_2pl_lib as hier_model

import math

import autograd
import autograd.numpy as np
import numpy as onp

import matplotlib.pyplot as plt
%matplotlib inline

import time

from copy import deepcopy
import scipy as sp
from scipy import optimize
from scipy import stats

import time

In [44]:
import pandas as pd
import os
import json

simulate_data = False

if simulate_data:
    np.random.seed(42)
    num_i = 10
    num_j = 200
    true_params, y, y_prob = hier_model.simulate_data(num_i=num_i, num_j=num_j)
    prior_params = hier_model.get_prior_params()

else:
    # This json file is generated by the script hierarchical_base_mcmc.R in StanSensitivity.
    data_dir = os.path.join(os.environ['GIT_REPO_LOC'],
                            'StanSensitivity/examples/example_models/hierarchical_2pl/')
    json_filename = os.path.join(data_dir, 'hierarchical_2pl_mcmc_results.json')

    json_file = open(json_filename, 'r')
    json_dat = json.load(json_file)
    json_file.close()

    mcmc_df = pd.DataFrame({ 'mean': json_dat['mcmc_results']['mean'],
                             'par': json_dat['mcmc_results']['pars'],  })
    stan_data = json_dat['stan_data']

    num_i = stan_data['I'][0]
    num_j = stan_data['J'][0]
    y = np.reshape(np.array(stan_data['y']), (num_i, num_j))
    # print(y[0:5, 0:5]) # Checked by hand that this is the transpose of the R matrix

    # Set the prior parameters.
    prior_params = hier_model.get_prior_params()

    prior_params['lkj_param'].set(stan_data['lkj_concentration'][0])
    prior_params['theta_mean'].set(stan_data['theta_loc'][0])
    prior_params['theta_var'].set(stan_data['theta_scale'][0] ** 2) 

    s11 = stan_data['mu_1_scale'][0]
    s22 = stan_data['mu_2_scale'][0]
    mu_cov = np.array([[s11 ** 2, 0.], [0., s22 ** 2]])
    prior_params['mu_info'].set(np.linalg.inv(mu_cov))

    tau_param = np.full(2, stan_data['tau_loc'][0])
    prior_params['tau_param'].set(tau_param)

    mu_loc = np.full(2, stan_data['mu_loc'][0])
    prior_params['mu_mean'].set(mu_loc)



In [45]:
vb_params = hier_model.get_vb_params(num_i=num_i, num_j=num_j)
vb_init_par = vb_params.get_free()

model = hier_model.Model(y, vb_params, prior_params, num_draws=10)
objective = Objective(model.vb_params, model.get_kl)
objective.logger.print_every = 10
objective.fun_free(vb_init_par)

array([ 33337.57601947])

In [None]:
vb_time = time.time()
objective.logger.initialize()

# print('Running BFGS')
# vb_opt_bfgs = optimize.minimize(
#     lambda par: objective.fun_free(par, verbose=True), vb_init_par,
#     method='bfgs', jac=objective.fun_free_grad, tol=1e-6)


print('Running Newton Trust Region')
vb_opt = optimize.minimize(
    lambda par: objective.fun_free(par, verbose=True),
    vb_init_par, #vb_opt_bfgs.x,
    method='trust-ncg',
    jac=objective.fun_free_grad,
    hessp=objective.fun_free_hvp, tol=1e-6)
print('Done.')

vb_time = time.time() - vb_time

vb_opt_par = deepcopy(vb_opt.x)

Running Newton Trust Region
Iter  0  value:  [ 33337.57601947]
Iter  10  value:  [ 9954.91095939]
Iter  20  value:  [ 9515.98665859]
Iter  30  value:  [ 9089.73103795]
Iter  40  value:  [ 8941.13649608]


In [None]:
model.vb_params.set_free(vb_opt_bfgs.x)
print(model.vb_params)

In [None]:
model.vb_params.set_free(vb_opt_par)
# print(model.vb_params)

print(model.vb_params['mu'].e())
print(np.log(true_params['mu'].get()))

# print(model.vb_params['log_alpha'])
# print(np.log(true_params['alpha'].get()))

# print(model.vb_params['beta'])
# print(true_params['beta'])

plt.figure()
plt.plot(true_params['theta'].get(), model.vb_params['theta'].e(), 'k.')

plt.figure()
plt.plot(true_params['beta'].get(), model.vb_params['beta'].e(), 'k.')

plt.figure()
plt.plot(true_params['alpha'].get(), model.vb_params['log_alpha'].e_exp(), 'k.')