In [None]:
#INSTRUCTIONS:

#1. Run first cell to load imports
#2. Run 2nd cell to define the vectorized lense flow and plotting / printing code
#3. 3rd cell loads original version of logpdf into memory which took f_lensed as input
#4. 4th cell loads current version of logpdf into mempory which uses f as input
#5. 5th cell loads data you have ported from julia into this python notebook and compares logpdf_v1, logpdf_v2, and the julia ground truth
#6. Last cell plots the difference in the fourier space lensed field calculated in python versus the output from julia. 
#   The difference here seems to be the source of error for the gradients.  

In [None]:
#------------------------------------------------------------------------------------------------------------
#------------------------------------------------- Imports --------------------------------------------------
#------------------------------------------------------------------------------------------------------------

import numpy as np
import math as mt
import matplotlib as plt # type: ignore
import jax # type: ignore
import jax.numpy as jnp # type: ignore
import matplotlib.pyplot as plt # type: ignore
import jax.numpy.fft as jfft # type: ignore
import numpy as np
import matplotlib.pyplot as plt # type: ignore
import time
import math as mt
from numba import njit # type: ignore
from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt # type: ignore
from jax.scipy.sparse.linalg import cg

In [None]:
#DEFINE LENSEFLOW VECTORIZED FOR JAX AND OTHER HELPER FUNCTIONS

#------------------------------------------------------------------------------------------------------------
#----------------------------------------  Function Definitions ---------------------------------------------
#------------------------------------------------------------------------------------------------------------

#reusable plotting code to standardize plots
def plot_heat_map(heatmap, title, x_label, y_label, color_bar_label, clim = None):
    if clim == None:
        plt.imshow(heatmap, cmap='coolwarm', origin='lower')
    else:
        plt.imshow(heatmap, cmap='coolwarm', origin='lower', clim = clim)
    plt.colorbar(label=color_bar_label)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.show()

#reusable function for computing standard data science metrics
def print_metrics_to_screen(ground_truth, prediction):
    #Mean Squared Error
    mse = np.mean((ground_truth - prediction)**2)
    print("Mean Squared Error = "+str(mse))

    #normalized mean squared error
    nmse = mse / np.var(ground_truth)
    print("Normalized Mean Squared Error = "+str(nmse))

    #normalized root mean squared error
    rmse = np.sqrt(np.mean((prediction - ground_truth)**2))
    nrmse = rmse / np.std(ground_truth)
    print("Normalized Root Mean Squared Error = "+str(nrmse))

    #R^2 value
    numerator = np.sum((prediction - ground_truth)**2)
    denominator = np.sum((ground_truth - np.mean(ground_truth))**2)
    r2 = 1 - (numerator / denominator)
    print("R^2 value = "+str(r2))

#---------------------------------------------------------------------------------------------------
#------------------------------------- Utility Function Definitions --------------------------------
#---------------------------------------------------------------------------------------------------


# -----------------------------------
# vectorized FFT derivatives (JAX)
# ----------------------------------
def get_spatial_derivatives(f, pix_width):
    Nx, Ny = f.shape
    F = jnp.fft.fft2(f)
    kx = 2.0 * jnp.pi * jnp.fft.fftfreq(Nx, d=pix_width)
    ky = 2.0 * jnp.pi * jnp.fft.fftfreq(Ny, d=pix_width)
    KX, KY = jnp.meshgrid(kx, ky, indexing="ij")
    Fx  = jnp.fft.ifft2(1j * KX * F).real
    Fy  = jnp.fft.ifft2(1j * KY * F).real
    Fxx = jnp.fft.ifft2(- (KX**2) * F).real
    Fyy = jnp.fft.ifft2(- (KY**2) * F).real
    Fxy = jnp.fft.ifft2(- (KX * KY) * F).real
    return Fx, Fy, Fxx, Fxy, Fyy

# ----------------------------------------------------------------------------------------------------
# Given a symmetric matrix of the form [[a, b], [b, d]] return the components of its inverse.
# Should use this instead of jnp.linalg.pinv since the jax tracer has trouble with the former...
# ----------------------------------------------------------------------------------------------------
def get_inverse_matrix_components(a, b, d, eps=1e-12):
    det = a * d - b * b
    inv_det = 1.0 / (det + eps) #avoid divide by zero errors
    inv11 =  d * inv_det
    inv12 = -b * inv_det
    inv22 =  a * inv_det
    return inv11, inv12, inv22

# -----------------------------------------------------------------
# Define the dynamics df/dt = f(t, y, args) for a single time step
# -----------------------------------------------------------------
def single_lense_flow_step(t, y, args):

    #unpack args
    grad_phi_x, grad_phi_y, d2_phi_dx2, d2_phi_dxdy, d2_phi_dy2, pix_width, rows, cols, adjoint = args
    #reshape y into 2D field
    f = y.reshape((rows, cols))

    #compute gradients of f (vectorized, JAX)
    grad_fx, grad_fy, _, _, _ = get_spatial_derivatives(f, pix_width)

    #magnification matrix components (vectorized)
    a = 1.0 + t * d2_phi_dx2
    d = 1.0 + t * d2_phi_dy2
    b = t * d2_phi_dxdy

    #get the corresponding components of the inverse magnification matrix
    inv11, inv12, inv22 = get_inverse_matrix_components(a, b, d)

    #if we want to compute the adjoint we need to calculate extra terms
    extra_adjoint_term = jnp.zeros((rows, cols))
    if adjoint == True:
        grad_inv11_x, _, _, _, _ = get_spatial_derivatives(inv11, pix_width)
        grad_inv12_x, grad_inv12_y, _, _, _ = get_spatial_derivatives(inv12, pix_width)
        _, grad_inv22_y, _, _, _ = get_spatial_derivatives(inv22, pix_width)
        grad_p_x = d2_phi_dx2*inv11 + d2_phi_dxdy*inv12 + grad_phi_x*grad_inv11_x + grad_phi_y*grad_inv12_x
        grad_p_y = d2_phi_dy2*inv22 + d2_phi_dxdy*inv12 + grad_phi_x*grad_inv12_y + grad_phi_y*grad_inv22_y
        extra_adjoint_term = (grad_p_x + grad_p_y)*f

    #compute the vector components of v = M^{-1} * grad_f
    vx = inv11 * grad_fx + inv12 * grad_fy
    vy = inv12 * grad_fx + inv22 * grad_fy

    #rate of change per pixel
    rate = grad_phi_x * vx + grad_phi_y * vy + extra_adjoint_term

    return rate.ravel() #flattened vector

# ----------------------------------------------
# JAX imlementation of LenseFlow using diffrax
# ----------------------------------------------
def LenseFlow(f, phi, pix_width, n=7, direction=1, adjoint=False):
    f = jnp.asarray(f) #cast to jax array just to be safe since diffrax requires jax inputs
    phi = jnp.asarray(phi)
    rows, cols = f.shape

    #precompute phi partials
    grad_phi_x, grad_phi_y, d2_phi_dx2, d2_phi_dxdy, d2_phi_dy2 = get_spatial_derivatives(phi, pix_width)

    #time interval
    if direction == -1 or adjoint == True:
        t0, t1 = 1.0, 0.0
        dt0=-1.0/n
    else:
        t0, t1 = 0.0, 1.0
        dt0=1.0/n

    #ravel up 2D array into 1D array since this is required for the diffrax ode solver
    y0 = f.ravel()
    #store extra arguments in a single array
    args = (grad_phi_x, grad_phi_y, d2_phi_dx2, d2_phi_dxdy, d2_phi_dy2, pix_width, rows, cols, adjoint)
    #define a single step
    single_step_dynamics = ODETerm(single_lense_flow_step)
    ode_solver_method = Tsit5() #diffrax equivalent of RK45 need to use a non-stiff solver to avoid singluar matrix inversions
    
    #call the ode solver
    sol = diffeqsolve(
        single_step_dynamics,
        ode_solver_method,
        t0=t0, #initial time
        t1=t1, #final time
        dt0=dt0, #initial guess for step
        y0=y0, #initial conditons
        args=args #extra arguments
    )

    #final result is stored in sol.ys... need to reshape back into a 2D jax array output
    y_final = jnp.asarray(sol.ys).reshape((-1,))
    return y_final.reshape((rows, cols))

In [None]:
#DEFINE LOGPDFV1 - This version calculates gradients wrong but gets the right scalar value

#GOAL 1: compute logpdf and compare to Julia. Python implementation of logpdf
def logpdf(f_array, phi_array, data_array, f_lensed_array,
           cf_diagonal, cphi_diagonal, cn_diagonal, 
           f_lambda_array, phi_lambda_array, data_lambda_array, 
           cf_lambda_array, cphi_lambda_array, cn_lambda_array, num_pixels):

    #assuming a square matrix in real / map space, compute the square width
    map_width = int(np.sqrt(num_pixels))

    #reshape the covariance matrix diagonals (need to specify Fortran style reshaping as Julia uses to be consistent)
    cf_matrix = cf_diagonal.reshape((int(map_width/2+1), map_width), order='F')
    cphi_matrix = cphi_diagonal.reshape((int(map_width/2+1), map_width), order='F')
    cn_matrix = cn_diagonal.reshape((int(map_width/2+1), map_width), order='F')

    #take the natural log of the absolute value of each element in the covariance matrices
    #and then perform element-wise multiplication by the fourier weights
    log_abs_cf_matrix = np.where(cf_matrix != 0, np.log(np.abs(cf_matrix)), 0)*cf_lambda_array[:,np.newaxis]
    log_abs_cphi_matrix = np.where(cphi_matrix != 0, np.log(np.abs(cphi_matrix)), 0)*cphi_lambda_array[:,np.newaxis]
    log_abs_cn_matrix = np.where(cn_matrix != 0, np.log(np.abs(cn_matrix)), 0)*cn_lambda_array[:,np.newaxis]

    #find the log determinants of these matrices
    log_det_f_value = np.sum(log_abs_cf_matrix)
    log_det_f_sign = np.prod(np.sign(np.where(cf_matrix != 0, cf_matrix, 1)))
    
    log_det_phi_value = np.sum(log_abs_cphi_matrix)
    log_det_phi_sign = np.prod(np.sign(np.where(cphi_matrix != 0, cphi_matrix, 1)))
    
    log_det_noise_value = np.sum(log_abs_cn_matrix)
    log_det_noise_sign = np.prod(np.sign(np.where(cn_matrix != 0, cn_matrix, 1)))
    
    #even though we reshaped the diagonal matrices to be rectangular, 
    #we compute the inverse as if they were diagonal matrices because they originally were
    f_covar_inv = np.where(cf_matrix != 0, 1.0/cf_matrix, 0.0)
    phi_covar_inv = np.where(cphi_matrix != 0, 1.0/cphi_matrix, 0.0)
    noise_covar_inv = np.where(cn_matrix != 0, 1.0/cn_matrix, 0.0)
    
    #compute the contribution from each field term (the [:, jnp.newaxis] syntax is necessary to do
    #the kind of element wise multiplication we want to do for this type of calculation)
    f_contribution = np.sum(np.real(np.conj(f_array) * f_covar_inv * f_array) \
                            * f_lambda_array[:, np.newaxis] * (1/num_pixels))
    phi_contribution = np.sum(np.real(np.conj(phi_array) * phi_covar_inv * phi_array) \
                              * phi_lambda_array[:, np.newaxis] * (1/num_pixels))
    noise_contribution = np.sum(np.real(np.conj(data_array - f_lensed_array) * noise_covar_inv \
                         * (data_array - f_lensed_array)) * data_lambda_array[:, np.newaxis] * (1/num_pixels))

    result = -1*(f_contribution + phi_contribution + noise_contribution \
                 + log_det_f_value * log_det_f_sign + log_det_phi_value * log_det_phi_sign \
                 + log_det_noise_value * log_det_noise_sign)/2

    
    return result

In [None]:
#DEFINE LOGPDFV2 - This version is off in the ones place from the ground truth which is likelt a very substantial difference

#GOAL 1: compute logpdf and compare to Julia. Python implementation of logpdf
def logpdf_v2(f_array, phi_array, data_array,
              b_diagonal, m_diagonal, 
              b_lambda_array, m_lambda_array, #not currently using these
              cf_diagonal, cphi_diagonal, cn_diagonal, 
              f_lambda_array, phi_lambda_array, data_lambda_array, 
              cf_lambda_array, cphi_lambda_array, cn_lambda_array, num_pixels):

    #assuming a square matrix in real / map space, compute the square width
    map_width = int(jnp.sqrt(num_pixels))

    #reshape the covariance matrix diagonals (need to specify Fortran style reshaping as Julia uses to be consistent)
    cf_matrix = cf_diagonal.reshape((int(map_width/2+1), map_width), order='F')
    cphi_matrix = cphi_diagonal.reshape((int(map_width/2+1), map_width), order='F')
    cn_matrix = cn_diagonal.reshape((int(map_width/2+1), map_width), order='F')
    # b_matrix = b_diagonal.reshape((int(map_width/2+1), map_width), order='F') #not currently using these
    # m_matrix = m_diagonal.reshape((int(map_width/2+1), map_width), order='F')

    #take the natural log of the absolute value of each element in the covariance matrices
    #and then perform element-wise multiplication by the fourier weights
    log_abs_cf_matrix = jnp.where(cf_matrix != 0, jnp.log(jnp.abs(cf_matrix)), 0)*cf_lambda_array[:,jnp.newaxis]
    log_abs_cphi_matrix = jnp.where(cphi_matrix != 0, jnp.log(jnp.abs(cphi_matrix)), 0)*cphi_lambda_array[:,jnp.newaxis]
    log_abs_cn_matrix = jnp.where(cn_matrix != 0, jnp.log(jnp.abs(cn_matrix)), 0)*cn_lambda_array[:,jnp.newaxis]

    #find the log determinants of these matrices
    log_det_f_value = jnp.sum(log_abs_cf_matrix)
    log_det_f_sign = jnp.prod(jnp.sign(jnp.where(cf_matrix != 0, cf_matrix, 1)))
    
    log_det_phi_value = jnp.sum(log_abs_cphi_matrix)
    log_det_phi_sign = jnp.prod(jnp.sign(jnp.where(cphi_matrix != 0, cphi_matrix, 1)))
    
    log_det_noise_value = jnp.sum(log_abs_cn_matrix)
    log_det_noise_sign = jnp.prod(jnp.sign(jnp.where(cn_matrix != 0, cn_matrix, 1)))
    
    #even though we reshaped the diagonal matrices to be rectangular, 
    #we compute the inverse as if they were diagonal matrices because they originally were
    f_covar_inv = jnp.where(cf_matrix != 0, 1.0/cf_matrix, 0.0)
    phi_covar_inv = jnp.where(cphi_matrix != 0, 1.0/cphi_matrix, 0.0)
    noise_covar_inv = jnp.where(cn_matrix != 0, 1.0/cn_matrix, 0.0)
    
    #compute the contribution from each field term (the [:, jnp.newaxis] syntax is necessary to do
    #the kind of element wise multiplication we want to do for this type of calculation)
    f_contribution = jnp.sum(jnp.real(jnp.conj(f_array) * f_covar_inv * f_array) \
                            * f_lambda_array[:, jnp.newaxis] * (1/num_pixels))
    phi_contribution = jnp.sum(jnp.real(jnp.conj(phi_array) * phi_covar_inv * phi_array) \
                              * phi_lambda_array[:, jnp.newaxis] * (1/num_pixels))
    
    #need to explicitly calculate f_lensed ourselves if we want to have the grad_f of logpdf be accurate
    #when computed by an AD (auto diff)
    pix_width = 0.00058177643 #from julia field object metadata delta_x property
    num_steps = 10 #to match Julia data
    f_lensed_array = LenseFlow(jfft.irfft2(f_array.T, s = (256, 256)).T, \
                               jfft.irfft2(phi_array.T, s = (256, 256)).T, pix_width, num_steps)
    f_lensed_array = m_diagonal * b_diagonal * jfft.rfft2(f_lensed_array.T).T

    noise_contribution = jnp.sum(jnp.real(jnp.conj(data_array - f_lensed_array) * noise_covar_inv \
                         * (data_array - f_lensed_array)) * data_lambda_array[:, jnp.newaxis] * (1/num_pixels))

    result = -1*(f_contribution + phi_contribution + noise_contribution \
                 + log_det_f_value * log_det_f_sign + log_det_phi_value * log_det_phi_sign \
                 + log_det_noise_value * log_det_noise_sign)/2

    
    return result

In [None]:
#COMPUTE LOGPDF AND LOGPDFV2 VALUES FOR GIVEN DATA - I got these using npzwrite in julia

#get the field data array file paths (replace with your own file paths...)
f_array_file_path = "/home/zane-blood/Desktop/f_array.npz"
phi_array_file_path = "/home/zane-blood/Desktop/phi_array.npz"
f_lensed_array_file_path = "/home/zane-blood/Desktop/f_lensed_array.npz"
f_lensed_map_array_file_path = "/home/zane-blood/Desktop/f_lensed_map_array.npz"
data_array_file_path = "/home/zane-blood/Desktop/data_array.npz"

#get the covariance array file paths (replace with your own file paths...)
cf_diagonal_file_path = "/home/zane-blood/Desktop/cf_diagonal.npz"
cphi_diagonal_file_path = "/home/zane-blood/Desktop/cphi_diagonal.npz"
cn_diagonal_file_path = "/home/zane-blood/Desktop/cn_diagonal.npz"
m_diagonal_file_path = "/home/zane-blood/Desktop/m_diagonal.npz"
b_diagonal_file_path = "/home/zane-blood/Desktop/b_diagonal.npz"

#get the fourier coefficient array file paths (replace with your own file paths...)
f_lambda_array_file_path = "/home/zane-blood/Desktop/f_lambda_array.npz"
phi_lambda_array_file_path = "/home/zane-blood/Desktop/phi_lambda_array.npz"
data_lambda_array_file_path = "/home/zane-blood/Desktop/data_lambda_array.npz"
cf_lambda_array_file_path = "/home/zane-blood/Desktop/cf_lambda_array.npz"
cf_lensed_lambda_array_file_path = "/home/zane-blood/Desktop/cf_lensed_lambda_array.npz"
cphi_lambda_array_file_path = "/home/zane-blood/Desktop/cphi_lambda_array.npz"
cn_lambda_array_file_path = "/home/zane-blood/Desktop/cn_lambda_array.npz"
b_lambda_array_file_path = "/home/zane-blood/Desktop/b_lambda_array.npz"
m_lambda_array_file_path = "/home/zane-blood/Desktop/m_lambda_array.npz"

#load the data from these file paths into memory
f_array = jnp.load(f_array_file_path)
phi_array = jnp.load(phi_array_file_path)
f_lensed_array = jnp.load(f_lensed_array_file_path)
f_lensed_map_array = jnp.load(f_lensed_map_array_file_path)
data_array = jnp.load(data_array_file_path)

cf_diagonal = jnp.load(cf_diagonal_file_path)
cphi_diagonal = jnp.load(cphi_diagonal_file_path)
cn_diagonal = jnp.load(cn_diagonal_file_path)
b_diagonal = jnp.load(b_diagonal_file_path)
m_diagonal = jnp.load(m_diagonal_file_path)

f_lambda_array = np.load(f_lambda_array_file_path)
phi_lambda_array = np.load(phi_lambda_array_file_path)
data_lambda_array = np.load(data_lambda_array_file_path)

cf_lambda_array = jnp.load(cf_lambda_array_file_path)
cf_lensed_lambda_array = jnp.load(cf_lensed_lambda_array_file_path)
cphi_lambda_array = jnp.load(cphi_lambda_array_file_path)
cn_lambda_array = jnp.load(cn_lambda_array_file_path)
b_lambda_array = jnp.load(b_lambda_array_file_path)
m_lambda_array = jnp.load(m_lambda_array_file_path)

#we will need the number of pixels Nx * Ny of the data in order to properly normalize the logpdf
num_pixels = 256*256 #in the case for the data loaded above this was done on a 256 x 256 grid in real / map space...

#call the logpdf and store its value
log_pdf = logpdf(np.array(f_array), np.array(phi_array), np.array(data_array), np.array(f_lensed_array),
                np.array(cf_diagonal), np.array(cphi_diagonal), np.array(cn_diagonal), 
                np.array(f_lambda_array), np.array(phi_lambda_array), np.array(data_lambda_array), 
                np.array(cf_lambda_array), np.array(cphi_lambda_array), np.array(cn_lambda_array), num_pixels)


log_pdf_v2 = logpdf_v2(f_array, phi_array, data_array,
                        b_diagonal, m_diagonal, 
                        b_lambda_array, m_lambda_array,
                        cf_diagonal, cphi_diagonal, cn_diagonal, 
                        f_lambda_array, phi_lambda_array, data_lambda_array, 
                        cf_lambda_array, cphi_lambda_array, cn_lambda_array, num_pixels)

In [None]:
#COMPARE THE VALUES OF LOGPDF AND LOGPDFV2 FOR SAME INPUTS AND COMPARE TO JULIA GROUND TRUTH
ground_truth = 1.1621151553273401e6 #from julia logpdf() function... Need to update if inputs changes
print("ground truth log pdf = " + str(ground_truth))
print("predicted log pdf v1 = " + str(log_pdf))
print("percent error v1 = " + str(100*(log_pdf-ground_truth)/ground_truth) + " %")
print("predicted log pdf v2 = " + str(log_pdf_v2))
print("percent error v2 = " + str(100*(log_pdf_v2-ground_truth)/ground_truth) + " %")

In [None]:
#compare the difference between the lensed fields in fourier space - julia ground truth versus python computed values
pix_width = 0.00058177643
num_steps = 10 #to match Julia data
f_lensed_array_predict = LenseFlow(jfft.irfft2(f_array.T, s = (256, 256)).T, \
                                   jfft.irfft2(phi_array.T, s = (256, 256)).T, pix_width, num_steps)
f_lensed_array_predict = m_diagonal * b_diagonal * jfft.rfft2(f_lensed_array_predict.T).T

plot_heat_map(np.abs(f_lensed_array), \
              'F Lensed Ground - Fourier Space', \
              'arcmin', 'arcmin', 'Intensity')
plot_heat_map(np.abs(f_lensed_array_predict), \
              'F Lensed Predict - Fourier Space', \
              'arcmin', 'arcmin', 'Intensity')
plot_heat_map(np.abs(f_lensed_array_predict-f_lensed_array), \
              'F Lensed Diff - Fourier Space', \
              'arcmin', 'arcmin', 'Intensity')