# UNet for WBC Segmentation in TF-slim
__________________________________
> Contains the implementation of UNet using TF-Slim framework to demarcate the boundary of WBC's

In [1]:
# load the necessary
import matplotlib.pylab as plt
%matplotlib inline
import os
import numpy as np
import math as m
import sys
print('Python version:',sys.version)
import tensorflow as tf
print('TF version:',tf.__version__)
from data_utils import data_augmentation 
slim = tf.contrib.slim
from colorama import Fore, Style

Python version: 3.5.2 (default, Nov 23 2017, 16:37:01) 
[GCC 5.4.0 20160609]


  from ._conv import register_converters as _register_converters


TF version: 1.8.0


__________________________
## Architecture

> 1. The model is implemented in tensorflow Slim framework for better readability.   
> 2. The preprocessing pipeline is done using the dataset API, where dynamic switching between test and validation sets are implemented.  
> 3. The number of layers, and feature length of each of the Unet layers are configurable.  
> 4. For regularization, dropout is implemented.
> 5. For better training convergence, batch norm is implemented after every convolutional layer. 
> 6. Based on the article https://distill.pub/2016/deconv-checkerboard/ i have used their recommendation of using image resize, instead of transposed convolution. Both these implementations are configurable.

In [2]:
#Hyperparameters
T_CONV = False
KEEP_PROB = 0.7

slim = tf.contrib.slim

# Encoder block
def encoder_block(input_layer, output_channels, is_train, scope):
    """ Encode block on the down size of Unet"""
    
    with tf.variable_scope(scope, reuse=None):
         with tf.contrib.framework.arg_scope([slim.conv2d], 
                                             normalizer_fn=slim.batch_norm,
                                             normalizer_params={"is_training":is_train, 'updates_collections': None},
                                             padding='SAME',
                                             kernel_size=(3,3)):
                #Default activation is relu, use slim.repeat to repeat the conv layer twice
                conv = slim.repeat(input_layer,2, slim.conv2d, output_channels, scope='conv')
                drop = slim.dropout(conv,keep_prob=KEEP_PROB, is_training=is_train)
                pool = slim.max_pool2d(drop, kernel_size=[2,2],padding='SAME',scope='pool')
                
                #print('conv:',drop.get_shape().as_list())
                #print('pool:',pool.get_shape().as_list())
                return drop, pool
# Decoder block
def decoder_block(input_layer, concat_layer, output_channels, is_train, scope):
    """ Decode block on the down size of Unet"""
    with tf.variable_scope(scope, reuse=None):
         with tf.contrib.framework.arg_scope([slim.conv2d],                                             
                                             normalizer_fn=slim.batch_norm,
                                             normalizer_params={"is_training":is_train,'updates_collections': None},
                                             padding='SAME', kernel_size=(3,3)):
                
                #Found image resize works better than conv2d_transpose
                
                if T_CONV:
                    conv_t = slim.conv2d_transpose(input_layer, output_channels,kernel_size=[2,2], stride=2, scope='up_conv')
                else:
                    # upsample
                    conv_t = tf.image.resize_nearest_neighbor(input_layer, (tf.shape(input_layer)[1]*2, tf.shape(input_layer)[2]*2), name='up_sample')
                
                #concat the layers                    
                concat = tf.concat([conv_t, concat_layer],3)
                # Use the repeat function to repeat a layer
                conv = slim.repeat(concat,2, slim.conv2d, output_channels//2, scope='conv')
                
                # droupout for regulatization
                drop  = slim.dropout(conv,keep_prob=KEEP_PROB,is_training=is_train)
                
                #print('conv_t:',conv_t.get_shape().as_list())
                #print('concat:',concat.get_shape().as_list())
                #print('conv:',drop.get_shape().as_list())
                return drop

In [3]:
#Unet implementation

def Unet (inputs, is_train):
    """ Unet implementation
    Args:
        inputs : placeholder for inputs [batch, HEIGHT, WIDTH, 3]
        is_train: tf.bool set to True during training, False otherwise
    Returns:
        Output tensor same WIDTH and HEIGHT as the input
        returns the enc and decoder block outputs as well.
    """
    layers = 4
    feature_multiple = 64 #32
    
    # down layers
    enc_convs = []
    input_layer = inputs    
    for layer in range(0, layers):
        features = 2** layer * 32
        conv, input_layer = encoder_block(input_layer, features, is_train, 'enc_block'+str(layer+1))       
        enc_convs.append(conv)

    #up layers
    out_layer = input_layer
    for layer in range(layers-1, -1, -1):
        features = 2**(layer + 1) * feature_multiple
        out_layer = decoder_block(out_layer, enc_convs[layer], features, is_train, 'dec_block'+str(layer+1))
    
    #the final conv 1x1 layer
    out_layer = slim.conv2d(out_layer, 2, [1,1], activation_fn=None)
        
    return out_layer

In [4]:
# Check the above graph
tf.reset_default_graph()
net = Unet(tf.placeholder(tf.float32,(None, 128, 128, 3)), True)
for var in tf.global_variables():
    print(var.name)


enc_block1/conv/conv_1/weights:0
enc_block1/conv/conv_1/BatchNorm/beta:0
enc_block1/conv/conv_1/BatchNorm/moving_mean:0
enc_block1/conv/conv_1/BatchNorm/moving_variance:0
enc_block1/conv/conv_2/weights:0
enc_block1/conv/conv_2/BatchNorm/beta:0
enc_block1/conv/conv_2/BatchNorm/moving_mean:0
enc_block1/conv/conv_2/BatchNorm/moving_variance:0
enc_block2/conv/conv_1/weights:0
enc_block2/conv/conv_1/BatchNorm/beta:0
enc_block2/conv/conv_1/BatchNorm/moving_mean:0
enc_block2/conv/conv_1/BatchNorm/moving_variance:0
enc_block2/conv/conv_2/weights:0
enc_block2/conv/conv_2/BatchNorm/beta:0
enc_block2/conv/conv_2/BatchNorm/moving_mean:0
enc_block2/conv/conv_2/BatchNorm/moving_variance:0
enc_block3/conv/conv_1/weights:0
enc_block3/conv/conv_1/BatchNorm/beta:0
enc_block3/conv/conv_1/BatchNorm/moving_mean:0
enc_block3/conv/conv_1/BatchNorm/moving_variance:0
enc_block3/conv/conv_2/weights:0
enc_block3/conv/conv_2/BatchNorm/beta:0
enc_block3/conv/conv_2/BatchNorm/moving_mean:0
enc_block3/conv/conv_2/Ba

________________________________
### Create input pipelines for both train and validation
________________________________

In [5]:
#input pipeline for both test and validation

def binary_threshold(x, thres=0.5):
    """ threshold to 0-1"""
    cond = tf.less(x, thres)
    out = tf.where(cond, tf.zeros_like(x), tf.ones_like(x))    
    return out

def _parse_function(line):
    """ parse the csv, read image and mask"""
    image_raw, mask_raw = tf.decode_csv(line, record_defaults=[[""],[""]])
    image = tf.image.decode_jpeg(tf.read_file(image_raw))
    mask = tf.image.decode_jpeg(tf.read_file(mask_raw))    
    mask = tf.cast(binary_threshold(tf.cast(mask, tf.float32)/255.), tf.int32)
    return tf.cast(image, tf.float32)/255., mask

def _augment(image, mask):
    """Function that does input augmentation"""
    mask = tf.cast(mask, tf.float32)
    image, mask = data_augmentation(image, mask)
    
    #binarize mask to 0-1
    mask = tf.cast(binary_threshold(mask), tf.int32)
    return image, mask

def _resize_valid(image, mask):
    """Resize the images to nearest multiple of 64
       Odd sizes will lead to concatination issues.
    """
    
    height = tf.truncatediv(tf.shape(image)[0],64)*64
    width = tf.truncatediv(tf.shape(image)[1],64)*64
    
    image = tf.image.resize_images(image, [height, width])
    mask = tf.image.resize_images(mask, [height, width])
    #binarize mask to 0-1
    mask = tf.cast(binary_threshold(mask), tf.int32)
    
    return image, mask
    
# The filename file must contain the path for images and the labels
def input_pipeline(filename, batch_size, validation=False):
    """ Input data pipeline, no augmentation during validation"""
    # Read from csv
    dataset = tf.data.TextLineDataset([filename])
    dataset = dataset.map(_parse_function, num_parallel_calls=4)
    
    # shuffle for only train set
    if validation == False:
        # here iam combining both normal and augmented samples
        augmented = dataset.map(_augment, num_parallel_calls=4)
        dataset = dataset.concatenate(augmented)
        dataset = dataset.shuffle(buffer_size=1000).repeat()
        dataset = dataset.batch(batch_size)

    else:
        dataset = dataset.map(_resize_valid).shuffle(buffer_size=5)
        dataset = dataset.batch(batch_size)
 
    return dataset    


_____________________________________
## Create the graph and running it in a session.
> The input pipelines for both train and validation can be dynamically switched.
> Since the validation samples are of different size, they are given as single batch one by one.
> 

In [6]:
def dice_coef(y_true, y_pred, axis=None, smooth = 0.001):
    if axis is None:
        axis=[1,2]
    y_true_f = tf.cast(y_true, dtype=tf.float32)
    y_pred_f = tf.cast(y_pred, dtype=tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=axis)
    dice = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f, axis=axis)
                                           + tf.reduce_sum(y_pred_f, axis=axis) + smooth)
    return -tf.reduce_mean(dice)


In [7]:
#Create the graph
from datetime import datetime
import time

BATCH_SIZE = 32
LEARNING_RATE = 3e-5
INPUT_WIDTH = 128
INPUT_HEIGHT = 128
LOG_FREQ = 100
MODEL_DIR = './unet_model'

tf.reset_default_graph()

train_graph = tf.Graph()
with train_graph.as_default():
    global_step = tf.train.get_or_create_global_step()
    
    #-------------------------------------------
    #1. Create valid and train iterators
    #-------------------------------------------
    # A feedable iterator that can be switched between train and valid samples cannot be used in this case
    # as the train and valid output sizes are different, so we will use a bool to switch between the two.
    
    training_filenames = 'train.csv'
    validation_filenames = 'validation.csv'
    
    # Create valid and train iterators
    training_dataset = input_pipeline(training_filenames, BATCH_SIZE)
    validation_dataset = input_pipeline(validation_filenames, 1,  validation=True)

    # A bool to switch between training loop and testing loop
    is_train = tf.placeholder(dtype=tf.bool, name='is_train')
    
    # A feedable iterator is defined by a handle placeholder and its structure.
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle, 
                                                   training_dataset.output_types, 
                                                   training_dataset.output_shapes)
    # Returns a batch of image and mask at every call
    image_batch, mask_batch = iterator.get_next()
    
    # Create a initializable iterator for valid dataset, 
    # so that the dataset is same for every valid loop.
    validation_iterator = validation_dataset.make_initializable_iterator()
    training_iterator = training_dataset.make_one_shot_iterator()#make_initializable_iterator()
    
    #------------------------------------------------
    # 2. load Unet to the graph
    #------------------------------------------------
    
    # Placeholder definitions to handle incase of testing individual images
    X = tf.placeholder_with_default(image_batch, shape=[None,None,None,3], name='X')
    y = tf.placeholder_with_default(mask_batch, shape=[None,None,None,1], name='y')
    
    #Unet logits
    logits = Unet(X, is_train)   
    
    #------------------------------------------------
    # 3. Loss and accuracy
    #------------------------------------------------
    
    y = tf.squeeze(y)
    
    loss = tf.losses.softmax_cross_entropy(onehot_labels=tf.one_hot(y,2), logits=logits)
    
    solver = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
   
    train_op = slim.learning.create_train_op(loss, solver, global_step=global_step)
    
    # Probabilities of the outputs
    prob = tf.nn.softmax(logits, name='prob')
    
    # Prediction 
    pred = tf.argmax(prob, 3, name='pred')

    #accuracy 
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.cast(pred, tf.int32), tf.cast(y, tf.int32)), tf.float32), name='acc')
                                             


In [8]:
# Run the graph in the session
with tf.Session(graph=train_graph) as sess:
    sess.run(tf.global_variables_initializer())
    train_handle,valid_handle = sess.run([training_iterator.string_handle(),
                                                  validation_iterator.string_handle()])

    # variables for saving model and early exit
    loss_sum = 0.
    best_v_loss = 10e10
    early_stopping = 0
    saver = tf.train.Saver()

    while True:
        try:

            # the train call
            _, l, step = sess.run([train_op,loss, global_step], {is_train:True, handle:train_handle} )
            if (step %LOG_FREQ) == 0:
                    tf.logging.info(Fore.BLUE + Style.BRIGHT +'Step:{}:Train Loss:{:.5f}'.format(step,l)+Fore.RESET+Style.RESET_ALL)
                    train_pred, train_mask = sess.run([pred, y], {is_train:True, handle:train_handle} )

            # the validation call
            if (step %(LOG_FREQ*5)) == 0:
                valid_loss = 0.
                valid_acc = 0.
                count = 0
                sess.run(validation_iterator.initializer, {is_train:False})
                while True:
                    try:
                        acc, valid_pred, valid_mask,l = sess.run([accuracy,pred,y,loss], {is_train:False, handle:valid_handle} )

                        valid_loss += l
                        valid_acc += acc
                        count += 1
                    except tf.errors.OutOfRangeError:
                        break
                valid_loss = valid_loss/count
                tf.logging.info(Fore.GREEN + Style.BRIGHT +'Step:{}:Valid Loss:{:.5f}:Accuracy:{:.5f}'.format(step,valid_loss, valid_acc/count)+Fore.RESET+Style.RESET_ALL)

                 # Save the best model based on valid loss
                if (best_v_loss > valid_loss) and (step > 0):
                    tf.logging.info(Fore.RED+ Style.BRIGHT +'Saving the model...'+Fore.RESET+Style.RESET_ALL)
                    saver.save(sess, MODEL_DIR)
                    best_v_loss = valid_loss
                    early_stopping = 0
                else:
                    early_stopping += 1

                if early_stopping > 3:
                    tf.logging.info(Fore.RED+ Style.BRIGHT +'Stopping the training...'+Fore.RESET+Style.RESET_ALL)
                    break


        except tf.errors.OutOfRangeError:
            print('Completed training...')
            break



INFO:tensorflow:[34m[1mStep:0:Train Loss:0.70553[39m[0m
INFO:tensorflow:[32m[1mStep:0:Valid Loss:0.65323:Accuracy:0.98655[39m[0m
INFO:tensorflow:[34m[1mStep:100:Train Loss:0.16027[39m[0m
INFO:tensorflow:[34m[1mStep:200:Train Loss:0.13283[39m[0m
INFO:tensorflow:[34m[1mStep:300:Train Loss:0.12480[39m[0m
INFO:tensorflow:[34m[1mStep:400:Train Loss:0.10738[39m[0m
INFO:tensorflow:[34m[1mStep:500:Train Loss:0.13480[39m[0m
INFO:tensorflow:[32m[1mStep:500:Valid Loss:0.67323:Accuracy:0.51926[39m[0m
INFO:tensorflow:[31m[1mSaving the model...[39m[0m
INFO:tensorflow:[34m[1mStep:600:Train Loss:0.08831[39m[0m
INFO:tensorflow:[34m[1mStep:700:Train Loss:0.12008[39m[0m
INFO:tensorflow:[34m[1mStep:800:Train Loss:0.08028[39m[0m
INFO:tensorflow:[34m[1mStep:900:Train Loss:0.10024[39m[0m
INFO:tensorflow:[34m[1mStep:1000:Train Loss:0.10024[39m[0m
INFO:tensorflow:[32m[1mStep:1000:Valid Loss:0.52609:Accuracy:0.99206[39m[0m
INFO:tensorflow:[31m[1mSavi

----- EOF -----