In [None]:
import tensorflow as tf
tf.keras.backend.set_floatx('float32')

from astropy.table import Table
from astropy.cosmology import Planck15 as cosmo
from math import pi

import numpy as np

In [None]:
#Check if GPUs. If there are, some code to fix cuDNN bugs
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)
else:
    print('No GPU')

In [None]:
train_file = 'HSC_SSP-morphology-sex-train.fits'
valid_file = 'HSC_SSP-morphology-sex-valid.fits'

MIN_MAX = {'asymmetry':[-4.0,4.0],
           'concentration':[0.0, 6.0],
           'deviation':[0.0, 3.0],
           'ellipticity_asymmetry':[0.0, 1.0],
           'ellipticity_centroid':[0.0, 1.0],
           'elongation_asymmetry':[1.0, 8.0],
           'elongation_centroid':[1.0, 8.0],
           'gini':[0.0, 1.0],
           'gini_m20_bulge':[-3, 3],
           'gini_m20_merger':[-1.0, 1.0],
           'intensity':[0.0, 1.0],
           'm20':[-4.0, 0.0],
           'multimode':[0.0, 1.0],
           'sersic_amplitude':[0.0, 200.0],
           'sersic_ellip':[-6.0, 3.0],
           'sersic_n':[0.0, 50.0],
           'smoothness':[-0.4,0.4]}

CLASS_NAMES = ['merger', 'nonmerger']
NO_CLASS = len(CLASS_NAMES)
CLASS_COLUMN = 'class'

EPOCHS = 5000
BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 10000

STEPS_PER_EPOCH = 1
STEPS_PER_VALID_EPOCH = 1

train_photo_count = 0
valid_photo_count = 0

In [None]:
def prepare_fits_dataset(file_path, class_column, class_number, min_max, use_columns=None):
    table = Table.read(file_path)
    
    if 'train' in file_path:
        global train_photo_count
        global STEPS_PER_EPOCH
        train_photo_count = len(table)
        STEPS_PER_EPOCH = np.ceil(train_photo_count/BATCH_SIZE).astype(int)
        batch_size = BATCH_SIZE
    elif 'valid' in file_path:
        global valid_photo_count
        global STEPS_PER_VALID_EPOCH
        valid_photo_count = len(table)
        STEPS_PER_VALID_EPOCH = 1
        batch_size = valid_photo_count
    
    #labels
    labels = tf.one_hot(table[class_column].data, class_number)
    
    #data
    data = []
    nan = []
    for column in min_max.keys():
        data.append(table[column].data.astype(np.float32))
        nan.append(np.where(data[-1] == -99))
        data[-1] -= min_max[column][0]
        data[-1] /= (min_max[column][1] - min_max[column][0])
    data = np.array(data)
    data = data.T
    
    for i in range(0, len(nan)):
        data[nan[i]] = -1.0
        
    if use_columns is not None:
        #photometric data
        data2 = []
        for column in use_columns:
            data2.append(table[column].data)
        data2 = np.array(data2)  
        data2 = data2.T

        nan = np.where(~np.isfinite(data2))

        for i in range(0, len(data2)):
            data2[i] -= np.nanmin(data2[i])
            data2[i] /= np.nanmax(data2[i])

        if len(nan[0]) > 0:
            data2[nan] = -1.0
            
        data = np.hstack((data, data2))
        
    ds = tf.data.Dataset.from_tensor_slices((data, labels))
    ds = ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(batch_size)
    return ds

In [None]:
train_ds = prepare_fits_dataset(train_file, CLASS_COLUMN, NO_CLASS, MIN_MAX)
valid_ds = prepare_fits_dataset(valid_file, CLASS_COLUMN, NO_CLASS, MIN_MAX)

In [None]:
class morph_model(tf.keras.Model):
    def __init__(self):
        super(morph_model, self).__init__()
        self.drop_rate = 0.2
        
        self.fuco1 = tf.keras.layers.Dense(128)
        self.batn1 = tf.keras.layers.BatchNormalization()
        self.drop1 = tf.keras.layers.Dropout(self.drop_rate)
        
        self.fuco4 = tf.keras.layers.Dense(128)
        self.batn4 = tf.keras.layers.BatchNormalization()
        self.drop4 = tf.keras.layers.Dropout(self.drop_rate)
        
    def call(self, x, training=True):
        
        x = self.fuco1(x)
        x = self.batn1(x)
        x = tf.keras.activations.relu(x)
        x = self.drop1(x, training=training)
        
        x = self.fuco4(x)
        x = self.batn4(x)
        x = tf.keras.activations.relu(x)
        x = self.drop4(x, training=training)
        
        return x
    
class morph_wrapper(tf.keras.Model):
    def __init__(self):
        super(morph_wrapper, self).__init__()
        self.y_out = tf.keras.layers.Dense(NO_CLASS, activation='softmax')
        
        self.morph_model = morph_model()
        
    def call(self, morph, training=True):
        x = self.morph_model(morph, training)
        return self.y_out(x)

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

val_loss = tf.keras.metrics.Mean(name='val_loss')
val_accuracy = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy')

@tf.function
def train_step(data, labels):
    '''labels shoule be one_hot'''
    with tf.GradientTape() as tape:
        pred = model(data)
        loss = total_loss(labels, pred)
        mean_loss = tf.reduce_mean(loss)

    #Update gradients and optimize
    grads = tape.gradient(mean_loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    #tf statistics tracking
    train_loss(mean_loss)
    train_accuracy(labels, pred)

@tf.function
def val_step(data, labels):
    '''labels should be one_hot'''
    pred = model(data, training=False)
    v_loss = total_loss(labels, pred)
    mean_v_loss = tf.reduce_mean(v_loss)

    #tf statistics tracking
    val_loss(mean_v_loss)
    val_accuracy(labels, pred)
    return pred

In [None]:
model = morph_wrapper()

total_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)

In [None]:
print('To train on', train_photo_count)
print('To validate on', valid_photo_count)
peak = [0, 0, 100]

t_los = []
t_acc = []
v_los = []
v_acc = []

template = 'Epoch {}\nTrain Loss: {:.3g}, Train Accuracy: {:.3g}\nValid Loss: {:.3g}, Valid Accuracy: {:.3g}'

for epoch in range(0, EPOCHS):
    
    train_loss.reset_states()
    train_accuracy.reset_states()
    val_loss.reset_states()
    val_accuracy.reset_states()
    
    #Train
    for step in range(0, STEPS_PER_EPOCH):
        x_batch, y_batch = next(iter(train_ds))
        train_step(x_batch, y_batch)
    
    #Validate  
    y_val_all = None
    val_pred = None
    for step in range(0, STEPS_PER_VALID_EPOCH):
        x_val, y_val = next(iter(valid_ds))
        if y_val_all is None:
            y_val_all = y_val
        else:
            y_val_all = np.vstack((y_val_all, y_val))
        pred = val_step(x_val, y_val)
        if val_pred is None:
            val_pred = pred
        else:
            val_pred = np.vstack((val_pred, pred))
    
    if epoch%100 == 0:
        print(template.format(epoch+1,
                              train_loss.result(), train_accuracy.result(),
                              val_loss.result(), val_accuracy.result()))
    
    t_los.append(train_loss.result())
    t_acc.append(train_accuracy.result())
    v_los.append(val_loss.result())
    v_acc.append(val_accuracy.result())
    
    if val_loss.result() <= peak[2] or\
      (val_loss.result() == peak[2] and val_accuracy.result() >= peak[1]):
        peak[0] = epoch+1
        peak[1] = val_accuracy.result()
        peak[2] = val_loss.result()
        model.morph_model.save_weights('./saved_morph_model_HSC_SSP_wAS-sex/checkpoint')
    
    if val_accuracy.result() > 1.1:
        y_val = np.argmax(y_val_all, axis=1)
        y_out_val_agm = np.argmax(val_pred, axis=1)
        for j in range(0, NO_CLASS):
            testing = np.where(y_out_val_agm == j)
            correct = np.where(np.logical_and(y_out_val_agm == j, y_val == j))
            validating = np.where(y_val == j)
     
            if len(testing[0]) == 0:
                cor_tes = 'inf'
            elif len(correct[0]) == 0:
                cor_tes = 0.0
            else:
                cor_tes = len(correct[0])/len(testing[0])
                cor_tes = round(cor_tes, 3)
    
            if len(validating[0]) == 0:
                cor_val = 'inf'
            elif len(correct[0]) == 0:
                cor_val = 0.0
            else:
                cor_val = len(correct[0])/len(validating[0])
                cor_val = round(cor_val, 3)
                
            print('Val \t Are', CLASS_NAMES[j], ' classed ', CLASS_NAMES[j], ':', cor_val, 
                  '(',len(correct[0]), 'of', len(validating[0]),')')
            print('Val \t Classed', CLASS_NAMES[j], ' are ', CLASS_NAMES[j], ':', cor_tes, 
                  '(',len(correct[0]), 'of', len(testing[0]),')')
        print()
    
    if epoch%100 == 0:
        print()
print('Peaks at Epoch', peak[0], 'with accuracy', np.round(peak[1],3), 'and loss', np.round(peak[2],3))