In [None]:
import numpy as np
import paltas
from paltas.Substructure.los_dg19 import LOSDG19
from paltas.Substructure.subhalos_dg19 import SubhalosDG19
from paltas.Sources.cosmos import COSMOSIncludeCatalog, COSMOSExcludeCatalog
from paltas.Substructure import nfw_functions
from paltas.MainDeflector.simple_deflectors import PEMDShear
from paltas.Utils.cosmology_utils import get_cosmology, kpc_per_arcsecond
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from matplotlib.colors import LogNorm, SymLogNorm
from matplotlib import colorbar
from matplotlib.lines import Line2D
from astropy.visualization import simple_norm
from astropy.io import fits
from lenstronomy.Util.kernel_util import degrade_kernel
from paltas import generate
from matplotlib.colors import Normalize
from scipy.stats import linregress, norm, uniform
from tqdm import tqdm
import emcee
import pandas as pd
import corner
import numba
import copy
import os

root_path = paltas.__path__[0][:-7]

### __To generate all of the figures in this notebook, you will have to download data.zip [from this zenodo submission](https://zenodo.org/record/6326743#.YiIt_xPML9E) and expand it into this folder.__

### __This notebook is built to work with paltas version 0.0.4. Later versions may not work with this code.__

# Figure 1

In [None]:
# Set up the parameters for our lens and light models
z_lens = 0.5
z_source = 1.5
subhalo_parameters = {'sigma_sub':2e-3,'shmf_plaw_index':-1.87,'m_pivot': 1e10,'m_min': 1e7,'m_max': 1e10,
                      'c_0':18,'conc_zeta':-0.2,'conc_beta':0.8,'conc_m_ref': 1e8,'dex_scatter': 0.1,'k1':0.0, 'k2':0.0}
los_parameters = {'delta_los':1,'m_min':1e7,'m_max':1e10,'z_min':0.01,'dz':0.01,'cone_angle':8.0,'r_min':0.5,'r_max':10.0,'c_0':18,
                  'conc_zeta':-0.2,'conc_beta':0.8,'conc_m_ref': 1e8,'dex_scatter': 0.1,'alpha_dz_factor':5.0}
main_deflector_parameters = {'M200': 2e13, 'z_lens': z_lens,'theta_E':1.3,'center_x':0.0,'center_y':0.0,'gamma': 2.0,'e1':0.05,'e2':0.0,
                             'gamma1':0.0,'gamma2': 0.05,'ra_0':0.0,'dec_0':0.0}
cosmology_parameters = {'cosmology_name': 'planck18'}
output_ab_zeropoint = 25.127
cosmos_folder = root_path + r'/datasets/cosmos/COSMOS_23.5_training_sample/'
source_parameters = {'z_source':z_source,'cosmos_folder':cosmos_folder,'max_z':1.0,'minimum_size_in_pixels':64,
                     'faintest_apparent_mag':20,'smoothing_sigma':0.00,'random_rotation':False,
                     'output_ab_zeropoint':output_ab_zeropoint,'min_flux_radius':10.0,'center_x':0.0,'center_y':0.0,
                     'source_inclusion_list':np.array([882])}

# Set up the parameters for our observation
kwargs_numerics = {'supersampling_factor':2,'supersampling_convolution':True}
kwargs_numerics['point_source_supersampling_factor'] = (kwargs_numerics['supersampling_factor'])
mag_cut = 0.0
add_noise=True
hdul = fits.open(os.path.join(root_path,'datasets/hst_psf/emp_psf_f814w.fits'))
# Don't leave any 0 values in the psf.
psf_pix_map = degrade_kernel(hdul[0].data[17]-np.min(hdul[0].data[17]),2)

# Set up our models
los_class = LOSDG19(los_parameters,main_deflector_parameters,source_parameters,cosmology_parameters)
subhalo_class = SubhalosDG19(subhalo_parameters,main_deflector_parameters,source_parameters,cosmology_parameters)
source_class = COSMOSIncludeCatalog(source_parameters=source_parameters,cosmology_parameters=cosmology_parameters)
main_deflector_class = PEMDShear(main_deflector_parameters=main_deflector_parameters,cosmology_parameters=cosmology_parameters)
cosmo = get_cosmology(cosmology_parameters)
sub_model_list, sub_kwargs_list, sub_z_list = subhalo_class.draw_subhalos()
los_model_list, los_kwargs_list, los_z_list = los_class.draw_los()
source_model_list, source_kwargs_list = source_class.draw_source()

# Generate a fake sample
sample = {'main_deflector_parameters':main_deflector_parameters,'source_parameters':source_parameters,'cosmology_parameters':cosmology_parameters,
          'los_parameters':los_parameters,'subhalo_parameters':subhalo_parameters,
          'psf_parameters':{'psf_type':'PIXEL','kernel_point_source': psf_pix_map,'point_source_supersampling_factor':2
                           },
          'detector_parameters':{'pixel_scale':0.040,'ccd_gain':1.58,'read_noise':3.0,'magnitude_zero_point':output_ab_zeropoint,'exposure_time':1380,
                                 'sky_brightness':21.83,'num_exposures':1,'background_noise':None
                                },
          'drizzle_parameters':{'supersample_pixel_scale':0.020,'output_pixel_scale':0.030,'wcs_distortion':None,
                                'offset_pattern':[(0,0),(0.5,0),(0.0,0.5),(-0.5,-0.5)],'psf_supersample_factor':2
                               }
         }
image,_ = generate.draw_drizzled_image(sample,los_class,subhalo_class,main_deflector_class,source_class,None,None,128,True,
                             kwargs_numerics,mag_cut,add_noise)

# Convert back to masses and physical positions to generate a plot of our LOS
los_mass_array = np.zeros(len(los_kwargs_list))
los_pos_array = np.zeros((len(los_kwargs_list),3))
for i, kwargs in enumerate(los_kwargs_list):
    # Get the lenstronomy parameters
    r_scale_ang = kwargs['Rs']
    alpha_rs = kwargs['alpha_Rs']
    z = los_z_list[i]
    
    # Convert them to mass
    r_scale, rho_nfw = nfw_functions.convert_from_lenstronomy_NFW(r_scale_ang,alpha_rs,z,z_source,cosmo)
    m, c = nfw_functions.m_c_from_rho_r_scale(rho_nfw,r_scale,cosmo,z)
    kpa = kpc_per_arcsecond(z,cosmo)
    los_mass_array[i] = m 
    los_pos_array[i] = np.array([kwargs['center_x']*kpa,kwargs['center_y']*kpa,z])

sub_mass_array = np.zeros(len(sub_kwargs_list))
sub_pos_array = np.zeros((len(sub_kwargs_list),3))
for i, kwargs in enumerate(sub_kwargs_list):
    # Get the lenstronomy parameters
    r_scale_ang = kwargs['Rs']
    alpha_rs = kwargs['alpha_Rs']
    z = sub_z_list[i]
    
    # Convert them to mass
    r_scale, rho_nfw = nfw_functions.convert_from_lenstronomy_NFW(r_scale_ang,alpha_rs,z,z_source,cosmo)
    m, c = nfw_functions.m_c_from_rho_r_scale(rho_nfw,r_scale,cosmo,z)
    kpa = kpc_per_arcsecond(z,cosmo)
    sub_mass_array[i] = m 
    sub_pos_array[i] = np.array([kwargs['center_x']*kpa,kwargs['center_y']*kpa,z])

plt.rcParams.update({'font.size': 22})

# Make the scatter plot more informative
los_marker_size=np.log10(los_mass_array)-np.log10(los_parameters['m_min'])+1e-1
los_marker_size *= 120
los_norm=LogNorm(vmin=los_parameters['m_min']/1e2,vmax=los_parameters['m_max'])

sub_marker_size=np.log10(sub_mass_array)-np.log10(subhalo_parameters['m_min'])+1e-1
sub_marker_size *= 120
sub_norm=LogNorm(vmin=subhalo_parameters['m_min']/1e2,vmax=subhalo_parameters['m_max'])

figsize = (22,12)
f, ax = plt.subplots(1,1, figsize=figsize,dpi=100)
fontsize = 25

# Plot all of our deflectors
elip = Ellipse((z_lens,0),0.05,58,color='grey',zorder=0)
ax.add_patch(elip)
data = ax.scatter(los_pos_array[:,2],los_pos_array[:,0],c=los_mass_array,norm=los_norm,s=los_marker_size,cmap='GnBu',
            label='line-of-sight halos')
ax.scatter(sub_pos_array[:,2],sub_pos_array[:,0],c=sub_mass_array,norm=sub_norm,s=sub_marker_size,cmap='RdPu',
           label='subhalos',marker='^')
ax.set_xlim([-0.35,1.82])
ax.set_ylim([-30,30])

# Create the colorbars
cbar_ticks = [1e7,1e8,1e9,1e10]
cbar_ax = f.add_axes([0.92, 0.125, 0.02, 0.755])
cbar = colorbar.ColorbarBase(cbar_ax,cmap=plt.get_cmap('GnBu'),norm=sub_norm,boundaries=np.logspace(7,10,200))
cbar.set_ticks(cbar_ticks)
cbar.ax.get_yaxis().set_visible(False)
cbar_ax2 = f.add_axes([0.94, 0.125, 0.02, 0.755])
cbar2 = colorbar.ColorbarBase(cbar_ax2,cmap=plt.get_cmap('RdPu'),norm=sub_norm,boundaries=np.logspace(7,10,200))
cbar2.ax.set_ylabel('Substructure Mass ($M_\odot$)',rotation=270,labelpad=15,fontsize=fontsize)
cbar2.ax.tick_params(axis='both', labelsize=fontsize)
cbar2.set_ticks(cbar_ticks)

# Plot our source image and the lensed image
image_box_size = 0.2
source_box = f.add_axes([0.786, 0.402, figsize[1]/figsize[0]*image_box_size , image_box_size])
source_box.get_yaxis().set_visible(False)
source_box.get_xaxis().set_visible(False)
soource_norm = simple_norm(source_kwargs_list[0]['image'],stretch='asinh')
source_box.imshow(source_kwargs_list[0]['image'],cmap='plasma',norm=soource_norm)
image_box = f.add_axes([0.128, 0.402, figsize[1]/figsize[0]*image_box_size , image_box_size])
image_box.get_yaxis().set_visible(False)
image_box.get_xaxis().set_visible(False)
image_norm = simple_norm(image,stretch='asinh')
image_box.imshow(image,cmap='plasma',norm=image_norm)
ax.text(-0.015,-13,'Observational Effects',fontsize=fontsize*1.1,color='white',rotation='vertical',
        bbox=dict(boxstyle="round",ec='k',fc='grey'))
ax.text(-0.19,-13.5,'Observed\nImage',fontsize=fontsize*1.1,color='k',rotation='horizontal',horizontalalignment='center')
ax.text(1.57,-11.5,'Source',fontsize=fontsize*1.1,color='k',rotation='horizontal')

# Draw light rays
kpa_lens = kpc_per_arcsecond(z_lens,cosmo)
kpc_eins = kpa_lens * main_deflector_parameters['theta_E']
ax.annotate("",xy=(1.0, (kpc_eins+0.5)/2),xytext=(1.5, 0.5),arrowprops=dict(lw=1,color='k',shrink=0,headwidth=30,headlength=30))
ax.annotate("",xy=(0.5, kpc_eins),xytext=(1.5, 0.5),arrowprops=dict(lw=1,color='k',shrink=0,headlength=1e-10,headwidth=0))
ax.annotate("",xy=(0.275, (kpc_eins+0.5)/2),xytext=(0.5, kpc_eins),arrowprops=dict(lw=1,color='k',shrink=0,headwidth=30,headlength=30))
ax.annotate("",xy=(0.05, 0.5),xytext=(0.5, kpc_eins),arrowprops=dict(lw=1,color='k',shrink=0,headlength=1e-10,headwidth=0))
ax.annotate("",xy=(1.0, -(kpc_eins+0.5)/2),xytext=(1.5, -0.5),arrowprops=dict(lw=1,color='k',shrink=0,headwidth=30,headlength=30))
ax.annotate("",xy=(0.5, -kpc_eins),xytext=(1.5, -0.5),arrowprops=dict(lw=1,color='k',shrink=0,headlength=1e-10,headwidth=0))
ax.annotate("",xy=(0.275, -(kpc_eins+0.5)/2),xytext=(0.5, -kpc_eins),arrowprops=dict(lw=1,color='k',shrink=0,headwidth=30,headlength=30))
ax.annotate("",xy=(0.05, -0.5),xytext=(0.5, -kpc_eins),arrowprops=dict(lw=1,color='k',shrink=0,headlength=1e-10,headwidth=0))

# Generate our custom legend
custom_lines = [Line2D([0], [0],color='w', marker='o',ms=fontsize,markerfacecolor=cbar.cmap(200)),
                Line2D([0], [0],color='w', marker='^',ms=fontsize,markerfacecolor=cbar2.cmap(200)),
                Line2D([0], [0],color='w', marker='o',ms=fontsize,markerfacecolor='grey'),
                Line2D([0], [0],color='k', marker='<',ms=fontsize,lw=fontsize/6,markerfacecolor='k')]
ax.legend(custom_lines, ['line-of-sight halos','subhalos','main deflector','light ray'],fontsize=fontsize)
ax.set_ylabel('Physical x (kpc)',fontsize=fontsize)
ax.set_xticks(np.linspace(0,1.5,7))
ax.tick_params(axis='both', labelsize=fontsize)
ax.set_xlabel('z',fontsize=fontsize)


ax.get_yaxis().set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_bounds((0, 1.5))

plt.savefig('figures/simulation_summary.pdf',bbox_inches='tight',pad_inches=0.1)
plt.show()

Note the Figure above will not perfectly match our Figure 1 because the los halos and subhalos are randomly redrawn. It should capture all the same details however.

# Figure 2

In [None]:
source_parameters = {'z_source':1.5,'cosmos_folder':cosmos_folder,'max_z':1.0,'minimum_size_in_pixels':64,'faintest_apparent_mag':20,'smoothing_sigma':0.04,'random_rotation':True,
                   'output_ab_zeropoint':output_ab_zeropoint,'min_flux_radius':10.0,'center_x':0.0,'center_y':0.0,'source_exclusion_list':np.append(pd.read_csv(
                       os.path.join(root_path,'paltas/Sources/bad_galaxies.csv'),names=['catalog_i'])['catalog_i'].to_numpy(),pd.read_csv(
                       os.path.join(root_path,'paltas/Sources/val_galaxies.csv'),names=['catalog_i'])['catalog_i'].to_numpy())}
CosmosClass = COSMOSExcludeCatalog(source_parameters=source_parameters,cosmology_parameters='planck18')

# Drawn randomly once and then hardcoded for consistency
c_index = [ 2419,  6423, 26844, 36954, 35403, 41070,  8838, 19124, 11472, 40132, 41122, 17043, 18301, 43535, 11367, 35516]
f, ax = plt.subplots(4, 4, figsize=(21.5, 22), sharex=False, sharey=False,gridspec_kw={'hspace': 0.02,'wspace':0.02},dpi=100)
for i,index in enumerate(c_index):
    image = CosmosClass.draw_source(index)[1][0]['image']
    # Make sure all of the images have equal x and y dimension to make the plots look even
    image = image[:np.min(image.shape),:np.min(image.shape)]
    im_norm = simple_norm(image,stretch='asinh')
    ax[i//4,i%4].imshow(image,norm=im_norm,cmap='plasma')
    ax[i//4,i%4].get_xaxis().set_visible(False)
    ax[i//4,i%4].get_yaxis().set_visible(False)
    
plt.savefig('figures/cosmos_samples.pdf',bbox_inches='tight',pad_inches=0.1)
plt.show()

# Figure 3

Figure 3 simply combines the plotting done for Figure 1, Figure 4, and Figure 6 using a vector image editor. Therefore we do not reproduce it here.

# Figure 4

In [None]:
# Load the predictions and truths we need from the data folder.
y_pred = np.load('data/figure4_y_pred.npy')
cov_pred = np.load('data/figure4_cov_pred.npy')
std_pred = np.zeros(y_pred.shape)
for i in range(len(y_pred)):
    std_pred[i] = np.sqrt(np.diag(cov_pred[i]))
images = np.load('data/figure4_images.npy')
y_test = np.load('data/figure4_y_test.npy')

learning_params_print = [r'$\theta_\mathrm{E}$ $[^{\prime\prime}]$',r'$\gamma_1$',r'$\gamma_2$',
                         r'$\gamma_\mathrm{lens}$',r'$e_1$',r'$e_2$',r'$x_\mathrm{lens}$ $[^{\prime\prime}]$',
                         r'$y_\mathrm{lens}$ $[^{\prime\prime}]$',
                         r'$\Sigma_\mathrm{sub}$'+ '\n' + r'$[\mathrm{kpc}^{-2}]$']

samps_for_contour = 1000000
posterior_samples = np.zeros((samps_for_contour,y_pred.shape[-1]))
image_index = 20

posterior_samples=np.random.multivariate_normal(mean=y_pred[image_index],cov=cov_pred[image_index],
                                                size=samps_for_contour)
    
# rescale sigma_sub for plotting
keep_indices = [0,1,3,4,6,8]
posterior_samples[:,-1] *= 1e3
truths = np.copy(y_test[image_index])
truths[-1] *= 1e3
corner_param_print = copy.deepcopy(learning_params_print)
corner_param_print[-1]=r'$\Sigma_\mathrm{sub} \times 10^{3}$' + '\n' + r'$[\mathrm{kpc}^{-2}]$'
corner_param_print = np.array(corner_param_print)

figsize = (12,12)
fontsize = 20

# First create the corner plot
color='#FFAA00'
truth_color = 'k'
hist_kwargs = {'density':True,'color':color,'lw':3}
f = corner.corner(posterior_samples[:,keep_indices],labels=corner_param_print[keep_indices],bins=20,
                  show_titles=False, plot_datapoints=False,label_kwargs=dict(fontsize=fontsize),levels=[0.68,0.95],
                  color=color,fill_contours=True,hist_kwargs=hist_kwargs,title_fmt='.2f',truths=truths[keep_indices],
                  truth_color=truth_color,max_n_ticks=3)

# Do some whacky stuff to deal with corner
f.set_figheight(figsize[0])
f.set_figwidth(figsize[1])
for i in range(len(keep_indices)):
    for j in range(len(keep_indices)):
        corn_axis = f.axes[i*(len(keep_indices))+j]
        if j == 0:
            if i > 0:
                corn_axis.tick_params(axis='y', labelsize=fontsize*0.7,labelrotation=75)
                corn_axis.set_ylabel(corner_param_print[keep_indices][i],fontsize=fontsize)
        if i == len(keep_indices)-1:
            corn_axis.tick_params(axis='x', labelsize=fontsize*0.7,labelrotation=15)
            corn_axis.set_xlabel(corner_param_print[keep_indices][j],fontsize=fontsize)

# Now add our image
ax = f.add_axes([0.719, 0.73, 0.2, 0.2])
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
image_norm = simple_norm(images[image_index][:,:,0],stretch='asinh')
ax.imshow(images[image_index][:,:,0],cmap='plasma',norm=image_norm)
ax.text(0.5,-0.3,'Observed\nImage',fontsize=fontsize,color='k',rotation='horizontal',horizontalalignment='center',
        transform=ax.transAxes)

handles = [Line2D([0], [0], color=color, lw=10),Line2D([0], [0], color=truth_color, lw=3)]
f.legend(handles,['Network Output','True Value'],loc=(0.7,0.58),fontsize=fontsize)


plt.savefig('figures/corner_plot.pdf',bbox_inches='tight',pad_inches=0.1)
plt.show()

# Figure 5

In [None]:
f, ax = plt.subplots(2,2, sharex='col', figsize=(20,10), gridspec_kw={'height_ratios':[3, 1],'hspace':0.00,
                                                                    'wspace':0.35})
confidence_cmap = 'plasma_r'
fontsize = 20

# Plot the sigma_sub quantities
sigma_norms = np.abs((y_pred[:,-1]-y_test[:,-1])/std_pred[:,-1])
sigma_norm = Normalize(vmin=np.min(sigma_norms),vmax=np.max(sigma_norms))
ax[0,0].scatter(y_test[:,-1],y_pred[:,-1],c=np.abs((y_pred[:,-1]-y_test[:,-1])/std_pred[:,-1]),
                cmap=confidence_cmap,norm=sigma_norm)
sigma_min = -3e-3
sigma_max = 7e-3
sigma_x = np.linspace(sigma_min,sigma_max)
ax[0,0].plot(sigma_x,sigma_x,c='k')
ax[0,0].set_xlim((sigma_min,sigma_max))
ax[0,0].set_ylim((sigma_min,sigma_max))
ax[0,0].set_yticks([-2e-3,0,2e-3,4e-3,6e-3])
ax[0,0].set_yticklabels([r'$-2\times10^{-3}$',r'$0$',r'$2\times10^{-3}$',r'$4\times10^{-3}$',r'$6\times10^{-3}$'],
                        fontsize=fontsize,rotation=0)
ax[0,0].set_ylabel(r'$\Sigma_\mathrm{sub,\mu} \ [\mathrm{kpc}^{-2}]$',fontsize=fontsize*1.4)

# Display the correlation coefficient and the average uncertainty
_, _, rval, _, _ = linregress(y_test[:,-1],y_pred[:,-1])
ax[0,0].text(0.75*sigma_max+0.25*sigma_min,sigma_max*0.05+0.95*sigma_min,r'$\rho:$' + ' %.3f'%(rval),
             {'fontsize':fontsize})
ax[0,0].errorbar(0.41*sigma_max+0.59*sigma_min,sigma_max*0.86+0.14*sigma_min,yerr=np.mean(std_pred[:,-1]),
                c=plt.get_cmap(confidence_cmap).colors[256//2],fmt='.',ms=12,lw=3)
ax[0,0].text(0.06*sigma_max+0.94*sigma_min,sigma_max*0.84+0.16*sigma_min,r'Mean 68% C.I. :',
             {'fontsize':fontsize})
# rectangle = plt.Rectangle((0.05*sigma_max+0.95*sigma_min,sigma_max*0.75+0.25*sigma_min),
#                           (sigma_max-sigma_min)*0.40, (sigma_max-sigma_min)*0.22,
#                           fc='None',ec='grey',alpha=0.7,lw=1.5)
# ax[0,0].add_patch(rectangle)

ax[1,0].scatter(y_test[:,-1],(y_pred[:,-1]-y_test[:,-1]),
                c=np.abs((y_pred[:,-1]-y_test[:,-1])/std_pred[:,-1]),cmap=confidence_cmap,norm=sigma_norm)
ax[1,0].axhline(0,c='k')
ax[1,0].set_ylim((-4e-3,4e-3))
ax[1,0].set_xlabel(r'$\Sigma_\mathrm{sub,true} \ [\mathrm{kpc}^{-2}]$',fontsize=fontsize*1.4)
ax[1,0].set_xticks([-2e-3,0,2e-3,4e-3,6e-3])
ax[1,0].set_xticklabels([r'$-2\times10^{-3}$',r'$0$',r'$2\times10^{-3}$',r'$4\times10^{-3}$',r'$6\times10^{-3}$'],
                        fontsize=fontsize,rotation=0)
ax[1,0].set_yticks([-2e-3,0,2e-3])
ax[1,0].set_yticklabels([r'$-2\times10^{-3}$',r'$0$',r'$2\times10^{-3}$'],
                        fontsize=fontsize,rotation=0)
ax[1,0].set_ylabel(r'$\Sigma_\mathrm{sub,\mu} - \Sigma_\mathrm{sub,true}$' + '\n' + r'$[\mathrm{kpc}^{-2}]$',
                   fontsize=fontsize*1.4)

# Plot the theta_E quantities
ax[0,1].scatter(y_test[:,0],y_pred[:,0],c=np.abs((y_pred[:,0]-y_test[:,0])/std_pred[:,0]),
                cmap=confidence_cmap,norm=sigma_norm)
theta_min = np.min(y_test[:,0]-6e-2)
theta_max = np.max(y_test[:,0]+6e-2)
theta_x = np.linspace(theta_min,theta_max)
ax[0,1].plot(theta_x,theta_x,c='k')
ax[0,1].set_xlim((theta_min,theta_max))
ax[0,1].set_ylim((theta_min,theta_max))
ax[0,1].set_ylabel(r'$\theta_\mathrm{E,\mu} \ [^{\prime\prime}]$',fontsize=fontsize*1.4)
ax[0,1].set_yticks(np.linspace(0.6,1.6,6))
ax[0,1].set_yticklabels(['%.1f'%(x) for x in np.linspace(0.6,1.6,6)],fontsize=fontsize,rotation=0)

# Display the correlation coefficient and the average uncertainty
_, _, rval, _, _ = linregress(y_test[:,0],y_pred[:,0])
ax[0,1].text(0.75*theta_max+0.25*theta_min,theta_max*0.05+0.95*theta_min,r'$\rho:$' +' %.3f'%(rval),
             {'fontsize':fontsize})
ax[0,1].errorbar(0.41*theta_max+0.59*theta_min,theta_max*0.86+0.14*theta_min,yerr=np.mean(std_pred[:,0]),
                c=plt.get_cmap(confidence_cmap).colors[256//2],fmt='.',ms=12,lw=3)
ax[0,1].text(0.06*theta_max+0.94*theta_min,theta_max*0.84+0.16*theta_min,r'Mean 68% C.I. :',
             {'fontsize':fontsize})
# rectangle = plt.Rectangle((0.05*theta_max+0.95*theta_min,theta_max*0.75+0.25*theta_min),
#                           (theta_max-theta_min)*0.40, (theta_max-theta_min)*0.22,
#                           fc='None',ec='grey',alpha=0.7,lw=1.5)
# ax[0,1].add_patch(rectangle)

ax[1,1].scatter(y_test[:,0],(y_pred[:,0]-y_test[:,0]),
                c=np.abs((y_pred[:,0]-y_test[:,0])/std_pred[:,0]),cmap=confidence_cmap,norm=sigma_norm)
ax[1,1].axhline(0,c='k')
ax[1,1].set_ylim((-1e-1,1e-1))
ax[1,1].set_xlabel(r'$\theta_\mathrm{E,true} \ [^{\prime\prime}]$',fontsize=fontsize*1.4)
ax[1,1].set_xticks(np.linspace(0.6,1.6,6))
ax[1,1].set_xticklabels(['%.1f'%(x) for x in np.linspace(0.6,1.6,6)],fontsize=fontsize,rotation=0)
ax[1,1].set_yticks([-5e-2,0,5e-2])
ax[1,1].set_yticklabels([r'$-0.05$',r'$0$',r'$0.05$'],
                        fontsize=fontsize,rotation=0)
ax[1,1].set_ylabel(r'$\theta_\mathrm{E,\mu} - \theta_\mathrm{E,true}$' + '\n' + r'$[^{\prime\prime}]$', 
                   fontsize=fontsize*1.4)

# Create the colorbars
cbar_ticks = [0,1,2,3]
cbar_ax = f.add_axes([0.92, 0.125, 0.02, 0.755])
cbar = colorbar.ColorbarBase(cbar_ax,cmap=plt.get_cmap(confidence_cmap),
                             boundaries=np.linspace(0,3.2,200),norm=sigma_norm)
cbar.set_ticks(cbar_ticks)
cbar.ax.set_ylabel(r'$|\mu-\mathrm{true}|/\sigma$',rotation=90,labelpad=15,fontsize=fontsize*1.4)
cbar.ax.tick_params(axis='both', labelsize=fontsize)

plt.savefig('figures/individual_population.pdf',bbox_inches='tight',pad_inches=0.2)
plt.show()

# Figure 6 and 7

In [None]:
import numba
@numba.njit()
def upsample(image,upsampling):
    new_image = np.zeros((image.shape[0]*upsampling,image.shape[1]*upsampling))
    for i in range(len(image)):
        for j in range(len(image)):
            new_image[i*upsampling:(i+1)*upsampling,j*upsampling:(j+1)*upsampling] = image[i,j]
    return new_image

@numba.njit()
def downsample(image,downsampling):
    new_image = np.zeros((image.shape[0]//downsampling,image.shape[1]//downsampling))
    for i in range(len(new_image)):
        for j in range(len(new_image)):
            new_image[i,j] = np.mean(image[i*downsampling:(i+1)*downsampling,j*downsampling:(j+1)*downsampling])
    return new_image

sigma_mean = np.linspace(2e-4,4e-3,20)
lower_sigma = np.ones(len(sigma_mean))
upper_sigma = np.ones(len(sigma_mean))
mean_sigma = np.ones(len(sigma_mean))
n_lenses_list = [10,50,100,1000]
colors = ['#a1dab4','#41b6c4','#2c7fb8','#253494']
fontsize=20
burnin = 4000
chains_base = 'data/'

# Load the data
for ti in [3,11,5]:
    # Load the images for that test set.
    images_array = np.load(os.path.join(chains_base,'images_test_shift_%d.npy'%(ti+1)))

    # Place colored boxes around lens counts
    perm_boxes = [[0.69,0.7,0.205,0.18],[0.69,0.52,0.205,0.36],[0.69,0.295,0.205,0.585],[0.69,0.128,0.205,0.752]]

    n_plot_draws = 10000
    current_color = colors[-1]
    fontsize = 30
    quant_low = 0.05
    quant_high = 0.95

    m_pivot = 1e10
    m_min = 1e7
    m_max = 1e10
    m_range = np.logspace(np.log10(m_min),np.log10(m_max),20)

    sigma_subs = norm(loc=sigma_mean[ti],scale=1.5e-4).rvs(n_plot_draws)
    gamma_subs = uniform(loc=-1.92,scale=0.1).rvs(n_plot_draws)
    shmf_test = np.zeros((n_plot_draws,len(m_range)))
    for j in range(n_plot_draws):
        sigma_sub = sigma_subs[j]
        gamma_sub = gamma_subs[j]
        shmf_test[j] = sigma_sub*1/m_pivot*(m_range/m_pivot)**gamma_sub

    # Cut values that are too close to 0
    shmf_test[shmf_test<1e-16]=1e-16

    im_norm = simple_norm(images_array,stretch='asinh')
    spine_width = 7

    # Initialize the figure
    f, ax = plt.subplots(1,2, figsize=(22,10), 
                     gridspec_kw={'width_ratios':[3, 1.2],'wspace':0.05},dpi=100)
    for nli,n_lenses in tqdm(enumerate(n_lenses_list)):

        # Pull up the chains to know what parameter values to sample.
        chains_path = os.path.join(chains_base,'test_set_%d_lenses_%d.h5'%(ti+1,n_lenses))
        all_chains = emcee.backends.HDFBackend(chains_path).get_chain()[burnin:,:,:]
        all_chains = all_chains.reshape((-1,all_chains.shape[-1]))

        # Draw and evaluate the shmf contour
        sigma_subs = np.random.choice(all_chains[:,8],size=n_plot_draws) 
        sigma_subs += np.random.randn(n_plot_draws)*np.exp(np.random.choice(all_chains[:,17],size=n_plot_draws))
        gamma_subs = uniform(loc=-1.92,scale=0.1).rvs(n_plot_draws)
        shmf_samples = np.zeros((n_plot_draws,len(m_range)))
        for j in range(n_plot_draws):
            sigma_sub = sigma_subs[j]
            gamma_sub = gamma_subs[j]
            shmf_samples[j] = sigma_sub*1/m_pivot*(m_range/m_pivot)**gamma_sub

        # Cut values that are too close to 0
        shmf_samples[shmf_samples<1e-16]=1e-16

        # Add a boundary for the galaxy occupation distribution
        ax[0].axvspan(7.0, 7.93, color='grey', alpha=0.1,zorder=0)
        ax[0].text(7.13,3e-3,'   Potential \n"Dark Halos"',fontsize=fontsize,color='white',
                   bbox=dict(boxstyle="round",ec='#969696',fc='#969696'))

        # Plot the samples
        ax[0].fill_between(np.log10(m_range),
                 10**np.quantile(np.log10(shmf_samples*m_range),quant_low,axis=0),
                 10**np.quantile(np.log10(shmf_samples*m_range),quant_high,axis=0),
                 alpha=0.8,color=colors[nli],label='%d Lenses Constraint'%(n_lenses))
        ax_box = f.add_axes(perm_boxes[nli])
        ax_box.get_xaxis().set_visible(False)
        ax_box.get_yaxis().set_visible(False)
        for spine in ax_box.spines.values():
            spine.set_edgecolor(colors[nli])
            spine.set_linewidth(spine_width)
        ax_box.patch.set_alpha(0)

    # Plot the test distribution
    ax[0].plot(np.log10(m_range),10**np.quantile(np.log10(shmf_test*m_range),quant_low,axis=0),color='k',ls='--',lw=2)
    ax[0].plot(np.log10(m_range),10**np.mean(np.log10(shmf_test*m_range),axis=0),color='k',ls='-',lw=3,
             label='Test Distribution')
    ax[0].plot(np.log10(m_range),10**np.quantile(np.log10(shmf_test*m_range),quant_high,axis=0),color='k',ls='--',lw=2)

    ax[0].set_xlabel(r'$\log_{10} \left(M/M_\odot\right)$',fontsize=fontsize)
    ax[0].set_ylabel(r'$\frac{d^2 N(M)}{d\log (M/M_\odot) \ dA} \ [\mathrm{kpc}^{-2}]$',fontsize=fontsize*1.3)
    ax[0].tick_params(axis='both', labelsize=fontsize)
    ax[0].set_yscale('log')
    ax[0].legend(fontsize=fontsize*0.8,loc=1)
    ax[0].set_xlim(np.min(np.log10(m_range)),np.max(np.log10(m_range)))
    ax[0].set_ylim(1e-3,1)
    ax[0].set_title('Subhalo Mass Function Constraints for Test Set %d'%(ti+1),fontsize=fontsize)

    mosaic = np.ones((170*83,170*10*5))
    mosaic[:] = np.nan

    for i in range(min(n_lenses,10)):
        mosaic[i//5*1700:(i//5+1)*1700,i%5*1700:(i%5+1)*1700] = upsample(images_array[i,:,:,0],10)

    offset = 1700*2
    for i in range(min(n_lenses-10,40)):
        mosaic[offset+i//10*850:offset+(i//10+1)*850,i%10*850:(i%10+1)*850] = (
            upsample(images_array[i+10,:,:,0],5))

    offset = 1700*2+850*4
    for i in range(min(n_lenses-50,50)):
        mosaic[offset+i//10*850:offset+(i//10+1)*850,i%10*850:(i%10+1)*850] = (
            upsample(images_array[i+50,:,:,0],5))

    offset = 170*65
    for i in range(min(n_lenses-100,900)):
        mosaic[offset+i//50*170:offset+(i//50+1)*170,i%50*170:(i%50+1)*170] = (images_array[i+100,:,:,0])

    ax[1].imshow(mosaic,norm=im_norm,cmap='plasma')
    ax[1].get_xaxis().set_visible(False)
    ax[1].get_yaxis().set_visible(False)

    f.savefig('figures/shmf_constraint_test_%d.pdf'%(ti+1),bbox_inches='tight',pad_inches=0.2)
    plt.show()

# Figure 8

In [None]:
sigma_mean = np.linspace(2e-4,4e-3,20)
lower_sigma = np.ones(len(sigma_mean))
upper_sigma = np.ones(len(sigma_mean))
mean_sigma = np.ones(len(sigma_mean))
n_lenses_list = [10,50,100,1000]
colors = ['#a1dab4','#41b6c4','#2c7fb8','#253494']
fontsize=20
burnin = 4000
lower_quant = 0.16
upper_quant = 0.84
chains_base = 'data/'

for n_lines_plot in range(5):
    f, ax = plt.subplots(2,1, sharex='col', sharey='row', figsize=(15,18), 
                         gridspec_kw={'height_ratios':[4, 1],'hspace':0.00},dpi=100)

    for li,n_lenses in enumerate(n_lenses_list[:n_lines_plot]):
        for ti in range(0,len(sigma_mean)):
            # Extract the sigma_sub mean estimate from the chains.
            chains_path = os.path.join(chains_base,'test_set_%d_lenses_%d.h5'%(ti+1,n_lenses))
            sigma_sub_chain = emcee.backends.HDFBackend(chains_path,read_only=True).get_chain()[burnin:,:,8].flatten()
            if len(sigma_sub_chain) > 1000:
                mean_sigma[ti] = np.mean(sigma_sub_chain)
                lower_sigma[ti] = np.quantile(sigma_sub_chain,lower_quant)
                upper_sigma[ti] = np.quantile(sigma_sub_chain,upper_quant)
            else:
                mean_sigma[ti] = -10
                lower_sigma[ti] = -11
                upper_sigma[ti] = -9

        bars = np.array([mean_sigma-lower_sigma,upper_sigma-mean_sigma])
        offset = -6e-5+li*3e-5
        ax[0].errorbar(sigma_mean+offset,mean_sigma+offset,c=colors[li],fmt='.',yerr=bars,
                     label='%d Lenses'%(n_lenses_list[li]),ms=10,lw=3)
        ax[1].errorbar(sigma_mean+offset,mean_sigma-sigma_mean,c=colors[li],fmt='.',yerr=bars,
                     label='%d Lenses'%(n_lenses_list[li]),ms=10,lw=3)

    # Plot some regions for comparison
    ax[0].axvspan(8.6e-4-1.4e-4, 8.6e-4+1.4e-4, color='grey', alpha=0.5)
    ax[1].axvspan(8.6e-4-1.4e-4, 8.6e-4+1.4e-4, color='grey', alpha=0.5)
    ax[0].axvspan(4.8e-4, 1.18e-3, color='grey', alpha=0.5)
    ax[1].axvspan(4.8e-4, 1.18e-3, color='grey', alpha=0.5)
    ax[0].text(8.0e-4,2.5e-3,r'DMO Prediction',fontsize=fontsize*1.1,color='white',
             bbox=dict(boxstyle="round",ec='k',fc='grey',alpha=1.0),rotation='vertical')

    # Plot the perfect recovery line
    sigma_sigma = np.linspace(0,4e-3,100)
    ax[0].plot(sigma_sigma,sigma_sigma,c='k',label='$\Sigma_\mathrm{sub,hier}=\Sigma_\mathrm{sub,pop}$')
    ax[1].plot(sigma_sigma,np.zeros(sigma_sigma.shape),c='k',
               label='$\Sigma_\mathrm{sub,hier}=\Sigma_\mathrm{sub,pop}$')

    # Set plot limits and labels
    ax[0].set_xlim(-5e-4,4.3e-3)
    ax[0].set_ylim(-5e-4,4.3e-3)
    ticks = [0,1e-3,2e-3,3e-3,4e-3]
    tick_labels = ['$0$',r'$1\times10^{-3}$',r'$2\times10^{-3}$',r'$3\times10^{-3}$',r'$4\times10^{-3}$']
    ax[0].set_ylabel(r'$\Sigma_\mathrm{sub,hier} \ [\mathrm{kpc}^{-2}]$',fontsize=fontsize*1.4)
    ax[1].set_ylabel(r'$\Sigma_\mathrm{sub,hier}-\Sigma_\mathrm{sub,pop}$' + '\n' + r'$[\mathrm{kpc}^{-2}]$',
                     fontsize=fontsize*1.4)
    ax[1].set_xlabel(r'$\Sigma_\mathrm{sub,pop} \ [\mathrm{kpc}^{-2}]$',fontsize=fontsize*1.4)
    ax[1].set_xticks(ticks)
    ax[1].set_xticklabels(tick_labels,fontsize=fontsize*1.1)
    ax[0].set_yticks(ticks)
    ax[0].set_yticklabels(tick_labels,fontsize=fontsize*1.1)
    ticks = [-5e-4,0,5e-4]
    tick_labels = [r'$-5\times10^{-4}$','$0$',r'$5\times10^{-4}$']
    ax[1].set_yticks(ticks)
    ax[1].set_yticklabels(tick_labels,fontsize=fontsize*1.1)
    ax[0].tick_params(length=7,width=2)
    ax[1].tick_params(length=7,width=2)
    ax[1].set_ylim(-0.6e-3,0.6e-3)
    ax[0].legend(fontsize=fontsize,framealpha=1,loc=4)
    plt.savefig('figures/response_plot_%d.pdf'%(n_lines_plot),bbox_inches='tight',pad_inches=0.2)
    plt.show()