# Complete

In [None]:
run_from = 'google_colab'
NUM_EPOCHS = 150
size = 128
train_with, mask = 'only', 'spiral_mask'
min_vote = 3
threshold = 1
patience = 5
BATCH_SIZE = 32

# Libraries

In [None]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
import tensorflow_datasets as tfds
from tensorflow.python.keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import os
from datetime import datetime
from sklearn.metrics import confusion_matrix, jaccard_score
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
import time
import warnings
warnings.filterwarnings('ignore')


if run_from == 'google_colab':

    from google.colab import drive

    drive.mount('/content/drive')
    !mkdir /root/tensorflow_datasets
    !cp -r /content/drive/MyDrive/tensorflow_dataset/galaxy_zoo3d /root/tensorflow_datasets/.

Mounted at /content/drive


# Data

## Useful functions

In [None]:
def resize(input_image, input_mask):
    input_image = tf.image.resize(input_image, (size, size), method="nearest")
    input_mask = tf.image.resize(input_mask, (size, size), method="nearest")

    return input_image, input_mask 


def augment(input_image, input_mask):
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)

    return input_image, input_mask


def normalize(input_image):
    input_image = tf.cast(input_image, tf.float32) / 255.0
  
    return input_image


def binary_mask(input_mask, th):
    input_mask = tf.where(input_mask<th, tf.zeros_like(input_mask), tf.ones_like(input_mask))
    
    return input_mask


def load_image_test(datapoint):
    input_image = datapoint['image']
    input_mask = datapoint[mask]
    manga_id = datapoint['mangaid']
    input_image, input_mask = resize(input_image, input_mask)
    input_image = normalize(input_image)
    input_mask_1 = binary_mask(input_mask, th=1)
    input_mask_2 = binary_mask(input_mask, th=2)
    input_mask_3 = binary_mask(input_mask, th=3)
    input_mask_4 = binary_mask(input_mask, th=4)

    return manga_id, input_image, input_mask_1, input_mask_2, input_mask_3, input_mask_4


def display(display_list):
  plt.figure(figsize=(15, 15))
  titles = ['Image', 'Mask', 'Prediction']
  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.title(titles[i])
    plt.axis("off")
  plt.show()


def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask

## Data loader

In [None]:
model_1 = '2022_08_26-22:05:28_only_spiral-mask_epochs:150_size:128_th:1_patience:5'
model_2 = '2022_08_25-13:16:28_all_spiral-mask_epochs:150_size:128_th:1_patience:5'
model_3 = '2022_08_16-15:53:44_only_spiral-mask_epochs:150_size:128_th:2'
model_4 = '2022_08_24-22:03:29_all_spiral-mask_epochs:150_size:128_th:2_patience:6'
model_5 = '2022_08_23-13:52:10_only_bar-mask_epochs:150_size:128_th:3_patience:5'
model_6 = '2022_08_17-13:52:32_all_spiral-mask_epochs:150_size:128_th:3_patience:10'
model_7 = '2022_08_16-14:21:51_only_spiral-mask_epochs:150_size:128_th:4'
model_8 = '2022_08_26-18:39:35_all_spiral-mask_epochs:150_size:128_th:4_patience:5'

models = [model_1, model_2, model_3, model_4,
          model_5, model_6, model_7, model_8]

In [None]:
path = '/content/drive/MyDrive/Galaxy Segmentation Project/Modelos/'

unet_models = [load_model(f'{path}{model}/{model}_best.h5') for model in models]

In [None]:
ds_test= tfds.load('galaxy_zoo3d', split=['train[75%:]'])[0]

test_batches = ds_test.map(load_image_test, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE)

In [None]:
def predict_unet(model, dataset_batch):
  pred_mask = create_mask(model.predict(dataset_batch))
  return pred_mask

In [None]:
iterables = [['Accuracy', 'Precision', 'Sensitivity', 'Specificity', 'Jaccard'], ['th=1', 'th=2', 'th=3', 'th=4'], ['Only spirals', 'All images']]
headers = pd.MultiIndex.from_product(iterables)

df_scores = pd.DataFrame(columns=headers)
df_scores.insert(0, 'Galaxy ID', 0)
# df_scores.to_csv('/content/drive/MyDrive/galaxy-segmentation-project/scores(1).csv', header=df_scores.columns, index=False)

count = 0
t_i = time.time()


for gal_batch in test_batches.skip(34).take(20):

  df_scores = pd.DataFrame(columns=headers)
  df_scores.insert(0, 'Galaxy ID', 0)

  gal_ids = [str(gal_batch[0].numpy()[i]) for i in range(len(gal_batch[0]))]

  true_masks = [gal_batch[i] for i in range(2,6) for _ in range(2)]
  pred_masks = [predict_unet(unet_model, gal_batch[1]) for unet_model in unet_models] 

  for gal in range(len(gal_batch[0])):

    gal_id = gal_ids[gal]

    conf_matrices = [confusion_matrix(pred_mask[gal].numpy().reshape(-1), true_mask[gal].numpy().reshape(-1), labels=[0,1]) for pred_mask, true_mask in zip(pred_masks, true_masks)]
    
    TNs = [conf_matrix[0,0] for conf_matrix in conf_matrices]
    FNs = [conf_matrix[0,1] for conf_matrix in conf_matrices]
    FPs = [conf_matrix[1,0] for conf_matrix in conf_matrices]
    TPs = [conf_matrix[1,1] for conf_matrix in conf_matrices]

    accuracies = [(TP+TN)/(TP+TN+FP+FN) for TP, TN, FP, FN in zip(TPs, TNs, FPs, FNs)]
    precisions = [TP/(TP+FP) if TP+FP!=0 else np.nan for TP, FP in zip(TPs, FPs)]
    sensitivities = [TP/(TP+FN) if TP+FN!=0 else np.nan for TP, FN in zip(TPs, FNs)]
    specificities = [TN/(TN+FP) if TN+FP!=0 else np.nan for TN, FP in zip(TNs, FPs)]
    
    jacc_scores = [jaccard_score(pred_mask[gal].numpy().reshape(-1), true_mask[gal].numpy().reshape(-1)) 
    if len(np.where(np.logical_or(pred_mask[gal].numpy().reshape(-1)==1, true_mask[gal].numpy().reshape(-1)==1))[0]) > 0 else np.nan 
    for pred_mask, true_mask in zip(pred_masks, true_masks)]

    scores = [gal_id]
    specificities.extend(jacc_scores), sensitivities.extend(specificities), precisions.extend(sensitivities), accuracies.extend(precisions), scores.extend(accuracies)

    df_scores.loc[len(df_scores.index)] = scores

    fig, ax = plt.subplots(3, 5, figsize=(25, 15), sharey=True)
    ax[0,0].imshow(tf.keras.utils.array_to_img(gal_batch[1][gal]))
    ax[0,0].set_title('Image', fontsize=20, fontweight='bold', pad=25)
    ax[0,0].set_xlabel('ID: '+gal_ids[gal], fontsize=20, fontweight='bold', labelpad=10)
    ax[0,1].imshow(tf.keras.utils.array_to_img(gal_batch[2][gal]))
    ax[0,1].set_title('th = 1', fontsize=20, fontweight='bold', pad=25)
    ax[0,2].imshow(tf.keras.utils.array_to_img(gal_batch[3][gal]))
    ax[0,2].set_title('th = 2', fontsize=20, fontweight='bold', pad=25)
    ax[0,3].imshow(tf.keras.utils.array_to_img(gal_batch[4][gal]))
    ax[0,3].set_title('th = 3', fontsize=20, fontweight='bold', pad=25)
    ax[0,4].imshow(tf.keras.utils.array_to_img(gal_batch[5][gal]))
    ax[0,4].set_ylabel('True masks', y=0.49, fontsize=20, fontweight='bold', labelpad=40, rotation=270)
    ax[0,4].yaxis.set_label_position("right")
    ax[0,4].set_title('th = 4', fontsize=20, fontweight='bold', pad=25)
    ax[1,1].imshow(tf.keras.utils.array_to_img(pred_masks[0][gal]))
    ax[1,1].set_ylabel('Trained with \n\n only spirals', x=0.13, y=0.37, fontsize=20, fontweight='bold', labelpad=95, rotation=0)
    ax[2,1].imshow(tf.keras.utils.array_to_img(pred_masks[1][gal]))
    ax[2,1].set_ylabel('Trained with \n\n all images', x=0.13, y=0.37, fontsize=20, fontweight='bold', labelpad=95, rotation=0)
    ax[1,2].imshow(tf.keras.utils.array_to_img(pred_masks[2][gal]))
    ax[2,2].imshow(tf.keras.utils.array_to_img(pred_masks[3][gal]))
    ax[1,3].imshow(tf.keras.utils.array_to_img(pred_masks[4][gal]))
    ax[2,3].imshow(tf.keras.utils.array_to_img(pred_masks[5][gal]))
    ax[1,4].imshow(tf.keras.utils.array_to_img(pred_masks[6][gal]))
    ax[2,4].imshow(tf.keras.utils.array_to_img(pred_masks[7][gal]))
    [axi.set_yticklabels([]) for axi in ax.ravel()]
    [axi.set_xticklabels([]) for axi in ax.ravel()]
    [axi.tick_params(axis='both', which='both',length=0) for axi in ax.ravel()]
    fig.delaxes(ax[1,0])
    fig.delaxes(ax[2,0])
    fig.suptitle('Threshold', x=0.585, y=1.05, fontsize=30, fontweight='bold')
    fig.text(0.04, 0.32, 'Predictions', va='center', rotation='vertical', fontsize=30, fontweight='bold')
    fig.tight_layout(pad=0.4, w_pad=-15, h_pad=0)
    fig.savefig(f'/content/drive/MyDrive/Galaxy Segmentation Project/Binary masks figures/{gal_ids[gal]}.jpg')
    fig.show(False)
    plt.close(fig)

  df_scores.to_csv('/content/drive/MyDrive/galaxy-segmentation-project/scores.csv', mode='a', header=False, index=False)

  t_f = time.time()
  print(f'Batch #{count}: {round((t_f-t_i)/60, 2)} mins')
  count += 1
  t_i = t_f

Batch #0: 1.38 mins
Batch #1: 1.25 mins
Batch #2: 0.77 mins
Batch #3: 0.82 mins


In [None]:
import pandas as pd
df = pd.read_csv('/content/drive/MyDrive/galaxy-segmentation-project/scores.csv')
df[df.columns[0]].value_counts()[:100]

b'1-185076'    14
b'1-151418'    14
b'1-619112'    14
b'1-494298'    14
b'1-432161'    14
               ..
b'1-32070'     12
b'1-277161'    11
b'1-212813'    11
b'1-284631'    11
b'1-287487'    11
Name: ('Galaxy ID', '', ''), Length: 100, dtype: int64

In [None]:
gal_ids = []
count = 0
for gal_batch in test_batches:
  gal_batch_ids = [str(gal_batch[0].numpy()[i]) for i in range(len(gal_batch[0]))]
  gal_ids.extend(gal_batch_ids)
  print(count)
  count += 1

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232


In [None]:
for i in range(10):
  pred = [ind for ind in range(5)]
print(pred)

[0, 1, 2, 3, 4]
