In [1]:
from distutils.version import LooseVersion
import glob
from glob import glob
import numpy as np
import os.path
import re
import random
import scipy.misc
import shutil
import tensorflow as tf
from tqdm import tqdm
import time
from urllib.request import urlretrieve
import warnings
import zipfile

In [2]:
DATA_DIRECTORY = './data'
RUNS_DIRECTORY = './runs'
TRAINING_DATA_DIRECTORY ='./data/data_road/training'
NUMBER_OF_IMAGES = len(glob('./data/data_road/training/calib/*.*'))
VGG_PATH = './data/vgg'

In [3]:
NUMBER_OF_CLASSES = 2
IMAGE_SHAPE = (160, 576)

EPOCHS = 1
BATCH_SIZE = 1

LEARNING_RATE = 0.0001
DROPOUT = 0.75

In [4]:
correct_label = tf.placeholder(tf.float32, [None, IMAGE_SHAPE[0], IMAGE_SHAPE[1], NUMBER_OF_CLASSES])
learning_rate = tf.placeholder(tf.float32)
keep_prob = tf.placeholder(tf.float32)

In [5]:
def generate_batch_function(data_folder, image_shape):
    def get_batches_function(batch_size):
        image_paths = glob(os.path.join(data_folder, 'image_2', '*.png'))
        label_paths = {
            re.sub(r'_(lane|road)_', '_', os.path.basename(path)): path
            for path in glob(os.path.join(data_folder, 'gt_image_2', '*_road_*.png'))}
        background_color = np.array([255, 0, 0])
        random.shuffle(image_paths)
        
        for i in range(0, len(image_paths), batch_size):
            images = []
            gt_images = []
            
            for image_file in image_paths[i:i+batch_size]:
                gt_image_file = label_paths[os.path.basename(image_file)]

                image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape)
                gt_image = scipy.misc.imresize(scipy.misc.imread(gt_image_file), image_shape)

                gt_bg = np.all(gt_image == background_color, axis=2)
                gt_bg = gt_bg.reshape(*gt_bg.shape, 1)
                gt_image = np.concatenate((gt_bg, np.invert(gt_bg)), axis=2)

                images.append(image)
                gt_images.append(gt_image)

            yield np.array(images), np.array(gt_images)
    return get_batches_function

In [6]:
def generate_test_output(sess, logits, keep_prob, image_pl, data_folder, image_shape):
    for image_file in glob(os.path.join(data_folder, 'image_2', '*.png')):
        image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape)

        im_softmax = sess.run([tf.nn.softmax(logits)], {keep_prob: 1.0, image_pl: [image]})
        im_softmax = im_softmax[0][:, 1].reshape(image_shape[0], image_shape[1])
        
        segmentation = (im_softmax > 0.5).reshape(image_shape[0], image_shape[1], 1)
        
        mask = np.dot(segmentation, np.array([[0, 255, 0, 127]]))
        mask = scipy.misc.toimage(mask, mode='RGBA')
        
        street_im = scipy.misc.toimage(image)
        street_im.paste(mask, box=None, mask=mask)

        yield os.path.basename(image_file), np.array(street_im)

In [7]:
def save_inference_samples(runs_dir, data_dir, sess, image_shape, logits, keep_prob, input_image):
    output_dir = os.path.join(runs_dir, 'epochs {0}'.format(EPOCHS))
    
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    os.makedirs(output_dir)

    print('Training Finished. Saving test images to: {}'.format(output_dir))
    image_outputs = generate_test_output(sess, logits, keep_prob, input_image, os.path.join(data_dir, 'data_road/testing'), image_shape)
    
    for name, image in image_outputs:
        scipy.misc.imsave(os.path.join(output_dir, name), image)

In [8]:
def load_vgg(sess, vgg_path):
    model = tf.saved_model.loader.load(sess, ['vgg16'], vgg_path)

    graph = tf.get_default_graph()
    image_input = graph.get_tensor_by_name('image_input:0')
    keep_prob = graph.get_tensor_by_name('keep_prob:0')
    layer3 = graph.get_tensor_by_name('layer3_out:0')
    layer4 = graph.get_tensor_by_name('layer4_out:0')
    layer7 = graph.get_tensor_by_name('layer7_out:0')

    return image_input, keep_prob, layer3, layer4, layer7

In [9]:
def layers(vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, num_classes = NUMBER_OF_CLASSES):
    layer3x = tf.layers.conv2d(inputs = vgg_layer3_out,
                               filters =  NUMBER_OF_CLASSES,
                               kernel_size = (1, 1),
                               strides = (1, 1),
                               name = 'layer3conv1x1')
    
    layer4x = tf.layers.conv2d(inputs = vgg_layer4_out,
                               filters =  NUMBER_OF_CLASSES,
                               kernel_size = (1, 1),
                               strides = (1, 1),
                               name = 'layer4conv1x1')
    
    layer7x = tf.layers.conv2d(inputs = vgg_layer7_out,
                               filters =  NUMBER_OF_CLASSES,
                               kernel_size = (1, 1),
                               strides = (1, 1),
                               name = 'layer7conv1x1')

    decoderlayer1 = tf.layers.conv2d_transpose(inputs = layer7x,
                                               filters = NUMBER_OF_CLASSES,
                                               kernel_size = (4, 4),
                                               strides = (2, 2),
                                               padding = 'same',
                                               name = 'decoderlayer1')
    decoderlayer2 = tf.add(decoderlayer1, layer4x, name = 'decoderlayer2')
    
    decoderlayer3 = tf.layers.conv2d_transpose(inputs = decoderlayer2,
                                               filters = NUMBER_OF_CLASSES,
                                               kernel_size = (4, 4),
                                               strides = (2, 2),
                                               padding = 'same',
                                               name = 'decoderlayer3')
    
    decoderlayer4 = tf.add(decoderlayer3, layer3x, name = 'decoderlayer4')
    decoderlayer_output = tf.layers.conv2d_transpose(inputs = decoderlayer4,
                                                     filters = NUMBER_OF_CLASSES,
                                                     kernel_size = (16, 16),
                                                     strides = (8, 8),
                                                     padding = 'same',
                                                     name = 'decoderlayer_output')

    return decoderlayer_output

In [10]:
def optimize(nn_last_layer, correct_label, learning_rate, num_classes = NUMBER_OF_CLASSES):
    logits = tf.reshape(nn_last_layer, (-1, num_classes))
    class_labels = tf.reshape(correct_label, (-1, num_classes))

    cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits = logits, labels = class_labels)
    cross_entropy_loss = tf.reduce_mean(cross_entropy)

    train_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy_loss)

    return logits, train_optimizer, cross_entropy_loss

In [11]:
def train_nn(sess, epochs, batch_size, get_batches_function, train_optimizer, cross_entropy_loss, input_image, correct_label, keep_prob, learning_rate):
    start = time.time()
    print('epoch: 0', '/', EPOCHS, 'training loss: N/A')
    
    for epoch in range(EPOCHS):
        losses, i = [], 0
        for images, labels in get_batches_function(BATCH_SIZE):
            i += 1
            feed = {input_image: images,
                   correct_label: labels,
                   keep_prob: DROPOUT,
                   learning_rate: LEARNING_RATE}

            _, partial_loss = sess.run([train_optimizer, cross_entropy_loss], feed_dict = feed)
            
            print('-> iteration: ', i, '/', NUMBER_OF_IMAGES, 'partial loss: ', partial_loss)
            losses.append(partial_loss)

        training_loss = sum(losses) / len(losses)
        
        end = time.time()
        training_time = end - start
        
        print('epoch: ', epoch + 1, '/', EPOCHS, 'training loss: ', training_loss)
    print('training time: ', training_time)

In [12]:
get_batches_function = generate_batch_function(TRAINING_DATA_DIRECTORY, IMAGE_SHAPE)

In [13]:
with tf.Session() as session:    
    image_input, keep_prob, layer3, layer4, layer7 = load_vgg(session, VGG_PATH)
    model_output = layers(layer3, layer4, layer7, NUMBER_OF_CLASSES)
    
    logits, train_optimizer, cross_entropy_loss = optimize(model_output, correct_label, learning_rate, NUMBER_OF_CLASSES)
    
    session.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
    train_nn(session, EPOCHS, BATCH_SIZE, get_batches_function, train_optimizer, cross_entropy_loss, image_input, correct_label, keep_prob, learning_rate)
    
    save_inference_samples(RUNS_DIRECTORY, DATA_DIRECTORY, session, IMAGE_SHAPE, logits, keep_prob, image_input)

INFO:tensorflow:Restoring parameters from b'./data/vgg/variables/variables'
epoch: 0 / 1 training loss: N/A
---> iteration:  1 / 289 partial loss:  57.317
---> iteration:  2 / 289 partial loss:  30.9959
---> iteration:  3 / 289 partial loss:  46.8647
---> iteration:  4 / 289 partial loss:  21.2122
---> iteration:  5 / 289 partial loss:  21.4382
---> iteration:  6 / 289 partial loss:  17.7256
---> iteration:  7 / 289 partial loss:  14.3689
---> iteration:  8 / 289 partial loss:  20.6739
---> iteration:  9 / 289 partial loss:  14.4892
---> iteration:  10 / 289 partial loss:  12.6492
---> iteration:  11 / 289 partial loss:  12.45
---> iteration:  12 / 289 partial loss:  11.0529
---> iteration:  13 / 289 partial loss:  8.04862
---> iteration:  14 / 289 partial loss:  7.4452
---> iteration:  15 / 289 partial loss:  6.22143
---> iteration:  16 / 289 partial loss:  6.0769
---> iteration:  17 / 289 partial loss:  5.25337
---> iteration:  18 / 289 partial loss:  5.76766
---> iteration:  19 / 28

---> iteration:  165 / 289 partial loss:  0.858921
---> iteration:  166 / 289 partial loss:  0.801933
---> iteration:  167 / 289 partial loss:  0.787094
---> iteration:  168 / 289 partial loss:  0.784124
---> iteration:  169 / 289 partial loss:  0.771223
---> iteration:  170 / 289 partial loss:  0.776431
---> iteration:  171 / 289 partial loss:  0.878011
---> iteration:  172 / 289 partial loss:  0.7807
---> iteration:  173 / 289 partial loss:  0.86002
---> iteration:  174 / 289 partial loss:  0.84983
---> iteration:  175 / 289 partial loss:  0.841655
---> iteration:  176 / 289 partial loss:  0.789239
---> iteration:  177 / 289 partial loss:  0.786108
---> iteration:  178 / 289 partial loss:  0.839303
---> iteration:  179 / 289 partial loss:  0.766587
---> iteration:  180 / 289 partial loss:  0.862978
---> iteration:  181 / 289 partial loss:  0.734057
---> iteration:  182 / 289 partial loss:  0.757243
---> iteration:  183 / 289 partial loss:  0.822639
---> iteration:  184 / 289 partial 