In [1]:
import numpy as np
cc = 2.99792458e10 # [cm s^-1]
from numpy import pi
arcsec = pi / (180. * 3600) # [radians] = 1/206265 radian/arcsec
from numpy.fft import fftfreq, rfftfreq, fftshift, fft2, ifftshift, ifft2
import scipy.optimize as optimize
from scipy.signal import convolve2d
from scipy.interpolate import interpn
import numpy.linalg as linalg
import casatools
from numba import njit, prange

# set up mpl params for pretty plots
import matplotlib.pylab as pl
from matplotlib.pyplot import *
import matplotlib.colors as col
from matplotlib.ticker import NullFormatter
from matplotlib.font_manager import FontProperties
font = FontProperties()
font.set_family("sans-serif")
from matplotlib import rc
rcParams['mathtext.default']='regular'
rcParams['axes.linewidth'] = 1.0
pl.style.use('classic')
pl.rcParams['hatch.linewidth'] = 0.5
pl.rc('text', usetex=True)
rcParams['font.size'] = 17
rcParams['font.family'] = 'serif'
rcParams['font.weight']='light'
rcParams['mathtext.bf'] = 'serif:normal'
pl.rcParams['xtick.major.pad']='2'
pl.rcParams['ytick.major.pad']='2'

In [2]:
@njit(fastmath=True)
def grid_wgts(gwgts, uu, vv, du, dv, npix, wgts):
    for i in np.arange(uu.shape[0]):
        gwgts[int(npix/2 + uu[i]/du + 0.5), int(npix/2 + vv[i]/dv + 0.5)] += wgts[i]
        gwgts[int(npix/2 - uu[i]/du + 0.5), int(npix/2 - vv[i]/dv + 0.5)] += wgts[i]
    return gwgts


@njit
def ungrid_wgts(gwgts, uu, vv, du, dv, npix):
    ugwgts = np.zeros(uu.shape[0])
    for i in np.arange(uu.shape[0]):
        ugwgts[i] = gwgts[int(npix/2 + uu[i]/du + 0.5), int(npix/2 + vv[i]/dv + 0.5)]
    return ugwgts

In [3]:
def grid_ms_singlechan(base_ms, npix, cell_size, chan=0):
    tb = casatools.table()
    ms = casatools.ms()

    # Use CASA table tools to get frequencies
    tb.open(base_ms+"/SPECTRAL_WINDOW")
    chan_freqs = tb.getcol("CHAN_FREQ")
    rfreq = tb.getcol("REF_FREQUENCY")
    tb.close()

    # Use CASA table tools to get columns of UVW, DATA, WEIGHT, etc.
    tb.open(base_ms, nomodify=False)
    flag   = tb.getcol("FLAG")
    sigma   = tb.getcol("SIGMA")
    uvw     = tb.getcol("UVW")
    weight  = tb.getcol("WEIGHT")
    ant1    = tb.getcol("ANTENNA1")
    ant2    = tb.getcol("ANTENNA2")
    tb.close()

    flag = np.logical_not(np.prod(flag, axis=(0,2)).T)

    # break out the u, v spatial frequencies, convert from m to lambda
    uu = uvw[0,:][:,np.newaxis]*chan_freqs[:,0]/(cc/100)
    vv = uvw[1,:][:,np.newaxis]*chan_freqs[:,0]/(cc/100)

    # toss out the autocorrelation placeholders
    xc = np.where(ant1 != ant2)[0]
    wgts = weight[0,:] + weight[1,:]

    uu_xc = uu[xc][:,flag]
    vv_xc = vv[xc][:,flag]
    wgts_xc = wgts[xc]

    dl = cell_size*arcsec
    dm = cell_size*arcsec

    du = 1./((npix)*dl)
    dv = 1./((npix)*dm)

    uvdist_grid = np.sqrt(np.add.outer(np.arange(-(npix/2.)*du, (npix/2.)*du, du)**2, np.arange(-(npix/2.)*dv, (npix/2.)*dv, dv)**2))
    frac_bw = (np.max(chan_freqs) - np.min(chan_freqs)) / rfreq
    corr_fac = frac_bw*uvdist_grid/du
    corr_fac[corr_fac<1] = 1.

    # grid the weights (with complex conjugates)
    gwgts_init = np.zeros((npix, npix))
    gwgts_init = grid_wgts(gwgts_init, uu_xc[:,chan], vv_xc[:,chan], du, dv, npix, wgts_xc)  

    gwgts_init_sq = gwgts_init**2
        
    return [uu_xc[:,chan], vv_xc[:,chan], du, dv, npix, wgts_xc], gwgts_init, gwgts_init_sq, corr_fac

In [7]:
def weight_data(ungrid_info, gwgts_init, gwgts_init_sq, corr_fac, robust, taper):
    uu_xc, vv_xc, du, dv, npix, wgts_xc = ungrid_info
    
    f_sq = ((5*10**(-robust))**2)/(np.sum(gwgts_init_sq)/(np.sum(wgts_xc)*2))
    gr_wgts = 1/(1+gwgts_init/corr_fac*f_sq)

    # multiply to get robust weights
    indexed_gr_wgts = ungrid_wgts(gr_wgts, uu_xc, vv_xc, du, dv, npix)
    wgts_robust = wgts_xc*indexed_gr_wgts
    wgts_robust_sq = wgts_xc*(indexed_gr_wgts)**2

    #get the total gridded weights (to make dirty beam)
    gwgts_final = np.zeros((npix, npix))
    gwgts_final = grid_wgts(gwgts_final, uu_xc[:,chan], vv_xc[:,chan], du, dv, npix, wgts_robust)           

    return gwgts_final, wgts_robust, wgts_robust_sq

In [8]:
def predict_beam(gwgts_final, cell_size):
    robust_beam = np.real(fftshift(fft2(fftshift(gwgts_final))))
    robust_beam /= np.max(robust_beam)
    beam_params = fit_beam_CASA(robust_beam, cell_size)
    return robust_beam, beam_params

In [9]:
def predict_rms(wgts_robust, wgts_robust_sq):    # calculate rms (formula from Briggs et al. 1995)
    C = 1/(2*np.sum(wgts_robust))
    rms = 2*C*np.sqrt(np.sum(wgts_robust_sq))
    return rms

In [None]:
def gaussian2D(params, nrow):
    width_x, width_y, rotation = params
    rotation = 90-rotation

    rotation = np.deg2rad(rotation)
    x, y = np.indices((nrow*2+1,nrow*2+1)) - nrow

    xp = x * np.cos(rotation) - y * np.sin(rotation)
    yp = x * np.sin(rotation) + y * np.cos(rotation)
    g = 1.*np.exp(-(((xp)/width_x)**2+((yp)/width_y)**2)/2.)
    return g

In [13]:
def make_fake_restored_image(clean_model_cube, robust_beam, beam_params, rms):
    npix = robust_beam.shape[0]
    small_beam = robust_beam[int(npix/2-8):int(npix/2+8), int(npix/2-8):int(npix/2+8)]
    fake_restored_image = np.zeros(clean_model_cube.shape)
    for chan in np.arange(clean_model_cube.shape[2]):
        noise = np.random.normal(scale=rms, size=chan.shape)
        gauss_kernel = gaussian2D(beam_params, 8)
        convolved_clean_model = convolve2d(clean_model_cube[chan], gauss_kernel)
        convolved_noise = convolve2d(noise, small_beam)
        fake_channel = convolved_clean_model + convolved_noise
        fake_restored_image[chan] = fake_channel
        
    return fake_restored_image        

In [14]:
def make_fake_restored_image_single(clean_model_channel, robust_beam, beam_params, rms):
    npix = robust_beam.shape[0]
    small_beam = robust_beam[int(npix/2-8):int(npix/2+8), int(npix/2-8):int(npix/2+8)]
    noise = np.random.normal(scale=rms, size=chan.shape)
    gauss_kernel = gaussian2D(beam_params, 8)
    
    convolved_clean_model = convolve2d(clean_model_channel, gauss_kernel)
    convolved_noise = convolve2d(noise, small_beam)
    fake_restored_image = convolved_clean_model + convolved_noise
     
    return fake_restored_image     