In [None]:
import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

# 1. Create example data
# 1.1 Setup a 3T kernel-nuller

In [None]:
# combiner matrix for a 3T kernel nuller (from github.com/rlaugier/kernuller)
combiner = np.array(((1, -1), (1, 1)), dtype=complex) / np.sqrt(2)
mat_3T_txt = """
Matrix([
[sqrt(3)/3,                sqrt(3)/3,                sqrt(3)/3],
[sqrt(3)/3,  sqrt(3)*exp(2*I*pi/3)/3, sqrt(3)*exp(-2*I*pi/3)/3],
[sqrt(3)/3, sqrt(3)*exp(-2*I*pi/3)/3,  sqrt(3)*exp(2*I*pi/3)/3]])
"""
import sympy as sp
combiner_s = sp.sympify(mat_3T_txt)
combiner = np.array(sp.N(combiner_s,), dtype=np.complex128)

kmat = np.array([[0.0, 1.0, -1.0],])

# collector positions
baseline = 15  # in meter
# Collector diameter
telescope_diam = 3.0

# rotation angles over observation
n_sample_time = 100
rotation_angle = np.linspace(0., 2*np.pi, n_sample_time)  # in rad

# collector_positions_init = np.array(((-baseline/2, baseline/2),
#                                      (0, 0)))

collector_positions_init = np.array(((-baseline/2, baseline/2, 0),
                                     (0, 0, baseline/2)))

rotation_matrix = np.array(((np.cos(rotation_angle), -np.sin(rotation_angle)),
                            (np.sin(rotation_angle), np.cos(rotation_angle))))

collector_position = np.dot(np.swapaxes(rotation_matrix, -1, 0), collector_positions_init)

# observing wavelengths
n_wl_bin = 5
wl_bins = np.linspace(4.0e-6, 18.0e-6, n_wl_bin)  # in meter

# collector area
scaled_area = 1  # in meter^2

## 1.2 Compute planet signal

## 1.3 Add random noise and plot the signal

In [None]:
# # star_signal = np.random.normal([np.mean(planet_signal)*10, np.mean(planet_signal)*1000], [np.sqrt(np.mean(planet_signal)*10), np.sqrt(np.mean(planet_signal)*1000)], planet_signal.shape)
# star_signal = np.zeros_like(planet_signal)
# signal = planet_signal + star_signal

In [None]:
# plt.figure()
# for i, awl in enumerate(wl_bins):
#     plt.plot(signal[i, 1, :], linestyle="--", label=f"Bright, wl {wl_bins[i]:.2e}")
#     plt.plot(signal[i, 0, :], linestyle="-", label=f"Dark, wl {wl_bins[i]:.2e}")
# plt.legend(fontsize=7)
# plt.ylabel('Signal in a.u.')
# plt.xlabel('Time in a.u.')
# plt.show()

# 2. Initialize a nifits object
## 2.1 Showcasing a list of NIFITS extensions

In [None]:

import nifits.io.oifits as io
for aclass in io.NIFITS_EXTENSIONS:
    a = io.getclass(aclass)
    print()
    print(aclass, " :")
    print("---------------")
    print(a.__doc__)
    print("==============================================================")

In [None]:

ni_catm = io.NI_CATM(data_array=combiner[None, :, :] * np.ones_like(wl_bins)[:,None,None])

In [None]:
mykmat = io.NI_KMAT(data_array=kmat)

In [None]:
from copy import copy
my_FOV_header = copy(io.NI_FOV_DEFAULT_HEADER)
my_FOV_header["FOV_TELDIAM"] = telescope_diam
my_FOV_header["FOV_TELDIAM_UNIT"] = "m"
ni_fov = io.NI_FOV.simple_from_header(header=my_FOV_header, lamb=wl_bins,
                                  n=n_sample_time)

In [None]:
oi_target = io.OI_TARGET.from_scratch()
oi_target.add_target(target='Test Target', 
                      raep0=14.3, 
                      decep0=-60.4)

In [None]:
from astropy.table import Table, Column
from astropy.time import Time
n_telescopes = combiner.shape[1]
total_obs_time = 10*3600      # s
times_relative = np.linspace(0, total_obs_time, n_sample_time)
dateobs = Time("2035-06-23T00:00:00.000") + times_relative*u.s
mjds = dateobs.to_value("mjd")
seconds = (dateobs - dateobs[0]).to_value("s")
target_id = np.zeros_like(times_relative)
app_index = np.arange(n_telescopes)[None,:]*np.ones(n_sample_time)[:,None]
target_ids = 0 * np.ones(n_sample_time)
int_times = np.gradient(seconds)
mod_phas = np.ones((n_sample_time, n_wl_bin, n_telescopes), dtype=complex)
appxy = collector_position.transpose((0,2,1))
arrcol = np.ones((n_sample_time, n_telescopes)) * np.pi*telescope_diam**2 / 4
fov_index = np.ones(n_sample_time)

app_index         = Column(data=app_index, name="APP_INDEX",
                   unit=None, dtype=int)
target_id         = Column(data=target_ids, name="TARGET_ID",
                   unit=None, dtype=int)
times_relative    = Column(data=seconds, name="TIME",
                   unit="", dtype=float)
mjds              = Column(data=mjds, name="MJD",
                   unit="day", dtype=float)
int_times         = Column(data=seconds, name="INT_TIME",
                   unit="s", dtype=float)
mod_phas          = Column(data=mod_phas, name="MOD_PHAS",
                   unit="rad", dtype=complex)
appxy             = Column(data=appxy, name="APPXY",
                   unit="m", dtype=float)
arrcol            = Column(data=arrcol, name="ARRCOL",
                   unit="m^2", dtype=float)
fov_index         = Column(data=fov_index, name="FOV_INDEX",
                   unit=None, dtype=int)
mymod_table = Table()
mymod_table.add_columns((app_index, target_id, times_relative, mjds,
                        int_times, mod_phas, appxy, arrcol, fov_index))
mymod_table
mynimod = io.NI_MOD(mymod_table)

## 2.2 Creating the NIFITS parent object

In [None]:

from astropy.io import fits

wl_data = np.hstack((wl_bins[:,None], np.gradient(wl_bins)[:,None]))
wl_table = Table(data=wl_data, names=("EFF_WAVE", "EFF_BAND"), dtype=(float, float))
wl_table

del wl_data
oi_wavelength = io.OI_WAVELENGTH(data_table=wl_table,)
# oi_wavelength = io.OI_WAVELENGTH()

myheader = fits.Header()
mynifit = io.nifits(header=myheader,
                        ni_catm=ni_catm,
                        ni_fov=ni_fov,
                        oi_target=oi_target,
                        oi_wavelength=oi_wavelength,
                        ni_mod=mynimod,
                        ni_kmat=mykmat)
mynifit.__dict__.keys()

## 2.3 Saving and opening

In [None]:
myhdu = mynifit.to_nifits(filename="log/testfits.nifits",
                            static_only=False,
                          writefile=True,
                         overwrite=True)
myhdu[0].header

In [None]:
with fits.open("log/testfits.nifits") as anhdu:
    newfits = io.nifits.from_nifits(anhdu)
newfits.header

|  Column      |  format                   |  unit            | Empty |
|:------------:|:------------------------- |:---------------- | ---- | 
|  `APP_INDEX` |  $n_a \times$ int         |  NA              |     |
|  `TARGET_ID` |  int                      |  d               |     |
|  `TIME`      |  float                    |  s               |     |
|  `MJD`       |  float                    |  day             |     |
|  `INT_TIME`  |  float                    |  s               |     |
|  `MOD_PHAS`  |  $n_{\lambda}, n_a $ cpx  |                  |     |
|  `APPXY`     |  $n_a, 2 $ float          | m               |      |
|  `ARRCOL`    |  $n_a $ float             |  $\mathrm{m}^2$  |     |
|  `FOV_INDEX` |  $n_a $ int               |  NA              |     |

# 3. Testing the back end

In [None]:
import nifits.backend as be

In [None]:
mybe = be.NI_Backend(mynifit)
abe = be.NI_Backend()
abe.add_instrument_definition(mynifit)
# abe.add_observation_data(mynifit)

In [None]:
abe.create_fov_function_all()
halfrange = 1000
halfrange_rad = halfrange*u.mas.to(u.rad)
xs = np.linspace(-halfrange_rad, halfrange_rad, 100)
map_extent = [-halfrange, halfrange, -halfrange, halfrange]
xx, yy = np.meshgrid(xs, xs)
map_fov = abe.nifits.ni_fov.xy2phasor(xx.flatten(), yy.flatten())

In [None]:
plt.figure(dpi=30)
plt.imshow(np.abs(map_fov[0,0,:].reshape((xx.shape))), extent=map_extent)
plt.colorbar()
plt.contour(np.abs(map_fov[0,0,:].reshape((xx.shape))), levels=(0.5,), extent=map_extent)
plt.show()

plt.figure(dpi=100)
plt.imshow(np.abs(map_fov[0,-1,:].reshape((xx.shape))), extent=map_extent)
plt.colorbar()
plt.contour(np.abs(map_fov[0,-1,:].reshape((xx.shape))), levels=(0.5,), extent=map_extent)
plt.show()

In [None]:
(wl_bins/telescope_diam)*u.rad.to(u.mas)

In [None]:
xys_mas = np.random.uniform(low=-500, high=+500, size=(2,10000)) 
xys = xys_mas * u.mas.to(u.rad)
# xysm = xys[:,:]
%time z = abe.get_all_outs(xys[0,:], xys[1,:], kernels=False)

In [None]:
plt.figure()
plt.plot(z[:,1,1,1000])
plt.show()

### Note the shape of the output.

In [None]:
print(be.NI_Backend.get_all_outs.__doc__)

In [None]:
plt.figure()
plt.scatter(xys_mas[0,:],xys_mas[1,:], c=z[0,-1,1,:], cmap="viridis", s=6)
plt.colorbar()
plt.show()

kz = abe.get_all_outs(xys[0,:], xys[1,:], kernels=True)
plt.figure()
plt.scatter(xys_mas[0,:],xys_mas[1,:], c=kz[0,-1,0,:], cmap="coolwarm", s=6)
plt.colorbar()
plt.show()

x_inj = abe.nifits.ni_fov.xy2phasor(xys[0,:], xys[1,:])
plt.figure()
plt.scatter(xys_mas[0,:],xys_mas[1,:], c=np.abs(x_inj[0,-1,:]), cmap="viridis", s=6)
plt.colorbar()
plt.show()
print(np.max(z))

In [None]:
a = np.random.normal(size=10)
b = np.random.normal(size=5)
a[:,None].dot(b[None,:])

In [None]:
abe.nifits.ni_mod.appxy.shape

# Handling point collections

Further down the line, these can be "summed" together with a `+` operator to create arbitrary sampled maps.

In [None]:
print(be.PointCollection.__doc__)

### A boring cartesian grid using `PointCollection.from_centered_square_grid`

In [None]:
%%time
acollec = be.PointCollection.from_centered_square_grid(600., 100, md=np)
z = abe.get_all_outs(*acollec.coords_rad, kernels=True)

plt.figure()
plt.scatter(*acollec.coords, c=z[0,-1,0,:], cmap="coolwarm", s=5)
plt.colorbar()
plt.show()

x_inj = abe.nifits.ni_fov.xy2phasor(*acollec.coords_rad)
plt.figure()
plt.scatter(*acollec.coords, c=np.abs(x_inj[0,-1,:]), cmap="viridis", s=5)
plt.colorbar()
plt.show()
print(np.max(z))


### A point-sampled disk using `PointCollection.from_uniform_disk`
N.B. This merges well with `scipy.interpolate.griddata` 

In [None]:
%%time
acollec = be.PointCollection.from_uniform_disk(600., 600)
z = abe.get_all_outs(*acollec.coords_rad, kernels=True)

plt.figure()
plt.scatter(*acollec.coords, c=z[0,-1,0,:], cmap="coolwarm", s=80)
plt.colorbar()
plt.show()

x_inj = abe.nifits.ni_fov.xy2phasor(*acollec.coords_rad)
plt.figure()
plt.scatter(*acollec.coords, c=np.abs(x_inj[0,-1,:]), cmap="viridis", s=80)
plt.colorbar()
plt.show()
print(np.max(z))

print("Resampling with griddata:")

from scipy.interpolate import griddata
agrid = be.PointCollection.from_centered_square_grid(600., 512, md=np)
interped = griddata(acollec.coords, z[0,-1,0,:], agrid.coords_shaped, method="nearest")
plt.figure()
plt.imshow(interped, cmap="coolwarm")
plt.colorbar()
plt.show()

### A grid for a using `PointCollection.from_grid`

In [None]:
%%time
acollec = be.PointCollection.from_grid(np.linspace(0, 600, 100), np.linspace(-200,500, 100))
z = abe.get_all_outs(*acollec.coords_rad, kernels=True)

plt.figure()
plt.scatter(*acollec.coords, c=z[0,-1,0,:], cmap="coolwarm", s=5)
plt.colorbar()
plt.xlim(-600, 600)
plt.ylim(-600, 600)
plt.show()

x_inj = abe.nifits.ni_fov.xy2phasor(*acollec.coords_rad)
plt.figure()
plt.scatter(*acollec.coords, c=np.abs(x_inj[0,-1,:]), cmap="viridis", s= 5)
plt.colorbar()
plt.xlim(-600, 600)
plt.ylim(-600, 600)
plt.show()
print(np.max(z))