# One Shot Video Object Segmentation.
The Implementation is based on the paper [https://arxiv.org/abs/1611.05198]. Whole implementation is in tensorflow keras.The Dataset used for training purpose is Davis. The Pipeline followed in the notebook is
- Import of module and packages.
- Data Processing and analysis.
- DataGenerator along with Augmentation function.
- Model Creation.
- Training Script.
- Fine tunning and Testing on videos

 ## Import of different modules and packages.

In [1]:
import numpy as np
import tensorflow as tf
import cv2
import os
import glob
import random
import tensorflow as tf
from tensorflow.keras import layers, Model, regularizers
from tensorflow import keras


img_folder_path = "/Users/tangerine/PycharmProjects/dataset/DAVIS2017/train_data/Train/"
img_annotation_path = "/Users/tangerine/PycharmProjects/dataset/DAVIS2017/train_data/Train_Annotated/"
weight_path = "/Users/tangerine/stryker/models/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"

## Data Processing and analysis.

In [2]:
class Data_analysis:
    def __init__(self,img_folder_path,img_annotation_path):
        
        self.img_folder_path = img_folder_path   
        self.img_annotation_path = img_annotation_path
        
        self.img_path_train = []     # Training list containing image path
        self.target_path_train = []  # Training list containing image mask
        self.img_path_val = []       # Validation list containing image path
        self.target_path_val = []    # Validation list containing image mask
        
    def __call__(self,visualize=False):
        for roots,dirs,files in os.walk(self.img_folder_path):
            for dir in dirs:
                dir_path_image = os.path.join(roots,dir)  
                dir_path_anno  = os.path.join(self.img_annotation_path,dir)
                image_files   = os.listdir(dir_path_image)  # all the files in the current directory
                random.shuffle(image_files)                 # random shuffling and splitting in train val 
                length  = len(image_files)
                length_train = int(length*0.85)
                for file in image_files[:length_train]:
                    img_path = os.path.join(dir_path_image,file)
                    annotation_path = os.path.join(dir_path_anno,file[:-4]+'.png')
                    self.img_path_train.append(img_path)
                    self.target_path_train.append(annotation_path)
                
                for file in image_files[length_train:]:
                    img_path = os.path.join(dir_path_image,file)
                    annotation_path = os.path.join(dir_path_anno,file[:-4]+'.png')
                    self.img_path_val.append(img_path)
                    self.target_path_val.append(annotation_path)
                    
        if visualize:
            visualization()                          
                    
    def visualization():
        pass
data = Data_analysis(img_folder_path,img_annotation_path)
data()

## DataGenerator and Augmentation.

In [3]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self,batch_size, input_img_paths, target_img_paths,img_size=(300,300)):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths

    def __len__(self):
        return len(self.target_img_paths)//self.batch_size  # number of steps per epoch

    def __getitem__(self, idx):
        i = idx * self.batch_size
        batch_inp = self.input_img_paths[i:i+self.batch_size]
        batch_target = self.target_img_paths[i:i+self.batch_size]
        x = np.zeros((self.batch_size,) + (self.img_size[1],self.img_size[0]) + (3,), dtype="float32")  # batch data
        y = np.zeros((self.batch_size,) + (self.img_size[1],self.img_size[0]) + (1,), dtype="uint8")
        
        for j in range(self.batch_size):
            img_path = batch_inp[j]
            mask_path = batch_target[j]
            
            img = cv2.imread(img_path)
            img = cv2.resize(img,self.img_size)
            
            mask = cv2.imread(mask_path,0)
            mask = cv2.resize(mask,self.img_size)
            mask = np.where(mask>0,1,0).astype('float32')
            
            img,mask = self.augment(img,mask)
            x[j] = img/255.0  # Normalization
            y[j] = np.expand_dims(mask,2)               

        return x,y
    
    def augment(self,img,mask):
        
        flag = np.random.randint(0,2)
        if flag == 0:   # horizontal_flip
            img = cv2.flip(img,1)
            mask = cv2.flip(mask,1)
        if flag == 1:  # rotation in range of -15 to 15
            height, width = img.shape[:2]
            center = (width/2, height/2)
            rotate_mat = cv2.getRotationMatrix2D(center=center,angle = 20,scale = 1)
            img = cv2.warpAffine(src=img, M=rotate_mat,dsize=(width,height))
            mask = cv2.warpAffine(src=mask, M=rotate_mat,dsize=(width,height))

        return img,mask
    

# for x,y in data_gen_train:
    # print(x.shape,y.shape)

## Parent Network and Training.
### Parent Network.
The Parent Netork consist of Base Network, skip connections and transposed convolutional layer. The Base Network used here is
VGG16 which is initialized with imagenet weights. The last Convolutional layer output from each block(starting from block2) before Maxpolling layer of VGG16 are passed to covolutional layer and transposed convolutional layer separately and finally fused forming a skip connections.
### Loss Function.
The cost function used for training is weighted pixelwise cross entropy to deal with class imbalance.


In [4]:
def weighted_pixelwise_cross_entropy(label, output):
    output = tf.nn.sigmoid(output)
    labels_pos = tf.cast(tf.greater(label, 0), tf.float32)
    labels_neg = tf.cast(tf.less(label, 1), tf.float32)

    num_labels_pos = tf.reduce_sum(labels_pos)
    num_labels_neg = tf.reduce_sum(labels_neg)
    num_total = num_labels_pos + num_labels_neg

    loss_pos = tf.reduce_sum(tf.multiply(labels_pos, tf.math.log(output + 0.00001)))
    loss_neg = tf.reduce_sum(tf.multiply(labels_neg, tf.math.log(1 - output + 0.00001)))

    final_loss = -num_labels_neg / num_total * loss_pos - num_labels_pos / num_total * loss_neg

    return final_loss

def vgg16_osvs(input_shape,weight='imagenet'):
    """
    Args:
        weights: either Initialization method available in tensorflow or Imagenet weights
        input_shape: Input shape of Image
        pooling:
    Returns:
    """ 
    kernel_regularizer=regularizers.l1_l2(l1=1e-3, l2=1e-3)
    vgg_arch =[
        # block1
        [['conv', 64 ],['conv', 64 ],['pool']],
        # block2       
        [['conv', 128],['conv', 128],['pool']],
        #block3
        [['conv', 256],['conv', 256],['conv', 256  ],['pool']],
        #block4
        [['conv', 512],['conv', 512],['conv', 512  ],['pool']],
        #block5
        [['conv', 512],['conv', 512],['conv', 512  ],['pool']],          
    ]
    
    img_input = layers.Input(shape=input_shape)
    _,h,w,_ = tf.shape(img_input)
    block_cnt = 0
    aux_tensor = []
    for block in vgg_arch:
        block_cnt +=1
        lyr_cnt = 0
        for i,lyr in enumerate(block):
            lyr_cnt+=1
            if lyr[0] == 'conv':
                out_ch = lyr[1]
                name = f'block{block_cnt}_conv{lyr_cnt}'
                if lyr_cnt == 1 and block_cnt == 1:                   
                    x = layers.Conv2D(out_ch,(3,3),padding='same',activation='relu',
                                      kernel_regularizer=kernel_regularizer,name = name)(img_input)
                else:
                    x = layers.Conv2D(out_ch,(3,3),padding='same',activation= 'relu',
                                      kernel_regularizer=kernel_regularizer,name=name)(x)
            
            elif lyr[0] == 'pool':
                aux_lyr = f'aux_lyr1_{block_cnt}'
                aux_lyr = layers.Conv2D(16,(3,3),padding='same',name=aux_lyr)(x)
                aux_tensor.append(aux_lyr)
                
                name = f'block{block_cnt}_pool'
                x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=name)(x)

    ### Main Output ####
    stage = 'transposed_lyr_'


    tr_lyr2 = layers.Conv2DTranspose(16,(4,4),strides=2,name = stage+'1')(aux_tensor[1])
    tr_lyr2 = tr_lyr2[:,:h,:w,:]
    
    tr_lyr3 = layers.Conv2DTranspose(16,(8,8),strides=4,name = stage+'2')(aux_tensor[2])
    tr_lyr3 = tr_lyr3[:,:h,:w,:]
    
    tr_lyr4 = layers.Conv2DTranspose(16,(16,16),strides=8,name = stage+'3')(aux_tensor[3])
    tr_lyr4 = tr_lyr4[:,:h,:w,:]
    
    tr_lyr5 = layers.Conv2DTranspose(16,(32,32),strides=16,name = stage+'4')(aux_tensor[4])
    tr_lyr5 = tr_lyr5[:,:h,:w,:]

    concat = tf.concat([tr_lyr2,tr_lyr3,tr_lyr4,tr_lyr5],axis=3)
    
    output = layers.Conv2D(1,(1,1))(concat)
    # output = tf.nn.sigmoid(output)
    
    model = Model(inputs = img_input,outputs=output)
    
    if weight=='imagenet':
        model.load_weights(weight_path,by_name=True)
        
    return model

In [5]:
def parent_training(data_gen_train,dat_gen_test):
    
    model_save_path = "/model/parent_model/"
    model_name = "weights-improvement-{epoch:02d}-{val_loss:.2f}.hdf5"
    checkpoint_filepath = os.path.join(model_save_path,model_name)
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
                                        filepath=checkpoint_filepath,
                                        save_weights_only=True,
                                        monitor='val_loss',
                                        mode='min',
                                        save_best_only=True)
    
    parent_model = vgg16_osvs((None,None,3),weight='imagenet')
    
    loss = weighted_pixelwise_cross_entropy
    adam = tf.keras.optimizers.Adam(learning_rate=0.001)
    parent_model.compile(optimizer=adam,
              loss= loss,        
              metrics=['accuracy'])   
    
    parent_model.fit(data_gen_train,
                     epochs=10,verbose=1,
                     validation_data=dat_gen_test,
                     callbacks=model_checkpoint_callback
                     )
    
batch_size = 16
img_size = [300,300]
data_gen_train = DataGenerator(batch_size,data.img_path_train,
                               data.target_path_train)
dat_gen_test  = DataGenerator(batch_size,data.img_path_val,
                              data.target_path_val)

# parent_training(data_gen_train,dat_gen_test)

## FineTunning and Testing.
### Fine Tunning.
In The Fine tunning, Parent Network is trained with one or more Frame/Ground truth pair for n number of epochs.The frame used for fine tunning is frame at (t-1) time. The fine tunned model is then used for segmenting subsequent frames in the sequence.

In [6]:
def fine_tunning(fine_tune_gen,weight_parent):
    
    model_ft = vgg16_osvs((None,None,3),weight=None)
    model_ft.load_weights(weight_parent)

    loss = weighted_pixelwise_cross_entropy
    model_ft.compile(optimizer='adam',
                  loss= loss,        
                  metrics=['accuracy'])

    model_ft.fit(fine_tune_gen,epochs=100,verbose=1
             )
    model_ft.save('./model/finetunned_model.h5')
    
    
weight_parent = "/Users/tangerine/Downloads/weights-improvement-73-20985.05.hdf5"  

image_path = ["/Users/tangerine/PycharmProjects/dataset/DAVIS2017/test_data/image/cows/00000.jpg"]
mask_path = ["/Users/tangerine/PycharmProjects/dataset/DAVIS2017/test_data/mask/cows/00000.png"]

img_size  = [300,300] # same as parent training
batch_size = 1
fine_tune_gen = DataGenerator(batch_size,image_path,mask_path)
# fine_tunning(fine_tune_gen,weight_parent)

### Testing.
In testing phase, fine tunned model with (t-1) frame/ground truth pair is used to predict mask for frames at time>=t.

In [8]:
import time

def iou_calc(gt_arr,pred_arr):  
    iou_list = []
    for ind in range(len(gt_arr)):
        gt = gt_arr[ind].flatten()
        pred = pred_arr[ind].flatten()
        intersection = np.sum(np.multiply(gt,pred))
        union = np.add(gt,pred)
        union = np.sum(np.where(union>1,1,union))
        iou = intersection/union
        iou_list.append(iou)
    print(np.mean(iou_list))
            
def testing_seq(fine_tunned_model,image_dir):
    pred_mask_arr = []
    gt_mask_arr = []
    
    model_test = vgg16_osvs((None,None,3),weight=None)
    model_test.load_weights(fine_tunned_model)
    
    files = os.listdir(image_dir)
    files = sorted(files)
    for i,file in enumerate(files):
        image_path = os.path.join(image_dir,file)
        image  = cv2.imread(image_path)/255.0
        h,w,_  = image.shape
        image  = np.expand_dims(image,axis=0)
        st_time = time.time()
        output = tf.nn.sigmoid(model_test.predict(image)).numpy()
        end_time = time.time()
        
        print(f"inference time for {i}. Frame is {end_time-st_time}") 
        output = np.where(output>0.5,1,0).reshape(h,w)
        pred_mask = np.expand_dims(output,axis =0)
        pred_mask_arr.append(pred_mask)
              
        gt_anno_path = os.path.join(anno_dir,file[:-4]+'.png')
        anno = cv2.imread(gt_anno_path,0)
        annot = np.where(anno>0,1,0).reshape(h,w).astype('float32')
        
        anno = np.expand_dims(annot,axis=0)
        gt_mask_arr.append(anno)        
        
    return pred_mask_arr,gt_mask_arr


fine_tunned_model = "./model/finetunned_model.h5"
image_dir = "/Users/tangerine/PycharmProjects/dataset/DAVIS2017/test_data/image/cows/"
anno_dir = "/Users/tangerine/PycharmProjects/dataset/DAVIS2017/test_data/mask/cows/"


pred_mask_list,gt_mask_list = testing_seq(fine_tunned_model,image_dir)
pred_mask_arr = np.asarray(pred_mask_list).transpose([0,2,3,1])
gt_mask_arr = np.asarray(gt_mask_list).transpose([0,2,3,1])
iou_calc(gt_mask_arr,pred_mask_arr)