In [None]:
import healpy
from cora.util import hputil
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import ticker
from ch_util import ephemeris as ephem, andata, tools
from caput.time import unix_to_skyfield_time
import h5py
import time
from datetime import datetime
from glob import glob
from ch_pipeline.core import telescope
from scipy.optimize import leastsq

%load_ext autoreload
%autoreload 2

from continuum_beam import *

%matplotlib inline
plt.rcParams.update({'figure.figsize': (16, 12), 'font.size': 20})

# Validate fitting scheme on simulated visibilities

### Load simulated visibilities

In [None]:
sim_file = "/home/tristpinsm/scratch-fast/continuum_beam_data/sstreamgroup_0.h5"

In [None]:
sim_vis = h5py.File(sim_file, 'r')

In [None]:
sim_vis.keys()

In [None]:
freq = sim_vis['index_map']['freq']['centre']

In [None]:
start_time = time.time()
sim_time = np.array([ ephem.transit_times(r, start_time) for r in sim_vis['index_map']['ra'] ])

In [None]:
plt.plot(sim_time)

In [None]:
inputs = tools.get_correlator_inputs(datetime.now(), correlator='pathfinder')

In [None]:
pos = tools.get_feed_positions(inputs)
pol = tools.get_feed_polarisations(inputs)

### Compare to input beam

In [None]:
def model_beam(za, fwhm_fudge=0.7, amp=1.):
    fwhm = 2.0 * np.pi / 3.0 * fwhm_fudge
    
    alpha = np.log(2.0) / (2*np.tan(fwhm / 2.0)**2)

    return amp * np.exp(-alpha*np.tan(za)**2)

In [None]:
f_ind = 58  # 408MHz

In [None]:
sim_model = ModelVis(freq=freq[f_ind])#, fname="./lambda_haslam408_dsds.fits")

In [None]:
prod_excl = []

ns_baselines = (pos[sim_vis['index_map/prod'][:,0],1]
                 - pos[sim_vis['index_map/prod'][:,1],1])
ew_baselines = (pos[sim_vis['index_map/prod'][:,0],0]
                 - pos[sim_vis['index_map/prod'][:,1],0])

pol_pair = np.empty(sim_vis['index_map/prod'].shape[0], dtype=[('pol_a', '<U1'), ('pol_b', '<U1')])
pol_pair['pol_a'] = pol[sim_vis['index_map/prod'][:,0]]
pol_pair['pol_b'] = pol[sim_vis['index_map/prod'][:,1]]

# exclude bad channels
prod_excl += list(np.where(np.logical_not(np.isfinite(ew_baselines + ns_baselines)))[0])
# exclude intercyl
prod_excl += list(np.where(np.abs(ew_baselines) > 10.)[0])
# exclude autos
prod_excl += list(np.where(ns_baselines + ew_baselines == 0)[0])
# exclude all but SS pol
prod_excl += list(np.where(np.logical_not(np.logical_and(pol_pair['pol_a'] == "E",
                                                         pol_pair['pol_b'] == "E")))[0])
# exclude longer baselines
prod_excl += list(np.where(np.abs(ns_baselines) > 0.5 * sim_model.wl / sim_model._res())[0])
# get unique values
prod_excl = set(prod_excl)

prod_sel = np.array([ p for p in range(sim_vis['index_map/prod'].shape[0]) if not p in prod_excl ])

In [None]:
vis = sim_vis['vis'][f_ind, prod_sel, :]
ns_baselines = ns_baselines[prod_sel]

In [None]:
sim_model.set_baselines(ns_baselines)

In [None]:
for p in range(len(prod_sel)):
    if abs(ns_baselines[p]) > 15.:
    #if abs(ns_baselines[p]) < 2.:
        plt.plot(sim_ra, np.abs(vis[p,:]))
plt.plot(sim_ra[time_slice], np.ones_like(time_slice), 'o')
plt.vlines(np.degrees(ephem.CasA.ra.radians), *plt.ylim())

In [None]:
sim_ra = ephem.transit_RA(sim_time)
ra_res = (sim_ra[1] - sim_ra[0])[0]

In [None]:
for src in (ephem.CasA, ephem.CygA, ephem.TauA, ephem.VirA):
    print src.names, src.dec, np.degrees(src.ra.radians)

In [None]:
# exclude point source transits
excl_ind = []
transit_cut = int(10. / ra_res)
for src in (ephem.CasA, ephem.CygA, ephem.TauA, ephem.VirA):
    src_ind = np.argmin(np.abs(sim_ra - np.degrees(src.ra.radians)))
    cut_ext = max(0, src_ind - transit_cut), min(src_ind + transit_cut, len(sim_ra))
    excl_ind += range(*cut_ext)
# casA also shows up over the pole
src_ind = np.argmin(np.abs(sim_ra - np.degrees(ephem.CasA.ra.radians) + 180.))
cut_ext = max(0, src_ind - transit_cut - 10), min(src_ind + transit_cut, len(sim_ra) - 1)
excl_ind += range(*cut_ext)
# try excluding region chosen by eye
excl_ind += range(np.argmin(np.abs(sim_ra - 250.)), np.argmin(np.abs(sim_ra - 300.)))
excl_ind = set(excl_ind)

In [None]:
max_za = 89.
# approx resolution for smoothed Haslam
approx_res = np.degrees(max(0.5 / np.abs(ns_baselines).max(), sim_model._res()))
num_pix = int(2 * max_za / approx_res)
za = np.radians(np.linspace(-max_za, max_za, num_pix))
#time_slice = slice(test_ind-5, test_ind+5)
#time_slice = slice(0, 180)
time_slice = np.array([ i for i in range(len(sim_ra)) if not i in excl_ind ])
#time_slice = np.arange(len(sim_ra))

In [None]:
beam_sol = sim_model.fit_beam(sim_time[time_slice], vis[:,time_slice],
                               np.ones_like(vis[:,time_slice]),
                               num_pix, max_za=max_za, rcond=1e-3)

In [None]:
plt.plot(za, beam_sol / beam_sol.max())
#plt.plot(za, np.cos(za)*model_beam(za, fwhm_fudge=0.7))
plt.plot(za, model_beam(za, fwhm_fudge=1.2))
plt.xlabel(r"$\theta$")
#plt.savefig("./sim_beam_sol_NS.pdf", dpi=300)

In [None]:
plt.imshow(np.log10(np.abs(sim_model.cov)))
plt.colorbar()

In [None]:
U, S, V = np.linalg.svd(sim_model.M)

In [None]:
plt.imshow(V.T, aspect='auto')
plt.colorbar()

In [None]:
i = 0

In [None]:
plt.plot(V[i,:])
i += 1

In [None]:
plt.plot(np.log10(S / S[0]))

In [None]:
np.linalg.norm(sim_model.M)

In [None]:
"{:.3}".format(np.median(np.abs(sim_model.v)))

In [None]:
np.linalg.det(sim_model.M / sim_model.M.max())

In [None]:
np.allclose(np.dot(sim_model.M, Minv), np.dot(Minv, sim_model.M),)

In [None]:
sim_model.M.shape

In [None]:
Minv = np.linalg.inv(sim_model.M)

In [None]:
plt.plot(np.log10(np.linalg.eig(sim_model.M)[0]))

In [None]:
plt.imshow(np.dot(sim_model.M, Minv) - np.dot(Minv, sim_model.M))
plt.colorbar()

In [None]:
plt.imshow(np.log10(np.abs(np.dot(sim_model.M, Minv))),
          extent=(-max_za,max_za,-max_za,max_za))
plt.colorbar()

In [None]:
plt.imshow(np.log10(np.abs(sim_model.M)))
plt.colorbar()

In [None]:
sim_model.ns_baselines[1]

In [None]:
plt.plot(sim_model._basis[2,0,:])

#### Fit for FWHM

In [None]:
fwhm_fit = leastsq(lambda t: beam_sol / beam_sol.max() - model_beam(za, *t), (1.,1.))

In [None]:
fwhm_fit[0]

In [None]:
plt.plot(za, beam_sol / beam_sol.max())
plt.plot(za, model_beam(za, *fwhm_fit[0]))

#### Look at map slices

In [None]:
test_ind = 300

In [None]:
sim_model._gen_basis(sim_time[test_ind:test_ind+1], vis[:,test_ind:test_ind+1],
                     num_pix, max_za=max_za)
model_basis = sim_model._basis.copy()
model_vis = np.sum(model_basis * model_beam(za, fwhm_fudge=0.7) * np.cos(za), axis=2)

In [None]:
#model_vis = sim_model.get_vis(sim_time[time_slice], vis[:,time_slice],
#                              num_pix, max_za=max_za)
model_map = np.dot(model_vis[:,0], np.exp(-2j * np.pi * ns_baselines[:,np.newaxis]
                                          / sim_model.wl * np.sin(za)[np.newaxis,:]))
vis_map = np.dot(vis[:,test_ind], np.exp(-2j * np.pi * ns_baselines[:,np.newaxis]
                                    / sim_model.wl * np.sin(za)[np.newaxis,:]))

In [None]:
plt.subplot(2,1,1)
plt.plot(za/np.pi, model_map.real, label="Haslam")
plt.gca().yaxis.set_ticklabels([])
plt.legend()

plt.subplot(2,1,2)
plt.plot(za/np.pi, vis_map.real / vis_map.real.max(), label="sim")
plt.plot(za/np.pi, model_beam(za))
plt.gca().yaxis.set_ticklabels([])
plt.legend()

plt.xlabel(r"$\theta_k/\pi$")
#yfmt = ticker.ScalarFormatter()
#yfmt.set_powerlimits((-2,2))
#plt.gca().yaxis.set_major_formatter(yfmt)

In [None]:
plt.plot(za/np.pi, model_map.real / model_map.real.max(), label="Haslam")
plt.plot(za/np.pi, vis_map.real / vis_map.real.max(), label="sim")
plt.gca().yaxis.set_ticklabels([])
plt.legend()

#### Look at model visibilities

In [None]:
model_vis_all = sim_model.get_vis(sim_time, vis, num_pix,
                                  model_beam=lambda x: model_beam(np.radians(x)), max_za=max_za)

In [None]:
for p in range(len(prod_sel)):
    if abs(ns_baselines[p]) > 18.:
        plt.plot(sim_ra, np.abs(model_vis_all[p,:]))
plt.yscale('log')
plt.plot(sim_ra[time_slice], np.ones_like(time_slice), 'o')
plt.vlines(np.degrees(ephem.VirA.ra.radians), *plt.ylim())

In [None]:
for p in range(len(prod_sel)):
    plt.plot(sim_ra[time_slice], np.abs(vis[p,time_slice]), '.')
#plt.plot(sim_ra[time_slice], np.ones_like(time_slice), 'o')
plt.vlines(np.degrees(ephem.VirA.ra.radians), *plt.ylim())

In [None]:
for p in range(len(prod_sel)):
    plt.plot(sim_ra, np.abs(vis[p,:]))
plt.plot(sim_ra[time_slice], np.ones_like(time_slice), 'o')
plt.vlines(np.degrees(ephem.VirA.ra.radians), *plt.ylim())

## SVD the design matrix

In [None]:
svd_model = ModelVis(freq=freq[f_ind])

In [None]:
svd_prod_excl = []

svd_ns_baselines = (pos[sim_vis['index_map/prod'][:,0],1]
                    - pos[sim_vis['index_map/prod'][:,1],1])
svd_ew_baselines = (pos[sim_vis['index_map/prod'][:,0],0]
                    - pos[sim_vis['index_map/prod'][:,1],0])

# exclude bad channels
svd_prod_excl += list(np.where(np.logical_not(np.isfinite(svd_ew_baselines + svd_ns_baselines)))[0])
# exclude intercyl
svd_prod_excl += list(np.where(np.abs(svd_ew_baselines) > 10.)[0])
# exclude autos
svd_prod_excl += list(np.where(svd_ns_baselines + svd_ew_baselines == 0)[0])
# exclude all but SS pol
svd_prod_excl += list(np.where(np.logical_not(np.logical_and(pol_pair['pol_a'] == "S",
                                                             pol_pair['pol_b'] == "S")))[0])
# exclude longer baselines
svd_prod_excl += list(np.where(np.abs(ns_baselines) > 14.)[0])
# get unique values
svd_prod_excl = set(svd_prod_excl)

svd_prod_sel = np.array([ p for p in range(sim_vis['index_map/prod'].shape[0]) if not p in svd_prod_excl ])

In [None]:
svd_model.set_baselines(svd_ns_baselines[svd_prod_sel])

In [None]:
svd_vis = sim_vis['vis'][f_ind, svd_prod_sel, :]

In [None]:
svd_max_za = 89.
# approx resolution for smoothed Haslam
svd_fwhm_smoothing = np.degrees(2 * np.sqrt(2*np.log(2)) * svd_model._res())
svd_num_pix = int(2. * svd_max_za / svd_fwhm_smoothing)
svd_za = np.radians(np.linspace(-svd_max_za, svd_max_za, svd_num_pix))

svd_time_slice = time_slice[::30]
svd_time_slice_nghb = time_slice[:10]

svd_time_slice = time_slice[200:201]

svd_model._gen_basis(sim_time[svd_time_slice], svd_vis[:,svd_time_slice], svd_num_pix, svd_max_za)
svd_basis = svd_model._basis.copy()

svd_model._gen_basis(sim_time[svd_time_slice_nghb], svd_vis[:,svd_time_slice_nghb], svd_num_pix, svd_max_za)
svd_basis_nghb = svd_model._basis.copy()

In [None]:
svd_basis = svd_basis.reshape(svd_vis.shape[0]*len(svd_time_slice), svd_num_pix)
svd_basis = np.vstack((svd_basis, svd_basis.conj()))

svd_basis_nghb = svd_basis_nghb.reshape(svd_vis.shape[0]*len(svd_time_slice_nghb), svd_num_pix)
svd_basis_nghb = np.vstack((svd_basis_nghb, svd_basis_nghb.conj()))

In [None]:
trash = svd_model.fit_beam(sim_time[svd_time_slice], svd_vis[:,svd_time_slice],
                               np.ones_like(svd_vis[:,svd_time_slice]),
                               svd_num_pix, max_za=svd_max_za)
del trash

In [None]:
U, S, V = np.linalg.svd(svd_basis, full_matrices=False)
Un, Sn, Vn = np.linalg.svd(svd_basis_nghb, full_matrices=False)
Um, Sm, Vm = np.linalg.svd(svd_model.M, full_matrices=False)

In [None]:
np.sqrt(np.mean((np.dot(V.T.astype(np.complex128), np.dot(np.diag(S**2),V.astype(np.complex128))).real
            - svd_model.M)**2)) / np.mean(svd_model.M)

In [None]:
plt.imshow(np.log10(np.abs(np.dot(U.T.conj(), U))), aspect='auto')
plt.colorbar()

In [None]:
print np.sqrt(np.mean((svd_model.M - np.dot(V.astype(np.complex128).T.conj(), np.dot(np.diag(S**2),V)).real
                       / 2.)**2)) / np.mean(svd_model.M)

In [None]:
print np.mean(svd_model.M)
print np.mean(np.dot(V.astype(np.complex128).T.conj(), np.dot(np.diag(S**2),V)).real / 2.)

In [None]:
plt.subplot(1,2,1)
plt.imshow(np.abs(V.T[:,:]), aspect='auto')
plt.title("22 times spaced by 30 indices")
plt.colorbar()

plt.subplot(1,2,2)
plt.imshow(np.abs(Vm.T[:,:]), aspect='auto')
plt.title("M")
plt.colorbar()

In [None]:
plt.plot(S / S[0], label="10 times spaced by 60 indices")
plt.plot(Sn / Sn[0], label="10 contiguous times")
plt.plot(np.sqrt(Sm / Sm[0]), label="M")
plt.yscale('log')
plt.legend()

In [None]:
plt.subplot(1,2,1)
plt.imshow(np.abs(U[:100,:]), aspect='auto')
plt.title("10 times spaced by 60 indices")

plt.subplot(1,2,2)
plt.imshow(np.abs(Un[:100,:]), aspect='auto')
plt.title("10 contiguous times")
#plt.colorbar()

In [None]:
plt.subplot(1,2,1)
plt.imshow(np.abs(V), aspect='auto')
plt.title("10 times spaced by 60 indices")

plt.subplot(1,2,2)
plt.imshow(np.abs(Vn), aspect='auto')
plt.title("10 contiguous times")
#plt.colorbar()

In [None]:
plt.subplot(1,2,1)
svd_cov_wgt = np.diag(1./S**2)
svd_cov_wgt[np.where(S/S[0] < 5e-2)] = 0.
plt.imshow(np.log10(np.abs(np.dot(V, np.dot(svd_cov_wgt, V.T)).real)), aspect='auto')
plt.title('covariance')
plt.colorbar()

plt.subplot(1,2,2)
plt.imshow(np.log10(np.abs(svd_model.M)), aspect='auto')
plt.title('M')
plt.colorbar()

In [None]:
for i in range(1):
    plt.plot(Vn[:,i].real)

In [None]:
svd_weight = 1. / S
svd_weight[np.where(S / S[0] < 5e-2)] = 0.
svd_inv = np.dot(V.T, np.dot(np.diag(svd_weight), U.T.conj()))

In [None]:
plt.plot(svd_weight)

In [None]:
svd_vis_full = svd_vis[:,svd_time_slice].reshape(svd_vis.shape[0]*len(svd_time_slice))
svd_vis_full = np.hstack((svd_vis_full, svd_vis_full.conj()))
svd_sol = np.dot(svd_inv, svd_vis_full)

In [None]:
plt.plot(svd_sol)