In [None]:
import os
import numpy as np
from PIL import Image
import cv2

import tensorflow as tf

1. Create folders for storing ss logits

In [None]:
cdnet_root_path = '/home/Datasets/CDnet2014/dataset'   
cdnet_ss_logits_root_path = '/home/Datasets/CDnet2014_SS_Logits'

In [None]:
for root, _, files in os.walk(cdnet_root_path):
    
    relative_path_hierarch_list = root.split('/')[len(cdnet_root_path.split('/')):]
    level_above_dataset_root_path = len(relative_path_hierarch_list)
    
    if len(relative_path_hierarch_list) == 1:
        level_1_dir_to_create = cdnet_ss_logits_root_path + '/' + relative_path_hierarch_list[0]
        if not os.path.exists(level_1_dir_to_create):
            os.mkdir(level_1_dir_to_create)
        print(relative_path_hierarch_list)

In [None]:
for root, _, files in os.walk(cdnet_root_path):
    
    relative_path_hierarch_list = root.split('/')[len(cdnet_root_path.split('/')):]
    level_above_dataset_root_path = len(relative_path_hierarch_list)
    
    if len(relative_path_hierarch_list) == 2:
        level_2_dir_to_create = cdnet_ss_logits_root_path + '/' + relative_path_hierarch_list[0] + '/' + relative_path_hierarch_list[1]
        if not os.path.exists(level_2_dir_to_create):
            os.mkdir(level_2_dir_to_create)
        #print(relative_path_hierarch_list)

2. Define functions for ss

In [None]:
class DeepLabModel(object):
    """Class to load deeplab model and run inference."""

    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    LOGITS_TENSOR_NAME = 'Mean:0'

    def __init__(self, frozen_inference_graph_path):
        """Creates and loads pretrained deeplab model."""
        self.graph = tf.Graph()

        graph_def = tf.GraphDef()
        # Read from frozen graph
        with tf.gfile.GFile(frozen_inference_graph_path, 'rb') as f:
            graph_def.ParseFromString(f.read())

        if graph_def is None:
            raise RuntimeError('Cannot find inference graph.')

        with self.graph.as_default():
            tf.import_graph_def(graph_def, name='')

        self.sess = tf.Session(graph=self.graph)

    def run(self, image):
        """Runs inference on a single image.

        Args:
            image: A PIL.Image object, raw input image.

        Returns:
            resized_image: RGB image.
            seg_logits: segmentation logits
            seg_map: Segmentation map of `resized_image`.
        """
        width, height = image.size
        image_array = np.asarray(image.convert('RGB'))
        batch_seg_logits, batch_seg_map = self.sess.run(
            [self.LOGITS_TENSOR_NAME,self.OUTPUT_TENSOR_NAME],
            feed_dict={self.INPUT_TENSOR_NAME: [image_array]})
        seg_logits = batch_seg_logits[0]
        seg_map = batch_seg_map[0]
        return image_array, seg_logits[:height,:width,:], seg_map[:height,:width]

    def write_graph(self,log_path):
        tf.summary.FileWriter(log_path, self.sess.graph)

In [None]:
frozen_inference_graph_path = './Models/frozen_inference_graph.pd'
print('loading model, this might take a while...')
MODEL = DeepLabModel(frozen_inference_graph_path)
print('model loaded successfully!')

3. Calculate logits

In [None]:
def create_cdnet_label_colormap():
    """Creates a label colormap used in CDnet.

    Returns:
        A Colormap for visualizing segmentation results.
    """
    colormap = np.zeros((256, 3), dtype=int)
    ind = np.arange(256, dtype=int)

    for shift in reversed(range(8)):
        for channel in range(3):
            colormap[:, channel] |= ((ind >> channel) & 1) << shift
        ind >>= 3

    return colormap

In [None]:
def label_to_color_image(label):
    """Adds color defined by the dataset colormap to the label.

    Args:
        label: A 2D array with integer type, storing the segmentation label.

    Returns:
        result: A 2D array with floating type. The element of the array
            is the color indexed by the corresponding element in the input label
            to the PASCAL color map.

    Raises:
        ValueError: If label is not of rank 2 or its value is larger than color
            map maximum entry.
    """
    if label.ndim != 2:
        raise ValueError('Expect 2-D input label')

    colormap = create_cdnet_label_colormap()

    if np.max(label) >= len(colormap):
        raise ValueError('label value too large.')

    return colormap[label]

In [None]:
def save_logits(seg_logits, save_dir, image_name):
    for i in range(seg_logits.shape[-1]):
        cv2.imwrite(save_dir+'/logit_'+image_name+'_{}.png'.format(i),np.uint8(seg_logits[:,:,i]*255),[cv2.IMWRITE_PNG_COMPRESSION, 9])

In [None]:
for root, _, _ in os.walk(cdnet_root_path):
    
    relative_path_hierarch_list = root.split('/')[len(cdnet_root_path.split('/')):]
    level_above_dataset_root_path = len(relative_path_hierarch_list)
    
    if len(relative_path_hierarch_list) == 2:
        dir_storing_images = root + '/input'
        dir_to_store_logits = cdnet_ss_logits_root_path + '/' + relative_path_hierarch_list[0] + '/' + relative_path_hierarch_list[1]
        
        print(dir_to_store_logits)
        for _,_, files in os.walk(dir_storing_images):
            for file in files:
                file_index = file[-10:-4]
                image_file = dir_storing_images + '/' + file
                #print(image_file)
                image_array, seg_logits, seg_map = MODEL.run(Image.open(image_file))
                seg_image = label_to_color_image(seg_map).astype(np.uint8)
                save_logits(seg_logits, dir_to_store_logits, file_index)
                cv2.imwrite(dir_to_store_logits+'/seg_'+file_index+'.png',cv2.cvtColor(seg_image, cv2.COLOR_RGB2BGR),[cv2.IMWRITE_PNG_COMPRESSION, 9])
