In [None]:
from pylab import *
import gc
import pickle5 as pickle
import inspect

import healpy as hp
import numpy as np

from qubicpack.utilities import Qubic_DataDir
from pysimulators import FitsArray
import qubic
import qubic.lib.QskySim as qss
from qubic import SpectroImLib as si
from qubic import camb_interface as qc
from pysm3 import models
import pysm3

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import ChebConv

from tqdm import tqdm
from pygsp import graphs
import torch_geometric.nn as gnn
from torch_geometric.data import Data, DataLoader


import torch.distributed as dist
import torch.multiprocessing as mp
from itertools import combinations

In [None]:
from pyoperators import MPIDistributionIdentityOperator, HomothetyOperator, DiagonalOperator, IdentityOperator, Rotation2dOperator, Rotation3dOperator, ReshapeOperator,BlockDiagonalOperator, DenseBlockDiagonalOperator, CompositionOperator
from pysimulators import ConvolutionTruncatedExponentialOperator, ProjectionOperator
from pysimulators.interfaces.healpy import Cartesian2HealpixOperator, HealpixConvolutionGaussianOperator
from qubic.ripples import ConvolutionRippledGaussianOperator

In [None]:
from pygsp import graphs, filters
from scipy import sparse

In [None]:
%matplotlib inline

In [None]:
dictfilename = 'qubic/qubic/dicts/pipeline_demo.dict'
d = qubic.lib.qubicDict()
d.read_from_file(dictfilename)

center = qubic.lib.equ2gal(d['RA_center'], d['DEC_center'])

In [None]:
d['nf_recon'] = 1
d['MultiBand'] = False

d['synthbeam_kmax'] = 1
nf_sub = d['nf_sub']
d['synthbeam_fraction'] = 1

d['use_synthbeam_fits_file'] = False

In [None]:
d['noiseless'] = True
d['photon_noise'] = False

In [None]:
d['npointings'] = 4000

In [None]:
d['nside'] = 128

In [None]:
seed = 42
sky_config = {'cmb': seed, 'dust': 'd0'}
Qubic_sky = qss.Qubic_sky(sky_config, d)
sky_map = Qubic_sky.get_simple_sky_map()

In [None]:
p = qubic.lib.get_pointing(d)
s = qubic.lib.QubicScene(d)
q = qubic.lib.QubicInstrument(d)

In [None]:
_, nus_edge, nus, _, _, _ = qubic.lib.compute_freq(d['filter_nu'] / 1e9, d['nf_sub'], d['filter_relative_bandwidth'])

In [None]:
acquisition = qubic.lib.QubicAcquisition(q, p, s, d)

In [None]:
H = acquisition.get_operator()

In [None]:
convolution = acquisition.get_convolution_peak_operator()
convolved_maps = convolution(sky_map[0])

In [None]:
cov = acquisition.get_coverage()

U unit ‚àí1‚Äã(Unit Conversion ):

Description: Converts sky temperature units into W/m 2 2 /Hz using Planck's law. Inversion: The inversion involves reversing the unit conversion, which is relatively straightforward, although it involves non-linear functions (like exponentials).

ùëá atm ‚àí 1 T atm ‚àí1‚Äã(Atmosphere Transmission):

Description: Accounts for atmospheric transmission, reducing the signal. Inversion: Inverting this operator involves dividing by the transmission coefficient, which can be done easily if the transmission is uniform and known. However, in cases where the transmission varies (e.g., with frequency or position), the inversion could be more complex.

ùê¥ p ‚àí 1 A p ‚àí1‚Äã(Aperture Integration):

Description: Converts W/m 2 2 /Hz into W/Hz by multiplying by the scalar depending on the number of horns and their size. Inversion: This is a simple scalar multiplication, and the inverse is done dividing by this scalar.

ùêπ fil ‚àí 1 F fil ‚àí1‚Äã(Filter ):

Description: Takes into account the bandwidth to convert W/Hz into W. Inversion: Inverting the filter operator is straightforward.

ùëÉ proj ‚àí 1 P proj ‚àí1‚Äã(Projection --- nije implementiran!!!!!):

Description: Maps the sky onto the time-ordered data. Inversion: This operator is difficult to invert because it involves transforming a 2D sky map into a 1D TOD. 

ùêª wp ‚àí 1 H wp ‚àí1‚Äã(Halfwave Plate ):

Description: Modulates the signal based on the Half-Wave Plate (HWP) angles. Inversion: This operator is a simple rotation matrix, which is straightforward to invert mathematically by applying the inverse rotation.

ùëÉ pol ‚àí 1 P pol ‚àí1‚Äã(Polarizer ):

Description: Projects the sky map based on polarization. Inversion: Inverting this operator is complicated depending on the polarization model used. 

I det (Flux density Integration):

Idet is an operator to integrate the flux density in the solid angle taking into account the
primary and secondary beam.

ùëá det ‚àí 1 T det ‚àí1‚Äã(Transmission):

Description: Takes into account the transmission efficiency of the detectors. Inversion: Similar to the atmospheric transmission, this can be inverted if the transmission is known and uniform. Variability in transmission would make this more complex.

ùëÖ det ‚àí 1 R det ‚àí1‚Äã(Bolometer Operator):

Description: Applies the bolometer time constant. Inversion: This is related to the time response of the detector. Inverting this requires deconvolution (to invert the exponential decay)

$ s = U_{\text{unit}}^{-1} T_{\text{atm}}^{-1} A_{\text{ap}}^{-1} F_{\text{filter}}^{-1} P_{\text{proj}}^{-1} H_{\text{HWP}}^{-1} P_{\text{pol}}^{-1} A_{\text{det}}^{-1} T_{\text{ins}}^{-1} R_{\text{det}}^{-1} \cdot  d
 $

# Bolometer response (M)

In [None]:
def op_bolometer_response(qubic_instrument, qubic_acquisition, tau=None):
    """
    Return the operator for the bolometer responses.

    Parameters
    ----------
    qubic_instrument : QubicInstrument
        An instance of the QubicInstrument class.
    qubic_acquisition : QubicAcquisition
        An instance of the QubicAcquisition class.
    tau : float, optional
        The bolometer time constant. If None, it is taken from the instrument's detector attribute.
    
    Returns
    -------
    Operator
        The bolometer response operator.
    """
    if tau is None:
        tau = qubic_instrument.detector.tau
    sampling_period = qubic_acquisition.sampling.period
    shapein = (len(qubic_instrument), len(qubic_acquisition.sampling))
    if sampling_period == 0:
        return IdentityOperator(shapein)
    return ConvolutionTruncatedExponentialOperator(
        tau / sampling_period, shapein=shapein)

In [None]:
op_bolometer_response(q, acquisition)

# Transmission (F)

In [None]:
def op_transmission(qubic_instrument):
    """
    Return the operator that multiplies by the cumulative instrumental transmission.

    Parameters
    ----------
    qubic_instrument : QubicInstrument
        An instance of the QubicInstrument class.
    
    Returns
    -------
    DiagonalOperator
        The transmission operator.
    """
    transmission = np.product(qubic_instrument.optics.components['transmission']) * qubic_instrument.detector.efficiency
    return DiagonalOperator(transmission, broadcast='rightward')

In [None]:
op_transmission(q)

# Flux Density Integration (F)

In [None]:
def op_detector_integration(qubic_instrument):
    """
    Integrate flux density in detector solid angles and take into account
    the secondary beam transmission.

    Parameters
    ----------
    qubic_instrument : QubicInstrument
        An instance of the QubicInstrument class.
    
    Returns
    -------
    DiagonalOperator
        The detector integration operator.
    """
    position = qubic_instrument.detector.center
    area = qubic_instrument.detector.area
    secondary_beam = qubic_instrument.secondary_beam

    theta = np.arctan2(np.sqrt(np.sum(position[..., :2] ** 2, axis=-1)), position[..., 2])
    phi = np.arctan2(position[..., 1], position[..., 0])
    sr_det = -area / position[..., 2] ** 2 * np.cos(theta) ** 3
    sr_beam = secondary_beam.solid_angle
    sec = secondary_beam(theta, phi)
    return DiagonalOperator(sr_det / sr_beam * sec, broadcast='rightward')

In [None]:
op_detector_integration(q)

In [None]:
D = H.operands[1]
#D = op_detector_integration(q)

# Polarizer (M)

In [None]:
def op_polarizer(qubic_instrument, qubic_acquisition, qubic_scene):
    """
    Return operator for the polarizer grid.
    When the polarizer is not present a transmission of 1 is assumed
    for the detectors on the first focal plane and of 0 for the other.
    Otherwise, the signal is split onto the focal planes.

    Parameters
    ----------
    qubic_instrument : QubicInstrument
        An instance of the QubicInstrument class.
    qubic_acquisition : QubicAcquisition
        An instance of the QubicAcquisition class.
    qubic_scene : QubicScene
        The observed scene.
    
    Returns
    -------
    Operator
        The polarizer operator.
    """
    nd = len(qubic_instrument)
    nt = len(qubic_acquisition.sampling)
    grid = (qubic_instrument.detector.quadrant - 1) // 4

    if qubic_scene.kind == 'I':
        if qubic_instrument.optics.polarizer:
            return HomothetyOperator(1 / 2)
        return DiagonalOperator(1 - grid, shapein=(nd, nt), broadcast='rightward')

    if not qubic_instrument.optics.polarizer:
        raise NotImplementedError('Polarized input is not handled without the polarizer grid.')

    z = np.zeros(nd)
    data = np.array([z + 0.5, 0.5 - grid, z]).T[:, None, None, :]
    return ReshapeOperator((nd, nt, 1), (nd, nt)) *  DenseBlockDiagonalOperator(data, shapein=(nd, nt, 3))

In [None]:
print(inspect.getsource(ReshapeOperator))

In [None]:
print(inspect.getsource(DenseBlockDiagonalOperator))

# Half-wave plate (F)

In [None]:
def op_hwp(qubic_instrument, qubic_acquisition, qubic_scene):
    """
    Return the rotation matrix for the half-wave plate.

    Parameters
    ----------
    qubic_instrument : QubicInstrument
        An instance of the QubicInstrument class.
    qubic_acquisition : QubicAcquisition
        An instance of the QubicAcquisition class.
    qubic_scene : QubicScene
        The observed scene.
    
    Returns
    -------
    Operator
        The HWP rotation operator.
    """
    shape = (len(qubic_instrument), len(qubic_acquisition.sampling))
    if qubic_scene.kind == 'I':
        return IdentityOperator(shapein=shape)
    if qubic_scene.kind == 'QU':
        return Rotation2dOperator(-4 * qubic_acquisition.sampling.angle_hwp, degrees=True, shapein=shape + (2,))
    return Rotation3dOperator('X', -4 * qubic_acquisition.sampling.angle_hwp, degrees=True, shapein=shape + (3,))

In [None]:
op_hwp(q, acquisition, s)

In [None]:
#print(inspect.getsource(Rotation3dOperator))

# Half-wave plate + Polarizer

In [None]:
HWPol = op_polarizer(q, acquisition, s)(op_hwp(q, acquisition, s))

In [None]:
op_polarizer(q, acquisition, s)

In [None]:
HWPol

# Projection (M)

In [None]:
P = H.operands[-1]

In [None]:
#print(inspect.getsource(ProjectionOperator))

# Filter (F)

In [None]:
def op_filter(qubic_instrument):
    """
    Convert units from W/Hz to W.

    Parameters
    ----------
    QubicInstrument : QubicInstrument
        An instance of the QubicInstrument class.
    
    Returns
    -------
    Operator
        The filter operator.
    """
    if qubic_instrument.filter.bandwidth == 0:
        return IdentityOperator()
    return HomothetyOperator(qubic_instrument.filter.bandwidth)

In [None]:
op_filter(q)

# Aperture Integration (F)

In [None]:
def op_aperture_integration(qubic_instrument):
    """
    Integrate flux density in the telescope aperture.
    Convert signal from W / m^2 / Hz into W / Hz.

    Parameters
    ----------
    qubic_instrument : QubicInstrument
        An instance of the QubicInstrument class.
    
    Returns
    -------
    HomothetyOperator
        The operator for aperture integration.
    """
    nhorns = np.sum(qubic_instrument.horn.open)
    return HomothetyOperator(nhorns * np.pi * qubic_instrument.horn.radeff ** 2)

In [None]:
op_aperture_integration(q)

# Atmosphere transmission (F)

In [None]:
def op_atmosphere(qubic_acquisition):
    return qubic_acquisition.scene.atmosphere.transmission

In [None]:
op_atmosphere(acquisition)

# Unit conversion (F)

In [None]:
def op_unit_conversion(qubic_instrument, qubic_scene):
        """
        Convert sky temperature into W / m^2 / Hz.
        If the scene has been initialised with the 'absolute' keyword, the
        scene is assumed to include the CMB background and the fluctuations
        (in Kelvin) and the operator follows the non-linear Planck law.
        Otherwise, the scene only includes the fluctuations (in microKelvin)
        and the operator is linear (i.e. the output also corresponds to power
        fluctuations).
        """
        nu = qubic_instrument.filter.nu
        return qubic_scene.get_unit_conversion_operator(nu)

In [None]:
op_unit_conversion(q,s)

## Sequential operator combinations for training

In [None]:
Us = op_unit_conversion(q,s)(convolved_maps)

#TUs = op_atmosphere(acquisition)(Us)
# nije implementirano pa se nista ni ne minja
TUs = Us

ATUs = op_aperture_integration(q)(TUs)
#ATUs.shape

FATUs = op_filter(q)(ATUs)
print(FATUs.shape)

In [None]:
PFATUs = H.operands[-1](FATUs)
print(PFATUs.shape)

HPFATUs = op_hwp(q, acquisition, s)(PFATUs)
#HPFATUs.shape

In [None]:
PHPFATUs = op_polarizer(q, acquisition, s)(HPFATUs)
PHPFATUs.shape

In [None]:
APHPFATUs = op_detector_integration(q)(PHPFATUs)
#APHPFATUs.shape

TAPHPFATUs = op_transmission(q)(APHPFATUs)

RTAPHPFATUs = op_bolometer_response(q, acquisition)(TAPHPFATUs)

## Apply the transmission operator to a TOD to get both datasets (before-after application)

In [None]:
transmission_operator = op_transmission(q)

#original_tod = torch.randn(num_detectors, num_pointings, dtype=torch.float32)
original_tod = torch.tensor(APHPFATUs, dtype=torch.float64)

tod_after_transmission = torch.tensor(transmission_operator(original_tod.detach().cpu().numpy()), dtype=torch.float64)

In [None]:
tod_before_transmission_list = []
tod_after_transmission_list = []
q.detector.efficiency = np.ones(q.detector.efficiency.shape, dtype=np.float64)
q.detector.efficiency[:] = 0.8
for i in range(5):
    original_tod = APHPFATUs
    tod_before_transmission_list.append(original_tod)
    tod_after_transmission = transmission_operator(original_tod)
    tod_after_transmission_list.append(tod_after_transmission)

new_eff = np.ones(q.detector.efficiency.shape, dtype=np.float64)
new_eff[:] = 0.7
q.detector.efficiency = new_eff
for i in range(5):
    transmission_operator = op_transmission(q)
    original_tod = APHPFATUs
    tod_before_transmission_list.append(original_tod)
    tod_after_transmission = transmission_operator(original_tod)
    tod_after_transmission_list.append(tod_after_transmission)

tod_before_transmission_torch = torch.tensor(tod_before_transmission_list, dtype=torch.float64)
tod_after_transmission_torch = torch.tensor(tod_after_transmission_list, dtype=torch.float64)

    
    

# Monochromatic Pyro example for transmission:

In [None]:

import torch, pyro, pyro.distributions as dist
from pyro.nn   import PyroModule, PyroSample
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam
import matplotlib.pyplot as plt
import numpy as np, math
pyro.set_rng_seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"

# N is batch, D is detectrs (992), Nt is time samples 
det_tod = tod_after_transmission_torch.to(torch.float32).to(device)   # shape (N,D,Nt)
sky_tod = tod_before_transmission_torch.to(torch.float32).to(device)  # same shape
N, D, Nt = det_tod.shape                                              # D=992


class InvEtaPerSample(PyroModule):
    def __init__(self, rel_sigma=0.05):
        super().__init__()
        self.log_eta = PyroSample(dist.Normal(0.0, rel_sigma))

    def forward(self, tod_det):          # tod_det : (N, D, Nt)
        eta = torch.exp(self.log_eta)    # (N
        eta = eta[:, None, None]         # (N, 1, 1)   over det & time
        return tod_det / eta                          # undo thru

layer = InvEtaPerSample(rel_sigma=0.20).to(device)   # 20 % width prior

œÉ_noise = 1e-18                                      # assumed white noise add a small value

def model(det, sky):
    with pyro.plate("batch", det.size(0)):           # one eta per sample
        sky_hat = layer(det)
        pyro.sample("obs",
                    dist.Normal(sky_hat, œÉ_noise).to_event(2),
                    obs=sky)




Optimizer, guide, and SVI

In [None]:
guide = pyro.infer.autoguide.AutoNormal(model)

optim  = Adam({"lr": 3e-3})         

svi   = SVI(model, guide, optim, Trace_ELBO())

In [None]:
ckpt_file   = "invT_svi_ckpt.pt"
save_every  = 20          # how often to checkpoint
n_steps     = 400        # total extra stepst to run

In [None]:
start = 200

In [None]:
import os

In [None]:
# this is just if you want to train in multiple times, to save the checkppoint and continue, since pyro can be memory expensive
if os.path.exists(ckpt_file):
    ckpt = torch.load(ckpt_file, map_location=device)
    pyro.get_param_store().set_state(ckpt["param_store"])
    optim.set_state(ckpt["optim_state"])
    start = ckpt["step"] + 1
    print(f"‚úì Resuming from step {start}")


for step in range(start, start + n_steps):
    loss = svi.step(det_tod, sky_tod)

    if step % save_every == 0:
        torch.save(
            {
                "step":        step,
                "param_store": pyro.get_param_store().get_state(),
                "optim_state": optim.get_state(),
            },
            ckpt_file,
        )
        print(f"step {step:5d} | ELBO {loss:8.3g}  ‚ûú checkpoint saved")

print("Done ‚úî")


In [None]:
if True:
    n_steps = 10
    for step in range(1, n_steps+1):
        loss = svi.step(det_tod, sky_tod)
        if step % 20 == 0:
            print(f"SVI step {step:4d}  ELBO = {loss:.3g}")

In [None]:
predictive  = Predictive(model, guide=guide, num_samples=30,
                            return_sites=["log_eta"])
post         = predictive(det_tod, sky=None)
eta_samples  = torch.exp(post["log_eta"])            # (300 , N)

eta_mean = eta_samples.mean(0).cpu()                 # (N,)
eta_std  = eta_samples.std (0).cpu()

In [None]:
for i,(m,s) in enumerate(zip(eta_mean, eta_std)):
    print(f"SAMPLE {i:2d}  Œ∑ = {m:.3f} ¬± {s:.3f}")

plt.errorbar(range(N), eta_mean, yerr=eta_std,
             fmt='o', capsize=3, color='tab:blue')
#plt.axhline(0.7, ls='--', c='k', lw=1)
plt.axhline(0.8, ls='--', c='k', lw=1)
plt.ylabel("mean detector efficiency  Œ∑")
plt.xlabel("TOD sample ID")
plt.ylim(0.6, 1.0)
plt.title("Posterior mean ¬±1œÉ for each training TOD"); plt.show()

This is enough to get a mean efficiency and the error.
To look at the posterior and interpret it we need to draw samples.

# Drawing samples

In [None]:
def posterior_mean_std_and_samples(model, guide, det_tod,
                                   n_draw=60, chunk=10,
                                   sample_site="log_eta",
                                   keep_draws=False):
    """
    Same as before, but optionally returns a tensor of all stored draws.
    Set keep_draws=True if you want to keep them (this may be large! depends on the choice of size of sempling sets!).
    """
    assert n_draw % chunk == 0
    mean = m2 = None
    n_seen = 0
    all_draws = [] if keep_draws else None

    pred = Predictive(model, guide=guide, num_samples=chunk,
                      return_sites=[sample_site])

    for _ in range(n_draw // chunk):
        post   = pred(det_tod, sky=None)
        draws  = torch.exp(post[sample_site]).mean(1)   # (chunk, N_det)

        if keep_draws:
            all_draws.append(draws.cpu())              

        for x in draws:
            n_seen += 1
            if mean is None:
                mean = torch.zeros_like(x)
                m2   = torch.zeros_like(x)
            delta = x - mean
            mean += delta / n_seen
            m2   += delta * (x - mean)

        del post, draws

    std = torch.sqrt(m2 / (n_seen - 1))
    if keep_draws:
        all_draws = torch.cat(all_draws, dim=0)   # (n_draw, N_det)
        return mean.cpu(), std.cpu(), all_draws
    return mean.cpu(), std.cpu()


In [None]:
eta_mean, eta_std, eta_samples = posterior_mean_std_and_samples(
        model, guide, det_tod,
        n_draw=300, chunk=10, keep_draws=True)

In [None]:
if False:
    predictive  = Predictive(model, guide=guide, num_samples=50,
                            return_sites=["log_eta"])
    post         = predictive(det_tod, sky=None)
    eta_samples  = torch.exp(post["log_eta"])            # (300 , N)

    eta_mean = eta_samples.mean(0).cpu()                 # (N,)
    eta_std  = eta_samples.std (0).cpu()

In [None]:
from scipy.stats import norm

In [None]:

det_id = 0                       

invT_det = eta_samples[:,].cpu()        #,)

# posterior mean / std for this detector
mu_det  = eta_samples.mean().item()
sig_det = eta_samples.std().item()


tod_det   = det_tod[0, det_id].cpu()              # shape (Nt,)
tod_corr  = tod_det * mu_det                      # a point-estimate correction
tod_true  = sky_tod[0, det_id].cpu()

Nt = tod_det.numel()
t  = np.arange(Nt)                                # over time samples


fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].hist(eta_samples[:].numpy(), bins=40, alpha=0.2,
               color='steelblue', density=True, label = 'posterior samples')

xmin, xmax = ax[0].get_xlim()
xs = np.linspace(xmin, xmax, 200)
ax[0].plot(xs, norm.pdf(xs, mu_det, sig_det), 'k--',
           label=r'$\mathcal{N}(\mu,\sigma)$ fit')
ax[0].axvline(mu_det, color='k')
ax[0].axvspan(mu_det - sig_det, mu_det + sig_det,
              alpha=0.2, color='orange', label=r'$\pm1\sigma$')
ax[0].set_xlabel(r'mean detector efficiency  Œ∑')
ax[0].set_ylabel('density')
ax[0].set_title(f'Detector posterior')
ax[0].legend()

ax[1].plot(t, tod_det,        lw=0.8, label='det TOD (attenuated)', alpha = 0.4)
ax[1].plot(t, tod_corr,       lw=0.8, label='corrected (mean Œ∑)', color='C3', alpha = 0.4)
#ax[1].plot(t, tod_true, '--', lw=1.0, label='sky TOD (target)',  color='k', alpha = 0.4)
ax[1].set_xlabel('time sample')
ax[1].set_ylabel('power')
ax[1].set_title('Example TOD ‚Äì before / after correction')
ax[1].legend()

plt.tight_layout()
plt.show()


In [None]:

import torch, pyro, pyro.distributions as dist
from pyro.nn   import PyroModule, PyroSample
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam
import matplotlib.pyplot as plt
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

det_tod = tod_after_transmission_torch.to(torch.float32).to(device)
sky_tod = tod_before_transmission_torch.to(torch.float32).to(device)

N, D, Nt = det_tod.shape    # tribalo bi bit 10, 992, 8000     


class InvT_perSample(PyroModule):
    def __init__(self, T_optics, rel_sigma=0.05):
        super().__init__()
        loc   = torch.tensor(np.log(1.0))        # prior around 1 over T
        scale = torch.tensor(rel_sigma)

        self.log_invT = PyroSample(dist.Normal(loc, scale))

        self.T_optics = torch.tensor(T_optics, dtype=torch.float32)

    def forward(self, tod_det):
        invT = torch.exp(self.log_invT)          # shape (N,) 
        # broadcast: (N,1,1) * (N,992,Nt)
        return tod_det * invT[:,None,None] / self.T_optics

T_optics = float(np.prod(q.optics.components['transmission']))
layer    = InvT_perSample(T_optics, rel_sigma=0.10).to(device)

def model(det, sky):
    with pyro.plate("batch", det.size(0)):          
        sky_hat = layer(det)
        pyro.sample("obs", dist.Normal(sky_hat, œÉ_noise).to_event(2), obs=sky)
œÉ_noise = 1e-18                                      


guide = pyro.infer.autoguide.AutoNormal(invT_layer)
optim = Adam({"lr": 2e-3})
svi   = SVI(model, guide, optim, Trace_ELBO())

n_steps = 100
for step in range(n_steps):
    loss = svi.step(det_tod, sky_tod)
    if step % 20 == 0:
        print(f"SVI step {step:4d} | ELBO {loss:.4g}")



In [None]:
predictive = Predictive(model, guide=guide, num_samples=150)
post       = predictive(det_tod[:1], sky=None)               

print("available sites:", list(post.keys()))

In [None]:
invT_samples = torch.exp(post["log_invT"]) 

mean = invT_samples.mean(0).cpu()[0]
std  = invT_samples.std(0).cpu()[0]

plt.errorbar(range(len(mean)), mean, yerr=std, fmt='.', alpha=0.6)
plt.xlabel("detector id"); plt.ylabel("inverse transmission")
plt.title("Posterior mean ¬±1œÉ per bolometer")
plt.show()

In [None]:
invT_samples = invT_samples.squeeze()          

# u ovom slucaju izjednacujemo preko svih uzoraka PO DETEKTORU
# this preserves the sample-by-sample variability
global_samples = invT_samples.mean(dim=1)      # (Ndraw,)

# posterior mean, std of the single (one detektor) efficiency 
mean = global_samples.mean().item()
std  = global_samples.std().item()

print(f"Global inverse-throughput  = {mean:.4f} ¬± {std:.4f}")
print(f"Estimated efficiency Œ∑     = 1 / (T_optics * invT)")

import matplotlib.pyplot as plt
plt.hist(global_samples.cpu().numpy(), bins=40, alpha=0.7, color='steelblue')
plt.axvline(mean, color='k')
plt.axvspan(mean-std, mean+std, alpha=0.2, color='orange')
plt.xlabel("inverse throughput (shared)"); plt.ylabel("posterior draws")
plt.title("Posterior of global transmission factor")
plt.show()

In [None]:
# mozemo pogledat i globalni slucaj (izjednaceno preko detektora)
T_optics = float(np.prod(q.optics.components['transmission']))   # scalar

# draw-wise efficiency   Œ∑ = 1 / (T_optics ¬∑ invT)
eta_samples = 1.0 / (T_optics * global_samples)

eta_mean = eta_samples.mean().item()
eta_std  = eta_samples.std().item()

print(f"Posterior efficiency Œ∑  = {eta_mean:.3f} ¬± {eta_std:.3f}")

import matplotlib.pyplot as plt
plt.hist(eta_samples.cpu().numpy(), bins=40, alpha=0.7, color='seagreen')
plt.axvline(eta_mean, color='k', lw=2)
plt.axvspan(eta_mean-eta_std, eta_mean+eta_std,
            alpha=0.2, color='orange', label="¬±1 œÉ")
plt.xlabel("detector efficiency Œ∑"); plt.ylabel("posterior draws")
plt.title("Posterior of global detector efficiency")
plt.legend(); plt.show()
