In [None]:
import matplotlib.pyplot as plt
import sys, os
sys.path.append('../')
result_dir = os.path.join(os.path.dirname(os.getcwd()), 'result', 'oup')
figure_dir = os.path.join(os.path.dirname(os.getcwd()), 'figure')
from torch import nn, Tensor
from scipy.stats import entropy

import seaborn as sns

from src.metric import c2st
from src.simulator.oup import oup
import torch
import torch.nn.functional as F
import numpy as np
from src.plot import plot_recovery
import matplotlib
from matplotlib.patches import Patch

plt.rcParams['text.latex.preamble'] = r'\usepackage{times}'

In [3]:
mc_ref = torch.load(os.path.join(result_dir, 'oup_mc_n_10000_four_param.pt'), map_location = torch.device('cpu'))

  mc_ref = torch.load(os.path.join(result_dir, 'oup_mc_n_10000_four_param.pt'), map_location = torch.device('cpu'))
  from .autonotebook import tqdm as notebook_tqdm


In [4]:
simulator = oup(T = 10, dt = 0.1, sample_theta = True, three_param = False)

In [11]:
x = np.load(os.path.join(result_dir, 'oup_x_4_param.npy'))
theta = np.load(os.path.join(result_dir, 'oup_theta_4_param.npy'))
x = torch.from_numpy(x).float()
theta = torch.from_numpy(theta).float()

In [12]:
def obtain_nlpd_for_multi_run(theta, x, n_tl_str: str, n_mlmc_str: str):
    patience_list = [1, 20, 100, 1000]
    n_sim = 20
    mlmc_nlpd = np.zeros((n_sim, ))
    tl_nlpd = np.zeros((n_sim, len(patience_list)))

    for i in range(n_sim):
        for j in range(len(patience_list)):
            patience = patience_list[j]
            tl_net = torch.load(os.path.join(
                result_dir, 'oup_tl_n_' + n_tl_str + '_pa_' + str(patience) + '_' + str(i) + '_four_param.pt'), map_location=torch.device('cpu'))
            tl_nlpd[i, j] = - tl_net.log_prob_unstandardized(theta, x).detach().median().numpy().item()

        mlmc_net = torch.load(os.path.join(
                result_dir, 'oup_mlmc_n_' + n_mlmc_str + '_' + str(i) + '_four_param.pt'), map_location=torch.device('cpu'))
        mlmc_nlpd[i] = - mlmc_net.log_prob_unstandardized(theta, x).detach().median().numpy().item()


    return mlmc_nlpd, tl_nlpd

In [13]:
def obtain_nlpd_for_multi_run_mc(theta, x, n_str: str):
    n_sim = 20
    nlpd = np.zeros((n_sim, ))

    for i in range(n_sim):
        net = torch.load(os.path.join(
                result_dir, 'oup_mc_n_' + n_str + '_four_param_' + str(i) + '.pt'), map_location=torch.device('cpu'))
        nlpd[i] = - net.log_prob_unstandardized(theta, x).detach().median().numpy().item()

    return nlpd

In [14]:
def obtain_kl(approximated_densities, exact_densities, forward = True):
    
    jitter = 1e-20
    exact_densities = np.clip(exact_densities, jitter, None)

    kl = np.zeros((approximated_densities.shape[0]))

    for i in range(approximated_densities.shape[0]):

        if forward:
            kl[i] = entropy(exact_densities[i], approximated_densities[i]) # forward KL divergence 

        else:
            kl[i] = entropy(approximated_densities[i], exact_densities[i])

    return kl

def obtain_avg_kl(net_approx, net_ref, x, num_samples = 2000):

    with torch.no_grad():
        post_ref = net_ref.sample_unstandardized(num_samples = num_samples, condition = x)
        dens_def = torch.stack([net_ref.log_prob_unstandardized(post_ref[:, i, :], condition = x) for i in range(num_samples)])
        dens_approx = torch.stack([net_approx.log_prob_unstandardized(post_ref[:, i, :], condition = x) for i in range(num_samples)])

        kl_avg = obtain_kl(dens_approx, dens_def).mean().item()

    return kl_avg

In [8]:
def obtain_kl_for_multi_run(x, n_tl_str: str, n_mlmc_str: str, ref_net):
    patience_list = [1, 20, 100, 1000]
    n_net = 20
    mlmc_kl = np.zeros((n_net, ))
    tl_kl = np.zeros((n_net, len(patience_list)))

    for i in range(n_net):

        for j in range(len(patience_list)):
            patience = patience_list[j]
            # patience = patience_list[j]
            tl_net = torch.load(os.path.join(
                result_dir, 'oup_tl_n_' + n_tl_str + '_pa_' + str(patience) + '_' + str(i) + '_four_param.pt'), map_location=torch.device('cpu'))
            tl_kl[i, j] =  obtain_avg_kl(tl_net, ref_net, x)

        mlmc_net = torch.load(os.path.join(
                result_dir, 'oup_mlmc_n_' + n_mlmc_str + '_' + str(i) + '_four_param.pt'), map_location=torch.device('cpu'))
        mlmc_kl[i] = obtain_avg_kl(mlmc_net, ref_net, x)


    return mlmc_kl, tl_kl

In [10]:
def obtain_kl_for_multi_run_mc(x, n_str, ref_net):
    n_net = 20
    kl = np.zeros((n_net, ))
    for i in range(n_net):
        net = torch.load(os.path.join(
            result_dir, 'oup_mc_n_' + n_str + '_four_param_' + str(i) + '.pt'), map_location=torch.device('cpu'))
        kl[i] = obtain_avg_kl(net, ref_net, x)

    return kl

In [20]:
# mlmc_kl, tl_kl = obtain_kl_for_multi_run(x, '1100_100', '1000_100', mc_ref)
# mlnc_nlpd, tl_nlpd = obtain_nlpd_for_multi_run(theta, x, '1100_100', '1000_100')

In [18]:
mc_nlpd_100 = obtain_nlpd_for_multi_run_mc(theta, x, '100')
mc_kl_100 = obtain_kl_for_multi_run_mc(x, '100', mc_ref)

  net = torch.load(os.path.join(
  net = torch.load(os.path.join(


In [21]:
np.save(os.path.join(result_dir, 'oup_mc_nlpd_100_four_param.npy'), mc_nlpd_100)
np.save(os.path.join(result_dir, 'oup_mc_kl_100_four_param.npy'), mc_kl_100)
mc_nlpd_100 = np.load(os.path.join(result_dir, 'oup_mc_nlpd_100_four_param.npy'))
mc_kl_100 = np.load(os.path.join(result_dir, 'oup_mc_kl_100_four_param.npy'))

In [19]:
print('MC NLPD 100: ', round(mc_nlpd_100.mean(), 2), '(', (round(mc_nlpd_100.std(), 2)),  ')', '\n'
      'MC KL 100: ', round(mc_kl_100.mean(), 2), '(', (round(mc_kl_100.std(), 2)),  ')')

MC NLPD 100:  -1.21 ( 0.53 ) 
MC KL 100:  1.05 ( 0.14 )


In [None]:
# np.save(os.path.join(result_dir, 'oup_mlmc_kl_100_four_param.npy'), mlmc_kl)
# np.save(os.path.join(result_dir, 'oup_tl_kl_100_four_param.npy'), tl_kl)
# np.save(os.path.join(result_dir, 'oup_mlmc_nlpd_100_four_param.npy'), mlnc_nlpd)
# np.save(os.path.join(result_dir, 'oup_tl_nlpd_100_four_param.npy'), tl_nlpd)

In [105]:
nlpd_tl_100 = np.load(os.path.join(result_dir, 'oup_tl_nlpd_100_four_param.npy'))
nlpd_mlmc_100 = np.load(os.path.join(result_dir, 'oup_mlmc_nlpd_100_four_param.npy'))

In [106]:
kl_tl_100 = np.load(os.path.join(result_dir, 'oup_tl_kl_100_four_param.npy'))
kl_mlmc_100 = np.load(os.path.join(result_dir, 'oup_mlmc_kl_100_four_param.npy'))

In [22]:
# NLPD 
print('MLMC NLPD : ', round(nlpd_mlmc_100.mean(), 2), '(', (round(mc_nlpd_100.std(), 2)),  ')', '\n'
      'TL NLPD: ', round( nlpd_tl_100.mean(axis = 0)[1], 2), '(', (round( nlpd_tl_100.std(axis = 0)[1], 2)),  ')')

MLMC NLPD :  -0.18 ( 0.53 ) 
TL NLPD:  -0.89 ( 0.52 )


In [23]:
# KL 
print('MLMC KL : ', round(kl_mlmc_100.mean(), 2), '(', (round(kl_mlmc_100.std(), 2)),  ')', '\n'
      'TL KL: ', round( kl_tl_100.mean(axis = 0)[1], 2), '(', (round( kl_tl_100.std(axis = 0)[1], 2)),  ')')

MLMC KL :  0.98 ( 0.08 ) 
TL KL:  1.24 ( 0.4 )
