# Librerías

In [1]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
import random
import matplotlib as mpl
from sklearn.metrics import confusion_matrix
from datetime import datetime
from keras.callbacks import CSVLogger
from google.colab import drive

drive.mount('/content/drive')
!mkdir /root/tensorflow_datasets
!cp -r /content/drive/MyDrive/tensorflow_dataset/galaxy_zoo3d /root/tensorflow_datasets/.

Mounted at /content/drive


# Complete

In [2]:
NUM_EPOCHS = 150
size = 128
mask = 'spiral_mask'
threshold = 4
BATCH_SIZE = 32
BUFFER_SIZE = 300

# Data

In [3]:
ds, info = tfds.load('galaxy_zoo3d', split=['train[:75%]', 'train[75%:]'], with_info=True)
ds_train, ds_test = ds[0], ds[1]

min_vote = 3

ds_train_spirals = ds_train.filter(lambda x: tf.reduce_max(x['spiral_mask']) >= min_vote)
ds_train_bars = ds_train.filter(lambda x: tf.reduce_max(x['bar_mask']) >= min_vote)

ds_test_spirals = ds_test.filter(lambda x: tf.reduce_max(x['spiral_mask']) >= min_vote)
ds_test_bars = ds_test.filter(lambda x: tf.reduce_max(x['bar_mask']) >= min_vote)

In [4]:
def resize(input_image, input_mask):
    input_image = tf.image.resize(input_image, (size, size), method="nearest")
    input_mask = tf.image.resize(input_mask, (size, size), method="nearest")

    return input_image, input_mask 


def augment(input_image, input_mask):
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)

    return input_image, input_mask


def normalize(input_image):
    input_image = tf.cast(input_image, tf.float32) / 255.0
  
    return input_image


def binary_mask(input_mask):
    th = threshold
    input_mask = tf.where(input_mask<th, tf.zeros_like(input_mask), tf.ones_like(input_mask))
    
    return input_mask
    
    
def load_image_train(datapoint):
    input_image = datapoint['image']
    input_mask = datapoint[mask]
    input_image, input_mask = resize(input_image, input_mask)
    input_image, input_mask = augment(input_image, input_mask)
    input_image = normalize(input_image)
    input_mask = binary_mask(input_mask)

    return input_image, input_mask


def load_image_test(datapoint):
    input_image = datapoint['image']
    input_mask = datapoint[mask]
    input_image, input_mask = resize(input_image, input_mask)
    input_image = normalize(input_image)
    input_mask = binary_mask(input_mask)

    return input_image, input_mask

In [5]:
train_dataset = ds_train_spirals.map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = ds_test_spirals.map(load_image_test, num_parallel_calls=tf.data.AUTOTUNE)

# Brazos: 1100, 539. Barras: 800, 453
train_batches = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_batches = train_batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
validation_batches = test_dataset.take(1100).batch(BATCH_SIZE)
test_batches = test_dataset.skip(1100).take(539).batch(BATCH_SIZE)

# Visualización

In [6]:
def display(display_list):
  plt.figure(figsize=(15, 15))

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis("off")
  plt.show()

# U-Net

In [7]:
def double_conv_block(x, n_filters):

    x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
    x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)

    return x


def downsample_block(x, n_filters):
    
    f = double_conv_block(x, n_filters)
    p = layers.MaxPool2D(2)(f)
    p = layers.Dropout(0.3)(p)

    return f, p


def upsample_block(x, conv_features, n_filters):
    
    x = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)
    x = layers.concatenate([x, conv_features])
    x = layers.Dropout(0.3)(x)
    x = double_conv_block(x, n_filters)

    return x

In [8]:
def build_unet_model():

    inputs = layers.Input(shape=(size,size,3))

    f1, p1 = downsample_block(inputs, size/2)
    f2, p2 = downsample_block(p1, size)
    f3, p3 = downsample_block(p2, size*2)
    f4, p4 = downsample_block(p3, size*4)

    bottleneck = double_conv_block(p4, size*8)

    u6 = upsample_block(bottleneck, f4, size*4)
    u7 = upsample_block(u6, f3, size*2)
    u8 = upsample_block(u7, f2, size)
    u9 = upsample_block(u8, f1, size/2)

    outputs = layers.Conv2D(2, 1, padding="same", activation = "softmax")(u9)

    unet_model = tf.keras.Model(inputs, outputs, name="U-Net")

    return unet_model

In [9]:
unet_model = build_unet_model()

In [None]:
unet_model.compile(optimizer=tf.keras.optimizers.Adam(),
                   loss="sparse_categorical_crossentropy",
                   metrics="accuracy"
                  )

path = '/content/drive/MyDrive/Galaxy Segmentation Project/Modelos/'
date = datetime.now().strftime("%Y_%m_%d-%H:%M:%S")
csv_log = CSVLogger(f'{path}{date}_{mask}_epochs:{NUM_EPOCHS}_size:{size}_th:{threshold}.csv')

# Brazos: 4883, 1639. Barras: 3783, 1253. 
TRAIN_LENGTH = 4883
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

VAL_SUBSPLITS = 5
TEST_LENGTH = 1639
VALIDATION_STEPS = TEST_LENGTH // BATCH_SIZE // VAL_SUBSPLITS

model_history = unet_model.fit(train_batches,
                               epochs=NUM_EPOCHS,
                               steps_per_epoch=STEPS_PER_EPOCH,
                               validation_steps=VALIDATION_STEPS,
                               validation_data=validation_batches,
                               callbacks=[csv_log])

unet_model.save(f'{path}{date}_{mask}_epochs:{NUM_EPOCHS}_size:{size}_th:{threshold}.h5')

Epoch 1/150
Epoch 2/150
Epoch 3/150
Epoch 4/150
Epoch 5/150
Epoch 6/150
Epoch 7/150
Epoch 8/150
Epoch 9/150
Epoch 10/150
Epoch 11/150
Epoch 12/150
Epoch 13/150
Epoch 14/150
Epoch 15/150
Epoch 16/150
Epoch 17/150
Epoch 18/150
Epoch 19/150
Epoch 20/150
Epoch 21/150
Epoch 22/150
Epoch 23/150
Epoch 24/150
Epoch 25/150
Epoch 26/150
Epoch 27/150
Epoch 28/150
Epoch 29/150
Epoch 30/150
Epoch 31/150
Epoch 32/150
Epoch 33/150
Epoch 34/150
Epoch 35/150
Epoch 36/150
Epoch 37/150
Epoch 38/150
Epoch 39/150
Epoch 40/150
Epoch 41/150
Epoch 42/150
Epoch 43/150
Epoch 44/150
Epoch 45/150
Epoch 46/150
Epoch 47/150
Epoch 48/150
Epoch 49/150
Epoch 50/150
Epoch 51/150
Epoch 52/150
Epoch 53/150
Epoch 54/150
Epoch 55/150
Epoch 56/150
Epoch 57/150
Epoch 58/150
Epoch 59/150
Epoch 60/150
Epoch 61/150
Epoch 62/150
Epoch 63/150
Epoch 64/150
Epoch 65/150
Epoch 66/150
Epoch 67/150
Epoch 68/150
Epoch 69/150
Epoch 70/150
Epoch 71/150
Epoch 72/150
Epoch 73/150
Epoch 74/150
Epoch 75/150
Epoch 76/150
Epoch 77/150
Epoch 78

# Curvas de aprendizaje

In [None]:
def display_learning_curves(history):

    acc = history.history["accuracy"]
    val_acc = history.history["val_accuracy"]

    loss = history.history["loss"]
    val_loss = history.history["val_loss"]

    epochs_range = range(NUM_EPOCHS)

    fig, ax = plt.subplots(1, 2, figsize=(12,6))

    ax[0].plot(epochs_range, acc, label="train accuracy")
    ax[0].plot(epochs_range, val_acc, label="validataion accuracy")
    ax[0].set_title("Accuracy")
    ax[0].set_xlabel("Epoch")
    ax[0].set_ylabel("Accuracy")
    ax[0].legend(loc="lower right")

    ax[1].plot(epochs_range, loss, label="train loss")
    ax[1].plot(epochs_range, val_loss, label="validataion loss")
    ax[1].set_title("Loss")
    ax[1].set_xlabel("Epoch")
    ax[1].set_ylabel("Loss")
    ax[1].set_ylim(min(min(loss), min(val_loss)), max(max(loss[1:]), max(val_loss[1:])))
    ax[1].legend(loc="upper right")

    fig.tight_layout()
    fig.show()


display_learning_curves(unet_model.history)

# Predicciones

In [None]:
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]


def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = unet_model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
      print('Matriz de confusión:')
      print(confusion_matrix(create_mask(pred_mask).numpy().reshape(-1), mask[0].numpy().reshape(-1)))

# count = 0
# for i in test_batches:
#     count +=1
# print("number of batches:", count)

show_predictions(test_batches.skip(3), 15)