# Train new U-Net model

In [None]:
import tensorflow as tf
print(tf.__version__)
print(tf.config.list_physical_devices('GPU'))

In [None]:
# # Set directory in Colab -- uncomment this block if you want to run on Colab
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/My\ Drive/Colab\ Notebooks/camvid_unet_semantic_segmentation

## 1. Load data to memory

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from utils import load_data_from_dir

In [None]:
base_dir = os.getcwd()

train_image_df = pd.read_csv(base_dir + '/data/camvid_train.txt', header=None, sep=' ', names=['image', 'mask'])
val_image_df = pd.read_csv(base_dir + '/data/camvid_val.txt', header=None, sep=' ', names=['image', 'mask'])
test_image_df = pd.read_csv(base_dir + '/data/camvid_test.txt', header=None, sep=' ', names=['image', 'mask'])

classes = [
    'Sky', 'Building', 'Pole', 'Road', 'Pavement', 'Tree', 
    'SignSymbol', 'Fence', 'Car', 'Pedestrian', 'Bicyclist',
    'Void'
    ]

n_class = 11  # ignore 'Void' (background) class

In [None]:
%%time
image_dim = (224, 224)  # same as vgg-16
train_images, train_masks = load_data_from_dir(train_image_df, True, image_dim)
val_images, val_masks = load_data_from_dir(val_image_df, True, image_dim)
test_images, test_masks = load_data_from_dir(test_image_df, True, image_dim)

In [None]:
print(train_images.shape)
print(train_masks.shape)
print(val_images.shape)
print(val_masks.shape)
print(test_images.shape)
print(test_masks.shape)

## 2. Build U-Net model

### 2.1 Use vanilla U-Net architecture

In [None]:
from build_model import build_unet

unet_model = build_unet(input_shape = train_images.shape[1:], 
                        num_classes = 11, 
                        num_filters = 64, 
                        kernel_size = 3)

unet_model.summary()

### 2.2 Alternatively, use either a pre-trained ResNet50v2 or MobileNetV2 as the encoder

In [None]:
# Use resnet50v2
'''
from build_model import build_unet_resnet50v2

resnet50v2 = tf.keras.applications.ResNet50V2(
    include_top=False,
    weights="imagenet",
    input_shape=train_images.shape[1:],
    pooling=None,
)

# Build model
unet_model = build_unet_resnet50v2(encoder = resnet50v2, 
                                   num_classes = 11, 
                                   num_filters = 64, 
                                   kernel_size = 3)

unet_model.summary()
'''

In [None]:
# Use mobilenetv2
'''
from build_model import build_unet_resnet50v2

mobilenetv2 = tf.keras.applications.MobileNetV2(
    include_top=False,
    weights="imagenet",
    input_shape=train_images.shape[1:],
    pooling=None,
)

unet_model = build_unet_mobilenetv2(encoder = mobilenetv2, 
                                    num_classes = 11, 
                                    num_filters = 32, 
                                    kernel_size = 3)

unet_model.summary()
'''

In [None]:
# Set loss and compile model

tf.keras.backend.clear_session()

# Use SCCE loss to save memory
SCCE = tf.keras.losses.SparseCategoricalCrossentropy(ignore_class=255) # ignore void class in loss calculation

unet_model.compile(optimizer="adam", loss=SCCE, metrics='accuracy')

In [None]:
# Plot network architecture

from keras.utils.vis_utils import plot_model
plot_model(unet_model, show_shapes=True, show_layer_names=True)

## 3. Train model

In [None]:
from tensorflow.keras.callbacks import EarlyStopping

train_val = False   # Use train & valid sets to optimize hyperparameters
train_test = False  # Use all train and valid set data to train a final model and then evaluate on the test set

epoch = 100

if train_val:

    earlystopping = EarlyStopping(monitor="val_loss", patience = 20, restore_best_weights=True)
    callbacks_list = [earlystopping]
    
    history = unet_model.fit(train_images, train_masks, epochs=epochs, callbacks = callbacks_list,
                             validation_data=(val_images, val_masks), verbose=2)
    
    unet_model.evaluate(train_images, train_masks)
    unet_model.evaluate(val_images, val_masks)
    
elif train_test:

    history = unet_model.fit(np.concatenate((train_images, val_images), axis=0), 
                             np.concatenate((train_masks, val_masks), axis=0), 
                             epochs=epochs, 
                             verbose=2)
    
    unet_model.evaluate(train_images, train_masks)
    unet_model.evaluate(test_images, test_masks)

### Plot history

In [None]:
if train_val:
    fig, ax = plt.subplots(1,2, figsize=(12, 5))
    ax[0].plot(history.history['loss'], label='train')
    ax[0].plot(history.history['val_loss'], label='valid')
    ax[0].set_title('Loss')
    ax[0].legend()
    ax[0].grid()
    ax[0].set_yscale("log")  

    ax[1].plot(history.history['accuracy'], label='train')
    ax[1].plot(history.history['val_accuracy'], label='valid')
    ax[1].set_title('Accuracy')
    ax[1].legend()
    ax[1].grid()
    ax[1].set_ylim(0.8, 0.95)
    plt.show()
    
elif train_test:
    fig, ax = plt.subplots(1,2, figsize=(12, 5))
    ax[0].plot(history.history['loss'], label='train')
    ax[0].set_title('Loss')
    ax[0].legend()
    ax[0].grid()
    ax[0].set_yscale("log")  


    ax[1].plot(history.history['accuracy'], label='train')
    ax[1].set_title('Accuracy')
    ax[1].legend()
    ax[1].grid()
    ax[1].set_ylim(0.8, 0.98)
    plt.show()

In [None]:
# Save model
save_model = False
model_name = 'new_unet_model.h5'

if save_model:
    unet_model.save(os.path.join(base_dir, 'models', model_name))