In [None]:
import healpy
from cora.util import hputil
import os
os.environ["OMP_NUM_THREADS"] = "8"
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import ticker
from ch_util import ephemeris as ephem, andata, tools, data_index
from caput.time import unix_to_skyfield_time
import h5py
import time
from datetime import datetime
from glob import glob
from ch_pipeline.core import telescope
from scipy.optimize import leastsq

%load_ext autoreload
%autoreload 2

from continuum_beam import *
from beam_utils import model_beam

%matplotlib inline
plt.rcParams.update({'figure.figsize': (16, 12), 'font.size': 20})

In [None]:
f = data_index.Finder(node_spoof={'cedar_archive': '/project/rpp-krs/chime/chime_archive/'})

In [None]:
# choose a day that seemed to have little site activity
f.set_time_range(datetime(2018, 11, 18, 8), datetime(2018, 11, 19, 8))
# exclude point source transits
for src in (ephem.CasA, ephem.CygA, ephem.TauA, ephem.VirA):
    f.exclude_transits(src, time_delta=600)
# casA also shows up over the pole
src_ra = np.degrees(ephem.CasA.ra.radians + np.pi) % 360.
ra_delta = np.degrees(600. / (24*3600) * 2 * np.pi)
f.exclude_RA_interval(src_ra - ra_delta/2., src_ra + ra_delta/2)
# daytime
f.exclude_daytime()
# ignore global flags
f.accept_all_global_flags()
# get stacked data
f.filter_acqs(data_index.ArchiveInst.name == 'chimestack')

In [None]:
acqs = f.get_results()
f.print_results_summary()

In [None]:
freq_sel = (900,)

In [None]:
freq = None
vis = []
weight = []
times = []
for a in acqs:
    for f in a[0]:
        with h5py.File(f, 'r') as fh:
            if freq is None:
                freq = fh['index_map/freq']['centre']
                prod = fh['index_map/prod'][:]
                stack = fh['index_map/stack'][:]
            some_times = fh['index_map/time']['ctime']
            if some_times[0] < a[1][0]:
                t_min = np.argmin(np.abs(fh['index_map/time']['ctime'] - a[1][0]))
            else:
                t_min = 0
            if some_times[-1] > a[1][1]:
                t_max = np.argmin(np.abs(fh['index_map/time']['ctime'] - a[1][1]))
            else:
                t_max = -1
            vis.append(fh['vis'][freq_sel,:,t_min:t_max])
            weight.append(fh['flags/vis_weight'][freq_sel,:,t_min:t_max])
            times.append(some_times[t_min:t_max])
vis = np.concatenate(vis, axis=-1)
weight = np.concatenate(weight, axis=-1)
times = np.concatenate(times)

In [None]:
start_time = times[0]
inputs = tools.get_correlator_inputs(ephem.unix_to_datetime(start_time), correlator='chime')
pos = tools.get_feed_positions(inputs)
pol = tools.get_feed_polarisations(inputs)

In [None]:
ra = ephem.transit_RA(times)

In [None]:
plt.plot(pos[:,1])

In [None]:
freq[freq_sel[0]]

In [None]:
vis_model = ModelVis(freq=freq[freq_sel[0]], smooth=False, harm_basis=True)#, fname="./lambda_haslam408_dsds.fits")

In [None]:
pol_sel = 'S'

prod_excl = []

ns_baselines = (pos[prod[stack['prod']]['input_a'],1] - pos[prod[stack['prod']]['input_b'],1])
ew_baselines = (pos[prod[stack['prod']]['input_a'],0] - pos[prod[stack['prod']]['input_b'],0])

pol_pair = np.empty(stack.shape[0], dtype=[('pol_a', '<U1'), ('pol_b', '<U1')])
pol_pair['pol_a'] = pol[prod[stack['prod']]['input_a']]
pol_pair['pol_b'] = pol[prod[stack['prod']]['input_b']]

# exclude bad channels
prod_excl += list(np.where(np.logical_not(np.isfinite(ew_baselines + ns_baselines)))[0])
# exclude intercyl
prod_excl += list(np.where(np.abs(ew_baselines) > 10.)[0])
# exclude autos
prod_excl += list(np.where(ns_baselines + ew_baselines == 0)[0])
# exclude all but SS pol
prod_excl += list(np.where(np.logical_not(np.logical_and(pol_pair['pol_a'] == pol_sel,
                                                         pol_pair['pol_b'] == pol_sel)))[0])
# exclude longer baselines
prod_excl += list(np.where(np.abs(ns_baselines) > 0.5 * vis_model.wl / vis_model._res())[0])
# exclude shorter baselines
#prod_excl += list(np.where(np.abs(ns_baselines) < 3.)[0])
# exclude all but cylinder A
prod_excl += list(np.where(np.logical_or(prod[stack['prod']]['input_a'] > 512,
                                         prod[stack['prod']]['input_b'] > 512))[0])

# get unique values
prod_excl = set(prod_excl)

prod_sel = np.array([ p for p in range(stack.shape[0]) if not p in prod_excl ])

In [None]:
vis_sel = vis[0, prod_sel, :]
weight_sel = weight[0,prod_sel,:]
ns_baselines = ns_baselines[prod_sel]
ew_baselines = ew_baselines[prod_sel]

In [None]:
vis_model.set_baselines(ns_baselines, ew_baselines)

In [None]:
max_za = 89.
max_sinza = np.sin(np.radians(max_za))
# approx resolution for smoothed Haslam
approx_res = np.degrees(max(0.5 / np.abs(ns_baselines).max(), vis_model._res()))
num_pix = int(max_za / approx_res)
sinza = np.linspace(-max_sinza, max_sinza, num_pix)
za = np.arcsin(sinza)

In [None]:
t_stride = 8
time_slice = slice(0, vis.shape[-1] - (vis.shape[-1] % (4 * t_stride)), 4)

In [None]:
time_slice = slice(0, 3600, 4)
time_slice = np.concatenate((np.arange(0,2600,8), np.arange(2800,3600,8)))
#time_slice = np.arange(2800, 3600)

In [None]:
ntime = times[time_slice].shape[0]

In [None]:
for p in range(len(prod_sel)):
    #if abs(ns_baselines[p]) > 10.:
    #if abs(ns_baselines[p]) < 2.:
    plt.plot(np.abs(vis_sel[p,:]))
plt.plot(np.arange(times.shape[0])[time_slice], 2000*np.ones(ntime), 'o')
#plt.vlines(np.degrees(ephem.CasA.ra.radians), *plt.ylim())

In [None]:
for p in range(len(prod_sel)):
    #if abs(ns_baselines[p]) > 10.:
    #if abs(ns_baselines[p]) < 2.:
    plt.plot(ra, np.abs(vis_sel[p,:]))
plt.plot(ra[time_slice], 2000*np.ones(ntime), 'o')
#plt.vlines(np.degrees(ephem.CasA.ra.radians), *plt.ylim())

In [None]:
beam_sol = vis_model.fit_beam(times[time_slice], vis_sel[:,time_slice],
                              np.ones_like(vis_sel[:,time_slice], dtype=float), num_pix, rcond=1e-6,
                              #weight_sel[:,time_slice], num_pix, rcond=1e-6,
                              chain_len=1000, max_za=max_za, xtalk_iter=3, resume=True)

In [None]:
beam_shape = np.sum(np.sin(np.arange(1,num_pix+1)[np.newaxis,:]
                           * (za+np.pi/2)[:,np.newaxis])*beam_sol, axis=1)
amp_fit = leastsq(lambda t: beam_shape - np.cos(za) * model_beam(za, 0.7, *t), (1.,))

plt.plot(za, beam_shape, label='fit')
plt.plot(za, np.cos(za)*model_beam(za, 0.7, *amp_fit[0]), label='simple model')
plt.legend()
plt.xlabel(r"$\theta$")
plt.title('Fit to CHIME data -- after 150 iterations')

plt.savefig("./chime_xtalk_iter_beam.png", dpi=300, bbox_inches='tight')

In [None]:
wl = vis_model.c / freq[freq_sel[0]]

rmap = np.dot(vis_sel[:,time_slice].T, np.exp(-2j * np.pi * sinza[np.newaxis,:] * ns_baselines[:,np.newaxis] / wl)).real
rmap_xtalk = np.dot(vis_model.xtalk.T, np.exp(-2j * np.pi * sinza[np.newaxis,:] * ns_baselines[:,np.newaxis] / wl)).real

In [None]:
plt.plot(rmap_xtalk)

In [None]:
plt.imshow(rmap.T, aspect='auto', origin='lower', vmax=20000, vmin=-10000,
           extent=(0, (times[-1]-times[0])/3600, sinza[0], sinza[-1]))
plt.colorbar()
plt.ylabel(r"$\sin \theta$")
plt.xlabel("hours since {}".format(ephem.unix_to_datetime(times[0])))
plt.title("CHIME map ({:.1f} MHz)".format(freq[freq_sel]))

plt.savefig("./chime_xtalk_map.png", dpi=300, bbox_inches='tight')

In [None]:
model_vis = vis_model.get_vis(times[time_slice], vis_sel[:,time_slice], num_pix, max_za,
                              model_beam=beam_sol, skip_basis=True)
rmap_model = np.dot(model_vis.T, np.exp(-2j * np.pi * sinza[np.newaxis,:] * ns_baselines[:,np.newaxis] / wl)).real
rmap_xtalk_sub = np.dot(vis_sel[:,time_slice].T - vis_model.xtalk,
                        np.exp(-2j * np.pi * sinza[np.newaxis,:] * ns_baselines[:,np.newaxis] / wl)).real

In [None]:
plt.imshow(rmap_model.T.real, aspect='auto', origin='lower', vmax=20000, vmin=-10000,
           extent=(0, (times[-1]-times[0])/3600, sinza[0], sinza[-1]))
plt.colorbar()
plt.ylabel(r"$\sin \theta$")
plt.xlabel("hours since {}".format(ephem.unix_to_datetime(times[0])))
plt.title("Recovered beam and Haslam (408 MHz)")

plt.savefig("./chime_recov_map.png", dpi=300, bbox_inches='tight')

In [None]:
plt.imshow(rmap_xtalk_sub.T.real, aspect='auto', origin='lower', vmax=20000, vmin=-10000,
           extent=(0, (times[-1]-times[0])/3600, sinza[0], sinza[-1]))
plt.colorbar()
plt.ylabel(r"$\sin \theta$")
plt.xlabel("hours since {}".format(ephem.unix_to_datetime(times[0])))
plt.title("input map minus recovered crosstalk ({:.1f} MHz)".format(freq[freq_sel]))

plt.savefig("./chime_cleaned_map.png", dpi=300, bbox_inches='tight')

In [None]:
#plt.imshow((rmap_xtalk_sub - rmap_model).T / rmap.T, aspect='auto', origin='lower', vmax=10, vmin=-10)
plt.imshow((rmap - rmap_model).T, aspect='auto', origin='lower', vmax=5000, vmin=-5000)
plt.colorbar()

In [None]:
plt.plot(vis_sel[0,:2500].real, '.')

In [None]:
i = 0

In [None]:
plt.plot(np.abs(vis_sel[i, time_slice]))
plt.plot(np.abs(model_vis[i,:]))
i += 1

In [None]:
xtalk_est = (vis_sel[:,time_slice] - model_vis)[0,:]
plt.plot(xtalk_est.real)
plt.plot(xtalk_est.imag)

In [None]:
plt.plot(vis_model.xtalk.real)
plt.plot(vis_model.xtalk.imag)