In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import h5py
import corner
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from matplotlib.lines import Line2D
from paltas.Analysis import posterior_functions
import sys
sys.path.insert(0, '/Users/smericks/Desktop/StrongLensing/lens-npe/')
from Inference.network_reweighted_posteriors import NetworkReweightedPosteriors
import visualization_utils

This notebook presents further testing of hierarchical reweighting. These tests were done as an exploratory analysis. Material was presented at a KIPAC Strong Lensing x Computer Science Causality Group Seminar on December 9th, 2024. 

In [None]:
zenodo_filepath = '../Paper/lens-npe-data/'

### Visualize the Lenses we Test on ###

In [None]:
visualization_utils.matrix_plot_from_h5(zenodo_filepath+'test_sets/shifted/image_data.h5',(2,10),None)

### Load in Full Covariance NPE Predictions for the Shifted Test Set ###

In [None]:
# SHIFTED SET
file_path = zenodo_filepath+'model_predictions/npe/full/narrow_predictions.h5'
h5f = h5py.File(file_path, 'r')
y_test_shifted = h5f.get('y_test').value
y_pred_shifted = h5f.get('y_pred').value
std_pred_shifted = h5f.get('std_pred').value
prec_pred_shifted = h5f.get('prec_pred').value
h5f.close()

### Run Hierarchical Re-Weighting for First 20 Lenses in Shifted Test Set ###
* We choose the shifted test set because we have a ground truth for its population distribution
* We use full covariance NPE so we are less susceptible to functional-form mis-specification (unfortunately full covariance SNPE was not run for this work)
* We only use 20 lenses, because the re-weighting takes ~1hr to run with 20

In [None]:
already_computed = True
debug = False # debug=True means only one lens calculated

if not debug and not already_computed:
    print("WARNING: Running re-weighting on all 20 shifted set lenses takes over an hour.")
    print("Debug=True ensures this is run for only one lens")

seq_reweighted_filepath = 'notebook_data/full_npe_reweighted_shifted.h5'

if already_computed == True:
    samps_list_seq_shifted,weights_list_seq_shifted = NetworkReweightedPosteriors.load_samps_weights(seq_reweighted_filepath)

else:
    # in order: theta_E, gamma1, gamma2, gamma, e1, e2, x_lens, y_lens, x_src, y_src
    train_mean = [0.8,0.,0.,2.,0.,0,0.,0.,0.,0.]
    train_scatter = [0.15,.12,0.12,0.2,0.2,0.2,0.07,0.07,0.1,0.1]

    nrp = NetworkReweightedPosteriors({
        'hypermodel_type':'fixed_param',
        'sigmas_log_uniform':False,
        'n_emcee_samps':int(6e3)
    })

    samps_list_seq,weights_list_seq = nrp.reweighted_lens_posteriors_small_number(
        y_pred_shifted[:20],prec_pred_shifted[:20],train_mean,train_scatter,
        seq_reweighted_filepath,
        debug=debug)

### Plot Re-Weighted Individual Posteriors ###


In [None]:
param_labels =  [r'$\theta_\mathrm{E}$',r'$\gamma_1$',r'$\gamma_2$',r'$\gamma_\mathrm{lens}$',r'$e_1$',
								r'$e_2$']
cov_pred_shifted = np.linalg.inv(prec_pred_shifted)

npe_color = 'slateblue'
reweighted_color = 'indianred'

for i in range(0,20):

    samps_npe = multivariate_normal(mean=y_pred_shifted[i,:6],
                cov=cov_pred_shifted[i,:6,:6]).rvs(size=int(5e3))
                
    figure = corner.corner(samps_npe,labels=np.asarray(param_labels),bins=20,
            show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=50),
            levels=[0.68,0.95],color=npe_color,fill_contours=True,smooth=1.0,
            hist_kwargs={'density':True,'color':npe_color,'lw':3},
            title_fmt='.2f',max_n_ticks=3,fig=None)

    hist_kwargs = {'density':True,'color':reweighted_color,'lw':3}
    corner.corner(samps_list_seq_shifted[i][:,:6],weights=weights_list_seq_shifted[i],labels=np.asarray(param_labels),bins=20,
            show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=50),
            levels=[0.68,0.95],color=reweighted_color,fill_contours=True,smooth=1.0,
            hist_kwargs={'density':True,'color':reweighted_color,'lw':3},
            title_fmt='.2f',max_n_ticks=3,fig=figure,
            truths=y_test_shifted[i,:6],
            truth_color='black')#,range=np.ones(6)*0.98)


    axes = np.array(figure.axes).reshape((6, 6))
    custom_lines = [Line2D([0], [0], color=npe_color, lw=4),
                    Line2D([0], [0], color=reweighted_color, lw=4)]
    
    axes[0,5].legend(custom_lines,['NPE-Full','Reweighted NPE-Full'],frameon=False,fontsize=25)

    plt.suptitle('Narrow %02d'%(i),fontsize=30)
    plt.show()

### Show Calibration Curve Before and After Reweighting ###

In [None]:
# NPE calib
shifted_npe_samps = visualization_utils.construct_samps(y_pred_shifted[:20],np.linalg.inv(prec_pred_shifted[:20]))
calib_figure = posterior_functions.plot_calibration(shifted_npe_samps,
    y_test_shifted[:20],show_plot=False,color_map=['black','slateblue'])

# NPE-cPDF calib
samps_seq_reweighted = np.transpose(np.asarray(samps_list_seq),axes=[1,0,2])
weights_seq_reweighted = np.asarray(weights_list_seq).T

posterior_functions.plot_calibration(samps_seq_reweighted,y_test_shifted[:20],figure=calib_figure,
    color_map=['black','indianred'],
    legend=[r'Perfect Calibration',r'NPE, $\nu_{int}$ Prior',r'NPE, cPDF Prior'],
    title='Calibration of Shifted Set Posteriors',loc='upper left',
    weights=weights_seq_reweighted,show_plot=False)

### Appendix: Sanity Checks ###

In [None]:
# check that precision matrices are symmetric
for i in range(0,100):
    my_ex = prec_pred_shifted[0]
    is_fine = np.allclose(my_ex,my_ex.T)
    if not is_fine:
        print('uh oh!')

In [None]:
# check the values of the weights
plt.hist(weights_list_seq[6])