# Recentred algorithm with W=7, x_0=0.25, 42 w planes

## Reference:

Ye, H. (2019). Accurate image reconstruction in radio interferometry (Doctoral thesis). https://doi.org/10.17863/CAM.39448

Haoyang Ye, Stephen F Gull, Sze M Tan, Bojan Nikolic, Optimal gridding and degridding in radio interferometry imaging, Monthly Notices of the Royal Astronomical Society, Volume 491, Issue 1, January 2020, Pages 1146–1159, https://doi.org/10.1093/mnras/stz2970

Github: https://github.com/zoeye859/Imaging-Tutorial

In [1]:
%matplotlib notebook
import numpy as np
from scipy.optimize import leastsq, brent
from scipy.linalg import solve_triangular
import matplotlib.pyplot as plt
import scipy.integrate as integrate
from time import process_time
from numpy.linalg import inv
np.set_printoptions(precision=16)
from Imaging_core_new import *
from Gridding_core import *
import pickle
with open("min_misfit_gridding_14.pkl", "rb") as pp:
    opt_funcs = pickle.load(pp)

### 1. Read in the data

In [2]:
#########  Read in visibilities ##########
data = np.genfromtxt('simul3d.csv', delimiter = ',')
jj = complex(0,1)
u_original = data.T[2][1:]
v_original = data.T[3][1:]
w_original = data.T[4][1:]
V_original = data.T[5][1:] + jj*data.T[6][1:]
n_uv = len(u_original)
uv_max = max(np.sqrt(u_original**2+v_original**2))
V,u,v,w = Visibility_minusw(V_original,u_original,v_original,w_original)

#### Determine the pixel size ####
X_size = 300 # image size on x-axis (after cropping)
Y_size = 300 # image size on y-axis (after cropping)
# The following are actually the minimum and maximum values of the direction cosines
#  at the edges of the cropped map. They are not the same as the normalized map coordinates
#  used in the theory section.
X_min = -1.15/2 #You can change X_min and X_max in order to change the pixel size.
X_max = 1.15/2
X = np.linspace(X_min, X_max, num=X_size+1)[0:X_size]
Y_min = -1.15/2 #You can change Y_min and Y_max in order to change the pixel size.
Y_max = 1.15/2
Y = np.linspace(Y_min,Y_max,num=Y_size+1)[0:Y_size]
pixel_resol_x = 180. * 60. * 60. * (X_max - X_min) / np.pi / X_size
pixel_resol_y = 180. * 60. * 60. * (Y_max - Y_min) / np.pi / Y_size
print ("The pixel size on x-axis is ", pixel_resol_x, " arcsec") 

The pixel size on x-axis is  790.681757280536  arcsec


# Theory
The dirty map is given by
\begin{equation}\label{eqn:dirty-map}
\boldsymbol{d}(l,m)=\sum_k\boldsymbol{W}_k\boldsymbol{v}_k\exp\left[i2\pi\left(u_k l + v_k m + w_k(n-1)\right)\right]
\end{equation}
In this equation, the visibility plane coordinates are measured in wavelengths and $(l,m,n)$ are direction cosines of the 
observation point which satisfy $n=\sqrt{1-l^2-m^2}$. The full region to be mapped (before image cropping) is chosen to be $l_{\rm{min}}\le l\le l_{\rm{max}}$ and $m_{\rm{min}}\le m\le m_{\rm{max}}$ where without loss of generality we may assume that 
$l_{\rm{min}}=-l_{\rm{max}}$ and $m_{\rm{min}}=-m_{\rm{max}}$. We then write the map limits as $|l|\le l_{\rm{range}}/2$ and
$|m|\le m_{\rm{range}}/2$ where the ranges are the differences between the maximum and minimum values.

We introduce normalized map coordinates $x=l/l_{\rm{range}}$ and $y=m/m_{\rm{range}}$ which each lie within $[-0.5,0.5]$. The conjugate visibility plane coordinates to these are $u'=u l_{\rm{range}}$ and $v'=v m_{\rm{range}}$. In the $w$ direction, we introduce $z=(n-n_0)/n_{\rm{range}}$ and $w'=w n_{\rm{range}}$ where the quantities $n_0$ and $n_{\rm{range}}$ will be defined later. With this notation, the dirty map can be rewritten as
$$
\boldsymbol{d}(x,y)=\sum_k\boldsymbol{W}_k\boldsymbol{v}_k\exp\left[i2\pi\left(u'_k x + v'_k y + w'_k z + w_k (n_0-1)\right)\right]
$$
If $n_0=1$, the plane $z=0$ is tangential to the celestial sphere whereas a value of $n_0<1$ may be used to offset the plane. It is useful to associate the last part of the phase factor with the visibility:
$$
\boldsymbol{d}(x,y)=\sum_k\boldsymbol{W}_k\boldsymbol{v}_k\exp[w_k (n_0-1)] \exp\left[i2\pi\left(u'_k x + v'_k y + w'_k z \right)\right]
$$
Although this resembles a three dimensional Fourier transform, it is only evaluated on $(x,y)$ while $z$ is calculated from $x$ and $y$ according to 
\begin{equation}\label{eq:z-on-sphere}
z = \frac{\sqrt{1-l^2-m^2}-n_0}{n_{\rm range}} = \frac{\sqrt{1-(x l_{\rm{range}})^2-(y m_{\rm{range}})^2}-n_0}{n_{\rm{range}}}
\end{equation}

According to the theory of convolutional gridding, we can find a gridding convolutional function $C(u')$ and an associated grid correction function $h(x)$ so that this can be well approximated by:

\begin{equation}\label{eq:gridded-dirty-map}
\boldsymbol{d}(x,y)\approx
h(x)h(y)h(z)\sum_{t\in\mathbb{Z}} \left[\sum_{r\in\mathbb{Z}}\sum_{s\in\mathbb{Z}}
G_{rst}\mathrm{e}^{i2\pi r x} 
\mathrm{e}^{i2\pi s y}
\right]\mathrm{e}^{i2\pi t z}
\end{equation}
where
$$G_{rst}=
\sum_k \boldsymbol{W}_k\boldsymbol{v}_k\exp[w_k (n_0-1)] C(r-u'_k)  C(s-v'_k) C(t-w'_k).$$
This indicates that each weighted visibility should be multiplied by $\exp[w_k (n_0-1)]$ before being convolved in all three directions using the separable product $C(u')C(v')C(w')$ of gridding functions in each direction. Two of the inverse Fourier transforms, those along the $x$ and $y$ directions can be performed using the FFT algorithm, since they need to be evaluated on uniformly spaced grids. The final inverse transform (involving $t$) is not evaluated on a uniformly spaced grid of $z$ points and is not computed using the FFT.

Instead, the equation relating $z$ to $x$ and $y$ is used to find the phase factors $\exp(i2\pi t z(x,y))$ which are multiplied by the two-dimensional FFTs while they are summed up.  

The values of $l_{\rm{range}}$ and $m_{\rm{range}}$ are set by the region of the sky to be mapped. Since the convolutional gridding function $C$ is designed so that high accuracy is achieved within $|x|\leq x_0$ and $|y|\leq y_0$, the map is finally cropped to $|l|\leq x_0 l_{\rm{range}}$ and $|m|\leq x_0 m_{\rm{range}}.$ One would in practice determine $l_{\rm{range}}$ and $m_{\rm{range}}$ needed to give the desired final size.
The numbers of points $N_x\times N_y$ for the two-dimensional FFTs are chosen so that they cover the non-zero terms in the sums over $r$ and $s$ in equation for the gridded dirty map. This requires that $N_x\geq (u_{\rm{max}} - u_{\rm{min}})l_{\rm{range}}+W$ and $N_y\geq (v_{\rm{max}} - v_{\rm{min}})m_{\rm{range}}+W$ where $u$ and $v$ are measured in wavelengths and $W$  defines the support of the gridding functions ($C$ vanishes outside the interval $[-W/2, W/2]$).

It finally remains to determine the values of $n_0$ and $n_{\rm{range}}$, which will also determine the number of sheets in the $w$-stack (i.e., the number of non-zero terms in the sum over $t$). If we use the same gridding convolution and correction functions in all three directions, the results will be accurate within $|x|\leq x_0$, $|y|\leq x_0$ and $|z|\leq x_0$. From the equation $z(x,y)$ for the celestial sphere, we see that $z$ is maximum at the map center where $x=y=0$ and is a minimum at the corners of the map where $x=y=x_0$. In particular, 
$$z_{\rm{max}} = \frac{1-n_0}{n_{\rm{range}}}\quad\text{and}\quad
z_{\rm{min}} = \frac{\sqrt{1-(x_0 l_{\rm{range}})^2-(x_0 m_{\rm{range}})^2}-n_0}{n_{\rm{range}}}$$

In order to best use the region of accuracy of the gridding convolution function, we choose $z_{\rm{min}}=-x_0$ and $z_{\rm{max}}=x_0$. This leads to
$$n_{\rm{range}} = \frac{1-\sqrt{1-(x_0 l_{\rm{range}})^2-(x_0 m_{\rm{range}})^2}}{2x_0}\quad\text{and}\quad
n_0=1-x_0 n_{\rm{range}}$$
which are the promised expressions for the quantities used in the above.

In the previous version of this code, we chose $z_{\rm{min}}=-x_0$ and $z_{\rm{max}}=0$ which places $z=0$ at the center of the map. With this choice
$$n_{\rm{range}} = \frac{1-\sqrt{1-(x_0 l_{\rm{range}})^2-(x_0 m_{\rm{range}})^2}}{x_0}\quad\text{and}\quad
n_0=1$$

The separation between the layers in the $w$-stack is $1/n_{\rm{range}}$ wavelengths and so the minimum of number of layers required is
\begin{equation}\label{eq:N_{w'}+W}
N_{w'} = N_z \geq n_{\rm{range}}(w_{\rm{max}}-w_{\rm{min}}) + W
\end{equation}
where the additional $W$ layers are necessary to allow for the support of the gridding function.


### 2. Determine stack of w values to use

In [3]:
# Choose the support of the optimal gridding function to use. 
#  (These have been precomputed)
W = 7
M, x0, h = opt_funcs[W].M, opt_funcs[W].x0, opt_funcs[W].h

def calcWgrid(W, X_max, Y_max, w, x0=0.25, symm=True):
    """
    Calculate the layers of the w-stack
    Args:
        W (int): size of gridding convolution function
        X_max (float): Maximum direction cosine in L direction in final map
        Y_max (float): Maximum direction cosine in M direction in final map
        w (float array): w values (in wavelengths) of visibilities
        x0: portion of map to be retained
        symm: If False, place z=0 on celestial sphere, if True optimize position of z=0
               to minimize the number of layers
    Return:
        n0: value of direction cosine in N direction at which to optimize error
        w_values: w values of stack onto which visibilities are gridded
        dw: Separation between w_values
    """
    if symm:
        n_range = (1-np.sqrt(1-(X_max)**2-(Y_max)**2))/(2*x0)
        n0 = 1.0 - x0*n_range
    else:
        n_range = (1-np.sqrt(1-(X_max)**2-(Y_max)**2))/x0
        n0 = 1.0        
    dw = 1.0/n_range
    nlayers = int(np.ceil((np.max(w) - np.min(w))/dw) + W)
    wmid = 0.5*(np.max(w) + np.min(w))
    wrange = (nlayers - 1) * dw
    w_values = np.linspace(wmid-0.5*wrange, wmid+0.5*wrange, nlayers)
    print ("We will have", len(w_values), "w-planes")
    return n0, w_values, dw

n0, w_values, dw = calcWgrid(W, X_max, Y_max, w, x0, symm=True)

We will have 42 w-planes


### 3 3D Gridding + Imaging + Correcting

To know more about gridding, you can refer to https://github.com/zoeye859/Imaging-Tutorial 
#### Calculating gridding values for w respectively

In [4]:
Nfft = 600
im_size = 600

ind = find_nearestw(w_values, w)
C_w = cal_grid_w(w, w_values, ind, dw, W, h, M)

Elapsed time during the w gridding value calculation in seconds: 10.0625


#### Gridding on w-axis

In [5]:
def grid_w(V, u, v, w, C_w, w_values, W, Nw_2R, idx, n0=1.0):
    """
    Grid on w-axis
    Args:
        V (np.narray): visibility data
        u (np.narray): u of the (u,v,w) coordinates
        v (np.narray): v of the (u,v,w) coordinates
        w (np.narray): w of the (u,v,w) coordinates
        Nw_2R (int): number of w-planes used
        W (int): support width of the gridding function
        w_values (list): w values for all w-planes would be formed
        idx (list): the index of the nearest w plane that this w value would be assigned to
        dw (float): difference between two neighbouring w-planes
        C_w (list): the list of gridding weights for the w array
    """
    n_uv = len(V)
    bEAM = np.ones(n_uv)
    V_wgrid = np.zeros((Nw_2R,1),dtype = np.complex_).tolist()
    beam_wgrid = np.zeros((Nw_2R,1),dtype = np.complex_).tolist()
    u_wgrid = np.zeros((Nw_2R,1)).tolist()
    v_wgrid = np.zeros((Nw_2R,1)).tolist()
    t_start = process_time() 
    idx_floor = find_floorw(w_values, w)

    for k in range(n_uv):
        C_wk = C_w[k]
        if W % 2 == 1:
            w_plane = idx[k]
        else:
            w_plane = idx_floor[k]
        j = 0
        for n in range(-W//2+1,-W//2+1+W):
            #print (k, w_plane+n, C_wk[j,0], V[k])
            V_wgrid[w_plane+n] += [C_wk[j,0] * V[k] * np.exp(2j*np.pi*w[k]*(n0-1.0))]
            u_wgrid[w_plane+n] += [u[k]]
            v_wgrid[w_plane+n] += [v[k]]
            beam_wgrid[w_plane+n] += [C_wk[j,0] * bEAM[k] * np.exp(2j*np.pi*w[k]*(n0-1.0))]
            j+=1

    for i in range(Nw_2R):
        del(V_wgrid[i][0])
        del(u_wgrid[i][0])
        del(v_wgrid[i][0])
        del(beam_wgrid[i][0])

    t_stop = process_time()   
    print("Elapsed time during the w-gridding calculation in seconds:", t_stop-t_start)   
    return V_wgrid, u_wgrid, v_wgrid, beam_wgrid


In [6]:
V_wgrid, u_wgrid, v_wgrid, beam_wgrid = grid_w(V, u, v, w, C_w, w_values, W, len(w_values), ind, n0)

Elapsed time during the w-gridding calculation in seconds: 1.25


#### Imaging

In [7]:
def FFTnPShift(V_grid, ww, X, Y, im_size, x0=0.25, n0=1.0):
    """
    FFT the gridded V_grid, and apply a phaseshift to it
    Args:
        V_grid (np.narray): gridded visibility on a certain w-plane
        ww (np.narray): the value of the w-plane we are working on at the moment
        im_size (int): the image size, it is to be noted that this is before the image cropping
        x_0 (float): central 2*x_0*100% of the image will be retained    
        X (np.narray): X or l in radius on the image plane
        Y (np.narray): Y or m in radius on the image plane
    Returns:
        I (np.narray): the FFT and phaseshifted image
    """
    print ('FFTing...')
    I = np.fft.ifftshift(np.fft.ifftn(np.fft.ifftshift(V_grid)))
    I_cropped = image_crop(I, im_size)
    I_size = int(im_size*2*x0)
    I_FFTnPShift = np.zeros((I_size,I_size),dtype = np.complex_)
    print ('Phaseshifting...')
    for l_i in range(0,I_size):
        for m_i in range(0,I_size):
            ll = X[l_i]
            mm = Y[m_i]
            nn = np.sqrt(1 - ll**2 - mm**2)
            I_FFTnPShift[l_i,m_i] = np.exp(2j*np.pi*ww*(nn-n0))*I_cropped[l_i,m_i]
    return I_FFTnPShift


In [8]:
I_size = int(im_size*2*x0)
I_image = np.zeros((I_size,I_size),dtype = np.complex_)
B_image = np.zeros((I_size,I_size),dtype = np.complex_)

t2_start = process_time() 
for w_ind in range(len(w_values)):
    print ('Gridding the ', w_ind, 'th level facet out of ', len(w_values),' w facets.\n')
    V_update = np.asarray(V_wgrid[w_ind])
    u_update = np.asarray(u_wgrid[w_ind])
    v_update = np.asarray(v_wgrid[w_ind])
    beam_update = np.asarray(beam_wgrid[w_ind])
    V_grid, B_grid = grid_uv(V_update, u_update, v_update, beam_update, W, im_size, X_max, X_min, Y_max, Y_min, h, M)
    print ('FFT the ', w_ind, 'th level facet out of ', len(w_values),' w facets.\n')
    I_image += FFTnPShift(V_grid, w_values[w_ind], X, Y, im_size, x0, n0)
    B_image += FFTnPShift(B_grid, w_values[w_ind], X, Y, im_size, x0, n0)
    B_grid = np.zeros((im_size,im_size),dtype = np.complex_) 
    V_grid = np.zeros((im_size,im_size),dtype = np.complex_)
    
t2_stop = process_time()   
print("Elapsed time during imaging in seconds:", t2_stop-t2_start)  

Gridding the  0 th level facet out of  42  w facets.

Elapsed time during the u/v gridding value calculation in seconds: 0.28125
Elapsed time during the u/v gridding value calculation in seconds: 0.28125
FFT the  0 th level facet out of  42  w facets.

FFTing...
Phaseshifting...
FFTing...
Phaseshifting...
Gridding the  1 th level facet out of  42  w facets.

Elapsed time during the u/v gridding value calculation in seconds: 1.21875
Elapsed time during the u/v gridding value calculation in seconds: 1.125
FFT the  1 th level facet out of  42  w facets.

FFTing...
Phaseshifting...
FFTing...
Phaseshifting...
Gridding the  2 th level facet out of  42  w facets.

Elapsed time during the u/v gridding value calculation in seconds: 1.671875
Elapsed time during the u/v gridding value calculation in seconds: 2.0625
FFT the  2 th level facet out of  42  w facets.

FFTing...
Phaseshifting...
FFTing...
Phaseshifting...
Gridding the  3 th level facet out of  42  w facets.

Elapsed time during the u/v

#### Rescale and display uncorrected map

In [9]:
I_image_now = image_rescale(I_image,im_size, n_uv)
B_image_now = image_rescale(B_image,im_size, n_uv)
plt.figure()
plt.imshow(np.rot90(I_image_now.real,1), origin = 'lower')
plt.xlabel('Image Coordinates X')
plt.ylabel('Image Coordinates Y')
plt.title('Dirty map before any gridding correction')
B_image_now[150,150]

<IPython.core.display.Javascript object>

(0.3835329999321081-9.108364306113661e-10j)

#### Apply correcting functions h(x)h(y) along x and y axis

In [10]:
Nfft = 600
# Use these for calculating gridding correction on the FFT grid
M = 32
I_xycorrected = xy_correct(I_image_now, opt_funcs[W], im_size, x0=0.25)
B_xycorrected = xy_correct(B_image_now, opt_funcs[W], im_size, x0=0.25)

In [11]:
plt.figure()
plt.imshow(np.rot90(I_xycorrected.real/I_image_now.real,1), origin = 'lower')
plt.colorbar()
plt.title("Gridding Correction in x and y directions")

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'Gridding Correction in x and y directions')

In [12]:
plt.figure()
plt.imshow(np.rot90(I_xycorrected.real,1), origin = 'lower')
plt.xlabel('Image Coordinates X')
plt.ylabel('Image Coordinates Y')
plt.title('Dirty map after gridding correction in x and y directions')
B_xycorrected[150,150]

<IPython.core.display.Javascript object>

(0.38353299993210943-9.108364306113694e-10j)

#### Correcting function on z axis

In [13]:
class LookupTable:
    """
    Create lookup table for polynomial interpolation of specified degree. The
    function to be interpolated is evaluated at points in [xstart+k*dx] for 
    k=0,1,...,N-1 and the values are in f[0], f[1],...,f[N-1]
    """
    def __init__(self, xstart, dx, fvals, degree):
        self.xstart = xstart
        self.dx = dx
        self.degree = degree
        fcopy = np.array(fvals, dtype=float)
        self.table = [fcopy]
        for d in range(degree):
            fcopy = np.diff(fcopy, 1)
            self.table.append(fcopy)
        
    def interp(self, x):
        loc = (x-self.xstart)/self.dx
        pt = np.asarray(np.floor(loc), dtype=np.int)
        if np.any((pt<0) | (pt>=len(self.table[self.degree]))):
            raise ValueError("Outside range of lookup table")
        ft = loc - pt
        # Perform polynomial interpolation
        weights = self.table[0][pt].copy()
        factor = 1
        for k in range(self.degree):
            factor *= (ft - k) / (k + 1)
            weights += self.table[k + 1][pt] * factor
        return weights

def setup_lookup_table(opt_func, Nfine, degree):
    xfine = np.linspace(0.0, x0*(1 + degree/Nfine), Nfine)
    hfine = get_grid_correction(opt_func, xfine)
    lut = LookupTable(xfine.min(), xfine[1]-xfine[0], hfine, degree)
    return lut

def z_correct_cal(lut, X_min, X_max, Y_min, Y_max, dw, h, im_size, W, M, x0, n0=1):
    """
    Return:
        Cor_gridz (np.narray): correcting function on z-axis
    """ 
    I_size = int(im_size*2*x0)
    nu, x = make_evaluation_grids(W, M, I_size)
    gridder = calc_gridder(h, x0, nu, W, M)
    grid_correction = gridder_to_grid_correction(gridder, nu, x, W)
    h_map = np.zeros(im_size, dtype=float)
    h_map[I_size:] = grid_correction[:I_size]
    h_map[:I_size] = grid_correction[:0:-1]
    xrange = X_max - X_min
    yrange = Y_max - Y_min
    ny = im_size
    nx = im_size
    l_map = np.linspace(X_min, X_max, nx+1)[:nx]/(2*x0)
    m_map = np.linspace(Y_min, Y_max, ny+1)[:ny]/(2*x0)
    ll, mm = np.meshgrid(l_map, m_map)
    # Do not allow NaN or values outside the x0 for the optimal function
    z = abs(dw*(np.sqrt(np.maximum(0.0, 1. - ll**2 - mm**2))-n0))
    z[z > x0] = x0 

    fmap = lut.interp(z)
    Cor_gridz = image_crop(fmap, im_size, x0)
    return Cor_gridz

lut = setup_lookup_table(opt_funcs[W], Nfine=256, degree=7)
Cor_gridz = z_correct_cal(lut, X_min, X_max, Y_min, Y_max, dw, h, im_size, W, M, x0, n0)
I_zcorrected = z_correct(I_xycorrected, Cor_gridz, im_size, x0=0.25)
B_zcorrected = z_correct(B_xycorrected, Cor_gridz, im_size, x0=0.25)

In [14]:
plt.figure()
plt.imshow(np.rot90(Cor_gridz,1), origin = 'lower')
plt.colorbar()
plt.title("Gridding Correction in z direction")

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'Gridding Correction in z direction')

In [15]:
plt.figure()
plt.imshow(np.rot90(I_zcorrected.real,1), origin = 'lower')
plt.xlabel('Image Coordinates X')
plt.ylabel('Image Coordinates Y')
plt.title('Dirty map after gridding correction in x, y and z directions')
B_zcorrected[150,150]

<IPython.core.display.Javascript object>

(1.0000000010746588-2.3748580480726253e-09j)

### 4 DFT and FFT dirty image difference

In [16]:
I_DFT = np.loadtxt('I_DFT_simul300.csv', delimiter = ',')

In [17]:
I_dif = I_DFT - I_zcorrected.real
rms = RMS(I_dif, im_size, 0.5, x0=0.25)
plt.figure()
plt.imshow(np.rot90(I_dif,1), origin = 'lower')
plt.colorbar()
plt.xlabel('Image Coordinates X')
plt.ylabel('Image Coordinates Y')
plt.title('Difference from DFT map')
plt.show()

<IPython.core.display.Javascript object>

In [18]:
print (rms)

2.5136074547845503e-08
