In [1]:
from forklens import train
from forklens.dataset import ShapeDataset
from torch.utils.data import DataLoader

import numpy as np
# import matplotlib.pyplot as plt
from astropy.io import fits

Here we train the CNN with 100,000 pairs of mock galaxy and PSF pairs.


In [2]:
# Load data catalog
with fits.open('../../data/csst_snr_catalog.fits') as f:
    cat = f[0].data
cut_idx = np.where((cat[:,3]>0.1)&(cat[:,4]<25))[0]
size_cat = cat[cut_idx,3]
mag_cat = cat[cut_idx,4]

In [3]:
# Generate data
def dataframe(num, seed=12345):
    
    num = int(num/2)
    
    rng1 = np.random.RandomState(seed)
    idx = rng1.randint(0,mag_cat.shape[0],size=num)
    Gal_Hlr   = size_cat[idx]
    Gal_Hlr   = np.concatenate((Gal_Hlr, Gal_Hlr), axis=0)
    Gal_Mag   = mag_cat[idx]
    Gal_Mag   = np.concatenate((Gal_Mag, Gal_Mag), axis=0)
    
    rng2 = np.random.RandomState(seed+1)
    Gal_Theta = rng2.random(num*2)*2*np.pi
    Gal_ELL = np.sqrt(rng2.random(num*2))*0.999
    e1 = Gal_ELL*np.sin(2*Gal_Theta)
    e2 = Gal_ELL*np.cos(2*Gal_Theta)

    rng3 = np.random.RandomState(seed+2)
    PSF_randint = rng3.randint(0,10000,size=num)
    PSF_randint = np.concatenate((PSF_randint, PSF_randint), axis=0)
    
    gal_pars = {}
    gal_pars["e1"] = e1
    gal_pars["e2"] = e2
    gal_pars["hlr_disk"] = Gal_Hlr
    gal_pars["mag_i"] = Gal_Mag
    
    psf_pars = {}
    psf_pars['randint'] = PSF_randint
    
    return gal_pars, psf_pars

In the above we generate a input catalog of the training dataset (including both training and validation).

Note two catalogs are generated: one for galaxies and one for PSFs. The catalog for PSFs is simply random index which does not make a difference here, as the simulated PSF image in this code is unchanged for simplicity. 

One may want to rewrite his own code for simulation.py and also to customize dataset.py.  

In [4]:
# Get data loader
nSims = 100000
GalCat, PSFCat = dataframe(nSims, seed=12345)
train_ds = ShapeDataset(GalCat, PSFCat)

(4720952,)


Some hyper parameters for training are to be modified in ./config.py.

In [5]:
# Train the network
tr = train.Train()
tr.run(train_ds, show_log=True);

Train_dl: 360 Validation_dl: 40
Begin training ...
[TRAIN] Epoch: 1 Loss: 0.27965382910751724 Time: 1:14
[VALID] Epoch: 1 Loss: 0.2313229511027656 Time: 0:11
[TRAIN] Epoch: 2 Loss: 0.22804690249232595 Time: 1:16
[VALID] Epoch: 2 Loss: 0.22015504617240078 Time: 0:11
[TRAIN] Epoch: 3 Loss: 0.22308680501612704 Time: 1:18
[VALID] Epoch: 3 Loss: 0.22393668344096215 Time: 0:11
[TRAIN] Epoch: 4 Loss: 0.21994440234572382 Time: 1:17
[VALID] Epoch: 4 Loss: 0.21574010761141812 Time: 0:11
[TRAIN] Epoch: 5 Loss: 0.21813957438734216 Time: 1:18
[VALID] Epoch: 5 Loss: 0.21390902532403008 Time: 0:11
[TRAIN] Epoch: 6 Loss: 0.21584955123358815 Time: 1:17
[VALID] Epoch: 6 Loss: 0.21409474290142302 Time: 0:12
[TRAIN] Epoch: 7 Loss: 0.21499654557658984 Time: 1:17
[VALID] Epoch: 7 Loss: 0.21471007284832788 Time: 0:11
[TRAIN] Epoch: 8 Loss: 0.21382891788195516 Time: 1:18
[VALID] Epoch: 8 Loss: 0.2141304829640659 Time: 0:12
[TRAIN] Epoch: 9 Loss: 0.2116843066521367 Time: 1:18
[VALID] Epoch: 9 Loss: 0.211507336

([0.27965382910751724,
  0.22804690249232595,
  0.22308680501612704,
  0.21994440234572382,
  0.21813957438734216,
  0.21584955123358815,
  0.21499654557658984,
  0.21382891788195516,
  0.2116843066521367,
  0.21153840155606904,
  0.21095573511796697,
  0.21010204380159883,
  0.21024271915627457,
  0.2091244131206056,
  0.20977675711085964,
  0.20950613003866464,
  0.2076683395381882,
  0.20798420486048072,
  0.20730620890115942,
  0.20838954996560743,
  0.2070062047126409,
  0.2066266612769532,
  0.20707193907971794,
  0.2064464807652757,
  0.20600506055348064,
  0.20630110037791866,
  0.20653259309149521,
  0.20553451123523778,
  0.20547243510280913,
  0.20570724895928633,
  0.20569670225876957,
  0.20522272934446936,
  0.20487745442883043,
  0.2056912681900787,
  0.20582109155035067,
  0.20547539407042795,
  0.20540634325798418,
  0.20415691320301468,
  0.20437575065805952,
  0.20484396660495288,
  0.20410320227074302,
  0.2036458313160235,
  0.204918595536757,
  0.20451567301026397