In [None]:
import tensorflow as tf
tf.keras.backend.set_floatx('float32')
AUTOTUNE = tf.data.experimental.AUTOTUNE

from tf_fits.image import image_decode_fits
from tf_fits.bintable import bintable_decode_fits

from tensorflow_addons.image import rotate as tfa_image_rotate
from math import pi

import numpy as np
import glob
import os

In [None]:
#Check if GPUs. If there are, some code to fix cuDNN bugs
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)
else:
    print('No GPU')

In [None]:
train_path = './train/'
valid_path = './valid/'

extn = '.fits'

CLASS_NAMES = ['merger', 'nonmerger']
NO_CLASS = len(CLASS_NAMES)

train_images = glob.glob(train_path+'**/*'+extn)
train_image_count = len(train_images)
valid_images = glob.glob(valid_path+'**/*'+extn)
valid_image_count = len(valid_images)

EPOCHS = 200
BATCH_SIZE = 64

STEPS_PER_EPOCH = np.ceil(train_image_count/BATCH_SIZE).astype(int)
STEPS_PER_VALID_EPOCH = 1

IMG_HEIGHT = 128
IMG_WIDTH = 128
edge_cut = (128 - IMG_HEIGHT)//2
CROP_FRAC = IMG_HEIGHT/(edge_cut+edge_cut+IMG_HEIGHT)

print(train_image_count, STEPS_PER_EPOCH)
print(valid_image_count, STEPS_PER_VALID_EPOCH)

In [None]:
@tf.function
def get_columns(table):
    columns = tf.gather(table, [2,3,4,5,6,7,10,11,12,13,14,15,30,31,32])
    return columns

@tf.function
def _normalise_condition(i, numcol, tf_MIN_MAX, table, new_table):
    return tf.math.less(i, numcol)

@tf.function
def _normalise_body(i, numcol, tf_MIN_MAX, table, new_table):
    
    MIN = tf.gather_nd(tf_MIN_MAX, [i,0], name='MIN_slice')
    MAX = tf.gather_nd(tf_MIN_MAX, [i,1], name='MAX_slice')
    
    column = tf.gather_nd(table, [i], name='table_gather')
    column -= MIN
    column /= tf.subtract(MAX, MIN)
    new_table = new_table.write(i, column)
    
    i += 1
    return i, numcol, tf_MIN_MAX, table, new_table

@tf.function
def normalise_columns(table):
    tf_MIN_MAX = tf.constant([[0.0, 6.0],    #concentration
                              [0.0, 3.0],    #deviation
                              [0.0, 1.0],    #ellipticity_asymmetry
                              [0.0, 1.0],    #ellipticity_centroid
                              [0.0, 155.0],  #elongation_asymmetry
                              [0.0, 155.0],  #elongation_centroid
                              [0.0, 1.0],    #gini
                              [-3.0, 3.0],   #gini_m20_bulge
                              [-1.0, 1.0],   #gini_m20_merger
                              [0.0, 1.0],    #intensity
                              [-4.0, 0.0],   #m20
                              [0.0, 1.0],    #multimode
                              [0.0, 200.0],  #sersic_amplitude
                              [-6.0, 3.0],   #sersic_ellip
                              [0.0, 500.0]], #sersic_n
                             dtype=tf.float32)
    
    numcol = tf.constant(15, name='numcol')    
    i = tf.constant(0, name='i')
    
    new_table = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True, name='new_table')
    
    _, _, _, _, new_table = tf.while_loop(_normalise_condition, _normalise_body,
                      [i, numcol, tf_MIN_MAX, table, new_table],
                      shape_invariants=[i.get_shape(),
                                        numcol.get_shape(),
                                        tf_MIN_MAX.get_shape(),
                                        table.get_shape(),
                                        tf.TensorShape(None)])
    
    new_table = new_table.stack()
    new_table = tf.reshape(new_table, (15,))

    return new_table

In [None]:
#@tf.function
def get_label(file_path):
    # convert the path to a list of path components
    parts = tf.strings.split(file_path, os.path.sep)
    # The second to last is the class-directory
    return parts[-2] == CLASS_NAMES

#@tf.function
def decode_image(byte_data):
    #Get the image from the byte string
    img = image_decode_fits(byte_data, 0) 
    img = tf.reshape(img, (128,128,1))
    return img

def process_path(file_path):
    label = get_label(file_path)
    byte_data = tf.io.read_file(file_path)
    img = decode_image(byte_data)
    return img, label

In [None]:
from time import time
g = tf.random.Generator.from_seed(1)#int(time()))

#@tf.function
def augment_img(img, label):
    img = tf.image.rot90(img, k=g.uniform([], 0, 4, dtype=tf.int32))
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    
    return img, label

#@tf.function
def crop_img(img, label):
    img = tf.slice(img, [edge_cut,edge_cut,0], [IMG_HEIGHT,IMG_HEIGHT,1])
    img = tf.image.per_image_standardization(img)
    
    return img, label

In [None]:
#@tf.function
def prepare_dataset(ds, batch_size, shuffle_buffer_size=1000, training=False):
    #Load images and labels
    ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
    #cache result
    ds = ds.cache()
    #shuffle images
    ds = ds.shuffle(buffer_size=shuffle_buffer_size)
    
    #Augment Image
    if training:
        ds = ds.map(augment_img, num_parallel_calls=AUTOTUNE)
    ds = ds.map(crop_img, num_parallel_calls=AUTOTUNE)
    
    #Set batches and repeat forever
    ds = ds.batch(batch_size)
    ds = ds.repeat()
    
    # `prefetch` lets the dataset fetch batches in the background while the model
    # is training
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    
    return ds

In [None]:
list_train_ds = tf.data.Dataset.list_files(train_path+'*/*'+extn)
train_ds = prepare_dataset(list_train_ds, BATCH_SIZE, train_image_count, True)

list_valid_ds = tf.data.Dataset.list_files(valid_path+'*/*'+extn)
valid_ds = prepare_dataset(list_valid_ds, valid_image_count, valid_image_count)

In [None]:
print('train')
i_batch, y_batch = next(iter(train_ds))
print(i_batch.get_shape())
print(y_batch.get_shape())
print('valid')
_ = next(iter(valid_ds))
print('test load complete')

In [None]:
class image_model(tf.keras.Model):
    def __init__(self):
        super(image_model, self).__init__()
        self.drop_rate = 0.2
        
        self.conv1 = tf.keras.layers.Conv2D(32, 6, strides=1, padding='same')
        self.batn1 = tf.keras.layers.BatchNormalization()
        self.drop1 = tf.keras.layers.Dropout(self.drop_rate)
        self.pool1 = tf.keras.layers.MaxPool2D(2, 2, padding='same')
        
        self.conv2 = tf.keras.layers.Conv2D(64, 5, strides=1, padding='same')
        self.batn2 = tf.keras.layers.BatchNormalization()
        self.drop2 = tf.keras.layers.Dropout(self.drop_rate)
        self.pool2 = tf.keras.layers.MaxPool2D(2, 2, padding='same')
        
        self.conv3 = tf.keras.layers.Conv2D(128, 3, strides=1, padding='same')
        self.batn3 = tf.keras.layers.BatchNormalization()
        self.drop3 = tf.keras.layers.Dropout(self.drop_rate)
        self.conv4 = tf.keras.layers.Conv2D(128, 3, strides=1, padding='same')
        self.batn4 = tf.keras.layers.BatchNormalization()
        self.drop4 = tf.keras.layers.Dropout(self.drop_rate)
        self.pool3 = tf.keras.layers.MaxPool2D(2, 2, padding='same')
        
        self.flatten = tf.keras.layers.Flatten()
        
        self.fc1 = tf.keras.layers.Dense(2048)
        self.batn5 = tf.keras.layers.BatchNormalization()
        self.drop5 = tf.keras.layers.Dropout(self.drop_rate)
        self.fc2 = tf.keras.layers.Dense(128)
        self.batn6 = tf.keras.layers.BatchNormalization()
        self.drop6 = tf.keras.layers.Dropout(self.drop_rate)
        
    def call(self, inputs, training=True):
        
        x = self.conv1(inputs)
        x = self.batn1(x)
        x = tf.keras.activations.relu(x)
        x = self.drop1(x, training=training)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.batn2(x)
        x = tf.keras.activations.relu(x)
        x = self.drop2(x, training=training)
        x = self.pool2(x)
        
        x = self.conv3(x)
        x = self.batn3(x)
        x = tf.keras.activations.relu(x)
        x = self.drop3(x, training=training)
        x = self.conv4(x)
        x = self.batn4(x)
        x = tf.keras.activations.relu(x)
        x = self.drop4(x, training=training)
        x = self.pool3(x)
        
        x = self.flatten(x)
        
        x = self.fc1(x)
        x = self.batn5(x)
        x = tf.keras.activations.relu(x)
        x = self.drop5(x, training=training)
        x = self.fc2(x)
        x = self.batn6(x)
        x = tf.keras.activations.relu(x)
        x = self.drop6(x, training=training)
        
        return x

class image_wrapper(tf.keras.Model):
    def __init__(self):
        super(image_wrapper, self).__init__()
        self.y_out = tf.keras.layers.Dense(NO_CLASS, activation='softmax')
        
        self.image_model = image_model()
        
    def call(self, image, training=True):
        x = self.image_model(image, training)
        return self.y_out(x)

In [None]:
#@tf.function
def train_step(images, labels):
    '''labels shoule be one_hot'''
    with tf.GradientTape() as tape:
        pred = model(images)
        loss = total_loss(labels, pred)
        mean_loss = tf.reduce_mean(loss)

    #Update gradients and optimize
    grads = tape.gradient(mean_loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    #tf statistics tracking
    train_loss(mean_loss)
    train_accuracy(labels, pred)

#@tf.function
def val_step(images, labels):
    '''labels should be one_hot'''
    pred = model(images, training=False)
    v_loss = total_loss(labels, pred)
    mean_v_loss = tf.reduce_mean(v_loss)

    #tf statistics tracking
    val_loss(mean_v_loss)
    val_accuracy(labels, pred)
    return pred

In [None]:
model = image_wrapper()

optimizer = tf.keras.optimizers.Adam(learning_rate=5e-6)   
total_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

val_loss = tf.keras.metrics.Mean(name='val_loss')
val_accuracy = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy')

In [None]:
peak = [0, 0, 100]
t_los = []
t_acc = []
v_los = []
v_acc = []
run = 0

In [None]:
print('To train on', train_image_count)
print('To validate on', valid_image_count)

template = 'Epoch {}\nTrain Loss: {:.3g}, Train Accuracy: {:.3g}\nValid Loss: {:.3g}, Valid Accuracy: {:.3g}'
for epoch in range(0, EPOCHS):
    
    train_loss.reset_states()
    train_accuracy.reset_states()
    val_loss.reset_states()
    val_accuracy.reset_states()
    
    #Train
    print('Epoch', epoch+1+(EPOCHS*run), 'training')
    for step in range(0, STEPS_PER_EPOCH):
        i_batch, y_batch = next(iter(train_ds))
        train_step(i_batch, y_batch)
    
    #Validate  
    print('Epoch', epoch+1+(EPOCHS*run), 'validation')
    y_val_all = None
    val_pred = None
    for step in range(0, STEPS_PER_VALID_EPOCH):
        i_val, y_val = next(iter(valid_ds))
        if y_val_all is None:
            y_val_all = y_val
        else:
            y_val_all = np.vstack((y_val_all, y_val))
        pred = val_step(i_val, y_val)
        if val_pred is None:
            val_pred = pred
        else:
            val_pred = np.vstack((val_pred, pred))
    
    print(template.format(epoch+1+(EPOCHS*run),
                          train_loss.result(), train_accuracy.result(),
                          val_loss.result(), val_accuracy.result()))
    
    t_los.append(train_loss.result())
    t_acc.append(train_accuracy.result())
    v_los.append(val_loss.result())
    v_acc.append(val_accuracy.result())
    
    if val_loss.result() <= peak[2] or\
      (val_loss.result() == peak[2] and val_accuracy.result() >= peak[1]):
        peak[0] = epoch+1+(EPOCHS*run)
        peak[1] = val_accuracy.result()
        peak[2] = val_loss.result()
        model.image_model.save_weights('./saved_image_model_HSC_SSP-sex/checkpoint')
        print('Saved')
    
    if val_accuracy.result() > 0.9:
        y_val = np.argmax(y_val_all, axis=1)
        y_out_val_agm = np.argmax(val_pred, axis=1)
        for j in range(0, NO_CLASS):
            testing = np.where(y_out_val_agm == j)
            correct = np.where(np.logical_and(y_out_val_agm == j, y_val == j))
            validating = np.where(y_val == j)
     
            if len(testing[0]) == 0:
                cor_tes = 'inf'
            elif len(correct[0]) == 0:
                cor_tes = 0.0
            else:
                cor_tes = len(correct[0])/len(testing[0])
                cor_tes = round(cor_tes, 3)
    
            if len(validating[0]) == 0:
                cor_val = 'inf'
            elif len(correct[0]) == 0:
                cor_val = 0.0
            else:
                cor_val = len(correct[0])/len(validating[0])
                cor_val = round(cor_val, 3)
                
            print('Val \t Are', CLASS_NAMES[j], ' classed ', CLASS_NAMES[j], ':', cor_val, 
                  '(',len(correct[0]), 'of', len(validating[0]),')')
            print('Val \t Classed', CLASS_NAMES[j], ' are ', CLASS_NAMES[j], ':', cor_tes, 
                  '(',len(correct[0]), 'of', len(testing[0]),')')
        
    print()
print('Peaks at Epoch', peak[0], 'with accuracy', np.round(peak[1],3), 'and loss', np.round(peak[2],3))