In [None]:
import paltas
import jaxstronomy
import jax
import jax.numpy as jnp
import numpy as np
import functools
import matplotlib.pyplot as plt
import copy

The paltas and jaxstronomy codebase are performing the same operations, so 

There are three components to an update step for our model:

1. Drawing a population of dark matter substructure for the lensing calculations.
2. Generating a lensing image given that dark matter substructure population.
3. Updating the model given a batch of images.

Below we'll create code that should generate the same outputs (up to random seed issues for the first bullepoint) that we can
compare.

## Drawing a population of dark matter substructure for the lensing calculations.

In [None]:
# Setup paltas and jaxstronomy configs
config_handler = paltas.Configs.config_handler.ConfigHandler('comparison_files/input_config_paltas.py')

from comparison_files import input_config_jaxstronomy
rng = jax.random.PRNGKey(0)
input_config_jax = input_config_jaxstronomy.get_config()

In [None]:
# A bit of setup for the jaxstronomy code. This is drawing the specific distribution parameters for the subhalos, main deflector halos,
# and the source. These are all encoded as constants, so we're just drawing the constant values.
subhalo_params = jaxstronomy.input_pipeline.draw_sample(input_config_jax['lensing_config']['subhalo_params'], rng)
main_deflector_params = jaxstronomy.input_pipeline.draw_sample(input_config_jax['lensing_config']['main_deflector_params'], rng)
source_params = jaxstronomy.input_pipeline.draw_sample(input_config_jax['lensing_config']['source_params'], rng)

# Initialize the cosmology parameters we need for our cosmology calculations and create a jitted function for fast draws.
cosmology_params = jaxstronomy.input_pipeline.intialize_cosmology_params(input_config_jax, rng)
subhalos_pad_length = 1000
sampling_pad_length = 100000
draw_subhalos_jit = jax.jit(functools.partial(jaxstronomy.subhalos.draw_subhalos, subhalos_pad_length=subhalos_pad_length, 
                                              sampling_pad_length=sampling_pad_length))
_ = draw_subhalos_jit(main_deflector_params, source_params, subhalo_params, cosmology_params, rng)

In [None]:
# Here I'm drawing from paltas and from jaxstronomy and storing the results of both draws.
n_draws = 1000
paltas_kwargs_list = []
paltas_n_subhalos = np.zeros(n_draws)
for i in range(n_draws):
    paltas_models, paltas_kwargs, paltas_z = config_handler.subhalo_class.draw_subhalos()
    paltas_n_subhalos[i] = len(paltas_models)
    paltas_kwargs_list.append(paltas_kwargs)

rng_draw, _ = jax.random.split(rng)
jaxstronomy_kwargs_list = []
jaxstronomy_n_subhalos = np.zeros(n_draws)
for i in range(n_draws):
    rng_draw, _ = jax.random.split(rng_draw)
    jaxstronomy_z_list, jaxstronomy_kwargs = draw_subhalos_jit(main_deflector_params, source_params, subhalo_params, cosmology_params, rng_draw)
    jaxstronomy_n_subhalos[i] = jnp.sum(jaxstronomy_kwargs['alpha_rs'] > 0.0)
    jaxstronomy_kwargs_list.append(jaxstronomy_kwargs)

In [None]:
# Start by comparing the number of subhalos being drawn and make sure they look similar.
_, bins, _ = plt.hist(paltas_n_subhalos, bins=30, histtype='step', lw=3)
plt.hist(jaxstronomy_n_subhalos, bins=bins, histtype='step', lw=3)
plt.xlabel('Number of Subhalos')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

In [None]:
# Now compare the distribution of the parameters.
def extract_np_array_paltas(param, list_of_draws):
    concat_values = []
    for draw in list_of_draws:
        for subhalos in draw:
            concat_values.append(subhalos[param])
    return np.array(concat_values)

def extract_np_array_jaxstronomy(param, list_of_draws):
    concat_values = []
    for draw in list_of_draws:
        concat_values.append(draw[param])
    return np.concatenate(concat_values)

def extract_without_zeros_jaxstronomy(param, list_of_draws, key):
    array = extract_np_array_jaxstronomy(param, list_of_draws)
    return array[key > 0.0]

alpha_rs_jax = extract_np_array_jaxstronomy('alpha_rs', jaxstronomy_kwargs_list)
_, bins, _ = plt.hist(np.log(extract_np_array_paltas('alpha_Rs', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('alpha_rs', jaxstronomy_kwargs_list, alpha_rs_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Max Proxy $(\alpha_{Rs})$')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(np.log(extract_np_array_paltas('Rs', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('scale_radius', jaxstronomy_kwargs_list, alpha_rs_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Size Proxy $(Rs)$')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(np.log(extract_np_array_paltas('r_trunc', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('trunc_radius', jaxstronomy_kwargs_list, alpha_rs_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Truncation Radius')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(extract_np_array_paltas('center_x', paltas_kwargs_list), bins=100, histtype='step', log=True, lw=3)
plt.hist(extract_without_zeros_jaxstronomy('center_x', jaxstronomy_kwargs_list, alpha_rs_jax), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Center x')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(extract_np_array_paltas('center_y', paltas_kwargs_list), bins=100, histtype='step', log=True, lw=3)
plt.hist(extract_without_zeros_jaxstronomy('center_y', jaxstronomy_kwargs_list, alpha_rs_jax), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Center y')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

### Modify two of the substructure parameters, slope and normalization, and make sure that everything still agrees

In [None]:
input_config_jax['lensing_config']['subhalo_params']['sigma_sub'] = jaxstronomy.input_pipeline.encode_constant(1.0e-3)
input_config_jax['lensing_config']['subhalo_params']['shmf_plaw_index'] = jaxstronomy.input_pipeline.encode_constant(-1.92)
subhalo_params = jaxstronomy.input_pipeline.draw_sample(input_config_jax['lensing_config']['subhalo_params'], rng)
config_handler = paltas.Configs.config_handler.ConfigHandler('comparison_files/input_config_paltas_two.py')

In [None]:
# Again, draw from paltas and from jaxstronomy and storing the results of both draws.
n_draws = 1000
paltas_kwargs_list = []
paltas_n_subhalos = np.zeros(n_draws)
for i in range(n_draws):
    paltas_models, paltas_kwargs, paltas_z = config_handler.subhalo_class.draw_subhalos()
    paltas_n_subhalos[i] = len(paltas_models)
    paltas_kwargs_list.append(paltas_kwargs)

rng_draw, _ = jax.random.split(rng)
jaxstronomy_kwargs_list = []
jaxstronomy_n_subhalos = np.zeros(n_draws)
for i in range(n_draws):
    rng_draw, _ = jax.random.split(rng_draw)
    jaxstronomy_z_list, jaxstronomy_kwargs = draw_subhalos_jit(main_deflector_params, source_params, subhalo_params, cosmology_params, rng_draw)
    jaxstronomy_n_subhalos[i] = jnp.sum(jaxstronomy_kwargs['alpha_rs'] > 0.0)
    jaxstronomy_kwargs_list.append(jaxstronomy_kwargs)

In [None]:
# Start by comparing the number of subhalos being drawn and make sure they look similar.
_, bins, _ = plt.hist(paltas_n_subhalos, bins=30, histtype='step', lw=3)
plt.hist(jaxstronomy_n_subhalos, bins=bins, histtype='step', lw=3)
plt.xlabel('Number of Subhalos')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

In [None]:
alpha_rs_jax = extract_np_array_jaxstronomy('alpha_rs', jaxstronomy_kwargs_list)
_, bins, _ = plt.hist(np.log(extract_np_array_paltas('alpha_Rs', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('alpha_rs', jaxstronomy_kwargs_list, alpha_rs_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Max Proxy $(\alpha_{Rs})$')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(np.log(extract_np_array_paltas('Rs', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('scale_radius', jaxstronomy_kwargs_list, alpha_rs_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Size Proxy $(Rs)$')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(np.log(extract_np_array_paltas('r_trunc', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('trunc_radius', jaxstronomy_kwargs_list, alpha_rs_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Truncation Radius')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(extract_np_array_paltas('center_x', paltas_kwargs_list), bins=100, histtype='step', log=True, lw=3)
plt.hist(extract_without_zeros_jaxstronomy('center_x', jaxstronomy_kwargs_list, alpha_rs_jax), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Center x')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(extract_np_array_paltas('center_y', paltas_kwargs_list), bins=100, histtype='step', log=True, lw=3)
plt.hist(extract_without_zeros_jaxstronomy('center_y', jaxstronomy_kwargs_list, alpha_rs_jax), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Center y')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

## Generating a lensing image and comparing it to nno substructure.

In [None]:
# Setup paltas and jaxstronomy configs
config_handler_zero = paltas.Configs.config_handler.ConfigHandler('comparison_files/input_config_paltas_zero.py')
config_handler = paltas.Configs.config_handler.ConfigHandler('comparison_files/input_config_paltas.py')
# No noise in paltas and jaxstronomy to make things easier
config_handler.add_noise=False
config_handler_zero.add_noise = False

from comparison_files import input_config_jaxstronomy
rng = jax.random.PRNGKey(0)
input_config_jax = input_config_jaxstronomy.get_config()
draw_image_and_truth_jit = jax.jit(functools.partial(jaxstronomy.input_pipeline.draw_image_and_truth, all_models=input_config_jax['all_models'],
                                                    principal_md_index=input_config_jax['principal_md_index'], principal_source_index=input_config_jax['principal_source_index'],
                                                    kwargs_simulation=input_config_jax['kwargs_simulation'], kwargs_detector=input_config_jax['kwargs_detector'],
                                                    kwargs_psf=input_config_jax['kwargs_psf'], truth_parameters=input_config_jax['truth_parameters']))
lensing_config_zero = copy.deepcopy(input_config_jax['lensing_config'])
lensing_config_zero['subhalo_params']['sigma_sub'] = jaxstronomy.input_pipeline.encode_constant(0.0)
cosmology_params = jaxstronomy.input_pipeline.intialize_cosmology_params(input_config_jax, rng)
grid_x, grid_y = jaxstronomy.input_pipeline.generate_grids(input_config_jax)

In [None]:
f, ax = plt.subplots(4, 4, figsize=(21.5, 22), sharex=False, sharey=False,gridspec_kw={'hspace': 0.02,'wspace':0.02},dpi=100)
image_zero, _ = config_handler_zero.draw_image()
for i in range(16):
    image, metadata = config_handler.draw_image()
    ax[i//4,i%4].imshow((image-image_zero)/image_zero,cmap='plasma')
    ax[i//4,i%4].get_xaxis().set_visible(False)
    ax[i//4,i%4].get_yaxis().set_visible(False)
    
plt.show()

In [None]:
f, ax = plt.subplots(4, 4, figsize=(21.5, 22), sharex=False, sharey=False,gridspec_kw={'hspace': 0.02,'wspace':0.02},dpi=100)
rng_draw, _ = jax.random.split(rng)
image_jax_zero, _ = draw_image_and_truth_jit(lensing_config_zero, cosmology_params, grid_x, grid_y, rng_draw)
for i in range(16):
    rng_draw, _ = jax.random.split(rng_draw)
    image_jax, truth_jax = draw_image_and_truth_jit(input_config_jax['lensing_config'], cosmology_params, grid_x, grid_y, rng_draw)
    ax[i//4,i%4].imshow((image_jax-image_jax_zero)/image_jax_zero,cmap='plasma')
    ax[i//4,i%4].get_xaxis().set_visible(False)
    ax[i//4,i%4].get_yaxis().set_visible(False)
    
plt.show()