<a href="https://colab.research.google.com/github/vagmin27/DeepLearning/blob/main/ImageSegementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Library
https://github.com/qubvel/segmentation_models.pytorch

They have provided an example. For their example, they have used **CamVid** dataset. It is a set of:
 - **train** images + segmentation masks
 - **validation** images + segmentation masks
 - **test** images + segmentation masks

All images have 320 pixels height and 480 pixels width.
For more inforamtion about dataset visit http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/.

In [None]:
# Install required libs
!pip install -U segmentation-models-pytorch albumentations --user



#Loading Data

Here, we will use **GroZi Gap** dataset. It is a set of:
 - **train** images + segmentation masks
 - **test** images + segmentation masks

We will use a subset of the dataset in this exmaple.

For more inforamtion about dataset visit (https://github.com/gapDetection/gapDetectionDatasets).

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2
import matplotlib.pyplot as plt

In [None]:
#modelName = 'Unet'
#modelName = 'Linknet'
#modelName = 'FPN'
# modelName = 'PSPNet'
#modelName = 'DeepLabV3'
#modelName = 'DeepLabV3Plus'
#modelName = 'PAN'
modelName = 'MAnet'
#modelName = 'UnetPlusPlus'

DATASET_NAME = 'GroZi'
fold = 'fold_6'

DATA_DIR = './FoldData/' + DATASET_NAME + '/Fold_1/'+ fold + '/'

In [None]:
x_train_dir = os.path.join(DATA_DIR, 'trainImages')
y_train_dir = os.path.join(DATA_DIR, 'trainMasks')

x_valid_dir = os.path.join(DATA_DIR, 'trainImages')
y_valid_dir = os.path.join(DATA_DIR, 'trainMasks')

x_test_dir = os.path.join(DATA_DIR, 'testImages')
y_test_dir = os.path.join(DATA_DIR, 'testMasks')

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

### Dataloader

Writing helper class for data extraction, tranformation and preprocessing  
https://pytorch.org/docs/stable/data

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

In [None]:
class Dataset(BaseDataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing
            (e.g. noralization, shape manipulation, etc.)

    """

    CLASSES = ['nongap', 'gap']

    def __init__(
            self,
            images_dir,
            masks_dir,
            classes=None,
            augmentation=None,
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
#         print(self.images_fps[i], 'image = ', image)
#         print(type(image), image.dtype, image.shape)
        image = cv2.resize(image, (480, 320),
               interpolation = cv2.INTER_NEAREST)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         plt.imshow(image)

        mask = cv2.imread(self.masks_fps[i], 0)
#         print(self.masks_fps[i], 'mask = ', mask)
#         print(self.masks_fps[i])
#         print(type(mask), mask.shape, image.dtype)
        mask = cv2.resize(mask, (480, 320),
               interpolation = cv2.INTER_NEAREST)
        mask[mask > 0] = 1
#         print(mask, type(mask))
#         plt.imshow(mask)

        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
#         print(masks)
        mask = np.stack(masks, axis=-1).astype('float')

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        return len(self.ids)

In [None]:
# Lets look at data we have

dataset = Dataset(x_train_dir, y_train_dir, classes=['gap'])

# for i in range(len(dataset)):
image, mask = dataset[4] # get some sample
visualize(
    image=image,
    cars_mask=mask.squeeze(),
)

Data augmentation is a powerful technique to increase the amount of your data and prevent model overfitting.  
If you not familiar with such trick read some of these articles:
 - [The Effectiveness of Data Augmentation in Image Classification using Deep
Learning](http://cs231n.stanford.edu/reports/2017/pdfs/300.pdf)
 - [Data Augmentation | How to use Deep Learning when you have Limited Data](https://medium.com/nanonets/how-to-use-deep-learning-when-you-have-limited-data-part-2-data-augmentation-c26971dc8ced)
 - [Data Augmentation Experimentation](https://towardsdatascience.com/data-augmentation-experimentation-3e274504f04b)

Since our dataset is very small we will apply a large number of different augmentations:
 - horizontal flip
 - affine transforms
 - perspective transforms
 - brightness/contrast/colors manipulations
 - image bluring and sharpening
 - gaussian noise
 - random crops

All this transforms can be easily applied with [**Albumentations**](https://github.com/albu/albumentations/) - fast augmentation library.
For detailed explanation of image transformations you can look at [kaggle salt segmentation exmaple](https://github.com/albu/albumentations/blob/master/notebooks/example_kaggle_salt.ipynb) provided by [**Albumentations**](https://github.com/albu/albumentations/) authors.

In [None]:
import albumentations as albu

In [None]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),

        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),

        albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        albu.RandomCrop(height=320, width=320, always_apply=True),

        albu.IAAAdditiveGaussianNoise(p=0.2),
        albu.IAAPerspective(p=0.5),

        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.IAASharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(384, 480)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
# #### Visualize resulted augmented images and masks
# augmented_dataset = Dataset(
#     x_train_dir,
#     y_train_dir,
#     augmentation=get_training_augmentation(),
#     classes=['gap'],
# )

# # same image with different random transforms
# for i in range(10):
#     image, mask = augmented_dataset[1]
#     visualize(image=image, mask=mask.squeeze(-1))

## Create model and train

In [26]:
import sys
!{sys.executable} -m pip install segmentation-models-pytorch




In [27]:
import torch
import numpy as np
import segmentation_models_pytorch as smp
print("SMP version:", smp.__version__)


In [None]:
ENCODER = 'se_resnext101_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['gap']
ACTIVATION = 'sigmoid'
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
if modelName is 'Unet':
    model = smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

elif modelName is 'Linknet':
    model = smp.Linknet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

elif modelName is 'FPN':
    model = smp.FPN(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

elif modelName is 'PSPNet':
    model = smp.PSPNet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

elif modelName is 'DeepLabV3':
    model = smp.DeepLabV3(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

elif modelName is 'DeepLabV3Plus':
    model = smp.DeepLabV3Plus(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

elif modelName is 'PAN':
    model = smp.PAN(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

elif modelName is 'UnetPlusPlus':
    model = smp.UnetPlusPlus(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

elif modelName is 'MAnet':
    model = smp.MAnet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
train_dataset = Dataset(
    x_train_dir,
    y_train_dir,
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

valid_dataset = Dataset(
    x_valid_dir,
    y_valid_dir,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False, num_workers=4)

In [None]:
# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index

loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=0.0001),
])

In [None]:
# create epoch runners
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

In [None]:
# train model for 40 epochs

max_score = 0

for i in range(0, 1):

    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)

    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model_'+ modelName + '_' + DATASET_NAME + '.pth')
        print('Model saved!')

    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

## Test best saved model

In [None]:
# load best saved checkpoint
best_model = torch.load('./best_model_' + modelName + '_' + DATASET_NAME + '.pth')

In [None]:
import time

# start = time.clock()
start = time.time()
# create test dataset
test_dataset = Dataset(
    x_test_dir,
    y_test_dir,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

test_dataloader = DataLoader(test_dataset)

In [None]:
# evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)
# end = time.clock()
end = time.time()
print(end-start)
print('Result of ' + modelName)
logs = test_epoch.run(test_dataloader)

## Visualize predictions

In [None]:
# test dataset without transformations for image visualization
test_dataset_vis = Dataset(
    x_test_dir, y_test_dir,
    classes=CLASSES,
)

In [30]:
# for i in range(1):
#     n = np.random.choice(len(test_dataset))

#     fileName = test_dataset.ids[n]
#     print(fileName)
#     image_vis = test_dataset_vis[n][0].astype('uint8')
#     image, gt_mask = test_dataset[n]

#     gt_mask = gt_mask.squeeze()

#     x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
#     pr_mask = best_model.predict(x_tensor)
#     pr_mask = (pr_mask.squeeze().cpu().numpy().round())

#     visualize(
#         image=image_vis,
#         ground_truth_mask=gt_mask,
#         predicted_mask=pr_mask
#     )

In [28]:
# #### SAVING BINARY MASK IN STORAGE DEVICE ######

# from PIL import Image
# from torchvision.utils import save_image
# import matplotlib

# for i in range(len(test_dataset)):
# # for i in range(1):
#     fileName = test_dataset.ids[i]
#     print(fileName)
# #     fileName = fileName[:-4]
# #     print(fileName)
# #     fileName = fileName + '.bmp'

#     image_vis = test_dataset_vis[i][0].astype('uint8')
#     image, gt_mask = test_dataset[i]

#     gt_mask = gt_mask.squeeze()

#     x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
#     pr_mask_Ori = best_model.predict(x_tensor)
#     pr_mask_Ori = pr_mask_Ori.squeeze().round()
#     pr_mask_bin = pr_mask_Ori.cpu().numpy()

# #     matplotlib.image.imsave(DATA_DIR+'predictedMask'+'/'+fileName, pr_mask_bin)

#     if not os.path.exists(DATA_DIR+'predictedMask_'+ modelName):
#         os.mkdir(DATA_DIR+'predictedMask_'+ modelName)

#     plt.imsave(DATA_DIR+'predictedMask_'+modelName+'/'+fileName,pr_mask_bin)