# GIT_Multiclass_Segmentation
-------------------
https://github.com/juanpb27/GIT_Multiclass_Segmentation

# Helpful libraries and functions
-----------

In [None]:
# System operations
import re
import os
import glob
import shutil
import splitfolders #!pip install split-folders
import patoolib #!pip install patool

# Handling data
import random
import numpy as np
import pandas as pd

# Computer Vision and plotting
import cv2
from PIL import Image
import seaborn as sns
from skimage import io
import matplotlib.pyplot as plt

# Machine and Deep Learning
import tensorflow as tf
from keras import backend as K
from keras.models import Model
from sklearn.metrics import fbeta_score
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.layers import Conv2DTranspose, BatchNormalization, Dropout, Lambda
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate

# Data paths
-----------

In [None]:
# Challenge images and annotations
PROJECT_PATH = '../'
INPUT_PATH = PROJECT_PATH + 'input/'
CHALLENGE_PATH = INPUT_PATH + 'uw-madison-gi-tract-image-segmentation/'
IMAGES = CHALLENGE_PATH + 'train'
LABELS = CHALLENGE_PATH + 'train.csv'

# Built dataset to train a CNN
TRAIN_IMAGES_PATH = INPUT_PATH + 'uwmgi-dataset-splitted/train/images/'
TRAIN_MASKS_PATH = INPUT_PATH + 'uwmgi-dataset-splitted/train/masks/'
VALID_IMAGES_PATH = INPUT_PATH + 'uwmgi-dataset-splitted/valid/images/'
VALID_MASKS_PATH = INPUT_PATH + 'uwmgi-dataset-splitted/valid/masks/'

# Another used Functions
------------

In [None]:
# Obtaining images metadata
def get_metadata(path):
    
    metadata = re.search('\d{3}_\d{3}_\d{1}.\d{2}_\d{1}.\d{2}', path)
    
    width    = metadata.group()[0:3]
    height   = metadata.group()[4:7]
    m_width  = metadata.group()[8:12]
    m_height = metadata.group()[13:]
    
    return height, width, m_height, m_width

In [None]:
# Decoding RLE to masks
# Ref: https://www.kaggle.com/stainsby/fast-tested-rle
def rle2mask(mask_rle, label, shape):
    """
    mask_rle: run-length as string formatted (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background

    """
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = label
    return img.reshape(shape)  # Needed to align to RLE direction

In [None]:
# Prepare data to be used in the CNN
def PrepareData(X, Y, n_classes):
    
    X = X / (2 ** 16 - 1.) # Normalization is performed by 2^16 because the image is uint16
    Y = to_categorical(Y, n_classes) # To split the mask in 4 binary channels for the classes
    
    return X, Y

In [None]:
# Create a tensor to feed the DataGenerator and the CNN
def CreateSet(images_path, masks_path, n_classes, seed_sample):
    
    X = []  
    Y = [] 

    sample = len(os.listdir(images_path)) // 10 # Sample of 10%
    #sample = len(os.listdir(images_path)) // 5 # Sample of 20%
    #sample = len(os.listdir(images_path)) // 3 # Sample of 33.3%
    
    list_images = os.listdir(images_path)
    list_images.sort()
    list_images = list_images[seed_sample : seed_sample + sample]

    list_masks = os.listdir(masks_path)
    list_masks.sort()
    list_masks = list_masks[seed_sample : seed_sample + sample]


    for i, image_name in enumerate(list_images):
        image = cv2.imread(images_path + image_name, -1)
        #image = image.astype('uint8')
        image = Image.fromarray(image)

        X.append(np.array(image))

    for i, mask_name in enumerate(list_masks):
        mask = cv2.imread(masks_path + mask_name, -1)
        mask = Image.fromarray(mask)

        Y.append(np.array(mask))
        
    X = np.array(X)
    X = X[...,None] # Expand dims
    
    Y = np.array(Y)
    Y = Y[...,None]
        
    #X, Y = PrepareData(X, Y, n_classes)
        
    return X, Y

In [None]:
# Looped Generation of a batch of differents images
def Generator(X, Y, n_classes):
    
    img_data_gen_args = dict(horizontal_flip=True,
                      vertical_flip=True,
                      fill_mode='reflect')
    
    image_datagen = ImageDataGenerator(**img_data_gen_args)
    mask_datagen = ImageDataGenerator(**img_data_gen_args)
    
    image_generator = image_datagen.flow(
        X,
        batch_size = batch_size,
        seed = seed)
    
    mask_generator = mask_datagen.flow(
        Y,
        batch_size = batch_size,
        seed = seed)
    
    generator = zip(image_generator, mask_generator)
    
    for (img, mask) in generator:
        img, mask = PrepareData(img, mask, n_classes)
        yield (img, mask)

In [None]:
# Used architecture(U-Net)

# https://github.com/bnsreenu/python_for_microscopists/blob/master/208-simple_multi_unet_model.py
"""
Standard Unet
Model not compiled here, instead will be done externally to make it
easy to test various loss functions and optimizers. 
"""
################################################################
def multi_unet_model(n_classes=4, IMG_HEIGHT=256, IMG_WIDTH=256, IMG_CHANNELS=1):
#Build the model
    inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
    #s = Lambda(lambda x: x / 255)(inputs)   #No need for this if we normalize our inputs beforehand
    s = inputs

    #Contraction path
    c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
    c1 = Dropout(0.1)(c1)
    c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)
    
    c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = Dropout(0.1)(c2)
    c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)
     
    c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = Dropout(0.2)(c3)
    c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)
     
    c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = Dropout(0.2)(c4)
    c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = MaxPooling2D(pool_size=(2, 2))(c4)
     
    c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = Dropout(0.3)(c5)
    c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
    
    #Expansive path 
    u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = Dropout(0.2)(c6)
    c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)
     
    u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = Dropout(0.2)(c7)
    c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)
     
    u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = Dropout(0.1)(c8)
    c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)
     
    u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = concatenate([u9, c1], axis=3)
    c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = Dropout(0.1)(c9)
    c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)
     
    outputs = Conv2D(n_classes, (1, 1), activation='softmax')(c9)
     
    model = Model(inputs=[inputs], outputs=[outputs])
    
    return model

In [None]:
def get_model():
    return multi_unet_model(n_classes=n_classes, IMG_HEIGHT=IMG_HEIGHT, IMG_WIDTH=IMG_WIDTH, IMG_CHANNELS=IMG_CHANNELS)

In [None]:
# Loss function and metric
def jacard_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)

def jacard_coef_loss(y_true, y_pred):
    return -jacard_coef(y_true, y_pred)  # -1 ultiplied as we want to minimize this value as loss function

# Data Extraction
---------

In [None]:
# To read the CSV of annotations (labels)
labels_df = pd.read_csv(LABELS)
labels_df

In [None]:
# Sorting the path of each image into a list
file_list = []
image_list = []
n_cases = 156
n_days = 39
n_slices = 144

for case in range(n_cases+1):
    for day in range(n_days+1):
        for slices in range(n_slices+1):
            if(slices<10):
                str_slices = '000' + str(slices)
            elif(slices<100):
                str_slices = '00' + str(slices)
            else:
                str_slices = '0' + str(slices)
                
            for file in glob.glob(IMAGES + '/case' + str(case) + '/case' + str(case) + '_day' + str(day) +
                                 '/scans/slice_' + str_slices + '*.png'):
                
                # Para enlazar a cada clase con una misma imagen
                file_list.append(file)

In [None]:
# Create a list to match with their filenames
n_classes = 3
class_column = []
classes = ['large_bowel', 'small_bowel', 'stomach']
for x in range(len(file_list)):
    class_column.append(classes[0])
    class_column.append(classes[1])
    class_column.append(classes[2])

# Mask-Image Matching
--------------

In [None]:
#  Matching a filename for each segmentation by using regular expressions
filename_column = []
case_column, day_column, slice_column = [], [], []
height_column, width_column, m_height_column, m_width_column = [], [], [], []
id_column = []

for f in file_list:
    
    c = re.search('case(\d+)_day', f)
    d = re.search('_day(\d+)/scans', f)
    s = re.search('/slice_(\d+)_', f)
    i = 'case' + c.group(1) + '_day' + d.group(1) + '_slice_' + s.group(1)
    height, width, m_height, m_width = get_metadata(f)
    
    for n in range(n_classes):
        
        filename_column.append(f)
        
        case_column.append(int(c.group(1)))
        day_column.append(int(d.group(1)))
        slice_column.append(int(s.group(1)))

        height_column.append(int(height))
        width_column.append(int(width))
        m_height_column.append(float(m_height))
        m_width_column.append(float(m_width))
        
        id_column.append(i)

# Initial Dataset Construction
--------------

In [None]:
images_df = pd.DataFrame(list(zip(id_column, class_column, case_column, day_column, slice_column, filename_column,
                                height_column, width_column, m_height_column, m_width_column)),
                       columns = ['id', 'class', 'case', 'day', 'slice', 'filename', 'height', 'width',
                                  'm_height', 'm_width'])
images_df

In [None]:
complete_data = pd.merge(labels_df, images_df, how='inner', on = ['id', 'class'])
complete_data

In [None]:
complete_data.to_csv(r'./complete_data.csv', index=False)

# Data Analysis
---------

In [None]:
complete_data.info()

In [None]:
size_df = complete_data.groupby(["height",'width'], as_index=False)['filename'].count()
size_df

In [None]:
classes_df = complete_data.groupby("class", as_index=False)['segmentation'].count()
classes_df

In [None]:
cases_df = complete_data.groupby("case", as_index=False)['filename'].count()
cases_df

In [None]:
cases_df = complete_data.groupby("case", as_index=False)['filename'].count()
cases_df.max()

In [None]:
cases_graph = complete_data.groupby("case").size()
sns.barplot(x = cases_graph.index, y = cases_graph.values)

In [None]:
days_df = complete_data.groupby("day", as_index=False)['filename'].count()
days_df # 91% of the images are of the day 0

In [None]:
days_graph = complete_data.groupby(["day"]).size()
 
sns.barplot(x = days_graph.index, y = days_graph.values/3)
plt.savefig('day_description.png')

# Class Encoding
--------------
0. Unlabeled background
1. Large bowel
2. Small bowel
3. Stomach

In [None]:
complete_data['class'] = complete_data['class'].replace('large_bowel', 1)
complete_data['class'] = complete_data['class'].replace('small_bowel', 2)
complete_data['class'] = complete_data['class'].replace('stomach', 3)

complete_data

# Segmentation decoding
----------

In [None]:
complete_data['segmentation'] = complete_data['segmentation'].replace(np.nan, '0')     
complete_data

In [None]:
complete_data["segmentation"] = complete_data.apply(lambda x: rle2mask(mask_rle = x["segmentation"],
                                                                       label = 1,
                                                                       shape = (x["height"],x["width"])), axis=1)
complete_data

# Column drop
---------

In [None]:
complete_data = complete_data.drop(['id',
                                    'case',
                                    'day',
                                    'slice',
                                   'height',
                                   'width',
                                   'm_height',
                                   'm_width'],
                                   axis=1)
complete_data

# Mask Unification
----------

In [None]:
os.mkdir('./dataset')
os.mkdir('./dataset/images')
os.mkdir('./dataset/masks')

In [None]:
# Creation of general masks and dataset
for x in range(0, len(complete_data), 3):
    
    image = cv2.imread(complete_data['filename'].iloc[x], -1)
    image = image[(image.shape[0] - img_shape[0]):, (image.shape[1] - img_shape[1]):]
    
    mask = np.zeros(complete_data["segmentation"].iloc[x].shape, dtype=np.int8)
    
    mask_lb = complete_data["segmentation"].iloc[x]
    
    mask_sb = complete_data["segmentation"].iloc[x + 1]
    
    mask_st = complete_data["segmentation"].iloc[x + 2]
   
    mask[mask_lb == 1] = 1 # Large bowel
    mask[mask_sb == 1] = 2 # Small bowel
    mask[mask_st == 1] = 3 # Stomach
    
    mask = mask[(mask.shape[0] - img_shape[0]):, (mask.shape[1] - img_shape[1]):]
    
    cv2.imwrite('./dataset/images/image_' + str(x // 3) + '.png', image)
    cv2.imwrite('./dataset/masks/mask_' + str(x // 3) + '.png', mask)

In [None]:
shutil.make_archive('data', 'zip', './dataset/')

<a href='./data.zip'> Download Dataset </a>

# Dataset Splitting
----------

In [None]:
patoolib.extract_archive('./data.zip', outdir='./')

input_folder = './data'

splitfolders.ratio(input_folder, output="dataset", seed=1337, ratio=(.8, .2), group_prefix=None) # train and valid

shutil.make_archive('dataset', 'zip', './dataset')

# Visualization
------------

In [None]:
n = 28110
img = cv2.imread(complete_data['filename'].iloc[n], -1)
plt.imshow(img, cmap='gray')

In [None]:
mask1 = rle2mask(complete_data['segmentation'].iloc[n],
                1,
                (complete_data['height'].iloc[n],complete_data['width'].iloc[n]))
plt.imshow(mask1, cmap='gray')

In [None]:
mask2 = rle2mask(complete_data['segmentation'].iloc[n+1],
                1,
                (complete_data['height'].iloc[n+1],complete_data['width'].iloc[n+1]))
plt.imshow(mask2, cmap='gray')

In [None]:
mask3 = rle2mask(complete_data['segmentation'].iloc[n+2],
                1,
                (complete_data['height'].iloc[n+2],complete_data['width'].iloc[n+2]))
plt.imshow(mask3, cmap='gray')

In [None]:
fig = plt.figure(figsize=(12, 8))
plt.subplot(2, 3, 2); plt.imshow(img, cmap='bone')  ;  plt.title('Image')
plt.subplot(2, 3, 4); plt.imshow(mask1, cmap='bone');  plt.title('Large bowel')
plt.subplot(2, 3, 5); plt.imshow(mask2, cmap='bone');  plt.title('Small bowel')
plt.subplot(2, 3, 6); plt.imshow(mask3, cmap='bone');  plt.title('Stomach')
fig.savefig('classes.png')

In [None]:
mask = np.zeros(img.shape, dtype=np.int8)
    
mask[mask1 == 1] = 1 # Large bowel
mask[mask2 == 1] = 2 # Small bowel
mask[mask3 == 1] = 3 # Stomach

In [None]:
fig = plt.figure(figsize=(12, 7))
plt.subplot(1, 3, 1); plt.imshow(img, cmap='bone');
plt.axis('OFF'); plt.title('image')
plt.subplot(1, 3, 2); plt.imshow(mask*255, cmap='hot'); plt.axis('OFF'); plt.title('mask')
plt.subplot(1, 3, 3); plt.imshow(img, cmap='gray'); plt.imshow(mask*255, alpha=0.4);
plt.axis('OFF'); plt.title('overlay')
plt.tight_layout()
plt.savefig('segmentation.png')
plt.show()

# Image Preproccesing and Dataset Creation
-------------

In [None]:
img_shape = (256, 256, 1) # Required shape of the images

In [None]:
X_train, Y_train = CreateSet(TRAIN_IMAGES_PATH, TRAIN_MASKS_PATH, n_classes = 4, seed_sample = 15000)
X_valid, Y_valid = CreateSet(VALID_IMAGES_PATH, VALID_MASKS_PATH, n_classes = 4, seed_sample = 3000)

In [None]:
print(f' Train Dataset shape:     X_train --> {X_train.shape}    Y_train --> {Y_train.shape} \n')
print(f' Valid Dataset shape:     X_valid --> {X_valid.shape}     Y_valid --> {Y_valid.shape}')

In [None]:
num = 27
plt.imshow(X_train[num], cmap='gray')

In [None]:
list_mask = os.listdir(TRAIN_MASKS_PATH)
list_mask.sort()
msk = cv2.imread(TRAIN_MASKS_PATH + list_mask[15027],-1)
plt.imshow(msk, cmap='hot')
np.unique(msk)

In [None]:
plt.imshow(Y_train[num,:,:,1], cmap='hot')

In [None]:
Y_train[num].shape

# Avoiding Overfitting
-----------

In [None]:
seed=24
batch_size= 16

train_img_gen = Generator(X_train, Y_train, n_classes = 4)

val_img_gen = Generator(X_valid, Y_valid, n_classes = 4)

In [None]:
# Another option to set weights to the classes
class_weights_manual = {0: 0.01,
                1: 3.0,
                2: 3.5,
                3: 4.5}

# Evaluation Measures
------------

## Configuration 1 (Accuracy)

In [None]:
model = get_model()
model.compile(optimizer= Adam(learning_rate = 1e-3), loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

## Configuration 2 (IoU Loss)

In [None]:
model = get_model()
model.compile(optimizer= Adam(learning_rate = 1e-4), loss='categorical_crossentropy', metrics=[jacard_coef])
model.summary()

# Training
----------

In [None]:
n_classes = 4
IMG_HEIGHT = 256
IMG_WIDTH  = 256
IMG_CHANNELS = 1

In [None]:
callbacks = [ModelCheckpoint('./unet.h5', verbose=1, save_best_only=True, save_weights_only=True, monitor='val_loss'),
            EarlyStopping(monitor="val_loss", patience=15, verbose=2, mode="auto", restore_best_weights=True)]

In [None]:
num_train_imgs = X_train.shape[0]
steps_per_epoch = num_train_imgs // batch_size

In [None]:
history = model.fit(train_img_gen, 
                    #batch_size = 16,
                    verbose=1, 
                    epochs=200,
                    validation_data=val_img_gen,
                    steps_per_epoch=steps_per_epoch, 
                    validation_steps=steps_per_epoch,
                    callbacks = callbacks)
                    #class_weight=class_weights_manual) #Is not supported for 3D
                    #shuffle=False)
                    
model.save('second_train_200epochs_batchsize16_lr4_dataaug_weights30percentimages.hdf5')

In [None]:
# Plot the training and validation loss
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
fig = plt.figure()
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('30percent_Train_val_loss_200epochs_batchsize8_lr4_dataaug.png')
plt.show()

In [None]:
# Plot the training and validation metric
plt.plot(history.history['jacard_coef'])
plt.plot(history.history['val_jacard_coef'])
plt.title('model jacard_coef')
plt.ylabel('jacard_coef')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig('PREDICEJACCARDCOEFTrain_val_loss_200epochs_batchsize8_lr4_dataaug.png')
plt.show()

In [None]:
# summarize history for metric
plt.figure(figsize=(16, 8))
plt.subplot(231)
plt.plot(history.history['jacard_coef'])
plt.plot(history.history['val_jacard_coef'])
plt.title('model jacard_coef')
plt.ylabel('jacard_coef')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')

# summarize history for loss
plt.subplot(232)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig('30percent_Train_val_loss_200epochs_batchsize16_lr4_dataaug.png')
plt.show()

# Prediction
-------

In [None]:
n = 30000 #77
list_images = os.listdir(TRAIN_IMAGES_PATH)
list_images.sort()
img_prueba = cv2.imread(TRAIN_IMAGES_PATH + list_images[n], -1)
plt.imshow(img_prueba, cmap='gray')

In [None]:
list_masks = os.listdir(TRAIN_MASKS_PATH)
list_masks.sort()
mask_prueba = cv2.imread(TRAIN_MASKS_PATH + list_masks[n], -1)
plt.imshow(mask_prueba, cmap='hot')

In [None]:
P = []
image = Image.fromarray(img_prueba)
P.append(np.array(image))
P = np.array(P)
P = P[...,None] # Expand dims
P = P / (2 ** 16 - 1.) # Normalization is performed by 2^16 because the image is uint16

In [None]:
predict = model.predict(P)
mask_predict = predict[0].argmax(-1)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(10,10))
ax1.imshow(img_prueba, cmap='gray')
ax2.imshow(mask_prueba, cmap='hot')
ax3.imshow(mask_predict, cmap='hot')
fig.savefig('Prediction.png')