In [None]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
from scipy.stats import norm as norm_dist
from paltas.Analysis import posterior_functions
import pandas as pd
import numpy as np
import h5py
import copy
import os
import sys
# USER TODO: change this system path 
sys.path.insert(0, '/Users/smericks/Desktop/StrongLensing/lens-npe/')
from LensSystem.lens_system import LensSystem
from Inference.base_hierarchical_inference import BaseHierarchicalInference
from Inference.network_reweighted_posteriors import NetworkReweightedPosteriors
from LensSystem.image_positions_utils import matrix_plot_im_positions
import visualization_utils
import mcmc_utils

This notebook recreates figures in Erickson et al. '24

## TODO Before Running ##

download files from https://zenodo.org/records/13906030 and put them in one folder. 
For this notebook, you do NOT need to download trained_models.tgz

In [37]:
# Specify the filepath here
zenodo_folder = 'lens-npe-data/'

## Figures 1,2 ##

Created in powerpoint, please contact @smericks for details. 

## Figure 3 ##

In [None]:
# Show narrow vs broad training

prior_dists = [norm_dist(loc=0.8,scale=0.15),norm_dist(loc=2.0,scale=0.2)]
prior_means = [0.8,2.0]
narrow_dists = [norm_dist(loc=0.7,scale=0.08),norm_dist(loc=2.05,scale=0.1)]
narrow_means = [0.7,2.05]
x_ranges = [np.arange(0.2,1.4,0.01),np.arange(1.4,2.6,0.01)]
titles = [r'$\theta_E$ (")',r'$\gamma_{lens}$']
fig,axs = plt.subplots(1,2,figsize=(12,5),dpi=300)
plt.subplots_adjust(wspace=0.11)
green_color = '#7bb274'
for i in range(0,2):

    prior_samps = prior_dists[i].rvs(1000)
    axs[i].plot(x_ranges[i],prior_dists[i].pdf(x_ranges[i]),label=r'$\nu_{int}$',color='grey',linestyle='dashed',linewidth=3.5)
    axs[i].axvline(prior_means[i],color='grey',linewidth=2.5,alpha=0.7)
    axs[i].plot(x_ranges[i],narrow_dists[i].pdf(x_ranges[i]),label='Shifted',color='slateblue',linewidth=3.5)
    axs[i].axvline(narrow_means[i],color='slateblue',linewidth=2.5,alpha=0.7)
    if i ==0:
        axs[i].legend(loc='upper right',bbox_to_anchor=(1.02,1.02),fontsize=20)
    axs[i].set_ylabel('Density',fontsize=23)
    axs[i].set_yticks([])
    axs[i].set_xlabel(titles[i],fontsize=30)
    axs[i].tick_params(labelsize=14)

plt.savefig('figures/Figure3.pdf',bbox_inches='tight')

## Figure 4 ##

In [None]:
lens_names = ['ATLAS J2344-3056', 'DES J0405-3308','DES J0420-4037',
              'J0029-3814', 'J1131-4419', 'J2145+6345',
              'J2205-3727','PS J1606-2333', 'SDSS J0248+1913', 'SDSS J1251+2935',
'W2M J1042+1641', 'WG0214-2105', 'WISE J0259-1635',None]

file_names = ['ATLASJ2344-3056', 'DESJ0405-3308','DESJ0420-4037',
              'J0029-3814','J1131-4419', 'J2145+6345',
              'J2205-3727','PSJ1606-2333', 'SDSSJ0248+1913', 'SDSSJ1251+2935',
'W2MJ1042+1641', 'WG0214-2105', 'WISEJ0259-1635',None]

image_files = []
plot_names = []
for i in range(0,len(file_names)):
    if file_names[i] is None:
        image_files.append(None)
        plot_names.append(None)
    else:
        image_files.append(zenodo_folder+'test_sets/doppelganger/doppel_images/'+file_names[i]+'/image_0000000.npy')
        plot_names.append('DG '+lens_names[i])

visualization_utils.matrix_plot_from_npy(image_files,plot_names,[2,7],'figures/Figure4b.pdf',
    annotate=True,show_one_arcsec=True,rotate_data=True)

## Network Predictions ##

After the networks for NPE and SNPE are trained, images are passed through the 
networks to make mass model predictions. Network predictions source code is 
found in network_predictions.py. We load in predictions here to make the rest of 
the figures.

See notebook make_predictions.ipynb for more details.

In [15]:
# NPE model predictions 
npe_preds_path = zenodo_folder+'model_predictions/npe/diag/'

# SHIFTED SET
file_path = npe_preds_path+'narrow_predictions.h5'
h5f = h5py.File(file_path, 'r')
y_test_shifted = h5f.get('y_test').value
y_pred_shifted = h5f.get('y_pred').value
std_pred_shifted = h5f.get('std_pred').value
prec_pred_shifted = h5f.get('prec_pred').value
h5f.close()

# DOPPELGANGER SET
file_path = npe_preds_path+'doppelganger_predictions.h5'
h5f = h5py.File(file_path, 'r')
y_test_doppel = h5f.get('y_test').value
y_pred_doppel = h5f.get('y_pred').value
std_pred_doppel = h5f.get('std_pred').value
prec_pred_doppel = h5f.get('prec_pred').value
h5f.close()

# HST DATA
file_path = npe_preds_path+'HSTdata_predictions.h5'
h5f = h5py.File(file_path, 'r')
y_pred_data = h5f.get('y_pred').value
std_pred_data = h5f.get('std_pred').value
prec_pred_data = h5f.get('prec_pred').value
h5f.close()

# adjust for different pixel grid conventions
# x-coords
y_pred_data[:,6] = - (y_pred_data[:,6]-0.02)
y_pred_data[:,8] = - (y_pred_data[:,8]-0.02)
# y-coords
y_pred_data[:,7] = - (y_pred_data[:,7]+0.02)
y_pred_data[:,9] = - (y_pred_data[:,9]+0.02)


In [16]:
# SNPE model predictions

# SHIFTED SET
y_pred_shifted_seq_list = np.load(zenodo_folder+'model_predictions/snpe/shifted/y_pred_list_epoch10.npy')
std_pred_shifted_seq_list = np.load(zenodo_folder+'model_predictions/snpe/shifted/std_pred_list_epoch10.npy')
prec_pred_shifted_seq_list = np.load(zenodo_folder+'model_predictions/snpe/shifted/prec_pred_list_epoch10.npy')
# DOPPELGANGER SET 
y_pred_doppel_seq_list = np.load(zenodo_folder+'model_predictions/snpe/doppelganger/y_pred_list_epoch10.npy')
std_pred_doppel_seq_list = np.load(zenodo_folder+'model_predictions/snpe/doppelganger/std_pred_list_epoch10.npy')
prec_pred_doppel_seq_list = np.load(zenodo_folder+'model_predictions/snpe/doppelganger/prec_pred_list_epoch10.npy')

# HST DATA
y_pred_data_seq_list = np.load(zenodo_folder+'model_predictions/snpe/data/y_pred_list_epoch10.npy')
std_pred_data_seq_list = np.load(zenodo_folder+'model_predictions/snpe/data/std_pred_list_epoch10.npy')
prec_pred_data_seq_list = np.load(zenodo_folder+'model_predictions/snpe/data/prec_pred_list_epoch10.npy')

## Figure 5 ##

In [None]:
# SHIFTED SET CALIBRATION

test_set_indices = range(0,20)
y_pred_final = [y_pred_shifted[test_set_indices,:],y_pred_shifted_seq_list[2]]
cov_pred_list_final = [np.linalg.inv(prec_pred_shifted[test_set_indices,:]),
                       np.linalg.inv(prec_pred_shifted_seq_list[2])]

visualization_utils.combine_calib_plots(y_pred_final,cov_pred_list_final,y_test_shifted[test_set_indices,:],
    ['slateblue','mediumseagreen'],['Perfect Calibration','NPE','SNPE'],
    plot_title='Calibration of Shifted Set',save_path='figures/Figure5a.pdf')

# DOPPELGANGER SET CALIBRATION

y_pred_final = [y_pred_doppel,y_pred_doppel_seq_list[2]]
cov_pred_list_final = [np.linalg.inv(prec_pred_doppel),
                       np.linalg.inv(prec_pred_doppel_seq_list[2])]

visualization_utils.combine_calib_plots(y_pred_final,cov_pred_list_final,y_test_doppel,
    ['slateblue','mediumseagreen'],['Perfect Calibration','NPE','SNPE'],
    plot_title='Calibration of Doppelganger Set',save_path='figures/Figure5b.pdf')



## Table 1 ##

In [None]:
print("NPE: Shifted Set")
visualization_utils.table_metrics(y_pred_shifted[:20],y_test_shifted[:20],std_pred_shifted[:20],None)
print(" ")
print("SNPE: Shifted Set")
visualization_utils.table_metrics(y_pred_shifted_seq_list[2],y_test_shifted[:20],std_pred_shifted_seq_list[2],None)

## Table 2 ##

In [None]:
print("NPE: Doppelganger Set")
visualization_utils.table_metrics(y_pred_doppel,y_test_doppel,std_pred_doppel,None)
print(" ")
print("SNPE: Doppelganger Set")
visualization_utils.table_metrics(y_pred_doppel_seq_list[2],y_test_doppel,std_pred_doppel_seq_list[2],None)

## Hierarchical Inference ##

Using network predictions, we perform a hierarchical inference for the lens 
mass population model of each test set. Hierarchical inference source code is 
found in the Inference folder. We load in resulting MCMC chains for the rest of
the figures.

See notebook hierarchical_inference.ipynb for more details.

In [20]:
# load in HI chains
shifted_chains = BaseHierarchicalInference.retrieve_chains_h5(zenodo_folder+'hierarchical_inference/shifted/HI_NPE_shifted.h5')
doppel_chains = BaseHierarchicalInference.retrieve_chains_h5(zenodo_folder+'/hierarchical_inference/doppelganger/HI_NPE_doppel.h5')
data_chains = BaseHierarchicalInference.retrieve_chains_h5(zenodo_folder+'hierarchical_inference/data/HI_NPE_data.h5')
data_FM_chains = BaseHierarchicalInference.retrieve_chains_h5(zenodo_folder+'hierarchical_inference/data/HI_FM_data.h5')

## Figure 6a ##

In [None]:
# bounds should be the same for both!
bounds = [
    [0.6,0.8], # +/- 0.1
    [1.73,2.37], # +/- 0.32
    [0.001,0.15], # prior
    [0.001,0.2], # prior
]
param_labels = np.asarray([
    r'$\mathcal{M}_{\theta_{\mathrm{E}}}$',
    r'$\mathcal{M}_{\gamma_{\mathrm{lens}}}$',
    r'$\Sigma_{\theta_{\mathrm{E}},\theta_{\mathrm{E}}}$',
    r'$\Sigma_{\gamma_{1/2},\gamma_{1/2}}$',
    r'$\Sigma_{\gamma_{\mathrm{lens}},\gamma_{\mathrm{lens}}}$',
    r'$\Sigma_{e_{1/2},e_{1/2}}$',
    r'$\sigma(x/y_{lens})$',r'$\sigma(x/y_{src})$'])
true_hyperparameters = np.asarray([0.7,2.05,0.08,0.12,0.1,0.2,0.07,0.1])

ten = 'slateblue'
one_hundred = 'mediumseagreen'

mcmc_utils.overlay_contours([shifted_chains[0],shifted_chains[2]],
                            colors_list=['slateblue','mediumseagreen'],
                            iofi=[0,1,2,4],true_params=true_hyperparameters[[0,1,2,4]],param_labels=param_labels[[0,1,2,4]],
                            sampler_labels=['NPE', 'SNPE'],bounds=bounds,
                            save_path='figures/Figure6a.pdf')

## Figure 6b ##

In [None]:
# bounds should be the same for both!
bounds = [
    [0.69,0.89], # +/- 0.1
    [1.74,2.38], # +/- 0.32
    [0.001,0.15], # prior
    [0.001,0.2], # prior
]

doppel_means = np.mean(y_test_doppel,axis=0)
doppel_stds = np.std(y_test_doppel,axis=0,ddof=1)
true_params = np.asarray([doppel_means[0],doppel_means[3],doppel_stds[0],
               (doppel_stds[1]+doppel_stds[2])/2,doppel_stds[3],
               (doppel_stds[4]+doppel_stds[5])/2,(doppel_stds[6]+doppel_stds[7])/2,
               (doppel_stds[8]+doppel_stds[9])/2])

y_bounds = [
    [0,14],
    [0,9],
    [0,25],
    [0,25]
]
mcmc_utils.overlay_contours([doppel_chains[0],doppel_chains[2]],
                            colors_list=['slateblue','mediumseagreen'],
                            iofi=[0,1,2,4],true_params=true_params[[0,1,2,4]],param_labels=param_labels[[0,1,2,4]],
                            y_bounds=y_bounds,sampler_labels=['NPE','SNPE'],bounds=bounds,
                            save_path='figures/Figure6b.pdf')

## Table 4 ##

In [None]:
print('HI from SNPE: Shifted')
mcmc_utils.HI_medians_table(shifted_chains[2],param_labels,burnin=1e3)

## Table 5 ##

In [None]:
print('HI from SNPE: Doppelganger')
mcmc_utils.HI_medians_table(doppel_chains[2],param_labels,burnin=1e3)

## Figure 7 ##

Creation of Figure 7 depends on input files that are not publicly available. 

See matrix_plot_im_positions() in LensSystem/image_positions_utils.py for the source code.

An example of how I call this function to make the figure is included below:

In [None]:
print("Not functional w/out proprietary input files")
has_files = False
if has_files:

    # re-ordering to match paper convention
    y_pred_data_final = y_pred_data_seq_list[0]
    std_pred_data_final = std_pred_data_seq_list[0]
    # save copy of the values
    y_pred_0530 = copy.deepcopy(y_pred_data_final[3])
    std_pred_0530 = copy.deepcopy(std_pred_data_final[3])
    # delete from 3rd position
    y_pred_data_final = np.delete(y_pred_data_final, 3, axis=0)
    std_pred_data_final = np.delete(std_pred_data_final, 3, axis=0)
    # re-insert at end position
    y_pred_data_final = np.insert(y_pred_data_final, 13, y_pred_0530, axis=0)
    std_pred_data_final = np.insert(std_pred_data_final, 13, std_pred_0530, axis=0)

    cov_pred_data_final = []
    for std in std_pred_data_final:
        cov_pred_data_final.append(np.diag(std**2))

    file_names_impos = ['ATLASJ2344-3056', 'DESJ0405-3308','DESJ0420-4037','J0029-3814', 'J1131-4419', 'J2145+6345',
    'J2205-3727','PSJ1606-2333', 'SDSSJ0248+1913', 'SDSSJ1251+2935',
    'W2MJ1042+1641', 'WG0214-2105', 'WISEJ0259-1635','F0530-3730']
    file_names = ['ATLASJ2344-3056', 'DESJ0405-3308','DESJ0420-4037',
        'F0530-3730','J0029-3814', 'J1131-4419', 'J2145+6345',
        'J2205-3727','PSJ1606-2333', 'SDSSJ0248+1913', 'SDSSJ1251+2935',
        'W2MJ1042+1641', 'WG0214-2105', 'WISEJ0259-1635']
    fits_file_list = []
    fm_file_list = []
    for f in file_names:
        fits_file_list.append('../reduced_data/'+f+'_F814W_drc_sci.fits')
    for f in file_names_impos:
        fm_file_list.append('../doppelgangers/'+f+'_results.txt')
    lens_names_list = ['ATLAS J2344-3056', 'DES J0405-3308','DES J0420-4037','J0029-3814', 'J1131-4419', 'J2145+6345',
    'J2205-3727','PS J1606-2333', 'SDSS J0248+1913', 'SDSS J1251+2935',
    'W2M J1042+1641', 'WG0214-2105', 'WISE J0259-1635', 'DES J0530-3730']

    print("SNPE Method")
    catalog_df = pd.read_csv('https://docs.google.com/spreadsheets/d/'+
        '1jOC60bWMxpp65iJZbANc_6SxouyXwqsESF4ocLAj27E/export?gid=0&format=csv')
    matrix_plot_im_positions(y_pred_data_final,'../reduced_data/',catalog_df,
        [2,4,6,9,13,17,18,22,23,24,27,28,30,7],(2,7),'/Users/smericks/Desktop/im_positions_data.pdf',
        show_one_arcsec=False,fm_files_for_astrometry=fm_file_list,
        cov_pred=cov_pred_data_final)

## Figure 8 ##

In [None]:
bounds = [
    [0.66,0.86], # +/- 0.1 from 0.75
    [1.68,2.32], # +/- 0.32
    [0.001,0.15], # prior
    [0.001,0.2], # prior
]
y_bounds = [
    [0,14],
    [0,11],
    [0,31],
    [0,38]
]
mcmc_utils.overlay_contours([data_chains[0],data_chains[1],data_FM_chains[0]],
                            colors_list=['slateblue','mediumseagreen','lightpink'],
                            iofi=[0,1,2,4],true_params=None,param_labels=param_labels[[0,1,2,4]],
                            sampler_labels=['NPE','SNPE','STRIDES23'],bounds=bounds,
                            y_bounds=y_bounds,
                            save_path='figures/Figure8.pdf')

## Table 6 ##

In [None]:
# let's print out median/uncertainty from each chain
print("NPE")
mcmc_utils.HI_medians_table(data_chains[0],param_labels,burnin=1e3)
print(" ")
print("SNPE")
mcmc_utils.HI_medians_table(data_chains[1],param_labels,burnin=1e3)
print(" ")
print("Schmidt '23")
mcmc_utils.HI_medians_table(data_FM_chains[0],param_labels,burnin=1e3)

## Figure 9 ##

In [None]:
# means_list, cov_list is the FM results from Schmidt '23
means_list = np.load(zenodo_folder+'model_predictions/fm/y_pred_FM.npy')
cov_list = np.load(zenodo_folder+'model_predictions/fm/cov_pred_FM.npy')

ertl_thetaE_mu = [0.831,0.75,0.689,0.855]
ertl_thetaE_sigma = [0.002,0.01,0.009,0.004]
ertl_gamma_mu = np.asarray([0.65,0.35,0.31,0.6])*2 + 1
ertl_gamma_sigma = np.asarray([0.04,0.05,0.01,0.1])*2
ertl_idx = [2,4,8,12]

fig,axs = plt.subplots(1,3,dpi=300,figsize=(15,5))

titles = [r'$\theta_E$',r'$\gamma_{lens}$']
fm_lens_names = ['ATLAS J2344-3056', 'DES J0405-3308','DES J0420-4037','DES J0530-3730','J0029-3814', 
                'J1131-4419','J2145+6345','J2205-3727','PS J1606-2333', 'SDSS J0248+1913', 
                'SDSS J1251+2935', 'W2M J1042+1641','WG0214-2105', 'WISE J0259-1635']

for i,p in enumerate([0,3]):
    # snpe indices of the 10 good ones
    if p == 0:
        axs[i].plot([0.49,1.05],[0.49,1.05],color='black',alpha=0.8,linewidth=0.8)
    if p ==3:
        axs[i].plot([1.6,2.4],[1.6,2.4],color='black',alpha=0.8,linewidth=0.8)
    
    lens_letters = ['A','B','C','D','E','F','G','H','I','J','K','L','M','N']
    ertl_counter = 0
    for fm_idx in range(0,14):

        axs[i].errorbar(means_list[fm_idx,p],y_pred_data_seq_list[0,fm_idx,p],
            xerr=np.sqrt(cov_list[fm_idx,p,p]),yerr=std_pred_data_seq_list[0,fm_idx,p],
            color='lightpink',fmt='.',markersize=22,zorder=200)
        axs[i].text(means_list[fm_idx,p],y_pred_data_seq_list[0,fm_idx,p],lens_letters[fm_idx],ha='center',va='center',zorder=300)


        if fm_idx in ertl_idx:
            if p == 0:
                axs[i].errorbar(ertl_thetaE_mu[ertl_counter],y_pred_data_seq_list[0,fm_idx,p],
                    xerr=ertl_thetaE_sigma[ertl_counter],yerr=std_pred_data_seq_list[0,fm_idx,p],
                    color='indianred',fmt='.',markersize=22,zorder=200)
                axs[i].text(ertl_thetaE_mu[ertl_counter],y_pred_data_seq_list[0,fm_idx,p],
                    lens_letters[fm_idx],ha='center',va='center',zorder=300)

            if p == 3: 
                axs[i].errorbar(ertl_gamma_mu[ertl_counter],y_pred_data_seq_list[0,fm_idx,p],
                    xerr=ertl_gamma_sigma[ertl_counter],yerr=std_pred_data_seq_list[0,fm_idx,p],
                    color='indianred',fmt='.',markersize=22,zorder=200)
                axs[i].text(ertl_gamma_mu[ertl_counter],y_pred_data_seq_list[0,fm_idx,p],
                    lens_letters[fm_idx],ha='center',va='center',zorder=300)

            ertl_counter+=1

    # add the correlation coefficients
      

    axs[i].set_xlabel('Automated FM',fontsize=15)
    axs[i].set_ylabel(r'SNPE',fontsize=15)
    axs[i].set_title(titles[i],fontsize=17)

# custom legend for Schmidt/Ertl colors
axs[0].scatter([],[],color='lightpink',s=50,label='STRIDES23')
axs[0].scatter([],[],color='indianred',s=50,label='Ertl \'23')

axs[0].legend(fontsize=14,loc='upper left')

for i in range(0,14):
    axs[2].scatter(1,13-i,color='white',s=130,edgecolors='black')
    axs[2].text(1,12.95-i,lens_letters[i],ha='center',va='center',zorder=300)
    axs[2].text(1.005,12.95-i,fm_lens_names[i],va='center',fontsize=12)
axs[2].set_xlim([0.99,1.05])
axs[2].axis('off')

plt.savefig('figures/Figure9.pdf')

## Figure 10 ##

In [None]:
# everything is in units of hours

# CPU
sim_per_lens = 0.0015/60.
broad_sim = 5e5*sim_per_lens
seq_sim_per_lens = 5e4*sim_per_lens
pred_per_lens = .7/(60**2)

# GPU
broad_train = 5.73
seq_train = 4.6/60

fig,axs = plt.subplots(1,2,figsize=(15,6),dpi=200)
axs[0].set_ylabel('CPU Hours',fontsize=15)
axs[0].set_xlabel('# Lenses',fontsize=15)
axs[0].set_title('CPU Time per Lens',fontsize=15)
axs[1].set_ylabel('GPU Hours',fontsize=15)
axs[1].set_xlabel('# Lenses',fontsize=15)
axs[1].set_title('GPU Time per Lens',fontsize=15)

n_lenses = np.arange(1,1000,1)

npe_cpu = broad_sim/n_lenses + pred_per_lens
npe_gpu = broad_train/n_lenses

axs[0].plot(n_lenses,npe_cpu,label='NPE',color='slateblue',linewidth=3)
axs[1].plot(n_lenses,npe_gpu,label='NPE',color='slateblue',linewidth=3)

# plot zero as a reference
axs[0].plot([1.,1000.],[0.,0.],color='darkgrey',linewidth=2.5,linestyle='dashed',zorder=1)
axs[1].plot([1.,1000.],[0.,0.],color='darkgrey',linewidth=2.5,linestyle='dashed',zorder=1)


snpe_cpu = npe_cpu + seq_sim_per_lens 
snpe_gpu = npe_gpu + seq_train

axs[0].plot(n_lenses,snpe_cpu,label='SNPE',color='mediumseagreen',linewidth=3)
axs[0].legend(fontsize=15)
axs[1].plot(n_lenses,snpe_gpu,label='SNPE',color='mediumseagreen',linewidth=3)

axs[0].set_xscale('log')
axs[1].set_xscale('log')

axs[0].set_xlim([1,int(1e3)])
axs[1].set_xlim([1,int(1e3)])

axs[0].tick_params(labelsize=13)
axs[1].tick_params(labelsize=13)

#axs[0].grid()
#axs[1].grid()

plt.savefig('figures/Figure10.pdf')

## Figure 11 ##

In [None]:
# SHIFTED CALIB
test_set_indices = range(0,20)
y_pred_list = [y_pred_shifted[test_set_indices,:]]
cov_pred_list = [np.linalg.inv(prec_pred_shifted[test_set_indices,:])]
for j in range(3,-1,-1):
    y_pred_list.append(y_pred_shifted_seq_list[j])
    cov_pred_list.append(np.linalg.inv(prec_pred_shifted_seq_list[j]))

color_list = ['slateblue','#225ea8','mediumseagreen','#c2e699','#fff7bc']
label_list = ['Perfect Calibration','NPE','SNPE m=4','SNPE m=2','SNPE m=1','SNPE m=0']

visualization_utils.combine_calib_plots(y_pred_list,cov_pred_list,y_test_shifted[test_set_indices,:],color_list,label_list,
    plot_title='Calibration of Shifted Set',save_path='figures/Figure11.pdf')

## Figure 12 ##

In [31]:
# first, load in full covariance NPE network predictions

# NPE-full model predictions 
npe_preds_path = zenodo_folder+'model_predictions/npe/full/'

# SHIFTED SET
file_path = npe_preds_path+'narrow_predictions.h5'
h5f = h5py.File(file_path, 'r')
y_pred_shifted_FULL = h5f.get('y_pred').value
std_pred_shifted_FULL = h5f.get('std_pred').value
prec_pred_shifted_FULL = h5f.get('prec_pred').value
h5f.close()

# DOPPELGANGER SET
file_path = npe_preds_path+'doppelganger_predictions.h5'
h5f = h5py.File(file_path, 'r')
y_pred_doppel_FULL = h5f.get('y_pred').value
std_pred_doppel_FULL = h5f.get('std_pred').value
prec_pred_doppel_FULL = h5f.get('prec_pred').value
h5f.close()

In [None]:
# SHIFTED Calibration

test_set_indices = range(0,20)
# NPE-diag,NPE-full,SNPE-diag
y_pred_list = [y_pred_shifted[test_set_indices,:],
               y_pred_shifted_FULL[test_set_indices,:],
               y_pred_shifted_seq_list[2]]
cov_pred_list = [np.linalg.inv(prec_pred_shifted[test_set_indices]),
                np.linalg.inv(prec_pred_shifted_FULL[test_set_indices]),
                np.linalg.inv(prec_pred_shifted_seq_list[2])]

color_list = ['slateblue','orange','mediumseagreen']
label_list = ['Perfect Calibration','NPE-diag','NPE-full','SNPE-diag']

visualization_utils.combine_calib_plots(y_pred_list,cov_pred_list,y_test_shifted[test_set_indices,:],color_list,label_list,
    plot_title='Calibration of Shifted Set',save_path='figures/Figure12a.pdf')

# DOPPELGANGER Calibration
y_pred_list = [y_pred_doppel,
               y_pred_doppel_FULL,
               y_pred_doppel_seq_list[2]]
cov_pred_list = [np.linalg.inv(prec_pred_doppel),
                np.linalg.inv(prec_pred_doppel_FULL),
                np.linalg.inv(prec_pred_doppel_seq_list[2])]

color_list = ['slateblue','orange','mediumseagreen']
label_list = ['Perfect Calibration','NPE-diag','NPE-full','SNPE-diag']

visualization_utils.combine_calib_plots(y_pred_list,cov_pred_list,y_test_doppel,color_list,label_list,
    plot_title='Calibration of Doppelganger Set',save_path='figures/Figure12b.pdf')

## Figure 13a ##

In [None]:
# bounds should be the same for both!
bounds = [
    [0.6,0.8], # +/- 0.1
    [1.73,2.37], # +/- 0.32
    [0.001,0.15], # prior
    [0.001,0.2], # prior
]
true_hyperparameters = np.asarray([0.7,2.05,0.08,0.12,0.1,0.2,0.07,0.1])

y_bounds = [
    [0,35],
    [0,15],
    [0,47],
    [0,13]
]

mcmc_utils.overlay_contours([shifted_chains[0],shifted_chains[5],shifted_chains[2]],
                            colors_list=['slateblue','orange','mediumseagreen'],
                            iofi=[0,1,2,4],true_params=true_hyperparameters[[0,1,2,4]],param_labels=param_labels[[0,1,2,4]],
                            sampler_labels=['NPE', 'NPE-full','SNPE'],bounds=bounds,
                            save_path='figures/Figure13a.pdf',y_bounds=y_bounds)

## Figure 13b ##

In [None]:
# bounds should be the same for both!
bounds = [
    [0.69,0.89], # +/- 0.1
    [1.74,2.38], # +/- 0.32
    [0.001,0.15], # prior
    [0.001,0.2], # prior
]

doppel_means = np.mean(y_test_doppel,axis=0)
doppel_stds = np.std(y_test_doppel,axis=0,ddof=1)
true_params = np.asarray([doppel_means[0],doppel_means[3],doppel_stds[0],
               (doppel_stds[1]+doppel_stds[2])/2,doppel_stds[3],
               (doppel_stds[4]+doppel_stds[5])/2,(doppel_stds[6]+doppel_stds[7])/2,
               (doppel_stds[8]+doppel_stds[9])/2])

y_bounds = [
    [0,15],
    [0,9],
    [0,25],
    [0,25]
]
mcmc_utils.overlay_contours([doppel_chains[0],doppel_chains[5],doppel_chains[2]],
                            colors_list=['slateblue','orange','mediumseagreen'],
                            iofi=[0,1,2,4],true_params=true_params[[0,1,2,4]],param_labels=param_labels[[0,1,2,4]],
                            y_bounds=y_bounds,sampler_labels=['NPE','NPE-Full','SNPE'],bounds=bounds,
                            save_path='figures/Figure13b.pdf')

## Figure 14a ##

In [None]:
shifted_rw_filepath = zenodo_folder+'model_predictions/snpe/shifted/reweighted_seq_shifted.h5'
samps_list_seq_shifted,weights_list_seq_shifted = NetworkReweightedPosteriors.load_samps_weights(shifted_rw_filepath)

# NPE calib
shifted_npe_samps = visualization_utils.construct_samps(y_pred_shifted[:20],np.linalg.inv(prec_pred_shifted[:20]))
calib_figure = posterior_functions.plot_calibration(shifted_npe_samps,
    y_test_shifted[:20],show_plot=False,color_map=['black','slateblue'])

# SNPE calib
shifted_snpe_samps = visualization_utils.construct_samps(y_pred_shifted_seq_list[2],np.linalg.inv(prec_pred_shifted_seq_list[2]))
calib_figure = posterior_functions.plot_calibration(shifted_snpe_samps,y_test_shifted[:20],
    figure=calib_figure,color_map=['black','mediumseagreen'],
    legend=['Perfect Calibration','Broad','Sequential'],
    title='Calibration of Shifted Set Posteriors',show_plot=False)

# SNPE-cPDF calib
samps_seq_reweighted = np.transpose(np.asarray(samps_list_seq_shifted),axes=[1,0,2])
weights_seq_reweighted = np.asarray(weights_list_seq_shifted).T

posterior_functions.plot_calibration(samps_seq_reweighted,y_test_shifted[:20],figure=calib_figure,
    color_map=['black','orange'],
    legend=[r'Perfect Calibration',r'NPE, $\nu_{int}$ Prior',r'SNPE, $\nu_{int}$ Prior',r'SNPE, cPDF Prior'],
    title='Calibration of Shifted Set Posteriors',loc='upper left',
    weights=weights_seq_reweighted,show_plot=False)
plt.savefig('figures/Figure14a.pdf')

## Figure 14b ##

In [None]:
doppel_rw_filepath = zenodo_folder+'model_predictions/snpe/doppelganger/reweighted_doppel_shifted.h5'
samps_list_seq_doppel,weights_list_seq_doppel = NetworkReweightedPosteriors.load_samps_weights(doppel_rw_filepath)

# NPE calib
doppel_npe_samps = visualization_utils.construct_samps(y_pred_doppel,np.linalg.inv(prec_pred_doppel))
calib_figure = posterior_functions.plot_calibration(doppel_npe_samps,
    y_test_doppel,show_plot=False,color_map=['black','slateblue'])

# SNPE calib
doppel_snpe_samps = visualization_utils.construct_samps(y_pred_doppel_seq_list[2],np.linalg.inv(prec_pred_doppel_seq_list[2]))
calib_figure = posterior_functions.plot_calibration(doppel_snpe_samps,y_test_doppel,
    figure=calib_figure,color_map=['black','mediumseagreen'],
    legend=['Perfect Calibration','Broad','Sequential'],
    title='Calibration of Doppelganger Set Posteriors',show_plot=False)

# SNPE-cPDF calib
samps_seq_reweighted = np.transpose(np.asarray(samps_list_seq_doppel),axes=[1,0,2])
weights_seq_reweighted = np.asarray(weights_list_seq_doppel).T

posterior_functions.plot_calibration(samps_seq_reweighted,y_test_doppel,figure=calib_figure,
    color_map=['black','orange'],
    legend=[r'Perfect Calibration',r'NPE, $\nu_{int}$ Prior',r'SNPE, $\nu_{int}$ Prior',r'SNPE, cPDF Prior'],
    title='Calibration of Doppelganger Set Posteriors',loc='upper left',
    weights=weights_seq_reweighted,show_plot=False)
plt.savefig('figures/Figure14b.pdf')