In [None]:
import tensorflow as tf
tf.keras.backend.set_floatx('float32')
AUTOTUNE = tf.data.experimental.AUTOTUNE

from tf_fits.image import image_decode_fits
from tf_fits.bintable import bintable_decode_fits

import numpy as np
import glob
import os

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
#Check if GPUs. If there are, some code to fix cuDNN bugs
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)
else:
    print('No GPU')

In [None]:
valid_path = './valid/'
test_path = './test/'

extn = '.fits'

CLASS_NAMES = ['merger', 'nonmerger']
NO_CLASS = len(CLASS_NAMES)

valid_images = glob.glob(valid_path+'**/*'+extn)
valid_image_count = len(valid_images)
test_images = glob.glob(test_path+'**/*'+extn)
test_image_count = len(test_images)

EPOCHS = 1
VALID_BATCH_SIZE = valid_image_count
TEST_BATCH_SIZE = test_image_count

STEPS_PER_VALID_EPOCH = 1
STEPS_PER_TEST_EPOCH = 1

IMG_HEIGHT = 128
IMG_WIDTH = 128
edge_cut = (128 - IMG_HEIGHT)//2
CROP_FRAC = IMG_HEIGHT/(edge_cut+edge_cut+IMG_HEIGHT)

print(valid_image_count, STEPS_PER_VALID_EPOCH)
print(test_image_count, STEPS_PER_TEST_EPOCH)

In [None]:
@tf.function
def get_columns(table):
    columns = tf.gather(table, [1,2,3,4,5,6,7,10,11,12,13,14,15,30,31,32,41])#,51,52])
    return columns

@tf.function
def _normalise_condition(i, numcol, tf_MIN_MAX, table, new_table):
    return tf.math.less(i, numcol)

@tf.function
def _normalise_body(i, numcol, tf_MIN_MAX, table, new_table):
    
    MIN = tf.gather_nd(tf_MIN_MAX, [i,0], name='MIN_slice')
    MAX = tf.gather_nd(tf_MIN_MAX, [i,1], name='MAX_slice')
    
    column = tf.gather_nd(table, [i], name='table_gather')
    column -= MIN
    column /= tf.subtract(MAX, MIN)
    new_table = new_table.write(i, column)
    
    i += 1
    return i, numcol, tf_MIN_MAX, table, new_table

@tf.function
def normalise_columns(table):
    tf_MIN_MAX = tf.constant([[-4.0, 4.0],   #asymmetry
                              [0.0, 6.0],    #concentration
                              [0.0, 3.0],    #deviation
                              [0.0, 1.0],    #ellipticity_asymmetry
                              [0.0, 1.0],    #ellipticity_centroid
                              [1.0, 8.0],    #elongation_asymmetry
                              [1.0, 8.0],    #elongation_centroid
                              [0.0, 1.0],    #gini
                              [-3.0, 3.0],   #gini_m20_bulge
                              [-1.0, 1.0],   #gini_m20_merger
                              [0.0, 1.0],    #intensity
                              [-4.0, 0.0],   #m20
                              [0.0, 1.0],    #multimode
                              [0.0, 200.0],  #sersic_amplitude
                              [-6.0, 3.0],   #sersic_ellip
                              [0.0, 50.0],   #sersic_n
                              [-0.4, 0.4]],  #smoothness
                             dtype=tf.float32)
    
    numcol = tf.constant(17, name='numcol')    
    i = tf.constant(0, name='i')
    
    new_table = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True, name='new_table')
    
    _, _, _, _, new_table = tf.while_loop(_normalise_condition, _normalise_body,
                      [i, numcol, tf_MIN_MAX, table, new_table],
                      shape_invariants=[i.get_shape(),
                                        numcol.get_shape(),
                                        tf_MIN_MAX.get_shape(),
                                        table.get_shape(),
                                        tf.TensorShape(None)])
    
    new_table = new_table.stack()
    new_table = tf.reshape(new_table, (17,))

    return new_table

In [None]:
#@tf.function
def get_label(file_path):
    # convert the path to a list of path components
    parts = tf.strings.split(file_path, os.path.sep)
    # The second to last is the class-directory
    return parts[-2] == CLASS_NAMES

#@tf.function
def decode_image(byte_data):
    #Get the image from the byte string
    img = image_decode_fits(byte_data, 0) 
    img = tf.reshape(img, (128,128,1))
    return img

def decode_bintable(byte_data):
    #Get the table from the byte string
    morph = bintable_decode_fits(byte_data, 3)
    morph = tf.reshape(morph, (53,))
    morph = get_columns(morph)
    morph = normalise_columns(morph)
    
    return morph

def process_path(file_path):
    label = get_label(file_path)
    byte_data = tf.io.read_file(file_path)
    img = decode_image(byte_data)
    morph = decode_bintable(byte_data)
    return img, morph, label

In [None]:
from time import time
g = tf.random.Generator.from_seed(1)#int(time()))

#@tf.function
def augment_img(img, morph, label):
    img = tf.image.rot90(img, k=g.uniform([], 0, 4, dtype=tf.int32))
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    
    return img, morph, label

#@tf.function
def crop_img(img, morph, label):
    img = tf.slice(img, [edge_cut,edge_cut,0], [IMG_HEIGHT,IMG_HEIGHT,1])
    img = tf.image.per_image_standardization(img)
    
    return img, morph, label

In [None]:
#@tf.function
def prepare_dataset(ds, batch_size, shuffle_buffer_size=1000, training=False):
    #Load images and labels
    ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
    #cache result
    ds = ds.cache()
    #shuffle images
    ds = ds.shuffle(buffer_size=shuffle_buffer_size)
    
    #Augment Image
    if training:
        ds = ds.map(augment_img, num_parallel_calls=AUTOTUNE)
    ds = ds.map(crop_img, num_parallel_calls=AUTOTUNE)
    
    #Set batches and repeat forever
    ds = ds.batch(batch_size)
    ds = ds.repeat()
    
    # `prefetch` lets the dataset fetch batches in the background while the model
    # is training
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    
    return ds

In [None]:
list_valid_ds = tf.data.Dataset.list_files(valid_path+'*/*'+extn)
valid_ds = prepare_dataset(list_valid_ds, valid_image_count, valid_image_count)

list_test_ds = tf.data.Dataset.list_files(test_path+'*/*'+extn)
test_ds = prepare_dataset(list_test_ds, test_image_count, test_image_count)

In [None]:
class morph_model(tf.keras.Model):
    def __init__(self):
        super(morph_model, self).__init__()
        self.drop_rate = 0.2
        
        self.fuco1 = tf.keras.layers.Dense(128)
        self.batn1 = tf.keras.layers.BatchNormalization()
        self.drop1 = tf.keras.layers.Dropout(self.drop_rate)
        
        self.fuco4 = tf.keras.layers.Dense(128)
        self.batn4 = tf.keras.layers.BatchNormalization()
        self.drop4 = tf.keras.layers.Dropout(self.drop_rate)
        
    def call(self, x, training=True):
        
        x = self.fuco1(x)
        x = self.batn1(x)
        x = tf.keras.activations.relu(x)
        x = self.drop1(x, training=training)
        
        x = self.fuco4(x)
        x = self.batn4(x)
        x = tf.keras.activations.relu(x)
        x = self.drop4(x, training=training)
        
        return x

In [None]:
class image_model(tf.keras.Model):
    def __init__(self):
        super(image_model, self).__init__()
        self.drop_rate = 0.2
        
        self.conv1 = tf.keras.layers.Conv2D(32, 6, strides=1, padding='same')
        self.batn1 = tf.keras.layers.BatchNormalization()
        self.drop1 = tf.keras.layers.Dropout(self.drop_rate)
        self.pool1 = tf.keras.layers.MaxPool2D(2, 2, padding='same')
        
        self.conv2 = tf.keras.layers.Conv2D(64, 5, strides=1, padding='same')
        self.batn2 = tf.keras.layers.BatchNormalization()
        self.drop2 = tf.keras.layers.Dropout(self.drop_rate)
        self.pool2 = tf.keras.layers.MaxPool2D(2, 2, padding='same')
        
        self.conv3 = tf.keras.layers.Conv2D(128, 3, strides=1, padding='same')
        self.batn3 = tf.keras.layers.BatchNormalization()
        self.drop3 = tf.keras.layers.Dropout(self.drop_rate)
        self.conv4 = tf.keras.layers.Conv2D(128, 3, strides=1, padding='same')
        self.batn4 = tf.keras.layers.BatchNormalization()
        self.drop4 = tf.keras.layers.Dropout(self.drop_rate)
        self.pool3 = tf.keras.layers.MaxPool2D(2, 2, padding='same')
        
        self.flatten = tf.keras.layers.Flatten()
        
        self.fc1 = tf.keras.layers.Dense(2048) #2048
        self.batn5 = tf.keras.layers.BatchNormalization()
        self.drop5 = tf.keras.layers.Dropout(self.drop_rate)
        self.flatten = tf.keras.layers.Flatten()
        self.fc2 = tf.keras.layers.Dense(128)
        self.batn6 = tf.keras.layers.BatchNormalization()
        self.drop6 = tf.keras.layers.Dropout(self.drop_rate)
        
    def call(self, inputs, training=True):
        
        x = self.conv1(inputs)
        x = self.batn1(x)
        x = tf.keras.activations.relu(x)
        x = self.drop1(x, training=training)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.batn2(x)
        x = tf.keras.activations.relu(x)
        x = self.drop2(x, training=training)
        x = self.pool2(x)
        
        x = self.conv3(x)
        x = self.batn3(x)
        x = tf.keras.activations.relu(x)
        x = self.drop3(x, training=training)
        x = self.conv4(x)
        x = self.batn4(x)
        x = tf.keras.activations.relu(x)
        x = self.drop4(x, training=training)
        x = self.pool3(x)
        
        x = self.flatten(x)
        
        x = self.fc1(x)
        x = self.batn5(x)
        x = tf.keras.activations.relu(x)
        x = self.drop5(x, training=training)
        x = self.fc2(x)
        x = self.batn6(x)
        x = tf.keras.activations.relu(x)
        x = self.drop6(x, training=training)
        
        return x

In [None]:
class image_morph_model(tf.keras.Model):
    def __init__(self):
        super(image_morph_model, self).__init__()
        
        self.image_model = image_model()
        self.morph_model = morph_model()
        
        self.drop_rate = 0.2
        
        self.im_fuco1 = tf.keras.layers.Dense(256)
        self.im_batn1 = tf.keras.layers.BatchNormalization()
        self.im_drop1 = tf.keras.layers.Dropout(self.drop_rate)
        
        self.im_y_out = tf.keras.layers.Dense(NO_CLASS, activation='softmax')
        
    def call(self, image, morph, training=True):
        img_latent = self.image_model(image, training)
        mph_latent = self.morph_model(morph, training)
        
        x = tf.concat([img_latent, mph_latent], 1)
        x = self.im_fuco1(x)
        x = self.im_batn1(x)
        x = tf.keras.activations.relu(x)
        x = self.im_drop1(x)
        
        return self.im_y_out(x)

In [None]:
#@tf.function
def val_step(images, morphs, labels):
    '''labels should be one_hot'''
    pred = model(images, morphs, training=False)
    v_loss = total_loss(labels, pred)
    mean_v_loss = tf.reduce_mean(v_loss)

    #tf statistics tracking
    val_loss(mean_v_loss)
    val_accuracy(labels, pred)
    return pred

#@tf.function
def tst_step(images, morphs, labels):
    '''labels should be one_hot'''
    pred = model(images, morphs, training=False)
    t_loss = total_loss(labels, pred)
    mean_t_loss = tf.reduce_mean(t_loss)

    #tf statistics tracking
    tst_loss(mean_t_loss)
    tst_accuracy(labels, pred)
    return pred

In [None]:
model = image_morph_model()
model.load_weights('./saved_model_HSC_SSP_wAS_128-sex-saved/checkpoint')
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4) 
    
total_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)

In [None]:
val_loss = tf.keras.metrics.Mean(name='val_loss')
val_accuracy = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy')

tst_loss = tf.keras.metrics.Mean(name='tst_loss')
tst_accuracy = tf.keras.metrics.CategoricalAccuracy(name='tst_accuracy')

# Validation

In [None]:
print('To validate on', valid_image_count)

template = 'Epoch {}\nValid Loss: {:.3g}, Valid Accuracy: {:.3g}'

for epoch in range(0, EPOCHS):
    
    val_loss.reset_states()
    val_accuracy.reset_states()
    
    #Validate  
    print('Epoch', epoch+1, 'validation')
    y_val_all = None
    val_pred = None
    for step in range(0, STEPS_PER_VALID_EPOCH):
        i_val, m_val, y_val = next(iter(valid_ds))
        if y_val_all is None:
            y_val_all = y_val
        else:
            y_val_all = np.vstack((y_val_all, y_val))
        pred = val_step(i_val, m_val, y_val)
        if val_pred is None:
            val_pred = pred
        else:
            val_pred = np.vstack((val_pred, pred))
    
    print(template.format(epoch+1, val_loss.result(), val_accuracy.result()))
    
    y_val = np.argmax(y_val_all, axis=1)
    y_out_val_agm = np.argmax(val_pred, axis=1)
    for j in range(0, NO_CLASS):
        testing = np.where(y_out_val_agm == j)
        correct = np.where(np.logical_and(y_out_val_agm == j, y_val == j))
        validating = np.where(y_val == j)

        if len(testing[0]) == 0:
            cor_tes = 'inf'
        elif len(correct[0]) == 0:
            cor_tes = 0.0
        else:
            cor_tes = len(correct[0])/len(testing[0])
            cor_tes = round(cor_tes, 3)

        if len(validating[0]) == 0:
            cor_val = 'inf'
        elif len(correct[0]) == 0:
            cor_val = 0.0
        else:
            cor_val = len(correct[0])/len(validating[0])
            cor_val = round(cor_val, 3)

        print('Val \t Are', CLASS_NAMES[j], ' classed ', CLASS_NAMES[j], ':', cor_val, 
              '(',len(correct[0]), 'of', len(validating[0]),')')
        print('Val \t Classed', CLASS_NAMES[j], ' are ', CLASS_NAMES[j], ':', cor_tes, 
              '(',len(correct[0]), 'of', len(testing[0]),')')
        
    print()


In [None]:
cuts = np.arange(0.00, 1.01, 0.01)

val_recall = []
val_fall_out = []
val_dist = []

t_mgr = np.where(y_val_all.numpy()[:,0] == 1)[0]
t_nmg = np.where(y_val_all.numpy()[:,0] == 0)[0]
for cut in cuts:
    mgr = np.where(val_pred.numpy()[:,0] > cut)[0]
    nmg = np.where(val_pred.numpy()[:,0] <= cut)[0]
    
    tp = np.intersect1d(t_mgr, mgr)
    fp = np.intersect1d(t_nmg, mgr)
    tn = np.intersect1d(t_nmg, nmg)
    fn = np.intersect1d(t_mgr, nmg)
    
    TP = len(tp)
    FP = len(fp)
    TN = len(tn)
    FN = len(fn)
    
    if cut == 0.5:
        print(round((TP+TN)/(TP+TN+FP+FN),3))
        print(round(TP/(TP+FN),3), TP, (TP+FN))
        print(round(TP/(TP+FP),3), TP, (TP+FP))
        print(round(TN/(TN+FP),3), TN, (TN+FP))
        print(round(TN/(TN+FN),3), TN, (TN+FN))
    
    val_recall.append(TP/(TP+FN))
    val_fall_out.append(FP/(FP+TN))
    
    val_dist.append( np.sqrt(np.square(1-val_recall[-1])+np.square(val_fall_out[-1])) )
    
best_cut_idx = np.argmin(val_dist)
best_cut = cuts[best_cut_idx]

plt.plot(val_fall_out, val_recall)
plt.plot(val_fall_out[best_cut_idx], val_recall[best_cut_idx], 'o')
plt.xlabel('Fall Out')
plt.ylabel('Recall')
plt.show()

print('Best at', best_cut)

mgr = np.where(val_pred.numpy()[:,0] > best_cut)[0]
nmg = np.where(val_pred.numpy()[:,0] <= best_cut)[0]

tp = np.intersect1d(t_mgr, mgr)
fp = np.intersect1d(t_nmg, mgr)
tn = np.intersect1d(t_nmg, nmg)
fn = np.intersect1d(t_mgr, nmg)

TP = len(tp)
FP = len(fp)
TN = len(tn)
FN = len(fn)

print('Recall:', round(TP/(TP+FN),3))
print('Precision:', round(TP/(TP+FP),3))
print('Specificity:', round(TN/(TN+FP),3))
print('NPV:', round(TN/(TN+FN),3))
print('Accuracy:', round((TP+TN)/(TP+FP+TN+FN),3))

# Testing

In [None]:
print('To test on', test_image_count)

template = 'Epoch {}\nTest Loss: {:.3g}, Test Accuracy: {:.3g}'

for epoch in range(0, EPOCHS):
    
    val_loss.reset_states()
    val_accuracy.reset_states()
    
    #Validate  
    print('Epoch', epoch+1, 'test')
    y_tst_all = None
    tst_pred = None
    for step in range(0, STEPS_PER_TEST_EPOCH):
        i_tst, m_tst, y_tst = next(iter(test_ds))
        if y_tst_all is None:
            y_tst_all = y_tst
        else:
            y_tst_all = np.vstack((y_tst_all, y_tst))
        pred = tst_step(i_tst, m_tst, y_tst)
        if tst_pred is None:
            tst_pred = pred
        else:
            tst_pred = np.vstack((tst_pred, pred))
    
    print(template.format(epoch+1, tst_loss.result(), tst_accuracy.result()))
    
    y_tst = np.argmax(y_tst_all, axis=1)
    y_out_tst_agm = np.argmax(tst_pred, axis=1)
    for j in range(0, NO_CLASS):
        testing = np.where(y_out_tst_agm == j)
        correct = np.where(np.logical_and(y_out_tst_agm == j, y_tst == j))
        validating = np.where(y_tst == j)

        if len(testing[0]) == 0:
            cor_tes = 'inf'
        elif len(correct[0]) == 0:
            cor_tes = 0.0
        else:
            cor_tes = len(correct[0])/len(testing[0])
            cor_tes = round(cor_tes, 3)

        if len(validating[0]) == 0:
            cor_val = 'inf'
        elif len(correct[0]) == 0:
            cor_val = 0.0
        else:
            cor_val = len(correct[0])/len(validating[0])
            cor_val = round(cor_val, 3)

        print('Tst \t Are', CLASS_NAMES[j], ' classed ', CLASS_NAMES[j], ':', cor_val, 
              '(',len(correct[0]), 'of', len(validating[0]),')')
        print('Tst \t Classed', CLASS_NAMES[j], ' are ', CLASS_NAMES[j], ':', cor_tes, 
              '(',len(correct[0]), 'of', len(testing[0]),')')
        
    print()

In [None]:
print('Best at', best_cut)

t_mgr = np.where(y_tst_all.numpy()[:,0] == 1)[0]
t_nmg = np.where(y_tst_all.numpy()[:,0] == 0)[0]

mgr = np.where(tst_pred.numpy()[:,0] > best_cut)[0]
nmg = np.where(tst_pred.numpy()[:,0] <= best_cut)[0]

tp = np.intersect1d(t_mgr, mgr)
fp = np.intersect1d(t_nmg, mgr)
tn = np.intersect1d(t_nmg, nmg)
fn = np.intersect1d(t_mgr, nmg)

TP = len(tp)
FP = len(fp)
TN = len(tn)
FN = len(fn)

print('Recall:', round(TP/(TP+FN),3))
print('Precision:', round(TP/(TP+FP),3))
print('Specificity:', round(TN/(TN+FP),3))
print('NPV:', round(TN/(TN+FN),3))
print('Accuracy:', round((TP+TN)/(TP+FP+TN+FN),3))

In [None]:
tst_recall = []
tst_fall_out = []
tst_dist = []

t_mgr = np.where(y_tst_all.numpy()[:,0] == 1)[0]
t_nmg = np.where(y_tst_all.numpy()[:,0] == 0)[0]
for cut in cuts:
    mgr = np.where(tst_pred.numpy()[:,0] > cut)[0]
    nmg = np.where(tst_pred.numpy()[:,0] <= cut)[0]
    
    tp = np.intersect1d(t_mgr, mgr)
    fp = np.intersect1d(t_nmg, mgr)
    tn = np.intersect1d(t_nmg, nmg)
    fn = np.intersect1d(t_mgr, nmg)
    
    TP = len(tp)
    FP = len(fp)
    TN = len(tn)
    FN = len(fn)
    
    if cut == 0.5:
        print(round((TP+TN)/(TP+TN+FP+FN),3))
        print(round(TP/(TP+FN),3), TP, (TP+FN))
        print(round(TP/(TP+FP),3), TP, (TP+FP))
        print(round(TN/(TN+FP),3), TN, (TN+FP))
        print(round(TN/(TN+FN),3), TN, (TN+FN))
    
    tst_recall.append(TP/(TP+FN))
    tst_fall_out.append(FP/(FP+TN))
    
    tst_dist.append( np.sqrt(np.square(1-tst_recall[-1])+np.square(tst_fall_out[-1])) )
    
better_cut_idx = np.argmin(tst_dist)
better_cut = cuts[better_cut_idx]

plt.plot(tst_fall_out, tst_recall)
plt.plot(tst_fall_out[better_cut_idx], tst_recall[better_cut_idx], 'o')
plt.xlabel('Fall Out')
plt.ylabel('Recall')
plt.show()

print('Better at', better_cut)

mgr = np.where(tst_pred.numpy()[:,0] > better_cut)[0]
nmg = np.where(tst_pred.numpy()[:,0] <= better_cut)[0]

tp = np.intersect1d(t_mgr, mgr)
fp = np.intersect1d(t_nmg, mgr)
tn = np.intersect1d(t_nmg, nmg)
fn = np.intersect1d(t_mgr, nmg)

TP = len(tp)
FP = len(fp)
TN = len(tn)
FN = len(fn)

print('Recall:', round(TP/(TP+FN),3))
print('Precision:', round(TP/(TP+FP),3))
print('Specificity:', round(TN/(TN+FP),3))
print('NPV:', round(TN/(TN+FN),3))
print('Accuracy:', round((TP+TN)/(TP+FP+TN+FN),3))

In [None]:
order = np.argsort(val_fall_out)
val_fall_out_sort = np.array(val_fall_out)[order]
val_recall_sort = np.array(val_recall)[order]
val_auc = np.trapz(val_recall_sort, val_fall_out_sort)

order = np.argsort(tst_fall_out)
tst_fall_out_sort = np.array(tst_fall_out)[order]
tst_recall_sort = np.array(tst_recall)[order]
tst_auc = np.trapz(tst_recall_sort, tst_fall_out_sort)

fig = plt.figure(figsize=(5.5, 1*(5.5*2)/3.0))
ax = plt.axes([1,1,1,1])
ax.plot([0.0, 1.0],[0.0, 1.0], ':', label='Random Network', c='r', lw=2)
ax.plot(val_fall_out, val_recall, label='Validation', c='b', lw=2)
ax.plot(tst_fall_out, tst_recall, label='Testing', c='g', lw=2)

ax.text(0.7, 0.35, ' Val AUC: '+str(round(val_auc,3)), fontsize=14)
ax.text(0.7, 0.27, 'Test AUC: '+str(round(tst_auc,3)), fontsize=14)

ax.tick_params(axis='both', which='major', labelsize=14)
ax.tick_params(axis='x', which='major', direction='in', top=True)
ax.tick_params(axis='y', which='major', direction='in', right=True)
ax.set_xlabel('Fall Out', fontsize=16)
ax.set_ylabel('Recall', fontsize=16)

ax.legend(loc=0, frameon=False, fontsize=14)

plt.show()