# 1. Define the Background Segmentation Model

In [1]:
import tensorflow as tf
import numpy as np
import cv2
from matplotlib import pyplot as plt

  from ._conv import register_converters as _register_converters


In [2]:
# Prerequisites
image_height = 240
image_width = 320
num_bin = 20+1
bin_width = 1.0/(num_bin-1)
num_training_frame = 20

In [3]:
class FuzzyHistogramModel():
    """Class to load deeplab model and run inference."""

    def __init__(self):
        """Creates and loads pretrained deeplab model."""
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)
        with self.graph.as_default():
            self.model_scope = tf.variable_scope('model', reuse=tf.AUTO_REUSE)
            self.input_scope = tf.variable_scope('input', reuse=tf.AUTO_REUSE)
            self.output_scope = tf.variable_scope('output', reuse=tf.AUTO_REUSE)
            with self.model_scope:
                # Build the model
                self.fuzzy_histogram = tf.get_variable('fuzzy_histogram',
                                                  shape=[image_height, image_width, num_bin],
                                                  initializer=tf.constant_initializer(1.0),
                                                  trainable=False)
                # Define updating mask
                self.update_mask = tf.get_variable('update_mask',
                                              shape=[image_height, image_width],
                                              initializer=tf.constant_initializer(1.0),
                                              trainable=False)
                # Intermitent vars
                self.data_position_in_histogram = tf.get_variable('data_position_in_histogram',
                                                             shape=[image_height, image_width],
                                                             trainable=False)
                self.pre_index = tf.get_variable('previous_index_in_histogram',
                                            shape=[image_height, image_width],
                                            dtype=tf.int32,
                                            trainable=False)
                self.pre_weight = tf.get_variable('previous_weight_in_histogram',
                                             shape=[image_height, image_width],
                                             trainable=False)
                self.next_index = tf.get_variable('next_index_in_histogram',
                                             shape=[image_height, image_width],
                                             dtype=tf.int32,
                                             trainable=False)
                self.next_weight = tf.get_variable('next_weight_in_histogram',
                                              shape=[image_height, image_width],
                                              trainable=False)
                self.max_bin_value = tf.get_variable('max_bin_value',
                                                shape=[image_height, image_width],
                                                trainable=False)
                # Intermitent constants
                self.indexing_constant = tf.constant([[[r,c] for c in range(image_width)] for r in range(image_height)],
                                                name='indexing_constant')
            with self.input_scope:
                # Define the input image
                self.input_image = tf.placeholder(tf.float32, 
                                             shape=[image_height,image_width], 
                                             name='image')
                # Define the synthesis result
                self.synthesis_result = tf.placeholder(tf.float32, 
                                                  shape=[image_height,image_width], 
                                                  name='synthesis_result')
            with self.output_scope:
                # Define the raw output 
                self.raw_segmentation = tf.get_variable('compatibility',
                                         shape=[image_height, image_width],
                                         trainable=False)
        
    # Initialization
    def initialize_sess(self):
        with self.graph.as_default():
            self.sess.run(tf.global_variables_initializer())

    # Define histogram checking
    def histogram_checking(self, gray_float_image):
        '''
        Inputs:
            data_to_check: [height, width], float (from 0.0 to 1.0)
            histogram_to_use: [height, width, num_bins], float
        Outputs:
            result: [height, width], float
        '''
        with self.graph.as_default():
            self.sess.run(self.data_position_in_histogram.assign(self.input_image/bin_width),
                          feed_dict={self.input_image: gray_float_image})
            self.sess.run(self.pre_index.assign(tf.to_int32(tf.floor(self.data_position_in_histogram))))
            self.sess.run(self.pre_weight.assign(self.data_position_in_histogram-tf.to_float(self.pre_index)))
            self.sess.run(self.next_index.assign(self.pre_index+1))
            self.sess.run(self.next_weight.assign(tf.to_float(self.next_index)-self.data_position_in_histogram))
            self.sess.run(self.raw_segmentation.assign(tf.add(tf.multiply(tf.gather_nd(self.fuzzy_histogram,
                                                                                       tf.concat([self.indexing_constant,tf.expand_dims(self.pre_index, -1)], -1)),
                                                                          self.pre_weight),
                                                              tf.multiply(tf.gather_nd(self.fuzzy_histogram,
                                                                                       tf.concat([self.indexing_constant,tf.expand_dims(self.next_index, -1)], -1)),
                                                                          self.next_weight))))
    def get_raw_segmentation(self):
        with self.graph.as_default():
            raw_segmentation_in_np_in_float = self.sess.run(self.raw_segmentation.value())
        return raw_segmentation_in_np_in_float
        
    
    def fake_synthesis_generate(self):
        # The following can be replaced by avr_pooling
        with self.graph.as_default():
            raw_segmentation_in_np_in_float_scaled_to_255 = 255.0*self.sess.run(self.raw_segmentation.value())
        raw_segmentation_in_np_in_uint8 = raw_segmentation_in_np_in_float_scaled_to_255.astype('uint8')
        blurred_mask_uint8 = cv2.medianBlur(raw_segmentation_in_np_in_uint8,9)
        blurred_mask_float = blurred_mask_uint8.astype('float')/256
        return blurred_mask_float
        
    def update_weight_calculation(self, synthesis_map):
        # a*y^5/(x+b) a = 0.0792; b = 0.1585
        with self.graph.as_default():
            self.sess.run(self.update_mask.assign(0.0792*self.synthesis_result**5/(self.raw_segmentation+0.1585)),
                         feed_dict={self.synthesis_result: synthesis_map})
    
    def update_histogram(self):
        with self.graph.as_default():
            self.sess.run(self.fuzzy_histogram.assign_add(tf.sparse_tensor_to_dense(tf.SparseTensor(indices=tf.reshape(tf.to_int64(tf.concat([self.indexing_constant, tf.expand_dims(self.pre_index, -1)], -1)),
                                                                                                                       [image_height*image_width,3]), 
                                                                                                    values=tf.reshape(tf.multiply(self.pre_weight,self.update_mask),
                                                                                                                      [image_height*image_width,]), 
                                                                                                    dense_shape=[image_height,image_width,num_bin]))))
            self.sess.run(self.fuzzy_histogram.assign_add(tf.sparse_tensor_to_dense(tf.SparseTensor(indices=tf.reshape(tf.to_int64(tf.concat([self.indexing_constant, tf.expand_dims(self.next_index, -1)], -1)),
                                                                                                                       [image_height*image_width,3]), 
                                                                                                    values=tf.reshape(tf.multiply(self.next_weight,self.update_mask),
                                                                                                                      [image_height*image_width,]), 
                                                                                                    dense_shape=[image_height,image_width,num_bin]))))
            self.sess.run(self.max_bin_value.assign(tf.reduce_max(self.fuzzy_histogram,axis=-1)))
            self.sess.run(self.fuzzy_histogram.assign(tf.divide(self.fuzzy_histogram,
                                                                tf.tile(tf.expand_dims(self.max_bin_value, -1),
                                                                        [1, 1, num_bin]))))
    
    def write_graph(self,log_path):
        with self.graph.as_default():
            tf.summary.FileWriter(log_path, self.sess.graph)
    
    def __del__(self):
        self.sess.close()

## 1.1. Test the fuzzy histogram based background model

In [4]:
image_folder = 'F:\\dataset2014\\dataset\\baseline\\highway\\input\\'

In [5]:
FH_MODEL = FuzzyHistogramModel()

In [6]:
with FH_MODEL.graph.as_default():
    for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES ):
        print(var.name,var.shape)
    print(FH_MODEL.indexing_constant)
    for op in tf.get_default_graph().get_operations():
        if op.type == "Placeholder":
            print(op.get_attr)

model/fuzzy_histogram:0 (240, 320, 21)
model/update_mask:0 (240, 320)
model/data_position_in_histogram:0 (240, 320)
model/previous_index_in_histogram:0 (240, 320)
model/previous_weight_in_histogram:0 (240, 320)
model/next_index_in_histogram:0 (240, 320)
model/next_weight_in_histogram:0 (240, 320)
model/max_bin_value:0 (240, 320)
output/compatibility:0 (240, 320)
Tensor("model/indexing_constant:0", shape=(240, 320, 2), dtype=int32)
<bound method Operation.get_attr of <tf.Operation 'input/image' type=Placeholder>>
<bound method Operation.get_attr of <tf.Operation 'input/synthesis_result' type=Placeholder>>


In [7]:
FH_MODEL.initialize_sess()

Training the model with 20 frames:

In [8]:
for i_train in range(1,num_training_frame):
    cv_BGR_image = cv2.imread(image_folder + 'in{0:06d}'.format(i_train) + '.jpg')
    cv_gray_image = cv2.cvtColor(cv_BGR_image, cv2.COLOR_BGR2GRAY)
    cv_float_gray_image = cv_gray_image.astype('float')/256.0 # Avoid reaching the 22th bin of a histogram (not using ./255)
    cv2.imshow('cv_float_gray_image',cv_float_gray_image)
    cv2.waitKey(1)
    FH_MODEL.histogram_checking(cv_float_gray_image)
    raw_segmentation = FH_MODEL.get_raw_segmentation()
    cv2.imshow('raw_segmentation',raw_segmentation)
    cv2.waitKey(1)
    fake_synthesis = FH_MODEL.fake_synthesis_generate()
    cv2.imshow('synthesis_image',fake_synthesis)
    cv2.waitKey(50)
    # FH_MODEL.update_weight_calculation(fake_synthesis)
    FH_MODEL.update_histogram()

Test the model on consecutive 80 frames:

In [9]:
for i_train in range(num_training_frame,100):
    cv_BGR_image = cv2.imread(image_folder + 'in{0:06d}'.format(i_train) + '.jpg')
    cv_gray_image = cv2.cvtColor(cv_BGR_image, cv2.COLOR_BGR2GRAY)
    cv_float_gray_image = cv_gray_image.astype('float')/256.0 # Avoid reaching the 22th bin of a histogram (not using ./255)
    cv2.imshow('cv_float_gray_image',cv_float_gray_image)
    cv2.waitKey(1)
    FH_MODEL.histogram_checking(cv_float_gray_image)
    raw_segmentation = FH_MODEL.get_raw_segmentation()
    cv2.imshow('raw_segmentation',raw_segmentation)
    cv2.waitKey(1)
    fake_synthesis = FH_MODEL.fake_synthesis_generate()
    cv2.imshow('synthesis_image',fake_synthesis)
    cv2.waitKey(200)
    FH_MODEL.update_weight_calculation(fake_synthesis)
    FH_MODEL.update_histogram()

In [11]:
FH_MODEL.write_graph('Logs')

In [12]:
cv2.destroyAllWindows()