In [None]:
import numpy as np
import matplotlib.pyplot as plt 
import healpy as hp

from qubic.lib.MapMaking.Qatmosphere_2d import AtmosphereMaps

import sys
import yaml

import qubic
from qubic.lib.Instrument.Qacquisition import QubicDualBand

from qubic.lib.MapMaking.Qmaps import InputMaps
from qubic.lib.MapMaking.Qcg import PCGAlgorithm
from pyoperators.iterative.core import AbnormalStopIteration

from pyoperators import *



comm = MPI.COMM_WORLD
rank = comm.Get_rank()

%matplotlib inline

In [None]:
# Import simulation parameters
with open('params.yml', 'r') as file:
    params = yaml.safe_load(file) 

In [None]:
# Call the class which build the atmosphere maps
atm = AtmosphereMaps(params)

In [None]:
atm.qubic_dict

In [None]:
# Import the atm absorption spectrum
abs_spectrum = atm.absorption_spectrum()

plt.plot(atm.integration_frequencies, abs_spectrum)
plt.ylim(0, 0.0002)
plt.xlabel('Frequency (GHz)')
plt.ylabel(r'Absorption ($m^{2}/g$)')
plt.title("Atmospheric Absorption Spectrum")

In [None]:
# Import the atm temperature maps
atm_maps = atm.get_temp_maps(atm.rho_map)
print(atm_maps.shape)
plt.imshow(atm_maps[0], cmap='jet', extent=[-params['size_atm'], params['size_atm'], -params['size_atm'], params['size_atm']])
plt.title('Temperature fluctuations')
plt.xlabel('m')
plt.ylabel('m')
plt.colorbar(label=r'$\mu K_{CMB}$')

# Import the atm integrated absorption spectrum
integrated_abs_spectrum, frequencies = atm.integrated_absorption_spectrum()

In [None]:
mean_atm_maps = []
for i in range(atm_maps.shape[0]):
    mean_atm_maps.append(np.mean(atm_maps[i]))
plt.figure()
plt.plot(frequencies, mean_atm_maps, '.')
plt.title('Atmosphere maps spectrum')
plt.xlabel('Frequency (GHz)')
plt.ylabel(r'Mean temperature ($\mu K_{CMB}$)')
plt.figure()
plt.plot(frequencies, integrated_abs_spectrum, '.')
plt.xlabel('Frequency (GHz)')
plt.ylabel(r'Integrated absorption spectrum ($m^{2}/g$)')
plt.title('Integrated absorption spectrum')

In [None]:
for i in range(len(frequencies)):
    atm_maps[i] -= np.mean(atm_maps[i])
    
index_nu = 0
plt.imshow(atm_maps[index_nu], cmap='jet', extent=[-params['size_atm'], params['size_atm'], -params['size_atm'], params['size_atm']])
plt.colorbar(label=r'$µK_{CMB}$')
plt.xlabel('m')
plt.ylabel('m')
plt.title('Atmosphere temperature map at {:.2f} GHz'.format(frequencies[index_nu]))

In [None]:
atm.get_healpy_atm_maps_2d(atm_maps, atm.qubic_dict['RA_center'], atm.qubic_dict['DEC_center']).shape

In [None]:
# Build input maps : Atmsophere (only for I map)
# I substract the mean value to each map to keep only the temperature fluctuations
input_maps = np.zeros((len(frequencies), hp.nside2npix(params['nside']), 3))

index = np.where(atm.get_healpy_atm_maps_2d(atm_maps, atm.qubic_dict['RA_center'], atm.qubic_dict['DEC_center'])[0, :] != 0)
input_maps[:, index, 0] += atm.get_healpy_atm_maps_2d(atm_maps, atm.qubic_dict['RA_center'], atm.qubic_dict['DEC_center'])[:, index]

hp.mollview(input_maps[index_nu, :, 0] ,cmap='jet', unit='µK_CMB', title='Input map {:.2f} GHz'.format(frequencies[index_nu]))

In [None]:
# Compute the Nrec true maps from the Nsub maps
true_maps = np.zeros((params['nrec'], 12*params['nside']**2, 3))
true_frequencies = np.zeros(params['nrec'])
fsub = int(params['nsub_in'] / params['nrec'])

# Build the reconstructed maps and frequency by taking the mean inside each reconstructed frequency band
for i in range(params['nrec']):
    true_maps[i, :, :] = np.mean(input_maps[i*fsub:(i+1)*fsub, :, :], axis=0)
    true_frequencies[i] = np.mean(frequencies[i*fsub:(i+1)*fsub])
print(true_maps.shape) 
print(true_frequencies)   
plt.plot(true_frequencies, np.mean(true_maps, axis=1)[..., 0], '.')
plt.ylabel(r'Mean temperature ($\mu K_{CMB}$)')
plt.xlabel('Frequency (GHz)')
plt.xlim(130, 250)

In [None]:
# Plot all the true maps
for inu in range(len(true_frequencies)):
    hp.mollview(true_maps[inu, :, 0], cmap='jet', title='True - {:.2f} GHz'.format(true_frequencies[inu]))

# Map-making

In [None]:
# Build the QUBIC operators
H_tod = QubicDualBand(atm.qubic_dict, nsub=params['nsub_in'], nrec=params['nsub_in']).get_operator()
tod = H_tod(input_maps).ravel()
del H_tod

Qacq = QubicDualBand(atm.qubic_dict, nsub=params['nsub_in'], nrec=params['nrec'])

In [None]:
#! try to use diagonal operator or remove qubic noise
invN = Qacq.get_invntt_operator()

if params['nrec']==2:
    H_rec = Qacq.get_operator().operands[1]
else:
    H_rec = Qacq.get_operator()

R = ReshapeOperator(H_rec.shapeout, invN.shapein)
H_rec = R*H_rec

In [None]:
coverage = Qacq.coverage

covnorm = coverage / coverage.max()
seenpix = covnorm > params['coverage_cut']

center = np.array([0, -57])
qubic_patch = qubic.lib.Qsamplings.equ2gal(center[0], center[1])

In [None]:
# Print Operators' shape
print('H_rec', H_rec.shapein, H_rec.shapeout)
print("invN", invN.shapein, invN.shapeout)
print("TOD", tod.shape)

In [None]:
# Build PCG
R = ReshapeOperator(tod.shape, invN.shapein)
A = H_rec.T * invN * H_rec
b = H_rec.T * invN * R(tod)
x0 = true_maps*0

In [None]:
if params['nrec'] != 2:
    print("Number of Physical Bands :", len(H_rec.operands)) # operands[0] = 150 GHz / operands[1] = 220 GHz
    print("Number of Reconstructed Sub-Bands within each physical Bands :", len(H_rec.operands[0].operands))
    print("Number of Sub-Bands within each reconstructed bands :", len(H_rec.operands[0].operands[0].operands)) # operands[0] = 150 GHz / operands[1] = 220 GHz
else:
    H_rec = H_rec.operands[1]
    print("Number of Reconstructed Sub-Bands within each physical Bands :", len(H_rec.operands))
    print("Number of Sub-Bands within each reconstructed bands :", len(H_rec.operands[0].operands))

Note about preconditioner: stacked_dptdp_inv should have the shape (Nrec, Npix). But, we can compute that from H, which contains Nsub acquisition operators. In the next cell, I am using only the first Nrec operators rather than Nsub, because I don't know how to reduce them.
I tried to compute it with another H which had exactly Nrec sub-operators, but it didn't work.
We need to find a solution to this problem.

In [None]:
# Build Preconditionner
fsub = int(params['nsub_in'] / params['nrec'])
no_det = 992

stacked_dptdp_inv = np.zeros((params['nrec'],12*params['nside']**2))

### Loop on Focal Plane
for i_fp in range(2):
    stacked_dptdp_inv_fsub = np.zeros((params['nsub_in'],12*params['nside']**2))
    ### Loop on Bands
    for jsub in range(fsub):        
        print("Focal plane :", i_fp, "Nsub band :", jsub)

        ### Extract Operators
        if params['nrec'] == 2:
            H_single = H_rec.operands[i_fp].operands[jsub]
        else:
            H_single = H_rec.operands[0].operands[i_fp].operands[jsub]
        D = H_single.operands[1]
        P = H_single.operands[-1]
        sh = P.matrix.data.index.shape

        ### Compute the map P^t P
        point_per_det = int(sh[0] / no_det)
        mapPtP_perdet_seq = np.zeros((no_det, 12 * params['nside']**2))
        sample_ranges = [(det * point_per_det, (det + 1) * point_per_det) for det in range(no_det)]
        for det, (start, end) in enumerate(sample_ranges):
            indices = P.matrix.data.index[start:end, :]  
            weights = P.matrix.data.r11[start:end, :]
            flat_indices = indices.ravel()
            flat_weights = weights.ravel()

            mapPitPi = np.zeros(12 * params['nside']**2)
            np.add.at(mapPitPi, flat_indices, flat_weights**2)

            mapPtP_perdet_seq[det, :] = mapPitPi
            
        D_elements = D.data
        D_sq = D_elements**2
        mapPtP_seq_scaled = D_sq[:, np.newaxis] * mapPtP_perdet_seq 
        dptdp = mapPtP_seq_scaled.sum(axis = 0)
        dptdp_inv = 1 / dptdp
        dptdp_inv[np.isinf(dptdp_inv)] = 0.
        stacked_dptdp_inv_fsub[jsub] = dptdp_inv
        
    stacked_dptdp_inv[i_fp] = np.mean(stacked_dptdp_inv_fsub, axis = 0)

M = BlockDiagonalOperator( \
                    [DiagonalOperator(ci, broadcast='rightward') for ci in stacked_dptdp_inv],
                    new_axisin=0)

In [None]:
# Run PCG
algo = PCGAlgorithm(
    A,
    b,
    comm,
    x0=x0,
    tol=1e-10,
    maxiter=150,
    disp=True,
    M=M,
    center=[0, -57],
    reso=15,
    seenpix=seenpix,
    input=true_maps,
)
try:
    output = algo.run()
    success = True
    message = 'Success'
except AbnormalStopIteration as e:
    output = algo.finalize()
    success = False
    message = str(e)

In [None]:
plt.plot(output['convergence'])
plt.title("Polychromatic")
plt.yscale('log')
plt.xlabel('Iteration')
plt.ylabel('Convergence')

In [None]:
plt.figure(figsize=(12, 12), dpi=200)
k=1
# true_maps[:, ~seenpix, :] = hp.UNSEEN
# output['x'][:, ~seenpix, :] = hp.UNSEEN

stk = ['I', 'Q', 'U']
istk = 0
n_sig = 3
reso = 15

for inu in range(output['x'].shape[0]):
    sigma = np.std(true_maps[inu, seenpix, istk])
    hp.gnomview(true_maps[inu, :, istk], min=np.min(true_maps[inu, seenpix, istk]), max=np.max(true_maps[inu, seenpix, istk]), cmap='jet', rot=qubic_patch,title='{} - Input - {:.2f} GHz'.format(stk[istk], true_frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k), notext=True)
    hp.gnomview(output['x'][inu, :, istk], min=np.min(true_maps[inu, seenpix, istk]), max=np.max(true_maps[inu, seenpix, istk]), cmap='jet', rot=qubic_patch,title='{} - Output - {:.2f} GHz'.format(stk[istk], true_frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k+1), notext=True)
    hp.gnomview(output['x'][inu, :, istk] - true_maps[inu, :, istk], cmap='jet', rot=qubic_patch,title='{} - Residual - {:.2f} GHz'.format(stk[istk], true_frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k+2), notext=True)
    k+=3

In [None]:
plt.figure(figsize=(12, 12), dpi=200)
k=1
true_maps[:, ~seenpix, :] = hp.UNSEEN
output['x'][:, ~seenpix, :] = hp.UNSEEN

stk = ['I', 'Q', 'U']
istk = 0
n_sig = 3
reso = 15

for inu in range(output['x'].shape[0]):
    sigma = np.std(true_maps[inu, seenpix, istk])
    hp.gnomview(true_maps[inu, :, istk], min=np.min(true_maps[inu, seenpix, istk]), max=np.max(true_maps[inu, seenpix, istk]), cmap='jet', rot=qubic_patch,title='{} - Input - {:.2f} GHz'.format(stk[istk], true_frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k), notext=True)
    hp.gnomview(output['x'][inu, :, istk], min=np.min(true_maps[inu, seenpix, istk]), max=np.max(true_maps[inu, seenpix, istk]), cmap='jet', rot=qubic_patch,title='{} - Output - {:.2f} GHz'.format(stk[istk], true_frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k+1), notext=True)
    hp.gnomview(output['x'][inu, :, istk] - true_maps[inu, :, 0], cmap='jet', rot=qubic_patch,title='{} - Residual - {:.2f} GHz'.format(stk[istk], true_frequencies[inu]), reso=reso, sub=(output['x'].shape[0], 3, k+2), notext=True)
    k+=3

In [None]:
for inu in range(output['x'].shape[0]):
    sigma = np.std(true_maps[inu, seenpix, istk])
    hp.mollview(output['x'][inu, :, istk] - true_maps[inu, :, istk], cmap='jet',title='{} - Residual - {:.2f} GHz'.format(stk[istk], true_frequencies[inu]))

In [None]:
plt.figure(figsize=(12, 12), dpi=200)

k=1

istk = 1
n_sig = 2

for inu in range(output['x'].shape[0]):
    sigma = np.std(true_maps[inu, seenpix, istk])
    hp.gnomview(true_maps[inu, :, istk], cmap='jet', rot=qubic_patch,title='{} - Input - {:.2f} GHz'.format(stk[istk], true_frequencies[inu]), reso=15, sub=(output['x'].shape[0], 3, k), notext=True)
    hp.gnomview(output['x'][inu, :, istk], cmap='jet', rot=qubic_patch,title='{} - Output - {:.2f} GHz'.format(stk[istk], true_frequencies[inu]), reso=15, sub=(output['x'].shape[0], 3, k+1), notext=True)
    hp.gnomview(output['x'][inu, :, istk] - true_maps[inu, :, istk], cmap='jet', rot=qubic_patch,title='{} - Residual - {:.2f} GHz'.format(stk[istk], true_frequencies[inu]), reso=15, sub=(output['x'].shape[0], 3, k+2), notext=True)
    k+=3