<a href="https://colab.research.google.com/github/svakeczw/Semantic-Image-Segmentation-UNet/blob/main/UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
import cv2

In [None]:
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

#Image augmentation

In [None]:
# Data augmentation
def random_flip(input_image, input_mask):
  if tf.random.uniform(shape=(), minval=0, maxval=1) > 0.5:
    input_image = tf.image.flip_left_right(image=input_image)
    input_mask = tf.image.flip_left_right(image=input_mask)
  return input_image, input_mask

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask = input_mask - 1  # the pixel value for mask is [1,3] and convert to [0,2]
  return input_image, input_mask

def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'],size=(128,128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], size=(128,128))
  input_image, input_mask = random_flip(input_image, input_mask)
  input_image, input_mask = normalize(input_image, input_mask)
  return input_image, input_mask

def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'],size=(128,128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], size=(128,128))
  input_image, input_mask = normalize(input_image, input_mask)
  return input_image, input_mask

In [None]:
batch_size = 64
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train.cache()
train_dataset = train_dataset.shuffle(buffer_size=1000)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

test = dataset['test'].map(load_image_test)
test_dataset = test.batch(batch_size)

In [None]:
def display_image(image_list, titles=[]):
  plt.figure(figsize=(15,15))
  for i in range(len(image_list)):
    plt.subplot(1, len(image_list), i+1)
    plt.title(titles[i])
    plt.xticks([])
    plt.yticks([])
    img_arr = tf.keras.preprocessing.image.array_to_img(image_list[i])
    plt.imshow(img_arr)
  plt.show()

def show_dataset(dataset):
  for image, mask in dataset.take(1):
    display_image([image, mask], titles=['image', 'mask'])

In [None]:
show_dataset(train)
show_dataset(test)

# UNet model

In [None]:
# Encoder part
def conv2d_block(input_tensor, num_filters, kernel_size=3):
  x = input_tensor
  for i in range(2):
    x = tf.keras.layers.Conv2D(filters=num_filters, kernel_size=(kernel_size, kernel_size), padding='same')(x)
    x = tf.keras.layers.Activation('relu')(x)
  return x

def encoder_block(input_tensor, num_filters, pooling_size=(2,2), stride_size=2):
  f = conv2d_block(input_tensor, num_filters=num_filters)  # feature output
  p = tf.keras.layers.MaxPool2D(pool_size=pooling_size,strides=stride_size)(f)  # pooling output
  p = tf.keras.layers.Dropout(0.3)(p)
  return f, p

def encoder(input_tensor):
  x = input_tensor
  f1, p1 = encoder_block(x, num_filters=64)
  f2, p2 = encoder_block(p1, num_filters=128)
  f3, p3 = encoder_block(p2, num_filters=256)
  f4, output = encoder_block(p3, num_filters=512)

  return output, (f1, f2, f3, f4)

def bottom_layer(input_tensor):
  x = conv2d_block(input_tensor=input_tensor, num_filters=1024)
  return x

In [None]:
# Decoder part
def decoder_block(input_tensor, conv_output, num_filters):
  out = tf.keras.layers.Conv2DTranspose(filters=num_filters, kernel_size=2, padding='same', strides=2)(input_tensor)
  out = tf.keras.layers.concatenate([conv_output, out])
  out = tf.keras.layers.Dropout(0.3)(out)
  out = conv2d_block(out,num_filters=num_filters, kernel_size=3)
  return out

def decoder(input_tensor, encoder_out_f, output_channels=3):
  f1, f2, f3, f4 = encoder_out_f

  de4 = decoder_block(input_tensor, conv_output=f4, num_filters=512)
  de3 = decoder_block(de4, conv_output=f3, num_filters=256)
  de2 = decoder_block(de3, conv_output=f2, num_filters=128)
  de1 = decoder_block(de2, conv_output=f1, num_filters=64)

  output = tf.keras.layers.Conv2D(filters=output_channels, kernel_size=(1,1), activation='softmax')(de1)

  return output

In [None]:
def unet():

  inputs = tf.keras.layers.Input(shape=(128, 128,3))

  encoder_outputs, encoder_out_f = encoder(inputs)

  bottom_outputs = bottom_layer(encoder_outputs)

  outputs = decoder(bottom_outputs, encoder_out_f, output_channels=3)

  model = tf.keras.Model(inputs=inputs, outputs=outputs)

  return model

In [None]:
unet_model = unet()
unet_model.summary()

In [None]:
tf.keras.utils.plot_model(unet_model, show_shapes=True)

In [None]:
unet_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',metrics=['accuracy'])
epochs = 50
num_train = info.splits['train'].num_examples
num_test = info.splits['test'].num_examples
num_valid = num_test * 0.2

model_history = unet_model.fit(
    train_dataset, epochs=epochs, steps_per_epoch=num_train//batch_size, 
    validation_steps= num_valid//batch_size, validation_data=test_dataset
)

In [None]:
def plot_metrics(metric_name, title):
  plt.figure()
  plt.title(title)
  plt.ylim(bottom=0, top=1)
  plt.plot(model_history.history[metric_name],color='blue', label=metric_name)
  plt.plot(model_history.history['val_'+metric_name], color='green', label='val'+metric_name)

In [None]:
plot_metrics(metric_name='loss', title='Train-Valid loss')
plot_metrics('accuracy', 'Train-Valid accuracy')

# Prediction

In [None]:
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask,axis=-1)
  pred_mask = pred_mask[...,tf.newaxis]
  return pred_mask[0].numpy()

def make_prediction(image):
  image = image[tf.newaxis,:]
  pred_mask = unet_model.predict(image)
  pred_mask = create_mask(pred_mask)
  return pred_mask

In [None]:
def class_wise_metrics(y_true, y_pred):
  iou_list = []  # Contain three classes' iou
  dice_score_list = []  # Contain three classes' dice score
  smoothening_factor = 1e-10

  for i in range(3):
    y_true_area = np.sum(y_true == i)
    y_pred_area = np.sum(y_pred == i)

    intersection = np.sum((y_pred == i) * (y_true == i))
    combined_area = y_true_area + y_pred_area

    iou = (intersection + smoothening_factor) / (combined_area - intersection + smoothening_factor)
    iou_list.append(iou)

    dice_score = 2 * (
        (intersection + smoothening_factor) / (combined_area + smoothening_factor)
    )
    dice_score_list.append(dice_score)
  
  return iou_list, dice_score_list

In [None]:
def unpack_test_dataset():

  num_test = info.splits['test'].num_examples
  ds = test_dataset.unbatch()
  ds = ds.batch(num_test)
  
  images = []
  y_true_segments = []

  for image, mask in ds.take(1):
    y_true_mask = mask.numpy()
    images = image.numpy()
  
  y_true_mask = y_true_mask[:(num_test - (num_test % batch_size))]
  images = images[:(num_test - (num_test % batch_size))]
  return image, y_true_mask

In [None]:
def get_predict_by_idx(idx=0):
  images, y_true_mask = unpack_test_dataset()
  image = images[idx]
  image = image[tf.newaxis,:]
  y_pred = unet_model.predict(image)
  y_pred = np.argmax(y_pred, axis=3)
  y_pred = y_pred[...,tf.newaxis]
  y_mask = y_true_mask[idx]
  iou_list, dice_score_list = class_wise_metrics(y_true=y_mask,y_pred=y_pred)
  return image, y_mask, y_pred, iou_list, dice_score_list

In [None]:
def show_prediction_with_metric(image, y_mask, y_pred, iou_list, dice_score_list):
  class_name = ['Pet', 'Background', 'Outline']
  image = tf.squeeze(image)  # Remove batch dim
  y_pred = y_pred[0]  # Remove batch dim
  display_list = [image, y_mask, y_pred]
  title = ['Image', 'True mask', 'Pred mask']
  metric_string = [f'{name}: iou: {iou} : dice: {dice}' for _, (name, iou, dice) in enumerate(zip(class_name, iou_list, dice_score_list))]
  metric_string = '\n\n'.join(metric_string)
  plt.figure(figsize=(15,15))
  for i in range(len(display_list)):
    plt.subplot(1,3,i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    if i == 1:
      plt.xlabel(metric_string ,fontsize=15)  
  plt.show()

In [None]:
image, y_mask, y_pred, iou_list, dice_score_list = get_predict_by_idx(16)
show_prediction_with_metric(image, y_mask, y_pred, iou_list, dice_score_list)

In [None]:
# Show overlapped image
def over_lap(image, y_pred):
  bottom = image[0].numpy()
  top = np.array(y_pred[0],dtype=np.float32)
  top = np.concatenate((top,top,top), axis=-1)
  top[:,:,0] = tf.where(top[:,:,0]==0, top[:,:,0]+0.1, top[:,:,0])
  top[:,:,1] = tf.where(top[:,:,1]==0, top[:,:,1]+0.1, top[:,:,1])
  top[:,:,2] = tf.where(top[:,:,2]==0, top[:,:,2]+2, top[:,:,2])
  overlapped_image = cv2.addWeighted(src1=bottom, alpha=0.6, src2=top, beta=0.2, gamma=0.2)
  overlapped_image = tf.clip_by_value(overlapped_image,0,1)
  plt.imshow(overlapped_image)

In [None]:
over_lap(image, y_pred)