Lung segmentation <br/>
Copyright (C) 2017 Therapixel (Pierre Fillard).

Input data: LIDC (https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI) with
LUNA16 lung segmentations (https://luna16.grand-challenge.org/data/). <br />
Beforehant, all series shall be converted to a volumetric format (ITK MHD). Filenames
shall match the UID of each series. Binaries in the tools/ folder can be used for that 
(dataImporter, seriesExporter).

Annotation files are provided in csv format. Those depict positions (in real world
coordinates) of the series identified by their series UID (globally unique) with a 1
when inside the lung, and 0 otherwise. The labels 1/0 were obtained from the lung
segmentations provided by LUNA.

This notebook will guide you through the process of training a deep net to classify
nodules vs non-nodules. The following steps are involved:
- data conversion: all annotations are turned into h5 arrays by extracting a patch
of size 64x64x64 around each position. Images are all resampled to have the same
voxel size of 2x2x2.
- model training: 4xGPUs were used to train this model using data-parallelism. 

In [None]:
%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
from datetime import datetime
import os.path
import glob
import time
from time import sleep
import sys
sys.path.append('../')
sys.path.append('../../')
import subprocess
import shutil

os.environ["CUDA_VISIBLE_DEVICES"] =  "0,1,2,3"
import tensorflow as tf
from tensorflow.contrib.layers import fully_connected, convolution2d, flatten, batch_norm, max_pool2d, dropout, l2_regularizer
from tensorflow.python.ops.nn import relu, elu, relu6, relu1, sigmoid, tanh, softmax
from tensorflow.python.ops import variable_scope
import h5py as h5
import numpy as np
import lidc as lidc
import TherapixelDL.image as tpxdli
from six.moves import xrange
import scipy as sp
from scipy import ndimage
from TherapixelDL.confusionmatrix import ConfusionMatrix
import matplotlib
import matplotlib.pyplot as plt
import csv
import SimpleITK as sitk

print(tf.__version__)

In [None]:
import importlib
importlib.reload(lidc)
importlib.reload(tpxdli)

In [None]:
def readCSV(filename):
    lines = []
    with open(filename, 'r') as f:
        csvreader = csv.reader(f)
        for line in csvreader:
            lines.append(line)
    return lines

In [None]:
def csv_to_h5(input_file, data_directory, output_file):
    inputs = readCSV(input_file)
    inputs = inputs[1:] # skip header
    input_dict={}
    total_count = 0
    for i in range(len(inputs)):
        struct = inputs[i]
        seriesuid = struct[0]
        if seriesuid not in input_dict:
            input_dict[seriesuid] = []
        pos = [float(struct[1]),float(struct[2]),float(struct[3]),float(struct[4])]
        input_dict[seriesuid].append(pos)
        total_count +=1
        
    target_spacing = [2.,2.,2.]
    patch_size = 64
    offset = patch_size//2

    patches = np.zeros(shape=(total_count,patch_size,patch_size,patch_size), dtype=np.float32)
    labels = np.zeros(shape=(total_count), dtype=np.int32)
    ids = np.zeros(shape=(total_count), dtype="S100")

    index = 0
    for seriesuid in input_dict:
        print('processing series', seriesuid,'%d/%d'%(index+1,total_count))
        series_filename = data_directory + '/' + seriesuid + '.mhd'
        itk_image = lidc.load_itk_image(series_filename)
        volume, origin, spacing, orientation = lidc.parse_itk_image(itk_image)

        padding_value = volume.min()
        img_z_orig, img_y_orig, img_x_orig = volume.shape
        img_z_new = int(np.round(img_z_orig*spacing[2]/target_spacing[2]))
        img_y_new = int(np.round(img_y_orig*spacing[1]/target_spacing[1]))
        img_x_new = int(np.round(img_x_orig*spacing[0]/target_spacing[0]))

        itk_image = lidc.resample_itk_image(itk_image, [img_x_new,img_y_new,img_z_new], target_spacing, int(padding_value))

        volume, origin, spacing, orientation = lidc.parse_itk_image(itk_image)
        volume = volume.astype(np.float32)
        volume = lidc.normalizePlanes(volume)
        volume = np.pad(volume, ((offset,offset),(offset,offset),(offset,offset)), # pad to center
                        'constant', constant_values=((0, 0),(0, 0),(0, 0)))  

        positions = input_dict[seriesuid]

        for i in range(len(positions)):
            pos = positions[i]        
            kk,jj,ii = lidc.worldToVoxelCoord(pos[0:3],origin=origin,spacing=spacing,orientation=orientation)
            kk = int(round(kk))
            jj = int(round(jj))
            ii = int(round(ii))
            patches[index] = volume[kk:kk+patch_size,jj:jj+patch_size,ii:ii+patch_size]
            labels[index] = pos[3]
            ids[index]=seriesuid
            index +=1

    h5_file = h5.File(output_file, 'w')
    h5_file.create_dataset('PATCHES', data = patches, dtype=np.float32)
    h5_file.create_dataset('LABELS', data = labels, dtype=np.int32)
    h5_file.create_dataset('ID', data = ids, dtype="S100")
    h5_file.close()

In [None]:
data_directory = '/media/data/LIDC/LIDC-MHD/'

csv_to_h5('lung_segmentation_positions.csv', data_directory, 'lung_segmentation_positions.h5')

In [None]:
# hyperameters of the model
channels = 1
scalings=None #np.array([1.0, 1.5])
offsets=None #np.array([0.0, -0.05])
depth = 64
height = 64
width = 64
num_gpus = 2
batch_size = 32
patch_size = 64
gpu_mem_ratio = 1.0
num_classes = 2

In [None]:
def readAndSplitHDF5(filename, val_ratio=0.1, seed=1234):
    h5_file = h5.File(filename, 'r')
    patches = h5_file['PATCHES'][...]
    labels = h5_file['LABELS'][...]
    ids = h5_file['ID'][...]
    h5_file.close()

    indices = np.arange(patches.shape[0])
    np.random.seed(seed)
    np.random.shuffle(indices)
    val_count = max(int(patches.shape[0]*val_ratio), batch_size*num_gpus)

    validation_data = patches[indices[:val_count]]
    validation_targets = labels[indices[:val_count]]
    train_data = patches[indices[val_count:]]
    train_targets = labels[indices[val_count:]]
    
    return train_data, train_targets, validation_data, validation_targets

In [None]:
train_data, train_targets, validation_data, validation_targets = readAndSplitHDF5('/media/data/LIDC/lung_segmentation_patches_64.h5')

In [None]:
train_hn_data, train_hn_targets, validation_hn_data, validation_hn_targets = readAndSplitHDF5('/media/data/LIDC/lung_segmentation_hard_negs_1_patches_64.h5')

In [None]:
train_hn2_data, train_hn2_targets, validation_hn2_data, validation_hn2_targets = readAndSplitHDF5('/media/data/LIDC/lung_segmentation_hard_negs_2_patches_64.h5')

In [None]:
train_hn3_data, train_hn3_targets, validation_hn3_data, validation_hn3_targets = readAndSplitHDF5('/media/data/LIDC/lung_segmentation_hard_negs_3_patches_64.h5')

In [None]:
shift_range = 0.05
image_gen_3d = tpxdli.ImageDataGenerator3D(rotation_range=10.0, width_shift_range=shift_range, height_shift_range=shift_range, depth_shift_range=shift_range,
                                           shear_range=0.1, zoom_range=np.array([0.95,1.05], dtype=np.float32), horizontal_flip=True, vertical_flip=True, depth_flip=True,
                                           windowing_scale_range=0.0, windowing_intercept_range=0.0,
                                           dim_ordering = 'tf')
# do not augment validation batch to simulate real-life data
image_gen_3d_val = tpxdli.ImageDataGenerator3D(rotation_range=0.0, width_shift_range=0.0, height_shift_range=0.0, depth_shift_range=0.0,
                                           shear_range=0.0, zoom_range=np.array([1.0,1.0], dtype=np.float32), horizontal_flip=False, vertical_flip=False, depth_flip=False,
                                           dim_ordering = 'tf')

In [None]:
def train(train_data, train_targets, validation_data, validation_targets, lr_scheme, num_gpus=1, num_epochs=100,
          output_dir='', prev_model=''):
        
    # reset graph first
    tf.reset_default_graph()
    
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        global_step = tf.contrib.framework.get_or_create_global_step()
        
        is_training = tf.placeholder(tf.bool, shape=[], name='is_training')
    
        # Setting up placeholder, this is where your data enters the graph!
        x_pl = tf.placeholder(tf.float32, shape=(None, height, width, depth, channels), name='data_x')
        y_pl = tf.placeholder(tf.int32, shape=(None), name='data_y')
    
        # defining our optimizer
        learning_rate = tf.placeholder(tf.float32, shape=[])
        
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

        # Calculate the gradients for each model tower.
        tower_grads = []    
        losses = []
        y = []
    
        x_splits = tf.split(x_pl, num_or_size_splits=num_gpus)
        y_splits = tf.split(y_pl, num_or_size_splits=num_gpus)
    
        with tf.variable_scope(tf.get_variable_scope()) as scope:
            for i in range(num_gpus):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('tower_%d' % (i)) as scope:
                        logits = lidc.inference_emphyseme(x_splits[i],
                                                          is_training=is_training,
                                                          num_outputs=num_classes)
                        logits = tf.squeeze(logits)
                        l = lidc.loss(logits=logits, labels=y_splits[i], with_regularization=False)

                        # Reuse variables for the next tower.
                        tf.get_variable_scope().reuse_variables()

                        # Calculate the gradients for the batch of data
                        grads = optimizer.compute_gradients(l)


                        # Keep track of the gradients across all towers.
                        tower_grads.append(grads)
                        losses.append(l)
                        y.append(tf.nn.softmax(logits))

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        if (num_gpus>1):
            grads = lidc.average_gradients(tower_grads)    
        else:
            grads = tower_grads[0]
    
        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = optimizer.apply_gradients(grads)
    
        # Track the moving averages of all trainable variables.
        variable_averages = tf.train.ExponentialMovingAverage(lidc.MOVING_AVERAGE_DECAY, global_step)
        variables_averages_op = variable_averages.apply(tf.trainable_variables())
    
        with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
            train_op = tf.no_op(name='train')
            
        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(lidc.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore, max_to_keep=None)
        
        # restricting memory usage, TensorFlow is greedy and will use all memory otherwise
        gpu_opts = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_mem_ratio)
        # initialize the Session
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_opts, 
                                                allow_soft_placement=True, 
                                                log_device_placement=True)) # allow_soft_placement=True needed to make batch_normalization work accross GPU
            
        sess.run(tf.global_variables_initializer())
        if (prev_model):
            print('restoring model', prev_model)
            saver.restore(sess, prev_model)            
    
    train_negative_count = 0
    train_positive_count = 0
    for i in range(len(train_targets)):
        train_negative_count += (train_targets[i]==0).sum()
        train_positive_count += (train_targets[i]==1).sum()
    n_train_samples = max(train_negative_count,train_positive_count) * num_classes
    train_capacity = batch_size * num_gpus
    num_batches_train = n_train_samples // train_capacity
    
    train_iterator_3d = image_gen_3d.flowList(X=train_data, 
                                              Y=train_targets, 
                                              batch_size=train_capacity,
                                              balance=True,
                                              shuffle=True, 
                                              output_depth=patch_size, 
                                              output_rows=patch_size, 
                                              output_cols=patch_size,
                                              num_output_channels=channels,
                                              scalings=scalings,
                                              offsets=offsets)
    
    val_capacity = batch_size * num_gpus
    val_iterator_3d = image_gen_3d_val.flowList(X=validation_data, 
                                                Y=validation_targets,
                                                batch_size=val_capacity,
                                                balance=False,
                                                shuffle=False,
                                                output_depth=patch_size, 
                                                output_rows=patch_size, 
                                                output_cols=patch_size,
                                                num_output_channels=channels,
                                                scalings=scalings,
                                                offsets=offsets)
    
    n_val_samples = 0
    for i in range(len(validation_targets)):
        n_val_samples += validation_targets[i].shape[0]
    num_batches_valid = n_val_samples // val_capacity    

    print('training with parameters:\n\t- train capacity: %d\n\t- val capacity: %d\n\t- batch size: %d\n\t- patch size: %d\n\t'\
          '- num gpu: %d\n\t- num epochs: %d\n\t- previous model: %s' % (n_train_samples, n_val_samples, batch_size, patch_size,
                                                                         num_gpus, num_epochs, prev_model))           
    
    print('number of training batches per epoch', num_batches_train)
    print('number of validation batches per epoch', num_batches_valid)
    
    train_acc, train_loss = [], []
    valid_acc, valid_loss = [], []
    test_acc, test_loss = [], []
    lr = -1
    best_val_loss = -1.
    best_val_acc = 0.
    train_loss = 0.
    valid_loss = 0.
    
    train_queue = tpxdli.QueuedIterator(train_iterator_3d, num_batches_train)
    val_queue = tpxdli.QueuedIterator(val_iterator_3d, num_batches_valid)
    
    skip_first_val = False
    
    try:
        # init best_val_loss before training
        if not skip_first_val:
            confusion_valid = ConfusionMatrix(num_classes)            
            val_queue.produce()
            for i in range(num_batches_valid):
                (batch_val_x, batch_val_y) = val_queue.get_queue().get()
                feed_dict_eval = {
                    x_pl: batch_val_x,
                    y_pl: batch_val_y,
                    is_training: False
                }
                fetches_eval = [y, losses]#, x_splits_list, y_splits, x_splits]
                # running the validation
                res = sess.run(fetches=fetches_eval, feed_dict=feed_dict_eval)
                # collecting and storing predictions
                cur_loss = np.sum(res[1])
                preds = np.argmax(np.concatenate(res[0]), axis=-1)             
                confusion_valid.batch_add(batch_val_y, preds)
                if i==0:
                    valid_loss = cur_loss/num_gpus
                else:
                    valid_loss = valid_loss*i/(i+1) + cur_loss/(num_gpus*(i+1))
                val_queue.get_queue().task_done()
                sys.stdout.write('\rValidation. batch: %d/%d. loss: %f, acc.: %f'%(i+1,num_batches_valid,valid_loss,confusion_valid.accuracy()))
                sys.stdout.flush()
                sleep(1)
            
            best_val_loss = valid_loss
            valid_acc_cur = confusion_valid.accuracy()
            best_val_acc = valid_acc_cur
            print('\nInitial validation loss and accuracy are: %f / %f'%(best_val_loss, valid_acc_cur))                
    
        for epoch in range(num_epochs):
            
            if (lr != lr_scheme[epoch]):
                lr = lr_scheme[epoch]
                print('using lr', lr)
        
            t0 = time.time()                
                        
            confusion_train = ConfusionMatrix(num_classes)
            train_queue.produce()
            for i in range(num_batches_train):
                (batch_train_x, batch_train_y) = train_queue.get_queue().get()
                feed_dict_train = {
                    x_pl: batch_train_x,
                    y_pl: batch_train_y,
                    is_training: True, 
                    learning_rate: lr
                }
                fetches_train = [train_op, losses, y]
                res = sess.run(fetches=fetches_train, feed_dict=feed_dict_train)
                cur_loss = np.sum(res[1])
                preds = np.argmax(np.concatenate(res[2]), axis=-1)
                confusion_train.batch_add(batch_train_y, preds)
                if i==0:
                    train_loss = cur_loss/(num_gpus)
                else:
                    train_loss = train_loss*i/(i+1) + cur_loss/(num_gpus*(i+1))                                                
                train_queue.get_queue().task_done()
                sys.stdout.write('\rTraining. batch: %d/%d, loss: %f, acc.: %f'%(i+1,num_batches_train,train_loss,confusion_train.accuracy()))
                sys.stdout.flush()
                sleep(1)
                    
            t1 = time.time()
            epoch_time = t1 - t0
        
            sys.stdout.write("\n")
        
            confusion_valid = ConfusionMatrix(num_classes)            
            val_queue.produce()
            for i in range(num_batches_valid):
                (batch_val_x, batch_val_y) = val_queue.get_queue().get()
                feed_dict_eval = {
                    x_pl: batch_val_x,
                    y_pl: batch_val_y,
                    is_training: False
                }
                fetches_eval = [y, losses]
                # running the validation
                res = sess.run(fetches=fetches_eval, feed_dict=feed_dict_eval)
                # collecting and storing predictions
                cur_loss = np.sum(res[1])
                preds = np.argmax(np.concatenate(res[0]), axis=-1)             
                confusion_valid.batch_add(batch_val_y, preds)
                if i==0:
                    valid_loss = cur_loss/(num_gpus)
                else:
                    valid_loss = valid_loss*i/(i+1) + cur_loss/(num_gpus*(i+1))
                val_queue.get_queue().task_done()
                sys.stdout.write('\rValidation. batch: %d/%d, loss: %f, acc.: %f'%(i+1,num_batches_valid,valid_loss,confusion_valid.accuracy()))
                sys.stdout.flush()
                sleep(1)
                            
            sys.stdout.write("\n")
            
            train_acc_cur = confusion_train.accuracy()
            valid_acc_cur = confusion_valid.accuracy()

            train_acc += [train_acc_cur]
            valid_acc += [valid_acc_cur]
            print ("Epoch %i: train loss %e, train acc. %f, valid loss %f, valid acc %f, epoch time %.2f s " \
            % (epoch+1, train_loss, train_acc_cur, valid_loss, valid_acc_cur, epoch_time))
        
            if (best_val_loss<0):
                best_val_loss = valid_loss
            
            if ((best_val_loss>=0) and (valid_loss<best_val_loss)):
                print('val loss improved from %f to %f, saving model' % (best_val_loss, valid_loss))
                best_val_loss = valid_loss
                if (output_dir):
                    filename = output_dir + 'best_model_loss'
                    print('saving model to file:',filename)
                    saver.save(sess, filename)
                    
            if (best_val_acc<valid_acc_cur):
                print('val acc improved from %f to %f, saving model' % (best_val_acc, valid_acc_cur))
                best_val_acc = valid_acc_cur
                if (output_dir):
                    filename = output_dir + 'best_model_acc'
                    print('saving model to file:',filename)
                    saver.save(sess, filename)
                        
            if (((epoch+1)%10)==0):
                saver.save(sess, output_dir+'checkpoint_epoch')
                
            epoch += 1
            
            train_loss = 0.
            valid_loss = 0.

    except KeyboardInterrupt:        
        pass
    
    train_queue.get_queue().join()
    val_queue.get_queue().join()
    
    sess.close()

    epoch = np.arange(len(train_acc))
    plt.figure()
    plt.plot(epoch, train_acc,'r', epoch, valid_acc,'b')
    plt.legend(['Train Acc','Val Acc'])
    plt.xlabel('Epochs'), plt.ylabel('Acc'), plt.ylim([0.6,1.03])

In [None]:
num_epochs = 200
prev_model = ''

output_directory = 'model/'
if not os.path.exists(output_directory):
    os.makedirs(output_directory)
    
lr_scheme = np.zeros(shape=(num_epochs), dtype=np.float32)
lr = 1e-3
lr_decay = 10.
lr_scheme[0:5] = lr
lr /= lr_decay
lr_scheme[5:] = lr
    
train_data_list = [train_data, train_hn_data, train_hn2_data, train_hn3_data]
train_target_list = [train_targets, train_hn_targets, train_hn2_targets, train_hn3_targets]
validation_data_list = [validation_data, validation_hn_data, validation_hn2_data, validation_hn3_data]
validation_target_list = [validation_targets, validation_hn_targets, validation_hn2_targets, validation_hn3_targets]

# train model    
train(train_data=train_data_list, 
      train_targets=train_target_list,
      validation_data=validation_data_list, 
      validation_targets=validation_target_list,
      lr_scheme=lr_scheme,
      num_gpus=num_gpus,
      num_epochs=num_epochs,      
      output_dir=output_directory,
      prev_model=prev_model)