### Reqirements
- keras >= 2.2.0 or tensorflow >= 1.13
- segmenation-models==1.0.*
- albumentations==0.3.0

In [None]:
# # Install required libs

# ### please update Albumentations to version>=0.3.0 for `Lambda` transform support
# !pip install -U albumentations>=0.3.0 --user 
# !pip install -U --pre segmentation-models --user

In [None]:
!python --version

In [None]:
!pip --version

In [None]:
argumentation_type = '7_classes'

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import glob
import cv2
import keras
import shutil
import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm as tqdm

In [None]:
data_folder_name = '1002_7_classes_data/train'
DATA_DIR = './{}'.format(data_folder_name)

set_size = 1024
set_padding = 1280

In [None]:
def remove_check_point(path):
    for file_name in os.listdir(path):
        #print(file_name)
        if 'checkpoints' in file_name:
            shutil.rmtree(os.path.join(path, file_name))
# remove_check_point(x_train_dir)
# remove_check_point(y_train_dir)

# Dataloader and utility functions 

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
# helper function for data visualization    
def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)    
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x
    

# classes for data loading and preprocessing
class Dataset:
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = ['background', 'artifical_crown','tooth','overlap','cavity','cej','gums']
    
    def __init__(
            self, 
            images_path_list, 
            masks_path_list, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
#         self.ids = os.listdir(images_dir)
#         self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
#         self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        self.images_fps = images_path_list
        self.masks_fps = masks_path_list
        self.ids = [img_path.split('/')[-1] for img_path in masks_path_list]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # add background if mask is not binary
        if mask.shape[-1] != 1:
            background = 1 - mask.sum(axis=-1, keepdims=True)
            mask = np.concatenate((mask, background), axis=-1)
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        return image, mask
        
    def __len__(self):
        return len(self.ids)
    
    
class Dataloder(keras.utils.Sequence):
    """Load data from dataset and form batches
    
    Args:
        dataset: instance of Dataset class for image loading and preprocessing.
        batch_size: Integet number of images in batch.
        shuffle: Boolean, if `True` shuffle image indexes each epoch.
    """
    
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))
        self.on_epoch_end()

    def __getitem__(self, i):
        
        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])
        
        # transpose list of lists
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]
        
        return batch
    
    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.indexes) // self.batch_size
    
    def on_epoch_end(self):
        """Callback function to shuffle indexes each epoch"""
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)   

In [None]:
# Lets look at data we have
# img_list = glob.glob(x_train_dir+'/*.PNG')
# img_list.sort()
# mask_list = glob.glob(y_train_dir+'/*.PNG')
# mask_list.sort()
# dataset = Dataset(img_list, mask_list, classes= ['background', 'artifical_crown','tooth','overlap','cavity','cej','gums'])
# print(dataset.class_values)
# print(len(dataset))
# image, mask = dataset[65] # get some sample

# visualize(
#     image=image, 
#     arg_black_space=mask[..., 0].squeeze(),
#     background_mask=mask[..., 1].squeeze(),
#     artifical_crown_mask =mask[..., 2].squeeze(),
#     tooth_mask =mask[..., 3].squeeze(),
#     overlap_mask=mask[..., 4].squeeze(),
#     cavity_mask =mask[..., 5].squeeze(),
#     cej_mask=mask[..., 6].squeeze(),
#     gums_mask =mask[..., 7].squeeze(),
# )
# print(image.shape)

### Augmentations

Data augmentation is a powerful technique to increase the amount of your data and prevent model overfitting.  
If you not familiar with such trick read some of these articles:
 - [The Effectiveness of Data Augmentation in Image Classification using Deep
Learning](http://cs231n.stanford.edu/reports/2017/pdfs/300.pdf)
 - [Data Augmentation | How to use Deep Learning when you have Limited Data](https://medium.com/nanonets/how-to-use-deep-learning-when-you-have-limited-data-part-2-data-augmentation-c26971dc8ced)
 - [Data Augmentation Experimentation](https://towardsdatascience.com/data-augmentation-experimentation-3e274504f04b)

Since our dataset is very small we will apply a large number of different augmentations:
 - horizontal flip
 - affine transforms
 - perspective transforms
 - brightness/contrast/colors manipulations
 - image bluring and sharpening
 - gaussian noise
 - random crops

All this transforms can be easily applied with [**Albumentations**](https://github.com/albu/albumentations/) - fast augmentation library.
For detailed explanation of image transformations you can look at [kaggle salt segmentation exmaple](https://github.com/albu/albumentations/blob/master/notebooks/example_kaggle_salt.ipynb) provided by [**Albumentations**](https://github.com/albu/albumentations/) authors.


In [None]:
import albumentations as A

In [None]:
def round_clip_0_1(x, **kwargs):
    return x.round().clip(0, 1)

# define heavy augmentations
def get_training_augmentation():
    train_transform = [
        #水平翻轉
        A.HorizontalFlip(p=0.5),
        #垂直翻轉
        A.VerticalFlip(p=0.5),
        #拉升 旋轉 位移
        A.ShiftScaleRotate(scale_limit=0.3, rotate_limit=5, shift_limit=0.1, p=1, border_mode=0),
        
#         #set_padding = 1280, set_size = 1024 
#         #原圖1024*1024 將外部padding成 1280*1280
#         A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
#         #從1280*1280切割成1024
#         A.RandomCrop(height=set_size, width=set_size, always_apply=True),
        
#         #2成機率加入雜訊
#         A.IAAAdditiveGaussianNoise(p=0.2),
        
        #套用clahe 隨機亮度調整
        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightness(p=1),
                A.RandomGamma(p=1),
            ],
            p=0.5,
        ),

#         A.OneOf(
#             [
#                 A.IAASharpen(p=1),
#                 A.Blur(blur_limit=3, p=1),
#                 A.MotionBlur(blur_limit=3, p=1),
#             ],
#             p=0.5,
#         ),

        A.OneOf(
            [
                A.RandomContrast(p=1),
            ],
            p=0.5,
        ),
        A.Lambda(mask=round_clip_0_1)
    ]
    return A.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        A.PadIfNeeded(set_padding, set_padding),
        A.RandomCrop(height=set_size, width=set_size, always_apply=True)
    ]
    return A.Compose(test_transform)

def get_training_enhance_augmentation():
    """Add paddings to make image shape divisible by 32"""
    fine_tune_transform = [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=90, shift_limit=0.1, p=1, border_mode=0),

        A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
        A.RandomCrop(height=set_size, width=set_size, always_apply=True),
        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightness(p=1),
                A.RandomGamma(p=1),
            ],
            p=0.9,
        )
    ]
    return A.Compose(fine_tune_transform)

def get_fine_tune_augmentation():
    """Add paddings to make image shape divisible by 32"""
    fine_tune_transform = [
        A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
        #A.HorizontalFlip(p=0.5),
#         A.OneOf(
#             [
#                 A.CLAHE(p=1),
#                 A.RandomBrightness(p=1),
#                 A.RandomGamma(p=1),
#             ],
#             p=1,
#         ),
#         A.OneOf(
#             [
#                 A.RandomContrast(p=1),
#                 A.HueSaturationValue(p=1),
#             ],
#             p=1,
#         ),
#         A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=15, shift_limit=0.1, p=1, border_mode=0),
#         A.PadIfNeeded(min_height=set_padding, min_width=set_padding, always_apply=True, border_mode=0),
#         A.RandomCrop(height=set_size, width=set_size, always_apply=True),
#         A.IAAAdditiveGaussianNoise(p=1),
#         A.OneOf(
#             [
#                 A.CLAHE(p=1),
#                 A.RandomBrightness(p=1),
#                 A.RandomGamma(p=1),
#             ],
#             p=0.5,
#         )
    ]
    return A.Compose(fine_tune_transform)

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        A.Lambda(image=preprocessing_fn),
    ]
    return A.Compose(_transform)

# Segmentation model training

In [None]:
import segmentation_models as sm

# segmentation_models could also use `tf.keras` if you do not have Keras installed
# or you could switch to other framework using `sm.set_framework('tf.keras')`

In [None]:
def intit_Unet_model():
    BACKBONE = 'efficientnetb3'
    BATCH_SIZE = 2 #改成1
    CLASSES = ['background', 'artifical_crown','tooth','overlap','cavity','cej','gums']
    #['background','bacmixgums', 'artifical_crown','tooth','gums','overlap','cavity','cej']
    LR = 0.0001 
    #defult LR = 0.0001
    EPOCHS = 100

    preprocess_input = sm.get_preprocessing(BACKBONE)

    # define network parameters
    n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)  # case for binary and multiclass segmentation
    activation = 'sigmoid' if n_classes == 1 else 'softmax'
    #create model
    model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)

    # define optomizer
    optim = keras.optimizers.Adam(LR)

    # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
    # set class weights for dice_loss (car: 1.; pedestrian: 2.; background: 0.5;)
    dice_loss = sm.losses.DiceLoss(class_weights=np.array([0.5, 1, 1.5, 3, 2, 2, 3, 3])) 
    focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
    total_loss = dice_loss + (1 * focal_loss)

    # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
    # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss 

    metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

    # compile keras model with defined optimozer, loss and metrics
    model.compile(optim, total_loss, metrics)
    return model

In [None]:
BACKBONE = 'efficientnetb3'
BATCH_SIZE = 2 #改成1
CLASSES = ['background', 'artifical_crown','tooth','overlap','cavity','cej','gums']
#['background','bacmixgums', 'artifical_crown','tooth','gums','overlap','cavity','cej']
LR = 0.0001 
#defult LR = 0.0001
EPOCHS = 100

preprocess_input = sm.get_preprocessing(BACKBONE)

n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)

In [None]:
img_list_train = glob.glob('./1002_7_classes_data/train/image/*.PNG')
mask_list_train = glob.glob('./1002_7_classes_data/train/label/*.PNG')

In [None]:
def predict_to_mask(pr_mask):
    blank_image = np.zeros((pr_mask.shape[0],pr_mask.shape[1],3), np.uint8)
    img_whiteside = np.array([255,255,255])
    background = np.array([0,0,0])
    bacmixgums = np.array([16,78,128])
    tooth = np.array([0,255,255])
    artifical_crown = np.array([255,255,0])
    cavity = np.array([0,0,255])
    overlap = np.array([255,0,255])
    gums = np.array([255,0,0])
    cej = np.array([0,255,0])
    img_depressed = np.array([3,128,253])
    gray = np.array([125,125,125])
    pixel_list = [gray, background, artifical_crown ,tooth, overlap, cavity, cej, gums]
    for i in range(pr_mask.shape[0]):
        for j in range(pr_mask.shape[1]):
            blank_image[i][j] = pixel_list[np.argmax(pr_mask[i][j][1:])+1]
    return blank_image.astype('uint8')

def show_pixel_set(img_nparray):
    a = img_nparray
    unique, counts = np.unique(a, return_counts=True)
    return dict(zip(unique, counts))

In [None]:
def ensemble_image(img, model_list, flip_ = True):
    ensemble_final_result = np.zeros((img.shape[0], img.shape[1],8), dtype = 'float32')
    for index, model in enumerate(model_list):
        #print('img:',img.shape)
        pr_img = model.predict(img)
        #print('pr_img:',pr_img.shape)
        #sqpr_img = np.squeeze(pr_img)
        sqpr_img = pr_img
        #print('sqpr_img:',sqpr_img.shape)
        ensemble_img = sqpr_img.copy()
        if flip_:
            imgp = np.squeeze(img)
            h_flip = cv2.flip(imgp, 1)
            h_flip = np.expand_dims(h_flip, axis=(0))
            pre_h_flip = model.predict(h_flip)
            pre_h_flip = np.squeeze(pre_h_flip)
            h_flip = cv2.flip(pre_h_flip, 1)
            #print('h_flip:',h_flip.shape)
            
            v_flip = cv2.flip(imgp, 0)
            v_flip = np.expand_dims(v_flip, axis=(0))
            pre_v_flip = model.predict(v_flip)
            pre_v_flip = np.squeeze(pre_v_flip)
            v_flip = cv2.flip(pre_v_flip, 0)
            #print('v_flip:',v_flip.shape)
            
            hv_flip = cv2.flip(imgp, -1)
            hv_flip = np.expand_dims(hv_flip, axis=(0))
            pre_hv_flip = model.predict(hv_flip)
            pre_hv_flip = np.squeeze(pre_hv_flip)
            hv_flip = cv2.flip(pre_hv_flip, -1)
            #print('hv_flip:',hv_flip.shape)
            ensemble_img = (sqpr_img+h_flip+v_flip+hv_flip)/4
        ensemble_final_result = ensemble_final_result+ensemble_img
    ensemble_final_result = ensemble_final_result/(index+1)
    #return np.expand_dims(np.where(ensemble_final_result > 0.5, 1, 0), axis=(0,3))
    return np.expand_dims(ensemble_final_result, axis=(0))

In [None]:
Set_epoch = 100
model = intit_Unet_model()

# Generate mask 

In [None]:
# TODO: change weight path if you want
weight_path = './weight/0801_7_classes_2_train_ten_classes_100.h5'
show = False
ensemble=True

img_list = glob.glob('../data/resize_data/*.PNG')
# TODO: change image folder

print(len(img_list))

In [None]:

test_dataset = Dataset(
    img_list, 
    img_list,
    classes=CLASSES, 
    augmentation=None,
    preprocessing=get_preprocessing(preprocess_input),
)

test_dataloader = Dataloder(test_dataset, batch_size=1, shuffle=False)


model.load_weights(weight_path)

all_classes_save_folder = './result/raw_seg'
assert os.path.exists(all_classes_save_folder), "image out put path " + str(all_classes_save_folder) + " do not exits"

print('Get',len(test_dataset), 'data for testing')


for i, path in tqdm(enumerate(img_list), total=len(img_list)):
    name = os.path.basename(path)
    image, gt_mask = test_dataset[i]
    image = np.expand_dims(image, axis=0)
    if ensemble:
        pr_mask = ensemble_image(image,[model], flip_=True)
    else:
        pr_mask = model.predict(image)
    image=denormalize(image.squeeze())
    pr_mask = np.squeeze(pr_mask)
    result = predict_to_mask(pr_mask)
#     print(result.shape)
#     print(result[12:1036,12:1036].shape)
    if show:
        plt.imshow(result)
        plt.show()
    cv2.imwrite(os.path.join(all_classes_save_folder, name), result)