### Reqirements
- keras >= 2.2.0 or tensorflow >= 1.13
- segmenation-models==1.0.*
- albumentations==0.3.0

In [None]:
# # Install required libs

# ### please update Albumentations to version>=0.3.0 for `Lambda` transform support
# !pip install -U albumentations>=0.3.0 --user 
# !pip install -U --pre segmentation-models --user

In [None]:
!python --version

In [None]:
# !pip list

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import cv2
import keras
import shutil
import numpy as np
import matplotlib.pyplot as plt
import glob
import pandas as pd
import random

from tqdm import tqdm

In [None]:
import tensorflow as tf
tf.__version__
tf.test.is_gpu_available()
training = True

In [None]:
# INPUT
# TODO 設定圖片大小
# set_size = 1024
set_size = 512
set_padding = int(set_size*1.25)
# TODO 設定 input / label 的資料夾
x_train_dir = '../data/1_o_image_512/'
y_train_dir = '../data/label_CEJline_mask_385_512/'
# TODO 是否使用設定的切割方式，若沒有，隨機產生並輸出 split_df
# cut_df_path = '' 
cut_df_path = 'split_df.csv'

In [None]:
# Output
# TODO output weight 名稱
data_folder_name = '385_CEJline'
# 用以區分 weight 檔名 used in cell 26

In [None]:
def remove_check_point(path):
    for file_name in os.listdir(path):
        #print(file_name)
        if 'checkpoints' in file_name:
            shutil.rmtree(os.path.join(path, file_name))
remove_check_point(x_train_dir)
remove_check_point(y_train_dir)

# Dataloader and utility functions 

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
# helper function for data visualization    
def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)    
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x
    

# classes for data loading and preprocessing
class Dataset:
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES_DICT = {'background': 0, 
                    'tooth': 255}
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
            set_ids=None,
    ):
        self.ids = [os.path.basename(x) for x in glob.glob(masks_dir+"/*.PNG")]
        if set_ids != None:
            self.ids = set_ids
        # because masks files are smaller than images_dir
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES_DICT[cls.lower()] for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        #oupput image name
        name = self.images_fps[i].split('/')[-1]

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        assert len(np.unique(mask)) == 2, np.unique(mask)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # add background if mask is not binary
        if mask.shape[-1] != 1:
            background = 1 - mask.sum(axis=-1, keepdims=True)
            mask = np.concatenate((mask, background), axis=-1)
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)
    
    
class Dataloder(keras.utils.Sequence):
    """Load data from dataset and form batches
    
    Args:
        dataset: instance of Dataset class for image loading and preprocessing.
        batch_size: Integet number of images in batch.
        shuffle: Boolean, if `True` shuffle image indexes each epoch.
    """
    
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))
        self.on_epoch_end()

    def __getitem__(self, i):
        
        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])
        
        # transpose list of lists
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]
        
        return batch
    
    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.indexes) // self.batch_size
    
    def on_epoch_end(self):
        """Callback function to shuffle indexes each epoch"""
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)   

In [None]:
# Lets look at data we have
dataset = Dataset(x_train_dir, y_train_dir, classes= ['background', 'tooth'])
print(dataset.class_values)
print(dataset.images_fps[60])
image, mask = dataset[60] # get some sample

visualize(
    image=image, 
    background=mask[..., 0].squeeze(),
    tooth=mask[..., 1].squeeze()
)
print(image.shape)

In [None]:
# plt.figure(figsize=(12,12))
# plt.imshow(image)
# plt.show()

# #前牙 label
# #

### Augmentations

Data augmentation is a powerful technique to increase the amount of your data and prevent model overfitting.  
If you not familiar with such trick read some of these articles:
 - [The Effectiveness of Data Augmentation in Image Classification using Deep
Learning](http://cs231n.stanford.edu/reports/2017/pdfs/300.pdf)
 - [Data Augmentation | How to use Deep Learning when you have Limited Data](https://medium.com/nanonets/how-to-use-deep-learning-when-you-have-limited-data-part-2-data-augmentation-c26971dc8ced)
 - [Data Augmentation Experimentation](https://towardsdatascience.com/data-augmentation-experimentation-3e274504f04b)

Since our dataset is very small we will apply a large number of different augmentations:
 - horizontal flip
 - affine transforms
 - perspective transforms
 - brightness/contrast/colors manipulations
 - image bluring and sharpening
 - gaussian noise
 - random crops

All this transforms can be easily applied with [**Albumentations**](https://github.com/albu/albumentations/) - fast augmentation library.
For detailed explanation of image transformations you can look at [kaggle salt segmentation exmaple](https://github.com/albu/albumentations/blob/master/notebooks/example_kaggle_salt.ipynb) provided by [**Albumentations**](https://github.com/albu/albumentations/) authors.


In [None]:
import albumentations as A

In [None]:
def round_clip_0_1(x, **kwargs):
    return x.round().clip(0, 1)

# define heavy augmentations
def get_training_augmentation():
    train_transform = [
        #水平翻轉
        A.HorizontalFlip(p=0.5),
        #垂直翻轉
        A.VerticalFlip(p=0.5),
        #拉升 旋轉 位移
        A.ShiftScaleRotate(scale_limit=0.3, rotate_limit=5, shift_limit=0.1, p=1, border_mode=0),
        
        #set_padding = 1280, set_size = 1024 
        #原圖1024*1024 將外部padding成 1280*1280
        A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
        #從1280*1280切割成1024
        A.RandomCrop(height=set_size, width=set_size, always_apply=True),
        
        #2成機率加入雜訊
        A.IAAAdditiveGaussianNoise(p=0.2),
        
        #套用clahe 隨機亮度調整
        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightness(p=1),
                A.RandomGamma(p=1),
            ],
            p=0.5,
        ),

        A.OneOf(
            [
                A.IAASharpen(p=1),
                A.Blur(blur_limit=3, p=1),
                A.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.5,
        ),

        A.OneOf(
            [
                A.RandomContrast(p=1),
            ],
            p=0.5,
        ),
        A.Lambda(mask=round_clip_0_1)
    ]
    return A.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        A.PadIfNeeded(set_padding, set_padding),
        A.RandomCrop(height=set_size, width=set_size, always_apply=True)
    ]
    return A.Compose(test_transform)

def get_training_enhance_augmentation():
    """Add paddings to make image shape divisible by 32"""
    fine_tune_transform = [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=90, shift_limit=0.1, p=1, border_mode=0),

        A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
        A.RandomCrop(height=set_size, width=set_size, always_apply=True),
        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightness(p=1),
                A.RandomGamma(p=1),
            ],
            p=0.9,
        )
    ]
    return A.Compose(fine_tune_transform)

def get_fine_tune_augmentation():
    """Add paddings to make image shape divisible by 32"""
    fine_tune_transform = [
        A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
        #A.HorizontalFlip(p=0.5),
#         A.OneOf(
#             [
#                 A.CLAHE(p=1),
#                 A.RandomBrightness(p=1),
#                 A.RandomGamma(p=1),
#             ],
#             p=1,
#         ),
#         A.OneOf(
#             [
#                 A.RandomContrast(p=1),
#                 A.HueSaturationValue(p=1),
#             ],
#             p=1,
#         ),
#         A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=15, shift_limit=0.1, p=1, border_mode=0),
#         A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
#         A.RandomCrop(height=set_size, width=set_size, always_apply=True),
#         A.IAAAdditiveGaussianNoise(p=1),
#         A.OneOf(
#             [
#                 A.CLAHE(p=1),
#                 A.RandomBrightness(p=1),
#                 A.RandomGamma(p=1),
#             ],
#             p=0.5,
#         )
    ]
    return A.Compose(fine_tune_transform)

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        A.Lambda(image=preprocessing_fn),
    ]
    return A.Compose(_transform)

In [None]:
# Lets look at augmented data we have
dataset = Dataset(x_train_dir, y_train_dir, classes= ['background', 'tooth'], augmentation=get_training_enhance_augmentation())
print(dataset.class_values)
image, mask = dataset[15] # get some sample

visualize(
    image=image, 
    background=mask[..., 0].squeeze(),
    tooth=mask[..., 1].squeeze()
)

print(image.shape)
print(mask.shape)

image, mask = dataset[15] # get some sample

visualize(
    image=image, 
    background=mask[..., 0].squeeze(),
    tooth=mask[..., 1].squeeze()
)

print(image.shape)
print(mask.shape)

# Segmentation model training

In [None]:
import segmentation_models as sm

# segmentation_models could also use `tf.keras` if you do not have Keras installed
# or you could switch to other framework using `sm.set_framework('tf.keras')`

In [None]:
BACKBONE = 'efficientnetb3'
BATCH_SIZE = 2 #改成1
CLASSES = ['tooth']
LR = 0.0001
#defult LR = 0.0001
EPOCHS = 100

preprocess_input = sm.get_preprocessing(BACKBONE)

In [None]:
# define network parameters
n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)  # case for binary and multiclass segmentation
activation = 'sigmoid' if n_classes == 1 else 'softmax'
#create model
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)

In [None]:
# define optomizer
optim = keras.optimizers.Adam(LR)

# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
# set class weights for dice_loss (car: 1.; pedestrian: 2.; background: 0.5;)

# dice_loss = sm.losses.DiceLoss(class_weights=np.array([1, 100])) 
# focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
# total_loss = dice_loss

# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
total_loss = sm.losses.BinaryCELoss() # or sm.losses.categorical_focal_dice_loss 

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

# compile keras model with defined optimozer, loss and metrics
model.compile(optim, total_loss, metrics)

In [None]:
# # model.summary()
# model.summary()

In [None]:
# Dataset for train images
dataset_ids = [os.path.basename(x) for x in glob.glob(y_train_dir+"/*.PNG")]
if cut_df_path == "":
    random.shuffle(dataset_ids)
    cut_df = []
    for index, ids in enumerate(dataset_ids):
        if index < int(len(dataset_ids)*0.7):
            cut_df.append([ids, 'train'])
        elif index < int(len(dataset_ids)*0.8):
            cut_df.append([ids, 'valid'])
        else:
            cut_df.append([ids, 'test'])
    cut_df = pd.DataFrame(cut_df, columns=['image_id', 'type'])
    cut_df.to_csv('./split_df.csv', index=False)
else:
    cut_df = pd.read_csv('./split_df.csv')

assert len(cut_df) > 0, 'cut df is empty'
train_id = list(cut_df[cut_df['type'] == 'train']['image_id'])
valid_id = list(cut_df[cut_df['type'] == 'valid']['image_id'])
    
    
    
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    classes=CLASSES, 
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocess_input),
    set_ids=train_id
)

valid_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    classes=CLASSES, 
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocess_input),
    set_ids=valid_id
)
print('train', len(train_dataset), 'valid', len(valid_dataset))
print(set(train_dataset.ids) & set(valid_dataset.ids))

# training_enhance_dataset = Dataset(
#     x_train_dir, 
#     y_train_dir, 
#     classes=CLASSES, 
#     augmentation=get_training_enhance_augmentation(),
#     preprocessing=get_preprocessing(preprocess_input),
# )

train_dataloader = Dataloder(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = Dataloder(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
# training_enhance_dataloader = Dataloder(training_enhance_dataset, batch_size=BATCH_SIZE, shuffle=True)

# #valid_dataloader = Dataloder(valid_dataset, batch_size=1, shuffle=False)
# valid_dataset = Dataset(
#     x_val_dir, 
#     y_val_dir, 
#     classes=CLASSES, 
#     augmentation=get_training_augmentation(),
#     preprocessing=get_preprocessing(preprocess_input),
# )

# valid_dataloader = Dataloder(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)

# check shapes for errors
assert train_dataloader[0][0].shape == (BATCH_SIZE, set_size, set_size, 3)
assert train_dataloader[0][1].shape == (BATCH_SIZE, set_size, set_size, n_classes)

# define callbacks for learning rate scheduling and best checkpoints saving
callbacks = [
    keras.callbacks.ModelCheckpoint('./weight/{}_train_ten_classes_20.h5'.format(data_folder_name), monitor="val_acc", save_weights_only=True, save_best_only=True, mode='auto'),
    keras.callbacks.ReduceLROnPlateau(),
]

In [None]:
if training:
    # train model
    history = model.fit_generator(
        train_dataloader, 
        steps_per_epoch=len(train_dataloader), 
        epochs=40, 
#         callbacks=callbacks, 
        validation_data=valid_dataloader, 
        validation_steps=len(valid_dataloader),
    )

    model.save('./weight/{}_train_ten_classes_40.h5'.format(data_folder_name))
    

In [None]:
his_iou = []
his_val_iou = []
his_loss = []
his_val_loss = []
his_iou += history.history['iou_score']
his_val_iou += history.history['val_iou_score']
his_loss += history.history['loss']
his_val_loss += history.history['val_loss']
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(his_iou)
plt.plot(his_val_iou)
plt.title('Model iou_score')
plt.ylabel('iou_score')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')

# Plot training & validation loss values
plt.subplot(122)
plt.plot(his_loss)
plt.plot(his_val_loss)
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig('./result/train_result_image/epoch40.png')
plt.show()

In [None]:
if training:
    # train model
    history = model.fit_generator(
        train_dataloader, 
        steps_per_epoch=len(train_dataloader), 
        epochs=40, 
        callbacks=callbacks, 
        validation_data=valid_dataloader, 
        validation_steps=len(valid_dataloader),
    )

    model.save('./weight/{}_train_ten_classes_80.h5'.format(data_folder_name))
    

In [None]:
his_iou += history.history['iou_score']
his_val_iou += history.history['val_iou_score']
his_loss += history.history['loss']
his_val_loss += history.history['val_loss']
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(his_iou)
plt.plot(his_val_iou)
plt.title('Model iou_score')
plt.ylabel('iou_score')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')

# Plot training & validation loss values
plt.subplot(122)
plt.plot(his_loss)
plt.plot(his_val_loss)
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig('./result/train_result_image/epoch80.png')
plt.show()

In [None]:
if training:
    # train model
    history = model.fit_generator(
        train_dataloader, 
        steps_per_epoch=len(train_dataloader), 
        epochs=40, 
#         callbacks=callbacks, 
        validation_data=valid_dataloader, 
        validation_steps=len(valid_dataloader),
    )

    model.save('./weight/{}_train_ten_classes_120.h5'.format(data_folder_name))

In [None]:
his_iou += history.history['iou_score']
his_val_iou += history.history['val_iou_score']
his_loss += history.history['loss']
his_val_loss += history.history['val_loss']
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(his_iou)
plt.plot(his_val_iou)
plt.title('Model iou_score')
plt.ylabel('iou_score')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')

# Plot training & validation loss values
plt.subplot(122)
plt.plot(his_loss)
plt.plot(his_val_loss)
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig('./result/train_result_image/epoch120.png')
plt.show()

In [None]:
if training:
    # train model
    history = model.fit_generator(
        train_dataloader, 
        steps_per_epoch=len(train_dataloader), 
        epochs=40, 
#         callbacks=callbacks, 
        validation_data=valid_dataloader, 
        validation_steps=len(valid_dataloader),
    )

    model.save('./weight/{}_train_ten_classes_160.h5'.format(data_folder_name))
    

In [None]:
his_iou += history.history['iou_score']
his_val_iou += history.history['val_iou_score']
his_loss += history.history['loss']
his_val_loss += history.history['val_loss']
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(his_iou)
plt.plot(his_val_iou)
plt.title('Model iou_score')
plt.ylabel('iou_score')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')

# Plot training & validation loss values
plt.subplot(122)
plt.plot(his_loss)
plt.plot(his_val_loss)
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig('./result/train_result_image/epoch160.png')
plt.show()

In [None]:
if training:
    # train model
    history = model.fit_generator(
        train_dataloader, 
        steps_per_epoch=len(train_dataloader), 
        epochs=40, 
        callbacks=callbacks, 
        validation_data=valid_dataloader, 
        validation_steps=len(valid_dataloader),
    )
    model.save('./weight/{}_train_ten_classes_200.h5'.format(data_folder_name))
    

In [None]:
his_iou += history.history['iou_score']
his_val_iou += history.history['val_iou_score']
his_loss += history.history['loss']
his_val_loss += history.history['val_loss']
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(his_iou)
plt.plot(his_val_iou)
plt.title('Model iou_score')
plt.ylabel('iou_score')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')

# Plot training & validation loss values
plt.subplot(122)
plt.plot(his_loss)
plt.plot(his_val_loss)
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig('./result/train_result_image/epoch200.png')
plt.show()
assert False

In [None]:
if training:
    # train model
    history = model.fit_generator(
        train_dataloader, 
        steps_per_epoch=len(train_dataloader), 
        epochs=20, 
        callbacks=callbacks, 
#         validation_data=valid_dataloader, 
#         validation_steps=len(valid_dataloader),
    )

    model.save('./weight/{}_train_ten_classes_120.h5'.format(data_folder_name))
    

In [None]:
if training:
    # train model
    history = model.fit_generator(
        train_dataloader, 
        steps_per_epoch=len(train_dataloader), 
        epochs=20, 
        callbacks=callbacks, 
#         validation_data=valid_dataloader, 
#         validation_steps=len(valid_dataloader),
    )

    model.save('./weight/{}_train_ten_classes_140.h5'.format(data_folder_name))
    