In [None]:
import os

import healpy as hp
import matplotlib.pyplot as plt
import numpy as np
import pyro
import pyro.distributions as dist
import torch
from pyro.infer import SVI, Predictive, Trace_ELBO
from pyro.nn import PyroModule, PyroSample
from pyro.optim import Adam
from scipy.stats import norm

from qubic.lib.Instrument.Qacquisition import QubicAcquisition
from qubic.lib.Instrument.Qinstrument import QubicInstrument, compute_freq
from qubic.lib.MapMaking.FrequencyMapMaking.Qspectra_component import CMBModel
from qubic.lib.MapMaking.NeuralNetworkMapMaking.operators.forward_ops import ForwardOps
from qubic.lib.Qdictionary import qubicDict
from qubic.lib.Qsamplings import equ2gal, get_pointing
from qubic.lib.Qscene import QubicScene

%matplotlib inline

# QUBIC Parameters

In [None]:
dictfilename = "qubic/qubic/dicts/pipeline_demo.dict"
d = qubicDict()
d.read_from_file(dictfilename)

In [None]:
d["nf_recon"] = 1
d["MultiBand"] = False
d["synthbeam_kmax"] = 1
nf_sub = d["nf_sub"]
d["synthbeam_fraction"] = 1
d["noiseless"] = True
d["photon_noise"] = False
d["use_synthbeam_fits_file"] = False
d["npointings"] = 1000
d["nside"] = 128

# Build Sky

In [None]:
seed = 3
cl_cmb = CMBModel(None).give_cl_cmb(r=0, Alens=1)
sky_map = np.array(d["nf_sub"] * [hp.synfast(cl_cmb, d["nside"], new=True, verbose=False).T])[0]
print(sky_map.shape)

hp.mollview(sky_map[..., 0])

# QUBIC Instance

In [None]:
p = get_pointing(d)
s = QubicScene(d)
q = QubicInstrument(d)
acquisition = QubicAcquisition(q, p, s, d)

_, nus_edge, nus, _, _, _ = compute_freq(d["filter_nu"] / 1e9, d["nf_sub"], d["filter_relative_bandwidth"])

In [None]:
H = acquisition.get_operator()
convolution = acquisition.get_convolution_peak_operator()
convolved_maps = convolution(sky_map)
cov = acquisition.get_coverage()

In [None]:
forward_ops = ForwardOps(q, acquisition, s)

# Sequential Operators combinations for training

In [None]:
Us = forward_ops.op_unit_conversion()(convolved_maps)

TUs = Us

ATUs = forward_ops.op_aperture_integration()(TUs)

FATUs = forward_ops.op_filter()(ATUs)
print(FATUs.shape)

In [None]:
PFATUs = H.operands[-1](FATUs)
print(PFATUs.shape)

HPFATUs = forward_ops.op_hwp()(PFATUs)
PHPFATUs = forward_ops.op_polarizer()(HPFATUs)
print(PHPFATUs.shape)

APHPFATUs = forward_ops.op_detector_integration()(PHPFATUs)
TAPHPFATUs = forward_ops.op_transmission()(APHPFATUs)
RTAPHPFATUs = forward_ops.op_bolometer_response()(TAPHPFATUs)

# Apply Detector Integration Operator to a TOD to get both datasets (before-after application)

In [None]:
det_integration_operator = forward_ops.op_detector_integration()

original_tod = torch.tensor(APHPFATUs, dtype=torch.float64)

tod_after_det_integration = torch.tensor(det_integration_operator(original_tod.detach().cpu().numpy()), dtype=torch.float64)

In [None]:
tod_before_det_integration_list = []
tod_after_det_integration_list = []

solid_angle_true = q.secondary_beam.solid_angle
print("True Solid Angle :", solid_angle_true)

for i in range(10):
    tod_before_det_integration_list.append(APHPFATUs)
    tod_after_det_integration = det_integration_operator(APHPFATUs)
    tod_after_det_integration_list.append(tod_after_det_integration)

solid_angle_wrong = 0.01
print("False Solid Angle :", solid_angle_wrong)
for i in range(10):
    tod_before_det_integration_list.append(APHPFATUs)
    tod_after_det_integration = det_integration_operator(APHPFATUs)
    tod_after_det_integration_list.append(tod_after_det_integration)

tod_before_det_integration_torch = torch.tensor(tod_before_det_integration_list, dtype=torch.float64)
tod_after_det_integration_torch = torch.tensor(tod_after_det_integration_list, dtype=torch.float64)

# Monochromatic Pyro for Detector Integration

In [None]:
pyro.set_rng_seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"

det_tod = tod_after_det_integration_torch.to(torch.float32).to(device)
sky_tod = tod_before_det_integration_torch.to(torch.float32).to(device)
N, D, Nt = det_tod.shape

pos = q.detector.center
area = q.detector.area
sec_beam = q.secondary_beam


class InvSolidAnglePerSample(PyroModule):
    def __init__(self, rel_sigma=0.05):
        super().__init__()
        self.log_solid_angle = PyroSample(dist.Normal(0.0, rel_sigma))

    def forward(self, det_tod, pos, area, sec_beam):
        solid_angle = torch.exp(self.log_solid_angle)

        theta = np.arctan2(np.sqrt((pos[..., :2] ** 2).sum(-1)), pos[..., 2])
        phi = np.arctan2(pos[..., 1], pos[..., 0])
        sr_det = -area / pos[..., 2] ** 2 * np.cos(theta) ** 3
        gain = sec_beam(theta, phi)

        return det_tod / torch.tensor(solid_angle / (sr_det * gain))


layer = InvSolidAnglePerSample(rel_sigma=0.20).to(device)
sigma_noise = 1e-18


def model(det_tod, sky_tod, pos, area, sec_beam):
    with pyro.plate("batch", det_tod.size(0)):
        sky_hat = layer(det_tod, pos, area, sec_beam)
        pyro.sample("obs", dist.Normal(sky_hat, sigma_noise).to_event(2), obs=sky_tod)

## Optimiser, guide, and SVI

In [None]:
guide = pyro.infer.autoguide.AutoNormal(model)

optim = Adam({"lr": 3e-3})

svi = SVI(model, guide, optim, Trace_ELBO())

In [None]:
file = "invI_svi.pt"
save_every = 20
n_steps = 200
start = 200

In [None]:
if os.path.exists(file):
    ckpt = torch.load(file, map_location=device, weights_only=False)
    pyro.get_param_store().set_state(ckpt["param_store"])
    optim.set_state(ckpt["optim_state"])
    start = ckpt["step"] + 1
    print(f"✓ Resuming from step {start}")

for step in range(start, start + n_steps):
    loss = svi.step(det_tod, sky_tod, pos, area, sec_beam)

    if step % save_every == 0:
        torch.save({"step": step, "param_store": pyro.get_param_store().get_state(), "optim_state": optim.get_state()}, file)
        print(f"step {step:5d} | ELBO {loss:8.3g}  ➜ checkpoint saved")

print("Done ✔")

In [None]:
if True:
    n_steps = 10
    for step in range(1, n_steps + 1):
        loss = svi.step(det_tod, sky_tod, pos, area, sec_beam)
        if step % 20 == 0:
            print(f"SVI step {step:4d} ELBO = {loss:.3f}")

In [None]:
predictive = Predictive(model, guide=guide, num_samples=50, return_sites=["log_solid_angle"])
post = predictive(det_tod, pos, tod_sky=None)
solid_angle_samples = torch.exp(post["log_solid_angle"])

solid_angle_mean = solid_angle_samples.mean(0).cpu()
solid_angle_std = solid_angle_samples.std(0).cpu()

print(solid_angle_mean)

In [None]:
for i, (m, s) in enumerate(zip(solid_angle_mean, solid_angle_std)):
    print(f"SAMPLE {i:2d}  η = {m:.3f} ± {s:.3f}")

plt.figure()
plt.errorbar(range(N), solid_angle_mean, yerr=solid_angle_std, fmt="o", capsize=3, color="tab:blue")

plt.show()

# Drawing samples