In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import h5py
import corner
from matplotlib.lines import Line2D
import sys
sys.path.insert(0, '/Users/smericks/Desktop/StrongLensing/lens-npe/')
from Inference.network_reweighted_posteriors import NetworkReweightedPosteriors
import visualization_utils

This notebook demonstrates how to re-weight NPE posteriors to account for 
distribution shift between the test set distribution and the training set 
distribution. 
We use the shifted test set for this example.

Please note, as discussed in Erickson et al. '24, this technique currently produces 
miscalibrated (overconfident) posteriors, and is not recommended for application.

In [None]:
zenodo_filepath = '../Paper/from_zenodo/'

## Load in Network Predictions ##

In [None]:
# NPE model predictions 
npe_preds_path = zenodo_filepath+'model_predictions/npe/diag/'

# SHIFTED SET
file_path = npe_preds_path+'narrow_predictions.h5'
h5f = h5py.File(file_path, 'r')
y_test_shifted = h5f.get('y_test').value
y_pred_shifted = h5f.get('y_pred').value
std_pred_shifted = h5f.get('std_pred').value
prec_pred_shifted = h5f.get('prec_pred').value
h5f.close()

# DOPPELGANGER SET
file_path = npe_preds_path+'doppelganger_predictions.h5'
h5f = h5py.File(file_path, 'r')
y_test_doppel = h5f.get('y_test').value
y_pred_doppel = h5f.get('y_pred').value
std_pred_doppel = h5f.get('std_pred').value
prec_pred_doppel = h5f.get('prec_pred').value
h5f.close()

# SHIFTED SET
y_pred_shifted_seq_list = np.load(zenodo_filepath+'model_predictions/snpe/shifted/y_pred_list_epoch10.npy')
std_pred_shifted_seq_list = np.load(zenodo_filepath+'model_predictions/snpe/shifted/std_pred_list_epoch10.npy')
prec_pred_shifted_seq_list = np.load(zenodo_filepath+'model_predictions/snpe/shifted/prec_pred_list_epoch10.npy')
# DOPPELGANGER SET 
y_pred_doppel_seq_list = np.load(zenodo_filepath+'model_predictions/snpe/doppelganger/y_pred_list_epoch10.npy')
std_pred_doppel_seq_list = np.load(zenodo_filepath+'model_predictions/snpe/doppelganger/std_pred_list_epoch10.npy')
prec_pred_doppel_seq_list = np.load(zenodo_filepath+'model_predictions/snpe/doppelganger/prec_pred_list_epoch10.npy')

## Shifted Set Re-Weighting ##

In [None]:
already_computed = True
debug = True

if not debug and not already_computed:
    print("WARNING: Running re-weighting on all 20 shifted set lenses takes over an hour.")
    print("Debug=True ensures this is run for only one lens")

seq_reweighted_filepath = zenodo_filepath+'model_predictions/snpe/shifted/reweighted_seq_shifted.h5'

if already_computed == True:
    samps_list_seq_shifted,weights_list_seq_shifted = NetworkReweightedPosteriors.load_samps_weights(seq_reweighted_filepath)

else:
    train_mean = [0.8,0.,0.,2.,0.,0,0.,0.,0.,0.]
    train_scatter = [0.15,.12,0.12,0.2,0.2,0.2,0.07,0.07,0.1,0.1]

    # debug=True means only one lens calculated
    nrp = NetworkReweightedPosteriors({
        'hypermodel_type':'fixed_param',
        'sigmas_log_uniform':False,
        'n_emcee_samps':int(6e3)
    })

    samps_list_seq,weights_list_seq = nrp.reweighted_lens_posteriors_small_number(
        y_pred_shifted_seq_list[2],prec_pred_shifted_seq_list[2],train_mean,train_scatter,
        seq_reweighted_filepath,
        debug=False)

In [None]:
# Check that resulting posterior is reasonable...
learning_params_names = [r'$\theta_\mathrm{E}$',r'$\gamma_1$',r'$\gamma_2$',r'$\gamma_\mathrm{lens}$',r'$e_1$',
								r'$e_2$',r'$x_{lens}$',r'$y_{lens}$',r'$x_{src}$',r'$y_{src}$']
i=0

snpe_samps = visualization_utils.construct_samps(np.asarray([y_pred_shifted_seq_list[2][i]]),
    np.asarray([np.linalg.inv(prec_pred_shifted_seq_list[2][i])]))

figure = corner.corner(snpe_samps[:,0,:],bins=20,
            show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=50),
            levels=[0.68,0.95],color='mediumseagreen',fill_contours=True,smooth=1.0)

corner.corner(samps_list_seq[i],weights=weights_list_seq[i],bins=20,
            show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=50),
            levels=[0.68,0.95],color='indianred',fill_contours=True,smooth=1.0,
            fig=figure,labels=learning_params_names,
            truths=y_test_shifted[0],truth_color='black')

axes = np.array(figure.axes).reshape((10, 10))
custom_lines = [Line2D([0], [0], color='mediumseagreen', lw=4),
    Line2D([0], [0], color='indianred', lw=4)]

axes[0,9].legend(custom_lines,['SNPE','SNPE-RW'],frameon=False,fontsize=30)

## Doppelganger Set Reweighting ##

In [None]:
already_computed = False
debug = True

if not debug and not already_computed:
    print("WARNING: Running re-weighting on all 13 doppelganger set lenses takes roughly 30 minutes.")
    print("Debug=True ensures this is run for only one lens")

seq_reweighted_filepath = zenodo_filepath+'model_predictions/snpe/doppelganger/reweighted_doppel_shifted.h5'

if already_computed == True:
    samps_list_seq_shifted,weights_list_seq_shifted = NetworkReweightedPosteriors.load_samps_weights(seq_reweighted_filepath)

else:
    train_mean = [0.8,0.,0.,2.,0.,0,0.,0.,0.,0.]
    train_scatter = [0.15,.12,0.12,0.2,0.2,0.2,0.07,0.07,0.1,0.1]

    # debug=True means only one lens calculated
    nrp = NetworkReweightedPosteriors({
        'hypermodel_type':'fixed_param',
        'sigmas_log_uniform':False,
        'n_emcee_samps':int(6e3)
    })

    samps_list_seq,weights_list_seq = nrp.reweighted_lens_posteriors_small_number(
        y_pred_doppel_seq_list[2],prec_pred_doppel_seq_list[2],train_mean,train_scatter,
        seq_reweighted_filepath,
        debug=debug)

In [None]:
# Check that resulting posterior is reasonable...
learning_params_names = [r'$\theta_\mathrm{E}$',r'$\gamma_1$',r'$\gamma_2$',r'$\gamma_\mathrm{lens}$',r'$e_1$',
								r'$e_2$',r'$x_{lens}$',r'$y_{lens}$',r'$x_{src}$',r'$y_{src}$']
i=0

snpe_samps = visualization_utils.construct_samps(np.asarray([y_pred_doppel_seq_list[2][i]]),
    np.asarray([np.linalg.inv(prec_pred_doppel_seq_list[2][i])]))

figure = corner.corner(snpe_samps[:,0,:],bins=20,
            show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=50),
            levels=[0.68,0.95],color='mediumseagreen',fill_contours=True,smooth=1.0)

corner.corner(samps_list_seq[i],weights=weights_list_seq[i],bins=20,
            show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=50),
            levels=[0.68,0.95],color='indianred',fill_contours=True,smooth=1.0,
            fig=figure,labels=learning_params_names,
            truths=y_test_doppel[0],truth_color='black')

axes = np.array(figure.axes).reshape((10, 10))
custom_lines = [Line2D([0], [0], color='mediumseagreen', lw=4),
    Line2D([0], [0], color='indianred', lw=4)]

axes[0,9].legend(custom_lines,['SNPE','SNPE-RW'],frameon=False,fontsize=30)