In [None]:
%matplotlib inline
# %config InlineBackend.figure_format = 'retina'

import tensorflow as tf
import keras

import os
from glob import glob
from typing import List, Tuple, Union
import enum

import matplotlib.pyplot as plt
import numpy as np

# Dataset

In [None]:
IMG_HEIGHT = 256#512
IMG_WIDTH  = 256#512
IMG_CHANNELS = 3
NUM_CLASSES = 5
BATCH_SIZE = 4

DATA_DIR = "dataset"
NUM_TRAIN_IMAGES = 254
NUM_VAL_IMAGES = 46
NUM_TEST_IMAGES = 46

#class ENCODING(enum):
#    background = 0
#    sky = 1
#    sun = 2
#    thick_cloud = 3
#    thin_cloud = 4
#    other = 8

# Read images
images_path = sorted(glob(os.path.join(DATA_DIR, "images/*")))
masks_path  = sorted(glob(os.path.join(DATA_DIR, "masks/*")))

if len(images_path) != len(masks_path):
    raise RuntimeError("There must be the same number of images and masks!")


# Random shuffle
np.random.seed(14092000)
perm = np.random.permutation(len(images_path))
images_path = [images_path[perm[i]] for i in range(len(images_path))]
masks_path = [masks_path[perm[i]] for i in range(len(images_path))]

# Divide into train, val and test
train_images = images_path[:NUM_TRAIN_IMAGES]
train_masks = masks_path[:NUM_TRAIN_IMAGES]
val_images = images_path[NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES]
val_masks = masks_path[NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES]
test_images = images_path[NUM_VAL_IMAGES + NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES + NUM_TEST_IMAGES]
test_masks = masks_path[NUM_VAL_IMAGES + NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES + NUM_TEST_IMAGES]

train_images = sorted(glob(os.path.join(DATA_DIR,'augmented_images_train','*'))) + train_images
train_masks  = sorted(glob(os.path.join(DATA_DIR,'augmented_masks_train','*')))  + train_masks
val_images = sorted(glob(os.path.join(DATA_DIR,'augmented_images_val','*'))) + val_images
val_masks  = sorted(glob(os.path.join(DATA_DIR,'augmented_masks_val','*')))  + val_masks

# Load into tf.data.Dataset
def read_image(image_path:str, isMask:bool=False, num_classes=NUM_CLASSES) -> tf.Tensor:
    '''
    Read either image or mask from its path. Returns a tensor.
    
    Mask are hot enconded.
    '''
    image = tf.io.read_file(image_path)
    if isMask:
        image = tf.image.decode_png(image, channels=1)
        image.set_shape([None, None, 1])
        image = tf.image.resize(images=image, size=[IMG_WIDTH, IMG_HEIGHT])
        image = tf.cast(image, dtype=tf.uint8)
        # Other classified as number 8 (check)
        if num_classes > 1:
            image = tf.keras.utils.to_categorical(image, num_classes = num_classes)
            image = tf.squeeze(image,axis=2) #remove extra axis
    else:
        image = tf.image.decode_png(image, channels=IMG_CHANNELS)
        image.set_shape([None, None, IMG_CHANNELS])
        image = tf.image.resize(images=image, size=[IMG_WIDTH, IMG_HEIGHT])
        image = tf.cast(image, dtype=tf.uint8)
    return image


def load_data(image_list:List[str], mask_list:List[str]) -> Tuple[tf.Tensor,tf.Tensor]:
    '''
    Auxiliar function to read both image and mask
    '''
    image = read_image(image_list)
    mask = read_image(mask_list, isMask=True)
    return image, mask


def data_generator(image_list:List[str], mask_list:List[str],batch_size:int=BATCH_SIZE) -> tf.data.Dataset:
    '''
    Return a dataset from a list of images paths
    '''
    dataset = tf.data.Dataset.from_tensor_slices((image_list, mask_list))
    dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset

train_dataset = data_generator(train_images, train_masks)
val_dataset = data_generator(val_images, val_masks)
# test_dataset = data_generator(test_images, test_masks)


# Ading class weights
class_weights = tf.constant([1,1,5,1,0.8])
def map_weights(image, label):
    # Assuming label is one-hot encoded, calculate weights based on the class
    weights = tf.reduce_sum(label * class_weights, axis=-1)  # Calculate weights based on class
    return image, label, weights

# Map the function to the dataset
train_dataset = train_dataset.map(map_weights)



print("Train Dataset:  ", train_dataset)
print("Val Dataset:  ", val_dataset)
# print("Test Dataset:  ", test_dataset)

# Model

In [None]:
conv_side = 3
conv_trans_side = 3
conv_trans_strides_side = 2
pool_side = 2

def multi_unet_model(n_classes:int=NUM_CLASSES, img_height:int=IMG_HEIGHT, img_width:int=IMG_WIDTH, img_channels:int=IMG_CHANNELS) -> keras.models.Model:
    '''
    Build the model
    '''
    inputs = keras.layers.Input((img_height, img_width, img_channels))
    s = keras.layers.Lambda(lambda x: x / 255)(inputs)   #No need for this if we normalize our inputs beforehand

    #Contraction path
    c1 = keras.layers.Conv2D(16, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(s)
    c1 = keras.layers.Dropout(0.1)(c1)
    c1 = keras.layers.Conv2D(16, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = keras.layers.MaxPooling2D((pool_side, pool_side))(c1)
    b1 = tf.keras.layers.BatchNormalization(synchronized=True)(p1)
    
    c2 = keras.layers.Conv2D(32, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(b1)
    c2 = keras.layers.Dropout(0.1)(c2)
    c2 = keras.layers.Conv2D(32, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = keras.layers.MaxPooling2D((pool_side, pool_side))(c2)
    b2 = tf.keras.layers.BatchNormalization(synchronized=True)(p2)
     
    c3 = keras.layers.Conv2D(64, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(b2)
    c3 = keras.layers.Dropout(0.2)(c3)
    c3 = keras.layers.Conv2D(64, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = keras.layers.MaxPooling2D((pool_side, pool_side))(c3)
    b3 = tf.keras.layers.BatchNormalization(synchronized=True)(p3)
     
    c4 = keras.layers.Conv2D(128, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(b3)
    c4 = keras.layers.Dropout(0.2)(c4)
    c4 = keras.layers.Conv2D(128, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = keras.layers.MaxPooling2D(pool_size=(pool_side, pool_side))(c4)
    b4 = tf.keras.layers.BatchNormalization(synchronized=True)(p4)
     
    c5 = keras.layers.Conv2D(256, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(b4)
    c5 = keras.layers.Dropout(0.3)(c5)
    c5 = keras.layers.Conv2D(256, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
    
    #Expansive path 
    u6 = keras.layers.Conv2DTranspose(128, (conv_trans_side, conv_trans_side), strides=(conv_trans_strides_side, conv_trans_strides_side), padding='same')(c5)
    u6 = keras.layers.concatenate([u6, c4])
    c6 = keras.layers.Conv2D(128, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = keras.layers.Dropout(0.2)(c6)
    c6 = keras.layers.Conv2D(128, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(c6)
    b6 = tf.keras.layers.BatchNormalization(synchronized=True)(c6)
     
    u7 = keras.layers.Conv2DTranspose(64, (conv_trans_side, conv_trans_side), strides=(conv_trans_strides_side, conv_trans_strides_side), padding='same')(b6)
    u7 = keras.layers.concatenate([u7, c3])
    c7 = keras.layers.Conv2D(64, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = keras.layers.Dropout(0.2)(c7)
    c7 = keras.layers.Conv2D(64, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(c7)
    b7 = tf.keras.layers.BatchNormalization(synchronized=True)(c7)
    
    u8 = keras.layers.Conv2DTranspose(32, (conv_trans_side, conv_trans_side), strides=(conv_trans_strides_side, conv_trans_strides_side), padding='same')(b7)
    u8 = keras.layers.concatenate([u8, c2])
    c8 = keras.layers.Conv2D(32, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = keras.layers.Dropout(0.1)(c8)
    c8 = keras.layers.Conv2D(32, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(c8)
    b8 = tf.keras.layers.BatchNormalization(synchronized=True)(c8)
     
    u9 = keras.layers.Conv2DTranspose(16, (conv_trans_side, conv_trans_side), strides=(conv_trans_strides_side, conv_trans_strides_side), padding='same')(b8)
    u9 = keras.layers.concatenate([u9, c1], axis=3)
    c9 = keras.layers.Conv2D(16, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = keras.layers.Dropout(0.1)(c9)
    c9 = keras.layers.Conv2D(16, (conv_side, conv_side), activation='relu', kernel_initializer='he_normal', padding='same')(c9)
     
    outputs = keras.layers.Conv2D(n_classes, (1, 1), activation='softmax')(c9)
     
    model = keras.models.Model(inputs=[inputs], outputs=[outputs])
    
    return model
 

In [None]:
model = multi_unet_model()
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=[
        tf.keras.metrics.OneHotIoU(
            num_classes=NUM_CLASSES,
            target_class_ids=[i for i in range(NUM_CLASSES)],
            sparse_y_pred = False # when false retrive prediction with tf.argmax
        ),
    ]
#     loss_weights=None,
#     weighted_metrics=None,
#     run_eagerly=None,
#     steps_per_execution=None,
#     jit_compile=None,
)


# Train

In [None]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        patience=10,
        monitor='val_loss'
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.1,
        patience=5,
#         verbose=0,
        mode='auto',
        min_delta=0.0001,
    ),
    tf.keras.callbacks.TensorBoard(
        log_dir='logs'
    ),
    tf.keras.callbacks.History(
    )
]

In [None]:
history = model.fit(
    x=train_dataset,
#     y=None,
#     batch_size=None,
    epochs=300,
    verbose=2,
    callbacks=callbacks,
#     validation_split=0.0,
    validation_data=val_dataset,
    shuffle=True,
#     class_weight=None,
#     sample_weight=None,
    initial_epoch=0,
#     steps_per_epoch=None,
#     validation_steps=None,
#     validation_batch_size=None,
#     validation_freq=1,
#     max_queue_size=10,
#     workers=1,
#     use_multiprocessing=False
)


In [None]:
history.history.keys()

In [None]:
plt.plot(history.history['val_loss'],'r-', label='Validation Loss')
plt.plot(history.history['loss'],'r--', label='Loss')
plt.plot(history.history['one_hot_io_u_4'],'b--',label='IoU')
plt.plot(history.history['val_one_hot_io_u_4'],'b-',label='Validation IOU')

plt.xlabel('Epochs')
plt.xlim(0)
plt.ylim(0)
plt.legend()

In [None]:
def color_map(mask:tf.Tensor) -> np.ndarray:
    '''
    Turn an 1-encoded gray image into a colored image
    '''
    background_color = (0,0,0) # black
    clear_sky = (92, 179, 255) # clear blue
    thin_cloud = (255, 243, 245) # shiny gray
    tick_cloud = (113, 125, 150) # deep gray
    sun = (255, 242, 0) # yellow
    
    colored_image = np.empty((mask.shape[0],mask.shape[1],3),dtype=np.uint8)
    
    
    for class_index, color in enumerate([background_color,clear_sky,sun,tick_cloud,thin_cloud]):
        indices = np.where(mask == class_index)
        colored_image[indices[0],indices[1],:] = color
    return colored_image

def multipredict(img_list:List[str],mask_list:List[str],model:keras.Model=model) -> None:
    if len(img_list) != len(mask_list):
        raise ValueError('There must be the same number os masks and images')
    if len(img_list) == 0:
        raise ValueError('There must be at least one image')
    
    fig, ax = plt.subplots(len(img_list),3,figsize=(12,4*len(img_list)+3))
    index = 0 # row counter
    for img_path, mask_path in zip(img_list,mask_list):
        img = read_image(img_path)
        ax[index,0].imshow( img )
        ax[index,0].set_xticks([])
        ax[index,0].set_yticks([])
        
        pred = model.predict( tf.expand_dims(img,0) )
        ax[index,1].imshow( color_map(tf.argmax(pred[0,:,:,:],2)) )
        ax[index,1].set_xticks([])
        ax[index,1].set_yticks([])
        
        ax[index,2].imshow(color_map(read_image(mask_path)[:,:,0]))
        ax[index,2].set_xticks([])
        ax[index,2].set_yticks([])
        
        index += 1
    
    ax[0,0].set_title("Image")
    ax[0,1].set_title("Inference")
    ax[0,2].set_title("Ground truth")

In [None]:
multipredict(test_images,test_masks)
plt.savefig("TestPrediction3.jpg")