In [None]:
# import libraries
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

# general constants
c = 299792458
G = 6.67259E-11
Msun = 1.98892e30
pc = 3.0857E16
Mpc= 1e6*pc

# cosmological constants
h = 0.71
OmegaC = 0.222
OmegaLambda = 0.734
OmegaBaryon = 0.0449
H = 100*h

# misc
arcsec_units = np.pi/180/3600

In [None]:
def AngularDiameter(z2,z1):
    """
    Computes angular diameter and light-travel distances. See Wikipedia link below for details.
    https://en.wikipedia.org/wiki/Distance_measures_(cosmology)
    
    Input: z1,z2 the two redshifts [float]
    
    Output: d_A the angular diameter distance in parsecs [float]
    """
    
    omega_M = OmegaC+OmegaBaryon
    omega_V = OmegaLambda
    omega_R = 4.165E-5/h**2
    omega_K = 1 - omega_M - omega_R - omega_V
    
    # do the integral for d_C from z1 to z2 using the midpoint rule
    # we define a = 1/(1+z)
    a_z1 = 1/(1+z1)
    a_z2 = 1/(1+z2)
    
    nPoints = 10000 # number of points in integral
    a = np.linspace(a_z1,a_z2,nPoints)
    E = np.sqrt(omega_R*a**-4 + omega_M*a**-3 + omega_K*a**-2 + omega_V)
    
    d_C = (1/E).mean()
    
    # compute d_M
    x = omega_K**0.5 * d_C
    
    if omega_K > 0:
        d_M = omega_K**-0.5 * 0.5*(exp(x)-exp(-x))
    elif omega_K == 0:
        d_M = d_C
    else:
        d_M = omega_K**-0.5 * (x - x**2/2 + x**4/24)
        
    # compute d_A
    d_A = (1-a_z1)*d_C
    d_A *= (c/1000)/H * 1e6 # unit conversion
    
    return d_A

def distances(z_source,z_lens):
    """
    Compute angular distances in Mpc.
    
    Input:  z_source - redshift of source [float]
            z_lens - redshift of lens [float]
    
    Output: Ds - distance between observer and source [float]
            Dd - distance between observer and lens [float]
            Dds - distance between lens and source [float]
    """
    Ds = AngularDiameter(0,z_source) * 10**-6
    Dd = AngularDiameter(0,z_lens) * 10**-6
    Dds = AngularDiameter(z_lens,z_source) * 10**-6
    
    return Ds, Dd, Dds

In [None]:
def cart2pol(x, y):
    rho = np.sqrt(x**2 + y**2)
    phi = np.arctan2(y, x)
    return(rho, phi)

def pol2cart(rho, phi):
    x = rho * np.cos(phi)
    y = rho * np.sin(phi)
    return(x, y)

def raytrace_SIE(ximage,yimage,theta_E,elpSIS,Ext_Shear,Shear_angle,zLENS,zSOURCE,theta,offsets):
    """
    This function takes an empty meshgrid (ximage,yimage) that has the shape of the source image
    and will map the deflection of light onto a new meshgrid (xsource, ysource) for given
    paramters (Einstein radius, ...)
    
    SIE: https://ui.adsabs.harvard.edu/abs/1994A%26A...284..285K/abstract
    einstein R (p10): https://arxiv.org/pdf/1003.5567.pdf
    
    Input:  ximage - meshgrid for x position of ...
    ...
    """
    
    # compute angular distances in Mpc
    Dd, Ds, Dds = distances(zSOURCE,zLENS)
        
    # position of lens(?)
    xoffset = -offsets[0]*arcsec_units
    yoffset = -offsets[1]*arcsec_units
    
    ximage=ximage+xoffset
    yimage=yimage+yoffset
    
    # convert to rad
    theta=theta*(np.pi/180)
    Shear_angle=Shear_angle*(np.pi/180)
    
    # compute quantities for ray-tracing
    sigma_v = (c**2/(4*np.pi) * theta_E*arcsec_units * Ds/Dds)**0.5
    M_SIS = np.pi * sigma_v**2 * theta_E*arcsec_units * Dd/G # mass inside R=theta_E*D_d
    M_SIS *= Mpc/Msun # convert units
    R_ein = 4*np.pi * (sigma_v/c)**2 *(Dds/Ds) # to convert units
    XI_0 = 4*np.pi * (sigma_v/c)**2 *(Dd*Dds/Ds)
    
    # apply deviation angle theta to image
    ximage_rho, yimage_theta = cart2pol(ximage,yimage)
    ximage, yimage = pol2cart(ximage_rho, yimage_theta-theta)
    
    phi_shear = Shear_angle-theta
    g1=Ext_Shear*(-np.cos(2*phi_shear))
    g2=Ext_Shear*(-np.sin(2*phi_shear))
    g3=Ext_Shear*(-np.sin(2*phi_shear))
    g4=Ext_Shear*( np.cos(2*phi_shear))    
    
    f=1-elpSIS # axis ratio
    fp=(1-f**2)**0.5
    
    phi = np.arctan2(yimage,ximage)
    
    xsource = ximage - f**0.5/fp*np.arcsinh(np.cos(phi)*fp/f)*XI_0/Dd - g1*ximage - g2*yimage
    ysource = yimage - f**0.5/fp*np.arcsin(np.sin(phi)*fp)*XI_0/Dd - g3*ximage - g4*yimage
    
    ximage_rho, yimage_theta = cart2pol(ximage,yimage)
    ximage, yimage = pol2cart(ximage_rho, yimage_theta+theta)
    xsource_rho, ysource_theta = cart2pol(xsource,ysource)
    xsource, ysource = pol2cart(xsource_rho, ysource_theta+theta)
    
    ximage=ximage-xoffset
    yimage=yimage-yoffset
    xsource=xsource-xoffset
    ysource=ysource-yoffset
    
    Minterior=(np.pi*(sigma_v**2)*R_ein*Dd*Mpc)/G/Msun
    
    return xsource, ysource, sigma_v, Minterior

In [None]:
# this is the shape of the grid we will use to produce our lensed image
nPixels_image = 500 # our source image will have 500x500 pixels
imside = nPixels_image / 40 # rescale output

XIM,YIM = np.meshgrid(np.linspace(-imside/2,imside/2,nPixels_image)*arcsec_units,
                     np.linspace(-imside/2,imside/2,nPixels_image)*arcsec_units)

In [None]:
# set redshift for source and lens planes
zSOURCE = 2
zLENS = 0.5

# generate some random parameters
max_theta_ein = 3.0
max_elp = 0.9

theta_ein = 0.1 + np.random.rand(1) * (max_R_ein-0.1)
elp = np.random.rand(1) * max_elp
angle = np.random.rand(1)*360
XLENS = (np.random.rand(1)-0.5) * 0.1
YLENS = (np.random.rand(1)-0.5) * 0.1
shear =  np.random.rand(1) * 0.01
shear_angle =  np.random.rand(1) * 0.01

# ray trace and visualize the deviation
XS ,YS, sigma_v, MSIS = raytrace_SIE(XIM,YIM,theta_ein,elp,shear,shear_angle,zLENS,zSOURCE,angle,[XLENS,YLENS])
plt.scatter(XIM,YIM)
plt.scatter(XS,YS,s=0.1); plt.show()

In [None]:
# load the source image and create its grid
source_image = plt.imread('image.png')[:,:,0] #500x500pix dog pic

x_src,y_src = np.meshgrid(np.linspace(-1,1,nPixels_image)*arcsec_units,
                         np.linspace(-1,1,nPixels_image)*arcsec_units)

# compute taper
rfilt = np.sqrt(x_src**2+y_src**2)
taper = 1/(1+(rfilt/(0.6*arcsec_units))**6)
taper = taper/np.max(taper)

print('No taper')
plt.imshow(source_image); plt.show()

source_image /= np.max(source_image) # normalize pixel values
source_image *= taper # apply taper

print('With taper')
plt.imshow(source_image); plt.show()

In [None]:
# interpolate and display result
xpoints = x_src + x_src*0.1# add a shift to xpoints and ypoints to see what happens
ypoints = y_src
points = np.array((xpoints.flatten(), ypoints.flatten())).T
values = source_image.flatten()

observed_image = griddata(points, values, (XS,YS), method='nearest')

plt.imshow(observed_image); plt.show()