In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import math
import time
import numpy as np
import tensorflow as tf
from PIL import Image
from matplotlib import pyplot as plt

# local python package
sys.path.append(os.path.join(os.getcwd(), '..'))
from voc_loader import VOC_loader
import fcn_vgg
from loss import loss as get_loss

In [None]:
def plot_result(img, gt_seg, pred_seg, palette):
    if pred_seg is None:
        f, (ax1, ax2) = plt.subplots(1, 2)
    else:
        f, (ax1, ax2, ax3) = plt.subplots(1, 3)
        
    ax1.imshow(img.astype(np.uint8))
    ax2.imshow(gt_seg.astype(np.uint8))
    
    if pred_seg is not None:
        np.place(pred_seg, pred_seg == 21, 255)
        tmp = Image.fromarray(pred_seg.astype(np.uint8), 'P')
        tmp.putpalette(palette)
        ax3.imshow(tmp)
    
    plt.show()

## Download VOC 2012 dataset

In [None]:
if not os.path.exists('VOCdevkit'):
    os.system('wget http://cvlab.postech.ac.kr/~jonghwan/VOC2012.tar')
    os.system('tar xvf VOC2012.tar')
    os.system('rm VOC2012.tar')

if not os.path.isfile('vgg16.npy'):
    os.system('wget http://cvlab.postech.ac.kr/~jonghwan/vgg16.npy')

### Create loader

In [None]:
loader_params = {
    'num_classes': 22,
    'image_size': 448,
    'split_root': 'VOCdevkit/VOC2012/ImageSets/Segmentation',
    'image_root': 'VOCdevkit/VOC2012/JPEGImages',
    'segmap_root': 'VOCdevkit/VOC2012/SegmentationClass',
}
loader = VOC_loader(loader_params)

# get information for VOC
class_names = loader.get_class_names()
num_classes = loader_params['num_classes']
img_size = loader_params['image_size']

### Check the dataset

In [None]:
batch = loader.get_batch(10)

In [None]:
bi = np.random.randint(10)
plot_result(batch['images'][bi], batch['seg_maps'][bi], 
            None, loader.get_palette())

## Build FCN

In [None]:
# Model updating parameters
lr_params = {}
lr_params['initial_lr'] = 0.0001
lr_params['decay_step'] = 10000
lr_params['decay_rate'] = 0.8

batch_size = 16
num_epochs = 50
save_path = 'fcn_checkpoints/fcn_vgg'
if not os.path.exists('fcn_checkpoints'): os.makedirs('fcn_checkpoints')
iteration_per_epoch = int(math.floor(loader.get_num_train_examples() / batch_size))
save_checkpoint_frequency = 100
print_frequency = 10

In [None]:
#with tf.Session() as sess:
with tf.variable_scope('inputs') as scope:
    images = tf.placeholder(dtype=tf.float32, 
                            shape=[batch_size, img_size, img_size, 3],
                            name='images')
    labels = tf.placeholder(dtype=tf.int64, 
                            shape=[batch_size, img_size, img_size], 
                            name='labels')

# build FCN
vgg_fcn = fcn_vgg.FCN()
with tf.name_scope("content_vgg"):
    vgg_fcn.build(images, train=True, num_classes=22, debug=False)
    
# optimizer
global_step = tf.Variable(initial_value=0, name='global_step', trainable=False)
lr = tf.train.exponential_decay(lr_params['initial_lr'],
                                global_step,
                                lr_params['decay_step'],
                                lr_params['decay_rate'],
                                staircase=True)
total_loss = get_loss(vgg_fcn.upscore, labels, num_classes)
train_op = tf.train.AdamOptimizer(lr).minimize(total_loss, global_step=global_step)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

## Training FCN

In [None]:
# Manually set the checkpoint path
#checkpoint_path = 'cifar10_checkpoints/cifar10_cnn-5000'
# Automatically find the last checkpoint
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir='fcn_checkpoints/')
print('Last checkpoint path is %s' % (checkpoint_path))

In [None]:
# Create saver
saver = tf.train.Saver()
if checkpoint_path != '':
    saver.restore(sess, save_path=checkpoint_path)
    print('Model is restored from %s' % checkpoint_path)

# create summary node and file wirter
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter('./log_train', sess.graph)

for ie in range(num_epochs):
    for ii in range(iteration_per_epoch):
        # Load a batch data
        batch = loader.get_batch(batch_size, 'train')
        feed_dict = {images: batch['images'], labels: batch['seg_labels']}

        # Run the optimizer
        tensors = [global_step, merged, total_loss, train_op]
        iteration, summary, tf_loss, _ = sess.run(tensors, feed_dict=feed_dict)
        train_writer.add_summary(summary, iteration)

        # Print the accuracy and loss of current batch data
        if iteration % print_frequency == 0:
            print('%d Epoch %d iteration - Loss (%.3f)' % (ie+1, ii+1, tf_loss))

        # Save checkpoint
        if iteration % save_checkpoint_frequency == 0:
            saver.save(sess, save_path=save_path, global_step=global_step)
            print('Saved checkpoint %s_%d' % (save_path, iteration))

## TEST the model

### Load the checkpoint

In [None]:
# reset the graph and session
tf.reset_default_graph()
sess.close()

In [None]:
batch_size = 7
with tf.variable_scope('inputs') as scope:
    images = tf.placeholder(dtype=tf.float32, 
                            shape=[batch_size, img_size, img_size, 3],
                            name='images')
    labels = tf.placeholder(dtype=tf.int64, 
                            shape=[batch_size, img_size, img_size], 
                            name='labels')

# build FCN
vgg_fcn = fcn_vgg.FCN()
with tf.name_scope("content_vgg"):
    vgg_fcn.build(images, train=False, num_classes=22, debug=False)
    
# optimizer
global_step = tf.Variable(initial_value=0, name='global_step', trainable=False)
lr = tf.train.exponential_decay(lr_params['initial_lr'],
                                global_step,
                                lr_params['decay_step'],
                                lr_params['decay_rate'],
                                staircase=True)
total_loss = get_loss(vgg_fcn.upscore, labels, num_classes)
train_op = tf.train.AdamOptimizer(lr).minimize(total_loss, global_step=global_step)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

# Automatically find the last checkpoint
#checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir='fcn_checkpoints/')
checkpoint_path = 'fcn_checkpoints/cifar10_cnn-3000'
print('Last checkpoint path is %s' % (checkpoint_path))
# Create saver
saver = tf.train.Saver()
if checkpoint_path != '':
    saver.restore(sess, save_path=checkpoint_path)
    print('Model is restored from %s' % checkpoint_path)

### Visualize the result

In [None]:
loader.reset()

In [None]:
# load test data and 
batch = loader.get_batch(batch_size, 'test')

tensors = [vgg_fcn.pred_up]
feed_dict = {images: batch['images'], labels: batch['seg_labels']}

# forward network to obtain segmentation result
score  = sess.run(tensors, feed_dict=feed_dict)

for bi in range(batch_size):
    plot_result(batch['images'][bi], batch['seg_maps'][bi], score[0][bi], loader.get_palette())

### Compute the accuracy

In [None]:
# Predict segmentation labels for all test examples
# and compute confusion matrix
num_classes = 22
num_iterations = loader.get_num_test_examples() / batch_size
conf_counts = np.zeros((num_classes, num_classes))

loader.reset()
for ii in range(int(num_iterations)):
    # Load a batch data
    batch = loader.get_batch(batch_size, 'test')
    feed_dict = {images: batch['images'], labels: batch['seg_labels']}

    # Run the optimizer
    score = sess.run(vgg_fcn.pred_up, feed_dict=feed_dict)

    # Accumulate confusions
    for bi in range(batch_size):
        # Do not count boundary labels
        loc = np.where(batch['seg_labels'][bi] < 21, True, False)
        # row is gt labels and column is predicted labels
        sumim = batch['seg_labels'][bi] + score[bi] * num_classes
        hs, bin_edge = np.histogram(sumim[loc], np.arange(num_classes*num_classes+1), 
                                    (0, num_classes*num_classes+1))
        conf_counts = conf_counts + np.reshape(hs, (num_classes,num_classes))
    
    # Print the accuracy and loss of current batch data
    if ((ii+1) % 10 == 0) or ((ii+1) == num_iterations):
        print('TEST %d/%d Done' % (ii+1, num_iterations))

In [None]:
# Compute accuracy for all classes and mean accuracy 
acc = np.zeros(num_classes)
for ic in range(num_classes):
    gt1 = np.sum(conf_counts[ic,:])
    res1 = np.sum(conf_counts[:,ic])
    gtlres = conf_counts[ic,ic]
    acc[ic] = 100.0 * gtlres / (gt1 + res1 - gtlres)
    if (ic > 0) and (ic < num_classes-1):
        print('%s accuracy %.3f' % (class_names[ic-1], acc[ic]))
print('Mean accuracy %.3f' % (np.mean(acc[1:num_classes])))