In [None]:
import numpy as np
import matplotlib.pyplot as plt 
import healpy as hp

from qubic.lib.MapMaking.Qatmosphere_2d import AtmosphereMaps

import yaml

import qubic
from qubic.lib.Instrument.Qacquisition import QubicInstrumentType

from qubic.lib.MapMaking.Qcg import PCGAlgorithm
from pyoperators.iterative.core import AbnormalStopIteration

from pyoperators import MPI, ReshapeOperator, BlockDiagonalOperator, DiagonalOperator, IdentityOperator


comm = MPI.COMM_WORLD
rank = comm.Get_rank()

%matplotlib inline

In [None]:
# Import simulation parameters
with open('params.yml', 'r') as file:
    params = yaml.safe_load(file) 

In [None]:
# Call the class which build the atmosphere maps
atm = AtmosphereMaps(params)
qubic_dict = atm.qubic_dict

center = np.array([0, -57])
qubic_patch = qubic.lib.Qsamplings.equ2gal(center[0], center[1])

In [None]:
qubic_dict["instrument_type"] = "DB"
qubic_dict["interp_projection"] = False 

In [None]:
# Import the atm absorption spectrum
abs_spectrum = atm.absorption_spectrum()

plt.plot(atm.integration_frequencies, abs_spectrum)
plt.ylim(0, 0.0002)
plt.xlabel('Frequency (GHz)')
plt.ylabel(r'Absorption ($m^{2}/g$)')
plt.title("Atmospheric Absorption Spectrum")

In [None]:
# Import the atm temperature maps
atm_maps = np.zeros((len(atm.frequencies), hp.nside2npix(params["nside"]), 3))
atm_maps[..., 0] = atm.get_temp_maps(atm.delta_rho_map)

index_nu = 0
hp.mollview(
    atm_maps[index_nu, :, 0],
    cmap="jet",
    unit="µK_CMB",
    title="Atmosphere map {:.2f} GHz".format(atm.frequencies[index_nu]),
)


# Import the atm integrated absorption spectrum
integrated_abs_spectrum, frequencies, bandwidth = atm.integrated_absorption_spectrum()

In [None]:
mean_atm_maps = []
for i in range(atm_maps.shape[0]):
    mean_atm_maps.append(np.mean(atm_maps[i]))
plt.figure()
plt.plot(frequencies, mean_atm_maps, '.')
plt.title('Atmosphere maps spectrum')
plt.xlabel('Frequency (GHz)')
plt.ylabel(r'Mean temperature ($\mu K_{CMB}$)')
plt.figure()
plt.plot(frequencies, integrated_abs_spectrum, '.')
plt.xlabel('Frequency (GHz)')
plt.ylabel(r'Integrated absorption spectrum ($m^{2}/g$)')
plt.title('Integrated absorption spectrum')

In [None]:
true_maps = np.zeros((params["nrec"], hp.nside2npix(params["nside"]), 3))
fsub = params["nsub_in"] // params["nrec"]
for i in range(params["nrec"]):
    true_maps[i, :, 0] = atm_maps[(i * fsub) : (i + 1) * fsub, :, 0].mean(axis=0)

# Map-making

In [None]:
# Build the QUBIC operators
H_tod = QubicInstrumentType(atm.qubic_dict, nsub=params['nsub_in'], nrec=params['nsub_in']).get_operator()

tod = H_tod(atm_maps).ravel()

Qacq = QubicInstrumentType(atm.qubic_dict, nsub=params['nsub_in'], nrec=params['nrec'])

In [None]:
#! try to use diagonal operator or remove qubic noise
invN = IdentityOperator()

if params['nrec']==2:
    H_rec = Qacq.get_operator().operands[1]
else:
    H_rec = Qacq.get_operator()

H_rec = H_rec

In [None]:
coverage = Qacq.coverage

covnorm = coverage / coverage.max()
seenpix = covnorm > params['coverage_cut']

In [None]:
# Print Operators' shape
print('H_rec', H_rec.shapein, H_rec.shapeout)
print("invN", invN.shapein, invN.shapeout)
print("TOD", tod.shape)

In [None]:
# Build PCG
R = ReshapeOperator(tod.shape, H_rec.shapeout)
A = H_rec.T * invN * H_rec
b = H_rec.T * invN * R(tod)
x0 = true_maps * 0.

In [None]:
print("A", A.shapein, A.shapeout)
print("b", b.shape)
print("x0", x0.shape)

In [None]:
if params['nrec'] != 2:
    print("Number of Physical Bands :", len(H_rec.operands)) # operands[0] = 150 GHz / operands[1] = 220 GHz
    print("Number of Reconstructed Sub-Bands within each physical Bands :", len(H_rec.operands[0].operands))
    print("Number of Sub-Bands within each reconstructed bands :", len(H_rec.operands[0].operands[0].operands)) # operands[0] = 150 GHz / operands[1] = 220 GHz
else:
    H_rec = H_rec.operands[1]
    print("Number of Reconstructed Sub-Bands within each physical Bands :", len(H_rec.operands))
    print("Number of Sub-Bands within each reconstructed bands :", len(H_rec.operands[0].operands))

Note about preconditioner: stacked_dptdp_inv should have the shape (Nrec, Npix). But, we can compute that from H, which contains Nsub acquisition operators. In the next cell, I am using only the first Nrec operators rather than Nsub, because I don't know how to reduce them.
I tried to compute it with another H which had exactly Nrec sub-operators, but it didn't work.
We need to find a solution to this problem.

In [None]:
# preconditioner

nrec = params["nrec"]
nside = params["nside"]
npix = 12 * nside**2
nsub = params["nsub_in"]
no_det = 992

stacked_dptdp_inv = np.empty((nrec, npix))

q_acq = Qacq


H_qubic = Qacq.operator

stacked_dptdp_inv_nsub = np.empty((fsub, npix))

for irec in range(nrec):
    for j_fsub in range(fsub):
        H_single = H_qubic[irec * fsub + j_fsub]

        D = H_single.operands[1]
        P = H_single.operands[4]
        sh = P.matrix.data.index.shape

        point_per_det = sh[0] // no_det
        mapPtP_perdet_seq = np.empty((no_det, npix))

        for det in range(no_det):
            start, end = det * point_per_det, (det + 1) * point_per_det
            indices = P.matrix.data.index[start:end, :]
            weights = P.matrix.data.r11[start:end, :]
            flat_indices = indices.ravel()
            flat_weights = weights.ravel()

            mapPitPi = np.bincount(flat_indices, weights=flat_weights**2, minlength=npix)
            mapPtP_perdet_seq[det, :] = mapPitPi

        D_sq = D.data**2
        mapPtP_seq_scaled = D_sq[:, np.newaxis] * mapPtP_perdet_seq
        dptdp = mapPtP_seq_scaled.sum(axis=0)

        # Safe inversion
        dptdp_inv = np.zeros_like(dptdp)
        nonzero = dptdp != 0
        dptdp_inv[nonzero] = 1.0 / dptdp[nonzero]
        stacked_dptdp_inv_nsub[j_fsub] = dptdp_inv

    stacked_dptdp_inv[irec] = stacked_dptdp_inv_nsub.sum(axis=0)

preconditioner = BlockDiagonalOperator([DiagonalOperator(ci, broadcast="rightward") for ci in stacked_dptdp_inv], new_axisin=0)

In [None]:
# Run PCG
algo = PCGAlgorithm(
    A,
    b,
    comm,
    x0=x0,
    tol=1e-10,
    maxiter=200,
    disp=True,
    M=None,
    center=[0, -57],
    reso=15,
    seenpix=seenpix,
    input=true_maps,
)
try:
    output = algo.run()
    success = True
    message = 'Success'
except AbnormalStopIteration as e:
    output = algo.finalize()
    success = False
    message = str(e)

In [None]:
plt.plot(output['convergence'])
plt.title("Polychromatic")
plt.yscale('log')
plt.xlabel('Iteration')
plt.ylabel('Convergence')

In [None]:
plt.figure(figsize=(12, 12), dpi=200)
k=1
# true_maps[:, ~seenpix, :] = hp.UNSEEN
# output['x'][:, ~seenpix, :] = hp.UNSEEN

stk = ['I', 'Q', 'U']
istk = 0
n_sig = 3
reso = 15

for inu in range(output['x'].shape[0]):
    sigma = np.std(true_maps[inu, seenpix, istk])
    hp.gnomview(true_maps[inu, :, istk], min=np.min(true_maps[inu, seenpix, istk]), max=np.max(true_maps[inu, seenpix, istk]), cmap='jet', rot=qubic_patch,title='{} - Input - {:.2f} GHz'.format(stk[istk], atm.frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k), notext=True)
    hp.gnomview(output['x'][inu, :, istk], min=np.min(true_maps[inu, seenpix, istk]), max=np.max(true_maps[inu, seenpix, istk]), cmap='jet', rot=qubic_patch,title='{} - Output - {:.2f} GHz'.format(stk[istk], atm.frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k+1), notext=True)
    hp.gnomview(output['x'][inu, :, istk] - true_maps[inu, :, istk], cmap='jet', rot=qubic_patch,title='{} - Residual - {:.2f} GHz'.format(stk[istk], atm.frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k+2), notext=True)
    k+=3

In [None]:
plt.figure(figsize=(12, 12), dpi=200)
k=1
true_maps[:, ~seenpix, :] = hp.UNSEEN
output['x'][:, ~seenpix, :] = hp.UNSEEN

stk = ['I', 'Q', 'U']
istk = 0
n_sig = 3
reso = 15

for inu in range(output['x'].shape[0]):
    sigma = np.std(true_maps[inu, seenpix, istk])
    hp.gnomview(true_maps[inu, :, istk], min=np.min(true_maps[inu, seenpix, istk]), max=np.max(true_maps[inu, seenpix, istk]), cmap='jet', rot=qubic_patch,title='{} - Input - {:.2f} GHz'.format(stk[istk], atm.frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k), notext=True)
    hp.gnomview(output['x'][inu, :, istk], min=np.min(true_maps[inu, seenpix, istk]), max=np.max(true_maps[inu, seenpix, istk]), cmap='jet', rot=qubic_patch,title='{} - Output - {:.2f} GHz'.format(stk[istk], atm.frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k+1), notext=True)
    hp.gnomview(output['x'][inu, :, istk] - true_maps[inu, :, 0], cmap='jet', rot=qubic_patch,title='{} - Residual - {:.2f} GHz'.format(stk[istk], atm.frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k+2), notext=True)
    k+=3

In [None]:
for inu in range(output['x'].shape[0]):
    sigma = np.std(true_maps[inu, seenpix, istk])
    hp.mollview(output['x'][inu, :, istk] - true_maps[inu, :, istk], cmap='jet',title='{} - Residual - {:.2f} GHz'.format(stk[istk], atm.frequencies[inu]))

In [None]:
plt.figure(figsize=(12, 12), dpi=200)

k=1

istk = 1
n_sig = 2

for inu in range(output['x'].shape[0]):
    sigma = np.std(true_maps[inu, seenpix, istk])
    hp.gnomview(true_maps[inu, :, istk], cmap='jet', rot=qubic_patch,title='{} - Input - {:.2f} GHz'.format(stk[istk], atm.frequencies[inu]), reso=15, sub=(output['x'].shape[0], 3, k), notext=True)
    hp.gnomview(output['x'][inu, :, istk], cmap='jet', rot=qubic_patch,title='{} - Output - {:.2f} GHz'.format(stk[istk], atm.frequencies[inu]), reso=15, sub=(output['x'].shape[0], 3, k+1), notext=True)
    hp.gnomview(output['x'][inu, :, istk] - true_maps[inu, :, istk], cmap='jet', rot=qubic_patch,title='{} - Residual - {:.2f} GHz'.format(stk[istk], atm.frequencies[inu]), reso=15, sub=(output['x'].shape[0], 3, k+2), notext=True)
    k+=3