In [None]:
# !pip install keras==2.2.4

In [None]:
!pip install git+https://github.com/qubvel/segmentation_models

In [None]:
import os
import warnings
import numpy as np
import random
import matplotlib.pyplot as plt
import cv2
import keras

In [None]:
# print(keras.__version__)

In [None]:
from model import Deeplabv3

warnings.filterwarnings('ignore')
deeplab_model = Deeplabv3(input_shape=(768, 1152, 3), classes=3)

In [None]:
"""
Directory structure
-home
    -username
        -deeplabv3
            -dataset
                -train
                    -images
                    -masks
                -test
                    -images
                    -masks
            -deeplabv3.ipynb
"""
# give the root path to the dataset
DATA_PATH = '/path/to/the/dataset/'
FRAME_PATH = DATA_PATH+'/train/images/'
MASK_PATH = DATA_PATH+'/train/masks/'

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
# helper function for data visualization    
def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)    
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x
    

# classes for data loading and preprocessing
class Dataset:
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = ['sky', 'building', 'pole', 'road', 'pavement', 
               'tree', 'signsymbol', 'fence', 'car', 
               'pedestrian', 'bicyclist', 'unlabelled', 'background']
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # add background if mask is not binary
        if mask.shape[-1] != 1:
            background = 1 - mask.sum(axis=-1, keepdims=True)
            mask = np.concatenate((mask, background), axis=-1)
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)
    
    
class Dataloder(keras.utils.Sequence):
    """Load data from dataset and form batches
    
    Args:
        dataset: instance of Dataset class for image loading and preprocessing.
        batch_size: Integet number of images in batch.
        shuffle: Boolean, if `True` shuffle image indexes each epoch.
    """
    
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))

        self.on_epoch_end()

    def __getitem__(self, i):
        
        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])
        
        # transpose list of lists
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]
        
        return batch
    
    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.indexes) // self.batch_size
    
    def on_epoch_end(self):
        """Callback function to shuffle indexes each epoch"""
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)
    
#     def __next__(self):
#         if self.n >= self.max:
#             self.n = 0
#         result = self.__getitem__(self.n)
#         self.n += 1
#         return result

In [None]:
# Lets look at data we have
dataset = Dataset(FRAME_PATH, MASK_PATH, classes=['sky', 'road'])
image, mask = dataset[random.randint(0, 100)] # get some sample
visualize(
    image=image, 
    masks=mask[..., 0],
)

In [None]:
train_dataset = Dataset(
    FRAME_PATH, 
    MASK_PATH, 
    classes=['sky', 'road'],
)

val_dataset = Dataset(
    DATA_PATH+'/test/images/', 
    DATA_PATH+'/test/masks/', 
    classes=['sky', 'road'],
)

train_dataloader = Dataloder(train_dataset, batch_size=2, shuffle=True)
val_dataloader = Dataloder(val_dataset, batch_size=1, shuffle=False)

In [None]:
import segmentation_models as sm
# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

# # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss 

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

# compile keras model with defined optimozer, loss and metrics
deeplab_model.compile(optimizer= 'adam', loss= total_loss, metrics= metrics)

In [None]:
# load the weights after first run

# deeplab_model.load_weights('./deeplab_v3_plus.h5')

fitted_model= deeplab_model.fit_generator(train_dataloader,steps_per_epoch=len(train_dataloader), epochs=500,)
deeplab_model.save_weights('./deeplab_v3_plus.h5')

In [None]:
scores = deeplab_model.evaluate_generator(val_dataloader)

print("Loss: {:.5}".format(scores[0]))
for metric, value in zip(metrics, scores[1:]):
    print("mean {}: {:.5}".format(metric.name, value))

In [None]:
n = 3
ids = np.random.choice(np.arange(len(val_dataset)), size=n)

for i in ids:
    image, gt_mask = val_dataset[i]
    image = np.expand_dims(image, axis=0)
    pr_mask = deeplab_model.predict(image).round()
    
    visualize(
        image=denormalize(image.squeeze()),
        gt_mask=gt_mask[..., 0].squeeze(),
        pr_mask=pr_mask[..., 0].squeeze(),
    )