In [None]:
import paltas

import jaxstronomy

import jax
import jax.numpy as jnp
import numpy as np
import functools
import matplotlib.pyplot as plt
import copy

The paltas and jaxstronomy codebase are performing the same operations, so 

There are three components to an update step for our model:

1. Drawing a population of dark matter substructure for the lensing calculations.
2. Generating a lensing image given that dark matter substructure population.
3. Updating the model given a batch of images.

Below we'll create code that should generate the same outputs (up to random seed issues for the first bullepoint) that we can
compare.

## Drawing a population of dark matter substructure for the lensing calculations.

In [None]:
# Setup paltas and jaxstronomy configs
config_handler = paltas.Configs.config_handler.ConfigHandler('comparison_files/input_config_paltas.py')

from comparison_files import input_config_jaxstronomy
rng = jax.random.PRNGKey(0)
input_config_jax = input_config_jaxstronomy.get_config()

In [None]:
# A bit of setup for the jaxstronomy code. This is drawing the specific distribution parameters for the subhalos, main deflector halos,
# and the source. These are all encoded as constants, so we're just drawing the constant values.
subhalo_params = jaxstronomy.input_pipeline.draw_sample(input_config_jax['lensing_config']['subhalo_params'], rng)
main_deflector_params = jaxstronomy.input_pipeline.draw_sample(input_config_jax['lensing_config']['main_deflector_params'], rng)
source_params = jaxstronomy.input_pipeline.draw_sample(input_config_jax['lensing_config']['source_params'], rng)

# Initialize the cosmology parameters we need for our cosmology calculations and create a jitted function for fast draws.
cosmology_params = jaxstronomy.input_pipeline.intialize_cosmology_params(input_config_jax, rng)
subhalos_pad_length = 1000
sampling_pad_length = 100000
draw_subhalos_jit = jax.jit(functools.partial(jaxstronomy.subhalos.draw_subhalos, subhalos_pad_length=subhalos_pad_length, 
                                              sampling_pad_length=sampling_pad_length))
_ = draw_subhalos_jit(main_deflector_params, source_params, subhalo_params, cosmology_params, rng)

In [None]:
# Here I'm drawing from paltas and from jaxstronomy and storing the results of both draws.
n_draws = 1000
paltas_kwargs_list = []
paltas_n_subhalos = np.zeros(n_draws)
for i in range(n_draws):
    paltas_models, paltas_kwargs, paltas_z = config_handler.subhalo_class.draw_subhalos()
    paltas_n_subhalos[i] = len(paltas_models)
    paltas_kwargs_list.append(paltas_kwargs)

rng_draw, _ = jax.random.split(rng)
jaxstronomy_kwargs_list = []
jaxstronomy_n_subhalos = np.zeros(n_draws)
for i in range(n_draws):
    rng_draw, _ = jax.random.split(rng_draw)
    jaxstronomy_z_list, jaxstronomy_kwargs = draw_subhalos_jit(main_deflector_params, source_params, subhalo_params, cosmology_params, rng_draw)
    jaxstronomy_n_subhalos[i] = jnp.sum(jaxstronomy_kwargs['alpha_rs'] > 0.0)
    jaxstronomy_kwargs_list.append(jaxstronomy_kwargs)

In [None]:
# Start by comparing the number of subhalos being drawn and make sure they look similar.
_, bins, _ = plt.hist(paltas_n_subhalos, bins=30, histtype='step', lw=3)
plt.hist(jaxstronomy_n_subhalos, bins=bins, histtype='step', lw=3)
plt.xlabel('Number of Subhalos')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

In [None]:
# Now compare the distribution of the parameters.
def extract_np_array_paltas(param, list_of_draws):
    concat_values = []
    for draw in list_of_draws:
        for subhalos in draw:
            concat_values.append(subhalos[param])
    return np.array(concat_values)

def extract_np_array_jaxstronomy(param, list_of_draws):
    concat_values = []
    for draw in list_of_draws:
        concat_values.append(draw[param])
    return np.concatenate(concat_values)

def extract_without_zeros_jaxstronomy(param, list_of_draws, model_index):
    array = extract_np_array_jaxstronomy(param, list_of_draws)
    return array[model_index > -1]

model_index_jax = extract_np_array_jaxstronomy('model_index', jaxstronomy_kwargs_list)
_, bins, _ = plt.hist(np.log(extract_np_array_paltas('alpha_Rs', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('alpha_rs', jaxstronomy_kwargs_list, model_index_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Mass Proxy $(\alpha_{Rs})$')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(np.log(extract_np_array_paltas('Rs', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('scale_radius', jaxstronomy_kwargs_list, model_index_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Size Proxy $(Rs)$')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(np.log(extract_np_array_paltas('r_trunc', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('trunc_radius', jaxstronomy_kwargs_list, model_index_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Truncation Radius')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(extract_np_array_paltas('center_x', paltas_kwargs_list), bins=100, histtype='step', log=True, lw=3)
plt.hist(extract_without_zeros_jaxstronomy('center_x', jaxstronomy_kwargs_list, model_index_jax), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Center x')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(extract_np_array_paltas('center_y', paltas_kwargs_list), bins=100, histtype='step', log=True, lw=3)
plt.hist(extract_without_zeros_jaxstronomy('center_y', jaxstronomy_kwargs_list, model_index_jax), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Center y')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

### Modify two of the substructure parameters, slope and normalization, and make sure that everything still agrees

In [None]:
input_config_jax['lensing_config']['subhalo_params']['sigma_sub'] = jaxstronomy.input_pipeline.encode_constant(1.0e-3)
input_config_jax['lensing_config']['subhalo_params']['shmf_plaw_index'] = jaxstronomy.input_pipeline.encode_constant(-1.92)
subhalo_params = jaxstronomy.input_pipeline.draw_sample(input_config_jax['lensing_config']['subhalo_params'], rng)
config_handler = paltas.Configs.config_handler.ConfigHandler('comparison_files/input_config_paltas_two.py')

In [None]:
# Again, draw from paltas and from jaxstronomy and storing the results of both draws.
n_draws = 1000
paltas_kwargs_list = []
paltas_n_subhalos = np.zeros(n_draws)
for i in range(n_draws):
    paltas_models, paltas_kwargs, paltas_z = config_handler.subhalo_class.draw_subhalos()
    paltas_n_subhalos[i] = len(paltas_models)
    paltas_kwargs_list.append(paltas_kwargs)

rng_draw, _ = jax.random.split(rng)
jaxstronomy_kwargs_list = []
jaxstronomy_n_subhalos = np.zeros(n_draws)
for i in range(n_draws):
    rng_draw, _ = jax.random.split(rng_draw)
    jaxstronomy_z_list, jaxstronomy_kwargs = draw_subhalos_jit(main_deflector_params, source_params, subhalo_params, cosmology_params, rng_draw)
    jaxstronomy_n_subhalos[i] = jnp.sum(jaxstronomy_kwargs['alpha_rs'] > 0.0)
    jaxstronomy_kwargs_list.append(jaxstronomy_kwargs)

In [None]:
# Start by comparing the number of subhalos being drawn and make sure they look similar.
_, bins, _ = plt.hist(paltas_n_subhalos, bins=30, histtype='step', lw=3)
plt.hist(jaxstronomy_n_subhalos, bins=bins, histtype='step', lw=3)
plt.xlabel('Number of Subhalos')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

In [None]:
model_index_jax = extract_np_array_jaxstronomy('model_index', jaxstronomy_kwargs_list)
_, bins, _ = plt.hist(np.log(extract_np_array_paltas('alpha_Rs', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('alpha_rs', jaxstronomy_kwargs_list, model_index_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Max Proxy $(\alpha_{Rs})$')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(np.log(extract_np_array_paltas('Rs', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('scale_radius', jaxstronomy_kwargs_list, model_index_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Size Proxy $(Rs)$')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(np.log(extract_np_array_paltas('r_trunc', paltas_kwargs_list)), bins=100, histtype='step', log=True, lw=3)
plt.hist(np.log(extract_without_zeros_jaxstronomy('trunc_radius', jaxstronomy_kwargs_list, model_index_jax)), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Truncation Radius')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(extract_np_array_paltas('center_x', paltas_kwargs_list), bins=100, histtype='step', log=True, lw=3)
plt.hist(extract_without_zeros_jaxstronomy('center_x', jaxstronomy_kwargs_list, model_index_jax), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Center x')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

_, bins, _ = plt.hist(extract_np_array_paltas('center_y', paltas_kwargs_list), bins=100, histtype='step', log=True, lw=3)
plt.hist(extract_without_zeros_jaxstronomy('center_y', jaxstronomy_kwargs_list, model_index_jax), bins=bins, histtype='step', log=True, lw=3)
plt.xlabel(r'Center y')
plt.ylabel('Count')
plt.legend(['paltas', 'jaxstronomy'])
plt.show()

## Generating a lensing image and comparing it to no substructure.

In [None]:
# Setup paltas and jaxstronomy configs
config_handler_zero = paltas.Configs.config_handler.ConfigHandler('comparison_files/input_config_paltas_zero.py')
config_handler = paltas.Configs.config_handler.ConfigHandler('comparison_files/input_config_paltas.py')
# No noise in paltas and jaxstronomy to make things easier
config_handler.add_noise=False
config_handler_zero.add_noise = False

from comparison_files import input_config_jaxstronomy
rng = jax.random.PRNGKey(0)
input_config_jax = input_config_jaxstronomy.get_config()
draw_image_and_truth_jit = jax.jit(functools.partial(jaxstronomy.input_pipeline.draw_image_and_truth, all_models=input_config_jax['all_models'],
                                                    principal_md_index=input_config_jax['principal_md_index'], principal_source_index=input_config_jax['principal_source_index'],
                                                    kwargs_simulation=input_config_jax['kwargs_simulation'], kwargs_detector=input_config_jax['kwargs_detector'],
                                                    kwargs_psf=input_config_jax['kwargs_psf'], truth_parameters=input_config_jax['truth_parameters'], normalize_image=False))
lensing_config_zero = copy.deepcopy(input_config_jax['lensing_config'])
lensing_config_zero['subhalo_params']['sigma_sub'] = jaxstronomy.input_pipeline.encode_constant(0.0)
cosmology_params = jaxstronomy.input_pipeline.intialize_cosmology_params(input_config_jax, rng)
grid_x, grid_y = jaxstronomy.input_pipeline.generate_grids(input_config_jax)

In [None]:
f, ax = plt.subplots(4, 4, figsize=(21.5, 22), sharex=False, sharey=False,gridspec_kw={'hspace': 0.02,'wspace':0.02},dpi=100)
image_zero, _ = config_handler_zero.draw_image()
for i in range(16):
    image, metadata = config_handler.draw_image()
    ax[i//4,i%4].imshow((image-image_zero)/image_zero,cmap='plasma')
    ax[i//4,i%4].get_xaxis().set_visible(False)
    ax[i//4,i%4].get_yaxis().set_visible(False)
    
plt.show()

In [None]:
f, ax = plt.subplots(4, 4, figsize=(21.5, 22), sharex=False, sharey=False,gridspec_kw={'hspace': 0.02,'wspace':0.02},dpi=100)
rng_draw, _ = jax.random.split(rng)
image_jax_zero, _ = draw_image_and_truth_jit(lensing_config_zero, cosmology_params, grid_x, grid_y, rng_draw)

for i in range(16):
    rng_draw, _ = jax.random.split(rng_draw)
    image_jax, truth_jax = draw_image_and_truth_jit(input_config_jax['lensing_config'], cosmology_params, grid_x, grid_y, rng_draw)
    ax[i//4,i%4].imshow((image_jax-image_jax_zero)/image_jax_zero,cmap='plasma')
    ax[i//4,i%4].get_xaxis().set_visible(False)
    ax[i//4,i%4].get_yaxis().set_visible(False)
    
plt.show()

In [None]:
# Check that the error is essentially floating point without substructure. Ignore edges, that's just different psf treatment.
plt.imshow((image_zero / image_jax_zero)[2:-2,2:-2])
plt.colorbar()

### Get into the guts with just the main deflector

In [None]:
# Start at the lowest level of comparison and work your way up until you find disagreement.

In [None]:
# Make sure we're operating on the same grid
grid_x, grid_y = jaxstronomy.input_pipeline.generate_grids(input_config_jax)

In [None]:
from lenstronomy.LensModel.lens_model import LensModel
from lenstronomy.LightModel.light_model import LightModel
from lenstronomy.ImSim.image_model import ImageModel

# First go back to basics, just do the lensing straight with lenstronomy versus the paltas code and make sure we get the same result
z_source = 1.5
z_lens = 0.5
cosmo = config_handler_zero.source_class.cosmo
lens_model = LensModel(['EPL_NUMBA'], z_source, z_source, lens_redshift_list=[z_lens], cosmo=cosmo.toAstropy(), multi_plane=True)
source_model = LightModel(['SERSIC_ELLIPSE'], source_redshift_list=z_source)

# These are the kwargs for the lens and the source
kwargs_lens = [{'theta_E': 1.1,
    'gamma': 2.0,
    'e1': 0.05263157894736841,
    'e2': 0.0,
    'center_x': 0.08,
    'center_y': -0.16}]
kwargs_source = [{'R_sersic': 1.5,
    'center_x': 0.16,
    'center_y': -0.08,
    'e1': 0.05263157894736841,
    'e2': 0.0,
    'n_sersic': 1.5,
    'amp': 10.000075423785681}]

true_x, true_y = lens_model.ray_shooting(grid_x, grid_y, kwargs_lens)
lenstronomy_image = source_model.surface_brightness(true_x, true_y, kwargs_source).reshape((256,256)) * 0.020**2

# And there we have the psf-free lenstronomy image
plt.imshow(lenstronomy_image)
plt.colorbar()
plt.show()

In [None]:
comv_x = jnp.zeros_like(grid_x)
comv_y = jnp.zeros_like(grid_y)
alpha_x = jnp.copy(grid_x)
alpha_y = jnp.copy(grid_y)
state = (comv_x, comv_y, alpha_x, alpha_y, 0.0)

all_lens_models = (jaxstronomy.lens_models.EPL,)
kwargs_z_lens = {
    'kwargs_lens': {'model_index': 0,
                    'theta_e': 1.1,
                    'slope': 2.0,
                    'center_x': 0.08,
                    'center_y': -0.16,
                    'axis_ratio': 0.9,
                    'angle': 0.0}, 
    'z_lens': 0.5}

state, _ = jaxstronomy.image_simulation._ray_shooting_step(state, kwargs_z_lens, cosmology_params, z_source, all_lens_models)
comv_x, comv_y, alpha_x, alpha_y, z_lens_last = state
delta_t = jaxstronomy.cosmology_utils.comoving_distance(cosmology_params, z_lens_last, z_source)
comv_x, comv_y = jaxstronomy.image_simulation._ray_step_add(comv_x, comv_y, alpha_x, alpha_y, delta_t)
x_source, y_source = jaxstronomy.cosmology_utils.comoving_to_angle(comv_x, comv_y, cosmology_params, z_source)

all_source_models = (jaxstronomy.source_models.SersicElliptic,)
kwargs_source = {
    'model_index':0,
    'amp': 10.0,
    'sersic_radius': 1.5,
    'n_sersic': 1.5,
    'axis_ratio': 0.9,
    'angle': 0.0,
    'center_x': 0.16,
    'center_y': -0.08
}
jaxstronomy_image = jnp.zeros_like(grid_x)
jaxstronomy_image, _ = jaxstronomy.image_simulation._add_surface_brightness(jaxstronomy_image, kwargs_source, x_source, y_source, all_source_models)
jaxstronomy_image = jaxstronomy_image.reshape((256,256)) * 0.02 ** 2

# And there we have the psf-free lenstronomy image
plt.imshow(jaxstronomy_image)
plt.colorbar()
plt.show()

In [None]:
# And there we have the psf-free lenstronomy image
plt.imshow(lenstronomy_image-jaxstronomy_image)
plt.colorbar()
plt.show()

### One step further, let's draw a population of substructure from jaxstronomy and feed it to both pipelines.

In [None]:
# Let's get the parameters for our models. This is copied straight from the input pipeline code.
def get_draw_of_structure(input_config, rng):
    kwargs_simulation = input_config['kwargs_simulation']
    lensing_config = input_config['lensing_config']
    all_models = input_config['all_models']
    principal_md_index = input_config['principal_md_index']
    principal_source_index = input_config['principal_source_index']
    kwargs_psf = input_config['kwargs_psf']
    kwargs_detector = input_config['kwargs_detector']

    num_z_bins = kwargs_simulation['num_z_bins']
    los_pad_length = kwargs_simulation['los_pad_length']
    subhalos_pad_length = kwargs_simulation['subhalos_pad_length']
    sampling_pad_length = kwargs_simulation['sampling_pad_length']

    rng_md, rng_source, rng_ll, rng_los, rng_sub, rng = jax.random.split(rng, 6)
    main_deflector_params = jaxstronomy.input_pipeline.extract_multiple_models(
        lensing_config['main_deflector_params'], rng_md,
        len(all_models['all_main_deflector_models'])
    )
    source_params = jaxstronomy.input_pipeline.extract_multiple_models(
        lensing_config['source_params'], rng_source,
        len(all_models['all_source_models'])
    )
    lens_light_params = jaxstronomy.input_pipeline.extract_multiple_models(
        lensing_config['lens_light_params'], rng_ll,
        len(all_models['all_source_models'])
    )
    los_params = jaxstronomy.input_pipeline.draw_sample(lensing_config['los_params'], rng_los)
    subhalo_params = jaxstronomy.input_pipeline.draw_sample(lensing_config['subhalo_params'], rng_sub)

    # Extract the principle model for redshifts and substructure draws.
    main_deflector_params_sub = jax.tree_util.tree_map(
        lambda x: x[principal_md_index], main_deflector_params
    )
    source_params_sub = jax.tree_util.tree_map(
        lambda x: x[principal_source_index], source_params
    )
    lens_light_params_sub = jax.tree_util.tree_map(
        lambda x: x[principal_source_index], lens_light_params
    )

    # Repackage the parameters.
    all_params = {
        'source_params': source_params_sub,
        'lens_light_params': lens_light_params_sub,
        'los_params': los_params, 'subhalo_params': subhalo_params,
        'main_deflector_params': main_deflector_params_sub
    }

    rng_los, rng_sub = jax.random.split(rng)
    los_before_tuple, los_after_tuple = jaxstronomy.los.draw_los(
        main_deflector_params_sub, source_params_sub, los_params,
        cosmology_params, rng_los, num_z_bins, los_pad_length)
    subhalos_z, subhalos_kwargs = jaxstronomy.subhalos.draw_subhalos(
        main_deflector_params_sub, source_params_sub, subhalo_params,
        cosmology_params, rng_sub, subhalos_pad_length, sampling_pad_length)

    kwargs_lens_all = {
        'z_array_los_before': los_before_tuple[0],
        'kwargs_los_before': los_before_tuple[1],
        'z_array_los_after': los_after_tuple[0],
        'kwargs_los_after': los_after_tuple[1],
        'kwargs_main_deflector': main_deflector_params,
        'z_array_main_deflector': main_deflector_params['z_lens'],
        'z_array_subhalos': subhalos_z, 'kwargs_subhalos': subhalos_kwargs}
    z_source = source_params_sub['z_source']
    
    return source_params, kwargs_lens_all, lens_light_params

# Pick an rng key that gives an odd image in the tests above.
rng_draw, _ = jax.random.split(rng)
rng_draw, _ = jax.random.split(rng_draw)
source_params, kwargs_lens_all, lens_light_params = get_draw_of_structure(input_config_jax, rng_draw)

In [None]:
# Now we can patch in this substructure as the output of get_lenstronomy_model_kwargs for our config_handler and we're good to go.
def get_lenstronomy_model_kwargs(new_sample=False):
    
    kwargs_lens = []
    kwargs_lens_light = []
    kwargs_ps = []
    kwargs_source = []
    lens_model_list = []
    lens_redshift_list = []
    lens_light_model_list = []
    point_source_model_list = []
    source_light_model_list = []
    source_redshift_list = []
    multi_plane = True
    z_source = source_params['z_source'][0]
    z_source_convention = z_source
    
    # Populate the list of lens model names and parameters for subhalos
    for i in range(int(jnp.sum(kwargs_lens_all['kwargs_subhalos']['model_index'] >= 0))):
        lens_model_list.append('TNFW')
        lens_redshift_list.append(float(kwargs_lens_all['z_array_subhalos'][i]))
        kwargs_lens.append({
            'alpha_Rs': float(kwargs_lens_all['kwargs_subhalos']['alpha_rs'][i]),
            'Rs': float(kwargs_lens_all['kwargs_subhalos']['scale_radius'][i]),
            'center_x': float(kwargs_lens_all['kwargs_subhalos']['center_x'][i]),
            'center_y': float(kwargs_lens_all['kwargs_subhalos']['center_y'][i]),
            'r_trunc': float(kwargs_lens_all['kwargs_subhalos']['trunc_radius'][i])})
        
    # Now add the main deflector
    lens_model_list.append('EPL_NUMBA')
    lens_redshift_list.append(0.5)
    kwargs_lens.append({
        'theta_E': 1.1,
        'gamma': 2.0,
        'e1': 0.05263157894736841,
        'e2': 0.0,
        'center_x': 0.08,
        'center_y': -0.16})
    
    # Add the source
    kwargs_source.append({'R_sersic': 1.5,
                          'center_x': 0.16,
                          'center_y': -0.08,
                          'e1': 0.05263157894736841,
                          'e2': 0.0,
                          'n_sersic': 1.5,
                          'amp': 10.000075423785681})
    source_light_model_list.append('SERSIC_ELLIPSE')
    source_redshift_list.append(z_source)
    
    kwargs_model = {}
    kwargs_params = {}
    kwargs_model['lens_model_list'] = lens_model_list
    kwargs_params['kwargs_lens'] = kwargs_lens
    kwargs_model['lens_redshift_list'] = lens_redshift_list
    kwargs_model['lens_light_model_list'] = lens_light_model_list
    kwargs_params['kwargs_lens_light'] = kwargs_lens_light
    kwargs_model['point_source_model_list'] = point_source_model_list
    kwargs_params['kwargs_ps'] = kwargs_ps
    kwargs_model['source_light_model_list'] = source_light_model_list
    kwargs_params['kwargs_source'] = kwargs_source
    kwargs_model['source_redshift_list'] = source_redshift_list
    kwargs_model['multi_plane'] = multi_plane
    kwargs_model['z_source'] = z_source
    kwargs_model['z_source_convention'] = z_source
    
    return kwargs_model, kwargs_params

config_handler.get_lenstronomy_models_kwargs = get_lenstronomy_model_kwargs

In [None]:
# Draw our two images with the subhalos forced to be the same.
image_compare, _ = config_handler.draw_image()
image_compare_jax, _ = draw_image_and_truth_jit(input_config_jax['lensing_config'], cosmology_params, grid_x, grid_y, rng_draw)

plt.imshow(image_compare)
plt.colorbar()
plt.show()

plt.imshow(image_compare_jax)
plt.colorbar()
plt.show()

In [None]:
plt.imshow((image_compare - image_zero) / image_zero, cmap='plasma')
plt.colorbar()
plt.show()

plt.imshow((image_compare_jax - image_jax_zero) / image_jax_zero, cmap='plasma')
plt.colorbar()
plt.show()

plt.imshow(((image_compare - image_compare_jax))[2:-2,2:-2])
plt.colorbar()
plt.show()