# A universal framework for fusing semantic information and temporal consistency for background segmentations

This notebook illustrate the training process of the temporal consistency based model and the top model. The training of the DeepLab model is not proposed here.

---by Zhi Zeng Sep.,28,2018

# 1. Train the temporal consistency-based model

In [None]:
import tensorflow as tf
import numpy as np
import cv2
from matplotlib import pyplot as plt
from plotly.offline import init_notebook_mode, iplot
import plotly.graph_objs as go
import os
import time

## 1.0. Check paths

In [None]:
dataset_root_path = './Ours_Dataset'
print('dataset_root_path is: \n' + dataset_root_path)

In [None]:
model_root_path = './Our_Models'
per_scene_FH_model_root_path = model_root_path+'\\FH\\per_scene'
print('model_root_path is: \n' + model_root_path + '\n')
print('per_scene_FH_model_root_path is: \n' + per_scene_FH_model_root_path)

## 1.1. Load the model

In [None]:
from Utilities.fuzzy_partition_histogram import FuzzyHistogramModel

In [None]:
FH_MODEL = FuzzyHistogramModel()

## 1.2. Check the fuzzy histogram based background model

In [None]:
with FH_MODEL.graph.as_default():
    for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES ):
        print(var.name,var.shape)
    print(FH_MODEL.indexing_constant)
    for op in tf.get_default_graph().get_operations():
        if op.type == "Placeholder":
            print(op.get_attr)

## 1.3. Train the model for each scene with 200 frames

In [None]:

for root, _, files in os.walk(dataset_root_path):
    relative_path_hierarch_list = root.split('\\')[len(dataset_root_path.split('\\')):]
    level_above_dataset_root_path = len(relative_path_hierarch_list)
    if len(relative_path_hierarch_list) == 2:
        # Get the training data
        Train_GTs_path = root + '\\Train_GTs'
        Train_Inputs_path = root + '\\Train_Inputs'
        FH_model_path = per_scene_FH_model_root_path + '\\'+relative_path_hierarch_list[0] + '\\' + relative_path_hierarch_list[1] + '\\classic_model'
        per_scene_TOP_model_path = per_scene_FH_model_root_path + '\\'+relative_path_hierarch_list[0] + '\\' + relative_path_hierarch_list[1] + '\\DNN_model'
        truth_file_list = []
        image_file_list = []
        for _, _, files in os.walk(Train_GTs_path):
            for file in files:
                truth_file_list.append(Train_GTs_path+'\\'+file)
                image_file_list.append(Train_Inputs_path+'\\in'+file[2:8]+'.jpg')
        # Initialize the model
        FH_MODEL.initialize_sess()
        # Train the model
        for image_file,truth_file in zip(image_file_list,truth_file_list):
            # Read the image
            cv_BGR_image = cv2.resize(cv2.imread(image_file),(320, 240), interpolation = cv2.INTER_CUBIC)
            cv_gray_image = cv2.cvtColor(cv_BGR_image, cv2.COLOR_BGR2GRAY)
            cv_float_gray_image = cv_gray_image.astype('float')/256.0 # Avoid reaching the 22th bin of a histogram (do not using ./255)
            cv2.imshow('cv_float_gray_image',cv_float_gray_image)
            cv2.waitKey(1)
            # Load the groundtruth
            cv_BGR_truth = cv2.resize(cv2.imread(truth_file),(320, 240), interpolation = cv2.INTER_NEAREST)
            cv_gray_truth = cv2.cvtColor(cv_BGR_truth, cv2.COLOR_BGR2GRAY)
            cv_float_gray_truth = 1.0-cv_gray_truth.astype('float')/255.0 
            cv2.imshow('cv_float_gray_truth',cv_float_gray_truth)
            # Check the histogram
            raw_segmentation = FH_MODEL.histogram_checking(cv_float_gray_image)
            cv2.imshow('raw_segmentation',raw_segmentation)
            cv2.waitKey(10)
            # Update the histogram
            FH_MODEL.update_histogram(cv_float_gray_image,cv_float_gray_truth,train_flag=True)
        FH_MODEL.reduce_histogram()
        cv2.destroyAllWindows()
        # Save the model
        FH_MODEL.save_model(FH_model_path)

## 1.4. Run over the model on training frames

Note: These results combined with semantic segmentaions are used to train the top model.

In [None]:
for root, _, files in os.walk(dataset_root_path):
    relative_path_hierarch_list = root.split('\\')[len(dataset_root_path.split('\\')):]
    level_above_dataset_root_path = len(relative_path_hierarch_list)
    if len(relative_path_hierarch_list) == 2:
        # Get the training data
        Train_GTs_path = root + '\\Train_GTs'
        Train_Inputs_path = root + '\\Train_Inputs'
        FH_Training_Results_path = root + '\\FH_Training_Results'
        FH_model_path = per_scene_FH_model_root_path + '\\'+relative_path_hierarch_list[0] + '\\' + relative_path_hierarch_list[1] + '\\classic_model'
        per_scene_TOP_model_path = per_scene_FH_model_root_path + '\\'+relative_path_hierarch_list[0] + '\\' + relative_path_hierarch_list[1] + '\\DNN_model'
        truth_file_list = []
        image_file_list = []
        result_file_list = []
        for _, _, files in os.walk(Train_GTs_path):
            for file in files:
                truth_file_list.append(Train_GTs_path+'\\'+file)
                image_file_list.append(Train_Inputs_path+'\\in'+file[2:8]+'.jpg')
                result_file_list.append(FH_Training_Results_path+'\\result'+file[2:8]+'.png')
        # Initialize the model
        FH_MODEL.initialize_sess()
        FH_MODEL.load_model(FH_model_path)
        # Train the model
        for image_file,truth_file,result_file in zip(image_file_list,truth_file_list,result_file_list):
            # Read the image
            cv_BGR_image = cv2.resize(cv2.imread(image_file),(320, 240), interpolation = cv2.INTER_CUBIC)
            cv_gray_image = cv2.cvtColor(cv_BGR_image, cv2.COLOR_BGR2GRAY)
            cv_float_gray_image = cv_gray_image.astype('float')/256.0 # Avoid reaching the 22th bin of a histogram (do not using ./255)
            cv2.imshow('cv_float_gray_image',cv_float_gray_image)
            cv2.waitKey(1)
            # Load the groundtruth
            cv_BGR_truth = cv2.resize(cv2.imread(truth_file),(320, 240), interpolation = cv2.INTER_NEAREST)
            cv_gray_truth = cv2.cvtColor(cv_BGR_truth, cv2.COLOR_BGR2GRAY)
            cv_float_gray_truth = 1.0-cv_gray_truth.astype('float')/255.0 
            cv2.imshow('cv_float_gray_truth',cv_float_gray_truth)
            # Check the histogram
            raw_segmentation = FH_MODEL.histogram_checking(cv_float_gray_image)
            cv2.imshow('raw_segmentation',raw_segmentation)
            cv2.waitKey(10)
            # Record the performance of the model
            cv2.imwrite(result_file,raw_segmentation*255)
        cv2.destroyAllWindows()

# 2. DeepLab Model

## 2.1. Load the model

In [None]:
from Utilities.deeplab import DeepLabModel,vis_segmentation,vis_segmentation_map_calculate

In [None]:
model_root_path = './Our_Models'
deeplab_model_path = model_root_path + '\\SS_model\\deeplab_model.tar.gz'
print('deeplab_model_path: \n',deeplab_model_path)

In [None]:
SS_MODEL = DeepLabModel(deeplab_model_path)
print('model loaded successfully!')

In [None]:
def run_visualization_local(img_path):
    """Inferences DeepLab model and visualizes result."""
    try:
        orignal_im = Image.open(img_path)
    except IOError:
        print('Cannot retrieve image. Please check path: ' + img_path)
        return

    print('running deeplab on image %s...' % img_path)
    resized_im, seg_logits, seg_map = SS_MODEL.run(orignal_im)
    print('seg_logits.shape: ',seg_logits.shape)
    print('seg_map.shape: ',seg_map.shape)
    
    vis_segmentation(resized_im, seg_map)

img_path = 'in000534.jpg'
run_visualization_local(img_path)
resized_im, seg_logits, seg_map = SS_MODEL.run(Image.open(img_path))

# 3. Training the top model

In [None]:
from Utilities.synthesis_model_white_box_rectified_F_score import SynthesisModel,calculate_double_mask,single_feature_builder

## 3.1. Load the model

In [None]:
TOP_Model = SynthesisModel()

In [None]:
with TOP_Model.graph.as_default():
    for var in tf.trainable_variables():
        print(var)

## 3.2. Training on a batch of images for model for each scene with 200 frames

In [None]:
care_catagory_list = ['baseline',
                      'dynamicBackground',
                      'intermittentObjectMotion',
                      'badWeather',
                      'shadow',
                      'cameraJitter',
                      'lowFramerate']

In [None]:
truth_file_list = []
image_file_list = []
result_file_list = []

for root, _, files in os.walk(dataset_root_path):
    
    relative_path_hierarch_list = root.split('\\')[len(dataset_root_path.split('\\')):]
    level_above_dataset_root_path = len(relative_path_hierarch_list)
    
    if len(relative_path_hierarch_list) == 2 and relative_path_hierarch_list[0] in care_catagory_list:
        
        Train_GTs_path = root + '\\Train_GTs'
        Train_Inputs_path = root + '\\Train_Inputs'
        FH_Training_Results_path = root + '\\FH_Training_Results'

        # Get the training data
        for _, _, files in os.walk(Train_GTs_path):
            for file in files:
                truth_file_list.append(Train_GTs_path+'\\'+file)
                image_file_list.append(Train_Inputs_path+'\\in'+file[2:8]+'.jpg')
                result_file_list.append(FH_Training_Results_path+'\\result'+file[2:8]+'.png')

In [None]:
print('Number of truth files is:',len(truth_file_list))
print('Number of image files is:',len(image_file_list))
print('Number of result files is:',len(result_file_list))

In [None]:
from sklearn.utils import shuffle

TOP_model_path = './Our_Models\\FH\\model'
TOP_Model.initialize_sess(log_path=TOP_model_path)

In [None]:
# Train the model
num_epochs = 10
batch_size = 50
num_batchs = len(image_file_list)//batch_size -1

temp_count = 0
for epoch in range(num_epochs):
    image_file_list_shuffled, truth_file_list_shuffled, result_file_list_shuffled = shuffle(image_file_list, 
                                                                                            truth_file_list, 
                                                                                            result_file_list, 
                                                                                            random_state=epoch)
    for batch_num in range(num_batchs):
        image_batch_file_list = image_file_list_shuffled[batch_num*batch_size:(1+batch_num)*batch_size]
        truth_batch_file_list = truth_file_list_shuffled[batch_num*batch_size:(1+batch_num)*batch_size]
        result_batch_file_list = result_file_list_shuffled[batch_num*batch_size:(1+batch_num)*batch_size]
        composit_feature_batch = []
        positive_mask_batch = []
        negative_mask_batch = []
        # Build the learning batch
        for image_file,truth_file,result_file in zip(image_batch_file_list,truth_batch_file_list,result_batch_file_list):
            # Read the image
            cv_BGR_image = cv2.resize(cv2.imread(image_file),(320, 240), interpolation = cv2.INTER_CUBIC)
            cv_gray_image = cv2.cvtColor(cv_BGR_image, cv2.COLOR_BGR2GRAY)
            cv_float_gray_image = cv_gray_image.astype('float')/256.0 # Avoid reaching the 22th bin of a histogram (not using ./255)
            cv2.imshow('cv_float_gray_image',cv_float_gray_image)
            cv2.waitKey(1)
            # Load the groundtruth
            cv_BGR_truth = cv2.resize(cv2.imread(truth_file),(320, 240), interpolation = cv2.INTER_NEAREST)
            cv_gray_truth = cv2.cvtColor(cv_BGR_truth, cv2.COLOR_BGR2GRAY)
            cv_float_gray_truth = 1.0-cv_gray_truth.astype('float')/255.0 
            cv2.imshow('cv_float_gray_truth',cv_float_gray_truth)
            # Load the result
            cv_BGR_result = cv2.imread(result_file)
            cv_gray_result = cv2.cvtColor(cv_BGR_result, cv2.COLOR_BGR2GRAY)
            raw_segmentation = cv_gray_result.astype('float')/255.0 
            cv2.imshow('raw_segmentation',raw_segmentation)
            cv2.waitKey(1)
            # Calculate SS result
            Image_image = Image.open(image_file)
            _, seg_logits, seg_map = SS_MODEL.run(Image_image)
            seg_logits_channel_sum = np.sum(seg_logits,axis=-1)
            seg_logits_channel_sum_tile = np.dstack([seg_logits_channel_sum for i in range(seg_logits.shape[-1])])
            seg_logits_normalized = seg_logits/seg_logits_channel_sum_tile
            seg_map_show = vis_segmentation_map_calculate(seg_map)
            cv2.imshow('seg_map_show',seg_map_show)
            cv2.waitKey(1)
            # Calculate double masks
            positive_mask,negative_mask = calculate_double_mask(cv_gray_truth)
            cv2.imshow('positive_mask',positive_mask)
            cv2.imshow('negative_mask',negative_mask)
            cv2.waitKey(1)
            composit_feature_batch.append(single_feature_builder(raw_segmentation,seg_logits_normalized))
            positive_mask_batch.append(positive_mask)
            negative_mask_batch.append(negative_mask)

        # Learning
        current_loss,synthesis_result = TOP_Model.train(composit_feature_batch,
                                                        positive_mask_batch,
                                                        negative_mask_batch,
                                                        step=temp_count)
        cv2.imshow('synthesis_result',synthesis_result[0])
        cv2.waitKey(1)

        temp_count += 1
        if temp_count%50==1:
            print('temp_count: ',temp_count,'current_loss:',current_loss)

cv2.destroyAllWindows()
TOP_Model.save_model(TOP_model_path)