In [2]:
%load_ext autoreload
%autoreload 2
from lens_catalog import OM10LensCatalog
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
from Utils.inference_utils import median_sigma_from_samples
import h5py
import pandas as pd
import corner
from matplotlib.lines import Line2D
import Utils.mcmc_utils as mcmc_utils

In [None]:
gt_lenses = OM10LensCatalog('MassModels/om10_sample/om10_venkatraman_erickson24.csv')

### Define Gold/Silver Samples ###

We split the sample into doubles and quads. We already have a sample of ~30 quads (the STRIDES sample), so we take a conservative assumption of 50 quads in the gold sample. We add 200 doubles to the gold sample.
The remaining lenses are added to the silver sample. 

In [4]:
# split into doubles & quads
dbls = gt_lenses.doubles_indices()
quads = gt_lenses.quads_indices()
# Let's assume in the gold sample: 200 doubles, 50 quads (overall quad fraction is 11%, so this is amplified)
# The rest are silver (regardless of time delay)
dbls = gt_lenses.doubles_indices()
quads = gt_lenses.quads_indices()

gold_dbls = dbls[:200]
silver_dbls = dbls[200:]

gold_quads = quads[:50]
silver_quads = quads[50:]

gold = np.append(gold_dbls,gold_quads)
silver = np.append(silver_dbls,silver_quads)

### Make Data Vectors ###

This is all done using the make_data_vectors() function in make_data_vectors.py

1. Emulates time delay measurements
2. Uses NPE posteriors to compute samples from fermat potential difference posteriors
3. Stores all information in format for fast_tdc()

In [7]:
from make_data_vectors import make_data_vectors

# load in NPE posteriors
hst_mu = np.load('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/MassModels/om10_sample/y_pred_hst.npy')
hst_cov = np.load('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/MassModels/om10_sample/cov_pred_hst.npy')

lsst_mu = np.load('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/MassModels/om10_sample/y_pred_lsst.npy')
lsst_cov = np.load('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/MassModels/om10_sample/cov_pred_lsst.npy')

# where to store this test
exp_folder = 'gold2days_test'

# gold quads
make_data_vectors(gt_lenses,gold_quads,num_images=4,td_meas_error=2,
    npe_mu=hst_mu,npe_cov=hst_cov,
    h5_save_path=('DataVectors/'+exp_folder+'/gold_quads.h5'),
    num_fpd_samps=1000)

# gold dbls
make_data_vectors(gt_lenses,gold_dbls,num_images=2,td_meas_error=2,
    npe_mu=hst_mu,npe_cov=hst_cov,
    h5_save_path=('DataVectors/'+exp_folder+'/gold_dbls.h5'),
    num_fpd_samps=1000)

### Check chains from make_data_vectors() test ###

In [None]:
h5f = h5py.File('DataVectors/gold2days_test/gold_chain1000.h5', 'r')
gold_chain1 = h5f.get('mcmc_chain')[:]
h5f.close()

h5f = h5py.File('DataVectors/gold2days_silver5days/gold_chain1000.h5', 'r')
gold_chain2 = h5f.get('mcmc_chain')[:]
h5f.close()

mu_mean_gold = np.mean(gt_lenses.lens_df.loc[gold,'gamma'].to_numpy().astype(float))
std_mean_gold = np.std(gt_lenses.lens_df.loc[gold,'gamma'].to_numpy().astype(float),ddof=1)


exp_chains = [gold_chain1,gold_chain2]#mcmc_chain_quads_w0wa, mcmc_chain_quads_w0wa_1day,mcmc_chain_quads_w0wa_1day_EM, mcmc_chain_quads_w0wa_200]
exp_names = ['NEW code: Gold: 50 Quads+ 200 Dbls, NPE Models, 1e3 fpd samps', #2-day td error
             'OLD code: Gold: 50 Quads+ 200 Dbls, NPE Models, 1e3 fpd samps',
             '250Gold+1034Silver, NPE Models, 5-day silver TD error']#'50 Gold Quads, NPE Models, 2-day td error', 
             #'50 Gold Quads, NPE Models, 1-day td error', 
             #'50 Gold Quads, Emulated Models, 1-day td error',
             #'Gold: 50 Quads+ 200 Dbls, NPE Models, 1-day td error']
burnin = int(1e2)
colors = ['purple','turquoise']#'gold','silver','indianred','turquoise','purple']
custom_labels = []

custom_lines = []
custom_labels = []
for i,exp_chain in enumerate(exp_chains):
     
    if i ==0:

        figure = corner.corner(exp_chain[:,burnin:,:].reshape((-1,5)),plot_datapoints=False,
            color=colors[i],levels=[0.68,0.95],fill_contours=True,
            labels=['$H_0$','w$_0$','w$_a$',r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],
            dpi=200,truths=[70.,-1.,0.,mu_mean_gold,std_mean_gold],truth_color='black',
            fig=None,label_kwargs={'fontsize':17})

    else:

        corner.corner(exp_chain[:,burnin:,:].reshape((-1,5)),plot_datapoints=False,
            color=colors[i],levels=[0.68,0.95],fill_contours=True,
            labels=['$H_0$','w$_0$','w$_a$',r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],
            dpi=200,truths=[70.,-1.,0.,mu_mean_gold,std_mean_gold],truth_color='black',
            fig=figure,label_kwargs={'fontsize':17})
        
    custom_lines.append(Line2D([0], [0], color=colors[i], lw=4))

    # calculate h0 constraint
    h0, h0_sigma = median_sigma_from_samples(exp_chain[:,burnin:,0].reshape((-1,1)),weights=None)
    # construct label
    custom_labels.append(exp_names[i]+':\n $H_0$=%.2f$\pm$%.2f'%(h0, h0_sigma))

"""
axes = np.array(figure.axes).reshape((3, 3))
bounds = [[63,77],[1.91,2.095],[0.0,0.2]]
for r in range(0,3):
        for c in range(0,r+1):
            if bounds is not None:
                axes[r,c].set_xlim(bounds[c])
                if r != c :
                    axes[r,c].set_ylim(bounds[r])

axes = np.array(figure.axes).reshape((3, 3))
"""

axes = np.array(figure.axes).reshape((5, 5))
axes[0,4].legend(custom_lines,custom_labels,frameon=False,fontsize=16)

### Look further into chains (debugging option) ###

In [None]:
param_labels = np.asarray(['$H_0$','w$_0$','w$_a$',r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'])
true_hyperparameters = np.asarray([70.,-1.,0.,mu_mean_gold,std_mean_gold])

mcmc_utils.analyze_chains(gold_chain1,param_labels,true_hyperparameters,
                    'test.txt',show_chains=True,
                    burnin=int(1e2))

## Sanity Checks ##

### Does the assumed training prior match the effective training prior after cuts ###

In [None]:
hst_val = pd.read_csv('MassModels/hst_validation_metadata.csv')
gamma_vals = hst_val['main_deflector_parameters_gamma'].to_numpy().astype(float)
plt.hist(gamma_vals,density=True,label='Samples')
x_range = np.arange(1.4,2.6,0.01)
mu_est = np.mean(gamma_vals)
sigma_est = np.std(gamma_vals,ddof=1)
plt.plot(x_range,norm.pdf(x_range,mu_est,sigma_est),label='$\mu=%.3f$, $\sigma=%.2f$'%(mu_est,sigma_est))
plt.legend()
plt.title('Validation')

plt.figure()
hst_train0 = pd.read_csv('MassModels/hst_train0_metadata.csv')
gamma_vals = hst_train0['main_deflector_parameters_gamma'].to_numpy().astype(float)
plt.hist(gamma_vals,density=True,label='Samples')
x_range = np.arange(1.4,2.6,0.01)
mu_est = np.mean(gamma_vals)
sigma_est = np.std(gamma_vals,ddof=1)
plt.plot(x_range,norm.pdf(x_range,mu_est,sigma_est),label='$\mu=%.3f$, $\sigma=%.2f$'%(mu_est,sigma_est))
plt.legend()
plt.title('Train 0')

from scipy.stats import chisquare
from scipy.stats import gaussian_kde
plt.figure()

gaussian_samples = norm.rvs(mu_est,sigma_est,size=10000)

# Let's try a KDE
gamma_kde = gaussian_kde(gamma_vals)
kde_samples = gamma_kde.resample(size=10000)[0]
bins=np.histogram(np.hstack((kde_samples,gamma_vals)), bins=40)[1]

counts_exp,_,_ = plt.hist(kde_samples,bins,
            histtype='step',label='KDE Estimate')
counts_obs,_,_ = plt.hist(gamma_vals,bins,histtype='step',
                          label='Training Samples')
    
# only take bins where counts are greater than 5, then add in an extra bin at beginning and end for the tails
idx = np.where((counts_exp > 5) & (counts_obs > 5))[0]
# less than some #, bins, greather than some #
prepend = np.sum(counts_obs[:idx[0]])
append = np.sum(counts_obs[(idx[-1]+1):])
counts_obs_final = np.concatenate(([prepend],counts_obs[idx],[append]))
prepend = np.sum(counts_exp[:idx[0]])
append = np.sum(counts_exp[(idx[-1]+1):])
counts_exp_final = np.concatenate(([prepend],counts_exp[idx],[append]))



chi2_distance = np.sum((counts_obs_final-counts_exp_final)**2/counts_exp_final)
#print(chi2_distance)

chi2,_= chisquare(counts_obs_final,counts_exp_final)
print(chi2)

dof = len(counts_obs_final) - 1
print('chi2/dof:', chi2/dof)
plt.text(1.6,800,r'$\frac{\chi^2}{\nu}$: %.2f'%(chi2/dof),
                        {'fontsize':13})
plt.legend()