In [None]:
import config
from forklens import train
from forklens.dataset import ShapeDataset
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
import time

In ./cali_code, I simply list the main codes we used in the shear calibration procedure, including the training data generation and training using the Tenblic code (Tewes et al 2019, http://cdsarc.u-strasbg.fr/viz-bin/qcat?J/A+A/621/A36).

One may want to first install and learn how to use Tenblic before moving forward to this part.

In [None]:
tr = train.Train()
model_file = "../cnn_tests/model/test_model149"
model = tr.load_model(path=model_file,strict=True)

In [None]:
# Load data catalog
with fits.open('../../data/csst_catalog_example.fits') as f:
    cat = f[0].data
cut_idx = np.where((cat[:,0]>0.1)&(cat[:,1]<25))[0]
size_cat = cat[cut_idx,0]
mag_cat = cat[cut_idx,1]

## Type II data

In [None]:
from forklens.dataset import ShearDataset

from scipy.optimize import curve_fit

In [None]:
case_num = 200
real_num = 100000

seed = 33333
rng = np.random.RandomState(seed)
Gal_Shear   = rng.random((case_num,2))*(-0.1-0.1)+0.1

for i in range(case_num):
    
    sub_seed = seed + i
    
    rng1 = np.random.RandomState(sub_seed+1)
    idx = rng1.randint(0,mag_cat.shape[0],size=int(real_num))
    Gal_Hlr     = size_cat[idx]
    # Gal_Hlr     = np.concatenate((Gal_Hlr, Gal_Hlr), axis=0)
    Gal_Mag     = mag_cat[idx]
    # Gal_Mag     = np.concatenate((Gal_Mag, Gal_Mag), axis=0)
    
    rng2 = np.random.RandomState(sub_seed+2)
    Gal_Phi     = rng2.random(size=int(real_num))*(-np.pi/2-np.pi/2)+np.pi/2
    # Gal_Phi     = np.concatenate((Gal_Phi, Gal_Phi+np.pi/2), axis=0)
    
    rng3 = np.random.RandomState(sub_seed+3)
    Gal_AxRatio = rng3.random(size=int(real_num))*(0.5-1)+1
    # Gal_AxRatio = np.concatenate((Gal_AxRatio, Gal_AxRatio), axis=0)
    
    Gal_E1 = (1-Gal_AxRatio)/(1+Gal_AxRatio)*np.cos(Gal_Phi*2)
    Gal_E2 = (1-Gal_AxRatio)/(1+Gal_AxRatio)*np.sin(Gal_Phi*2)

    PSF_randint = rng.randint(0,high=10000,size=real_num)
    
    if i == 0:
        gal_pars = {}
        gal_pars["e1"] = Gal_E1
        gal_pars["e2"] = Gal_E2
        gal_pars["hlr_disk"] = Gal_Hlr
        gal_pars["mag_i"] = Gal_Mag
        gal_pars['randint'] = PSF_randint
    else:
        gal_pars["e1"] = np.vstack((gal_pars["e1"], Gal_E1))
        gal_pars["e2"] = np.vstack((gal_pars["e2"], Gal_E2))
        gal_pars["hlr_disk"] = np.vstack((gal_pars["hlr_disk"], Gal_Hlr))
        gal_pars["mag_i"] = np.vstack((gal_pars["mag_i"], Gal_Mag))
        gal_pars['randint'] = np.vstack((gal_pars["randint"], PSF_randint))

shear_pars = {}
shear_pars['shear'] = Gal_Shear

In [None]:
show_ds = ShearDataset(shear_pars, gal_pars)

plt.figure(figsize=(10,10))

for i in range(5):
    for j in range(6):
        plt.subplot(5,6,i*6+j+1)
        #print(j)
        plt.text(10,30,'%d'%show_ds.__getitem__(i*real_num+j)['snr'],color='white',fontsize=18)
        gal_im = show_ds.__getitem__(i*real_num+j)['gal_image'][0]
        plt.imshow(gal_im,cmap='gray')
        plt.xticks([])
        plt.yticks([])
        
plt.subplots_adjust(wspace=0., hspace=-0.3)

In [None]:
# this might takes hours depending on the data volume
start = time.time()

shear_ds = ShearDataset(shear_pars, gal_pars)
shear_dl = DataLoader(shear_ds, 
                      batch_size=250, 
                      num_workers=20)

pred, true, _, _ = tr._predictFunc(shear_dl,model)
diff = pred - true

(time.time()-start)/60

In [None]:
# Save
# hdu0 = fits.PrimaryHDU(pred[:,0:5])
# hdu1 = fits.ImageHDU(Gal_Shear[:,0])
# hdul = fits.HDUList([hdu0,hdu1])
# hdul.writeto('TypeII_200case_100000real.fits')

In [None]:
results = pred[:,1:].reshape(case_num,real_num,4)
g = np.zeros((case_num,2))
for i in range(case_num):
    g[i,0] = np.mean(results[i,:,0])
    g[i,1] = np.mean(results[i,:,1])


coeffs, cov = np.polyfit(shear_pars['shear'][:,0], 
                         g[:,0]-shear_pars['shear'][:,0], 1, cov=True)

# Extract the best-fit values and standard deviations
k, b = coeffs
std_k = np.sqrt(cov[0, 0])
std_b = np.sqrt(cov[1, 1])

# m, m error, c, c error
k,std_k,b,std_b

In [None]:
plt.figure(figsize=(8,6))

plt.scatter(shear_pars['shear'][:,0],g[:,0]-shear_pars['shear'][:,0],s=3,color='maroon',alpha=0.7)

plt.axhline(0,color='k',linestyle='--',linewidth=4)
x = np.linspace(-0.1,0.1)
y = k*x+b
plt.plot(x,y,color='gold',linewidth=4)

plt.ylim(-0.1,0.1)
plt.tick_params(axis='both',which='major',labelsize=13)
plt.xlabel(r'$g^{true}_1$',fontsize=18)
plt.ylabel(r'$\left<g_1\right>-g^{true}_1$',fontsize=18)
plt.title('CNN shear measurement', fontsize=18)

## Type I data

In [None]:
from forklens.dataset import CaliDataset

In [None]:
case_num = 5000
real_num = 2000

seed = 12345
rng = np.random.RandomState(seed)
Gal_Shear   = rng.random((case_num,2))*(-0.1-0.1)+0.1

rng1 = np.random.RandomState(seed + 1)
idx = rng1.randint(0,mag_cat.shape[0],size=(case_num,1))
Gal_Hlr     = size_cat[idx]
Gal_Mag     = mag_cat[idx]

Gal_Phi     = np.linspace(-np.pi/2,np.pi/2,int(real_num/2))
Gal_Phi     = np.concatenate((Gal_Phi, Gal_Phi+np.pi/2),axis=0)

rng2 = np.random.RandomState(seed + 2)
Gal_AxRatio = rng1.random(size=(case_num,1))*(0.1-1)+1

Gal_E1 = (1-Gal_AxRatio)/(1+Gal_AxRatio)*np.cos(Gal_Phi*2)
Gal_E2 = (1-Gal_AxRatio)/(1+Gal_AxRatio)*np.sin(Gal_Phi*2)

rng3 = np.random.RandomState(seed + 3)
PSF_randint = rng.randint(0,high=10000,size=(case_num,1))

gal_pars = {}
gal_pars["e1"] = Gal_E1
gal_pars["e2"] = Gal_E2
gal_pars["hlr_disk"] = Gal_Hlr
gal_pars["mag_i"] = Gal_Mag
gal_pars['randint'] = PSF_randint

shear_pars = {}
shear_pars['shear'] = Gal_Shear

In [None]:
show_ds = CaliDataset(shear_pars, gal_pars)

plt.figure(figsize=(10,10))

for i in range(5):
    for j in range(6):
        plt.subplot(5,6,i*6+j+1)
        #print(j)
        if j == 0:
            plt.text(10,30,'%d'%show_ds.__getitem__(i*real_num+j)['snr'],color='white',fontsize=18)
        gal_im = show_ds.__getitem__(i*real_num+j)['gal_image'][0]
        plt.imshow(gal_im,cmap='gray')
        plt.xticks([])
        plt.yticks([])
        
plt.subplots_adjust(wspace=0., hspace=-0.3)

In [None]:
# this might takes hours depending on the data volume
start = time.time()

cali_ds = CaliDataset(shear_pars, gal_pars)

cali_dl = DataLoader(cali_ds, 
                      batch_size=250,
                      num_workers=20)

cali_pred, cali_true, cali_snr, loss = tr._predictFunc(cali_dl,model)

(time.time()-start)/60

In [None]:
dataset = np.zeros((case_num*real_num,8))
dataset[:,0:5] = cali_pred

box = np.zeros((case_num*real_num,2))
for i in range(case_num*real_num):
    idx = i//real_num
    box[i,0:2] = Gal_Shear[idx,:]

dataset[:,5:7] = box
dataset[:,7] = cali_snr[:,1]

In [None]:
# Save
# hdu = fits.PrimaryHDU(dataset)
# hdul = fits.HDUList([hdu])
# hdul.writeto('TypeI_5000case_2000real.fits')