In [50]:
import numpy as np
import numba as nb
import matplotlib.pyplot as plt
import corrcal
import hera_sim
import healpy
import os
import vis_cpu
from astropy import constants, units
from astropy.time import Time
from astropy.coordinates import Longitude, Latitude, EarthLocation
from scipy.optimize import minimize
from scipy import special
from pyuvdata import UVBeam
from pyradiosky import SkyModel
from pyuvsim import AnalyticBeam
%matplotlib inline

In [2]:
beam_file = "NF_HERA_Dipole_efield_beam.fits"
# hpx_beam_file = "NF_HERA_dipole_linpol_power_healpix128.fits"
# gleam_file = "gleam-120.02-127.34MHz-nf-76-pld.skyh5"
# egsm_file = "egsm_pred_50_200mhz.h5.npz"
# array_file = "array_layout.csv"
# haslam_file = "haslam408_dsds_Remazeilles2014.fits"
# for f in (beam_file, hpx_beam_file, gleam_file, egsm_file, array_file, haslam_file):
#     assert os.path.exists(f)

In [3]:
sep = 14.6
diameter = 10
beam = AnalyticBeam("airy", diameter=diameter)
# beam = UVBeam.from_file(beam_file)
# beam.interpolation_function = "az_za_simple"
# beam.freq_interp_kind = "cubic"
# setattr(beam, "type", "NF_Dipole")

def square_array(side_len=1, sep=1):
    ant_ind = 0
    antpos = {}
    for i in range(side_len):
        for j in range(side_len):
            antpos[ant_ind] = np.array([i, j, 0], dtype=float) * sep
            ant_ind += 1
    return antpos
    
array_layout = square_array(side_len=8, sep=10)
add_jitter = False
if add_jitter:
    jitter_amp = 0.05
    radii = np.random.normal(size=len(array_layout), loc=0, scale=jitter_amp*sep)
    angles = np.random.uniform(low=0, high=2*np.pi, size=len(array_layout))
    dx = radii * np.exp(1j*angles)
    for ant, pos in array_layout.items():
        array_layout[ant] = pos + np.array([dx[ant].real, dx[ant].imag, 0])

n_freqs = 1
start_freq = 150e6
channel_width = 100e3
n_times = 1
start_time = 2458099.28
integration_time = 60 * 15
uvdata = hera_sim.io.empty_uvdata(
    array_layout=array_layout,
    Nfreqs=n_freqs,
    start_freq=start_freq,
    channel_width=channel_width,
    Ntimes=n_times,
    start_time=start_time,
    integration_time=integration_time,
#     polarization_array=np.array(['xx','yy']),
)

Cannot check consistency of a string-mode BeamList! Set force=True to force consistency checking.
Fixing phases using antenna positions.


In [4]:
lst = uvdata.lst_array[0]
n_side = 128
n_pix = healpy.nside2npix(n_side)
lat, lon, _ = uvdata.telescope_location_lat_lon_alt
stokes = np.zeros((4,1,n_pix), dtype=float)
stokes[0,0] = np.random.normal(size=n_pix)
stokes[0,0] += np.abs(stokes[0,0].min())

In [5]:
sky_model = SkyModel(
    stokes=stokes*units.Jy/units.sr,
    spectral_type="flat",
    component_type="healpix",
    nside=n_side,
    hpx_inds=np.arange(n_pix),
)

In [6]:
beam_ids = [0,] * len(array_layout)

In [7]:
data_model = hera_sim.visibilities.ModelData(
    uvdata=uvdata,
    sky_model=sky_model,
    beam_ids=beam_ids,
    beams=[beam,],
)

In [8]:
simulation = hera_sim.visibilities.VisibilitySimulation(
    data_model=data_model,
    simulator=hera_sim.visibilities.VisCPU(),
)

In [9]:
simulation.simulate();

In [10]:
lst = uvdata.lst_array[0]
n_src = 100
lat, lon, _ = uvdata.telescope_location_lat_lon_alt
dra = 10 * units.deg.to("rad")
ddec = 10 * units.deg.to("rad")
ra = np.random.normal(size=n_src, loc=lst, scale=dra)
dec = np.random.normal(size=n_src, loc=lat, scale=ddec)
src_stokes = np.zeros((4,1,n_src), dtype=float)
src_stokes[0,0] = 10**np.random.uniform(low=-1,high=1,size=n_src)

In [11]:
sky_model = SkyModel(
    name=np.arange(n_src).astype(str),
    ra=Longitude(ra*units.rad),
    dec=Latitude(dec*units.rad),
    stokes=src_stokes*units.Jy,
    spectral_type="flat",
    component_type="point",
)

In [12]:
src_uvdata = hera_sim.io.empty_uvdata(
    array_layout=array_layout,
    Nfreqs=n_freqs,
    start_freq=start_freq,
    channel_width=channel_width,
    Ntimes=n_times,
    start_time=start_time,
    integration_time=integration_time,
#     polarization_array=np.array(['xx','yy']),
)

Cannot check consistency of a string-mode BeamList! Set force=True to force consistency checking.
Fixing phases using antenna positions.


In [13]:
data_model = hera_sim.visibilities.ModelData(
    uvdata=src_uvdata,
    sky_model=sky_model,
    beam_ids=beam_ids,
    beams=[beam,],
)

In [14]:
simulation = hera_sim.visibilities.VisibilitySimulation(
    data_model=data_model,
    simulator=hera_sim.visibilities.VisCPU(),
)

In [15]:
simulation.simulate();

In [16]:
uvdata.data_array += src_uvdata.data_array

In [17]:
reds, _, lens, conj = uvdata.get_redundancies(include_conjugates=True)
conj = set(conj)

In [18]:
data = []
ant_1_array = []
ant_2_array = []
edges = [0,]
idx = 0
min_length = np.sqrt(2) * diameter
for group, length in zip(reds, lens):
    if (length <= min_length) or (len(group) < 5):
        continue
    for bl in group:
        ai, aj = uvdata.baseline_to_antnums(bl)
        if bl in conj:
            ai, aj = aj, ai
        data.append(np.atleast_1d(uvdata.get_data(ai,aj,"xx").squeeze()))
        ant_1_array.append(ai)
        ant_2_array.append(aj)
        idx += 1
    edges.append(idx)
data = np.asarray(data).squeeze()
ant_1_array = np.asarray(ant_1_array)
ant_2_array = np.asarray(ant_2_array)
edges = np.asarray(edges)

In [19]:
pos, ants = uvdata.get_ENU_antpos()

In [20]:
snr = 1000
noise_amp = np.abs(data).max() / snr
noise = np.eye(uvdata.Nbls, dtype=complex) * noise_amp**2
noise_diag = np.diag(noise)

n_ants = ants.size
err = 0.05
re_gain = np.random.normal(loc=1, size=n_ants, scale=err)
im_gain = np.random.normal(loc=0, size=n_ants, scale=err)
split_gains = np.zeros(2*n_ants, dtype=float)
split_gains[::2] = re_gain
split_gains[1::2] = im_gain

In [21]:
lat, lon, _ = uvdata.telescope_location_lat_lon_alt
sky_alms = healpy.map2alm(0.5*stokes[0,0], use_pixel_weights=True)
healpy.rotate_alm(sky_alms, lon, np.pi/2-lat)
sky_power = healpy.anafast(healpy.alm2map(sky_alms, n_side), use_pixel_weights=True)

In [22]:
power_beam = AnalyticBeam("airy", diameter=diameter)
power_beam.efield_to_power()

In [29]:
ecef_antpos = uvdata.antenna_positions
enu_antpos = uvdata.get_ENU_antpos()[0]
uvws = uvdata.freq_array[0,0] * (
    enu_antpos[ant_2_array] - enu_antpos[ant_1_array]
) / constants.c.value
za, az = healpy.pix2ang(n_side, np.arange(healpy.nside2npix(n_side)))
az = Longitude((az - np.pi/2)*units.rad).value
beam_vals = power_beam.interp(az, za, np.atleast_1d(uvdata.freq_array[0,0]))[0][0,0,-1,0]
beam_vals[za>np.pi/2] = 0
local_crd = np.array(
    [np.sin(za)*np.cos(az), np.sin(za)*np.sin(az), np.cos(za)]
)
# Calculating the fringe this way is suuuuper memory intensive
# fringe = np.exp(-2j * np.pi * uvws @ local_crd)
diff_mat =  np.zeros((edges[-1], edges.size - 1), dtype=complex)
ells = np.arange(sky_power.size)
scaling = 2 * ells + 1
for grp in range(diff_mat.shape[1]):
    start, stop = edges[grp:grp+2]
    fringe = np.exp(-2j * np.pi * uvws[start:start+1] @ local_crd)
    fringed_beam = beam_vals * fringe
    beam_spectrum = healpy.anafast(fringed_beam, use_pixel_weights=True)
    cov_amp = np.sum(sky_power * scaling * beam_spectrum)
    diff_mat[start:stop,grp] = np.sqrt(cov_amp)

Casting complex values to real discards the imaginary part


In [30]:
eq2tops = vis_cpu.conversions.eci_to_enu_matrix(lst, lat)
src_crd_eq = vis_cpu.conversions.point_source_crd_eq(
    sky_model.ra.value, sky_model.dec.value
)
local_crd = eq2tops @ src_crd_eq
az, za = vis_cpu.conversions.enu_to_az_za(
    *local_crd[:2],
    orientation="uvbeam",
)

In [32]:
uvws = uvdata.freq_array[0,0] * (pos[ant_2_array] - pos[ant_1_array]) / constants.c.value
phases = 2 * np.pi * uvws @ local_crd
beam_vals = beam.interp(az, za, uvdata.freq_array[0,0])[0].squeeze().transpose(2,0,1)
beam_vals = beam_vals @ beam_vals.transpose(0,2,1).conj()
beam_vals = beam_vals[:,0,0]
above_horizon = za < np.pi / 2
src_fluxes = 0.5 * sky_model.stokes[0,0,:].value * beam_vals
src_fluxes = src_fluxes[above_horizon]
src_mat = src_fluxes[None,:] * np.exp(1j * phases)
pct_flux_to_keep = 95
flux_cut = 1e-2 * (100-pct_flux_to_keep) * src_fluxes.max()
select = src_fluxes >= flux_cut
src_mat = src_mat[:,select]

In [33]:
class Cov:
    def __init__(self, noise, diff_mat, src_mat, gains, edges, ant_1_inds, ant_2_inds):
        self.noise = noise
        self.diff_mat = diff_mat
        self.src_mat = src_mat
        self.gains = gains
        self.edges = edges
        self.ant_1_inds = ant_1_inds
        self.ant_2_inds = ant_2_inds
        self.n_bls = diff_mat.shape[0]
        self.n_grp = len(edges) - 1
        self.n_eig = diff_mat.shape[1] // self.n_grp
        self.n_src = src_mat.shape[1]
    
    def inv(self, dense=False, return_det=False):
        return self._dense_inv(return_det) if dense else self._sparse_inv(return_det)
    
    def _dense_inv(self, return_det=False):
        gain_mat = self.build_gain_mat()
        sky_cov = self.src_mat@self.src_mat.T.conj() + self.diff_mat@self.diff_mat.T.conj()
        full_cov = self.noise + gain_mat[:,None]*sky_cov*gain_mat[None,:].conj()
        inv = np.linalg.inv(full_cov)
        if return_det:
            try:
                logdet = 2*np.sum(np.log(np.diag(np.linalg.cholesky(full_cov))))
            except np.linalg.LinAlgError:
                logdet = np.inf
            return inv, np.real(logdet)
        return inv
    
    def _sparse_inv(self, return_det=False):
        if return_det:
            logdet = 0
            
        gain_mat = self.build_gain_mat()
        Cinv = np.zeros((self.n_bls, self.n_bls), dtype=complex)
        GD = gain_mat[:,None] * self.diff_mat
        for grp, (start, stop) in enumerate(zip(self.edges, self.edges[1:])):
            left = grp * self.n_eig
            right = left + self.n_eig
            
            block = GD[start:stop, left:right].copy()
            block = self.noise[start:stop,start:stop] + block@block.T.conj()
            if return_det:
                logdet += 2*np.sum(np.log(np.diag(np.linalg.cholesky(block))))
            Cinv[start:stop,start:stop] = np.linalg.inv(block)
            
        GS = gain_mat[:,None] * self.src_mat
        CGS = Cinv @ GS
        tmp = np.eye(self.n_src) + GS.T.conj()@CGS
        Cinv -= CGS @ np.linalg.inv(tmp) @ CGS.T.conj()
        
        if return_det:
            logdet += 2 * np.sum(np.log(np.diag(np.linalg.cholesky(tmp))))
            return Cinv, np.real(logdet)
        return Cinv
       
    def build_gain_mat(self):
        gains = self.gains[::2] + 1j*self.gains[1::2]
        return gains[self.ant_1_inds] * gains[self.ant_2_inds].conj()

In [34]:
def nll(gains, cov, data, scale=1, phs_norm_fac=1):
    """Negative log-likelihood."""
    cov.gains = gains / scale
    cinv, logdet = cov.inv(dense=False, return_det=True)
    chisq = data.conj() @ cinv @ data
    # Use a Gaussian prior that the phase should be nearly zero for all antennas
    phases = np.arctan2(cov.gains[1::2], cov.gains[::2])
    phs_norm = phases.sum()**2 / phs_norm_fac**2
#     phs_norm = cov.gains[1::2].sum()**2 / phs_norm_fac**2
    return np.real(chisq) + logdet + phs_norm

In [80]:
@nb.njit
def accumulate_grad(gains, gain_mat, src_mat, diff_mat, ant_1_inds, ant_2_inds, cinv, cinv_data, prefac):
    # surprise, turns out the dump implementation (this function)
    # is slower than what I had already done
    grad_nll = np.zeros(2*gains.size, dtype=float)
    for k in nb.prange(2*gains.size):
        grad_chisq = 0
        grad_gains = np.zeros_like(gain_mat)
        if k%2 == 0:  # derivative wrt real gain
            grad_gains += np.where(ant_1_inds == k//2, gains[ant_2_inds].conj(), 0)
            grad_gains += np.where(ant_2_inds == k//2, gains[ant_1_inds], 0)
            grad_phs_norm = -prefac[k//2] * gains[k//2].imag / gains[k//2].real
        else:
            grad_gains += np.where(ant_1_inds == k//2, 1j*gains[ant_2_inds].conj(), 0)
            grad_gains += np.where(ant_2_inds == k//2, -1j*gains[ant_1_inds], 0)
            grad_phs_norm = prefac[k//2]
        tmp1 = grad_gains.reshape(-1,1) * src_mat
        tmp2 = src_mat.T.conj() * gain_mat.reshape(1,-1).conj()
        grad_cov = tmp1 @ tmp2
        tmp1 = cinv_data.conj() @ tmp1
        tmp2 = tmp2 @ cinv_data
        grad_chisq += tmp1 @ tmp2
        tmp1 = grad_gains.reshape(-1,1) * diff_mat
        tmp2 = diff_mat.T.conj() * gain_mat.reshape(1,-1).conj()
        grad_cov += tmp1 @ tmp2
        tmp1 = cinv_data.conj() @ tmp1
        tmp2 = tmp2 @ cinv_data
        grad_chisq += tmp1 @ tmp2
        grad_cov = grad_cov + grad_cov.T.conj()
        grad_logdet = np.sum(cinv * grad_cov.T)
        # actually calculated -grad_chisq/2, so this is right
        grad_nll[k] = np.real(grad_logdet - 2*grad_chisq + grad_phs_norm)
    return grad_nll

In [83]:
def grad_nll(gains, cov, data, scale=1, phs_norm_fac=1):
    """Gradient of negative log-likelihood."""
    cov.gains = gains / scale
    complex_gains = (gains[::2] + 1j*gains[1::2]) / scale
    gain_mat = cov.build_gain_mat()
    cinv = cov.inv(dense=False, return_det=False)
    cinv_data = cinv @ data
    grad_nll = np.zeros_like(gains)
    tan_phs = cov.gains[1::2] / cov.gains[::2]
    phases = np.arctan2(cov.gains[1::2], cov.gains[::2])
    grad_phs_prefac = 2 * phases.sum() / (phs_norm_fac**2 * cov.gains[::2] * (1+tan_phs**2))
#     grad_nll = accumulate_grad(
#         complex_gains,
#         gain_mat,
#         cov.src_mat,
#         cov.diff_mat,
#         cov.ant_1_inds,
#         cov.ant_2_inds,
#         cinv,
#         cinv_data,
#         grad_phs_prefac,
#     )
    for k in range(gains.size):
        grad_chisq = 0
        grad_gains = np.zeros_like(gain_mat)
        if k%2 == 0:  # derivative wrt real gain
            grad_gains += np.where(cov.ant_1_inds == k//2, complex_gains[cov.ant_2_inds].conj(), 0)
            grad_gains += np.where(cov.ant_2_inds == k//2, complex_gains[cov.ant_1_inds], 0)
            grad_phs_norm = -grad_phs_prefac[k//2] * cov.gains[k+1] / cov.gains[k]
#             grad_phs_norm = 0
        else:
            grad_gains += np.where(cov.ant_1_inds == k//2, 1j*complex_gains[cov.ant_2_inds].conj(), 0)
            grad_gains += np.where(cov.ant_2_inds == k//2, -1j*complex_gains[cov.ant_1_inds], 0)
            grad_phs_norm = grad_phs_prefac[k//2]
#             grad_phs_norm = 2 * cov.gains[1::2].sum() / phs_norm_fac**2
        tmp1 = grad_gains[:,None] * cov.src_mat
        tmp2 = cov.src_mat.T.conj() * gain_mat[None,:].conj()  # don't need to do every iter
        grad_cov = tmp1 @ tmp2
        tmp1 = cinv_data.conj() @ tmp1
        tmp2 = tmp2 @ cinv_data
        grad_chisq += tmp1 @ tmp2
        tmp1 = grad_gains[:,None] * cov.diff_mat
        tmp2 = cov.diff_mat.T.conj() * gain_mat[None,:].conj()  # same here
        grad_cov += tmp1 @ tmp2
        tmp1 = cinv_data.conj() @ tmp1
        tmp2 = tmp2 @ cinv_data
        grad_chisq += tmp1 @ tmp2
        grad_cov = grad_cov + grad_cov.T.conj()
        grad_logdet = np.sum(cinv * grad_cov.T)
        grad_nll[k] = np.real(grad_logdet - 2*grad_chisq + grad_phs_norm)  # actually calculated -grad_chisq/2, so this is right
    return grad_nll / scale

In [84]:
cov = Cov(
    noise=noise,
    diff_mat=diff_mat,
    src_mat=src_mat,
    gains=split_gains,
    edges=edges,
    ant_1_inds=ant_1_array,
    ant_2_inds=ant_2_array,
)

In [85]:
scale = 1
phs_norm_fac = 1e-1

In [40]:
%%timeit
nll(scale*split_gains, cov, data, scale, phs_norm_fac)

180 ms ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [86]:
%%timeit
grad_nll(scale*split_gains, cov, data, scale, phs_norm_fac)

21.7 s ± 209 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
cov.diff_mat.shape

(1858, 94)

In [45]:
cov.src_mat.shape

(1858, 23)

Non-optimized results for 64-element uniform square array:
<br>
likelihood eval: ~180 ms $\pm$ 13.7 ms
<br>
gradient eval: ~21.8 s $\pm$ 305 ms
<br>
typical number of evaluations: few hundred each? 
<br>
-> 1.5 hours per calibration solution
<br>
-> obscenely long to calibrate full band and 1 sidereal day

In [96]:
result = minimize(
    nll,
    scale*split_gains,
    args=(cov, data, scale, phs_norm_fac),
    method="CG",
    jac=grad_nll,
)

KeyboardInterrupt: 

In [None]:
result.nit, result.nfev, result.njev

In [None]:
fit = (result.x[::2] + 1j*result.x[1::2]) / scale
def err(val, exp):
    return np.sqrt(np.mean((val - exp)**2))
avg_err_before = err(split_gains[::2], 1) + 1j*err(split_gains[1::2], 0)
avg_err = err(fit.real, 1) + 1j*err(fit.imag, 0)
fit_std = np.std(fit.real) + 1j*np.std(fit.imag)
lim = 0.1
plt.figure(figsize=(8,5),dpi=150, facecolor='white')
title = f"Avg Error Before: {avg_err_before:.3f}\n"
title += f"Avg Error: {avg_err:.3f}\n"
title += f"Residual Standard Deviation: {fit_std:.3f}\n"
title += f"Average Fit Amplitude: {np.mean(np.abs(fit)):.3f}\n"
title += f"Average Fit Phase: {np.mean(np.angle(fit)):.3f}"
plt.title(title)
plt.plot(split_gains[::2] - 1, 'kx', label='Guess, Real')
plt.plot(fit.real - 1, 'k+', label='Fit, Real')
plt.plot(split_gains[1::2], 'rx', label='Guess, Imag')
plt.plot(fit.imag, 'r+', label='Fit, Imag')
plt.ylim(-lim,lim)
plt.axhline(0, color='gray', ls=':')
plt.legend(ncol=2)
plt.xlabel("Antenna Index")
plt.ylabel("Gain Residual")
plt.xticks(np.arange(n_ants));
basename = f"gain_solutions_diffuse_sky_with_point_sources_{beam.type}_beam"
# plt.savefig(f"{basename}_without_noise.png", dpi=150, bbox_inches='tight')
# plt.savefig(f"{basename}_with_noise.png", dpi=150, bbox_inches="tight")