In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from util import *
from tensorflow.keras.optimizers import Adam

image_size = (224, 224)
bs = 32

better_model = False
load_pretrained_weights = True

losses = sparse_crossentropy_ignoring_last_label
metrics = {'pred_mask' : [Jaccard, sparse_accuracy_ignoring_last_label]}

backbone = 'mobilenetv2' #mobilenetv2, xception

NET = 'deeplab_' + backbone
PATH = '../dataset_preprocess/VOCdevkit/VOC2012/'

n_classes = len(get_VOC2012_classes()) - 1

print('Num workers:', workers)
print('Backbone:', backbone)
print('Path to dataset:', PATH)
print('N classes:', n_classes)
print('Image size:', image_size)
print('Batch size:', bs)


SegClass = SegModel(PATH, image_size)
SegClass.set_batch_size(bs)

In [None]:
if better_model:
    model = SegClass.create_seg_model(net='subpixel',n=n_classes, load_weights=load_pretrained_weights, 
                                      multi_gpu=False, backbone=backbone)
else:
    model = SegClass.create_seg_model(net='original',n=n_classes, load_weights=load_pretrained_weights, 
                                      multi_gpu=False, backbone=backbone)
    

model.compile(optimizer = Adam(lr=7e-4, epsilon=1e-8, decay=1e-6), loss = "categorical_crossentropy", metrics = ["accuracy"])
print('Weights path:', SegClass.modelpath)

In [None]:
def parse_code(l):
    '''Function to parse lines in a text file, returns separated elements (label codes and names in this case)
    '''
    if len(l.strip().split("\t")) == 2:
        a, b = l.strip().split("\t")
        return tuple(int(i) for i in a.split(' ')), b
    else:
        a, b, c = l.strip().split("\t")
        return tuple(int(i) for i in a.split(' ')), c

In [None]:
label_codes, label_names = zip(*[parse_code(l) for l in open("../dataset_preprocess/data/label_colors.txt")])
label_codes, label_names = list(label_codes), list(label_names)
label_codes[:5], label_names[:5]

In [None]:
code2id = {v:k for k,v in enumerate(label_codes)}
id2code = {k:v for k,v in enumerate(label_codes)}

In [None]:
name2id = {v:k for k,v in enumerate(label_names)}
id2name = {k:v for k,v in enumerate(label_names)}

In [None]:
id2code

In [None]:
def rgb_to_onehot(rgb_image, colormap = id2code):
    '''Function to one hot encode RGB mask labels
        Inputs: 
            rgb_image - image matrix (eg. 256 x 256 x 3 dimension numpy ndarray)
            colormap - dictionary of color to label id
        Output: One hot encoded image of dimensions (height x width x num_classes) where num_classes = len(colormap)
    '''
    num_classes = len(colormap)
    shape = rgb_image.shape[:2]+(num_classes,)
    encoded_image = np.zeros( shape, dtype=np.int8 )
    for i, cls in enumerate(colormap):
        encoded_image[:,:,i] = np.all(rgb_image.reshape( (-1,3) ) == colormap[i], axis=1).reshape(shape[:2])
    return encoded_image


def onehot_to_rgb(onehot, colormap = id2code):
    '''Function to decode encoded mask labels
        Inputs: 
            onehot - one hot encoded image matrix (height x width x num_classes)
            colormap - dictionary of color to label id
        Output: Decoded RGB image (height x width x 3) 
    '''
    single_layer = np.argmax(onehot, axis=-1)
    output = np.zeros( onehot.shape[:2]+(3,) )
    for k in colormap.keys():
        output[single_layer==k] = colormap[k]
    return np.uint8(output)

In [None]:
path = PATH + 'data'
import cv2

train_images = []
train_masks = []

i=0
for image_name in os.listdir(os.path.join(path, 'train_frames', 'train')):
    if(i%20==0):
        print(i)
        
#     if(i==100):
#         break
    img = cv2.imread(os.path.join(path, 'train_frames', 'train', '{}'.format(image_name)))
    
    if(img is None):
        continue
        
    img = cv2.resize(img, image_size)
    
    mask = cv2.imread(os.path.join(path, 'train_masks', 'train', '{}'.format(image_name)))
    mask = cv2.resize(mask, image_size, interpolation = cv2.INTER_NEAREST)
    mask_encoded = rgb_to_onehot(mask)
    
    train_images.append(img)
    train_masks.append(np.array(mask_encoded))
    
    i += 1
    
train_images = np.array(train_images)
train_masks = np.array(train_masks)

In [None]:
val_images = []
val_masks = []

i=0
for image_name in os.listdir(os.path.join(path, 'val_frames', 'val')):
    if(i%20==0):
        print(i)
    if(i==200):
        break
    img = cv2.imread(os.path.join(path, 'val_frames', 'val', '{}'.format(image_name)))
    if img is None:
        continue
    
    img = cv2.resize(img, image_size)
                
    mask = cv2.imread(os.path.join(path, 'val_masks', 'val', '{}'.format(image_name)))
    mask = cv2.resize(mask, image_size, interpolation = cv2.INTER_NEAREST)
    mask_encoded = rgb_to_onehot(mask)
    
    val_images.append(img)
    val_masks.append(np.array(mask_encoded))
    
    i += 1
    
val_images = np.array(train_images)
val_masks = np.array(train_masks)

val_data = (val_images, val_masks)

In [None]:
print(train_images.shape)
print(train_masks.shape)
print(val_images.shape)
print(val_masks.shape)
print(train_masks[0].shape)
print(train_images[0].shape)

In [None]:
train_generator = SegClass.create_generators(blur=5,crop_shape=None, mode='train', n_classes=n_classes,
                                             horizontal_flip=True, vertical_flip=False, brightness=0.3, 
                                             rotation=False, zoom=0.1, validation_split=.15, seed=7, do_ahisteq=False)
# valid_generator = SegClass.create_generators(blur=0, crop_shape=None, mode='validation', 
#                                              n_classes=n_classes, horizontal_flip=True, vertical_flip=False, 
#                                              brightness=.1, rotation=False, zoom=.05, validation_split=.15, 
#                                              seed=7, do_ahisteq=False)

In [None]:
# fine-tune model (train only last conv layers)
if load_pretrained_weights:
    flag = 0
    for k, l in enumerate(model.layers):
        l.trainable = False
        if l.name == 'concat_projection':
            flag = 1
        if flag:
            l.trainable = True
        
mc = ModelCheckpoint(mode='max', filepath='checkpoint-{epoch:003d}-{val_accuracy:.2f}.h5', monitor='accuracy', save_best_only='False', save_weights_only='True', verbose=1)

SegClass.set_num_epochs(100)
history = SegClass.train(model, train_images, train_masks, val_data, callback_list=[mc])

In [None]:
def mIOU(gt, preds):
    ulabels = np.unique(gt)
    iou = np.zeros(len(ulabels))
    for k, u in enumerate(ulabels):
        inter = (gt == u) & (preds==u)
        union = (gt == u) | (preds==u)
        iou[k] = inter.sum()/union.sum()
    return np.round(iou.mean(), 2)

In [None]:
def plotImages(img, y_true, y_pred):  
    fig = plt.figure(figsize=(20,8))

    ax1 = fig.add_subplot(1,3,1)
    ax1.imshow(img)
    ax1.title.set_text('Actual frame')
    ax1.grid(b=None)

    ax2 = fig.add_subplot(1,3,2)
    ax2.set_title('Ground truth labels')
    ax2.imshow(onehot_to_rgb(y_true,id2code))
    ax2.grid(b=None)

    ax3 = fig.add_subplot(1,3,3)
    ax3.set_title('Predicted labels')
    ax3.imshow(onehot_to_rgb(y_pred[0], id2code))
    ax3.grid(b=None)

    plt.show()

In [None]:
img = val_images[10]
x = np.expand_dims(img, axis=0)
y_true = val_masks[10]

model.load_weights('checkpoint-001-0.86.h5')
y_pred = model.predict(x)
plotImages(img, y_true, y_pred)

model.load_weights('checkpoint-003-0.84.h5')
y_pred = model.predict(x)
plotImages(img, y_true, y_pred)

model.load_weights('checkpoint-010-0.94.h5')
y_pred = model.predict(x)
plotImages(img, y_true, y_pred)

model.load_weights('checkpoint-021-0.95.h5')
y_pred = model.predict(x)
plotImages(img, y_true, y_pred)

model.load_weights('checkpoint-047-0.96.h5')
y_pred = model.predict(x)
plotImages(img, y_true, y_pred)

In [None]:
i = np.random.randint(0, len(val_images))
x = val_images[i]
x = np.expand_dims(x, axis=0)
y = val_masks[i]
y = np.expand_dims(y, axis=0)
preds1 = np.argmax(model.predict(x), -1)[0].reshape(image_size)

print(preds1.shape)
print(x.shape)

y_true = np.argmax(y).reshape(image_size)

im = x[0].astype('uint8')
gt = y_true.reshape(image_size).astype('int32')

plt.figure(figsize=(14,10))
plt.subplot(141)
plt.imshow(x[0].astype('uint8'))
plt.imshow(preds1, alpha=.5)
plt.title('Original DeepLab\nmIOU: '+str(mIOU(gt, preds1)))
plt.subplot(143)
MAP = do_crf(im, preds1, zero_unsure=False)
plt.imshow(x[0].astype('uint8'))
plt.imshow(MAP, alpha=.5)
plt.title('Original DeepLab + CRF\nmIOU: '+str(mIOU(gt, MAP)))

In [None]:
def calculate_iou(model, nb_classes = 21):
    label = np.zeros((len(val_masks),np.prod(image_size)), dtype='float32')
    X = np.zeros((len(val_images), image_size[0], image_size[1], 3), dtype='float32')
    for n in range(len(val_images)):
        x = val_images[n]
        y = val_masks[n]
        label[n,:] = y[0,:,0]
        X[n,:,:,:] = x
    preds = model.predict(X, batch_size=1)
    conf_m = np.zeros((nb_classes, nb_classes), dtype=float)
    total = 0
    mask = np.reshape(np.argmax(preds, axis=-1), (-1,) + image_size)
    flat_pred = np.ravel(mask).astype('int')
    flat_label = np.ravel(label).astype('int')
    for p, l in zip(flat_pred, flat_label):
        if l == nb_classes:
            continue
        if l < nb_classes and p < nb_classes:
            conf_m[l-1, p-1] += 1
        else:
            print('Invalid entry encountered, skipping! Label: ', l,
                    ' Prediction: ', p, ' Img_num: ', img_num)
    I = np.diag(conf_m)
    U = np.sum(conf_m, axis=0) + np.sum(conf_m, axis=1) - I
    IOU = I/U
    meanIOU = np.mean(IOU)
    return conf_m
conf_1 = calculate_iou(model, nb_classes = 21)

In [None]:
classes = [c for c in get_VOC2012_classes().values()][:-1]
plt.figure(figsize=(12,8))
plt.subplot(121)
cm1=plot_confusion_matrix(conf_1, classes, normalize=True)
plt.title('Original DeepLab\nMean IOU: '+ str(np.round(np.diag(cm1).mean(), 2)))