## Segmentation CNN

@ysbecca
Basic CNN to perform segmentation of training set based on layers:
- epithelial layer
- submucosa

In [8]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import confusion_matrix
from tensorflow.python.client import device_lib

import scripts.cnn_helper as cn
import scripts.dataset as ds

Parameters.

In [12]:
checkpoints_dir = "/Users/ysbecca/ysbecca-projects/bcsp-expert/checkpoints/"

# Image params
img_size = 16
num_channels = 3
img_size_flat = img_size * img_size * num_channels
img_shape = (img_size, img_size)

# Convolutional layer params
filter_sizes = [3, 3]
num_filters = [16, 16]
num_layers = len(filter_sizes)

max_pools = [2, 2]

# Fully connected layers, followed by classification layer.
fc_1_size = 128
fc_2_size = 64

num_classes = 2

Build TensorFlow graph structure.

In [13]:
x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')
x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])

y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true')
y_true_cls = tf.argmax(y_true, dimension=1)

keep_prob = tf.placeholder(tf.float32) # So that we can control dropout.

Find device configuration.

In [14]:
gpus = [x.name for x in device_lib.list_local_devices() if x.device_type == 'GPU']
num_gpus = len(gpus)
print("GPU devices found: " + str(gpus))

GPU devices found: []


Define model.

In [27]:
def model(x_image_, y_true_):
    ''' Expecting the following parameters, in batches:
        x_image_ - x_image
        y_true_ - y_true
    '''
    network, _ = cn.new_conv_layer(input=x_image_,
                                  num_input_channels=num_channels,
                                  filter_size=filter_sizes[0],
                                  num_filters=num_filters[0],
                                  use_pooling=True,
                                  max_pool_size=max_pools[0])
    
    network = tf.nn.dropout(network, keep_prob=keep_prob)
    network, _ = cn.new_conv_layer(input=network,
                                  num_input_channels=num_filters[0],
                                  filter_size=filter_sizes[1],
                                  num_filters=num_filters[1],
                                  use_pooling=True,
                                  max_pool_size=max_pools[1])
    network = tf.nn.dropout(network, keep_prob=keep_prob)
    network, num_fc_features = cn.flatten_layer(network) # 256
    
    # Flatten and build the fully-connected layers.
    network, _ = cn.new_fc_layer(input=network,
                             num_inputs=num_fc_features,
                             num_outputs=fc_1_size,
                             use_relu=False)
    network, _ = cn.new_fc_layer(input=network,          
                             num_inputs=fc_1_size,
                             num_outputs=num_classes,
                             use_relu=True)

    y_pred = tf.nn.softmax(network)                    
    y_pred_cls = tf.argmax(network, dimension=1)
    
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=network, labels=y_true_)
    cost = tf.reduce_mean(cross_entropy)
    
    return y_pred, y_pred_cls, cost


Split operations (optionally) across GPUs.

In [28]:
def make_parallel(fn, num_gpus, **kwargs):
    in_splits = {}
    for k, v in kwargs.items():
        in_splits[k] = tf.split(v, num_gpus)

    y_pred_split, y_pred_cls_split, cost_split = [], [], []
    for i in range(num_gpus):
        with tf.device(tf.DeviceSpec(device_type="GPU", device_index=i)):
            with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0):
                y_pred_, y_pred_cls_, cost_ = fn(**{k : v[i] for k, v in in_splits.items()})
                y_pred_split.append(y_pred_)
                y_pred_cls_split.append(y_pred_cls_)
                cost_split.append(cost_)

    return tf.concat(y_pred_split, axis=0), tf.concat(y_pred_cls_split, axis=0), tf.stack(cost_split, axis=0)

if num_gpus > 0:
    total_batches = num_gpus
else:
    total_batches = 1

# train_batch_size = train_batch_size * total_batches
# test_batch_size = test_batch_size * total_batches

Define cost and loss functions.

In [29]:
if num_gpus > 0:
    # Remember that this adds significant latency for CPU<->GPU copying of shared variables.
    # 2 GPU's is enough to get a good balance between speedup and minimal latency (12 GB on k80 nodes)
    y_pred, y_pred_cls, cost = make_parallel(model, num_gpus, x_image_=x_image, y_true_=y_true)
else:
    # CPU-only version
    y_pred, y_pred_cls, cost = model(x_image_=x_image, y_true_=y_true)
    
optimizer = tf.train.AdagradOptimizer(learning_rate=1e-4).minimize(cost)
# optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

Tensor("Conv2D_6:0", shape=(?, 16, 16, 16), dtype=float32)
Tensor("Relu_9:0", shape=(?, 16, 16, 16), dtype=float32)
Tensor("MaxPool_6:0", shape=(?, 8, 8, 16), dtype=float32)
Tensor("Conv2D_7:0", shape=(?, 8, 8, 16), dtype=float32)
Tensor("Relu_10:0", shape=(?, 8, 8, 16), dtype=float32)
Tensor("MaxPool_7:0", shape=(?, 4, 4, 16), dtype=float32)


Start session!

In [30]:
if num_gpus > 0: # Log GPU/CPU placement to the terminal if using GPU's.
    session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
else:
    session = tf.Session()
    
session.run(tf.global_variables_initializer())
saver = tf.train.Saver() # For when we want to save the model.

# Global ounter for total number of iterations performed so far.
total_iterations = 0

Supporting functions.

In [31]:
def save_model(iterations=False):
    model_name = 'm-' + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M") + "-"
    if iterations:
        model_name += str(iterations)
    else:
        model_name += str(total_iterations)
    
    save_path = saver.save(session, checkpoints_dir + model_name)
    print("Model saved: " + model_name)

def restore_model(model_name):
    saver.restore(sess=session, save_path=checkpoints_dir + model_name)
    
def optimize(dataset_train, num_iterations, dropout_keep_prob=0.9, print_opt_acc=False, silent=False):
    global total_iterations
    
    start_time = time.time()
    for i in range(total_iterations,
                   total_iterations + num_iterations):
        x_batch, y_true_batch = dataset_train.next_batch(train_batch_size)
        x_batch = x_batch.reshape(len(x_batch), img_size_flat)
        feed_dict_train = {x: x_batch, y_true: y_true_batch, keep_prob: dropout_keep_prob}

        session.run(optimizer, feed_dict=feed_dict_train)

        # Print status every few iterations (a big few).
        if i % 200 == 0:
            # Calculate the accuracy on the training-set.
            acc = session.run(accuracy, feed_dict=feed_dict_train)

            if print_opt_acc:
                msg = "Optimization Iteration: {0:>6}, Training Accuracy: {1:>6.1%}"
                print(msg.format(i + 1, acc))
                
        # Check if we should save the model more frequently than we print the status.
        #if i % 1000 == 0:
        #    acc = session.run(accuracy, feed_dict=feed_dict_train)
        #    if acc > 0.65:
        #        save_model()

    total_iterations += num_iterations

    end_time = time.time()
    time_dif = end_time - start_time
    if not silent:
        print("Time usage: " + str(timedelta(seconds=int(round(time_dif)))))
    
def print_test_accuracy(dataset_test, show_confusion_matrix=True, quieter=False, silent=False):
    num_test = len(dataset_test.images)
    cls_pred = np.zeros(shape=num_test, dtype=np.int)
    i = 0

    while i < num_test:
        j = min(i + test_batch_size, num_test)
        curr_batch_size = j - i
        
        # Get the images and targets from the test-set between index i and j.
        images = dataset_test.images[i:j, :].reshape(curr_batch_size, img_size_flat)
        labels = dataset_test.labels[i:j, :]
        feed_dict = {x: images, y_true: labels, keep_prob: 1.0}

        cls_pred[i:j] = session.run(y_pred_cls, feed_dict=feed_dict)
        i = j

    cls_true = dataset_test.cls
    
    # Create a boolean array whether each image is correctly classified.
    correct = (cls_true == cls_pred)

    correct_sum = correct.sum() #sum(1 for a, b in zip(cls_true, cls_pred) if a and b)
    acc = float(correct_sum) / num_test

    msg = "Accuracy on validation set: {0:.1%} ({1} / {2})"
    if not quieter:
        print(msg.format(acc, correct_sum, num_test))
    else:
        if not silent:
            print("{0:.1%}".format(acc))
    if show_confusion_matrix:
        cn.plot_confusion_matrix(cls_true, cls_pred=cls_pred)
    return acc

Load data.