In [None]:
# REQUIRED IMPORTS
# FILE MANAGEMENT
import h5py
import os
from IPython.utils import io

%load_ext autoreload
%autoreload 2

# DATA MANIPULATION
import numpy as np
import pandas as pd


# STATISTICS
from scipy.stats import norm, truncnorm, uniform, beta, multivariate_normal

# VISUALIZATION
import matplotlib
import matplotlib.pyplot as plt
from astropy.visualization import simple_norm
import matplotlib.colors as mpc
import corner
from matplotlib.patches import Patch
import astropy.visualization as asviz


SMALL_SIZE = 17
MEDIUM_SIZE = 20
BIGGER_SIZE = 30

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rcParams["font.family"] = "DejaVu Serif"

%matplotlib inline

In [None]:
learning_params =np.array( ['main_deflector_parameters_theta_E', 'main_deflector_parameters_gamma1',
                   'main_deflector_parameters_gamma2', 'main_deflector_parameters_gamma', 
                   'main_deflector_parameters_e1', 'main_deflector_parameters_e2', 'main_deflector_parameters_center_x',
                  'main_deflector_parameters_center_y', 'source_parameters_R_sersic', 'source_parameters_mag_app'])

labels = np.array(["$\\theta_E$", "$\gamma_1$", "$\gamma_2$", "$\gamma_{lens}$", "$e_1$", "$e_2$",
                   '$x_D$', '$y_D$', '$R_{src}$', "$m_{i}$"])
mu_labels = np.array(['$\mathcal{M}(' + i[1:-1] + ')$' for i in np.array(labels)])
std_labels = np.array(['$\Sigma(' + i[1:-1] + ')$' for i in np.array(labels)])

labels_dict = dict(zip(learning_params, labels))

In [47]:
#### latils imports
from latils import prepRes, make_contour , get_train_data,  get_obj_of_wide_posteriors_obj, retrieve_chains_h5

# from lenstronomy.Util.param_util import ellipticity2phi_q, shear_cartesian2polar


In [None]:

### fiducial preparation
ALobj = prepRes('full', 'all',8,learning_params[:8],'All Light Included','rebeccapurple','0118/all','0325/all_no_single','full/0118/all', 'data/0325_results',mode_of_stopping='early_stopping')

### orange dataset: applied network trained with mass-light correlated to test with different mass-light ellipticity correlations
ALobj_dinos_ml = prepRes('full', 'all',8,learning_params[:8],'All Light Included','rebeccapurple','0118/all','0429/all_dinos_ml','full/0118/all', 'data/dinos_qm_ql',mode_of_stopping='early_stopping')

NLobj = prepRes('full', 'nolens',8,learning_params[:8],'Lens Light Subtracted','mediumaquamarine','0118/nolens','0325/nolens','full/0118/nolens', 'data/0325_results',mode_of_stopping='early_stopping')

NSobj = prepRes('full', 'nosrc',10,learning_params[:10],'Lens+AGN Light Subtracted','gold','0118/nosrc','0325/nosrc','full/0118/nosrc', 'data/0325_results',mode_of_stopping='early_stopping')

### trained without correlations
nc_ALobj = prepRes('full', 'nc_all_om10',8,learning_params[:8],'NC:All Light Included','darkgoldenrod','0310/all','0325/all_no_single','full/0310/all', 'data/0325_train_no_corr',mode_of_stopping='early_stopping')
nc_NLobj = prepRes('full', 'nc_nolens_om10',8,learning_params[:8],'NC:Lens Light Subtracted','rebeccapurple','0310/nolens','0325/nolens','full/0310/nolens', 'data/0325_train_no_corr',mode_of_stopping='early_stopping')

### valid_results
ALvobj = prepRes('full', 'valid_all',8,learning_params[:8],'V: All Light Included','navajowhite','0118/all','0118/valid_all','full/0118/all', 'data/0118_valid_results',mode_of_stopping='early_stopping')
NLvobj = prepRes('full', 'valid_nolens',8,learning_params[:8],'V: Lens Light Subtracted','aquamarine','0118/nolens','0118/valid_nolens','full/0118/nolens', 'data/0118_valid_results',mode_of_stopping='early_stopping')

nc_ALvobj = prepRes('full', 'nc_all',8,learning_params[:8],'NC:All Light Included','darkgoldenrod','0310/all','0118/valid_all','full/0310/all', 'data/0320_results_held_out',mode_of_stopping='early_stopping')
nc_NLvobj = prepRes('full', 'nc_nolens',8,learning_params[:8],'NC:Lens Light Subtracted','rebeccapurple','0310/nolens','0118/valid_nolens','full/0310/nolens', 'data/0320_results_held_out',mode_of_stopping='early_stopping')

### diag_results
ALdobj = prepRes('diag', 'all_diag',8,learning_params[:8],'D: All Light Included','tan','0118/all','0325/all_no_single','diagonal/0118/all', 'data/0118_diag_results',mode_of_stopping='early_stopping')

### bright_results
ALbobj = prepRes('full', 'allb',8,learning_params[:8],'All Light Included','gold','0118/all','0118/all/bright','full/0118/all', 'data/0118_bright_results',mode_of_stopping='early_stopping')
NLbobj = prepRes('full', 'nolensb',8,learning_params[:8],'Lens Light Subtracted','mediumaquamarine','0118/nolens','0118/nolens/bright','full/0118/nolens', 'data/0118_bright_results',mode_of_stopping='early_stopping')
NSbobj = prepRes('full', 'nosrcb',10,learning_params[:10],'Lens+AGN Light Subtracted','rebeccapurple','0118/nosrc','0118/nosrc/bright','full/0118/nosrc', 'data/0118_bright_results',mode_of_stopping='early_stopping')


### deconv_results
ALdeobj = prepRes('full', 'all',10,learning_params[:10],'EM-D: All Light Included','gold','0118/deconvolved/all','0325/deconvolved/all','full/0118/deconvolved/all', 'data/0325_deconv_results',mode_of_stopping='early_stopping')
NLdeobj = prepRes('full', 'nolens',10,learning_params[:10],'EM-D: Lens Light Subtracted','mediumaquamarine','0118/deconvolved/nolens','0325/deconvolved/nolens','full/0118/deconvolved/nolens', 'data/0325_deconv_results',mode_of_stopping='early_stopping')
NSdeobj = prepRes('full', 'nosrc',10,learning_params[:10],'EM-D: Lens+AGN Light Subtracted','mediumpurple','0118/deconvolved/nosrc','0325/deconvolved/nosrc','full/0118/deconvolved/nosrc', 'data/0325_deconv_results',mode_of_stopping='early_stopping')



In [None]:
obj_list = np.array([ALobj,ALobj_dinos_ml, NLobj, ALbobj, NLbobj, NSbobj, ALdeobj, NLdeobj, NSdeobj, nc_ALobj, nc_NLobj, ALvobj, NLvobj])

### Figure 7: Population-level parameter recovery [fiducial]

In [None]:
def hypermodel_plot(results_df, preps,n_params_learned, h5_files, params,colors, categories,burnin=3000, save_name=None,
                    show_corr = False,obj_list=None):
    num_preps = len(preps)
    # if len(np.unique(preps)) ==1:
    #     results_df = results_df.loc[preps]
    # else:
    #     results_df = results_df.loc[np.unique(preps)]
    legend_elements = []
    h5i = 0
    list_of_dists = []
    # list_of_dists_std= []
    mu_labels = np.array(['$\mathcal{M}(' + '{' + i[1:-1] + '})$' for i in np.array(labels)[params]])
    std_labels = np.array(['$\Sigma(' + '{' + i[1:-1] + '})$' for i in np.array(labels)[params]])
    all_labels = np.append(mu_labels, std_labels)
    truths, errors, err_perc = [],[],[]
    # decide looping factors
    for prep in preps:
        df = results_df[h5i]
        # print(prep)
        # print(results_df.loc[prep,'learning_params'])
        # print(np.array(results_df.loc[prep,'learning_params']))
        learning_params=np.array(df.loc[prep,'learning_params'])[params]
        category = df.loc[prep, 'name']
        color = df.loc[prep, 'color']
        legend_elements.append(Patch(facecolor=color, edgecolor=color, label=category))
        
        test_data = pd.read_csv(os.path.join(df.loc[prep, 'path_to_test_images'], 'metadata.csv'))
        train_data = get_train_data(df, prep)[learning_params]
        mu_train = np.array(train_data.loc[:, learning_params].mean())
        std_train = np.array(train_data.loc[:, learning_params].std())
        print('mu_train:', mu_train)
        print('std_train:', std_train)
        n_lenses = len(test_data)
        bad_ind = get_obj_of_wide_posteriors_obj(obj_list[h5i])
        good_ind = [i for i in range(n_lenses) if i not in bad_ind]
        mu_test = np.array(test_data.loc[good_ind, learning_params].mean())
        print('mu_test before:', np.array(test_data.loc[:, learning_params].mean()))
        print('mu_test after:',mu_test)
        cov_test = np.array(test_data.loc[good_ind, learning_params].cov())
        std_test = np.array(test_data.loc[good_ind, learning_params].std())
        print('std_test before:',np.array(test_data.loc[:, learning_params].std()))
        print('mu_test after:',std_test)
        emcee_chain = retrieve_chains_h5(h5_files[h5i])
        all_params = np.append(params, params+n_params_learned[h5i])
        chain_full = emcee_chain[:, burnin:, all_params]
        how_much_learned = df.loc[prep,'n_params']
        print(emcee_chain.shape)
        good_walk = []
        
        for i in range(40):
            for n in range(chain_full.shape[2]):
                arr = chain_full[i, :, n].T
                arr2 = chain_full[i, :, n].T
                if arr[0]==arr[-1] or arr2[0]==arr2[-1]:
                    pass
                else:
                    good_walk.append(i)
        good_walk = np.unique(good_walk)
        print(prep, len(good_walk))
        chain_full = chain_full[np.array(good_walk), :, :]
        reshaped = chain_full.reshape(-1,chain_full.shape[-1])
        list_of_dists.append(reshaped)
        truths_arr = np.append(mu_test, std_test)
        truths.append(truths_arr)
        errors_arr = (reshaped.mean(axis=0) - truths_arr)
        err_perc_arr = (reshaped.mean(axis=0) - truths_arr)*100/truths_arr
        errors.append(errors_arr)
        err_perc.append(err_perc_arr)
        h5i +=1
    # print(truths)
    print(colors)
    # truths=[truths[0]]*num_preps
    return mu_test, std_test, mu_train, std_train, errors, err_perc, make_contour(list_of_dists, all_labels, categories, colors, truths_list=[truths[0], truths[0]],show_correlation=show_corr,save_fig=save_name)
    # make_contour(list_of_dists_std, std_labels, categories, colors, truths_list=[std_test]*num_preps, save_fig='sigma_'+save_name)


In [None]:
# h5_files = [f'{NLobj.prep}_1300_obj.h5',f'{NLobj.prep}_no_x_y_r_m.h5',]
params = [0]

select_obj = obj_list[params]
# h5_files = [o.h5_file for o in obj_list[params]]
h5_files = ['data/all_lsst_0325_uniform.h5']
# colors = ['mediumaquamarine', 'darkgreen']
# colors = ['gold', 'red']
colors=['rebeccapurple']
# colors[-3:] = ['red','blue','orange']
preps = [o.prep for o in select_obj]
# preps=['nolens','nolens']
# categories =[o.name for o in obj_list[params]]
# categories = ['Uniform','Mass traces Light','Bright Hosts','No Distribution Shift','Informative Prior',]
categories = [o.name for o in select_obj]
# categories = ['All Light Included: $q_l$ < $q_m$', 'All Light Included: $q_l$ = $q_m$']
# colors.insert(0, 'lightblue')
# colors.insert(1, 'salmon')
# categories = [f"Sample {i} from True Dist" for i in range(1,6)]
# categories.insert(0, "Unbiased Recovery - \nTrue Dist Mean")
# categories.insert(1, 'Paltas Recovery')
mu_test, std_test, mu_train, std_train, errors, err_perc, fig = hypermodel_plot([o.df for o in obj_list[params]], 
                      n_params_learned=[8],
              preps=preps,
              h5_files=h5_files,
              params=np.array([0,3,4,1]),colors = colors, categories=categories,burnin=3000,
             save_name=None, obj_list=obj_list[params]);
# fig.tight_layout()
axs = fig.axes 
axs = np.array(axs)
for ax in axs:
        for spine in ax.spines.values():
                spine.set_linewidth(3)

axs = axs.reshape(8,8)

nrows, ncols = axs.shape

for row in range(nrows):
    for col in range(ncols):
        ax = axs[row, col]

        xlabel = ax.get_xlabel()
        ylabel = ax.get_ylabel()
        ax.set_ylabel(ylabel, x=-0.4)
        ax.set_xlabel(xlabel, y=-0.4)

        # Left column → y ticks
        if col == 0:
            ax.minorticks_on()
            ax.tick_params(axis='y', which='both', direction='out', length=6, width=1)
        else:
            ax.tick_params(labelleft=False, left=False)

        # Bottom row → x ticks
        if row == nrows - 1:
            ax.minorticks_on()
            ax.tick_params(axis='x', which='both', direction='out', length=6, width=1)
        else:
            ax.tick_params(labelbottom=False, bottom=False)


### Figure 9: Inference of $\rm \gamma_{lens}$ Population Mean 

In [None]:
for obj in obj_list:
    print(obj.name, obj.y_test.shape)

In [None]:
# read in the chains just created

all_lsst = retrieve_chains_h5('data/all_lsst_0325_uniform.h5')
nolens_lsst = retrieve_chains_h5('data/nolens_lsst_0325_uniform.h5')

all_lsst=all_lsst[:, 3000:, :].reshape(-1,16)
nolens_lsst=nolens_lsst[:, 3000:, :].reshape(-1,16)

all_bright = retrieve_chains_h5('data/all_0327_bright_uniform.h5')
all_bright=all_bright[:, 3000:, :].reshape(-1,16)
nolens_bright = retrieve_chains_h5('data/nolens_0327_bright_uniform.h5')
nolens_bright=nolens_bright[:, 3000:, :].reshape(-1,16)
nosrc_bright = retrieve_chains_h5('data/nosrc_0327_bright_uniform.h5')
nosrc_bright=nosrc_bright[:, 3000:, :].reshape(-1,16)

all_deconv = retrieve_chains_h5('data/all_0325_uniform.h5')
nolens_deconv = retrieve_chains_h5('data/nolens_0325_uniform.h5')
nosrc_deconv = retrieve_chains_h5('data/nosrc_0325_uniform.h5')
all_deconv=all_deconv[:, 3000:, :].reshape(-1,16)
nolens_deconv=nolens_deconv[:, 3000:, :].reshape(-1,16)
nosrc_deconv=nosrc_deconv[:, 3000:, :].reshape(-1,16)


nc_all = retrieve_chains_h5('data/nc_all_om10_0325_uniform.h5')
nc_nolens = retrieve_chains_h5('data/nc_nolens_om10_0325_uniform.h5')
nc_all=nc_all[:, 3000:, :].reshape(-1,16)
nc_nolens=nc_nolens[:, 3000:, :].reshape(-1,16)

In [None]:
list_of_chain = np.array([all_lsst, nolens_lsst,all_bright, nolens_bright, nosrc_bright,  all_deconv, nolens_deconv, nosrc_deconv, nc_all, nc_nolens])
corresponding_obj_list = np.array([ALobj, NLobj, ALbobj, NLbobj, NSbobj, ALdeobj, NLdeobj, NSdeobj, nc_ALobj, nc_NLobj])

In [None]:
def return_mean_std_pop(obj, params):
    vec = np.array([])
    bad_obj = get_obj_of_wide_posteriors_obj(ALobj, params=np.arange(8))
    good_obj = [i for i in range(obj.num_obj) if i not in bad_obj]
    if type(params)!=list:
        params = [params]
    for p in params:
        vec = np.append(vec, obj.y_test[good_obj, p].mean())
        vec = np.append(vec, obj.y_test[good_obj, p].std())
    return vec

In [None]:
# def plot_param_recovery_all_preps(param):
param = 3
medians = np.median(list_of_chain, axis=1)[:, param].flatten()
scatters = np.std(list_of_chain,axis=1)[:, param].flatten()*2
all_means_scatters = np.array([return_mean_std_pop(o, param) for o in corresponding_obj_list]).T
true_means, true_scatters = all_means_scatters

# bad_dinos_obj = get_obj_of_wide_posteriors_obj(ALobj_dinos_ml)
# y_test_filtered_dinos = np.delete(ALobj_dinos_ml.y_test, bad_dinos_obj, axis=0)
# true_dinos_mean = np.mean(ALobj_dinos_ml.y_test[:, param],axis=0)
# true_dinos_scatter = np.std(ALobj_dinos_ml.y_test[:, param],axis=0)



# make a plot showing the deviation of the median of the chains from the true hyperparameters

# scatters = np.array([true_scatter, true_scatter, true_bright_scatter, true_bright_scatter, true_bright_scatter, true_scatter, true_scatter, true_scatter, true_scatter, true_scatter])
fig, ax = plt.subplots(figsize=(15,7))

colors=np.array(['rebeccapurple', 'mediumturquoise', 'rebeccapurple','mediumturquoise','red','rebeccapurple','mediumturquoise','red','rebeccapurple','mediumturquoise'])
ax.errorbar(np.arange(len(list_of_chain)), medians-true_means,yerr=scatters, fmt='none', ecolor='black',capsize=5)

ax.scatter(np.arange(len(list_of_chain)), medians-true_means,color=colors,s=80)
# ax.set_xticks(np.arange(len(list_of_chain)))
ax.axhline(0, color='k', lw='4', ls='--')
ax.axvline(1.5, color='gray', lw='2', ls='--')
ax.axvline(4.5, color='gray', lw='2', ls='--')
ax.axvline(7.5, color='gray', lw='2', ls='--')

ax.set_xticks([]);
# ax.text(0.05,-0.1, "Fiducial",transform=ax.transAxes)
ax.text(0.05, -0.08, r"$\mathbf{Fiducial}$", transform=ax.transAxes)
ax.text(0.35, -0.12, "Bright Host \nGalaxies", ha='center',transform=ax.transAxes)

ax.text(0.65, -0.09, "Deconvolved", ha='center',transform=ax.transAxes)


ax.text(0.93, -0.12, "Trained without\nMass Light Correlations", ha='center', transform=ax.transAxes)
legend_elements = [
    plt.Line2D([0], [0], marker='o', color='w', label='All Light Included', markerfacecolor='rebeccapurple', markersize=10),
    plt.Line2D([0], [0], marker='o', color='w', label='Lens Light Subtracted', markerfacecolor='mediumturquoise', markersize=10),
    plt.Line2D([0], [0], marker='o', color='w', label='Lens and AGN Light Subtracted', markerfacecolor='red', markersize=10)
]
ax.legend(handles=legend_elements, loc='upper center',  frameon=True, facecolor='white', framealpha=1)
ax.set_ylabel(r"$\Delta(\mathcal{M}" + f'({labels[param]})' + ")$")
ax.minorticks_on()
ax.tick_params(which='minor', length=4)
plt.show()


ax.set_xlabel("Experiment")

### Figure 10: Comparing the inferred cPDF against the true underlying distribution 

In [None]:
from lenstronomy.Util import param_util
def get_shear(param_array):
    # returns phi, gamma
    return param_util.shear_cartesian2polar(param_array[:, 1], param_array[:, 2])

def get_ellip(param_array):
    # return phi, q
    return param_util.ellipticity2phi_q(param_array[:, 4], param_array[:, 5])

In [None]:
from scipy.stats import multivariate_normal
import scipy
import copy
# shear_cartesian2polar,ellipticity2phi_q
def plot_interim_conditional(obj_list, pi=np.array([0,1,2,3,4,5]), pi_to_plot = [0, 3, 1, 4],plot_true_dist = True,plot_q_gamma = False):
    obj = obj_list[0]
    labels_new = labels.copy()
    learning_params =np.array( ['main_deflector_parameters_theta_E', 'main_deflector_parameters_gamma1',
                   'main_deflector_parameters_gamma2', 'main_deflector_parameters_gamma', 
                   'main_deflector_parameters_e1', 'main_deflector_parameters_e2', 'main_deflector_parameters_center_x',
                  'main_deflector_parameters_center_y', 'source_parameters_R_sersic', 'source_parameters_mag_app'])

    learning_params = learning_params[:obj.num_param]
    all_train = get_train_data(obj.df, obj.prep)
    y_test = copy.deepcopy(obj.y_test)

    if plot_q_gamma:
        phi_e, q = param_util.ellipticity2phi_q(np.array(all_train['main_deflector_parameters_e1']), np.array(all_train['main_deflector_parameters_e2']))
        all_train['main_deflector_parameters_q'] = q
        all_train['main_deflector_parameters_phi_e'] = phi_e

        phi_g, gamma_ext = param_util.shear_cartesian2polar(np.array(all_train['main_deflector_parameters_gamma1']), np.array(all_train['main_deflector_parameters_gamma2']))
        all_train['main_deflector_parameters_gamma_ext'] = gamma_ext
        all_train['main_deflector_parameters_phi_g'] = phi_g
        learning_params = np.array(['main_deflector_parameters_theta_E', 'main_deflector_parameters_gamma_ext','main_deflector_parameters_phi_g',
        'main_deflector_parameters_gamma', 'main_deflector_parameters_q', 'main_deflector_parameters_phi_e','main_deflector_parameters_center_x',
                  'main_deflector_parameters_center_y', 'source_parameters_R_sersic', 'source_parameters_mag_app'])
        
        labels_new[1] = '$\gamma_{ext}$'
        labels_new[2] = '$\phi_{\gamma}$'
        labels_new[4] = 'q'
        labels_new[5] = '$\phi_{q}$'
        phi_e_test, q_test = get_ellip(y_test)
        phi_g_test, gamma_ext_test =get_shear(y_test)
        y_test[:, 1], y_test[:, 2], y_test[:, 4], y_test[:, 5] = gamma_ext_test, phi_g_test, q_test, phi_e_test
    print(learning_params[pi])
    train_mean = np.mean(all_train[learning_params[pi]], axis=0)
    train_scatter = np.std(all_train[learning_params[pi]], axis=0)
    cov = np.zeros((len(pi), len(pi)))
    cov = np.fill_diagonal(cov,train_scatter[pi])
    mu_test = y_test.mean(axis=0)
    sigma_test = y_test.std(axis=0)
    if plot_true_dist:
        list_of_dists = [all_train[learning_params[pi]], y_test[:, pi]]
    else:
        list_of_dists = [all_train[learning_params[pi]]]
    # list_of_dists = [multivariate_normal(train_mean[pi],train_scatter[pi]**2).rvs(5000),
    #                         multivariate_normal(mu_test[pi],sigma_test[pi]**2).rvs(5000)]
    test_cov = scipy.stats.Covariance.from_diagonal(sigma_test[pi]**2).covariance
    train_cov = scipy.stats.Covariance.from_diagonal(train_scatter[pi]**2).covariance


    # print(a.covariance, sigma_test[pi]**2)
    
    # list_of_dists = [all_train[learning_params[pi]], y_test[:, pi]]
    if plot_true_dist:
        categories=['Interim Prior', "True Population"]
        colors=['gray','forestgreen']
    else:
        categories=['Training distribution: ' + r'$\rm p(\xi_k|\nu_{int})$: ' + 'samples \nfrom prior assumed for NPE posteriors']
        colors=['gray']
    if plot_true_dist:
        truths_list=[train_mean[pi], mu_test[pi]]
    else:
        truths_list = [train_mean[pi]]
    for obj in obj_list:
        all_chain = retrieve_chains_h5(obj.h5_file)
        all_chain = all_chain[:, 3000:, :].reshape(-1, obj.num_param*2)
        cpdf_mu = np.median(all_chain[:, :obj.num_param], axis=0)
        cpdf_sigma = np.median(all_chain[:, obj.num_param:], axis=0)
        hypermodel = multivariate_normal(cpdf_mu[pi], cpdf_sigma[pi]**2).rvs(5000)
        # kl = kl_mvn((cpdf_mu[pi], scipy.stats.Covariance.from_diagonal(cpdf_sigma[pi]**2).covariance), (train_mean[pi],train_cov))
        # print("KL divergence from prior to posterior: KL(posterior || prior): ",kl)
        if plot_q_gamma:
            phi_eh, qh = get_ellip(hypermodel)
            phi_gh, gamma_exth = get_shear(hypermodel)
            hypermodel[:, 1], hypermodel[:, 2], hypermodel[:, 4], hypermodel[:, 5] = gamma_exth, phi_gh, qh, phi_eh

        list_of_dists.append(hypermodel)
        categories.append(r"cPDF: $p(\xi_k|\nu)$"+": samples from \nprior assumed \nfor final posteriors")
        colors.append(obj.color)
        truths_list.append(mu_test[pi])
    # list_of_dists = np.array(list_of_dists)
    list_of_dists_new = []
    for l in list_of_dists:
        print(l.shape)
        l = np.array(l)
        list_of_dists_new.append(l[:, np.array(pi_to_plot)])

    truths_list = np.array(truths_list)
    labels_new = np.array(labels_new)
    print(truths_list)
    # print('list_of_dists shape: ', len(list_of_dists), len(list_of_dists[0]),len(list_of_dists[1]),len(list_of_dists[2]))
    fig = make_contour(list_of_dists_new,labels_new[pi_to_plot],categories,colors,truths_list = [np.array(truths_list[1])[pi_to_plot]]*len(list_of_dists_new))
    axes = np.array(fig.axes).reshape((len(pi_to_plot), len(pi_to_plot)))
    # param_cov = scipy.stats.Covariance.from_diagonal(cpdf_sigma[pi]**2).covariance
    # kl_list =[ukl((cpdf_mu[i], param_cov[i,i]), (train_mean[i],train_cov[i,i])) for i in pi]
    # for i in range(len(pi)):
    #     axes[i, i].set_title(f'IG = {np.round(kl_list[i],2)} nats', loc='left')
    return fig, axes

In [None]:
pi_to_plot_params = np.array([0, 1, 2, 3, 4, 5])

fig, axes = plot_interim_conditional([obj_list[0]], pi=np.array([0,1,2,3,4,5,6,7]),plot_true_dist=True,pi_to_plot=pi_to_plot_params,plot_q_gamma=True)


### Figure 11: Population inference with different mass-light correlations - Comparison to Dinos

In [None]:
from matplotlib.gridspec import GridSpec
from scipy.stats import gaussian_kde
from latils import make_analysis_table
from lenstronomy.Util.param_util import ellipticity2phi_q, shear_cartesian2polar

# h5_files = [f'{NLobj.prep}_1300_obj.h5',f'{NLobj.prep}_no_x_y_r_m.h5',]
params = [0,1]
select_obj = obj_list[params]
# h5_files = [o.h5_file for o in obj_list[params]]
h5_files = ['data/all_lsst_0325_uniform.h5','data/all_0429_dinos_ml.h5']
# colors = ['mediumaquamarine', 'darkgreen']
# colors = ['gold', 'red']
colors=['rebeccapurple','orange']
# colors[-3:] = ['red','blue','orange']
preps = [o.prep for o in select_obj]

# preps=['nolens','nolens']
# categories =[o.name for o in obj_list[params]]
# categories = ['Uniform','Mass traces Light','Bright Hosts','No Distribution Shift','Informative Prior',]
categories = [r"Fiducial: $q_{light} \sim q_{mass}$", r"Dinos 2: $q_{mass} \geq q_{light} - 0.1$"]
# categories = ['All Light Included: $q_l$ < $q_m$', 'All Light Included: $q_l$ = $q_m$']
# colors.insert(0, 'lightblue')
# colors.insert(1, 'salmon')
# categories = [f"Sample {i} from True Dist" for i in range(1,6)]
# categories.insert(0, "Unbiased Recovery - \nTrue Dist Mean")
# categories.insert(1, 'Paltas Recovery')
mu_test, std_test, mu_train, std_train, errors, err_perc, fig = hypermodel_plot([o.df for o in obj_list[params]], 
                      n_params_learned=[8,8],
              preps=preps,
              h5_files=h5_files,
              params=np.array([0,3,4,1]),colors = colors, categories=categories,burnin=3000,
             save_name=None, obj_list=obj_list[params]);
# fig.tight_layout()
axs = fig.axes 

axs = np.array(axs)
for ax in axs:
        for spine in ax.spines.values():
                spine.set_linewidth(2)

               
axs=axs.reshape(8,8)
all_train = pd.read_csv('data/all_train.csv')
# all2_train = get_train_data(ALobj2.df, 'all')
amtr,bmtr=ellipticity2phi_q(np.array([all_train['main_deflector_parameters_e1']]),np.array([all_train['main_deflector_parameters_e2']]))
altr,bltr=ellipticity2phi_q(np.array([all_train['lens_light_parameters_e1']]),np.array([all_train['lens_light_parameters_e2']]))

all = pd.read_csv('data/fiducial_test_data.csv')
am1,bm1=ellipticity2phi_q(np.array([all['main_deflector_parameters_e1']]),np.array([all['main_deflector_parameters_e2']]))
al2,bl1=ellipticity2phi_q(np.array([all['lens_light_parameters_e1']]),np.array([all['lens_light_parameters_e2']]))

all2 = pd.read_csv('data/all_test_dinos_ellipticity.csv')
am2,bm2=ellipticity2phi_q(np.array([all2['main_deflector_parameters_e1']]),np.array([all2['main_deflector_parameters_e2']]))
al2,bl2=ellipticity2phi_q(np.array([all2['lens_light_parameters_e1']]),np.array([all2['lens_light_parameters_e2']]))

# Create a new axis at position (2, 5) that spans the space of 4 axes

# Remove the existing axis at (2, 5)
fig.delaxes(axs[0, 5])

# Add a new axis that spans the space of 4 axes
gs = axs[2, 5].get_gridspec()
new_ax = fig.add_subplot(gs[0:2, 5:7])

# Scatter plot spanning the space of 4 axes on the figure
# new_ax.scatter(bltr, bmtr, color='gray', label='Train', alpha=0.6)
# Stack the arrays to shape (2, N)
train_data = np.vstack([bltr.flatten(), bmtr.flatten()])
bltr, bmtr = bltr.flatten(), bmtr.flatten()
# idx = np.random.choice(len(bltr), 5000, replace=False)


new_ax.scatter(bltr[:100000], bmtr[:100000], color='gray', alpha=0.4, label='Training Distribution')
new_ax.scatter(bl1, bm1, color='rebeccapurple', label='Fiducial', alpha=0.6)
new_ax.scatter(bl2, bm2, color='orange', label=r'Dinos 2: $q$ Mass-Light Relationship', alpha=0.3)

# Add legend and labels
new_ax.set_xlabel('$q_{light}$')
new_ax.set_ylabel('$q_{mass}$')
new_ax.set_title('Train vs Test Distribution')
fig.suptitle("Population-Level Parameter Recovery", x=0.3,y=1)

nrows, ncols = axs.shape

for row in range(nrows):
    for col in range(ncols):
        ax = axs[row, col]

        xlabel = ax.get_xlabel()
        ylabel = ax.get_ylabel()
        ax.set_ylabel(ylabel, x=-0.4)
        ax.set_xlabel(xlabel, y=-0.4)

        # Left column → y ticks
        if col == 0:
            ax.minorticks_on()
            ax.tick_params(axis='y', which='both', direction='out', length=6, width=1)
        else:
            ax.tick_params(labelleft=False, left=False)

        # Bottom row → x ticks
        if row == nrows - 1:
            ax.minorticks_on()
            ax.tick_params(axis='x', which='both', direction='out', length=6, width=1)
        else:
            ax.tick_params(labelbottom=False, bottom=False)

### Figure 12: Relationship between bias on individual posteriors and bias in population recovery

In [None]:
# Load chains from both files and reshape
chain1 = retrieve_chains_h5('data/all_lsst_0325_uniform.h5')
chain2 = retrieve_chains_h5('data/all_june10_test12.h5')
chain1 = chain1[:, 3000:, :].reshape(-1, chain1.shape[2])
chain2 = chain2[:, 3000:, :].reshape(-1, chain2.shape[2])

# Select parameter indices to plot (example: first two parameters)
param_indices1 = [3,11]
param_indices2 = [0,1]
colors=['rebeccapurple', 'mediumturquoise']
# Prepare data for make_contour
chains_to_plot = [chain1[:, param_indices1], chain2[:, param_indices2]]
categories = ['Fiducial',"Learned only $\mathcal{M}(\gamma_{lens})$"]
# Use categories and colors already defined
mu_test = ALobj.y_test.mean(axis=0)[np.array([3])]
std_test = ALobj.y_test.std(axis=0)[np.array([3])]
fig = make_contour(
    chains_to_plot,
    labels=[mu_labels[3],std_labels[3]],
    categories=categories,
    colors=colors,
    truths_list=[np.array([mu_test, std_test]).flatten(),np.array([mu_test, std_test]).flatten()]
)

for ax in fig.axes:
    xlabel=ax.get_xlabel()
    ylabel=ax.get_ylabel()
    ax.tick_params(axis='both', which='both', labelsize=40)
    ax.set_ylabel(ylabel,fontsize=70, x=-0.2)
    ax.set_xlabel(xlabel,fontsize=70,y=-0.2)
    ax.minorticks_on()
    ax.tick_params(axis='both', which='minor', length=10, width=1)
fig.tight_layout()
# plt.savefig("npe_post_to_hyper.pdf", dpi=300)

### Figure 13: Plotting the cPDF in cartesian space, displaying against the final posteriors

In [None]:
pi_to_plot_params = np.array([0, 3, 1, 4])
ALobj.h5_file = 'data/all_lsst_0325_uniform.h5'
fig, axes = plot_interim_conditional([obj_list[0]], pi=np.array([0,1,2,3,4,5,6,7]),plot_true_dist=False,pi_to_plot=pi_to_plot_params,plot_q_gamma=False)
axs = fig.axes 
axs = np.array(axs)
               
axs=axs.reshape(len(pi_to_plot_params), len(pi_to_plot_params))
nrows, ncols = axs.shape

for row in range(nrows):
    for col in range(ncols):
        ax = axs[row, col]
        for spine in ax.spines.values():
            spine.set_linewidth(4)
        xlabel = ax.get_xlabel()
        ylabel = ax.get_ylabel()
        ax.set_ylabel(ylabel,size=40, x=-0.28)
        ax.set_xlabel(xlabel,size=40, y=-0.3)
        ax.xaxis.label.set_fontsize(65)
        ax.yaxis.label.set_fontsize(65)

        # Left column → y ticks
        labelsize=40
        if col == 0:
            ax.minorticks_on()
            ax.tick_params(axis='y',labelsize=labelsize, which='both', direction='out', length=6, width=1)
            ax.yaxis.set_tick_params(labelsize=labelsize)
        else:
            ax.tick_params(labelleft=False, left=False)

        # Bottom row → x ticks
        if row == nrows - 1:
            ax.minorticks_on()
            ax.tick_params(axis='x', labelsize=labelsize,which='both', direction='out', length=6, width=1)
            ax.xaxis.set_tick_params(labelsize=labelsize)

        else:
            ax.tick_params(labelbottom=False, bottom=False)

for ax in axes[:, 1]:
    ax.set_xlim(1.4, 2.6)
for ax in axes[1, :1]:
    ax.set_ylim(1.4, 2.6)

for ax in axes[:, 0]:
    ax.set_xlim(0, 4)
for ax in axes[0, :0]:
    ax.set_xlim(0, 4)

leg = fig.legends[0]   # get the first legend in the figure
leg.set_bbox_to_anchor((0.32, 0.8))  # adjust these numbers

# fig.savefig('cPDF_fiducial_gray_purple_e1_g1.pdf',bbox_inches='tight',dpi=300)