# Importation

## Import libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt 
import healpy as hp
import yaml
 
from qubic.lib.MapMaking.Qatmosphere_2d import AtmosphereMaps
from qubic.lib.MapMaking.FrequencyMapMaking.Qspectra_component import CMBModel
from qubic.lib.Qsamplings import get_pointing, equ2gal, QubicSampling
from qubic.lib.Instrument.Qacquisition import QubicDualBand

from qubic.lib.MapMaking.Qcg_test_for_atm import PCGAlgorithm
from pyoperators.iterative.core import AbnormalStopIteration

from pyoperators import MPI, BlockDiagonalOperator, BlockRowOperator, DiagonalOperator, ReshapeOperator

from pysimulators.interfaces.healpy import Spherical2HealpixOperator, HealpixConvolutionGaussianOperator

comm = MPI.COMM_WORLD
rank = comm.Get_rank() 

%matplotlib inline

## Import parameters and atm class

In [None]:
### Import simulation parameters
with open('params.yml', 'r') as file:
    params = yaml.safe_load(file) 

In [None]:
### Call the class which build the atmosphere maps
atm = AtmosphereMaps(params)
qubic_dict = atm.qubic_dict

# Scanning Strategy

## Galactic Coordinates - Sweeping scan following the center of the Qubic patch

In [None]:
### Random pointing
qubic_dict['random_pointing'] = True

### Sweepingpointing
qubic_dict['sweeping_pointing'] = False
qubic_dict['fix_azimuth']['apply'] = False 

qubic_dict['angspeed'] = 0.4
qubic_dict['delta_az'] = 20
qubic_dict['nsweeps_per_elevation'] = 1
qubic_dict['period'] = 1
qubic_dict['duration'] = 1
# npointings = 3600 * t_obs / period

### Repeat pointing
qubic_dict['repeat_pointing'] = False

qubic_dict['fix_azimuth']['apply'] = False

In [None]:
### Create QubicSampling object
q_sampling_gal = get_pointing(qubic_dict)

### Define coordinates of the patch in Galactic and Local coordinates
qubic_patch = np.array([0, -57])
center_gal = equ2gal(qubic_patch[0], qubic_patch[1])
center_local = np.array([np.mean(q_sampling_gal.azimuth), np.mean(q_sampling_gal.elevation)])
print(q_sampling_gal)

In [None]:
### Plot the scanning strategy
az, el = q_sampling_gal.azimuth, q_sampling_gal.elevation

fig, axs = plt.subplots(1, 5, figsize=(25, 5))

# Azimuth plot
axs[0].plot(az)
axs[0].set_title("Azimuth")
axs[0].set_xlabel("Time samples")
axs[0].set_ylabel("Angles (degrees)")

# Elevation plot
axs[1].plot(el)
axs[1].set_title("Elevation")
axs[1].set_xlabel("Time samples")
axs[1].set_ylabel("Angles (degrees)")

# Scanning strategy plot
axs[2].plot(az, el, '.')
axs[2].set_title("Scanning strategy")
axs[2].set_xlabel("Azimuth (degrees)")
axs[2].set_ylabel("Elevation (degrees)")

# Equatorial coordinates plot
axs[3].plot((q_sampling_gal.equatorial[:, 0] + 180) % 360 - 180, q_sampling_gal.equatorial[:, 1], '.')
axs[3].set_title("Equatorial coordinates")
axs[3].set_xlabel("Right ascension (degrees)")
axs[3].set_ylabel("Declination (degrees)")

# Galactic coordinates plot
axs[4].plot(q_sampling_gal.galactic[:, 0], q_sampling_gal.galactic[:, 1], '.')
axs[4].set_title("Galactic coordinates")
axs[4].set_xlabel("Longitude (degrees)")
axs[4].set_ylabel("Latitude (degrees)")

fig.suptitle("Qubic Sampling")
plt.tight_layout()
plt.show()

## Local Coordinates - Sweeping scan following the center of the Qubic patch

In [None]:
### Define the scanning in Local coordinates by using the previous azimuth and elevation
### but putting fix.az = True

q_sampling_local = QubicSampling(params['npointings'], #int(np.ceil(qubic_dict['duration']*3600/qubic_dict['period'])),
                           date_obs = qubic_dict['date_obs'],
                           period = qubic_dict['period'],
                           latitude = qubic_dict['latitude'],
                           longitude = qubic_dict['longitude'])

q_sampling_local.azimuth = q_sampling_gal.azimuth
q_sampling_local.elevation = q_sampling_gal.elevation
q_sampling_local.pitch = q_sampling_gal.pitch
q_sampling_local.angle_hwp = q_sampling_gal.angle_hwp

q_sampling_local.fix_az = True

# Input Maps

## CMB

In [None]:
### Build CMB map
cl_cmb = CMBModel(None).give_cl_cmb(r=0, Alens=1)
cmb_map = hp.synfast(cl_cmb, params['nside'], new=True, verbose=False).T

hp.mollview(cmb_map[:, 0], cmap='jet', title='CMB map', unit=r'$µK_{CMB}$')

In [None]:
### Build Input CMB Maps
cmb_maps_input = np.zeros((params['nsub_in'], hp.nside2npix(params['nside']), 3))
for i in range(cmb_maps_input.shape[0]):
    cmb_maps_input[i] = cmb_map

In [None]:
### Build Expected CMB Maps
cmb_maps = np.zeros((params['nrec'], hp.nside2npix(params['nside']), 3))
for i in range(cmb_maps.shape[0]):
    cmb_maps[i] = cmb_map

## Atmosphere

In [None]:
### Import the atm temperature maps
atm_maps_2d = atm.get_temp_maps(atm.rho_map)
print(atm_maps_2d.shape)
plt.imshow(atm_maps_2d[0], cmap='jet', extent=[-params['size_atm'], params['size_atm'], -params['size_atm'], params['size_atm']])
plt.title('Temperature fluctuations')
plt.xlabel('m')
plt.ylabel('m')
plt.colorbar(label=r'$\mu K_{CMB}$')

In [None]:
### Import the atm integrated absorption spectrum
integrated_abs_spectrum, frequencies = atm.integrated_absorption_spectrum()

plt.figure()
plt.plot(frequencies, integrated_abs_spectrum, '.')
plt.xlabel('Frequency (GHz)')
plt.ylabel(r'Integrated absorption spectrum ($m^{2}/g$)')
plt.title('Integrated absorption spectrum')

In [None]:
### Remove the mean value for each frequency
for i in range(len(frequencies)):
    atm_maps_2d[i] -= np.mean(atm_maps_2d[i])
    
index_nu = 0
plt.imshow(atm_maps_2d[index_nu], cmap='jet', extent=[-params['size_atm'], params['size_atm'], -params['size_atm'], params['size_atm']])
plt.colorbar(label=r'$µK_{CMB}$')
plt.xlabel('m')
plt.ylabel('m')
plt.title('Atmosphere temperature map at {:.2f} GHz'.format(frequencies[index_nu]))

In [None]:
### Build input maps : Atmsophere (only for I map)
atm_maps_input = np.zeros((len(frequencies), hp.nside2npix(params['nside']), 3))

healpy_atm_maps = atm.get_healpy_atm_maps_2d(atm_maps_2d, np.mean(q_sampling_local.azimuth), np.mean(q_sampling_local.elevation))

index = np.where(healpy_atm_maps[0, :] != 0)
atm_maps_input[:, index, 0] += healpy_atm_maps[:, index]

index_nu = 0
hp.mollview(atm_maps_input[index_nu, :, 0] ,cmap='jet', unit='µK_CMB', title='Input map {:.2f} GHz'.format(frequencies[index_nu]))

In [None]:
### Build Expected Atm Maps
atm_maps = np.zeros((params['nrec'], 12*params['nside']**2, 3))
true_frequencies = np.zeros(params['nrec'])
fsub = int(params['nsub_in'] / params['nrec'])

# Build the reconstructed maps and frequency by taking the mean inside each reconstructed frequency band
for i in range(params['nrec']):
    atm_maps[i, :, :] = np.mean(atm_maps_input[i*fsub:(i+1)*fsub, :, :], axis=0)
    true_frequencies[i] = np.mean(frequencies[i*fsub:(i+1)*fsub])
print("Atm maps shape is :", atm_maps.shape) 
print("Reconstructed frequencies are : ", true_frequencies)   
plt.plot(true_frequencies, np.mean(atm_maps, axis=1)[..., 0], '.')
plt.ylabel(r'Mean temperature ($\mu K_{CMB}$)')
plt.xlabel('Frequency (GHz)')
plt.xlim(130, 250)

# Build QUBIC Instances

In [None]:
### Build QubicDualBand object in galactic and local coordinates
Qacq_gal = QubicDualBand(qubic_dict, nsub=params['nsub_in'], nrec=params['nrec'], sampling=q_sampling_gal)
Qacq_local = QubicDualBand(qubic_dict, nsub=params['nsub_in'], nrec=params['nrec'], sampling=q_sampling_local)

# Build QUBIC Operaotors

In [None]:
### Build H
if params['nrec'] == 2:
    H_gal = Qacq_gal.get_operator().operands[1]
    H_local = Qacq_local.get_operator().operands[1]
else:
    H_gal = Qacq_gal.get_operator()
    H_local = Qacq_local.get_operator()
    
### Build invN
invN_gal = Qacq_gal.get_invntt_operator(False, False)
invN_local = Qacq_local.get_invntt_operator(False, False)

### Correct the shape of H
R = ReshapeOperator(H_gal.shapeout, invN_gal.shapein)
H_gal = R * H_gal
H_local = R * H_local

In [None]:
### Build TOD
H_gal_tod = QubicDualBand(qubic_dict, nsub=params['nsub_in'], nrec=params['nsub_in'], sampling=q_sampling_gal).get_operator()
H_local_tod = QubicDualBand(qubic_dict, nsub=params['nsub_in'], nrec=params['nsub_in'], sampling=q_sampling_local).get_operator()

d_gal = H_gal_tod(cmb_maps_input).ravel()
d_local = H_local_tod(atm_maps_input).ravel()

del H_gal_tod, H_local_tod

In [None]:
### Build coverage
coverage_gal = Qacq_gal.coverage
covnorm_gal = coverage_gal / coverage_gal.max()
seenpix_gal = covnorm_gal > params['coverage_cut']

coverage_local = Qacq_local.coverage
covnorm_local = coverage_local / coverage_local.max()
seenpix_local = covnorm_local > params['coverage_cut']

seenpix = []
for i in range(params['nrec']):
    seenpix.append(seenpix_gal)
for i in range(params['nrec']):
    seenpix.append(seenpix_local)
seenpix = np.array(seenpix)

In [None]:
# Test coverage
hp.mollview(coverage_gal, title='Coverage Galactic Map')
hp.mollview(coverage_local, title='Coverage Local Map')

In [None]:
print("H_gal", H_gal.shapein, H_gal.shapeout)
print("invN_gal", invN_gal.shapein, invN_gal.shapeout)
print("d_gal", d_gal.shape)

In [None]:
print("H_local", H_local.shapein, H_local.shapeout)
print("invN_local", invN_local.shapein, invN_local.shapeout)
print("d_local", d_local.shape)

In [None]:
### Full MM

R = ReshapeOperator((H_gal.shapeout), (invN_gal.shapein))
H = BlockRowOperator([H_gal, H_local], axisin=0)

invN = invN_gal # BlockDiagonalOperator([invN_gal, invN_local], axisout=0)

for i in range(params['nrec']):
    true_maps = np.array([cmb_maps, atm_maps])

d = d_gal + d_local

true_maps = np.append(cmb_maps, atm_maps, axis=0)

In [None]:
print("H", H.shapein, H.shapeout)
print("invN", invN.shapein, invN.shapeout)
print("d", d.shape)
print("true_maps", true_maps.shape)

# Map-Making

In [None]:
# Ax=b equation to be solve by PCG
R_tod = ReshapeOperator((d.shape), (invN.shapein))
A = H.T * invN * H
b = H.T * invN * R_tod(d)

# I start from an atm map for the CMB and from a cmb map for the atmosphere
x0 = true_maps * 0.9

In [None]:
print("A", A.shapein, A.shapeout)
print("Ax", A(x0).shape)
print("b", b.shape)
print("true_maps", true_maps.shape)

In [None]:
### H dimension

# H.operands[0] = H_gal / H.operands[1] = H_loc
# H_operands[i].operands[0] = ReshapeOperator
# H_operands[i].operands[1].operands = Nrec H_i
# H_operands[i].operands[1].operands[j] = Nsub H_i for the band j

In [None]:
# Build Preconditionner
no_det = 992

stacked_dptdp_inv = np.zeros((true_maps.shape[0],12*params['nside']**2))

cpt = 0
### Loop on Maps
for i_maps in range(2):
    for j_freq in range(params['nrec']):
        ### Extract Operators
        h = H.operands[i_maps].operands[1].operands[j_freq].operands[int(params['nsub_in']/params['nrec']/2)]
        D = h.operands[1]
        P = h.operands[-1]
        sh = P.matrix.data.index.shape

        ### Compute the map P^t P
        point_per_det = int(sh[0] / no_det)
        mapPtP_perdet_seq = np.zeros((no_det, 12 * params['nside']**2))
        sample_ranges = [(det * point_per_det, (det + 1) * point_per_det) for det in range(no_det)]
        for det, (start, end) in enumerate(sample_ranges):
            indices = P.matrix.data.index[start:end, :]  
            weights = P.matrix.data.r11[start:end, :]
            flat_indices = indices.ravel()
            flat_weights = weights.ravel()

            mapPitPi = np.zeros(12 * params['nside']**2)
            np.add.at(mapPitPi, flat_indices, flat_weights**2)

            mapPtP_perdet_seq[det, :] = mapPitPi
            
        D_elements = D.data
        D_sq = D_elements**2
        mapPtP_seq_scaled = D_sq[:, np.newaxis] * mapPtP_perdet_seq 
        dptdp = mapPtP_seq_scaled.sum(axis = 0)
        dptdp_inv = 1 / dptdp
        dptdp_inv[np.isinf(dptdp_inv)] = 0.
            
        stacked_dptdp_inv[cpt] = dptdp_inv
        cpt += 1

M = BlockDiagonalOperator( \
                    [DiagonalOperator(ci, broadcast='rightward') for ci in stacked_dptdp_inv],
                    new_axisin=0)

In [None]:
# Run PCG

algo = PCGAlgorithm(
    A,
    b,
    comm,
    x0=x0,
    tol=1e-10,
    maxiter=100,
    disp=True,
    M=M,
    center=[0, -57],
    reso=15,
    seenpix=seenpix,
    input=true_maps,
)
try:
    result = algo.run()
    success = True
    message = 'Success'
except AbnormalStopIteration as e:
    result = algo.finalize()
    success = False
    message = str(e)

In [None]:
plt.plot(result['convergence'])
plt.yscale('log')
plt.xlabel('Iteration')
plt.ylabel('Convergence')

In [None]:
input = true_maps
output = result['x']
residual = output - input

In [None]:
print(input.shape)

In [None]:
plt.figure(figsize=(15, 12))
k=1

istk = 0
stk = ['I', 'Q', 'U']
map = ['CMB', 'CMB', 'Atm', 'Atm']
freq = ['150', '220', '150', '220']

for imap in range(input.shape[0]):
    print(imap)
    hp.mollview(input[imap, :, istk],min=np.min(input[imap, seenpix[imap], istk]), max=np.max(input[imap, seenpix[imap], istk]), cmap='jet', sub=(input.shape[0], 3, k), title=f'{stk[istk]} - Input - {map[imap]} - {freq[imap]} GHz', notext=True)
    hp.mollview(output[imap, :, istk],min=np.min(input[imap, seenpix[imap], istk]), max=np.max(input[imap, seenpix[imap], istk]), cmap='jet', sub=(input.shape[0], 3, k+1), title=f'{stk[istk]} - Output - {map[imap]} - {freq[imap]} GHz', notext=True)
    hp.mollview(output[imap, :, istk] - input[imap, :, istk],min=np.min(input[imap, seenpix[imap], istk]), max=np.max(input[imap, seenpix[imap], istk]), cmap='jet', sub=(input.shape[0], 3, k+2), title=f'{stk[istk]} - {freq[imap]} GHz - Residual - {map[imap]}', notext=True)
    k+=3
    
plt.tight_layout()

In [None]:
plt.figure(figsize=(15, 10))
k=1

istk = 0

reso = 20
center = [center_gal, center_gal, center_local, center_local]

for imap in range(input.shape[0]):
    hp.gnomview(input[imap, :, istk], min=np.min(input[imap, seenpix[imap], istk]), max=np.max(input[imap, seenpix[imap], istk]), rot=center[imap], reso=reso, cmap='jet', sub=(input.shape[0], 3, k), title=f'{stk[istk]} - Input - {map[imap]}', notext=True)
    hp.gnomview(output[imap, :, istk], min=np.min(input[imap, seenpix[imap], istk]), max=np.max(input[imap, seenpix[imap], istk]), rot=center[imap], reso=reso, cmap='jet', sub=(input.shape[0], 3, k+1), title=f'{stk[istk]} - Output - {map[imap]}', notext=True)
    hp.gnomview(output[imap, :, istk] - input[imap, :, istk], min=np.min(input[imap, seenpix[imap], istk]), max=np.max(input[imap, seenpix[imap], istk]), rot=center[imap], reso=reso, cmap='jet', sub=(input.shape[0], 3, k+2), title=f'{stk[istk]} - Residual - {map[imap]}', notext=True)
    k+=3
    
plt.tight_layout()

In [None]:
plt.figure(figsize=(15, 10))
k=1

istk = 0

input[0, ~seenpix_gal, :] = hp.UNSEEN
input[1, ~seenpix_gal, :] = hp.UNSEEN
output[0, ~seenpix_gal, :] = hp.UNSEEN
output[1, ~seenpix_gal, :] = hp.UNSEEN
input[2, ~seenpix_local, :] = hp.UNSEEN
input[3, ~seenpix_local, :] = hp.UNSEEN
output[2, ~seenpix_local, :] = hp.UNSEEN
output[3, ~seenpix_local, :] = hp.UNSEEN


reso = 20

for imap in range(input.shape[0]):
    hp.gnomview(input[imap, :, istk], rot=center[imap], reso=reso, cmap='jet', sub=(input.shape[0], 3, k), title=f'{stk[istk]} - Input - {map[imap]}', notext=True)
    hp.gnomview(output[imap, :, istk], rot=center[imap], reso=reso, cmap='jet', sub=(input.shape[0], 3, k+1), title=f'{stk[istk]} - Output - {map[imap]}', notext=True)
    hp.gnomview(output[imap, :, istk] - input[imap, :, istk], rot=center[imap], reso=reso, cmap='jet', sub=(input.shape[0], 3, k+2), title=f'{stk[istk]} - Residual - {map[imap]}', notext=True)
    k+=3
    
plt.tight_layout()

In [None]:
plt.figure(figsize=(15, 10))
k=1

istk = 0
stk = ['I', 'Q', 'U']

for imap in range(input.shape[0]):
    hp.mollview(input[imap, :, istk], cmap='jet', sub=(input.shape[0], 3, k), title=f'{stk[istk]} - Input - {map[imap]}', notext=True)
    hp.mollview(output[imap, :, istk], cmap='jet', sub=(input.shape[0], 3, k+1), title=f'{stk[istk]} - Output - {map[imap]}', notext=True)
    hp.mollview(output[imap, :, istk] - input[imap, :, istk], cmap='jet', sub=(input.shape[0], 3, k+2), title=f'{stk[istk]} - Residual - {map[imap]}', notext=True)
    k+=3
    
plt.tight_layout()