In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import h5py
import jax_cosmo
import tdc_sampler
from scipy.stats import norm
from lenstronomy.Analysis.kinematics_api import KinematicsAPI
from tdc_utils import jax_kin_distance_ratio
from matplotlib.lines import Line2D
from Utils.inference_utils import median_sigma_from_samples

First, set up ground truth kinematics calculation

In [None]:
R_APERTURE = 0.725
PSF_FWHM = 0.5

kwargs_aperture = {
    'aperture_type': 'shell', 
    'r_in': 0., 
    'r_out': R_APERTURE,
    'center_ra': 0, 'center_dec': 0}

kwargs_seeing = {'psf_type': 'GAUSSIAN', 'fwhm': PSF_FWHM}

kwargs_numerics_galkin = { 
    'interpol_grid_num': 1000,  # numerical interpolation, should converge -> infinity
    'log_integration': True,  # log or linear interpolation of surface brightness and mass models
    'max_integrate': 100, 'min_integrate': 0.001}  # lower/upper bound of numerical integrals

kwargs_model = {
    'lens_model_list':['SPP'],
    'lens_light_model_list':['SERSIC']
}

anisotropy_model = 'const'

kinematicsAPI = KinematicsAPI(0.5, 2., kwargs_model, 
    kwargs_aperture, kwargs_seeing, anisotropy_model, 
    kwargs_numerics_galkin=kwargs_numerics_galkin, 
    lens_model_kinematics_bool=[True, False],
    sampling_number=5000,MGE_light=True)

gt_cosmo = jax_cosmo.Cosmology(h=jnp.float32(70./100),
                        Omega_c=jnp.float32(0.3-0.05), # "cold dark matter fraction", OmegaM = 0.3
                        Omega_b=jnp.float32(0.05), # "baryonic fraction"
                        Omega_k=jnp.float32(0.),
                        w0=jnp.float32(-1.),
                        wa=jnp.float32(0.),
                        sigma8 = jnp.float32(0.8), n_s=jnp.float32(0.96))


def ground_truth_veldisp(theta_E,gamma_lens,R_sersic,n_sersic):

    # TODO: I messed this up!! Should be sqrt() of this...
    distance_scaling_factor = np.sqrt(
            kinematicsAPI._kwargs_cosmo['d_s'] / kinematicsAPI._kwargs_cosmo['d_ds'])

    kwargs_anisotropy = {'beta': 0.}

    kwargs_lens = [{
        'theta_E':theta_E, 
        'gamma':gamma_lens, 
        "center_x":0., 
        "center_y":0.
    }]

    kwargs_lens_light = [{
        'amp': 10.,
        'R_sersic': R_sersic,
        'n_sersic': n_sersic,
        'center_x': 0.,
        'center_y': 0.,
    }]

    vel_disp_numerical = kinematicsAPI.velocity_dispersion(kwargs_lens, 
        kwargs_lens_light, kwargs_anisotropy, r_eff=R_sersic, theta_E=theta_E)

    return vel_disp_numerical, distance_scaling_factor

Now, load in all of the pre-computed data products

In [None]:
# load in data vector
data_vectors_folder = ('/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/'+
    'DataVectors/src_mag_cut_silver_debiased/gold_quads/')

inputs_dict = {}
input_keys = ['lens_param_samps','z_lens_truth','z_src_truth',
              'theta_E_truth','gamma_truth','fpd_samps',
            'lens_light_parameters_R_sersic_truth',
            'lens_light_parameters_n_sersic_truth',
            'fpd01_truth','fpd02_truth','fpd03_truth',
            'td01_truth','td02_truth','td03_truth',
            'measured_td','measured_prec']

h5f = h5py.File((data_vectors_folder+'gold_quads.h5'), 'r')
for key in input_keys:
    inputs_dict[key] = h5f.get(key)[:]
h5f.close()

cJ_samples = np.empty((10,5000,1))
sigma_v_truth_list = np.empty(10)
c_sqrtJ_truth_list = np.empty(10)
sigma_v_measured = np.empty((10,1))
sigma_v_likelihood_prec = np.empty((10,1,1))
sigma_v_measurement_error = 5. # NOTE: assuming 5 km/s measurement error on kinematics...

# fill in c*sqrt(J) samples
for i in range(0,10):

    # compute ground truth kinematics
    vdisp_truth, vdisp_dist_scaling = ground_truth_veldisp(inputs_dict['theta_E_truth'][i],
        inputs_dict['gamma_truth'][i],
        inputs_dict['lens_light_parameters_R_sersic_truth'][i],
        inputs_dict['lens_light_parameters_n_sersic_truth'][i])
    # re-scale based on redshift & ground truth cosmology
    cJ_truth = vdisp_truth/vdisp_dist_scaling
    c_sqrtJ_truth_list[i] = cJ_truth
    z_l = inputs_dict['z_lens_truth'][i]
    z_s = inputs_dict['z_src_truth'][i]
    ds_div_dds = jax_kin_distance_ratio(gt_cosmo,z_l,z_s)
    sigma_v_truth = np.sqrt(ds_div_dds)*cJ_truth
    sigma_v_truth_list[i] = sigma_v_truth.item()
    # emulate the measurement from ground truth
    sigma_v_measured[i,0] = norm.rvs(loc=sigma_v_truth,scale=sigma_v_measurement_error)    
    sigma_v_likelihood_prec[i,0,0] = 1/(sigma_v_measurement_error**2)

    # now, track the c*sqrt(J) samples from the mass model
    vdisp_samps_lens = np.load(data_vectors_folder+'lens00'+str(i)+'_vdisp.npy')
    cJ_samples[i,:,0] = vdisp_samps_lens*vdisp_dist_scaling

Now let's investigate the joint mass model posterior

In [None]:
lens_idx = 0

posterior_samps = np.stack((
    inputs_dict['lens_param_samps'][lens_idx,:,3], # gamma_lens
    inputs_dict['fpd_samps'][lens_idx,:,0], # fpd01
    inputs_dict['fpd_samps'][lens_idx,:,1], # fpd02
    inputs_dict['fpd_samps'][lens_idx,:,2], # fpd03
    cJ_samples[lens_idx,:,0]
)).T

truth_params = [
    inputs_dict['gamma_truth'][lens_idx],
    inputs_dict['fpd01_truth'][lens_idx],
    inputs_dict['fpd02_truth'][lens_idx],
    inputs_dict['fpd03_truth'][lens_idx],
    c_sqrtJ_truth_list[lens_idx]
]


import corner
figure = corner.corner(posterior_samps,plot_datapoints=False,
            color='goldenrod',levels=[0.68,0.95],fill_contours=True,
            labels=['$\gamma_{lens}$','$\Delta \phi_{01}$','$\Delta \phi_{02}$','$\Delta \phi_{03}$','$c\sqrt{\mathcal{J}}$ (km/s)'],
            dpi=200,truths=truth_params,truth_color='black',
            fig=None,label_kwargs={'fontsize':20},smooth=0.7,hist_kwargs={'density':True})

# TODO: what happens if we Gaussianize?

print('posterior samps shape: ', posterior_samps.shape)

from scipy.stats import multivariate_normal
def gaussianized_samples(posterior_samps):

    Mu = np.mean(posterior_samps,axis=0)
    Cov = np.cov(posterior_samps,rowvar=False)

    gaussianized_samps = multivariate_normal.rvs(mean=Mu,cov=Cov,size=5000)

    return gaussianized_samps

gaussian_posterior_samps = gaussianized_samples(posterior_samps)
gaussian_posterior_samps500 = gaussianized_samples(posterior_samps[:500])

corner.corner(gaussian_posterior_samps,plot_datapoints=False,
            color='maroon',levels=[0.68,0.95],fill_contours=True,
            labels=['$\gamma_{lens}$','$\Delta \phi_{01}$','$\Delta \phi_{02}$','$\Delta \phi_{03}$','$c\sqrt{\mathcal{J}}$ (km/s)'],
            dpi=200,truths=truth_params,truth_color='black',
            fig=figure,label_kwargs={'fontsize':20},smooth=0.7,hist_kwargs={'density':True})
corner.corner(gaussian_posterior_samps500,plot_datapoints=False,
            color='turquoise',levels=[0.68,0.95],fill_contours=True,
            labels=['$\gamma_{lens}$','$\Delta \phi_{01}$','$\Delta \phi_{02}$','$\Delta \phi_{03}$','$c\sqrt{\mathcal{J}}$ (km/s)'],
            dpi=200,truths=truth_params,truth_color='black',
            fig=figure,label_kwargs={'fontsize':20},smooth=0.7,hist_kwargs={'density':True})

custom_lines = [
    Line2D([0], [0], color='goldenrod', lw=4),
    Line2D([0], [0], color='maroon', lw=4),
    Line2D([0], [0], color='turquoise', lw=4)
]

custom_labels = [
    'Original Samples','Gaussianized Samples','Gaussianized from 500'
]

axes = np.array(figure.axes).reshape((5,5))
axes[0,4].legend(custom_lines,custom_labels,frameon=False,fontsize=16)
plt.suptitle('Lens %d'%(lens_idx))

Now, let's run the inference!!

In [None]:
# Check that inclusion of lambda_int works
quad_kin_lklhd = tdc_sampler.TDCKinLikelihood(
    td_measured=inputs_dict['measured_td'][:10],
    td_likelihood_prec=inputs_dict['measured_prec'][:10],
    sigma_v_measured=sigma_v_measured,
    sigma_v_likelihood_prec=sigma_v_likelihood_prec,
    fpd_samples=inputs_dict['fpd_samps'][:10],
    gamma_pred_samples=inputs_dict['lens_param_samps'][:10,:,3],
    kin_pred_samples=cJ_samples,
    z_lens=inputs_dict['z_lens_truth'][:10],
    z_src=inputs_dict['z_src_truth'][:10],
    cosmo_model='LCDM_lambda_int')

Repeat the inference with Gaussianized chains

In [None]:
from scipy.stats import multivariate_normal
def gaussianized_samples(posterior_samps):

    Mu = np.mean(posterior_samps,axis=0)
    Cov = np.cov(posterior_samps,rowvar=False)

    gaussianized_samps = multivariate_normal.rvs(mean=Mu,cov=Cov,size=5000)

    return gaussianized_samps

fpd_samples_gaussianized = np.empty((10,5000,3))
gamma_pred_samples_gaussianized = np.empty((10,5000))
cJ_samples_gaussianized = np.empty((10,5000,1))

for lens_idx in range(0,10):
    posterior_samps = np.stack((
        inputs_dict['lens_param_samps'][lens_idx,:,3], # gamma_lens
        inputs_dict['fpd_samps'][lens_idx,:,0], # fpd01
        inputs_dict['fpd_samps'][lens_idx,:,1], # fpd02
        inputs_dict['fpd_samps'][lens_idx,:,2], # fpd03
        cJ_samples[lens_idx,:,0]
    )).T

    gaussianized_samps = gaussianized_samples(posterior_samps)

    # fill-in empty arrays
    fpd_samples_gaussianized[lens_idx] = gaussianized_samps[:,1:4]
    gamma_pred_samples_gaussianized[lens_idx] = gaussianized_samps[:,0]
    cJ_samples_gaussianized[lens_idx,:,0] = gaussianized_samps[:,4]
        

quad_kin_lklhd_gaussianized_samps = tdc_sampler.TDCKinLikelihood(
    td_measured=inputs_dict['measured_td'][:10],
    td_likelihood_prec=inputs_dict['measured_prec'][:10],
    sigma_v_measured=sigma_v_measured,
    sigma_v_likelihood_prec=sigma_v_likelihood_prec,
    fpd_samples=fpd_samples_gaussianized,
    gamma_pred_samples=gamma_pred_samples_gaussianized,
    kin_pred_samples=cJ_samples_gaussianized,
    z_lens=inputs_dict['z_lens_truth'][:10],
    z_src=inputs_dict['z_src_truth'][:10],
    cosmo_model='LCDM_lambda_int')

In [None]:
imp_sampling_chain = tdc_sampler.fast_TDC([quad_kin_lklhd],num_emcee_samps=5000,
            n_walkers=20)

In [None]:
gaussianized_imp_sampling_chain = tdc_sampler.fast_TDC([quad_kin_lklhd_gaussianized_samps],num_emcee_samps=5000,
            n_walkers=20)

In [None]:
def plot_convergence_by_walker(samples_mcmc, param_mcmc, n_walkers, verbose = False):
    n_params = samples_mcmc.shape[2]
    n_step = int(samples_mcmc.shape[1])
    chain = samples_mcmc
    mean_pos = np.zeros((n_params, n_step))
    median_pos = np.zeros((n_params, n_step))
    std_pos = np.zeros((n_params, n_step))
    q16_pos = np.zeros((n_params, n_step))
    q84_pos = np.zeros((n_params, n_step))
    # chain = np.empty((nwalker, nstep, ndim), dtype = np.double)
    for i in np.arange(n_params):
        for j in np.arange(n_step):
            mean_pos[i][j] = np.mean(chain[:, j, i])
            median_pos[i][j] = np.median(chain[:, j, i])
            std_pos[i][j] = np.std(chain[:, j, i])
            q16_pos[i][j] = np.percentile(chain[:, j, i], 16.)
            q84_pos[i][j] = np.percentile(chain[:, j, i], 84.)
    fig, ax = plt.subplots(n_params, sharex=True, figsize=(16, 2 * n_params))
    if n_params == 1: ax = [ax]
    last = n_step
    burnin = int((9.*n_step) / 10.) #get the final value on the last 10% on the chain
    for i in range(n_params):
        if verbose :
            print(param_mcmc[i], '{:.4f} +/- {:.4f}'.format(median_pos[i][last - 1], (q84_pos[i][last - 1] - q16_pos[i][last - 1]) / 2))
        ax[i].plot(median_pos[i][:last], c='g')
        ax[i].axhline(np.median(median_pos[i][burnin:last]), c='r', lw=1)
        ax[i].fill_between(np.arange(last), q84_pos[i][:last], q16_pos[i][:last], alpha=0.4)
        ax[i].set_ylabel(param_mcmc[i], fontsize=10)
        ax[i].set_xlim(0, last)
    return fig

plot_convergence_by_walker(np.transpose(test_chain,axes=(1,0,2)),
    ['$H_0$','$\Omega_M$',
     r'$\mu(\lambda_{int})$',r'$\sigma(\lambda_{int})$',
     r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],#'w$_0$','w$_a$'
    20)

In [None]:
mu_mean_gold = np.mean(inputs_dict['gamma_truth'])
std_mean_gold = np.std(inputs_dict['gamma_truth'])


exp_chains = [np.transpose(imp_sampling_chain,axes=(1,0,2)),
              np.transpose(gaussianized_imp_sampling_chain,axes=(1,0,2))]
exp_names = ['10 Gold with Kinematics',
             'Gaussianized 10 Gold with Kinematics']
burnin = [int(1000),int(1000)]
colors = ['goldenrod', 'indianred']
custom_labels = []

custom_lines = []
custom_labels = []
for i,exp_chain in enumerate(exp_chains):


    num_params = exp_chain.shape[2]
     
    if i ==0:

        figure = corner.corner(exp_chain[:,burnin[i]:,:].reshape((-1,exp_chain.shape[2])),plot_datapoints=False,
            color=colors[i],levels=[0.68,0.95],fill_contours=True,
            labels= ['$H_0$','$\Omega_M$',
                r'$\mu(\lambda_{int})$',r'$\sigma(\lambda_{int})$',
                r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],
            dpi=300,truths=[70.,0.3,1.,0.,mu_mean_gold,std_mean_gold],truth_color='black',
            fig=None,label_kwargs={'fontsize':24},smooth=0.7)

    else:

        corner.corner(exp_chain[:,burnin[i]:,:].reshape((-1,exp_chain.shape[2])),plot_datapoints=False,
            color=colors[i],levels=[0.68,0.95],fill_contours=True,
            labels=['$H_0$','$\Omega_M$',
                r'$\mu(\lambda_{int})$',r'$\sigma(\lambda_{int})$',
                r'$\mu(\gamma_{lens})$',r'$\sigma(\gamma_{lens})$'],
            dpi=300,truths=[70.,0.3,1.,0.,mu_mean_gold,std_mean_gold],truth_color='black',
            fig=figure,label_kwargs={'fontsize':24},smooth=0.7)
        
    custom_lines.append(Line2D([0], [0], color=colors[i], lw=4))

    # calculate h0 constraint
    h0, h0_sigma = median_sigma_from_samples(exp_chain[:,burnin[i]:,0].reshape((-1,1)),weights=None)
    # construct label
    custom_labels.append(exp_names[i]+':\n $H_0$=%.2f$\pm$%.2f'%(h0, h0_sigma))

"""
axes = np.array(figure.axes).reshape((3, 3))
bounds = [[63,77],[1.91,2.095],[0.0,0.2]]
for r in range(0,3):
        for c in range(0,r+1):
            if bounds is not None:
                axes[r,c].set_xlim(bounds[c])
                if r != c :
                    axes[r,c].set_ylim(bounds[r])

axes = np.array(figure.axes).reshape((3, 3))
"""

axes = np.array(figure.axes).reshape((num_params, num_params))
axes[0,num_params-1].legend(custom_lines,custom_labels,frameon=False,fontsize=16)

Now, let's make the input for hierArc so we can compare...

In [None]:
# Ground Truth Cosmology
gt_cosmo = jax_cosmo.Cosmology(h=jnp.float32(70./100),
                        Omega_c=jnp.float32(0.3-0.05), # "cold dark matter fraction", OmegaM = 0.3
                        Omega_b=jnp.float32(0.05), # "baryonic fraction"
                        Omega_k=jnp.float32(0.),
                        w0=jnp.float32(-1.),
                        wa=jnp.float32(0.),
                        sigma8 = jnp.float32(0.8), n_s=jnp.float32(0.96))

# Kinematic Settings
R_APERTURE = 0.725
PSF_FWHM = 0.5
kwargs_aperture = {
    'aperture_type': 'shell', 
    'r_in': 0., 
    'r_out': R_APERTURE,
    'center_ra': 0, 'center_dec': 0}

kwargs_seeing = {'psf_type': 'GAUSSIAN', 'fwhm': PSF_FWHM}

kwargs_numerics_galkin = { 
    'interpol_grid_num': 1000,  # numerical interpolation, should converge -> infinity
    'log_integration': True,  # log or linear interpolation of surface brightness and mass models
    'max_integrate': 100, 'min_integrate': 0.001}  # lower/upper bound of numerical integrals

kwargs_model = {
    'lens_model_list':['SPP'],
    'lens_light_model_list':['SERSIC']
}

anisotropy_model = 'const'
beta_prior = norm(loc=0.,scale=0.1)

# Dump relevant data vectors into a file
mu_theta_E = np.mean(inputs_dict['lens_param_samps'][:10,:,0],axis=1)
sigma_theta_E = np.std(inputs_dict['lens_param_samps'][:10,:,0],axis=1,ddof=1)
mu_gamma = np.mean(inputs_dict['lens_param_samps'][:10,:,3],axis=1)
sigma_gamma = np.std(inputs_dict['lens_param_samps'][:10,:,3],axis=1,ddof=1)
R_sersic_truth = inputs_dict['lens_light_parameters_R_sersic_truth'][:10]
n_sersic_truth = inputs_dict['lens_light_parameters_n_sersic_truth'][:10]
sigma_v_measured
sigma_v_likelihood_prec

# Ddt has to be constructed with MCMC
ddt_posterior_chains = []
for ddt_l in range(0,10):
    ddt_chain = tdc_sampler.TDCLikelihood.ddt_posterior_from_td_fpd(
        inputs_dict['measured_td'][ddt_l],
        inputs_dict['measured_prec'][ddt_l],
        inputs_dict['fpd_samps'][ddt_l],
        num_emcee_samps=10000
    )
    ddt_posterior_chains.append(ddt_chain)