### Reqirements
- keras >= 2.2.0 or tensorflow >= 1.13
- segmenation-models==1.0.*
- albumentations==0.3.0

In [None]:
# # Install required libs

# ### please update Albumentations to version>=0.3.0 for `Lambda` transform support
# !pip install -U albumentations>=0.3.0 --user 
# !pip install -U --pre segmentation-models --user

# Loading dataset

For this example we will use **CamVid** dataset. It is a set of:
 - **train** images + segmentation masks
 - **validation** images + segmentation masks
 - **test** images + segmentation masks
 
All images have 320 pixels height and 480 pixels width.
For more inforamtion about dataset visit http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/.

In [None]:
# import own 共用 lib code
import sys
sys.path.append(r"../data/data_checking_code/")

import folder_tool
import vis_tool

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import cv2
import tensorflow.keras as keras
import shutil
import numpy as np
import matplotlib.pyplot as plt
import glob

from tqdm.notebook import tqdm

In [None]:
# Setting
training = False
BATCH_SIZE = 1
# used in cell 16

In [None]:
data_info, result_info_path = folder_tool.data_json_setting_load('./0_data_json_setting/0_run.json')
# TODO 設定圖片大小
set_size = data_info['set_size']
# set_size = 1024
set_padding = int(set_size*1.25)
# TODO 設定 input / label 的資料夾
x_test_dir = data_info['x_train_dir']
y_test_dir = data_info['y_test_dir']
# TODO 是否使用設定的切割方式，若沒有，隨機產生並輸出 split_df
# cut_df_path = '' 
cut_df_path = data_info['cut_df_path']
exp_name = data_info['name']
# TODO 使用 weight path
weight_path = f'./result/{exp_name}/weight_80.h5'

# check output folder
folder_tool.check_folder(f'./result/{exp_name}', only_check_exist=True)
for img_type in ['train', 'val', 'test']:
    folder_tool.check_folder(f'./result/{exp_name}/{img_type}_pred_raw')
folder_tool.check_folder(f'./result/{exp_name}/seg_raw/')

# Dataloader and utility functions 

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
# helper function for data visualization    
def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)    
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x
    

# classes for data loading and preprocessing
class Dataset:
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES_DICT = {'background': 0, 
                    'tooth': 255}
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
            set_ids=None,
    ):
        self.ids = [os.path.basename(x) for x in glob.glob(masks_dir+"/*.PNG")]
        if set_ids != None:
            self.ids = set_ids
        # because masks files are smaller than images_dir
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES_DICT[cls.lower()] for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        #oupput image name
        name = self.images_fps[i].split('/')[-1]

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (set_size, set_size))
        mask = cv2.imread(self.masks_fps[i], 0)
        mask = cv2.resize(mask, (set_size, set_size))

        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # add background if mask is not binary
        if mask.shape[-1] != 1:
            background = 1 - mask.sum(axis=-1, keepdims=True)
            mask = np.concatenate((mask, background), axis=-1)
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask, name
        
    def __len__(self):
        return len(self.ids)
    
    
class Dataloder(keras.utils.Sequence):
    """Load data from dataset and form batches
    
    Args:
        dataset: instance of Dataset class for image loading and preprocessing.
        batch_size: Integet number of images in batch.
        shuffle: Boolean, if `True` shuffle image indexes each epoch.
    """
    
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))
        self.on_epoch_end()

    def __getitem__(self, i):
        
        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])
        
        # transpose list of lists
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]
        
        return batch
    
    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.indexes) // self.batch_size
    
    def on_epoch_end(self):
        """Callback function to shuffle indexes each epoch"""
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)

### Augmentations

Data augmentation is a powerful technique to increase the amount of your data and prevent model overfitting.  
If you not familiar with such trick read some of these articles:
 - [The Effectiveness of Data Augmentation in Image Classification using Deep
Learning](http://cs231n.stanford.edu/reports/2017/pdfs/300.pdf)
 - [Data Augmentation | How to use Deep Learning when you have Limited Data](https://medium.com/nanonets/how-to-use-deep-learning-when-you-have-limited-data-part-2-data-augmentation-c26971dc8ced)
 - [Data Augmentation Experimentation](https://towardsdatascience.com/data-augmentation-experimentation-3e274504f04b)

Since our dataset is very small we will apply a large number of different augmentations:
 - horizontal flip
 - affine transforms
 - perspective transforms
 - brightness/contrast/colors manipulations
 - image bluring and sharpening
 - gaussian noise
 - random crops

All this transforms can be easily applied with [**Albumentations**](https://github.com/albu/albumentations/) - fast augmentation library.
For detailed explanation of image transformations you can look at [kaggle salt segmentation exmaple](https://github.com/albu/albumentations/blob/master/notebooks/example_kaggle_salt.ipynb) provided by [**Albumentations**](https://github.com/albu/albumentations/) authors.


In [None]:
import albumentations as A

In [None]:
def round_clip_0_1(x, **kwargs):
    return x.round().clip(0, 1)

# define heavy augmentations
def get_training_augmentation():
    train_transform = [

        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(scale_limit=0.3, rotate_limit=5, shift_limit=0.1, p=1, border_mode=0),

        A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
        A.RandomCrop(height=set_size, width=set_size, always_apply=True),

        A.IAAAdditiveGaussianNoise(p=0.2),
        A.IAAPerspective(p=0.5),

        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightness(p=1),
                A.RandomGamma(p=1),
            ],
            p=0.5,
        ),

        A.OneOf(
            [
                A.IAASharpen(p=1),
                A.Blur(blur_limit=3, p=1),
                A.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.5,
        ),

        A.OneOf(
            [
                A.RandomContrast(p=1),
                A.HueSaturationValue(p=1),
            ],
            p=0.5,
        ),
        A.Lambda(mask=round_clip_0_1)
    ]
    return A.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        A.PadIfNeeded(set_padding, set_padding),
        A.RandomCrop(height=set_size, width=set_size, always_apply=True)
    ]
    return A.Compose(test_transform)

def get_training_enhance_augmentation():
    """Add paddings to make image shape divisible by 32"""
    fine_tune_transform = [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=90, shift_limit=0.1, p=1, border_mode=0),

        A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
        A.RandomCrop(height=set_size, width=set_size, always_apply=True),
        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightness(p=1),
                A.RandomGamma(p=1),
            ],
            p=0.9,
        )
    ]
    return A.Compose(fine_tune_transform)

def get_fine_tune_augmentation():
    """Add paddings to make image shape divisible by 32"""
    fine_tune_transform = [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=15, shift_limit=0.1, p=1, border_mode=0),

        A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
        A.RandomCrop(height=set_size, width=set_size, always_apply=True),
        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightness(p=1),
                A.RandomGamma(p=1),
            ],
            p=0.5,
        )
    ]
    return A.Compose(fine_tune_transform)

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        A.Lambda(image=preprocessing_fn),
    ]
    return A.Compose(_transform)
def get_preprocessing_v1():
    _transform = [
        A.CLAHE(p=1)
    ]
    return A.Compose(_transform)

# model define

In [None]:
import segmentation_models as sm
BACKBONE = 'efficientnetb3'
# BACKBONE = 'inceptionv3'
CLASSES = ['tooth']
LR = 0.0001
#defult LR = 0.0001
EPOCHS = 100

preprocess_input = sm.get_preprocessing(BACKBONE)

In [None]:
# define network parameters
n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)  # case for binary and multiclass segmentation
activation = 'sigmoid' if n_classes == 1 else 'softmax'
#create model
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)

In [None]:
# define optomizer
optim = keras.optimizers.Adam(LR)

# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
# set class weights for dice_loss (car: 1.; pedestrian: 2.; background: 0.5;)

# dice_loss = sm.losses.DiceLoss(class_weights=np.array([0.5, 1, 1, 1.5, 1.5, 2, 4, 2, 2, 2, 3])) 
# focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
# total_loss = dice_loss + (1 * focal_loss)

# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
total_loss = sm.losses.BinaryCELoss() # or sm.losses.categorical_focal_dice_loss 

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

# compile keras model with defined optimozer, loss and metrics
model.compile(optim, total_loss, metrics)

# Model Evaluation

In [None]:

# callbacks = [
#     keras.callbacks.ModelCheckpoint('../../weight/0301_ten_classes.h5', save_weights_only=True, save_best_only=True, mode='min'),
#     keras.callbacks.ReduceLROnPlateau(),
# ]

In [None]:
# load best weights
model.load_weights(weight_path) 

# Visualization of results on test dataset

In [None]:
test_dataset = Dataset(
    x_test_dir, 
    y_test_dir, 
    classes=CLASSES, 
    augmentation=None,
    preprocessing=get_preprocessing(preprocess_input),
    # preprocessing=get_preprocessing_v1(),
)

test_dataloader = Dataloder(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
assert test_dataloader[0][0].shape == (BATCH_SIZE, set_size, set_size, 3)
# assert test_dataloader[0][1].shape == (BATCH_SIZE, 1024, 1024, n_classes)

In [None]:
def predict_to_mask(pr_mask):
    blank_image = np.zeros((pr_mask.shape[0],pr_mask.shape[1],3), np.uint8)
    img_whiteside = np.array([255,255,255])
    background = np.array([0,0,0])
    bacmixgums = np.array([16,78,128])
    tooth = np.array([0,255,255])
    artifical_crown = np.array([255,255,0])
    cavity = np.array([0,0,255])
    overlap = np.array([255,0,255])
    gums = np.array([255,0,0])
    cej = np.array([0,255,0])
    img_depressed = np.array([3,128,253])
    gray = np.array([125,125,125])
    pixel_list = [gray,img_whiteside,background,bacmixgums, artifical_crown,tooth,overlap,cavity,cej,gums,img_depressed]
    for i in range(pr_mask.shape[0]):
        for j in range(pr_mask.shape[1]):
            blank_image[i][j] = pixel_list[np.argmax(pr_mask[i][j][1:])+1]
    return blank_image.astype('uint8')

def show_pixel_set(img_nparray):
    a = img_nparray
    unique, counts = np.unique(a, return_counts=True)
    return dict(zip(unique, counts))

In [None]:
def ensemble_image(img, model_list, flip_ = True):
    ensemble_final_result = np.zeros((img.shape[0], img.shape[1],11), dtype = 'float32')
    for index, model in enumerate(model_list):
        #print('img:',img.shape)
        pr_img = model.predict(img)
        #print('pr_img:',pr_img.shape)
        #sqpr_img = np.squeeze(pr_img)
        sqpr_img = pr_img
        #print('sqpr_img:',sqpr_img.shape)
        ensemble_img = sqpr_img.copy()
        if flip_:
            imgp = np.squeeze(img)
            h_flip = cv2.flip(imgp, 1)
            h_flip = np.expand_dims(h_flip, axis=(0))
            pre_h_flip = model.predict(h_flip)
            pre_h_flip = np.squeeze(pre_h_flip)
            h_flip = cv2.flip(pre_h_flip, 1)
            #print('h_flip:',h_flip.shape)
            
            v_flip = cv2.flip(imgp, 0)
            v_flip = np.expand_dims(v_flip, axis=(0))
            pre_v_flip = model.predict(v_flip)
            pre_v_flip = np.squeeze(pre_v_flip)
            v_flip = cv2.flip(pre_v_flip, 0)
            #print('v_flip:',v_flip.shape)
            
            hv_flip = cv2.flip(imgp, -1)
            hv_flip = np.expand_dims(hv_flip, axis=(0))
            pre_hv_flip = model.predict(hv_flip)
            pre_hv_flip = np.squeeze(pre_hv_flip)
            hv_flip = cv2.flip(pre_hv_flip, -1)
            #print('hv_flip:',hv_flip.shape)
            ensemble_img = (sqpr_img+h_flip+v_flip+hv_flip)/4
        ensemble_final_result = ensemble_final_result+ensemble_img
    ensemble_final_result = ensemble_final_result/(index+1)
    #return np.expand_dims(np.where(ensemble_final_result > 0.5, 1, 0), axis=(0,3))
    return np.expand_dims(ensemble_final_result, axis=(0,3))

In [None]:
import pandas as pd

show = False

ensemble = False

split_df = pd.read_csv(cut_df_path)
train_img_name = list(split_df[split_df.type == 'train'].image_id)
val_img_name = list(split_df[split_df.type == 'valid'].image_id)
test_img_name = list(split_df[split_df.type == 'test'].image_id)

cnt = 0

for i in tqdm(range(len(test_dataset)), total=len(test_dataset)):
    image, gt_mask, name = test_dataset[i]
    # image = cv2.resize(image, (set_size, set_size))
    image = np.expand_dims(image, axis=0)
    if ensemble:
        pr_mask = ensemble_image(image,[model], flip_=True)
    else:
        pr_mask = model.predict(image)
    image=denormalize(image.squeeze())
    pr_mask = np.squeeze(pr_mask)
    result = np.where(pr_mask > 0.5, 255, 0)
    # result = predict_to_mask(pr_mask)
#     print(result.shape)
#     print(result[12:1036,12:1036].shape)
    if name in train_img_name:
        cnt += 1
        cv2.imwrite(os.path.join(f'./result/{exp_name}/train_pred_raw/', name), result)
    elif name in val_img_name:
        cnt += 1
        cv2.imwrite(os.path.join(f'./result/{exp_name}/val_pred_raw/', name), result)
    elif name in test_img_name:
        cnt += 1
        cv2.imwrite(os.path.join(f'./result/{exp_name}/test_pred_raw/', name), result)
    else:
        cv2.imwrite(os.path.join(f'./result/{exp_name}/seg_raw/', name), result)
    if show:
        plt.figure(figsize = (20,20))
        plt.subplot(1,2,1)
        plt.imshow(image)
        plt.subplot(1,2,2)
        plt.imshow(result)
        plt.show()
print('generate', cnt, '/', len(split_df), 'label data, and', len(test_dataset) - cnt, 'unlabel data')