# Complete

In [None]:
run_from = 'google_colab'
mask = 'spiral_mask'
size = 128
BATCH_SIZE = 32

# Libraries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from tensorflow.python.keras.models import load_model
from sklearn.metrics import confusion_matrix, jaccard_score
import warnings
warnings.filterwarnings('ignore')


if run_from == 'google_colab':

    from google.colab import drive

    drive.mount('/content/drive')
    !mkdir /root/tensorflow_datasets
    !cp -r /content/drive/MyDrive/galaxy-segmentation-project/tensorflow_dataset/galaxy_zoo3d /root/tensorflow_datasets/.

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
mkdir: cannot create directory ‘/root/tensorflow_datasets’: File exists


# Useful functions

In [None]:
def resize(input_image, input_mask):
    input_image = tf.image.resize(input_image, (size, size), method="nearest")
    input_mask = tf.image.resize(input_mask, (size, size), method="nearest")

    return input_image, input_mask 


def normalize(input_image):
    input_image = tf.cast(input_image, tf.float32) / 255.0
  
    return input_image


def binary_mask(input_mask, th):
    input_mask = tf.where(input_mask<th, tf.zeros_like(input_mask), tf.ones_like(input_mask))
    
    return input_mask


def load_image_test(datapoint):
    input_image = datapoint['image']
    input_mask = datapoint[mask]
    manga_id = datapoint['mangaid']
    input_image, input_mask = resize(input_image, input_mask)
    input_image = normalize(input_image)
    input_mask_1 = binary_mask(input_mask, th=1)
    input_mask_2 = binary_mask(input_mask, th=2)
    input_mask_3 = binary_mask(input_mask, th=3)
    input_mask_4 = binary_mask(input_mask, th=4)

    return manga_id, input_image, input_mask_1, input_mask_2, input_mask_3, input_mask_4


def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]

  return pred_mask


def predict_unet(model, dataset_batch):
  pred_mask = create_mask(model.predict(dataset_batch))
  
  return pred_mask

# Data

## Dataset

In [None]:
ds_test= tfds.load('galaxy_zoo3d', split=['train[75%:]'])[0]

test_batches = ds_test.map(load_image_test, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE)

## Models

In [None]:
model_1 = '2022_08_26-22:05:28_only_spiral-mask_epochs:150_size:128_th:1_patience:5'
model_2 = '2022_08_25-13:16:28_all_spiral-mask_epochs:150_size:128_th:1_patience:5'
model_3 = '2022_08_16-15:53:44_only_spiral-mask_epochs:150_size:128_th:2'
model_4 = '2022_08_24-22:03:29_all_spiral-mask_epochs:150_size:128_th:2_patience:6'
model_5 = '2022_09_09-12:38:50_only_spiral-mask_epochs:50_size:128_th:3_patience:50'
model_6 = '2022_08_17-13:52:32_all_spiral-mask_epochs:150_size:128_th:3_patience:10'
model_7 = '2022_08_16-14:21:51_only_spiral-mask_epochs:150_size:128_th:4'
model_8 = '2022_08_26-18:39:35_all_spiral-mask_epochs:150_size:128_th:4_patience:5'

models = [model_1, model_2, model_3, model_4,
          model_5, model_6, model_7, model_8]

In [None]:
path = '/content/drive/MyDrive/galaxy-segmentation-project/Modelos/'

unet_models = [load_model(f'{path}{model}/{model}_best.h5') for model in models]

# Images and scores saver

This cell takes the test_bacthes galaxies's images (~7450), predicts their masks for different thresholds (minimum amount of votes per pixel) and saves a plot for each galaxy in a .jpg format and their respective scores in a .csv file. Each galaxy is identified by its own MaNGA ID. 

## Plot saver function

In [None]:
def plot_saver(gal_batch, gal, path):
    fig, ax = plt.subplots(3, 5, figsize=(25, 15), sharey=True)
    ax[0,0].imshow(tf.keras.utils.array_to_img(gal_batch[1][gal]))
    ax[0,1].imshow(tf.keras.utils.array_to_img(gal_batch[2][gal]))
    ax[0,2].imshow(tf.keras.utils.array_to_img(gal_batch[3][gal]))
    ax[0,3].imshow(tf.keras.utils.array_to_img(gal_batch[4][gal]))
    ax[0,4].imshow(tf.keras.utils.array_to_img(gal_batch[5][gal]))
    ax[1,1].imshow(tf.keras.utils.array_to_img(pred_masks[0][gal]))
    ax[2,1].imshow(tf.keras.utils.array_to_img(pred_masks[1][gal]))
    ax[1,2].imshow(tf.keras.utils.array_to_img(pred_masks[2][gal]))
    ax[2,2].imshow(tf.keras.utils.array_to_img(pred_masks[3][gal]))
    ax[1,3].imshow(tf.keras.utils.array_to_img(pred_masks[4][gal]))
    ax[2,3].imshow(tf.keras.utils.array_to_img(pred_masks[5][gal]))
    ax[1,4].imshow(tf.keras.utils.array_to_img(pred_masks[6][gal]))
    ax[2,4].imshow(tf.keras.utils.array_to_img(pred_masks[7][gal]))
    ax[0,0].set_title('Image', fontsize=20, fontweight='bold', pad=25)
    ax[0,1].set_title('th = 1', fontsize=20, fontweight='bold', pad=25)
    ax[0,2].set_title('th = 2', fontsize=20, fontweight='bold', pad=25)
    ax[0,3].set_title('th = 3', fontsize=20, fontweight='bold', pad=25)
    ax[0,4].set_title('th = 4', fontsize=20, fontweight='bold', pad=25)
    ax[0,0].set_xlabel('ID: '+gal_ids[gal], fontsize=20, fontweight='bold', labelpad=10)
    ax[0,4].set_ylabel('True masks', y=0.49, fontsize=20, fontweight='bold', labelpad=40, rotation=270)
    ax[1,1].set_ylabel('Trained with \n\n only spirals', x=0.13, y=0.37, fontsize=20, fontweight='bold', labelpad=95, rotation=0)
    ax[2,1].set_ylabel('Trained with \n\n all images', x=0.13, y=0.37, fontsize=20, fontweight='bold', labelpad=95, rotation=0)
    ax[0,4].yaxis.set_label_position("right")
    [axi.set_yticklabels([]) for axi in ax.ravel()]
    [axi.set_xticklabels([]) for axi in ax.ravel()]
    [axi.tick_params(axis='both', which='both',length=0) for axi in ax.ravel()]
    fig.delaxes(ax[1,0])
    fig.delaxes(ax[2,0])
    fig.suptitle('Threshold', x=0.585, y=1.05, fontsize=30, fontweight='bold')
    fig.text(0.04, 0.32, 'Predictions', va='center', rotation='vertical', fontsize=30, fontweight='bold')
    fig.tight_layout(pad=0.4, w_pad=-15, h_pad=0)
    fig.savefig(path)
    plt.close(fig)

## Plots and scores saving

In [None]:
iterables = [['Accuracy', 'Precision', 'Sensitivity', 'Specificity', 'Jaccard'], ['th=1', 'th=2', 'th=3', 'th=4'], ['Only spirals', 'All images']]
headers = pd.MultiIndex.from_product(iterables)

df_scores = pd.DataFrame(columns=headers)
df_scores.insert(0, 'Galaxy ID', 0)
# df_scores.to_csv('/content/drive/MyDrive/galaxy-segmentation-project/scores.csv', header=df_scores.columns, index=False)

count = 0
t_i = time.time()


for gal_batch in test_batches.skip(223):

  df_scores = pd.DataFrame(columns=headers)
  df_scores.insert(0, 'Galaxy ID', 0)

  gal_ids = [str(gal_batch[0].numpy()[i]) for i in range(len(gal_batch[0]))]

  true_masks = [gal_batch[i] for i in range(2,6) for _ in range(2)]
  pred_masks = [predict_unet(unet_model, gal_batch[1]) for unet_model in unet_models] 

  for gal in range(len(gal_batch[0])):

    gal_id = gal_ids[gal]

    conf_matrices = [confusion_matrix(pred_mask[gal].numpy().reshape(-1), true_mask[gal].numpy().reshape(-1), labels=[0,1]) for pred_mask, true_mask in zip(pred_masks, true_masks)]
    
    TNs = [conf_matrix[0,0] for conf_matrix in conf_matrices]
    FNs = [conf_matrix[0,1] for conf_matrix in conf_matrices]
    FPs = [conf_matrix[1,0] for conf_matrix in conf_matrices]
    TPs = [conf_matrix[1,1] for conf_matrix in conf_matrices]

    accuracies = [(TP+TN)/(TP+TN+FP+FN) for TP, TN, FP, FN in zip(TPs, TNs, FPs, FNs)]
    precisions = [TP/(TP+FP) if TP+FP!=0 else np.nan for TP, FP in zip(TPs, FPs)]
    sensitivities = [TP/(TP+FN) if TP+FN!=0 else np.nan for TP, FN in zip(TPs, FNs)]
    specificities = [TN/(TN+FP) if TN+FP!=0 else np.nan for TN, FP in zip(TNs, FPs)]
    
    jacc_scores = [jaccard_score(pred_mask[gal].numpy().reshape(-1), true_mask[gal].numpy().reshape(-1)) 
    if len(np.where(np.logical_or(pred_mask[gal].numpy().reshape(-1)==1, true_mask[gal].numpy().reshape(-1)==1))[0]) > 0 else np.nan 
    for pred_mask, true_mask in zip(pred_masks, true_masks)]

    scores = [gal_id]
    specificities.extend(jacc_scores), sensitivities.extend(specificities), precisions.extend(sensitivities), accuracies.extend(precisions), scores.extend(accuracies)

    df_scores.loc[len(df_scores.index)] = scores

    path = f'/content/drive/MyDrive/galaxy-segmentation-project/Binary masks figures/{gal_ids[gal]}.jpg'

    plot_saver(gal_batch, gal, path)

    del accuracies


  df_scores.to_csv('/content/drive/MyDrive/galaxy-segmentation-project/scores.csv', mode='a', header=False, index=False)

  t_f = time.time()
  print(f'Batch #{count}: {round((t_f-t_i)/60, 2)} mins')
  count += 1
  t_i = t_f