## Fast Fit Strong Lens model on ECDFS_G15422 with JAXtronomy

- Special shift of central lens
- author : Sylvie Dagoret-Campagne
- creation date : 2025-06-25
- python kernel : conda_py311_jax
- update : 2025-06-22
- last update : 2025-06-25
- last update : 2025-06-25
From https://github.com/lenstronomy/JAXtronomy/blob/main/notebooks/modeling_a_simple_Einstein_ring.ipynb


In [None]:
# For Angle conversion
import matplotlib.pyplot as plt
from astropy.coordinates import Angle
import astropy.units as u
from astropy.coordinates import SkyCoord
import pandas as pd
import numpy as np
from astropy.io import fits
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from astropy.nddata import Cutout2D
from astropy.visualization import simple_norm, ZScaleInterval,PercentileInterval

In [None]:
import copy

In [None]:
from IPython.display import Image, display

In [None]:
# import main simulation class of lenstronomy
from lenstronomy.Util import util, image_util
from lenstronomy.Data.psf import PSF
from lenstronomy.Workflow.fitting_sequence import FittingSequence as FittingSequence_ref

from jaxtronomy.Workflow.fitting_sequence import FittingSequence
from jaxtronomy.LensModel.lens_model import LensModel
from jaxtronomy.LightModel.light_model import LightModel
from jaxtronomy.PointSource.point_source import PointSource
from jaxtronomy.Data.imaging_data import ImageData
from jaxtronomy.ImSim.image_model import ImageModel

# Currently, jaxtronomy supports the following deflector and source profiles:
from jaxtronomy.LensModel.profile_list_base import (
    _JAXXED_MODELS as _JAXXED_DEFLECTOR_MODELS,
)
from jaxtronomy.LightModel.light_model_base import (
    _JAXXED_MODELS as _JAXXED_LIGHT_MODELS,
)

print("Deflector models:", _JAXXED_DEFLECTOR_MODELS)
print("Light models:", _JAXXED_LIGHT_MODELS)

In [None]:
plt.rcParams["figure.figsize"] = (12,8)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'
plt.rcParams['legend.fontsize']=  12
plt.rcParams['font.size'] = 12

## Configuration

### Target in ECDFS

- path to article : https://arxiv.org/pdf/1104.0931
- visual selection of the tiles : https://archive.stsci.edu/prepds/gems/browser.html
- path file download : https://archive.stsci.edu/pub/hlsp/gems/v_mk1/

In [None]:
fig_table1 = "input_figs/table1_gemcandidates.png"
fig_table2 = "input_figs/table2_gemcandidates.png"

In [None]:
Image(url= fig_table1,width=800)

In [None]:
Image(url= fig_table2,width=500)

In [None]:
#15422 44 03:32:38.21 â€“27:56:53.2 
ra1 = "03:32:38.21 hours"
dec1 = "-27:56:53.2 degrees"
tile1 = 44

#34244 94 03:32:06.45 â€“27:47:28.6 
ra2 = "03:32:06.45 hours"
dec2 = "-27:47:28.6 degrees"
tile2 = 94
# Je ne trouve pas cette tile ==> FindTileForCutoutGEM
tile2 = 32

#40173 35 03:33:19.45 â€“27:44:50.0 
ra3 = "03:33:19.45 hours"
dec3 = "-27:44:50.0 degrees"
tile3 = 35

#43242 45 03:31:55.35 â€“27:43:23.5 
ra4 = "03:31:55.35 hours"
dec4 = "-27:43:23.5 degrees"
tile4 = 45

#46446 47 03:31:35.94 â€“27:41:48.2 
ra5 = "03:31:35.94 hours"
dec5 = "-27:41:48.2 degrees"
tile5 = 47

#12589 03:31:24.89 âˆ’27:58:07.0
ra6 = "03:31:24.89 hours"
dec6 = "-27:58:07.0 degrees"
tile6 = 17

#43797 03:31:31.74 âˆ’27:43:00.8 
ra7 = "03:31:31.74 hours"
dec7 = "-27:43:00.8 degrees"
tile7 = 47

#28294 03:31:50.54 âˆ’27:50:28.4 
ra8 = "03:31:50.54 hours"
dec8 = "-27:50:28.4 degrees"
tile8 = 33

#36857 03:31:53.24 âˆ’27:46:18.9
ra9 = "03:31:53.24 hours"
dec9 = "-27:46:18.9 degrees"
tile9 = 38

#36714 03:32:59.78 âˆ’27:46:26.4 
ra10 = "03:32:59.78 hours"
dec10 = "-27:46:26.4 degrees"
tile10 = 37


In [None]:
lsstcomcam_targets = {}
# high rank
lsstcomcam_targets["ECDFS_G15422"] = {"field_name": "GEMS-15422", "ra": 53.159208333333325, "dec": -27.94811111111111,"tile":tile1}
lsstcomcam_targets["ECDFS_G34244"] = {"field_name": "GEMS-34244", "ra": 53.02687499999999 , "dec": -27.79127777777778,"tile":tile2}
lsstcomcam_targets["ECDFS_G40173"] = {"field_name": "GEMS-40173", "ra": 53.33104166666666 , "dec": -27.747222222222224,"tile":tile3}
lsstcomcam_targets["ECDFS_G43242"] = {"field_name": "GEMS-43242", "ra": 52.980624999999996 , "dec": -27.72319444444444,"tile":tile4}
lsstcomcam_targets["ECDFS_G46446"] = {"field_name": "GEMS-46446", "ra": 52.89975 , "dec": -27.696722222222224,"tile":tile5}

# low rank
lsstcomcam_targets["ECDFS_G12589"] = {"field_name": "GEMS-12589", "ra": 52.85370833333333, "dec": -27.96861111111111,"tile":tile6}
lsstcomcam_targets["ECDFS_G43797"] = {"field_name": "GEMS-43797", "ra": 52.88224999999999, "dec": -27.71688888888889,"tile":tile7}

lsstcomcam_targets["ECDFS_G28294"] = {"field_name": "GEMS-28294", "ra": 52.960583333333325 , "dec": -27.84122222222222,"tile":tile8}
lsstcomcam_targets["ECDFS_G6857"] = {"field_name": "GEMS-6857", "ra": 52.97183333333333 , "dec": -27.771916666666666,"tile":tile9}
lsstcomcam_targets["ECDFS_G36714"] = {"field_name": "GEMS-36714", "ra": 53.249083333333324, "dec": -27.773999999999997,"tile":tile10}


In [None]:
df = pd.DataFrame(lsstcomcam_targets).T

In [None]:
# candidates
key = "ECDFS_G15422"
#key = "ECDFS_G34244"
#key = "ECDFS_G40173"
#key= "ECDFS_G43242"
#key= "ECDFS_G46446"

# unknown
#key = "ECDFS_G12589"
#key = "ECDFS_G43797"
#key = "ECDFS_G28294"
#key = "ECDFS_G6857"
#key = "ECDFS_G36714"

the_target = lsstcomcam_targets[key]
target_ra = the_target["ra"]
target_dec = the_target["dec"]
target_name = the_target["field_name"]
tile_num = the_target["tile"]
coord = SkyCoord(ra=target_ra, dec=target_dec, unit=(u.deg, u.deg))

target_title = (
    the_target["field_name"] + f" (ra,dec) = ({target_ra:.2f},{target_dec:.2f}) "
)

In [None]:
tile_num

### Access to remote files

- visual selection of the tiles : https://archive.stsci.edu/prepds/gems/browser.html
- path file download : https://archive.stsci.edu/pub/hlsp/gems/v_mk1/

### Config

- http://archive.stsci.edu/pub/hlsp/gems/v_mk1/h_gems_v35_mk1.fits
- url = f"http://archive.stsci.edu/pub/hlsp/gems/z_mk1/h_gems_z{tile_num}_mk1.fits"

In [None]:
# select your bands
SELECT_VBAND = "v"
SELECT_ZBAND = "z"
SELECT_BAND = SELECT_ZBAND

#### define the cutoutfilename path

In [None]:
cutout_inputpath = f"cutout_gems_{key}_b{SELECT_BAND}.fits"
cutout_inputpath_wt = f"cutout_gems_{key}_b{SELECT_BAND}_wt.fits"

### Load image and weight


ðŸŒˆ Liste des stretch disponibles dans simple_norm
Tu peux utiliser ces chaÃ®nes de caractÃ¨res comme valeur de stretch :
Valeur stretch	Description
- "linear"	Ã‰chelle linÃ©aire simple (valeurs brutes entre min_cut et max_cut)
- "sqrt"	Ã‰chelle racine carrÃ©e â€“ met en valeur les faibles intensitÃ©s
- "power"	Ã‰chelle puissance (x^power)
- "log"	Ã‰chelle logarithmique â€“ bon pour Ã©tendre la dynamique
- "asinh"	Ã‰chelle arcsinus hyperbolique â€“ lisse, bien adaptÃ©e aux images astronomiques
- "sinh"	Hyperbolique sinus â€“ moins courant
- "histeq"	Ã‰galisation d'histogramme â€“ utile pour rÃ©partir les contrastes

In [None]:
# Calcul des coupures automatiques avec ZScale
#interval = ZScaleInterval()
#interval = PercentileInterval(98) 
#vmin_v, vmax_v = interval.get_limits(data_v)

# Normalisation avec stretch asinh
#norm_v = simple_norm(data_v, stretch='asinh', min_cut=vmin_v, max_cut=vmax_v)
#norm_v = simple_norm(data_v, 'asinh')

In [None]:
def getsimple_norm(image):
    vmin, vmax = ZScaleInterval().get_limits(image)
    norm = simple_norm(data, stretch='asinh', vmin=vmin, vmax=vmax)
    return  norm

In [None]:
def load_fits_image_with_norm(path):
    with fits.open(path) as hdul:
        data = hdul[0].data
        wcs = WCS(hdul[0].header)
        header = hdul[0].header
    vmin, vmax = ZScaleInterval().get_limits(data)
    norm = simple_norm(data, stretch='asinh', vmin=vmin, vmax=vmax)
    return data, wcs, norm, header

In [None]:
def add_EastNorthFrame(x0,y0,ax,width =30, dwidth = 5):
    """
    """
    #x0, y0 = 50, 150  # CoordonnÃ©e de base pour les flÃ¨ches
    # Ajout de la flÃ¨che vers le Nord
    ax.annotate('', xy=(x0, y0 + width), xytext=(x0, y0),
             arrowprops=dict(facecolor='white', width=3, headwidth=10))
    ax.text(x0, y0 + (width+dwidth), 'N', color='white', ha='center', va='bottom', fontsize=14,fontweight="bold")

    # Ajout de la flÃ¨che vers l'Est
    ax.annotate('', xy=(x0 - width, y0), xytext=(x0, y0),
             arrowprops=dict(facecolor='white', width=3, headwidth=10))
    ax.text(x0 - (width+dwidth), y0, 'E', color='white', ha='left', va='center', fontsize=14,fontweight="bold")
    

In [None]:
# Charger les donnÃ©es

data, wcs, norm,header = load_fits_image_with_norm(cutout_inputpath)
data_wt, wcs_wt, norm_wt,header_wt = load_fits_image_with_norm(cutout_inputpath_wt)


fig = plt.figure(figsize=(12, 6))


ax1 = fig.add_subplot(1, 2, 1, projection=wcs)
im1 = ax1.imshow(data, origin='lower', cmap='gray', norm=norm)
ax1.set_title(f'Flux {key} {SELECT_BAND} band')
ax1.set_xlabel('RA')
ax1.set_ylabel('Dec')
ax1.grid(color="blue",linestyle=":")
add_EastNorthFrame(50,150,ax1)



ax2 = fig.add_subplot(1, 2, 2, projection=wcs_wt)
im2 = ax2.imshow(data_wt, origin='lower', cmap='gray', norm=norm_wt)
ax2.set_title(f'Weight {key} {SELECT_BAND}  band')
ax2.set_xlabel('RA')
ax2.set_ylabel('')
ax2.grid(color="blue",linestyle=":")
add_EastNorthFrame(50,150,ax2)

# Mise en page
plt.tight_layout()
plt.show()


## Study peaks 

The goal is to find where are located the possible images of the sources in the image

#### Work on WCS

In [None]:
from astropy.wcs.utils import proj_plane_pixel_scales

In [None]:
# Donne les Ã©chelles angulaires en degrÃ©s/pixel pour chaque axe
pixscale_deg = proj_plane_pixel_scales(wcs)  # [deg/pixel] en X et Y
pixscale_arcsec = pixscale_deg * 3600        # [arcsec/pixel]
print(f"Pixel scale : {pixscale_arcsec[0]:.5f} x {pixscale_arcsec[1]:.5f} arcsec/pixel")
# pixel scale in arcsec
pixel_scale = np.average(pixscale_arcsec)

ðŸ§  DÃ©tails techniques :
La matrice CD (coordinate description) est une matrice 2Ã—2 qui donne la transformation linÃ©aire (rotation, Ã©chelle, cisaillement) entre les pixels et les coordonnÃ©es en degrÃ©s (RA/Dec, etc).
Si lâ€™en-tÃªte FITS contient seulement PCi_j et CDELTi, la matrice CD est implicitement obtenue par :
$ CD = CDELT \times  PC $
Astropy fournit la mÃ©thode .get_cd() pour obtenir la matrice CD dans tous les cas.

In [None]:
cd = wcs.wcs.cd  # si prÃ©sent
if cd is None:
    # Sinon, recompose Ã  partir de PC et CDELT
    cd = wcs.wcs.get_cd()

In [None]:
print("Matrice CD (de pixels vers angles en degrÃ©s):")
print(cd)

In [None]:
transform_pix2angle = cd*3600.
print("Matrice CD (de pixels vers angles en arcsec):")
print(transform_pix2angle)

In [None]:
# Convertit le pixel (0, 0) en coordonnÃ©es du ciel (RA, Dec)
ra_at_xy_0, dec_at_xy_0 = wcs.wcs_pix2world(0, 0, 0)  # le 3e argument = 0 â†’ origine '0-based' comme Python

In [None]:
delta_ra_at_xy_0 = (ra_at_xy_0 - target_ra)*3600
delta_dec_at_xy_0 = (dec_at_xy_0 - target_dec)*3600
print(f"relative position of corner x=0 y=0 is delta_ra = {delta_ra_at_xy_0:.3f} arcsec, delta_dec = {delta_dec_at_xy_0:.3f} arcsec ")

### Image smoothing

In [None]:
fwhm_arcsec = 0.11 # arcsec
fwhm_pixel = fwhm_arcsec/pixel_scale
sigma_pixel_smoothing = fwhm_pixel/2.36
sigma_pixel_smoothing = int(np.ceil(sigma_pixel_smoothing))
print(f"sigma_pixel_smoothing = {sigma_pixel_smoothing}")

In [None]:
from scipy.ndimage import gaussian_filter
smoothed_image = gaussian_filter(data, sigma= sigma_pixel_smoothing )
print(f"smoothed image shape : {smoothed_image.shape}")

In [None]:
from skimage.feature import peak_local_max

# Trouver les maxima locaux isoles (min_distance = 10) sur l'image lissÃ©e 
coordinates = peak_local_max(smoothed_image, min_distance=10)

In [None]:
# Extraire les valeurs correspondantes dans l'image lissÃ©e
values = smoothed_image[coordinates[:, 0], coordinates[:, 1]]

In [None]:
# CrÃ©er un tableau combinÃ© (x, y, valeur)
# RÃ©organiser en (x, y, value)
x = coordinates[:, 1]
y = coordinates[:, 0]
results = np.column_stack((x, y, values))  # shape: (N, 3)

In [None]:
df_maxima = pd.DataFrame(results, columns=["x", "y", "max"])
df_maxima = df_maxima.sort_values(by="max", ascending=False)

In [None]:
fig = plt.figure(figsize=(16, 16))

NCONT = 50

ax1 = fig.add_subplot(2, 2, 1, projection=wcs)
im1 = ax1.imshow(data, origin='lower', cmap='jet', norm=norm)
ax1.set_title(f'Flux {key} {SELECT_BAND} band')
ax1.set_xlabel('RA')
ax1.set_ylabel('Dec')
ax1.grid(color="blue",linestyle=":")
ax1.contour(data, levels=NCONT, colors='white', linewidths=1)
add_EastNorthFrame(40,155,ax1)



ax2 = fig.add_subplot(2, 2, 2, projection=wcs)
im2 = ax2.imshow(smoothed_image, origin='lower', cmap='jet', norm=norm)
ax2.set_title(f'Smoothed flux {key} {SELECT_BAND}  band')
ax2.set_xlabel('RA')
ax2.set_ylabel('')
ax2.grid(color="blue",linestyle=":")
ax2.contour(smoothed_image, levels=NCONT, colors='white', linewidths=1)
add_EastNorthFrame(40,155,ax2)


ax3 = fig.add_subplot(2, 2, 3)
im3 = ax3.imshow(data, origin='lower', cmap='jet', norm=norm)
ax3.set_title(f'Flux {key} {SELECT_BAND} band')
ax3.set_xlabel('X pix')
ax3.set_ylabel('Y pix')
ax3.grid(color="blue",linestyle=":")
ax3.contour(data, levels=NCONT, colors='white', linewidths=1)
add_EastNorthFrame(40,155,ax3)



ax4 = fig.add_subplot(2, 2, 4)
im4 = ax4.imshow(smoothed_image, origin='lower', cmap='jet', norm=norm)
#im4 = ax4.imshow(smoothed_image, origin='lower', cmap='jet')
ax4.set_title(f'Smoothed flux {key} {SELECT_BAND}  band')
ax4.set_xlabel('X pix')
ax4.grid(color="blue",linestyle=":")
ax4.contour(smoothed_image, levels=NCONT, colors='white', linewidths=1)
add_EastNorthFrame(40,155,ax4)


# Mise en page
plt.tight_layout()
plt.show()


In [None]:
# List of maxima
df_maxima

#### Most interesting maxima
- central maximum at (x,y) = (93, 105)
- One local maximum at (x,y) = (114,46)
- One local maximum at (x,y) = (146,126)

In [None]:
xpix0 = int(df_maxima.iloc[0]["x"])
ypix0 = int(df_maxima.iloc[0]["y"])
print(f"(x0,y0) = ({xpix0},{ypix0})")

In [None]:
ra_at_xy_centralmax, dec_at_xy_centralmax = wcs.wcs_pix2world(xpix0,ypix0, 0)  # le 3e argument = 0 â†’ origine '0-based' comme Python
delta_ra_at_xy_centralmax = (ra_at_xy_centralmax - target_ra)*3600
delta_dec_at_xy_centralmax = (dec_at_xy_centralmax - target_dec)*3600
print(f"relative position of centralmax x={xpix0} y={ypix0} is delta_ra = {delta_ra_at_xy_centralmax:.3f} arcsec, delta_dec = {delta_dec_at_xy_centralmax:.3f} arcsec ")

In [None]:
xpix1 = int(df_maxima.iloc[1]["x"])
ypix1 = int(df_maxima.iloc[1]["y"])
print(f"(x1,y1) = ({xpix1},{ypix1})")

In [None]:
ra_at_xy_s1, dec_at_xy_s1 = wcs.wcs_pix2world(xpix1, ypix1, 0)  # le 3e argument = 0 â†’ origine '0-based' comme Python
delta_ra_at_xy_s1 = (ra_at_xy_s1 - target_ra)*3600
delta_dec_at_xy_s1 = (dec_at_xy_s1 - target_dec)*3600
print(f"relative position of source1  x={xpix1} y={ypix1} is delta_ra = {delta_ra_at_xy_s1:.3f} arcsec, delta_dec = {delta_dec_at_xy_s1:.3f} arcsec ")

In [None]:
xpix2 = int(df_maxima.iloc[2]["x"])
ypix2 = int(df_maxima.iloc[2]["y"])
print(f"(x2,y2) = ({xpix2},{ypix2})")

In [None]:
ra_at_xy_s2, dec_at_xy_s2 = wcs.wcs_pix2world(xpix2,ypix2, 0)  # le 3e argument = 0 â†’ origine '0-based' comme Python
delta_ra_at_xy_s2 = (ra_at_xy_s2 - target_ra)*3600
delta_dec_at_xy_s2 = (dec_at_xy_s2 - target_dec)*3600
print(f"relative position of source2 x={xpix2} y={ypix2} is delta_ra = {delta_ra_at_xy_s1:.3f} arcsec, delta_dec = {delta_dec_at_xy_s1:.3f} arcsec ")

#### loop on images of sources
- Extract coordinates in ra,dec of images of the source

In [None]:
# Nsource_images = 2 or 4
Nsource_images = 2
all_delta_ra_at_xy_s = np.zeros(Nsource_images)
all_delta_dec_at_xy_s = np.zeros(Nsource_images)
for idx in range(Nsource_images):
    xpix = df_maxima.iloc[idx+1]['x']
    ypix = df_maxima.iloc[idx+1]['y']
    ra_at_xy_s, dec_at_xy_s = wcs.wcs_pix2world(xpix, ypix, 0) 

    # take relative distance of the image wrt the lens in arcsec
    delta_ra_at_xy_s = (ra_at_xy_s - target_ra)*3600
    delta_dec_at_xy_s = (dec_at_xy_s - target_dec)*3600

    all_delta_ra_at_xy_s[idx] = delta_ra_at_xy_s
    all_delta_dec_at_xy_s[idx] = delta_dec_at_xy_s

In [None]:
all_delta_ra_at_xy_s

In [None]:
all_delta_dec_at_xy_s

## Header

### Find total exposure time

In [None]:
#header

In [None]:
exp_time = header["EXPTIME"]
print(f"Total exposure time of the image {exp_time} seconds")

## Plot sigma

In [None]:
X = data.flatten()
Y = (1./np.sqrt(data_wt)).flatten()
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10, 4))
ax1.scatter(X, Y,marker="o",color="b",alpha=0.1)
ax1.set_yscale("log")
ax1.set_xscale("log")
ax1.set_xlim(0.00001,1)
ax1.set_ylim(0.028,0.06)
title = f'{key} {SELECT_BAND} band $\sigma_F$ vs Flux'
ax1.set_title(title)
ax1.set_xlabel("Flux")
ax1.set_ylabel("$\sigma_{Flux}$")

ax2.hist(Y,bins=50,histtype="step",color="b",lw=3)
ax2.set_title("standard deviation")
ax2.set_xlabel("$\sigma_{Flux}$")
cut_sigma = 0.042
ax2.axvline(cut_sigma,color="k",ls=":")

### In high sigma tailÂ¶

In [None]:
pixel_select = Y > cut_sigma

In [None]:
X1 = X[pixel_select]
Y1 = Y[pixel_select]
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10, 4))
ax1.scatter(X1, Y1,marker="o",color="b",alpha=0.1)
ax1.set_yscale("log")
ax1.set_xscale("log")
ax1.set_xlim(0.00001,1)
#ax1.set_ylim(cut_sigma,0.06)
title = f'{key} {SELECT_BAND} band $\sigma_F$ vs Flux'
ax1.set_title(title)
ax1.set_xlabel("Flux")
ax1.set_ylabel("$\sigma_{Flux}$")

ax2.hist(Y1,bins=50,histtype="step",color="b",lw=3)
ax2.set_title("standard deviation")
ax2.set_xlabel("$\sigma_{Flux}$")

In [None]:
F = data
sig = (1./np.sqrt(data_wt))

In [None]:
Fsel = np.where(sig>cut_sigma, F,0.0)

In [None]:
fig = plt.figure(figsize=(12, 6))

ax1 = fig.add_subplot(1, 2, 1, projection=wcs)
im1 = ax1.imshow(data, origin='lower', cmap='gray', norm=norm)
ax1.set_title(f'Flux {key} {SELECT_BAND} band')
ax1.set_xlabel('RA')
ax1.set_ylabel('Dec')
ax1.grid(color="blue",linestyle=":")
add_EastNorthFrame(40,155,ax1)

ax2 = fig.add_subplot(1, 2, 2, projection=wcs_wt)
im2 = ax2.imshow(Fsel, origin='lower', cmap='gray', norm=norm)
ax2.set_title(f'Flux sel {key} {SELECT_BAND}  band')
ax2.set_xlabel('RA')
ax2.set_ylabel('')
add_EastNorthFrame(40,155,ax2)

# Mise en page
plt.tight_layout()
plt.show()

### Remove high sigma tail

In [None]:
pixel_select = Y < cut_sigma

In [None]:
X1 = X[pixel_select]
Y1 = Y[pixel_select]
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10, 4))
ax1.scatter(X1, Y1,marker="o",color="b",alpha=0.1)
ax1.set_yscale("log")
ax1.set_xscale("log")
ax1.set_xlim(0.00001,1)
#ax1.set_ylim(cut_sigma,0.06)
title = f'{key} {SELECT_BAND} band $\sigma_F$ vs Flux'
ax1.set_title(title)
ax1.set_xlabel("Flux")
ax1.set_ylabel("$\sigma_{Flux}$")

ax2.hist(Y1,bins=50,histtype="step",color="b",lw=3)
ax2.set_title("standard deviation")
ax2.set_xlabel("$\sigma_{Flux}$")

### Check removed high sigma tail

In [None]:
F = data
sig = (1./np.sqrt(data_wt))

In [None]:
Fsel = np.where(sig<cut_sigma, F,0.0)

In [None]:
fig = plt.figure(figsize=(12, 6))

ax1 = fig.add_subplot(1, 2, 1, projection=wcs)
im1 = ax1.imshow(data, origin='lower', cmap='gray', norm=norm)
ax1.set_title(f'Flux {key} {SELECT_BAND} band')
ax1.set_xlabel('RA')
ax1.set_ylabel('Dec')
ax1.grid(color="blue",linestyle=":")
add_EastNorthFrame(40,155,ax1)


ax2 = fig.add_subplot(1, 2, 2, projection=wcs_wt)
im2 = ax2.imshow(Fsel, origin='lower', cmap='gray', norm=norm)
ax2.set_title(f'Flux sel {key} {SELECT_BAND}  band')
ax2.set_xlabel('RA')
ax2.set_ylabel('')
ax2.grid(color="blue",linestyle=":")
add_EastNorthFrame(40,155,ax2)


# Mise en page
plt.tight_layout()
plt.show()

In [None]:
fig = plt.figure(figsize=(12, 6))

ax1 = fig.add_subplot(1, 2, 1)
im1 = ax1.imshow(data, origin='lower', cmap='gray', norm=norm)
ax1.set_title(f'Flux {key} {SELECT_BAND} band')
ax1.set_xlabel('pixels (along RA)')
ax1.set_ylabel('picels (along Dec)')
ax1.grid(color="blue",linestyle=":")
add_EastNorthFrame(40,155,ax1)


ax2 = fig.add_subplot(1, 2, 2)
im2 = ax2.imshow(Fsel, origin='lower', cmap='gray', norm=norm)
ax2.set_title(f'Flux sel {key} {SELECT_BAND}  band')
ax2.set_xlabel('pixels (along RA)')
ax2.set_ylabel('')
ax2.grid(color="blue",linestyle=":")
add_EastNorthFrame(40,155,ax2)


# Mise en page
plt.tight_layout()
plt.show()

#### Remark about coordinate (ra,dec)
- RA increase toward left
- Dec increase toward up

### LensModel module

 $\texttt{LensModel}$ module
$\texttt{LensModel}$ and its sub-packages execute all the purely lensing related tasks of *lenstronomy*. This includes ray-shooting, solving the lens equation, arrival time computation and non-linear solvers to optimize lens models for specific image configurations. The module allows consistent integration with single and multi plane lensing and an arbitrary superpositions of lens models. There is a wide range of lens models available. For details we refer to to the online-documentation.

### Single plane lensing
As an example of initializing a single plane lens model as a super-position of an elliptical power-law potential and a singular isothermal sphere mass distribution and execute some basic routines:

In [None]:
ra_at_xy_centralmax, dec_at_xy_centralmax

In [None]:
# import the LensModel class #
from lenstronomy.LensModel.lens_model import LensModel

# specify the choice of lens models #
lens_model_list = ['EPL', 'SHEAR']

# setup lens model class with the list of lens models #
lensModel = LensModel(lens_model_list=lens_model_list)

# define parameter values of lens models #
#kwargs_spep = {'theta_E': 1.1, 'e1': 0.1, 'e2': 0.1, 'gamma': 2., 'center_x': 0.0, 'center_y': 0}
kwargs_spep = {'theta_E': 1.1, 'e1': 0.1, 'e2': 0.2, 'gamma': 2., 'center_x': delta_ra_at_xy_centralmax , 'center_y': delta_dec_at_xy_centralmax}
kwargs_shear = {'gamma1': -0.01, 'gamma2': .03}
kwargs_lens = [kwargs_spep, kwargs_shear]


# image plane coordinate #
# Need to guess where is one image
# theta_ra, theta_dec = target_ra, target_dec
#theta_ra, theta_dec = -1.3, .0
# Note this does not work later if I choose s2 !
theta_ra, theta_dec  = delta_ra_at_xy_s1, delta_dec_at_xy_s1

# source plane coordinate #
beta_ra, beta_dec = lensModel.ray_shooting(theta_ra, theta_dec, kwargs_lens)
# Fermat potential #
fermat_pot = lensModel.fermat_potential(x_image=theta_ra, y_image=theta_dec, x_source=beta_ra, y_source=beta_dec, kwargs_lens=kwargs_lens)

# Magnification #
mag = lensModel.magnification(theta_ra, theta_dec, kwargs_lens)

In [None]:
print(f"angle-beta : ra = {beta_ra:.3f} arcsec, dec = {beta_dec:.3f} arcsec")

- According the initial guess where is the image, one can find the infered true source position and where are the other sources. 

In [None]:
lens_plot.lens_model_plot?

In [None]:
from lenstronomy.Plots import lens_plot
f, axex = plt.subplots(1, 1, figsize=(5, 5), sharex=False, sharey=False)
lens_plot.lens_model_plot(axex, lensModel=lensModel, kwargs_lens=kwargs_lens, sourcePos_x=beta_ra, sourcePos_y=beta_dec, point_source=True, with_caustics=True, fast_caustic=True, coord_inverse=True)
#lens_plot.lens_model_plot(axex[1], lensModel=lensModel_mp, kwargs_lens=kwargs_lens_mp, sourcePos_x=beta_ra, sourcePos_y=beta_dec, point_source=True, with_caustics=True, fast_caustic=True, coord_inverse=False)
axex.grid()
f.show()

In [None]:
# keep the imports and variables from above #
# import the lens equation solver class #
from lenstronomy.LensModel.Solver.lens_equation_solver import LensEquationSolver

# specifiy the lens model class to deal with #
solver = LensEquationSolver(lensModel)

# solve for image positions provided a lens model and the source position #
theta_ra, theta_dec = solver.image_position_from_source(beta_ra, beta_dec, kwargs_lens)

# the magnification of the point source images #
mag = lensModel.magnification(theta_ra, theta_dec, kwargs_lens)

- Then the solver create the vectors of images position (theta_ra,theta_dec)

In [None]:
print(f"theta_ra = {theta_ra} (arcsec), theta_dec = {theta_dec} (arcsec)")

#### $\texttt{LightModel}$ module

The $\texttt{LightModel}$ class provides the functionality to describe galaxy surface brightnesses. $\texttt{LightModel}$ supports various analytic profiles as well as representations in shapelet basis sets. Any superposition of different profiles is allowed.

The parameter levelig the amplitude of the surface brightness of a certain profile is named 'amp'. The units are not further specified and are effectively reflecting a surface brightness quantitiy integrated over the unit of angle square. In the $\texttt{SimulationAPI}$ module, the user can conveniently chose astronomical magnitudes as inputs to the profiles, provided the magnitude zero point is declared. Have a look at the specific notebook to see a demonstration.

In [None]:
# import the LightModel class #
from lenstronomy.LightModel.light_model import LightModel
# set up the list of light models to be used #
source_light_model_list = ['SERSIC']
lightModel_source = LightModel(light_model_list=source_light_model_list)
lens_light_model_list = ['SERSIC_ELLIPSE']
lightModel_lens = LightModel(light_model_list=lens_light_model_list)
# define the parameters #
kwargs_light_source = [{'amp': 100, 'R_sersic': 0.1, 'n_sersic': 1.5, 'center_x': beta_ra, 'center_y': beta_dec}]
import lenstronomy.Util.param_util as param_util
e1, e2 = param_util.phi_q2_ellipticity(phi=0.5, q=0.7)
kwargs_light_lens = [{'amp': 1000, 'R_sersic': 0.1, 'n_sersic': 2.5, 'e1': e1, 'e2': e2, 'center_x': delta_ra_at_xy_centralmax, 'center_y': delta_dec_at_xy_centralmax}]

# evaluate surface brightness at a specific position #
flux = lightModel_lens.surface_brightness(x=1, y=1, kwargs_list=kwargs_light_lens)

### $\texttt{PointSource}$ module
To accurately predict and model the position and flux of point sources, one has to apply different numerical procedures than for extended surface brightness features. The $\texttt{PointSource}$ class manages the different options in describing point sources (e.g. in the image plane or source plane, with fixed magnification or allowed with individual variations thereof) and provides a homogeneous interface to access image positions and magnifications.

In [None]:
# import the PointSource class #
from lenstronomy.PointSource.point_source import PointSource

# unlensed source position #
point_source_model_list = ['SOURCE_POSITION']
pointSource = PointSource(point_source_type_list=point_source_model_list, lens_model=lensModel, fixed_magnification_list=[True])
kwargs_ps = [{'ra_source': beta_ra, 'dec_source': beta_dec, 'source_amp': 100}]
# return image positions and amplitudes #
x_pos, y_pos = pointSource.image_position(kwargs_ps=kwargs_ps, kwargs_lens=kwargs_lens)
point_amp = pointSource.image_amplitude(kwargs_ps=kwargs_ps, kwargs_lens=kwargs_lens)

# lensed image positions (solution of the lens equation) #
point_source_model_list = ['LENSED_POSITION']
pointSource = PointSource(point_source_type_list=point_source_model_list, lens_model=lensModel, fixed_magnification_list=[False])
kwargs_ps = [{'ra_image': theta_ra, 'dec_image': theta_dec, 'point_amp': np.abs(mag)*30}]
# return image positions and amplitudes #
x_pos, y_pos = pointSource.image_position(kwargs_ps=kwargs_ps, kwargs_lens=kwargs_lens)
point_amp = pointSource.image_amplitude(kwargs_ps=kwargs_ps, kwargs_lens=kwargs_lens)

### Make the Data module

In [None]:
Npix_y,Npix_x = data.shape

In [None]:
noise_map = 1./np.sqrt(data_wt)

#### Make the PSF

La camÃ©ra ACS/WFC (Wide Field Channel) du Hubble Space Telescope a une Ã©chelle angulaire (pixel scale) de :

- ACS/WFC : ~0.0495 arcsec/pixel
- ACS/HRC (High Resolution Channel, maintenant inactive) : ~0.028 arcsec/pixel

In [None]:
from lenstronomy.Data.psf import PSF
PSF_FWHM = 0.11 # arcsec
# 1.kwargs pour la PSF
kwargs_psf_gaussian = {
    'psf_type': 'GAUSSIAN',
    'fwhm': PSF_FWHM,
    'pixel_size': pixel_scale
}

In [None]:
# 2. DÃ©finir la PSF comme gaussienne approximative

#psf_class = PSF(psf_type='PIXEL', kernel_point_source=psf_kernel)
# FWHM typique en z-band (ACS) : 0.11 arcsec

psf_class = PSF(psf_type='GAUSSIAN', fwhm=PSF_FWHM, pixel_size=pixel_scale)

In [None]:
kernel = psf_class.kernel_point_source

#### Make the data (True observations)

In [None]:
from lenstronomy.Data.imaging_data import ImageData

In [None]:
# --- Exemple ---
kwargs_data = {
    'image_data': data,
    'noise_map': noise_map,
    'exposure_time': None,
    'background_rms': None,
    'exposure_time': exp_time,
    'ra_at_xy_0': delta_ra_at_xy_0,  
    'dec_at_xy_0': delta_dec_at_xy_0,
    'transform_pix2angle': transform_pix2angle 
}

In [None]:
# 3. DÃ©finir l'objet ImageData

# --- CrÃ©er l'objet ImageData ---
data_class = ImageData(**kwargs_data)

### Check the Forward model of the predicted image

#### Pixel grid

In [None]:
# import the PixelGrid() class #
from lenstronomy.Data.pixel_grid import PixelGrid
deltaPix = pixel_scale
transform_pix2anglepixgrid = np.array([[-1, 0], [0, 1]]) * deltaPix  # linear translation matrix of a shift in pixel in a shift in coordinates
kwargs_pixel = {'nx': Npix_x, 'ny': Npix_y,  # number of pixels per axis
                'ra_at_xy_0': delta_ra_at_xy_0,  # RA at pixel (0,0)
                'dec_at_xy_0': delta_dec_at_xy_0,  # DEC at pixel (0,0)
                'transform_pix2angle': transform_pix2anglepixgrid} 
pixel_grid = PixelGrid(**kwargs_pixel)
# return the list of pixel coordinates #
x_coords, y_coords = pixel_grid.pixel_coordinates
# compute pixel value of a coordinate position #
x_pos, y_pos = pixel_grid.map_coord2pix(ra=0, dec=0)
# compute the coordinate value of a pixel position #
ra_pos, dec_pos = pixel_grid.map_pix2coord(x=20, y=10)

### $\texttt{ImSim}$ module
The $\texttt{ImSim}$ module simulates images. At its core is the $\texttt{ImageModel}$ class. It is the interface to combine all the different components, $\texttt{LensModel}$, $\texttt{LightModel}$, $\texttt{PointSource}$ and $\texttt{Data}$ to model images. The $\texttt{LightModel}$ can be used to model lens light (unlensed) or source light (lensed). $\texttt{ImSim}$ is fully supporting all functionalities in each component. Additionaly, numerical precision arguments can be set in how to model the image.

#### Simulate image
We simulate an image with an instance of $\texttt{ImageModel}$ that got instances of the classes we created above. We can define two different $\texttt{LightModel}$ instances for the lens and source light. Additionally, we can define the sub-pixel ray-tracing resolution and whether the PSF convolution is applied on the higher resolution ray-tracing image or on the degraded pixel image.

In this example, we do not simulate point source. You can look at the notebooks dedicated to model quads and doubles and also for time-delya cosmography.

In [None]:
# import the ImageModel class #
from lenstronomy.ImSim.image_model import ImageModel
# define the numerics #
kwargs_numerics = {'supersampling_factor': 1, # each pixel gets super-sampled (in each axis direction) 
                  'supersampling_convolution': False}

# initialize the Image model class by combining the modules we created above #
imageModel = ImageModel(data_class=pixel_grid, psf_class=psf_class, lens_model_class=lensModel,
                        source_model_class=lightModel_source,
                        lens_light_model_class=lightModel_lens,
                        point_source_class=None, # in this example, we do not simulate point source.
                        kwargs_numerics=kwargs_numerics)
# simulate image with the parameters we have defined above #
image = imageModel.image(kwargs_lens=kwargs_lens, kwargs_source=kwargs_light_source,
                         kwargs_lens_light=kwargs_light_lens, kwargs_ps=kwargs_ps)

# we can also add noise #
import lenstronomy.Util.image_util as image_util
#exp_time = 100  # exposure time to quantify the Poisson noise level
background_rms = 0.0005  # background rms value
poisson = image_util.add_poisson(image, exp_time=exp_time)
bkg = image_util.add_background(image, sigma_bkd=background_rms)
image_noisy = image + bkg + poisson

f, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=False, sharey=False)
axes[0].matshow(np.log10(image), origin='lower')
axes[1].matshow(np.log10(image_noisy), origin='lower')

add_EastNorthFrame(40,155,axes[0])
add_EastNorthFrame(40,155,axes[1])


f.tight_layout()
plt.show()

#### Linear inversion
Parameter corresponding to an an amplitude of a surface brightness distribution have a linear response on the predicted flux values in the data and can be infered by a linear inversion. This can reduce the number of non-linear parameters fastly, depending on the source complexity to be modelled.

The $\texttt{ImageLinearFit}$ class performs this computation. The class inherits the $\texttt{ImageModel}$ class but instead of the $\texttt{PixelGrid}$ instance, it requires an instance of the full $\texttt{ImageData}$ class including the data (which we use the mock image created above) and the noise properties therof.

In the example of this notebook, we have 6 linear parameters, the 4 point source amplitudes and the amplitudes of the Sersic profile of the lens and source. *lenstronomy* automatically identifies those parameters and can recover those values (data permitting).

In [None]:
#from lenstronomy.Data.imaging_data import ImageData
# update the data with the noisy image and its noise properties #
kwargs_data = {'image_data': image_noisy,
               'background_rms': background_rms,
               'exposure_time': exp_time,
               'ra_at_xy_0': delta_ra_at_xy_0,  
               'dec_at_xy_0': delta_dec_at_xy_0,
               'transform_pix2angle': transform_pix2angle} 

data_class = ImageData(**kwargs_data)

from lenstronomy.ImSim.image_linear_solve import ImageLinearFit
image_linear_fit = ImageLinearFit(data_class=data_class, psf_class=psf_class, lens_model_class=lensModel,
                        source_model_class=lightModel_source,
                        lens_light_model_class=lightModel_lens,
                        point_source_class=pointSource, 
                        kwargs_numerics=kwargs_numerics)

# we do not require the knowledge of the linear parameters #
del kwargs_light_source[0]['amp']
# apply the linear inversion to fit for the noisy image #
image_reconstructed, _, _, _ = image_linear_fit.image_linear_solve(kwargs_lens=kwargs_lens, kwargs_source=kwargs_light_source, 
                              kwargs_lens_light=kwargs_light_lens, kwargs_ps=kwargs_ps)

In [None]:
# illustrate fit #
from lenstronomy.Plots.model_plot import ModelPlot
kwargs_model = {'lens_model_list': lens_model_list, 'source_light_model_list': source_light_model_list,
               'lens_light_model_list': lens_light_model_list} #, 'point_source_model_list': point_source_model_list}
kwargs_params = {'kwargs_lens': kwargs_lens, 'kwargs_source': kwargs_light_source,
                 'kwargs_lens_light': kwargs_light_lens, 'kwargs_ps': kwargs_ps}
lensPlot = ModelPlot([[kwargs_data, kwargs_psf_gaussian, kwargs_numerics]], kwargs_model, kwargs_params, arrow_size=0.02)

f, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=False, sharey=False)
lensPlot.data_plot(ax=axes[0])
lensPlot.model_plot(ax=axes[1])
f.tight_layout()
plt.show()

## Simulation choices

In [None]:
# data specifics
#background_rms = 0.015  #  background noise per pixel
#exp_time = 500.0  #  exposure time (arbitrary units, flux per pixel is in units #photons/exp_time unit)
numPix = Npix_x  #  cutout pixel size per axis
pixel_scale = pixel_scale  #  pixel size in arcsec (area per pixel = pixel_scale**2)
fwhm = PSF_FWHM  # full width at half maximum of PSF



# lensing quantities
lens_model_list = ["EPL", "SHEAR"]
kwargs_epl = {
    "theta_E": 1.2,
    "gamma": 2.0,
    "e1": 0.1,
    "e2": 0.2,
    "center_x": delta_ra_at_xy_centralmax,
    "center_y": delta_dec_at_xy_centralmax,
}
kwargs_shear = {"gamma1": -0.01, "gamma2": 0.03}

if lens_model_list[0] == "SIE":
    kwargs_epl.pop("gamma")

kwargs_lens = [kwargs_epl, kwargs_shear]
lens_model_class = LensModel(lens_model_list)


# Sersic parameters in the initial simulation for the source
kwargs_sersic = {
    "amp": 26.0,
    "R_sersic": 1.1,
    "n_sersic": 1.0,
    "e1": -0.1,
    "e2": 0.1,
    "center_x": beta_ra,
    "center_y": beta_dec,
}
source_model_list = ["SERSIC_ELLIPSE"]

kwargs_source = [kwargs_sersic]
source_model_class = LightModel(source_model_list)


kwargs_sersic_lens = {
    "amp": 16.0,
    "R_sersic": 3.6,
    "n_sersic": 2.0,
    "e1": -0.1,
    "e2": 0.1,
    "center_x": delta_ra_at_xy_centralmax,
    "center_y": delta_dec_at_xy_centralmax,
}

lens_light_model_list = ["SERSIC_ELLIPSE"]
kwargs_lens_light = [kwargs_sersic_lens]

lens_light_model_class = LightModel(lens_light_model_list)

point_source_type_list = ["LENSED_POSITION"]
kwargs_lensed_position = {
    "ra_image": all_delta_ra_at_xy_s,
    "dec_image": all_delta_dec_at_xy_s,
    "source_amp": 2,
}
kwargs_ps = [kwargs_lensed_position]
point_source_class = PointSource(
    point_source_type_list, lens_model_class, fixed_magnification_list=[True]
)

kwargs_truth = {
    "kwargs_lens": kwargs_lens,
    "kwargs_source": kwargs_source,
    "kwargs_lens_light": kwargs_lens_light,
    "kwargs_ps": kwargs_ps,
}

## Simulating a mock lens
In the blocks below we simulate a mock lens to generate an image. We only require the kwargs_data and kwargs_psf arguments to perform the modeling. If you have real data, you can leave out the image simulation and directly read in the data, PSF and noise properties into the keyword argument list. Make sure the units are correct. Further information on the settings are available in the ImageData() and PSF() classes in the lenstronomy.Data module.

## Defining a Prior Distribution
Now that we have our data image, our goal is to sample through the parameter space to recover the true parameters of the lens that generated this image. First, we define a prior distribution to pass into our samplers.

In [None]:
# deflector priors
fixed_lens = []
kwargs_lens_init = []
kwargs_lens_sigma = []
kwargs_lower_lens = []
kwargs_upper_lens = []

# Since our deflector model consists of ["EPL", "SHEAR"], we first define priors for the EPL
fixed_lens.append({})
kwargs_lens_init.append(
    {
        "theta_E": 1.1,
        "gamma": 2.0,
        "e1": 0.05,
        "e2": -0.05,
        "center_x": delta_ra_at_xy_centralmax,
        "center_y": delta_dec_at_xy_centralmax,
    }
)
kwargs_lens_sigma.append(
    {
        "theta_E": 0.2,
        "gamma": 0.1,
        "e1": 0.01,
        "e2": 0.01,
        "center_x": 0.1,
        "center_y": 0.1,
    }
)
kwargs_lower_lens.append(
    {
        "theta_E": 0.01,
        "gamma": 1.5,
        "e1": -0.5,
        "e2": -0.5,
        "center_x": -10.0,
        "center_y": -10.0,
    }
)
kwargs_upper_lens.append(
    {
        "theta_E": 10.0,
        "gamma": 2.5,
        "e1": 0.5,
        "e2": 0.5,
        "center_x": 10.0,
        "center_y": 10.0,
    }
)

if lens_model_list[0] == "SIE":
    kwargs_lens_init[0].pop("gamma")
    kwargs_lens_sigma[0].pop("gamma")
    kwargs_lower_lens[0].pop("gamma")
    kwargs_upper_lens[0].pop("gamma")

# Now we define priors for the SHEAR
fixed_lens.append({"ra_0": 0, "dec_0": 0})
kwargs_lens_init.append({"gamma1": 0.1, "gamma2": 0.04})
kwargs_lens_sigma.append({"gamma1": 0.01, "gamma2": 0.01})
kwargs_lower_lens.append({"gamma1": -0.2, "gamma2": -0.2})
kwargs_upper_lens.append({"gamma1": 0.2, "gamma2": 0.2})

lens_params = [
    kwargs_lens_init,
    kwargs_lens_sigma,
    fixed_lens,
    kwargs_lower_lens,
    kwargs_upper_lens,
]


# source priors
fixed_source = []
kwargs_source_init = []
kwargs_source_sigma = []
kwargs_lower_source = []
kwargs_upper_source = []

fixed_source.append({})
kwargs_source_init.append(
    {
        "amp": 16,
        "e1": -0.1,
        "e2": 0.1,
        "R_sersic": 1.2,
        "n_sersic": 1.5,
        "center_x":  beta_ra,
        "center_y":  beta_dec,
    }
)
kwargs_source_sigma.append(
    {
        "amp": 1,
        "e1": 0.01,
        "e2": 0.01,
        "R_sersic": 0.01,
        "n_sersic": 0.1,
        "center_x": 0.1,
        "center_y": 0.1,
    }
)
kwargs_lower_source.append(
    {
        "amp": 0,
        "e1": -0.5,
        "e2": -0.5,
        "R_sersic": 0.001,
        "n_sersic": 0.5,
        "center_x": -10.0,
        "center_y": -10.0,
    }
)
kwargs_upper_source.append(
    {
        "amp": 100,
        "e1": 0.5,
        "e2": 0.5,
        "R_sersic": 10,
        "n_sersic": 5.0,
        "center_x": 10,
        "center_y": 10,
    }
)

source_params = [
    kwargs_source_init,
    kwargs_source_sigma,
    fixed_source,
    kwargs_lower_source,
    kwargs_upper_source,
]


# Lens light priors
fixed_lens_light = []
kwargs_lens_light_init = []
kwargs_lens_light_sigma = []
kwargs_lower_lens_light = []
kwargs_upper_lens_light = []

fixed_lens_light.append({})
kwargs_lens_light_init.append(
    {
        "amp": 10,
        "e1": -0.05,
        "e2": 0.05,
        "R_sersic": 2.5,
        "n_sersic": 2.0,
        "center_x": delta_ra_at_xy_centralmax,
        "center_y": delta_dec_at_xy_centralmax,
    }
)
kwargs_lens_light_sigma.append(
    {
        "amp": 1,
        "e1": 0.01,
        "e2": 0.01,
        "R_sersic": 0.03,
        "n_sersic": 0.01,
        "center_x": 0.01,
        "center_y": 0.01,
    }
)
kwargs_lower_lens_light.append(
    {
        "amp": 0,
        "e1": -0.5,
        "e2": -0.5,
        "R_sersic": 0.001,
        "n_sersic": 0.5,
        "center_x": -10.0,
        "center_y": -10,
    }
)
kwargs_upper_lens_light.append(
    {
        "amp": 100,
        "e1": -0.05,
        "e2": 0.05,
        "R_sersic": 10.0,
        "n_sersic": 5.0,
        "center_x": 10.0,
        "center_y": 10,
    }
)

lens_light_params = [
    kwargs_lens_light_init,
    kwargs_lens_light_sigma,
    fixed_lens_light,
    kwargs_lower_lens_light,
    kwargs_upper_lens_light,
]

# Point source priors
fixed_ps = [{}]
kwargs_ps_init = []
kwargs_ps_sigma = []
kwargs_lower_ps = []
kwargs_upper_ps = []

#init_ra = np.array([0.69920543, -0.69893762, -2.44616562, 2.42290681])
#init_dec = np.array([-2.78394548, 2.72912308, -0.41755671, 0.55430711])

init_ra = all_delta_ra_at_xy_s
init_dec = all_delta_dec_at_xy_s

kwargs_ps_init.append(
    {
        "ra_image": init_ra,
        "dec_image": init_dec,
        "source_amp": 1,
    }
)
kwargs_ps_sigma.append(
    {
        "ra_image":  np.full_like(init_ra, 0.1),
        "dec_image": np.full_like(init_dec, 0.1),
        "source_amp": 0.1,
    }
)
kwargs_lower_ps.append(
    {
        "ra_image": init_ra - 0.65,
        "dec_image": init_dec - 0.65,
        "source_amp": 0.5,
    }
)
kwargs_upper_ps.append(
    {
        "ra_image": init_ra + 0.55,
        "dec_image": init_dec + 0.55,
        "source_amp": 10,
    }
)

point_source_params = [
    kwargs_ps_init,
    kwargs_ps_sigma,
    fixed_ps,
    kwargs_lower_ps,
    kwargs_upper_ps,
]

kwargs_params = {
    "lens_model": lens_params,
    "source_model": source_params,
    "lens_light_model": lens_light_params,
    "point_source_model": point_source_params,
}

In [None]:
# create ImageModel class and generate image
imageModel = ImageModel(
    data_class,
    psf_class,
    lens_model_class=lens_model_class,
    source_model_class=source_model_class,
    lens_light_model_class=lens_light_model_class,
    point_source_class=point_source_class,
    kwargs_numerics=kwargs_numerics,
)
image_model = imageModel.image(
    kwargs_lens,
    kwargs_source,
    kwargs_lens_light=kwargs_lens_light,
    kwargs_ps=kwargs_ps,
)

# Can add noise if desired
#background_rms = 0.001  # background rms value
poisson = image_util.add_poisson(image_model, exp_time=exp_time)
bkg = image_util.add_background(image_model, sigma_bkd=background_rms)
image_model += poisson + bkg
kwargs_data["image_data"] = image_model
data_class.update_data(image_model)


In [None]:
norm_model = getsimple_norm(image_model)
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(12,6))
ax1.imshow(image_model,origin="lower",norm=norm,cmap="hot")
ax2.imshow(image_model,origin="lower",cmap="hot")
plt.show()

## FittingSequence
This class holds all of the image generation choices, the data image, the likelihood calculation choices, the choices for which methods are used to sample through the parameter space, and the current distribution/best fit parameters (updated after each sampler finishes running), which are currently set to the prior distributions defined above.

We compare the performances of the jaxtronomy modeling pipeline to that of lenstronomy.

In [None]:
kwargs_likelihood = {
    "check_bounds": True,  # Checks if sampler goes out of bounds during sampling
    "image_likelihood": True,
    "image_position_likelihood": False,
    "source_position_likelihood": True,
    "astrometric_likelihood": False,
    "source_position_sigma": 0.001,
    "check_positive_flux": True,
    # The following likelihoods are not supported in jaxtronomy yet
    "time_delay_likelihood": False,
    "tracer_likelihood": False,
    "flux_ratio_likelihood": False,
    "kinematic_2d_likelihood": False,
}
kwargs_model = {
    "lens_model_list": lens_model_list,
    "source_light_model_list": source_model_list,
    "lens_light_model_list": lens_light_model_list,
    "point_source_model_list": point_source_type_list,
    "fixed_magnification_list": [True],
}

multi_band_list = [[kwargs_data, kwargs_psf_gaussian, kwargs_numerics]]
kwargs_data_joint = {
    "multi_band_list": multi_band_list,
    # Only single-band is supported in jaxtronomy
    "multi_band_type": "single-band",
}

linear_solver = True
kwargs_constraints = {
    "linear_solver": linear_solver,
    "num_point_source_list": [Nsource_images],  # Our LENSED_POSITION point source model has 4 images
}

fitting_seq = FittingSequence(
    kwargs_data_joint,
    kwargs_model,
    kwargs_constraints,
    kwargs_likelihood,
    kwargs_params,
)
fitting_seq_ref = FittingSequence_ref(
    kwargs_data_joint,
    kwargs_model,
    kwargs_constraints,
    kwargs_likelihood,
    kwargs_params,
)

# Store these for later comparison
kwargs_initial = copy.deepcopy(fitting_seq._updateManager.parameter_state)


fitting_kwargs_list_mcmc = [
    ["PSO", {"sigma_scale": 1.0, "n_particles": 50, "n_iterations": 100}],
    # ["MCMC", {"n_burn": 50, "n_run": 50, "n_walkers": 70, "sigma_scale": 1.0}],
]

In [None]:
# Run the lenstronomy fitting sequence for PSO + MCMC
chain_list = fitting_seq_ref.fit_sequence(fitting_kwargs_list_mcmc)

In [None]:
# Run the lenstronomy fitting sequence for PSO + MCMC
chain_list = fitting_seq_ref.fit_sequence(fitting_kwargs_list_mcmc)

In [None]:
# Run the jaxtronomy fitting sequence for PSO + MCMC
chain_list = fitting_seq.fit_sequence(fitting_kwargs_list_mcmc)
kwargs_result = fitting_seq.best_fit()
kwargs_result.pop("kwargs_tracer_source", None)

All of the speedup we got was simply from JAX compiling the functions and optimizing the way the code is run. In this case, we see a 2x speedup using the EPL profile on CPU, which was originally slow due to having to evaluate hyp2f1 for deflection angle calculations. For a simple SIE profile which already runs quickly, the performance may have negligible improvement or become worse, depending on compile time. For GPU, the speedup varies.

## Minimizer
JAX's autodifferentiation feature allows us to make use of a minimizer. Here we use Scipy's BFGS minimization method with 5 chains ran sequentially. The initial parameters for each chain are randomly drawn from the prior distribution, which helps avoid getting stuck at local minima. For posterior sampling, the PSO step can be replaced with this step for a significant speedup.

Support for parallelization is planned.

In [None]:
# Reset the initial parameters
fitting_seq = FittingSequence(
    kwargs_data_joint,
    kwargs_model,
    kwargs_constraints,
    kwargs_likelihood,
    kwargs_params,
)

# Create a copy of the initial set of parameters for later comparison
kwargs_initial = copy.deepcopy(fitting_seq._updateManager.parameter_state)

# options are BFGS and TNC
# Other options such as Nelder-Mead, Powell, CG, Newton-CG, L-BFGS-B, COBYLA, SLSQP, trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov have not been tested
fitting_kwargs_list_jaxopt = [
    [
        "Jaxopt",
        {
            "method": "BFGS",
            "maxiter": 500,
            "rng_int": 420,
            "num_chains": 5,
            "tolerance": 5500,  # We can stop running chains once we get a "good enough" fit
        },
    ],
]

chain_list = fitting_seq.fit_sequence(fitting_kwargs_list_jaxopt)

kwargs_result = fitting_seq.best_fit()
kwargs_result.pop("kwargs_tracer_source", None)

## Analyzing minimizer results

In [None]:
col1_spacing = 13
col2_spacing = 26
col3_spacing = 26
col4_spacing = 26

print(
    "|"
    + "parameter".center(col1_spacing)
    + "|"
    + "kwargs_initial".center(col2_spacing)
    + "|"
    + "kwargs_result".center(col3_spacing)
    + "|"
    + "kwargs_truth".center(col4_spacing)
    + "|"
)

# Iterate through lens/source/ps
for key in kwargs_truth.keys():
    print(
        "|"
        + "-".center(col1_spacing, "-")
        + "|"
        + "-".center(col2_spacing, "-")
        + "|"
        + "-".center(col3_spacing, "-")
        + "|"
        + "-".center(col4_spacing, "-")
        + "|"
    )
    for i in range(len(kwargs_truth[key])):
        for parameter in kwargs_truth[key][i]:
            if (
                parameter in ["amp", "source_amp", "point_amp"]
                and linear_solver is True
            ):
                continue
            kwargs_initial[key][i][parameter] = np.round(
                kwargs_initial[key][i][parameter], 2
            )
            kwargs_result[key][i][parameter] = np.round(
                kwargs_result[key][i][parameter], 2
            )
            kwargs_truth[key][i][parameter] = np.round(
                kwargs_truth[key][i][parameter], 2
            )

            print(
                "|"
                + parameter.center(col1_spacing)
                + "|"
                + str(kwargs_initial[key][i][parameter]).center(col2_spacing)
                + "|"
                + str(kwargs_result[key][i][parameter]).center(col3_spacing)
                + "|"
                + str(kwargs_truth[key][i][parameter]).center(col4_spacing)
                + "|"
            )

In [None]:
if linear_solver:
    fitted_image, _, _, amps = (
        fitting_seq.likelihoodModule.image_likelihood.imSim.image_linear_solve(
            **kwargs_result
        )
    )
    print("linearly solved amplitudes:", amps)
else:
    fitted_image = imageModel.image(**kwargs_result)

kwargs_initial.pop("kwargs_tracer_source", None)
initial_image = imageModel.image(**kwargs_initial)

reduced_residual = (fitted_image - image_model) / np.sqrt(
    imageModel.Data.C_D_model(fitted_image)
)

images = [initial_image, fitted_image, image_model, reduced_residual]
titles = ["initial image", "fitted image", "truth image", "reduced residual"]

fig, axs = plt.subplots(2, 2,figsize=(16,16),layout="constrained")
ax=axs.flatten()
for i in range(4):
    ax[i].set_title(titles[i])
    if i in [0,1, 2]:
        norm = getsimple_norm(images[i])
        im = ax[i].imshow(images[i],origin="lower",cmap="hot")
        #im = ax[i].imshow(images[i],origin="lower" ,norm=norm,cmap="hot")
    else:
        im = ax[i].imshow(images[i],origin="lower",cmap="seismic",vmin=-6,vmax=5)
    plt.colorbar(im, ax=ax[i])

fig.set_figwidth(25)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(range(len(chain_list[-1][2])), -np.array(chain_list[-1][2]))
ax.set_ylabel("negative logL", fontsize=14)
ax.set_xlabel("Iteration", fontsize=14)
ax.set_yscale("log")
ax.grid()
plt.show()