In [None]:
import os

import healpy as hp
import matplotlib.pyplot as plt
import numpy as np
import pyro
import pyro.distributions as dist
import torch
from pyro.infer import SVI, Predictive, Trace_ELBO
from pyro.nn import PyroModule, PyroSample
from pyro.optim import Adam
from scipy.stats import norm

from qubic.lib.Instrument.Qacquisition import QubicAcquisition
from qubic.lib.Instrument.Qinstrument import QubicInstrument, compute_freq
from qubic.lib.MapMaking.FrequencyMapMaking.Qspectra_component import CMBModel
from qubic.lib.MapMaking.NeuralNetworkMapMaking.operators.forward_ops import ForwardOps
from qubic.lib.Qdictionary import qubicDict
from qubic.lib.Qsamplings import equ2gal, get_pointing
from qubic.lib.Qscene import QubicScene

In [None]:
%matplotlib inline

# QUBIC Parameters

In [None]:
dictfilename = "qubic/qubic/dicts/pipeline_demo.dict"
d = qubicDict()
d.read_from_file(dictfilename)

center = equ2gal(d["RA_center"], d["DEC_center"])

In [None]:
d["nf_recon"] = 1
d["MultiBand"] = False
d["synthbeam_kmax"] = 1
nf_sub = d["nf_sub"]
d["synthbeam_fraction"] = 1
d["noiseless"] = True
d["photon_noise"] = False
d["use_synthbeam_fits_file"] = False
d["npointings"] = 1000
d["nside"] = 128

# Build Sky

In [None]:
seed = 3
cl_cmb = CMBModel(None).give_cl_cmb(r=0, Alens=1)
sky_map = np.array(d["nf_sub"] * [hp.synfast(cl_cmb, d["nside"], new=True, verbose=False).T])[0]
print(sky_map.shape)

In [None]:
hp.mollview(sky_map[..., 0])

# QUBIC Instance

In [None]:
p = get_pointing(d)
s = QubicScene(d)
q = QubicInstrument(d)
acquisition = QubicAcquisition(q, p, s, d)

_, nus_edge, nus, _, _, _ = compute_freq(d["filter_nu"] / 1e9, d["nf_sub"], d["filter_relative_bandwidth"])

In [None]:
H = acquisition.get_operator()
convolution = acquisition.get_convolution_peak_operator()
convolved_maps = convolution(sky_map)
cov = acquisition.get_coverage()

In [None]:
forward_ops = ForwardOps(q, acquisition, s)

# Sequential Operators combinations for training

In [None]:
Us = forward_ops.op_unit_conversion()(convolved_maps)

TUs = Us

ATUs = forward_ops.op_aperture_integration()(TUs)

FATUs = forward_ops.op_filter()(ATUs)
print(FATUs.shape)

In [None]:
PFATUs = H.operands[-1](FATUs)
print(PFATUs.shape)

HPFATUs = forward_ops.op_hwp()(PFATUs)
PHPFATUs = forward_ops.op_polarizer()(HPFATUs)
print(PHPFATUs.shape)

APHPFATUs = forward_ops.op_detector_integration()(PHPFATUs)
TAPHPFATUs = forward_ops.op_transmission()(APHPFATUs)
RTAPHPFATUs = forward_ops.op_bolometer_response()(TAPHPFATUs)

## Apply the transmission operator to a TOD to get both datasets (before-after application)

In [None]:
transmission_operator = forward_ops.op_transmission()

# original_tod = torch.randn(num_detectors, num_pointings, dtype=torch.float32)
original_tod = torch.tensor(APHPFATUs, dtype=torch.float64)

tod_after_transmission = torch.tensor(transmission_operator(original_tod.detach().cpu().numpy()), dtype=torch.float64)

In [None]:
tod_before_transmission_list = []
tod_after_transmission_list = []
q.detector.efficiency = np.ones(q.detector.efficiency.shape, dtype=np.float64)
q.detector.efficiency[:] = 0.8
forward_ops = ForwardOps(q, acquisition, s)
for i in range(5):
    transmission_operator = forward_ops.op_transmission()
    original_tod = APHPFATUs
    tod_before_transmission_list.append(original_tod)
    tod_after_transmission = transmission_operator(original_tod)
    tod_after_transmission_list.append(tod_after_transmission)

# new_eff = np.ones(q.detector.efficiency.shape, dtype=np.float64)
# new_eff[:] = 0.4
# q.detector.efficiency = new_eff
# forward_ops = ForwardOps(q, acquisition, s)
# for i in range(10):
#     transmission_operator = forward_ops.op_transmission()
#     original_tod = APHPFATUs
#     tod_before_transmission_list.append(original_tod)
#     tod_after_transmission = transmission_operator(original_tod)
#     tod_after_transmission_list.append(tod_after_transmission)

tod_before_transmission_torch = torch.tensor(tod_before_transmission_list, dtype=torch.float64)
tod_after_transmission_torch = torch.tensor(tod_after_transmission_list, dtype=torch.float64)

# Monochromatic Pyro example for transmission:

In [None]:
pyro.set_rng_seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"

# N is batch, D is detectrs (992), Nt is time samples
det_tod = tod_after_transmission_torch.to(torch.float32).to(device)  # shape (N,D,Nt)
sky_tod = tod_before_transmission_torch.to(torch.float32).to(device)  # same shape
N, D, Nt = det_tod.shape  # D=992

T_optics = np.prod(q.optics.components["transmission"])


class InvEtaPerSample(PyroModule):
    def __init__(self, rel_sigma=0.05):
        super().__init__()
        self.log_eta = PyroSample(dist.Normal(0.0, rel_sigma))

    def forward(self, tod_det, transmission):  # tod_det : (N, D, Nt)
        eta = torch.exp(self.log_eta)  # (N
        eta = eta[:, None, None]  # (N, 1, 1)   over det & time
        return tod_det / (torch.tensor(np.product(transmission)) * eta)


layer = InvEtaPerSample(rel_sigma=0.20).to(device)  # 20 % width prior

sigma_noise = 1e-18  # assumed white noise add a small value


def model(det, trans, sky):
    with pyro.plate("batch", det.size(0)):  # one eta per sample
        sky_hat = layer(det, trans)
        pyro.sample("obs", dist.Normal(sky_hat, sigma_noise).to_event(2), obs=sky)


Optimizer, guide, and SVI

In [None]:
guide = pyro.infer.autoguide.AutoNormal(model)

optim = Adam({"lr": 3e-3})

svi = SVI(model, guide, optim, Trace_ELBO())

In [None]:
ckpt_file = "invT_svi_ckpt.pt"
save_every = 20  # how often to checkpoint
n_steps = 200  # total extra stepst to run

In [None]:
start = 200

In [None]:
# this is just if you want to train in multiple times, to save the checkppoint and continue, since pyro can be memory expensive
if os.path.exists(ckpt_file):
    ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
    pyro.get_param_store().set_state(ckpt["param_store"])
    optim.set_state(ckpt["optim_state"])
    start = ckpt["step"] + 1
    print(f"✓ Resuming from step {start}")


for step in range(start, start + n_steps):
    loss = svi.step(det_tod, T_optics, sky_tod)

    if step % save_every == 0:
        torch.save(
            {
                "step": step,
                "param_store": pyro.get_param_store().get_state(),
                "optim_state": optim.get_state(),
            },
            ckpt_file,
        )
        print(f"step {step:5d} | ELBO {loss:8.3g}  ➜ checkpoint saved")

print("Done ✔")

In [None]:
if False:
    n_steps = 10
    for step in range(1, n_steps + 1):
        loss = svi.step(det_tod, T_optics, sky_tod)
        if step % 20 == 0:
            print(f"SVI step {step:4d}  ELBO = {loss:.3g}")

In [None]:
predictive = Predictive(model, guide=guide, num_samples=50, return_sites=["log_eta"])
post = predictive(det_tod, T_optics, sky=None)
eta_samples = torch.exp(post["log_eta"])  # (300 , N)

eta_mean = eta_samples.mean(0).cpu()  # (N,)
eta_std = eta_samples.std(0).cpu()

print(eta_mean)

In [None]:
for i, (m, s) in enumerate(zip(eta_mean, eta_std)):
    print(f"SAMPLE {i:2d}  η = {m:.3f} ± {s:.3f}")

plt.figure()
plt.errorbar(range(N), eta_mean, yerr=eta_std, fmt="o", capsize=3, color="tab:blue")
# plt.axhline(0.7, ls='--', c='k', lw=1)
plt.axhline(0.8, ls="--", c="k", lw=1)
plt.ylabel("mean detector efficiency  η")
plt.xlabel("TOD sample ID")
plt.title("Posterior mean ±1σ for each training TOD")
plt.show()

This is enough to get a mean efficiency and the error.
To look at the posterior and interpret it we need to draw samples.

# Drawing samples

In [None]:
def posterior_mean_std_and_samples(model, guide, det_tod, trans, n_draw=60, chunk=10, sample_site="log_eta", keep_draws=False):
    """
    Same as before, but optionally returns a tensor of all stored draws.
    Set keep_draws=True if you want to keep them (this may be large! depends on the choice of size of sempling sets!).
    """
    assert n_draw % chunk == 0
    mean = m2 = None
    n_seen = 0
    all_draws = [] if keep_draws else None

    pred = Predictive(model, guide=guide, num_samples=chunk, return_sites=[sample_site])

    for _ in range(n_draw // chunk):
        post = pred(det_tod, trans, sky=None)
        draws = torch.exp(post[sample_site]).mean(1)  # (chunk, N_det)

        if keep_draws:
            all_draws.append(draws.cpu())

        for x in draws:
            n_seen += 1
            if mean is None:
                mean = torch.zeros_like(x)
                m2 = torch.zeros_like(x)
            delta = x - mean
            mean += delta / n_seen
            m2 += delta * (x - mean)

        del post, draws

    std = torch.sqrt(m2 / (n_seen - 1))
    if keep_draws:
        all_draws = torch.cat(all_draws, dim=0)  # (n_draw, N_det)
        return mean.cpu(), std.cpu(), all_draws
    return mean.cpu(), std.cpu()

In [None]:
eta_mean, eta_std, eta_samples = posterior_mean_std_and_samples(model, guide, det_tod, T_optics, n_draw=300, chunk=10, keep_draws=True)

In [None]:
eta_samples.shape

In [None]:
det_id = 0

invT_det = eta_samples[:,].cpu()  # ,)

# posterior mean / std for this detector
mu_det = eta_samples.mean().item()
sig_det = eta_samples.std().item()


tod_det = det_tod[0, det_id].cpu()  # shape (Nt,)
tod_corr = tod_det * mu_det  # a point-estimate correction
tod_true = sky_tod[0, det_id].cpu()

Nt = tod_det.numel()
t = np.arange(Nt)  # over time samples


fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].hist(eta_samples[:].numpy(), bins=30, alpha=0.2, color="steelblue", density=True, label="posterior samples")

xmin, xmax = ax[0].get_xlim()
xs = np.linspace(xmin, xmax, 200)
ax[0].plot(xs, norm.pdf(xs, mu_det, sig_det), "k--", label=r"$\mathcal{N}(\mu,\sigma)$ fit")
ax[0].axvline(mu_det, color="k")
ax[0].axvspan(mu_det - sig_det, mu_det + sig_det, alpha=0.2, color="orange", label=r"$\pm1\sigma$")
ax[0].set_xlabel(r"mean detector efficiency  η")
ax[0].set_ylabel("density")
ax[0].set_title("Detector posterior")
ax[0].legend()

ax[1].plot(t, tod_det, lw=0.8, label="det TOD (attenuated)", alpha=0.4)
ax[1].plot(t, tod_corr, lw=0.8, label="corrected (mean η)", color="C3", alpha=0.4)
# ax[1].plot(t, tod_true, '--', lw=1.0, label='sky TOD (target)',  color='k', alpha = 0.4)
ax[1].set_xlabel("time sample")
ax[1].set_ylabel("power")
ax[1].set_title("Example TOD – before / after correction")
ax[1].legend()

plt.tight_layout()
plt.show()
