In [None]:
%load_ext autoreload
%autoreload 2
from lens_catalog import OM10LensCatalog
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
from Utils.inference_utils import median_sigma_from_samples
import h5py
import pandas as pd
import corner
from matplotlib.lines import Line2D
import Utils.mcmc_utils as mcmc_utils
import tdc_utils

In [None]:
gt_lenses = OM10LensCatalog('MassModels/om10_sample/om10_venkatraman_erickson24.csv')
gt_lenses.lens_df.keys()

### Define Gold/Silver Samples ###

We split the sample into doubles and quads. We already have a sample of ~30 quads (the STRIDES sample), so we take a conservative assumption of 50 quads in the gold sample. We add 200 doubles to the gold sample.
The remaining lenses are added to the silver sample. 

In [None]:
# split into doubles & quads
dbls = gt_lenses.doubles_indices()
quads = gt_lenses.quads_indices()

# require mag_app_src < 24
mag_limit = True
if mag_limit:
    # NOTE: hardcoded mag limit!!
    src_mag_app_lim = 24.
    bright_idxs = np.where(gt_lenses.lens_df['src_mag_app'] < src_mag_app_lim)[0]

    # find the overlap
    dbls = np.intersect1d(dbls, bright_idxs)
    quads = np.intersect1d(quads,bright_idxs)

# Let's assume in the gold sample: 200 doubles, 50 quads (overall quad fraction is 11%, so this is amplified)
# The rest are silver (regardless of time delay)
gold_dbls = dbls[-200:]#dbls[:200]
silver_dbls = dbls[:-200]#dbls[200:]

gold_quads = quads[-50:]#quads[:50]
silver_quads = quads[:-50]#quads[50:]

gold = np.append(gold_dbls,gold_quads)
silver = np.append(silver_dbls,silver_quads)

mu_mean_gold = np.mean(gt_lenses.lens_df.loc[gold,'gamma'].to_numpy().astype(float))
std_mean_gold = np.std(gt_lenses.lens_df.loc[gold,'gamma'].to_numpy().astype(float),ddof=1)

print('mean gamma of gold sample: ', mu_mean_gold)
print('sigma gamma of gold sample: ', std_mean_gold)

mu_mean_silver = np.mean(gt_lenses.lens_df.loc[silver,'gamma'].to_numpy().astype(float))
std_mean_silver = np.std(gt_lenses.lens_df.loc[silver,'gamma'].to_numpy().astype(float),ddof=1)

print('mean gamma of silver sample: ', mu_mean_silver)
print('sigma gamma of silver sample: ', std_mean_silver)

### Make Data Vectors ###

This is all done using the make_data_vectors() function in make_data_vectors.py

1. Emulates time delay measurements
2. Uses NPE posteriors to compute samples from fermat potential difference posteriors
3. Stores all information in format for fast_tdc()

In [None]:
from make_data_vectors import make_data_vectors

# load in NPE posteriors
hst_mu = np.load('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/MassModels/om10_sample/y_pred_hst.npy')
hst_cov = np.load('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/MassModels/om10_sample/cov_pred_hst.npy')

lsst_mu = np.load('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/MassModels/om10_sample/y_pred_lsst_DEBIASED2.npy')
lsst_cov = np.load('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/MassModels/om10_sample/cov_pred_lsst_DEBIASED2.npy')

# where to store this test
exp_folder = 'fpd_eval_DV_silverALLDEBIASED'

# gold quads
make_data_vectors(gt_lenses,gold_quads,num_images=4,td_meas_error=2,
    npe_mu=hst_mu,npe_cov=hst_cov,
    h5_save_path=('DataVectors/'+exp_folder+'/gold_quads.h5'),
    num_fpd_samps=3000,emulated=False)

# gold dbls
make_data_vectors(gt_lenses,gold_dbls,num_images=2,td_meas_error=2,
    npe_mu=hst_mu,npe_cov=hst_cov,
    h5_save_path=('DataVectors/'+exp_folder+'/gold_dbls.h5'),
    num_fpd_samps=3000,emulated=False)

# silver quads
make_data_vectors(gt_lenses,silver_quads,num_images=4,td_meas_error=5,
    npe_mu=lsst_mu,npe_cov=lsst_cov,
    h5_save_path=('DataVectors/'+exp_folder+'/silver_quads.h5'),
    num_fpd_samps=3000,emulated=False)

# silver dbls
make_data_vectors(gt_lenses,silver_dbls,num_images=2,td_meas_error=5,
    npe_mu=lsst_mu,npe_cov=lsst_cov,
    h5_save_path=('DataVectors/'+exp_folder+'/silver_dbls.h5'),
    num_fpd_samps=3000,emulated=False)

### Check fermat potential bias ### 

In [None]:
# where data vectors are stored
exp_folder = ('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/'+
        'DataVectors/fpd_eval_DV_silverALLDEBIASED')

lens_types = ['gold_quads','gold_dbls']

inputs_dict = {
    'gold_quads':{},
    'gold_dbls':{},
}

input_keys = ['measured_td','measured_prec','prefactor','fpd_samps','gamma_samps','z_lens_truth','z_src_truth']

for l in lens_types:
    my_filepath = (exp_folder+'/'+l+'.h5')
    h5f = h5py.File(my_filepath, 'r')
    for key in input_keys:
        inputs_dict[l][key] = h5f.get(key)[:]
    h5f.close()

quads_fpd_samps = inputs_dict['gold_quads']['fpd_samps']
quads_median_fpd = np.median(quads_fpd_samps,axis=1)
low = np.quantile(quads_fpd_samps,q=0.1586,axis=1)
high = np.quantile(quads_fpd_samps,q=0.8413,axis=1)
quads_sigma_fpd = ((high-quads_median_fpd)+(quads_median_fpd-low))/2

dbls_fpd_samps = inputs_dict['gold_dbls']['fpd_samps']
dbls_median_fpd = np.median(dbls_fpd_samps,axis=1)
low = np.quantile(dbls_fpd_samps,q=0.1586,axis=1)
high = np.quantile(dbls_fpd_samps,q=0.8413,axis=1)
dbls_sigma_fpd = ((high-dbls_median_fpd)+(dbls_median_fpd-low))/2

#  QUADS
fig,axs = plt.subplots(1,3,figsize=(13,4),dpi=200)
# iterate over fpds
for i in range(0,3):
    true_fpd = gt_lenses.lens_df.loc[gold_quads,'fpd0%d'%(i+1)].to_numpy().astype(float)
    pred_fpd = quads_median_fpd[:,i]
    sigma_pred_fpd = quads_sigma_fpd[:,i]
    axs[i].plot([-1.9,0.2],[-1.9,0.2],color='black')
    axs[i].scatter(true_fpd,pred_fpd,color='goldenrod')
    axs[i].set_title('fpd0%d'%(i+1))
    axs[i].set_xlabel('truth fpd0%d'%(i+1))
    axs[i].set_ylabel('pred fpd0%d'%(i+1))
    mean_error = np.mean(pred_fpd - true_fpd)
    mean_error_sigma = np.mean( (pred_fpd - true_fpd) / quads_sigma_fpd[:,i])
    axs[i].text(-0.9,-1.3,'mean_error = %.3f'%(mean_error))
    axs[i].text(-0.9,-1.5,'mean_error / $\sigma$ = %.2f'%(mean_error_sigma))

plt.suptitle('Gold Quads fpd recovery')

# DOUBLES
plt.figure(dpi=200)
i=0
true_fpd = gt_lenses.lens_df.loc[gold_dbls,'fpd0%d'%(i+1)].to_numpy().astype(float)
pred_fpd = dbls_median_fpd[:,i]
plt.plot([-4.,0.2],[-4.,0.2],color='black')
plt.scatter(true_fpd,pred_fpd,color='goldenrod')
plt.title('Gold Doubles fpd recovery')
plt.xlabel('truth fpd0%d'%(i+1))
plt.ylabel('pred fpd0%d'%(i+1))
mean_error = np.mean(pred_fpd - true_fpd)
mean_error_sigma = np.mean( (pred_fpd - true_fpd) / dbls_sigma_fpd[:,i])
plt.text(-1.5,-3.5,'mean_error = %.3f'%(mean_error))
plt.text(-1.5,-4.,'mean_error / $\sigma$ = %.2f'%(mean_error_sigma))

plt.figure(dpi=200)
true_td = gt_lenses.lens_df.loc[gold_dbls,'td0%d'%(i+1)].to_numpy().astype(float)
plt.errorbar(true_td, pred_fpd - true_fpd, yerr=dbls_sigma_fpd[:, i], fmt='o', color='goldenrod', markersize=5, elinewidth=1)
plt.xlabel('td01 (days)')
plt.ylabel('pred_fpd01 - true_fpd01')
plt.hlines(0.,xmin=-850.,xmax=1.,color='black',zorder=200)
plt.ylim([-1.5,1.5])
plt.title('Gold Doubles')

In [None]:
# where data vectors are stored
exp_folder = ('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/'+
        'DataVectors/fpd_eval_DV_silverALLDEBIASED')

lens_types = ['gold_quads','gold_dbls','silver_quads','silver_dbls']

inputs_dict = {
    'gold_quads':{},
    'gold_dbls':{},
    'silver_quads':{},
    'silver_dbls':{},
}

input_keys = ['measured_td','measured_prec','prefactor','fpd_samps','gamma_samps',
    'z_lens_truth','z_src_truth','theta_E_truth','Ddt_Mpc_truth_truth',
    'td01_truth','fpd01_truth','gamma_truth']

for l in lens_types:
    my_filepath = (exp_folder+'/'+l+'.h5')
    h5f = h5py.File(my_filepath, 'r')
    for key in input_keys:
        inputs_dict[l][key] = h5f.get(key)[:]
    h5f.close()

quads_fpd_samps = inputs_dict['silver_quads']['fpd_samps']
quads_median_fpd = np.median(quads_fpd_samps,axis=1)
low = np.quantile(quads_fpd_samps,q=0.1586,axis=1)
high = np.quantile(quads_fpd_samps,q=0.8413,axis=1)
quads_sigma_fpd = ((high-quads_median_fpd)+(quads_median_fpd-low))/2

dbls_fpd_samps = inputs_dict['silver_dbls']['fpd_samps']
dbls_median_fpd = np.median(dbls_fpd_samps,axis=1)
low = np.quantile(dbls_fpd_samps,q=0.1586,axis=1)
high = np.quantile(dbls_fpd_samps,q=0.8413,axis=1)
dbls_sigma_fpd = ((high-dbls_median_fpd)+(dbls_median_fpd-low))/2


dbls_gamma_samps = inputs_dict['silver_dbls']['gamma_samps']
dbls_median_gamma = np.median(dbls_gamma_samps,axis=1)
low = np.quantile(dbls_gamma_samps,q=0.1586,axis=1)
high = np.quantile(dbls_gamma_samps,q=0.8413,axis=1)
dbls_sigma_gamma = ((high-dbls_median_gamma)+(dbls_median_gamma-low))/2

#  QUADS
fig,axs = plt.subplots(1,3,figsize=(13,4),dpi=200)
# iterate over fpds
for i in range(0,3):
    true_fpd = gt_lenses.lens_df.loc[silver_quads,'fpd0%d'%(i+1)].to_numpy().astype(float)
    pred_fpd = quads_median_fpd[:,i]
    axs[i].plot([-1.5,0.2],[-1.5,0.2],color='black')
    axs[i].scatter(true_fpd,pred_fpd,color='silver')
    axs[i].set_title('fpd0%d'%(i+1))
    axs[i].set_xlabel('truth fpd0%d'%(i+1))
    axs[i].set_ylabel('pred fpd0%d'%(i+1))
    mean_error = np.mean(pred_fpd - true_fpd)
    mean_error_sigma = np.mean( (pred_fpd - true_fpd) / quads_sigma_fpd[:,i])
    axs[i].text(-0.9,-1.3,'mean_error = %.3f'%(mean_error))
    axs[i].text(-0.9,-1.5,'mean_error / $\sigma$ = %.2f'%(mean_error_sigma))

plt.suptitle('Silver Quads fpd recovery')

# DOUBLES
plt.figure(dpi=200)
i=0
true_fpd = inputs_dict['silver_dbls']['fpd01_truth']#gt_lenses.lens_df.loc[silver_dbls,'fpd0%d'%(i+1)].to_numpy().astype(float)
pred_fpd = dbls_median_fpd[:,i]
print(true_fpd.shape)
print(pred_fpd.shape)
plt.plot([-5.,0.2],[-5.,0.2],color='black')
plt.scatter(true_fpd,pred_fpd,color='silver')
plt.title('Silver Doubles fpd recovery')
plt.xlabel('truth fpd0%d'%(i+1))
plt.ylabel('pred fpd0%d'%(i+1))
mean_error = np.mean(pred_fpd - true_fpd)
mean_error_sigma = np.mean( (pred_fpd - true_fpd) / dbls_sigma_fpd[:,i])
plt.text(-1.5,-3.5,'mean_error = %.3f'%(mean_error))
plt.text(-1.5,-4.,'mean_error / $\sigma$ = %.2f'%(mean_error_sigma))

plt.figure(dpi=200)
true_td = inputs_dict['silver_dbls']['td01_truth']
plt.errorbar(true_td, pred_fpd - true_fpd, yerr=dbls_sigma_fpd[:, i], fmt='o', color='silver', markersize=5, elinewidth=1)
plt.xlabel('td01 (days)')
plt.ylabel('pred_fpd01 - true_fpd01')
plt.hlines(0.,xmin=-850.,xmax=1.,color='black',zorder=200)
plt.ylim([-1.5,1.5])
plt.title('Silver Doubles')

# try with gamma_lens as well
plt.figure(dpi=200)
true_td = inputs_dict['silver_dbls']['td01_truth']
true_gamma = inputs_dict['silver_dbls']['gamma_truth']
plt.errorbar(inputs_dict['silver_dbls']['theta_E_truth'], pred_fpd - true_fpd, yerr=dbls_sigma_fpd[:, i], fmt='o', color='silver', markersize=5, elinewidth=1)
plt.xlabel('theta_E')
plt.ylabel('pred_fpd01 - true_fpd01')
plt.hlines(0.,xmin=0.4,xmax=2.0,color='black',zorder=200)
plt.ylim([-1.5,1.5])
plt.title('Silver Doubles')
true_theta_E = inputs_dict['silver_dbls']['theta_E_truth']
print('mean error on gamma silver doubles, theta_E>1.2: ', np.mean((dbls_median_gamma-true_gamma)[true_theta_E>1.2]))
print('median error on gamma silver doubles, theta_E>1.2: ', np.median((dbls_median_gamma-true_gamma)[true_theta_E>1.2]))

print('mean error on gamma silver doubles, theta_E<1.2: ', np.mean((dbls_median_gamma-true_gamma)[true_theta_E<1.2]))
print('median error on gamma silver doubles, theta_E<1.2: ', np.median((dbls_median_gamma-true_gamma)[true_theta_E<1.2]))


# now let's try err(DDt) vs redshift
true_Ddt = inputs_dict['silver_dbls']['Ddt_Mpc_truth_truth']
pred_Ddt = tdc_utils.ddt_from_td_fpd(inputs_dict['silver_dbls']['measured_td'][:,i],pred_fpd)

plt.figure(dpi=200)
plt.scatter(inputs_dict['silver_dbls']['z_lens_truth'],100*(pred_Ddt-true_Ddt)/true_Ddt,color='silver')
plt.xlabel('z_lens',fontsize=15)
plt.ylabel('Ddt \% error (pred-truth)/truth',fontsize=15)
plt.hlines(0.,0.,2.,color='black')
plt.title('Silver Doubles, NPE-Debiased',fontsize=15)

plt.figure(dpi=200)
plt.scatter(inputs_dict['silver_dbls']['z_src_truth'],100*(pred_Ddt-true_Ddt)/true_Ddt,color='silver')
plt.xlabel('z_src',fontsize=15)
plt.ylabel('Ddt \% error (pred-truth)/truth',fontsize=15)
plt.hlines(0.,0.,3.,color='black')
plt.title('Silver Doubles, NPE-Debiased',fontsize=15)


plt.figure(dpi=200)
plt.scatter(np.abs(inputs_dict['silver_dbls']['td01_truth']),100*(pred_Ddt-true_Ddt)/true_Ddt,color='silver')
plt.xlabel('$\Delta t_{01}$ truth',fontsize=15)
plt.ylabel('Ddt \% error (pred-truth)/truth',fontsize=15)
plt.hlines(0.,0.,850.,color='black')
plt.title('Silver Doubles, NPE-Debiased',fontsize=15)

print(np.where(np.abs(100*(pred_Ddt-true_Ddt)/true_Ddt > 100.))[0])
bad_idx = [  4,  54,  87,  90 ,197 ,298, 324]

### Check chains from make_data_vectors() test ###

In [None]:
mu_mean_gold = np.mean(gt_lenses.lens_df.loc[gold,'gamma'].to_numpy().astype(float))
std_mean_gold = np.std(gt_lenses.lens_df.loc[gold,'gamma'].to_numpy().astype(float),ddof=1)

print('mean gamma of gold sample: ', mu_mean_gold)
print('sigma gamma of gold sample: ', std_mean_gold)

mu_mean_silver = np.mean(gt_lenses.lens_df.loc[silver,'gamma'].to_numpy().astype(float))
std_mean_silver = np.std(gt_lenses.lens_df.loc[silver,'gamma'].to_numpy().astype(float),ddof=1)

print('mean gamma of silver sample: ', mu_mean_silver)
print('sigma gamma of silver sample: ', std_mean_silver)

In [None]:
h5f = h5py.File('DataVectors/fpd_eval_DV_silverALLDEBIASED/silver_doubles_ALL_chain_5e3_w0waCDM.h5', 'r')
silver_doublesALL = h5f.get('mcmc_chain')[:]
h5f.close()

h5f = h5py.File('DataVectors/fpd_eval_DV_silverALLDEBIASED/silver_doubles_REMOVEBAD_chain_5e3_w0waCDM.h5', 'r')
silver_doublesGOOD = h5f.get('mcmc_chain')[:]
h5f.close()


mu_mean_gold = np.mean(gt_lenses.lens_df.loc[gold,'gamma'].to_numpy().astype(float))
std_mean_gold = np.std(gt_lenses.lens_df.loc[gold,'gamma'].to_numpy().astype(float),ddof=1)

mu_mean_silver = np.mean(gt_lenses.lens_df.loc[silver,'gamma'].to_numpy().astype(float))
std_mean_silver = np.std(gt_lenses.lens_df.loc[silver,'gamma'].to_numpy().astype(float),ddof=1)

mu_mean_silver_dbls = np.mean(gt_lenses.lens_df.loc[silver_dbls,'gamma'].to_numpy().astype(float))
std_mean_silver_dbls = np.std(gt_lenses.lens_df.loc[silver_dbls,'gamma'].to_numpy().astype(float),ddof=1)

exp_chains = [silver_doublesALL,silver_doublesGOOD]
exp_names = ['All Silver Doubles: LSST NPE, m_app_src < 24',
    'Bad Removed Silver Doubles: LSST NPE, m_app_src < 24',
             'Combined: m_app_src < 24']
burnin = [int(3e3),int(3e3),int(3e3),int(3e3),int(3e3)]
colors = ['indianred','darkcyan','silver','indianred','darkturquoise','lightcyan','darkcyan','darkturquoise']#'indianred','gold','silver','indianred','turquoise','purple']
custom_labels = []

custom_lines = []
custom_labels = []
for i,exp_chain in enumerate(exp_chains):
     
    if i ==0:

        figure = corner.corner(exp_chain[:,burnin[i]:,:].reshape((-1,6)),plot_datapoints=False,
            color=colors[i],levels=[0.68,0.95],fill_contours=True,
            labels=['$H_0$','$\Omega_M$','w$_0$','w$_a$',r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],
            dpi=200,truths=[70.,0.3,-1.,0.,mu_mean_silver_dbls,std_mean_silver_dbls],truth_color='black',
            fig=None,label_kwargs={'fontsize':30},smooth=0.7,hist_kwargs={'density':True})

    else:

        corner.corner(exp_chain[:,burnin[i]:,:].reshape((-1,6)),plot_datapoints=False,
            color=colors[i],levels=[0.68,0.95],fill_contours=True,
            labels=['$H_0$','$\Omega_M$','w$_0$','w$_a$',r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],
            dpi=200,truths=[70.,0.3,-1.,0.,mu_mean_silver_dbls,std_mean_silver_dbls],truth_color='black',
            fig=figure,label_kwargs={'fontsize':30},smooth=0.7,hist_kwargs={'density':True})
        
    custom_lines.append(Line2D([0], [0], color=colors[i], lw=4))

    # calculate h0 constraint
    h0, h0_sigma = median_sigma_from_samples(exp_chain[:,burnin[i]:,0].reshape((-1,1)),weights=None)
    # construct label
    custom_labels.append(exp_names[i]+':\n $H_0$=%.2f$\pm$%.2f'%(h0, h0_sigma))

"""
axes = np.array(figure.axes).reshape((3, 3))
bounds = [[63,77],[1.91,2.095],[0.0,0.2]]
for r in range(0,3):
        for c in range(0,r+1):
            if bounds is not None:
                axes[r,c].set_xlim(bounds[c])
                if r != c :
                    axes[r,c].set_ylim(bounds[r])

axes = np.array(figure.axes).reshape((3, 3))
"""

axes = np.array(figure.axes).reshape((6, 6))
axes[0,5].legend(custom_lines,custom_labels,frameon=False,fontsize=20)

plt.savefig('/Users/smericks/Desktop/forecast_contour.pdf')

# TODO: look into changing sampler to dynesty or some other, faster sampler (nested sampling)
# TODO: move to sherlock and use MPI
# TODO: check silver-only contour (confirm if its wide / constraint is driven by gold)

In [None]:
def HI_medians_table(emcee_chain,param_labels,burnin=1e3):

    burnin = int(burnin)
    num_params = emcee_chain.shape[2]
    chain = emcee_chain[:,burnin:,:].reshape((-1,num_params))

    med = np.median(chain,axis=0)
    low = np.quantile(chain,q=0.1586,axis=0)
    high = np.quantile(chain,q=0.8413,axis=0)
    sigma = ((high-med)+(med-low))/2

    for i in range(0,num_params):
        print(param_labels[i],': ',med[i],' $\pm$', sigma[i])
        
HI_medians_table(gold_chain_bright, 
    ['$H_0$','$\Omega_M$','w$_0$','w$_a$',r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],
    int(4e3))

In [None]:
h5f = h5py.File('DataVectors/gold2days_3e3fpd/gold_chain_3e3_LCDM.h5', 'r')
gold_chain = h5f.get('mcmc_chain')[:]
h5f.close()

h5f = h5py.File('DataVectors/gold2days_3e3fpd_bright_src/gold_chain_3e3_LCDM.h5', 'r')
gold_chain_bright = h5f.get('mcmc_chain')[:]
h5f.close()

h5f = h5py.File('DataVectors/gold2days_silver5days_3e3fpd_bright_src/silver_chain_3e3_LCDM.h5', 'r')
silver_chain_bright = h5f.get('mcmc_chain')[:]
h5f.close()

mu_mean_gold = np.mean(gt_lenses.lens_df.loc[gold,'gamma'].to_numpy().astype(float))
std_mean_gold = np.std(gt_lenses.lens_df.loc[gold,'gamma'].to_numpy().astype(float),ddof=1)


exp_chains = [gold_chain_bright,silver_chain_bright]#gold_chain3]#mcmc_chain_quads_w0wa, mcmc_chain_quads_w0wa_1day,mcmc_chain_quads_w0wa_1day_EM, mcmc_chain_quads_w0wa_200]
exp_names = ['250 Gold: NPE fpd, m_app_src < 24.','730 Gold+Silver: NPE fpd, m_app_src < 24.']
burnin = [int(1500),int(1500),int(1500)]
colors = ['goldenrod','indianred']#'gold','silver','indianred','turquoise','purple']
custom_labels = []

custom_lines = []
custom_labels = []
for i,exp_chain in enumerate(exp_chains):
     
    if i ==0:

        figure = corner.corner(exp_chain[:,burnin[i]:,:].reshape((-1,4)),plot_datapoints=False,
            color=colors[i],levels=[0.68,0.95],fill_contours=True,
            labels=['$H_0$','$\Omega_M$',r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],
            dpi=300,truths=[70.,0.3,mu_mean_gold,std_mean_gold],truth_color='black',
            fig=None,label_kwargs={'fontsize':24},smooth=0.7)

    else:

        corner.corner(exp_chain[:,burnin[i]:,:].reshape((-1,4)),plot_datapoints=False,
            color=colors[i],levels=[0.68,0.95],fill_contours=True,
            labels=['$H_0$','$\Omega_M$',r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],
            dpi=300,truths=[70.,0.3,mu_mean_gold,std_mean_gold],truth_color='black',
            fig=figure,label_kwargs={'fontsize':24},smooth=0.7)
        
    custom_lines.append(Line2D([0], [0], color=colors[i], lw=4))

    # calculate h0 constraint
    h0, h0_sigma = median_sigma_from_samples(exp_chain[:,burnin[i]:,0].reshape((-1,1)),weights=None)
    # construct label
    custom_labels.append(exp_names[i]+':\n $H_0$=%.2f$\pm$%.2f'%(h0, h0_sigma))

"""
axes = np.array(figure.axes).reshape((3, 3))
bounds = [[63,77],[1.91,2.095],[0.0,0.2]]
for r in range(0,3):
        for c in range(0,r+1):
            if bounds is not None:
                axes[r,c].set_xlim(bounds[c])
                if r != c :
                    axes[r,c].set_ylim(bounds[r])

axes = np.array(figure.axes).reshape((3, 3))
"""

axes = np.array(figure.axes).reshape((4, 4))
axes[0,3].legend(custom_lines,custom_labels,frameon=False,fontsize=16)

In [None]:
HI_medians_table(gold_chain_bright, 
    ['$H_0$','$\Omega_M$',r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],
    1500)

### Look further into chains (debugging option) ###

In [None]:
def plot_convergence_by_walker(samples_mcmc, param_mcmc, n_walkers, verbose = False):
    n_params = samples_mcmc.shape[2]
    n_step = int(samples_mcmc.shape[1])
    chain = samples_mcmc
    mean_pos = np.zeros((n_params, n_step))
    median_pos = np.zeros((n_params, n_step))
    std_pos = np.zeros((n_params, n_step))
    q16_pos = np.zeros((n_params, n_step))
    q84_pos = np.zeros((n_params, n_step))
    # chain = np.empty((nwalker, nstep, ndim), dtype = np.double)
    for i in np.arange(n_params):
        for j in np.arange(n_step):
            mean_pos[i][j] = np.mean(chain[:, j, i])
            median_pos[i][j] = np.median(chain[:, j, i])
            std_pos[i][j] = np.std(chain[:, j, i])
            q16_pos[i][j] = np.percentile(chain[:, j, i], 16.)
            q84_pos[i][j] = np.percentile(chain[:, j, i], 84.)
    fig, ax = plt.subplots(n_params, sharex=True, figsize=(16, 2 * n_params))
    if n_params == 1: ax = [ax]
    last = n_step
    burnin = int((9.*n_step) / 10.) #get the final value on the last 10% on the chain
    for i in range(n_params):
        if verbose :
            print(param_mcmc[i], '{:.4f} +/- {:.4f}'.format(median_pos[i][last - 1], (q84_pos[i][last - 1] - q16_pos[i][last - 1]) / 2))
        ax[i].plot(median_pos[i][:last], c='g')
        ax[i].axhline(np.median(median_pos[i][burnin:last]), c='r', lw=1)
        ax[i].fill_between(np.arange(last), q84_pos[i][:last], q16_pos[i][:last], alpha=0.4)
        ax[i].set_ylabel(param_mcmc[i], fontsize=10)
        ax[i].set_xlim(0, last)
    return fig

plot_convergence_by_walker(silver_doublesGOOD[:,100:,:],
    ['$H_0$','$\Omega_M$','w$_0$','w$_a$',r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],#'w$_0$','w$_a$'
    20)

## Sanity Checks ##

### Does the assumed training prior match the effective training prior after cuts ###

In [None]:
hst_val = pd.read_csv('MassModels/hst_validation_metadata.csv')
gamma_vals = hst_val['main_deflector_parameters_gamma'].to_numpy().astype(float)
plt.hist(gamma_vals,density=True,label='Samples')
x_range = np.arange(1.4,2.6,0.01)
mu_est = np.mean(gamma_vals)
sigma_est = np.std(gamma_vals,ddof=1)
plt.plot(x_range,norm.pdf(x_range,mu_est,sigma_est),label='$\mu=%.3f$, $\sigma=%.2f$'%(mu_est,sigma_est))
plt.legend()
plt.title('Validation')

plt.figure()
hst_train0 = pd.read_csv('MassModels/hst_train0_metadata.csv')
gamma_vals = hst_train0['main_deflector_parameters_gamma'].to_numpy().astype(float)
plt.hist(gamma_vals,density=True,label='Samples')
x_range = np.arange(1.4,2.6,0.01)
mu_est = np.mean(gamma_vals)
sigma_est = np.std(gamma_vals,ddof=1)
plt.plot(x_range,norm.pdf(x_range,mu_est,sigma_est),label='$\mu=%.3f$, $\sigma=%.2f$'%(mu_est,sigma_est))
plt.legend()
plt.title('Train 0')

from scipy.stats import chisquare
from scipy.stats import gaussian_kde
plt.figure()

gaussian_samples = norm.rvs(mu_est,sigma_est,size=10000)

# Let's try a KDE
gamma_kde = gaussian_kde(gamma_vals)
kde_samples = gamma_kde.resample(size=10000)[0]
bins=np.histogram(np.hstack((kde_samples,gamma_vals)), bins=40)[1]

counts_exp,_,_ = plt.hist(kde_samples,bins,
            histtype='step',label='KDE Estimate')
counts_obs,_,_ = plt.hist(gamma_vals,bins,histtype='step',
                          label='Training Samples')
    
# only take bins where counts are greater than 5, then add in an extra bin at beginning and end for the tails
idx = np.where((counts_exp > 5) & (counts_obs > 5))[0]
# less than some #, bins, greather than some #
prepend = np.sum(counts_obs[:idx[0]])
append = np.sum(counts_obs[(idx[-1]+1):])
counts_obs_final = np.concatenate(([prepend],counts_obs[idx],[append]))
prepend = np.sum(counts_exp[:idx[0]])
append = np.sum(counts_exp[(idx[-1]+1):])
counts_exp_final = np.concatenate(([prepend],counts_exp[idx],[append]))



chi2_distance = np.sum((counts_obs_final-counts_exp_final)**2/counts_exp_final)
#print(chi2_distance)

chi2,_= chisquare(counts_obs_final,counts_exp_final)
print(chi2)

dof = len(counts_obs_final) - 1
print('chi2/dof:', chi2/dof)
plt.text(1.6,800,r'$\frac{\chi^2}{\nu}$: %.2f'%(chi2/dof),
                        {'fontsize':13})
plt.legend()