# Adler project

Work on EAGLE lens models: source reconstructions, synthetic imaging, comparison to true maps, ...

### Imports

In [None]:
import sys
sys.path.append('..')
import os
import re
import time
import numpy as np
from scipy import interpolate, ndimage, sparse
from astropy.io import fits
import matplotlib.pyplot as plt

from gleam.skyf import SkyF
from gleam.lensobject import LensObject
from gleam.multilens import MultiLens
from gleam.reconsrc import ReconSrc, synth_filter, synth_filter_mp
from gleam.glass_interface import glass_renv, filter_env, export_state
glass = glass_renv()

%load_ext skip_kernel_extension


### Reading data

Reading in all relevant files as dictionaries with ids as keys and list of filenames as values

In [None]:
rdir = "/Users/phdenzel/adler"
jsondir = rdir+"/json/"
statedir = rdir+"/states/v2/"
kappadir = rdir+"/kappa/"
keys = ["H1S0A0B90G0", "H1S1A0B90G0", "H2S1A0B90G0", "H2S2A0B90G0", "H2S7A0B90G0",
        "H3S0A0B90G0", "H3S1A0B90G0", "H4S3A0B0G90", "H10S0A0B90G0", "H13S0A0B90G0",
        "H23S0A0B90G0", "H30S0A0B90G0", "H36S0A0B90G0", "H160S0A90B0G0",
        "H234S0A0B90G0"]


In [None]:
def an_sorted(data):
    """
    Perform an alpha-numeric, natural sort

    Args:
        data <list> - list of strings

    Kwargs:
        None

    Return:
        sorted <list> - the alpha-numerically, naturally sorted list of strings
    """
    def convert(text):
        return int(text) if text.isdigit() else text.lower()

    def an_key(key):
        return [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(data, key=an_key)


ls_jsons = an_sorted([os.path.join(jsondir, f) for f in os.listdir(jsondir)
                      if f.endswith('.json')])
ls_states = an_sorted([os.path.join(statedir, f) for f in os.listdir(statedir)
                       if f.endswith('.state')])
ls_kappas = an_sorted([os.path.join(kappadir, f) for f in os.listdir(kappadir)
                       if f.endswith('.kappa.fits')])


In [None]:
jsons = {k: [f for f in ls_jsons if k in f] for k in keys}
filtered_states = {k: [f for f in ls_states
                       if k in f and f.endswith('_filtered.state')] for k in keys}
ls_states = [f for f in ls_states if not f.endswith('_filtered.state')]

prefiltered_fsynth10_states = {k: [f for f in ls_states
                                   if k in f and f.endswith('_filtered_synthf10.state')] for k in keys}
prefiltered_fsynth25_states = {k: [f for f in ls_states
                                   if k in f and f.endswith('_filtered_synthf25.state')] for k in keys}
prefiltered_fsynth50_states = {k: [f for f in ls_states
                                   if k in f and f.endswith('_filtered_synthf50.state')] for k in keys}
ls_states = [f for f in ls_states if not (f.endswith('_filtered_synthf10.state')
                                          or f.endswith('_filtered_synthf25.state')
                                          or f.endswith('_filtered_synthf50.state'))]

synthf10_states = {k: [f for f in ls_states
                       if k in f and f.endswith('_synthf10.state')] for k in keys}
synthf25_states = {k: [f for f in ls_states
                       if k in f and f.endswith('_synthf25.state')] for k in keys}
synthf50_states = {k: [f for f in ls_states
                       if k in f and f.endswith('_synthf50.state')] for k in keys}
ls_states = [f for f in ls_states if not (f.endswith('_synthf10.state')
                                          or f.endswith('_synthf25.state')
                                          or f.endswith('_synthf50.state'))]

states = {k: [f for f in ls_states if k in f] for k in keys}

kappa_map_files = {k: [f for f in ls_kappas if k in f] for k in keys}

psf_file = os.path.join(rdir, 'lenses', 'psf.fits')


### Loading objects

The ensemble of a single state file is used

In [None]:
# Select a single file
print("\n# Selected file")
key = keys[5]  # 5 8 12
idx = 0
json = jsons[key][0]
# state = filtered_states[key][idx]
state = states[key][idx]
print(key)
print(json)
print(state)


In [None]:
# gleamobject
print("\n# gleamobject")
with open(json) as f:
    ml = MultiLens.from_json(f)
print(ml.__v__)

In [None]:
# recon_src
print("\n# recon_src")
args = (ml, state)
kwargs = dict(M=40, verbose=1)
recon_src = ReconSrc(*args, **kwargs)
#recon_src.chmdl(10)  # move to a single ensemble model


In [None]:
recon_src.calc_psf(psf_file, window_size=6, verbose=True)


### Inspection

In [None]:
# Estimating the gain
signals, variances = ml[0].flatfield(recon_src.lens_map(), size=0.2)
gain, _ = ml[0].gain(signals=signals, variances=variances)
print(gain)


In [None]:
# Generate some noise
f = 1./(10*gain)
bias = 0.01*np.max(f * recon_src.lensobject.data)
sgma2 = recon_src.lensobject.sigma2(f=f, add_bias=bias)
sgmaM2 = np.array([1./sgma2[recon_src.lensobject.idx2yx(i)] for i in range(sgma2.size)])
sgmaM2 = sparse.diags(sgmaM2)
dta_noise = np.random.normal(0, 1, size=recon_src.lensobject.data.shape)
dta_noise = dta_noise * np.sqrt(sgma2)


In [None]:
%%skip True
# Test PSF matrix construction
if recon_src.psf is None:
    P_kl = recon_src.calc_psf(psf_file, window_size=6, verbose=True)
else:
    P_kl = recon_src.psf

print(P_kl.shape)
print(np.count_nonzero(P_kl.diagonal()))

N, N = recon_src.lensobject.data.shape
test_data = np.zeros((N, N))
test_data[3*N//4, 1*N//4] = 1
test_data = test_data.reshape(N*N)

test_res = test_data * P_kl
test_res = test_res.reshape((N, N))

plt.imshow(test_res, origin='Lower')
# plt.xlim(left=70, right=90)
# plt.ylim(bottom=70, top=90)
print(np.sum(test_res))
test_center = [X//2 for X in test_res.shape]


In [None]:
%%skip True
# Testing matrix multiplications
sgma_i = 1./recon_src.lensobject.sigma2(f=f, add_bias=bias, flat=True)
sgma_i = sparse.diags(sgma_i)
M_gamma_i = recon_src.proj_matrix()

print(M_gamma_i.shape, sgma_i.T.shape, M_gamma_i.T.shape)

ti = time.time()
A = M_gamma_i * sgma_i * M_gamma_i.T
b = recon_src.d_ij() * sgma_i * M_gamma_i.T
x = sparse.linalg.lsqr(A, b)[0]
# x = sparse.linlag.lsmr(A, b)[0]
tf = time.time()
print("Timing: {}".format(tf-ti))

# x = x.reshape((recon_src.N_AA, recon_src.N_AA))
# plt.imshow(x)
# plt.colorbar()
# plt.axis('off')


In [None]:
%%skip True
# Test matrix multiplications including PSF
P_kl = recon_src.psf
sgma_i = 1./recon_src.lensobject.sigma2(f=f, add_bias=bias, flat=True)
sgma_i = sparse.diags(sgma_i)
M_gamma_i = recon_src.proj_matrix()
M_i_gamma = M_gamma_i.T.tocsc()
print(type(P_kl))
print(type(M_gamma_i))
print(type(sgma_i))
print(type(M_i_gamma))

M_gamma_i = M_gamma_i * P_kl
A = M_gamma_i * sgma_i * M_i_gamma
b = recon_src.d_ij() * sgma_i * M_i_gamma

print("")
print(type(A))
print(type(b))

ti = time.time()
# x = sparse.linalg.lsqr(A, b)[0]
x = sparse.linalg.lsmr(A, b)[0]
# x = sparse.linalg.cg(A, b)[0]
# x = sparse.linalg.cgs(A, b)[0]
# x = sparse.linalg.lgmres(A, b, atol=1e-05)[0]
# x = sparse.linalg.minres(A, b)[0]
# x = sparse.linalg.qmr(A, b)[0]

tf = time.time()
print("Timing: {}".format(tf-ti))

print("M_gamma_i, sigma_i, M_i_gamma", M_gamma_i.shape, sgma_i.T.shape, M_gamma_i.T.shape)
print("A, b", A.shape, b.shape)
print("x", x.shape)

x = x.reshape((recon_src.N, recon_src.N))
plt.imshow(x, cmap='Spectral_r', origin='Lower')
plt.colorbar()
plt.axis('off')


In [None]:
# srcgrid_mapping; testing antialiasing
_, r_fullres = recon_src.srcgrid_deflections(pixrad=None, mask=None)
print("r_fullres: {}".format(r_fullres))
_, r_max = recon_src.srcgrid_deflections(pixrad=None, mask=recon_src.image_mask())
print("r_max: {}".format(r_max))
print("f_AA: {}".format(r_fullres/r_max))
print("Src plane pixel resolution: {}".format(r_max/recon_src.M))
print("Img plane pixel resolution: {}".format(recon_src.lensobject.px2arcsec[0]))


In [None]:
# %%skip True
# inverse projection matrix
print("\n# inverse projection matrix")
# ti = time.time()
Mij_p = recon_src.inv_proj_matrix()
# tf = time.time()
# print("Timing: {}".format(tf-ti))
print(type(Mij_p))
print(Mij_p.shape)


In [None]:
# %%skip True
# (inverse of the inverse) projection matrix; TODO: inverse only in an ideal case
print("\n# projection matrix")
Mp_ij = recon_src.proj_matrix()
print(type(Mp_ij))
print(Mp_ij.shape)


In [None]:
# %%skip True
# image plane data arrays
print("\n# image plane data arrays")
data = recon_src.d_ij()  # 1d lens plane data
print(type(data))
print(data.shape)
lmap = recon_src.lens_map()  # 2d lens plane data
print(type(lmap))
print(lmap.shape)


In [None]:
# %%skip True
# source plane data arrays
print("\n# source plane data arrays")
rsrc = recon_src.d_p(antialias=True)  # 1d source plane data
print(type(rsrc))
print(rsrc.shape)
rsrc_map = recon_src.plane_map(antialias=True)  # 2d source plane data
print(type(rsrc_map))
print(rsrc_map.shape)


In [None]:
# %%skip True
# synthetic image
print("\n# synthetic image")
reproj = recon_src.reproj_map()
print(type(reproj))
print(reproj.shape)


#### Actual data plot

In [None]:
# %%skip True

data = recon_src.d_ij(flat=False) # + dta_noise
kw = dict(vmax=data.max(), vmin=data.min(), cmap='Spectral_r', origin='Lower')
plt.imshow(data, **kw)
plt.colorbar()
plt.axis('off')
plt.show()


#### Reconstructed source plot

In [None]:
# %%skip True
# recon_src.psf = sparse.diags(np.ones(recon_src.lensobject.data.size))
recon_src.chmdl(80)
# recon_src.flush_cache()
kw = dict(method='lsmr', use_psf=True, cached=True, sigma2=sgma2, sigmaM2=sgmaM2)
s = recon_src.plane_map(**kw)
plt.imshow(s, cmap='Spectral_r', origin='Lower')
plt.colorbar()
plt.axis('off')
plt.show()


#### Synthetic image plot

In [None]:
# %%skip True


kw = dict(flat=False, method='lsmr', use_psf=True, sigma2=sgma2, sigmaM2=sgmaM2)
i = recon_src.reproj_map(**kw)
plt.imshow(i, cmap='Spectral_r', origin='Lower') #, vmax=data.max(), vmin=data.min())
plt.colorbar()
plt.axis('off')
plt.show()


#### Masked data plot

In [None]:
%%skip True

plt.imshow(recon_src.lens_map(mask=True), **kwargs)
plt.colorbar()
plt.axis('off')
plt.show()


#### Residual map plot

In [None]:
# %%skip True
res = data-i+dta_noise
plt.imshow(res, cmap="bwr", vmin=-res.max(), vmax=res.max())
plt.colorbar()
plt.axis('off')
plt.show()


#### Arrival time surface

In [None]:
# %%skipt True
model = recon_src.gls.models[recon_src.model_index]
recon_src.gls.img_plot(obj_index=0, color='#fe4365')
recon_src.gls.arrival_plot(model, obj_index=0, only_contours=True, clevels=75, colors=['#603dd0'])


In [None]:
model = recon_src.gls.models[recon_src.model_index]
recon_src.gls.kappa_plot(model, obj_index=0, with_contours=False, clevels=20)


#### Residual statistics

In [None]:
import pickle
if os.path.exists('reconsrc.pkl'):
    with open('reconsrc.pkl', 'rb') as f:
        recon_src = pickle.load(f)
    print("Loaded reconsrc.pkl")


In [None]:
signals, variances = ml[0].flatfield(recon_src.lens_map(), size=0.2)
gain, _ = ml[0].gain(signals=signals, variances=variances)
f = 1./(10*gain)
bias = 0.01*np.max(f * recon_src.lensobject.data)
sgma2 = recon_src.lensobject.sigma2(f=f, add_bias=bias)
dta_noise = np.random.normal(0, 1, size=recon_src.lensobject.data.shape)
dta_noise = dta_noise * np.sqrt(sgma2)


In [None]:
# chi2 test on the ensemble average
print("Sum of squared residuals (ensemble avg)")
recon_src.chmdl(20)
recon_src.flush_cache()
ti = time.time()
resid = recon_src.reproj_chi2(reduced=False, method='lsmr', use_psf=True, cached=True,
                              from_cache=True, save_to_cache=True,
                              noise=dta_noise, sigma2=sgma2, sigmaM2=sgmaM2)
tf = time.time()
# print(recon_src.proj_matrix().shape)
print("Chi2", resid)
print("Time", tf-ti)
print("DOF", recon_src.proj_matrix().shape[1] - recon_src.N_nil)
print("Red. chi2", resid/(recon_src.proj_matrix().shape[1] - recon_src.N_nil))


In [None]:
# %%skip True
print("Sum of squared residuals (for all ensemble models)")
_, _, residuals = synth_filter(reconsrc=recon_src, percentiles=[],
                               reduced=False, nonzero_only=True, method='lsmr',
                               from_cache=True, cached=True, save_to_cache=True,
                               noise=dta_noise, sigma2=sgma2, sigmaM2=sgmaM2,
                               N_models=10,
                               save=False, verbose=True)


print("0th, 10th, 25th, and 50th percentile values")
rhi10 = np.percentile(residuals, 10, interpolation='higher')
rhi25 = np.percentile(residuals, 25, interpolation='higher')
rhi50 = np.percentile(residuals, 50, interpolation='higher')
rlo = 0
print(rlo, rhi10, rhi25, rhi50)


In [None]:
_, _, residuals = synth_filter_mp(reconsrc=recon_src, percentiles=[],
                                  nproc=2,
                                  reduced=False, nonzero_only=True, method='lsmr',
                                  from_cache=True, cached=True, save_to_cache=True,
                                  noise=dta_noise, sigma2=sgma2, sigmaM2=sgmaM2,
                                  N_models=100,
                                  save=False, verbose=True)


print("0th, 10th, 25th, and 50th percentile values")
rhi10 = np.percentile(residuals, 10, interpolation='higher')
rhi25 = np.percentile(residuals, 25, interpolation='higher')
rhi50 = np.percentile(residuals, 50, interpolation='higher')
rlo = 0
print(rlo, rhi10, rhi25, rhi50)


In [None]:
# %%skip True
import pickle

with open('reconsrc.pkl', 'wb') as f:
    pickle.dump(recon_src, f)


#### Histogram of an ensemble's residual distribution

In [None]:
# look at noisified data
d = recon_src.lens_map() + dta_noise
plt.imshow(d)
plt.colorbar()
plt.show()



In [None]:
#%%skip True

plt.hist(residuals, bins=50)
# plt.axvline(rhi10)
# plt.axvline(rhi25)
# plt.axvline(rhi50)
plt.show()


In [None]:
# Filter out some selected chi2
ichi2max = np.argmax(residuals)
ichi2min = np.argmin(residuals)
print("max chi2: {} @ {}".format(residuals[ichi2max], ichi2max))
print("min chi2: {} @ {}".format(residuals[ichi2min], ichi2min))


In [None]:
# look at a selected reconstruction
recon_src.chmdl(ichi2max)
d = recon_src.reproj_map()
plt.imshow(d)
plt.colorbar()
plt.axis('off')
plt.show()


In [None]:
m = recon_src.gls.models[ichi2min]
recon_src.gls.img_plot(color='#fe4365')
recon_src.gls.arrival_plot(m, only_contours=True, clevels=50, colors=['#603dd0'])
plt.show()


In [None]:
resid_p = np.asarray(residuals)**12
resid_p /= np.sum(resid_p)
invresid_p = 1/resid_p
invresid_p /= np.sum(invresid_p)
# subsetA =  list(np.random.choice(range(len(residuals)), 25, p=invresid_p))
# subsetB = list((np.random.choice(range(len(residuals)), 25, p=resid_p)))
sortedchi2 = sorted(range(len(residuals)), key=lambda k: residuals[k])
subsetA = sortedchi2[:25]
subsetB = sortedchi2[-25:]
print(subsetA)
print(subsetB)


In [None]:
# %%skip True

chi2A = filter_env(recon_src.gls, subsetA)
chi2B = filter_env(recon_src.gls, subsetB)
export_state(chi2A, name='chi2Asubset.state')
export_state(chi2B, name='chi2Bsubset.state')



#### Filtering and exporting the single state

In [None]:
%%skip True

# Filtering 10, 25, 50 percent
select10 = [i for i, r in enumerate(residuals) if rhi10 > r > rlo]
select25 = [i for i, r in enumerate(residuals) if rhi25 > r > rlo]
select50 = [i for i, r in enumerate(residuals) if rhi50 > r > rlo]
print("Number of selected models in 10th, 25th and 50th percentile")
print(len(select10))
print(len(select25))
print(len(select50))

In [None]:
%%skip True

dirname = os.path.dirname(state)
basename = ".".join(os.path.basename(state).split('.')[:-1])
save10 = dirname + '/' + basename + '_synthf10.state'
save25 = dirname + '/' + basename + '_synthf25.state'
save50 = dirname + '/' + basename + '_synthf50.state'
print("Names of filtered states...")
print(save10)
print(save25)
print(save50)


In [None]:
%%skip True

filtered_10 = filter_env(recon_src.gls, select10)
filtered_25 = filter_env(recon_src.gls, select25)
filtered_50 = filter_env(recon_src.gls, select50)
export_state(filtered_25, name=save25)
export_state(filtered_50, name=save50)


## Various test snippets

In [None]:
# Noise estimation

# manual noise estimation
dta = ml[0].data*1
l, r = np.roll(dta, -1, axis=0), np.roll(dta, 1, axis=0)
u, d = np.roll(dta, -1, axis=1), np.roll(dta, 1, axis=1)
snr = max([np.max(np.abs(dta-l)), np.max(np.abs(dta-r)), np.max(np.abs(dta-u)), np.max(np.abs(dta-d))])
print("Manual noise estimation {:2.4f}".format(snr))

# automated noise estimation
threshold = ml[0].finder.threshold_estimate(ml[0].data, sigma=5)
snr = threshold.max()
print("Autom. noise estimation {:2.4f}".format(snr))
print("Autom. noise estimation x3 {:2.4f}".format(3*snr))


In [None]:
# %%skip True

# Using threshold to do automated masking
# dta = ml[0].data*1
# mask = np.abs(dta) >= 0.3*threshold
mask = recon_src.image_mask(f=0.5, n_sigma=5)
dta[~mask] = 0
dta[mask] = 1
edge_mask = np.abs(dta - ndimage.morphology.binary_dilation(dta))
#edge_mask = dta
xsobel = ndimage.sobel(dta, 0)
ysobel = ndimage.sobel(dta, 1)
# edge_mask = np.sign(xsobel**2 + ysobel**2)

print(set(edge_mask.flatten().tolist()))
edges = np.array(np.where(edge_mask)).T

# sort along edge line
groups = []
ordering = [0]
current = edges[0]
for i in range(len(edges)):
    break

plt.imshow(edge_mask, interpolation='None')
# plt.savefig('test1.png')
plt.show()
plt.imshow(xsobel)
# plt.savefig('test2.png')
plt.show()
plt.imshow(ysobel)
# plt.savefig('test3.png')
plt.show()
plt.imshow(xsobel**2+ysobel**2)
# plt.savefig('test4.png')
plt.show()
plt.imshow(dta*2+edge_mask)
# plt.savefig('test5.png')
plt.show()

