In [None]:
import healpy as hp
import matplotlib.pyplot as plt
import numpy as np
import torch
from qubic.lib.MapMaking.NN.graphs.healpix_graph import (
    # get_G_masked_by_cov,
    # get_G_masked_by_cov_multifeature,
    get_high_coverage_indexes,
    # get_nside_from_graph,
    healpix_graph,
    healpix_graph_multifeature,
    # healpix_weightmatrix,
    # healpix_weightmatrix_multifeature,
    plot_sky_3d,
    plot_sky_3d_multifeature,
)
from qubic.lib.MapMaking.NN.operators.forward_ops import ForwardOps
from qubic.lib.Instrument.Qacquisition import QubicInstrumentType
from qubic.lib.Instrument.Qinstrument import QubicMultibandInstrument, compute_freq
from qubic.lib.MapMaking.FrequencyMapMaking.Qspectra_component import CMBModel
from qubic.lib.Qdictionary import qubicDict
from qubic.lib.Qsamplings import equ2gal, get_pointing
from qubic.lib.Qscene import QubicScene

In [None]:
%matplotlib inline

# QUBIC Parameters

In [None]:
dictfilename = "qubic/qubic/dicts/pipeline_demo.dict"
d = qubicDict()
d.read_from_file(dictfilename)

center = equ2gal(d["RA_center"], d["DEC_center"])

In [None]:
d["nf_recon"] = 4
d["nf_sub"] = 4
d["MultiBand"] = True
d["nside"] = 128

d["synthbeam_kmax"] = 1
nf_sub = d["nf_sub"]
d["synthbeam_fraction"] = 1

d["use_synthbeam_fits_file"] = False
d["noiseless"] = True
d["photon_noise"] = False
d["npointings"] = 100
d["instrument_type"] = "UWB"
nf_recon = d["nf_recon"]

# Build Sky

In [None]:
seed = 3
sky_config = {"cmb": seed}
cl_cmb = CMBModel(None).give_cl_cmb(r=0, Alens=1)
sky_map = np.array(d["nf_sub"] * [hp.synfast(cl_cmb, d["nside"], new=True, verbose=False).T])
print(sky_map.shape)

# QUBIC Instances

In [None]:
p = get_pointing(d)
s = QubicScene(d)
q = QubicMultibandInstrument(d)
multiacquisition = QubicInstrumentType(d, nsub=4, nrec=4)

In [None]:
_, nus_edge, nus, _, _, _ = compute_freq(d["filter_nu"] / 1e9, d["nf_sub"], d["filter_relative_bandwidth"])
nus

# Build TOD

In [None]:
# Convolve sky

convolved_sky_map = np.zeros(sky_map.shape)
for i in range(sky_map.shape[0]):
    convolution = multiacquisition.subacqs[i].get_convolution_peak_operator()
    convolved_map = convolution(sky_map[i])
    convolved_sky_map[i] = convolved_map

In [None]:
# Build TOD

TOD_total = np.zeros((nf_recon, 992, d["npointings"]))
for i in range(len(multiacquisition.subacqs)):
    TOD_nsub = multiacquisition.subacqs[i].get_operator()(convolved_sky_map[i])
    TOD_total[i] = TOD_nsub

# Plot Sky

In [None]:
seen_indexes = get_high_coverage_indexes(multiacquisition._get_coverage())
seen_indexes = hp.ring2nest(d["nside"], seen_indexes)

In [None]:
G = healpix_graph(d["nside"], indexes=seen_indexes)

In [None]:
plot_sky_3d(G)

In [None]:
hp.mollview(convolved_sky_map[0, :, 0])

In [None]:
input_map_reord = hp.reorder(convolved_sky_map[0, :, 0], r2n=True)

G_sky = healpix_graph(nside=d["nside"], nest=True)

G_sky.signal = input_map_reord

In [None]:
plot_sky_3d(G_sky)

In [None]:
G_sky_partial = G_sky.subgraph(seen_indexes)
G_sky_partial.signal = G_sky.signal[seen_indexes]

In [None]:
G_sky_partial.coords = G_sky.coords[seen_indexes]

In [None]:
G_sky_partial.coords

In [None]:
plot_sky_3d(G_sky_partial)

In [None]:
G_sky_multifeature = healpix_graph_multifeature(nside=d["nside"], nest=True)

In [None]:
input_map_reord_multifeature = np.zeros((convolved_sky_map.shape))

In [None]:
input_map_reord_multifeature[0, :, 0] = hp.reorder(convolved_sky_map[0, :, 0], r2n=True)
input_map_reord_multifeature[0, :, 1] = hp.reorder(convolved_sky_map[0, :, 1], r2n=True)
input_map_reord_multifeature[0, :, 2] = hp.reorder(convolved_sky_map[0, :, 2], r2n=True)

In [None]:
G_sky_multifeature.signal = input_map_reord_multifeature[0]

In [None]:
plot_sky_3d_multifeature(G_sky_multifeature, feature_index=1)

# Test

## Forward

In [None]:
forward_ops = ForwardOps(q[0], multiacquisition, s)

In [None]:
unit = forward_ops.op_unit_conversion()
aperture = forward_ops.op_aperture_integration()
filter = forward_ops.op_filter()
projection = forward_ops.op_projection()
hwp = forward_ops.op_hwp()
pol = forward_ops.op_polarizer()
det_inte = forward_ops.op_detector_integration()
transmission = forward_ops.op_transmission()
det_resp = forward_ops.op_bolometer_response()

In [None]:
tod_test = det_resp(transmission(det_inte(pol(hwp(projection[0](filter(aperture(unit(sky_map[0])))))))))

In [None]:
plt.plot(tod_test)

## Inverse

In [None]:
tod_test = torch.as_tensor(tod_test)

### Transmission

In [None]:
from qubic.lib.AnalyticalSolution.operators.inverse_ops import InverseTransmissionDeterministic, InverseTransmissionTrainable

#### Determinisitic

In [None]:
inv_trans_deterministric = InverseTransmissionDeterministic(q[0])

tod_trans = inv_trans_deterministric(tod_test)[0]
print(tod_trans.shape)

plt.plot(tod_trans)
plt.show()

#### Trainable

In [None]:
q = QubicMultibandInstrument(d)[0]

T_optics = np.prod(q.optics.components["transmission"]).copy()
eta_true = np.mean(q.detector.efficiency).copy()
D, Nt = 10, 1000

q.detector.efficiency += 0.3

torch.manual_seed(0)
tod_before = torch.randn(D, Nt, dtype=torch.float32)
tod_after = tod_before * (T_optics * eta_true)             

model = InverseTransmissionTrainable(qubic_instrument=q, mode="global_eta", dtype=torch.float32, device=None)
model = model.to(torch.float32)

print("Before training: eta param =", float(model.eta.detach().cpu().numpy()))
print("Loss before:", torch.mean((model.forward(tod_after) - tod_before)**2).item())

res = model.fit(tod_after, tod_before, lr=5e-2, epochs=500, print_every=50)
print("After training: eta param =", float(model.eta.detach().cpu().numpy()))
print("Expected eta_true:", eta_true)
