# Oxford IIIT image segmentation with SwinUNET

In [None]:
g_n = 6

import tensorflow as tf

'''
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
'''

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.experimental.set_visible_devices(gpus[g_n], 'GPU')
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)


In [None]:
import numpy as np
from glob import glob

import tensorflow as tf
from tensorflow import keras

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Conv2D, concatenate

In [None]:
import sys
sys.path.append('../')

from keras_vision_transformer import swin_layers
from keras_vision_transformer import transformer_layers
from keras_vision_transformer import utils

# Data and problem statement

This example applies the dataset of [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/) (Parkhi et al. 2012). This dataset contains images of pets and their pixel-wise mask that indicates (1) pixels belonging to the pet, (2) pixels bordering the pet, and (3) surrounding pixels.

A semantic segmentation problem is proposed; it takes images as inputs and predicts the classification probability of the three pixel-wise masks.

In [None]:
# the indicator of a fresh run
first_time_running = False

# user-specified working directory
filepath = '/mnt/hdd_2A/segment_ai_ml_project/datasets/oxford_pets_clean/'

# UNET

In [None]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model
import tensorflow as tf

n_labels = 3

def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def decoder_block(inputs, skip, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
    x = Concatenate()([x, skip])
    x = conv_block(x, num_filters)
    return x

def encoder_block(inputs, num_filters):
    x = MaxPool2D(pool_size=(2,2))(inputs)
    x = conv_block(x, num_filters)
    return x

def build_effienet_unet(input_shape, num_filters):
    """ Input """
    IN = Input(input_shape)
    e1 = conv_block(IN, num_filters[0])
    
    """ Encoder """
    e2 = encoder_block(e1, num_filters[1])
    e3 = encoder_block(e2, num_filters[2])
    e4 = encoder_block(e3, num_filters[3])
    b1 = encoder_block(e4, num_filters[4])

    """ Decoder """
    d1 = decoder_block(b1, e4, num_filters[3])                      
    d2 = decoder_block(d1, e3, num_filters[2])                         
    d3 = decoder_block(d2, e2, num_filters[1])                             
    d4 = decoder_block(d3, e1, num_filters[0])                               

    """ Output """
    OUT = Conv2D(n_labels, kernel_size=1, use_bias=False, activation='softmax')(d4)

    model = Model(inputs=[IN,], outputs=[OUT,], name="UNet")
    return model

input_shape = (128, 128, 3)
num_filters = [32, 64, 128, 256, 512]
model = build_effienet_unet(input_shape, num_filters)

## Hyperparameters

Hyperparameters of the Swin-UNET are listed as follows:

## Model configuration

In [None]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras import backend as K

## IOU in pure numpy
def numpy_iou(y_true, y_pred, n_class=3):
    def iou(y_true, y_pred, n_class):
        # IOU = TP/(TP+FN+FP)
        IOU = []
        for c in range(n_class):
            TP = np.sum((y_true == c) & (y_pred == c))
            FP = np.sum((y_true != c) & (y_pred == c))
            FN = np.sum((y_true == c) & (y_pred != c))

            n = TP
            d = float(TP + FP + FN + 1e-12)

            iou = np.divide(n, d)
            IOU.append(iou)

        return np.mean(IOU)

    batch = y_true.shape[0]
    y_true = np.reshape(y_true, (batch, -1))
    y_pred = np.reshape(y_pred, (batch, -1))

    score = []
    for idx in range(batch):
        iou_value = iou(y_true[idx], y_pred[idx], n_class)
        score.append(iou_value)
    return np.mean(score)


## Calculating IOU across a range of thresholds, then we will mean all the
## values of IOU's.
## this function can be used as keras metrics
def numpy_mean_iou(y_true, y_pred):
    prec = []
    for t in np.arange(0.5,1.0,1.0):
        y_pred_ = tf.cast(y_pred > t, tf.int32)
        score = tf.numpy_function(numpy_iou, [y_true, y_pred_], tf.float64)
        prec.append(score)
    return K.mean(K.stack(prec), axis=0)

## Loss Functions

In [None]:
from tensorflow.keras import backend as K
import segmentation_models as sm

dice_loss = sm.losses.DiceLoss() 
focal_loss = sm.losses.CategoricalFocalLoss()
jaccard_loss = sm.losses.JaccardLoss()
binary_focal_loss = sm.losses.BinaryFocalLoss()
binary_CE_loss = sm.losses.BinaryCELoss()

In [None]:
# Hyperparam. of the losses
alpha = 0.1       # Binary CE
beta = 0.9        # Dice
gamma = 0       # Focal

In [None]:
# total losses
total_loss = alpha * binary_CE_loss + beta * dice_loss + gamma * focal_loss

In [None]:
# Optimization
# <---- !!! gradient clipping is important
opt = keras.optimizers.Adam(learning_rate=1e-4, clipvalue=0.5)
model.compile(loss=total_loss, optimizer=opt)
#model.compile(loss=keras.losses.categorical_crossentropy, optimizer=opt)
model.summary()

## Data pre-processing

The input of RGB images are resized to 128-by-128 through the nearest neighbour scheme, and then normalized to the interval of [0, 1]. The training target of pixel-wise masks are resized similarly.

A random split is applied with 80%, 10%, 10% of the samples are assigned for training, validation, and testing, respectively.

In [None]:
def input_data_process(input_array):
    '''converting pixel vales to [0, 1]'''
    return input_array/255.

def target_data_process(target_array):
    '''Converting tri-mask of {1, 2, 3} to three categories.'''
    return keras.utils.to_categorical(target_array-1)

In [None]:
sample_names = np.array(sorted(glob(filepath+'images/*.jpg')))
label_names = np.array(sorted(glob(filepath+'annotations/trimaps/*.png')))

# Make a Constant Seed
np.random.seed(42)

L = len(sample_names)
ind_all = np.arange(L)
np.random.shuffle(ind_all)

print(L)
print(ind_all)

L_train = int(0.8*L); L_valid = int(0.1*L); L_test = L - L_train - L_valid
ind_train = ind_all[:L_train]; ind_valid = ind_all[L_train:L_train+L_valid]; ind_test = ind_all[L_train+L_valid:]
print("Training:validation:testing = {}:{}:{}".format(L_train, L_valid, L_test))

In [None]:
valid_input = input_data_process(utils.image_to_array(sample_names[ind_valid], size=128, channel=3))
valid_target = target_data_process(utils.image_to_array(label_names[ind_valid], size=128, channel=1))

In [None]:
test_input = input_data_process(utils.image_to_array(sample_names[ind_test], size=128, channel=3))
test_target = target_data_process(utils.image_to_array(label_names[ind_test], size=128, channel=1))

## Training

The segmentation model is trained with fixed 15 epoches. Each epoch containts 100 batches and each batch contains 32 samples.

*The training process here is far from systematic, and is provided for illustration purposes only.*

In [None]:
N_epoch = 120 # number of epoches
N_batch = 100 # number of batches per epoch
N_sample = 32 # number of samples per batch

tol = 0 # current early stopping patience
max_tol = 120 # the max-allowed early stopping patience
min_del = 0 # the lowest acceptable loss value reduction 

logs = {}

# loop over epoches
for epoch in range(N_epoch):
    
    # initial loss record
    if epoch == 0:
        y_pred = model.predict([valid_input])
        record = np.mean( alpha * binary_CE_loss(valid_target, y_pred) + beta * dice_loss(valid_target, y_pred) + gamma * focal_loss(valid_target, y_pred) )
        loss_ = 0
        
        print("Epoch: Initial")
        print('Initial Validation loss = {}'.format(record))
        valid_miou = np.mean( numpy_mean_iou(valid_target, y_pred) )
        print("mIoU: ", valid_miou)
    
    # loop over batches
    for step in range(N_batch):
        # selecting smaples for the current batch
        ind_train_shuffle = utils.shuffle_ind(L_train)[:N_sample]
        
        # batch data formation
        ## augmentation is not applied
        train_input = input_data_process(
            utils.image_to_array(sample_names[ind_train][ind_train_shuffle], size=128, channel=3))
        train_target = target_data_process(
            utils.image_to_array(label_names[ind_train][ind_train_shuffle], size=128, channel=1))
        
        # train on batch
        loss_ = model.train_on_batch([train_input,], [train_target,])
        
    # epoch-end validation
    y_pred = model.predict([valid_input])
    record_temp = np.mean( alpha * binary_CE_loss(valid_target, y_pred) + beta * dice_loss(valid_target, y_pred) + gamma * focal_loss(valid_target, y_pred) )

    # if loss is reduced
    if record - record_temp > min_del:
        print("Epoch: ", (epoch+1))
        print('Validation performance is improved from {} to {}'.format(record, record_temp))
        record = record_temp # update the loss record
        tol = 0; # refresh early stopping patience
        # ** model checkpoint is not stored ** #

    # if loss not reduced
    else:
        print("Epoch: ", (epoch+1))
        print('Validation performance {} is NOT improved'.format(record_temp))
        tol += 1
        if tol >= max_tol:
            print('Early stopping')
            break
    
    valid_miou = np.mean(numpy_mean_iou(valid_target, y_pred) )
    print("Validation mIoU: ", valid_miou)
    
    # ** train, validation loss & validation MIoU are NOW stored! ** #
    logs[epoch] = (loss_, record_temp, valid_miou)
    

In [None]:
# Make new path
import datetime
import os

path = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

if not os.path.exists(path):
    os.makedirs(path)
    
print(path)

In [None]:
# For logging
itemlist = logs.items()

f = open(path + "/log.txt", 'w')
data = "epoch" + "\t" + "Train Loss" + "\t\t\t\t" + "Val Loss" + "\t\t" + "Val mIoU\n"
f.write(data)
 
for item in itemlist:
    print(item)
    data = str(item[0]) + "\t\t" + str(item[1][0]) + "\t\t" + str(item[1][1]) + "\t\t" + str(item[1][2]) + "\t\t\n"
    f.write(data)

f.close()

## Evaluation

The testing set performance is evaluated.

In [None]:
y_pred = model.predict([test_input,])
# print('Testing set cross-entropy loss = {}'.format(np.mean(keras.losses.categorical_crossentropy(test_target, y_pred))))

In [None]:

d = np.mean( dice_loss(test_target, y_pred) )
f = np.mean( focal_loss(test_target, y_pred) )
ce = np.mean( binary_CE_loss(test_target, y_pred) )
tot = np.mean ( alpha * binary_CE_loss(test_target, y_pred) + beta * dice_loss(test_target, y_pred) + gamma * focal_loss(test_target, y_pred) )

print("TEST Dice Loss : ", d)
print("TEST Focal Loss : ", f)
print("TEST BCE Loss : ", ce)
print("TEST Total Loss : ", tot)


In [None]:
y_pred.shape, test_target.shape

In [None]:
r = np.mean( numpy_mean_iou(test_target, y_pred) )
print("TEST mIoU : ", r)

**Example of outputs**

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def ax_decorate_box(ax):
    [j.set_linewidth(0) for j in ax.spines.values()]
    ax.tick_params(axis="both", which="both", bottom=False, top=False, 
                   labelbottom=False, left=False, right=False, labelleft=False)
    return ax

In [None]:
# Show & Save Test Images
num_sample = 30

for i in range(num_sample):
    fig, AX = plt.subplots(1, 4, figsize=(13, (13-0.2)/4))
    plt.subplots_adjust(0, 0, 1, 1, hspace=0.1, wspace=0.1)

    for ax in AX:
        ax = ax_decorate_box(ax)

    AX[0].pcolormesh(test_target[i, ..., ], cmap=plt.cm.jet)
    AX[1].pcolormesh(y_pred[i, ..., 0], cmap=plt.cm.jet)
    AX[2].pcolormesh(y_pred[i, ..., 1], cmap=plt.cm.jet)
    AX[3].pcolormesh(y_pred[i, ..., 2], cmap=plt.cm.jet)

    AX[0].set_title("Original", fontsize=14)
    AX[1].set_title("Pixels belong to the object", fontsize=14)
    AX[2].set_title("Surrounding pixels", fontsize=14)
    AX[3].set_title("Bordering pixels", fontsize=14)
    
    plt.savefig( path + '/test_' + str(i) + '.png')
    