In [None]:
import numpy as np
import cupy as cp
import xara
import xaosim
from tqdm import  tqdm
import scipy
import matplotlib.pyplot as plt
from time import sleep
from time import time
import scipy.sparse as sparse

from xara import fft, ifft, shift
import astropy.units as u
from scipy.ndimage import median_filter
from scipy.signal import correlate2d
#from sneks import FigSaver
#fsave = FigSaver("figsave.json")
import saro

In [None]:
%run /home/rlaugier/Documents/kernel/adk/detection_maps_cleaned.py

In [None]:
def shifter(im0,vect, buildmask = True, sg_rad=40.0, verbose=False, nbit=10):


    szh = im0.shape[1] # horiz
    szv = im0.shape[0] # vertic

    temp = np.max(im0.shape) # max dimension of image

    for sz in [64, 128, 256, 512, 1024, 2048]:
        if sz >= temp: break

    dz = sz//2.           # image half-size
    if buildmask:
        #print("We have to make a new mask here")
        imcenter = xara.find_psf_center(im0, verbose=verbose)
        sgmask = xara.super_gauss(sz, sz, imcenter[1], imcenter[0], sg_rad)
    else:
        #print("Mask already exists")
        print("ERROR: HERE you should build the relevant mask")
        return
    x,y = np.meshgrid(np.arange(sz)-dz, np.arange(sz)-dz)
    wedge_x, wedge_y = x*np.pi/dz, y*np.pi/dz
    offset = np.zeros((sz, sz), dtype=complex) # to Fourier-center array

    # insert image in zero-padded array (dim. power of two)
    im = np.zeros((sz, sz))
    orih, oriv = (sz-szh)//2, (sz-szv)//2
    im[oriv:oriv+szv,orih:orih+szh] = im0
    
    #print(vect[1],vect[0])

    (x0, y0) = (vect[1], vect[0])
    
    im -= np.median(im)

    dx, dy = x0, y0
    im = np.roll(np.roll(im, -int(dx), axis=1), -int(dy), axis=0)

    #print("recenter: dx=%.2f, dy=%.2f" % (dx, dy))
    dx -= np.int(dx)
    dy -= np.int(dy)

    temp = im * sgmask
    mynorm = temp.sum()

    # array for Fourier-translation
    dummy = shift(dx * wedge_x + dy * wedge_y)
    offset.real, offset.imag = np.cos(dummy), np.sin(dummy)
    dummy = np.abs(shift(ifft(offset * fft(shift(temp)))))

    #dummy = im
    # image masking, and set integral to right value
    dummy *= sgmask

    return (dummy * mynorm / dummy.sum())

def intro_companion(image, params, pscale):
    rho = params[0]
    theta = params[1]
    c = params[2]
    xshift = (- rho * np.sin(np.deg2rad(theta))) / pscale
    yshift = ( rho * np.cos(np.deg2rad(theta)) ) / pscale
    
    compagim = shifter(image, np.array([-yshift, -xshift]))
    return image + compagim / c

In [None]:
try:
    myinstrument.start()
    myinstrument.stop()
    myinstrument.close()
except:
    print("no Instrument was found")
myinstrument = xaosim.instrument()

thepupil = myinstrument.cam.pupil[:,:]
#thepupil = myinstrument.cam.pupil[:,:]

wl = 1.6e-6
ron = 0.05
centerpup = thepupil.shape[0]//2
resolpup = 206 #Careful: this relies on the instrument defaults
pscale = myinstrument.cam.pscale
radius = resolpup // 2
padpup = 5
resolker = 36
teldiam = 7.92
binary = True
ppscale = teldiam / resolpup
croppedpup = thepupil[centerpup - radius-padpup:centerpup + radius+padpup,
                      centerpup - radius-padpup:centerpup +radius+padpup]

imsize = 256

In [None]:
p = wl / (2.*0.23) * u.rad.to(u.mas) / pscale
print("A good value for r0 is %.1f"%(p))
mask = xara.super_gauss(imsize, imsize, imsize/2, imsize/2, p)
wferror = 50.  #RMS wavefront error in nm
Nphot = 1e4  #Total number of photons per frame
Nreal = 10000
ron = 0.
apod = False

myinstrument.cam.pupil = thepupil

# Definition of the parameters for simulation

In [None]:
## Creation of phase screens
sz = myinstrument.atmo.rndarr.shape[0]
myinstrument.atmo.update_rms(wferror)
phs = []
myinstrument.atmo.update_screen()
for i in tqdm(range(Nreal)):
    myinstrument.atmo.rndarr = np.random.rand(sz,sz)
    myinstrument.atmo.update_screen()
    phs.append(myinstrument.atmo.shm_phs.get_data().astype(np.float32))
#    plt.figure()
#    plt.imshow(phs[i])
#    plt.colorbar()
#    print("RMS = %.1f"%(np.std(phs[i])))
phs = np.array(phs, dtype=np.float32)
np.save("test_data/phs_%d_%.0f_2.npy"%(Nreal,wferror), phs)

#loading the phase screen

#phs2 = np.load("phs_100_50.npy")
#phs = np.load("phs_100_10.npy") + phs2[3]
#plotim(phs2[0], name="Phase screen1 (nm)")
#plotim(phs2[20], name="Phase screen2 (nm)")

In [None]:
np.save("test_data/test_case_phs", phs[:2])

In [None]:
#model = pupil_tool(pupil_array, binary=True, step=0.3, thr=0.5, blim=0.9, name="Scexao_normal_pupil")
model = xara.create_discrete_model(apert=myinstrument.cam.pupil,step=0.35, ppscale=ppscale, binary=False, tmin=0.2)

print("Building original reduced pupil")
thekpo = xara.KPO(array=model)#,bmax=7.5
a = plt.figure()
plt.imshow(myinstrument.cam.pupil*1)
plt.show()
a = thekpo.kpi.plot_pupil_and_uv(cmap="viridis")
m2pix = xara.mas2rad( myinstrument.cam.pscale* 128 / 1.6e-6)#pscale supposed 16.7
thekpo.CWAVEL = wl

np.savetxt("test_data/pupil_model.txt", model)

In [None]:
phs = np.load("phs_%d_%.0f_2.npy"%(Nreal,wferror))

## Loading an apodizing phase plate
The apodizing phase plates are built using a modified Gerchberg-Saxton algorithm (https://www.researchgate.net/publication/241486004_Pupil_phase_apodization_for_achromatic_imaging_of_extra-_solar_planets)

In [None]:
def add_noise(image, Nphot=1e6, ron=1.):
    ronframe = np.random.normal(loc=0, scale=ron,size=image.shape)
    rescaled = image * Nphot / np.max(image)
    noisy = ronframe + np.random.poisson(lam=rescaled)
    noisy = np.clip(noisy, 0, np.inf)
    return noisy

In [None]:

cleanscreen = np.zeros_like(phs[0])
x0, y0, hw = 320//2, 256//2, 128//2
cleanimages = []
parimages = []
cvissclean = []
cvisspar = []
cvissnosig = []
for i in tqdm(range(Nreal)):
    #myinstrument.atmo.shm_phs.set_data((cleanscreen + cropapp).astype(np.float64))
    #therawclean = myinstrument.snap()[:,x0-128:x0+128]
    #compagim = intro_companion(therawclean, params=params[i],pscale=pscale)
    #cropped = compagim[y0-hw:y0+hw,y0-hw:y0+hw]
    #cleanimages.append(cropped)
    #cvissclean.append(thekpo.extract_cvis_from_img(cropped,m2pix,method="LDFT1"))
    
    myinstrument.atmo.shm_phs.set_data((phs[i] + cropapp).astype(np.float64))
    therawpar = myinstrument.snap()[:,x0-128:x0+128]
    cropped = therawpar[y0-hw:y0+hw,y0-hw:y0+hw]
    aberrated = add_noise(cropped, Nphot=Nphot, ron=ron)
    cvissnosig.append(thekpo.extract_cvis_from_img(aberrated,m2pix,method="LDFT1"))
    #compagim = intro_companion(therawpar, params=params[i],pscale=pscale)
    #cropped = compagim[y0-hw:y0+hw,y0-hw:y0+hw]
    #parimages.append(cropped)
    #cvisspar.append(thekpo.extract_cvis_from_img(cropped, m2pix, method="LDFT1"))
parimages = np.array(parimages)
cleanimages = np.array(cleanimages)
cvissclean = np.array(cvissclean)
cvisspar = np.array(cvisspar)
cvissnosig = np.array(cvissnosig)

In [None]:
kerns = thekpo.kpi.KPM.dot(np.angle(cvissnosig).T).T
covk = np.cov(kerns.T)
a = plt.figure()
plt.imshow(covk)
plt.colorbar()
plt.show()
np.save("test_data/covk_%d_%.0f_2_apod_%r.npy"%(Nreal,wferror,apod),covk)

In [None]:
W = scipy.linalg.sqrtm(np.linalg.inv(covk))
thekpo.Mp = Mp

# Generating the test data

In [None]:
Ndata=2
cleanscreen = np.zeros_like(phs[0])
x0, y0, hw = 320//2, 256//2, 128//2

params_single = np.array([75., 127., 10])

detpas = np.array([0., 20., 32., 40.])
params_series = np.array([np.array([75., 127.+detpa, 10]) for detpa in detpas])

myinstrument.atmo.shm_phs.set_data((phs[0] + cropapp).astype(np.float64))
therawpar = myinstrument.snap()[:,x0-128:x0+128]
cropped = therawpar[y0-hw:y0+hw,y0-hw:y0+hw]
calibrator = add_noise(cropped, Nphot=Nphot, ron=ron)
myinstrument.atmo.shm_phs.set_data((phs[1] + cropapp).astype(np.float64))
therawpar = myinstrument.snap()[:,x0-128:x0+128]
compagim = intro_companion(therawpar, params=params_single,pscale=pscale)
cropped = compagim[y0-hw:y0+hw,y0-hw:y0+hw]
tarimage_single = add_noise(cropped, Nphot=Nphot, ron=ron)

tarimage_series =[]
for i in range(detpas.shape[0]):
    myinstrument.atmo.shm_phs.set_data((phs[2+i] + cropapp).astype(np.float64))
    therawpar = myinstrument.snap()[:,x0-128:x0+128]
    compagim = intro_companion(therawpar, params=params_series[i],pscale=pscale)
    cropped = compagim[y0-hw:y0+hw,y0-hw:y0+hw]
    tarimage_series.append(add_noise(cropped, Nphot=Nphot, ron=ron))
tarimage_series = np.array(tarimage_series)


In [None]:
np.save("test_data/calibrator",calibrator)
np.save("test_data/tarimage_single",tarimage_single)
np.save("test_data/tarimage_series",tarimage_series)
np.save("test_data/params_single", params_single)
np.save("test_data/params_series", params_series)
np.save("test_data/detpas", detpas)


In [None]:
cleanscreen = np.zeros_like(phs[0])
x0, y0, hw = 320//2, 256//2, 128//2
cleanimages = []
parimages = []
cvissclean = []
cvisspar = []
cvissnosig = []
for i in tqdm(range(Nreal)):
    #myinstrument.atmo.shm_phs.set_data((cleanscreen + cropapp).astype(np.float64))
    #therawclean = myinstrument.snap()[:,x0-128:x0+128]
    #compagim = intro_companion(therawclean, params=params[i],pscale=pscale)
    #cropped = compagim[y0-hw:y0+hw,y0-hw:y0+hw]
    #cleanimages.append(cropped)
    #cvissclean.append(thekpo.extract_cvis_from_img(cropped,m2pix,method="LDFT1"))
    
    myinstrument.atmo.shm_phs.set_data((phs[i] + cropapp).astype(np.float64))
    therawpar = myinstrument.snap()[:,x0-128:x0+128]
    cropped = therawpar[y0-hw:y0+hw,y0-hw:y0+hw]
    aberrated = add_noise(cropped, Nphot=Nphot, ron=ron)
    cvissnosig.append(thekpo.extract_cvis_from_img(aberrated,m2pix,method="LDFT1"))
    #compagim = intro_companion(therawpar, params=params[i],pscale=pscale)
    #cropped = compagim[y0-hw:y0+hw,y0-hw:y0+hw]
    #parimages.append(cropped)
    #cvisspar.append(thekpo.extract_cvis_from_img(cropped, m2pix, method="LDFT1"))
parimages = np.array(parimages)
cleanimages = np.array(cleanimages)
cvissclean = np.array(cvissclean)
cvisspar = np.array(cvisspar)
cvissnosig = np.array(cvissnosig)

In [None]:
resol = 80
step = 10.
#totalrot.append(dtheta)
pgrid = build_grid(resol, step)
conts = thekpo.get_sadk_contrast_cpu(pgrid,0.01, 0.95, wl, W=None, dthetas=[0]).reshape((resol, resol))
np.save("test_data/detection_map_%d_%.0f_2_apod_%r.npy"%(Nreal,wferror,apod),conts)
logcontrast = np.log10(conts)

In [None]:
%run "~/Documents/kernel/tools/colormaps/kernel_cm.py"
logcontrast_apod = np.log10(np.load("detection_map_%d_%.0f_2_apod_%r.npy"%(Nreal,wferror,True)))
logcontrast_classical = np.log10(np.load("detection_map_%d_%.0f_2_apod_%r.npy"%(Nreal,wferror,False)))
amin = 1.
amax = np.nanmax([logcontrast_apod, logcontrast_classical])
a = plt.figure(figsize=(12,6))
plt.subplot(121)
plt.imshow(logcontrast_apod, cmap="inferno", vmin=amin, vmax=amax)
plt.colorbar()
plt.title("Apodized logcontrast")
plt.subplot(122)
plt.imshow(logcontrast_classical, cmap="inferno", vmin=amin, vmax=amax)
plt.colorbar()
plt.title("Normal logcontrast")
plt.show()
fsave.save(a,"Both_%d_%.0f_2_apod_%r.npy"%(Nreal,wferror,apod))
a = plt.figure()
plt.imshow(logcontrast_apod - logcontrast_classical, cmap=bbr, vmin=-1., vmax=1.)
plt.colorbar()
plt.title("Compared log-contrast normalized by peak count")
plt.show()
fsave.save(a,"Compared_%d_%.0f_2_apod_%r.npy"%(Nreal,wferror,apod))

np.save("phasapod_cvissclean_50", cvissclean)
np.save("phasapod_cvisspar_50", cvisspar)
np.save("phasapod_cvissnosig", cvissnosig)
np.save("phasapod_params", params)
#cvissclean = np.load("phasapod_cvissclean_25.npy")
#cvisspar = np.load("phasapod_cvisspar_25.npy")
#cvissnosig = np.load("phasapod_cvissnosig.npy")
#params = np.load("phasapod_params.npy")