In [None]:
import healpy
from cora.util import hputil
import os
os.environ["OMP_NUM_THREADS"] = "8"
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import ticker
from ch_util import ephemeris as ephem, andata, tools
from caput.time import unix_to_skyfield_time
import h5py
import time
from datetime import datetime
from glob import glob
from ch_pipeline.core import telescope
from scipy.optimize import leastsq

%load_ext autoreload
%autoreload 2

from continuum_beam import *
from beam_utils import model_beam

%matplotlib inline
plt.rcParams.update({'figure.figsize': (16, 12), 'font.size': 20})

## Read simulated visibilities

In [None]:
xtalk_files = sorted(glob("/home/tristpm/def-krs/pipeline_dev/xtalk_vis[01].h5"))
#xtalk_files = sorted(glob("/home/tristpm/def-krs/pipeline_dev/xtalk_vis_nosky[01].h5"))
times = []
vis = []
for f in xtalk_files:
    with h5py.File(f) as fh:
        times.append(fh['index_map/time'][:])
        vis.append(fh['vis'][:])
times = np.concatenate(times)
vis = np.concatenate(vis, axis=-1)

In [None]:
with h5py.File(xtalk_files[0]) as fh:
    ns_baselines = (fh['index_map/prod'][:,1] % 64) - (fh['index_map/prod'][:,0] % 64)
    freq = fh['index_map/freq']['centre'][:]
    prod = fh['index_map/prod'][:]

In [None]:
start_time = times[0]
inputs = tools.get_correlator_inputs(ephem.unix_to_datetime(start_time), correlator='pathfinder')
pos = tools.get_feed_positions(inputs)
pol = tools.get_feed_polarisations(inputs)

## Setup model with selection of products

In [None]:
f_ind = 58  # 408MHz

In [None]:
#with h5py.File("/scratch/cahofer/pass1_p/simulations/freqband_400-500/input_maps_buggy/map_foreground.h5") as fh:
sim_model = ModelVis(freq=freq[f_ind], smooth=True)#, sky_map=hputil.coord_c2g(fh['map'][f_ind,0,:]))

In [None]:
pol_sel = 'S'

prod_excl = []

ns_baselines = (pos[prod[:,0],1] - pos[prod[:,1],1])
ew_baselines = (pos[prod[:,0],0] - pos[prod[:,1],0])

pol_pair = np.empty(prod.shape[0], dtype=[('pol_a', '<U1'), ('pol_b', '<U1')])
pol_pair['pol_a'] = pol[prod[:,0]]
pol_pair['pol_b'] = pol[prod[:,1]]

# exclude bad channels
prod_excl += list(np.where(np.logical_not(np.isfinite(ew_baselines + ns_baselines)))[0])
# exclude intercyl
prod_excl += list(np.where(np.abs(ew_baselines) > 10.)[0])
# exclude autos
prod_excl += list(np.where(ns_baselines + ew_baselines == 0)[0])
# exclude all but SS pol
prod_excl += list(np.where(np.logical_not(np.logical_and(pol_pair['pol_a'] == pol_sel,
                                                         pol_pair['pol_b'] == pol_sel)))[0])
# exclude longer baselines
prod_excl += list(np.where(np.abs(ns_baselines) > 0.5 * sim_model.wl / sim_model._res())[0])
# get unique values
prod_excl = set(prod_excl)

prod_sel = np.array([ p for p in range(prod.shape[0]) if not p in prod_excl ])

In [None]:
vis_sel = vis[f_ind, prod_sel, :]
ns_baselines = ns_baselines[prod_sel]

In [None]:
sim_model.set_baselines(ns_baselines)

## Exclude point sources

In [None]:
sim_ra = ephem.transit_RA(times)
ra_res = (sim_ra[1] - sim_ra[0])

In [None]:
# exclude point source transits
excl_ind = []
transit_cut = int(10. / ra_res)
for src in (ephem.CasA, ephem.CygA, ephem.TauA, ephem.VirA):
    src_ind = np.argmin(np.abs(sim_ra - np.degrees(src.ra.radians)))
    cut_ext = max(0, src_ind - transit_cut), min(src_ind + transit_cut, len(sim_ra))
    excl_ind += range(*cut_ext)
# casA also shows up over the pole
src_ind = np.argmin(np.abs(sim_ra - np.degrees(ephem.CasA.ra.radians) + 180.))
cut_ext = max(0, src_ind - transit_cut - 10), min(src_ind + transit_cut, len(sim_ra) - 1)
excl_ind += range(*cut_ext)
# try excluding region chosen by eye
#excl_ind += range(np.argmin(np.abs(sim_ra - 250.)), np.argmin(np.abs(sim_ra - 300.)))
#excl_ind = set(excl_ind)

In [None]:
for p in range(len(prod_sel)):
    if abs(ns_baselines[p]) > 10.:
    #if abs(ns_baselines[p]) < 2.:
        plt.plot(sim_ra, np.abs(vis_sel[p,:]))
plt.plot(sim_ra[time_slice], np.ones_like(time_slice), 'o')
plt.vlines(np.degrees(ephem.CasA.ra.radians), *plt.ylim())

## Setup beam fit grid

In [None]:
max_za = 89.
max_sinza = np.sin(np.radians(max_za))
# approx resolution for smoothed Haslam
approx_res = np.degrees(max(0.5 / np.abs(ns_baselines).max(), sim_model._res()))
num_pix = int(max_za / approx_res)
sinza = np.linspace(-max_sinza, max_sinza, num_pix)
za = np.arcsin(sinza)
#time_slice = slice(test_ind-5, test_ind+5)
#time_slice = slice(0, 180)
time_slice = np.array([ i for i in range(len(sim_ra)) if not i in excl_ind ])
#time_slice = np.arange(len(sim_ra))

In [None]:
time_slice = time_slice[::8]

In [None]:
beam_sol = sim_model.fit_beam(times[time_slice], vis_sel[:,time_slice],
                               np.ones_like(vis_sel[:,time_slice]),
                               num_pix, max_za=max_za, xtalk_iter=3, resume=True)

In [None]:
amp_fit = leastsq(lambda t: beam_sol / beam_sol.max() - np.cos(za) * model_beam(za, 0.7, *t), (1.,))

plt.plot(za, beam_sol / beam_sol.max(), label='fit')
plt.plot(za, np.cos(za)*model_beam(za, 0.7, *amp_fit[0]), label='input')
#plt.plot(za, model_beam(za, fwhm_fudge=1.2))
plt.legend()
plt.xlabel(r"$\theta$")
#plt.savefig("./xtalk_beam_sol_NS.pdf", dpi=300)

In [None]:
amp_fit = leastsq(lambda t: beam_sol / beam_sol.max() - np.cos(za) * model_beam(za, 0.7, *t), (1.,))

plt.plot(za, beam_sol / beam_sol.max())
plt.plot(za, np.cos(za)*model_beam(za, 0.7, *amp_fit[0]))
#plt.plot(za, model_beam(za, fwhm_fudge=1.2))
plt.xlabel(r"$\theta$")
#plt.savefig("./sim_beam_sol_NS.pdf", dpi=300)

In [None]:
wl = 3e2 / freq[f_ind]

rmap = np.dot(vis_sel[:,time_slice].T, np.exp(-2j * np.pi * sinza[np.newaxis,:] * ns_baselines[:,np.newaxis] / wl)).real
rmap_xtalk = np.dot(sim_model.xtalk.T, np.exp(-2j * np.pi * sinza[np.newaxis,:] * ns_baselines[:,np.newaxis] / wl)).real

In [None]:
plt.plot(rmap_xtalk)

In [None]:
plt.imshow(rmap.T, aspect='auto', origin='lower', extent=(0, (times[-1]-times[0])/3600, sinza[0], sinza[-1]),
          vmax=5000, vmin=-1000)
plt.ylabel(r"$\sin \theta$")
plt.xlabel("hours since {}".format(ephem.unix_to_datetime(times[0])))
plt.colorbar()
plt.title('input map')

#plt.savefig("xtalk_input_map.png", dpi=300, bbox_inches='tight')

In [None]:
#plt.plot(rmap[50,:] - rmap_xtalk)
plt.plot(za, rmap_xtalk)
plt.xlabel(r"\theta")
plt.savefig("xtalk_slice.pdf", dpi=300, bbox_inches='tight')

In [None]:
model_vis = sim_model.get_vis(times[time_slice], vis_sel[:,time_slice], num_pix, max_za,
                              sim_model.beam_sol)
#model_vis = sim_model.get_vis(times[time_slice], vis_sel[:,time_slice], num_pix, max_za,
#                             )#np.cos(za)*model_beam(za, 0.7, *amp_fit[0]))

In [None]:
rmap_model = np.dot(model_vis.T, np.exp(-2j * np.pi * sinza[np.newaxis,:] * ns_baselines[:,np.newaxis] / wl)).real
rmap_xtalk_sub = np.dot(vis_sel[:,time_slice].T - sim_model.xtalk,
                        np.exp(-2j * np.pi * sinza[np.newaxis,:] * ns_baselines[:,np.newaxis] / wl)).real

In [None]:
plt.imshow(rmap_model.T, aspect='auto', origin='lower', extent=(0, (times[-1]-times[0])/3600, sinza[0], sinza[-1]),
          vmax=5000, vmin=-1000)
plt.ylabel(r"$\sin \theta$")
plt.xlabel("hours since {}".format(ephem.unix_to_datetime(times[0])))
plt.colorbar()
plt.title('recovered beam and haslam map')

#plt.savefig("xtalk_recov_map.png", dpi=300, bbox_inches='tight')

In [None]:
plt.imshow(rmap_xtalk_sub.T, aspect='auto', origin='lower', extent=(0, (times[-1]-times[0])/3600, sinza[0], sinza[-1]),
          vmax=5000, vmin=-1000)
plt.ylabel(r"$\sin \theta$")
plt.xlabel("hours since {}".format(ephem.unix_to_datetime(times[0])))
plt.colorbar()
plt.title('input map minus recovered crosstalk')

#plt.savefig("xtalk_cleaned_map.png", dpi=300, bbox_inches='tight')

In [None]:
plt.imshow(rmap_xtalk_sub.T - rmap_model.T, aspect='auto', origin='lower', extent=(0, (times[-1]-times[0])/3600, sinza[0], sinza[-1]),
          )
plt.colorbar()

In [None]:
vis_input_model = sim_model.get_vis(times, vis_sel,
                                    num_pix, max_za, 
                                    sim_model.beam_sol.max()*np.cos(za)*model_beam(za, fwhm_fudge=0.7))
rmap_input_model = np.dot(vis_input_model.T, np.exp(-2j * sinza[np.newaxis,:] * ns_baselines[:,np.newaxis] / wl)).real

In [None]:
plt.imshow(np.abs(rmap_input_model.T - rmap_xtalk_sub.T), aspect='auto', origin='lower', extent=(0, (times[-1]-times[0])/3600, sinza[0], sinza[-1]),
          vmax=2000)
plt.colorbar()

In [None]:
i = 200
plt.plot(np.abs(vis_sel[i,time_slice]))
plt.plot(np.abs(model_vis[i,:]))

In [None]:
plt.plot(np.abs(sim_model.xtalk))

## Try for many frequencies

In [None]:
vis[ff,prod_sel,:][:,time_slice].shape

In [None]:
beam_freq = []
xtalk_freq = []
for ff in range(0, len(freq), 8):
    a_model = ModelVis(freq=freq[ff])
    a_model.set_baselines(ns_baselines)
    beam_freq.append(
        a_model.fit_beam(times[time_slice], vis[ff,prod_sel,:][:,time_slice],
                           np.ones((prod_sel.shape[0],len(time_slice))),
                           num_pix, max_za=max_za, xtalk_iter=4, resume=False).copy()
    )
    xtalk_freq.append(a_model.xtalk.copy())

In [None]:
for ff, b in enumerate(beam_freq):
    plt.plot(za, b / b.max(), label=freq[ff*8])
plt.legend()
plt.plot(za, np.cos(za)*model_beam(za, 0.7), 'k--', linewidth=2)
plt.xlabel(r"\theta")

plt.savefig("beam_multi_freq.pdf", dpi=300, bbox_inches='tight')

In [None]:
for ff, b in enumerate(xtalk_freq):
    an_rmap = np.dot(b.T, np.exp(-2j * sinza[np.newaxis,:] * ns_baselines[:,np.newaxis] / wl)).real
    plt.plot(za, an_rmap, label=freq[ff*8])
plt.legend()
plt.xlabel(r"\theta")

plt.savefig("xtalk_multi_freq.pdf", dpi=300, bbox_inches='tight')