In [None]:
# REQUIRED IMPORTS
# FILE MANAGEMENT
import h5py
import os
from IPython.utils import io

%load_ext autoreload
%autoreload 2

# DATA MANIPULATION
import numpy as np
import pandas as pd


# STATISTICS
from scipy.stats import norm, truncnorm, uniform, beta, multivariate_normal

# VISUALIZATION
import matplotlib
import matplotlib.pyplot as plt
from astropy.visualization import simple_norm
import matplotlib.colors as mpc
import corner
from matplotlib.patches import Patch
import astropy.visualization as asviz


SMALL_SIZE = 17
MEDIUM_SIZE = 20
BIGGER_SIZE = 30

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rcParams["font.family"] = "DejaVu Serif"

%matplotlib inline

In [None]:
learning_params =np.array( ['main_deflector_parameters_theta_E', 'main_deflector_parameters_gamma1',
                   'main_deflector_parameters_gamma2', 'main_deflector_parameters_gamma', 
                   'main_deflector_parameters_e1', 'main_deflector_parameters_e2', 'main_deflector_parameters_center_x',
                  'main_deflector_parameters_center_y', 'source_parameters_R_sersic', 'source_parameters_mag_app'])

labels = np.array(["$\\theta_E$", "$\gamma_1$", "$\gamma_2$", "$\gamma_{lens}$", "$e_1$", "$e_2$",
                   '$x_D$', '$y_D$', '$R_{src}$', "$m_{i}$"])
mu_labels = np.array(['$\mathcal{M}(' + i[1:-1] + ')$' for i in np.array(labels)])
std_labels = np.array(['$\Sigma(' + i[1:-1] + ')$' for i in np.array(labels)])

labels_dict = dict(zip(learning_params, labels))

In [None]:
#### latils imports
from latils import make_contour , get_train_data,  get_obj_of_wide_posteriors_obj, make_metrics_table, retrieve_chains_h5, make_results_df_without_training

# from lenstronomy.Util.param_util import ellipticity2phi_q, shear_cartesian2polar


In [None]:
from latils import prepRes

In [None]:
from latils import make_contour , get_train_data,  get_obj_of_wide_posteriors_obj, make_metrics_table, retrieve_chains_h5, make_results_df_without_training

### fiducial preparation
ALobj = prepRes('full', 'all',8,learning_params[:8],'All Light Included','rebeccapurple','0118/all','0325/all_no_single','full/0118/all', 'data/0325_results',mode_of_stopping='early_stopping')

### orange dataset: applied network trained with mass-light correlated to test with different mass-light ellipticity correlations
ALobj_dinos_ml = prepRes('full', 'all',8,learning_params[:8],'All Light Included','rebeccapurple','0118/all','0429/all_dinos_ml','full/0118/all', 'data/dinos_qm_ql',mode_of_stopping='early_stopping')

NLobj = prepRes('full', 'nolens',8,learning_params[:8],'Lens Light Subtracted','mediumaquamarine','0118/nolens','0325/nolens','full/0118/nolens', 'data/0325_results',mode_of_stopping='early_stopping')

NSobj = prepRes('full', 'nosrc',10,learning_params[:10],'Lens+AGN Light Subtracted','gold','0118/nosrc','0325/nosrc','full/0118/nosrc', 'data/0325_results',mode_of_stopping='early_stopping')

### trained without correlations
nc_ALobj = prepRes('full', 'nc_all_om10',8,learning_params[:8],'NC:All Light Included','darkgoldenrod','0310/all','0325/all_no_single','full/0310/all', 'data/0325_train_no_corr',mode_of_stopping='early_stopping')
nc_NLobj = prepRes('full', 'nc_nolens_om10',8,learning_params[:8],'NC:Lens Light Subtracted','rebeccapurple','0310/nolens','0325/nolens','full/0310/nolens', 'data/0325_train_no_corr',mode_of_stopping='early_stopping')

### valid_results
ALvobj = prepRes('full', 'valid_all',8,learning_params[:8],'V: All Light Included','navajowhite','0118/all','0118/valid_all','full/0118/all', 'data/0118_valid_results',mode_of_stopping='early_stopping')
NLvobj = prepRes('full', 'valid_nolens',8,learning_params[:8],'V: Lens Light Subtracted','aquamarine','0118/nolens','0118/valid_nolens','full/0118/nolens', 'data/0118_valid_results',mode_of_stopping='early_stopping')

nc_ALvobj = prepRes('full', 'nc_all',8,learning_params[:8],'NC:All Light Included','darkgoldenrod','0310/all','0118/valid_all','full/0310/all', 'data/0320_results_held_out',mode_of_stopping='early_stopping')
nc_NLvobj = prepRes('full', 'nc_nolens',8,learning_params[:8],'NC:Lens Light Subtracted','rebeccapurple','0310/nolens','0118/valid_nolens','full/0310/nolens', 'data/0320_results_held_out',mode_of_stopping='early_stopping')

### diag_results
ALdobj = prepRes('diag', 'all_diag',8,learning_params[:8],'D: All Light Included','tan','0118/all','0325/all_no_single','diagonal/0118/all', 'data/0118_diag_results',mode_of_stopping='early_stopping')

### bright_results
ALbobj = prepRes('full', 'allb',8,learning_params[:8],'All Light Included','gold','0118/all','0118/all/bright','full/0118/all', 'data/0118_bright_results',mode_of_stopping='early_stopping')
NLbobj = prepRes('full', 'nolensb',8,learning_params[:8],'Lens Light Subtracted','mediumaquamarine','0118/nolens','0118/nolens/bright','full/0118/nolens', 'data/0118_bright_results',mode_of_stopping='early_stopping')
NSbobj = prepRes('full', 'nosrcb',10,learning_params[:10],'Lens+AGN Light Subtracted','rebeccapurple','0118/nosrc','0118/nosrc/bright','full/0118/nosrc', 'data/0118_bright_results',mode_of_stopping='early_stopping')


### deconv_results
ALdeobj = prepRes('full', 'all',10,learning_params[:10],'EM-D: All Light Included','gold','0118/deconvolved/all','0325/deconvolved/all','full/0118/deconvolved/all', 'data/0325_deconv_results',mode_of_stopping='early_stopping')
NLdeobj = prepRes('full', 'nolens',10,learning_params[:10],'EM-D: Lens Light Subtracted','mediumaquamarine','0118/deconvolved/nolens','0325/deconvolved/nolens','full/0118/deconvolved/nolens', 'data/0325_deconv_results',mode_of_stopping='early_stopping')
NSdeobj = prepRes('full', 'nosrc',10,learning_params[:10],'EM-D: Lens+AGN Light Subtracted','mediumpurple','0118/deconvolved/nosrc','0325/deconvolved/nosrc','full/0118/deconvolved/nosrc', 'data/0325_deconv_results',mode_of_stopping='early_stopping')



In [None]:
obj_list = np.array([ALobj,NLobj, ALbobj, NLbobj, NSbobj, ALdeobj, NLdeobj, NSdeobj, nc_ALobj, nc_NLobj, ALvobj, NLvobj])

### Figure 5: Parameter Recovery

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

params = [0, 3, 1, 4]
obj = ALobj

fig = plt.figure(figsize=(19, 17))
outer = gridspec.GridSpec(2, 2, wspace=0.3, hspace=0.3)  # 4 main panels

for idx, param_idx in enumerate(params):
    # Create subgrid inside this main panel
    inner = gridspec.GridSpecFromSubplotSpec(
        3, 1, 
        subplot_spec=outer[idx],
        height_ratios=[2, 1, 1],
        hspace=0.05
    )

    y_pred = obj.y_pred
    y_test = obj.y_test
    std_pred = obj.std_pred
    error = y_pred[:, param_idx] - y_test[:, param_idx]
    me_in_sigma = error / std_pred[:, param_idx]

    # Stats
    correlation = np.corrcoef(y_test[:, param_idx], y_pred[:, param_idx])[0, 1]
    mean_error = np.mean(error)

    ax_top = fig.add_subplot(inner[0])
    ax_top.errorbar(
        y_test[:, param_idx],
        y_pred[:, param_idx],
        yerr=std_pred[:, param_idx],
        fmt='o',
        alpha=0.7,
        color='rebeccapurple',
        ecolor='black'
    )
    ax_top.plot(y_test[:, param_idx], y_test[:, param_idx], ls='--', color='r')
    ax_top.set_ylabel('Prediction')
    ax_top.set_title(labels[param_idx])
    ax_top.text(0.6, 0.1, f'Correlation: {correlation:.2f}',
                fontsize=MEDIUM_SIZE, transform=ax_top.transAxes)

    ax_bottom = fig.add_subplot(inner[1], sharex=ax_top)
    ax_bottom.scatter(
        y_test[:, param_idx],
        error,
        alpha=0.7,
        color='k'
    )
    ax_bottom.axhline(0, color='r', ls='--', alpha=0.7)
    ax_bottom.set_ylabel(r'Error')

    props1 = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    if idx==0:
        ax_bottom.text(0.03, 0.8, f'Mean Error: {mean_error:.2f}',
                    fontsize=MEDIUM_SIZE, transform=ax_bottom.transAxes, bbox=props1)
    elif idx==2:
        ax_bottom.text(0.03, 0.09, f'Mean Error: {mean_error:.2f}',
                    fontsize=MEDIUM_SIZE, transform=ax_bottom.transAxes, bbox=props1)
    else:
        ax_bottom.text(0.66, 0.8, f'Mean Error: {mean_error:.2f}',
                    fontsize=MEDIUM_SIZE, transform=ax_bottom.transAxes, bbox=props1)

    ax_bottom2 = fig.add_subplot(inner[2], sharex=ax_top)

    ax_bottom2.fill_between(
        x=[np.min(y_test[:, param_idx]), np.max(y_test[:, param_idx])],
        y1=-1, y2=1,
        color='green', alpha=0.15,
        label=r'$\pm 1\sigma$'
    )

    ax_bottom2.fill_between(
        x=[np.min(y_test[:, param_idx]), np.max(y_test[:, param_idx])],
        y1=-2, y2=2,
        color='orange', alpha=0.10,
        label=r'$\pm 2\sigma$'
    )

    ax_bottom2.scatter(
        y_test[:, param_idx],
        me_in_sigma,
        alpha=0.7,
        color='k'
    )
    ax_bottom2.axhline(0, color='r', ls='--', alpha=0.7)
    ax_bottom2.set_ylabel(r'Error/$\sigma$')
    ax_bottom2.set_xlabel('Truth')

    for axi in [ax_top, ax_bottom, ax_bottom2]:
        axi.minorticks_on()
        axi.tick_params(which='minor', length=4)

fig.tight_layout()
# fig.savefig('pred_vs_truth_4panels.pdf', dpi=600)
plt.show()


### Figure 6: Calibration

In [None]:
from matplotlib.lines import Line2D
plt.rcParams["font.family"] = "DejaVu Serif"
from paltas.Analysis import posterior_functions as pf


params = [0,3,1,4]
fig, ax =plt.subplots(figsize=(9, 8))
# ax = ax.flatten()
chosen_obj = ALobj
colors=['red', 'orange', 'blue', 'green']
for i in range(len(params)):
    y_test = chosen_obj.y_test
    y_pred = chosen_obj.y_pred
    std_pred = chosen_obj.std_pred
    cov_pred = chosen_obj.cov_pred
    predict_samps = np.array([multivariate_normal.rvs(y_pred[k, :], cov_pred[k, :, :], size=5000) for k in range(y_pred.shape[0])]).transpose(1, 0, 2)
    predict_samps = predict_samps[:, :, params[i]]
    pf.plot_calibration(
        predict_samps=predict_samps,
        y_test=y_test[:, params[i]],
        figure=fig,
        ax=ax,
        title='Calibration Per Parameter',
        block=True,
        legend=None,  # We'll handle the legend manually
        color_map=['k', colors[i]]
    )

# Create custom legend handles for each parameter
legend_handles = [Line2D([0], [0], color=colors[j], lw=4, label=labels[params[j]]) for j in range(len(params))]
ax.legend(handles=legend_handles, loc='upper left', ncols=2, bbox_to_anchor=(0.02, 1), fontsize=MEDIUM_SIZE)
ax.set_xlabel("Percentage of Probability Volume")
ax.set_ylabel("Percent of Lenses With True Value in the Volume")
ax.text(0.08, 0.75, "Underconfident")
ax.text(0.7, 0.3, "Overconfident")
ax.minorticks_on()
ax.tick_params(which='minor', length=4)

### Look at Table values

In [None]:
from latils import make_analysis_table, get_stats


In [None]:
o = ALobj
params = np.array([0, 3, 4, 5, 1, 2, 6, 7])
all_corr = get_stats(o.y_test, o.y_pred, o.std_pred, params=params)
pd.DataFrame(all_corr, index=["Correlation", "Mean Error", "MAE", "Precision"], columns = labels[params])

### Figure 8: Comparison of Lens Light Subtraction under different data distribution shifts

In [None]:
import matplotlib as mpl
fig, ax = plt.subplots(2,2,figsize=(20,12))
ax = ax.flatten()
param = 3
objects = [ALobj, ALvobj, NLobj, NLvobj]
labels_prefix = ['All Light Included', 'All Light Included', 'Lens Light Subtracted', 'Lens Light Subtracted']
for i in range(4):
    obj = objects[i]
    y_pred = obj.y_pred[:, param]
    y_test = obj.y_test[:, param]
    std_pred = obj.std_pred[:, param]
    me_in_sigma = np.abs(y_test - y_pred)/std_pred
    cm = ax[i].scatter(y_test, y_pred, c=me_in_sigma,alpha=0.7, label='All Light Included')
    ax[i].plot(y_test,y_test,ls='--',color='r')
    ax[i].set_title(labels_prefix[i] +': ' +labels[param])
    ax[i].set_xlabel('Truth')
    ax[i].set_ylabel('Prediction')
    ax[i].figure.colorbar(cm, ax=ax[i], label='Error/$\sigma$')
    for spine in ax[i].spines.values():
        spine.set_linewidth(2)
    correlation = np.round(np.corrcoef(y_test, y_pred)[0][1], 2)
    mean_error = np.round(np.mean(y_pred - y_test), 2)
    median_absolute_error = np.round(np.median(np.abs(y_pred - y_test)), 2)
    median_precision = np.round(np.median(std_pred), 2)
    ax[i].text(0.6,0.1,
             'Correlation: %.2f'%(correlation),{'fontsize':MEDIUM_SIZE}, transform=ax[i].transAxes)
    props1 = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax[i].text(0.05,0.9,
            'Mean Error: %.2f'%(mean_error),{'fontsize':MEDIUM_SIZE},transform=ax[i].transAxes, bbox=props1)
    ax[i].text(0.05,0.8,
            'Median Precision: %.2f'%(median_precision),{'fontsize':MEDIUM_SIZE},transform=ax[i].transAxes, bbox=props1)

fig.tight_layout();
fig.text(0.2, 1, 'OM10 Test Set', ha='center', fontsize=BIGGER_SIZE)

fig.text(0.7, 1, 'No Distribution Shift', ha='center', fontsize=BIGGER_SIZE)



### Figure 13: Plot re-weighted posteriors

In [None]:
### Load in existing final posteriors (if they exist)
file_path = 'data/final_posteriors_0910.h5'
h5f = h5py.File(file_path, 'r')
chain_names = list(h5f.keys())
print(chain_names)
all_samples = []
all_weights = []
for i in range(len(chain_names)//2):
    ind = i
    samps, weights = h5f[f'samples_{ind}'][:], h5f[f'weights_{ind}'][:]
    all_samples.append(samps)
    all_weights.append(weights)

# for i in range(20
h5f.close()

In [None]:
from latils import get_train_data, learning_params, labels, labels_dict
full_df = ALobj.df
prep='all'
wide_post = get_obj_of_wide_posteriors_obj(ALobj)
obj_index=np.array([i for i in range(len(ALobj.y_test)) if i not in wide_post])

y_pred= np.delete(ALobj.y_pred, wide_post, axis=0)

y_test= np.delete(ALobj.y_test, wide_post, axis=0)
std_pred = np.delete(ALobj.std_pred, wide_post, axis=0)
prec_pred = np.delete(ALobj.prec_pred, wide_post, axis=0)
cov_pred = np.delete(ALobj.cov_pred, wide_post, axis=0)
train_data = get_train_data(full_df, prep)
train_mean = np.array(train_data[learning_params[:8]].mean(axis=0))
train_scatter = np.array(train_data[learning_params[:8]].std(axis=0))
chain = retrieve_chains_h5('data/all_lsst_0325_uniform.h5')

In [None]:
def reweighted_table(samples_list,weights_list):

    # get y_pred and std_pred from chains
    y_pred = np.empty((len(samples_list),samples_list[0].shape[1]))
    print(y_pred.shape)
    std_pred = np.empty((len(samples_list),samples_list[0].shape[1]))
    for i in range(len(samples_list)):
        samps = samples_list[i]
        weights = weights_list[i]
        weights = weights.reshape(len(weights),1)
        mus = np.sum(samps*weights,axis=0)/np.sum(weights,axis=0)
        stds = np.sqrt(np.sum(weights*(samps - mus)**2,axis=0)/np.sum(weights,axis=0))
        y_pred[i,:] = mus
        std_pred[i,:] = stds

    # visualization_utils.table_metrics(y_pred,y_truth,std_pred,file_name)

    return y_pred,std_pred

In [None]:
all_samples = np.array(all_samples)
all_weights = np.array(all_weights)
y_predr, std_predr = reweighted_table(all_samples,all_weights)

In [None]:
params = np.array([0, 3, 1, 4])

fig, ax = plt.subplots(2,len(params)//2, figsize=(12, 10))
ax = ax.flatten()
obj_idx = np.arange(len(y_predr))
random_samples = np.random.choice(obj_idx, 200)
k = 0
for i in params:
    ax[k].errorbar(y_test[random_samples, i], y_pred[random_samples, i],yerr=std_pred[random_samples,i],fmt='none', ecolor='gray',elinewidth=3,alpha=0.3)
    ax[k].scatter(y_test[random_samples, i], y_pred[random_samples, i], color='gray',edgecolor='k', alpha=0.6, label='Interim Posteriors' if i == 0 else "")

    ax[k].errorbar(y_test[random_samples, i], y_predr[random_samples, i],yerr=std_predr[random_samples,i],ms=7,fmt='o',ecolor='rebeccapurple',elinewidth=3,alpha=0.9,mfc='rebeccapurple',mec='k', label='Final Posteriors' if i == 0 else "")    
    ax[k].plot(y_test[:, i], y_test[:, i], lw=3,color='r',ls='--')
    ax[k].set_title(labels[i])
    ax[k].minorticks_on()
    ax[k].title.set_fontsize(30)
    ax[k].tick_params(axis='both', which='major', labelsize=22, length=8, width=2)
    ax[k].tick_params(axis='both', which='minor', labelsize=18, length=5, width=1.5)
    k += 1
fig.suptitle("Final Posteriors")
fig.supylabel("Recovered Posterior Means",size=30)
fig.supxlabel("Ground Truth",size=30)
fig.legend(ncols=1,loc=(0.24,0.57))
fig.tight_layout();
# plt.savefig('final_post_scatter.pdf', dpi=300)