In [None]:
import os

import healpy as hp
import matplotlib.pyplot as plt
import numpy as np
import pyro
import pyro.distributions as dist
import torch
from pyro.infer import SVI, Predictive, Trace_ELBO
from pyro.nn import PyroModule, PyroSample
from pyro.optim import Adam
from scipy.stats import norm

from qubic.lib.Instrument.Qacquisition import QubicAcquisition
from qubic.lib.Instrument.Qinstrument import QubicInstrument, compute_freq
from qubic.lib.MapMaking.FrequencyMapMaking.Qspectra_component import CMBModel
from qubic.lib.MapMaking.NeuralNetworkMapMaking.operators.forward_ops import ForwardOps
from qubic.lib.Qdictionary import qubicDict
from qubic.lib.Qsamplings import equ2gal, get_pointing
from qubic.lib.Qscene import QubicScene

In [None]:
%matplotlib inline

# QUBIC Parameters

In [None]:
dictfilename = "qubic/qubic/dicts/pipeline_demo.dict"
d = qubicDict()
d.read_from_file(dictfilename)

center = equ2gal(d["RA_center"], d["DEC_center"])

In [None]:
d["nf_recon"] = 1
d["MultiBand"] = False
d["synthbeam_kmax"] = 1
nf_sub = d["nf_sub"]
d["synthbeam_fraction"] = 1
d["noiseless"] = True
d["photon_noise"] = False
d["use_synthbeam_fits_file"] = False
d["npointings"] = 1000
d["nside"] = 128

# Build Sky

In [None]:
seed = 3
cl_cmb = CMBModel(None).give_cl_cmb(r=0, Alens=1)
sky_map = np.array(d["nf_sub"] * [hp.synfast(cl_cmb, d["nside"], new=True, verbose=False).T])[0]
print(sky_map.shape)

In [None]:
hp.mollview(sky_map[..., 0])

# QUBIC Instance

In [None]:
p = get_pointing(d)
s = QubicScene(d)
q = QubicInstrument(d)
acquisition = QubicAcquisition(q, p, s, d)

_, nus_edge, nus, _, _, _ = compute_freq(d["filter_nu"] / 1e9, d["nf_sub"], d["filter_relative_bandwidth"])

In [None]:
H = acquisition.get_operator()
convolution = acquisition.get_convolution_peak_operator()
convolved_maps = convolution(sky_map)
cov = acquisition.get_coverage()

In [None]:
forward_ops = ForwardOps(q, acquisition, s)

# Sequential Operators combinations for training

In [None]:
Us = forward_ops.op_unit_conversion()(convolved_maps)

TUs = Us

ATUs = forward_ops.op_aperture_integration()(TUs)

FATUs = forward_ops.op_filter()(ATUs)
print(FATUs.shape)

In [None]:
PFATUs = H.operands[-1](FATUs)
print(PFATUs.shape)

HPFATUs = forward_ops.op_hwp()(PFATUs)
PHPFATUs = forward_ops.op_polarizer()(HPFATUs)
print(PHPFATUs.shape)

APHPFATUs = forward_ops.op_detector_integration()(PHPFATUs)
TAPHPFATUs = forward_ops.op_transmission()(APHPFATUs)
RTAPHPFATUs = forward_ops.op_bolometer_response()(TAPHPFATUs)

## Apply the transmission operator to a TOD to get both datasets (before-after application)

In [None]:
transmission_operator = forward_ops.op_transmission()

# original_tod = torch.randn(num_detectors, num_pointings, dtype=torch.float32)
original_tod = torch.tensor(APHPFATUs, dtype=torch.float64)

tod_after_transmission = torch.tensor(transmission_operator(original_tod.detach().cpu().numpy()), dtype=torch.float64)

In [None]:
tod_before_transmission_list = []
tod_after_transmission_list = []

true_eta = 0.5
true_T = 1

q.detector.efficiency = np.ones(q.detector.efficiency.shape, dtype=np.float64) * true_eta
q.optics.components['transmission'] = np.ones(q.optics.components['transmission'].shape, dtype=np.float64) * true_T

forward_ops = ForwardOps(q, acquisition, s)
for i in range(5):
    transmission_operator = forward_ops.op_transmission()
    original_tod = APHPFATUs
    tod_before_transmission_list.append(original_tod)
    tod_after_transmission = transmission_operator(original_tod)
    tod_after_transmission_list.append(tod_after_transmission)

# new_eff = np.ones(q.detector.efficiency.shape, dtype=np.float64)
# new_eff[:] = 0.4
# q.detector.efficiency = new_eff
# forward_ops = ForwardOps(q, acquisition, s)
# for i in range(10):
#     transmission_operator = forward_ops.op_transmission()
#     original_tod = APHPFATUs
#     tod_before_transmission_list.append(original_tod)
#     tod_after_transmission = transmission_operator(original_tod)
#     tod_after_transmission_list.append(tod_after_transmission)

tod_before_transmission_torch = torch.tensor(tod_before_transmission_list, dtype=torch.float64)
tod_after_transmission_torch = torch.tensor(tod_after_transmission_list, dtype=torch.float64)

# Monochromatic Pyro for transmission:

In [None]:
pyro.set_rng_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

det_tod = tod_after_transmission_torch.to(torch.float32).to(device)  # (N, D, Nt)
sky_tod =  tod_before_transmission_torch.to(torch.float32).to(device) # (N, D, Nt)
N, D, Nt = det_tod.shape

trans_components = np.asarray(q.optics.components['transmission'], dtype=float)
t0 = float(np.prod(trans_components))  # the tru value will be used as prior

sigma_noise  = 1e-18                       # white noise  on TOD
sigma_eta    = 0.20                        # 20% on eta
sigma_topt   = 0.20                        # 20% also on the optics

class InvThroughput(torch.nn.Module):
    def forward(self, tod_det, eta, t_optics):
        return tod_det / (t_optics[:, None, None] * eta[:, None, None])

layer = InvThroughput().to(device)

def model(det, sky):

    with pyro.plate("batch", det.size(0)):
        log_eta = pyro.sample("log_eta", dist.Normal(0.0, sigma_eta))
        eta = torch.exp(log_eta)
        
        log_topt = pyro.sample("log_topt", dist.Normal(np.log(t0), sigma_topt))
        t_optics = torch.exp(log_topt)

        sky_hat = layer(det, eta, t_optics)
        pyro.sample("obs", dist.Normal(sky_hat, sigma_noise).to_event(2), obs=sky)

guide = pyro.infer.autoguide.AutoNormal(model)

optim  = Adam({"lr": 3e-3})         

svi   = SVI(model, guide, optim, Trace_ELBO())

ckpt_file = "inv_eta_topt.ckpt"
start = 0
save_every = 10
n_steps = 200

if os.path.exists(ckpt_file):
    ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
    pyro.get_param_store().set_state(ckpt["param_store"])
    optim.set_state(ckpt["optim_state"])
    start = ckpt["step"] + 1
    print(f"✓ Resuming from step {start}")

for step in range(start, start + n_steps):
    loss = svi.step(det_tod, sky_tod)

    if step % save_every == 0:
        torch.save(
            {
                "step":        step,
                "param_store": pyro.get_param_store().get_state(),
                "optim_state": optim.get_state(),
            },
            ckpt_file,
        )
        print(f"step {step:5d} | ELBO {loss:8.3g}  ➜ checkpoint saved")

print("Done ✔")


predictive = Predictive(model, guide=guide, num_samples=50,
                        return_sites=["log_eta", "log_topt"])
post       = predictive(det_tod, sky=None)

eta_samples   = torch.exp(post["log_eta"])     
topt_samples  = torch.exp(post["log_topt"])    

eta_mean = eta_samples.mean(0).cpu()          
eta_std  = eta_samples.std (0).cpu()
topt_mean = topt_samples.mean(0).cpu()         
topt_std  = topt_samples.std (0).cpu()

In [None]:
print(eta_mean)
print(topt_mean)