The following notebook is used to train what is referred to as a multi-visit reconstruction network which leverages previous subject scans to improve the quality of the initial reconstructions (single-visit) generated in Step 01. 

To run this notebook a docker container with the latest version of tensorflow was used. The command used to open this notebook is:

docker run -p 8888:8888 --gpus all -v $(pwd):/tf tensorflow/tensorflow:latest-gpu-jupyter 

Using docker eliminates the need to configure a python environment with a working installation of tensorflow. Keep in mind when running the docker containiner it is best to do so in your home directory. This is because docker the command (pwd)  as files with parent folders with respect to the directory the command is run in, are not accessible to the container. 

In [1]:
import sys
import subprocess

def install(name):
    subprocess.call(['pip', 'install', name])

install('nibabel')
install('scikit-learn')

import numpy as np
import os
import glob
import sys
import nibabel as nib

# Importing our model
MY_UTILS_PATH = "../src/"
if not MY_UTILS_PATH in sys.path:
    sys.path.append(MY_UTILS_PATH)
import enhancement_unet as eu
import tensorflow as tf
# Importing callbacks and data augmentation utils
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import  Adam

Define input parameters

In [2]:
#input paramters
H,W = (512,512)
Hpad = (8 - (H%8))//2
Wpad = (8 - (W%8))//2
norm = np.sqrt(H*W)
patience = 10
batch_size = 16
epochs = 150
lr = 1e-4 #learning rate (higher rate unstable, lower slow)
decay = 1e-7 
c1,c2 = 20,-6 #cropping slices with little or no anatomical structures

In [1]:
registration = 'elastic_' #type of registration performed in Step 02
#Load the training and validation file ids
training_files =  np.loadtxt("../data/train_val_test_split/train_long.txt", dtype=str)
val_files = np.loadtxt("../data/train_val_test_split/val_long.txt", dtype=str)

In [2]:
training_files

array([['TUM37-20190108.nii', 'TUM37-20191105.nii'],
       ['TUM37-20190730.nii', 'TUM37-20191105.nii'],
       ['TUM36-20181030.nii', 'TUM36-20190527.nii'],
       ['TUM29-20180125.nii', 'TUM29-20180716.nii'],
       ['TUM28-20180821.nii', 'TUM28-20181120.nii'],
       ['TUM27-20180716.nii', 'TUM27-20190104.nii']], dtype='<U18')

In [3]:
model_path = '../models/enhancement_unet_sc_10x-iki.h5' #model path
weights_path = '../models/weights_enhancement_unet_sc_10x-iki.h5' #weights path
previous_path = '../data/reference_reg_10x-iki/' #path to previous scans
follow_up_path = '../data/predicted/10x-iki/' #path to initial reconstructions
reference_path = '../../../data/brain-cancer/' #path to entire reconstructed dataset

In [6]:
#get training files
train_previous_files = [previous_path + registration + ID[0][:-4] + '_' + ID[1][:14] + '.nii' for ID in training_files]
train_follow_up_files = [follow_up_path + ID[1][:-4] + '_predicted.nii'  for ID in training_files]
train_reference_files = [reference_path + ID[1][:14] + '.nii' for ID in training_files]

#get val files
val_previous_files = [previous_path + registration + ID[0][:-4] + '_' + ID[1][:14] + '.nii' for ID in val_files]
val_follow_up_files = [follow_up_path + ID[1][:-4] + '_predicted.nii' for ID in val_files]
val_reference_files = [reference_path + ID[1][:14] + '.nii' for ID in val_files]


In [7]:
#count training samples
aux=0
for ID in train_previous_files:
    aux_shape = nib.load(ID).shape
    aux += aux_shape[-1] - (c1 - c2)

print('number of training samples', aux)
#load training samples into single array (nsamples,height,width,2)
train = np.zeros((aux, W, H, 2))
train_ref = np.zeros((aux,W,H,1))
aux_counter = 0
for ii in range(len(train_previous_files)):
    prev = nib.load(train_previous_files[ii]).get_fdata()[:,:,c1:c2]
    next_ = nib.load(train_follow_up_files[ii]).get_fdata()[:,:,c1:c2]
    ref = nib.load(train_reference_files[ii]).get_fdata()[:,:,c1:c2]
    
    aux = prev.shape[-1]
    train[aux_counter:aux_counter+aux,:,:,0] = np.swapaxes(prev,0,2) / np.abs(prev).max()
    train[aux_counter:aux_counter+aux,:,:,1] = np.swapaxes(next_,0,2) / np.abs(next_).max()
    train_ref[aux_counter:aux_counter+aux,:,:,0] = np.swapaxes(ref,0,2) / np.abs(ref).max()
    aux_counter += aux


number of training samples 746


In [8]:
#count validation samples
aux=0
for ID in val_previous_files:
    aux_shape = nib.load(ID).shape
    aux += aux_shape[-1] - (c1 - c2)

print('number of validation samples', aux)
#load validation samples into single array (nsamples, height,width,channels=2)
#the two channels account for the previous scan and initial reconstruction
val = np.zeros((aux, W, H, 2))
val_ref = np.zeros((aux,W,H,1))
aux_counter = 0
for ii in range(len(val_previous_files)):
    prev = nib.load(val_previous_files[ii]).get_fdata()[:,:,c1:c2]
    next_ = nib.load(val_follow_up_files[ii]).get_fdata()[:,:,c1:c2]
    ref = nib.load(val_reference_files[ii]).get_fdata()[:,:,c1:c2]
    aux = prev.shape[-1]
    val[aux_counter:aux_counter+aux,:,:,0] = np.swapaxes(prev,0,2) / np.abs(prev).max()
    val[aux_counter:aux_counter+aux,:,:,1] = np.swapaxes(next_,0,2) / np.abs(next_).max()
    val_ref[aux_counter:aux_counter+aux,:,:,0] = np.swapaxes(ref,0,2) / np.abs(ref).max()
    aux_counter += aux

number of validation samples 726


In [9]:
#shuffle training data
indexes = np.arange(train.shape[0],dtype = int)
np.random.shuffle(indexes)
train = train[indexes]
train_ref = train_ref[indexes]

In [10]:

# Callbacks
earlyStopping = EarlyStopping(monitor='val_loss',
                                           patience=patience,
                                           verbose=0, mode='min')

checkpoint = ModelCheckpoint(weights_path, mode = 'min',
                             monitor='val_loss',verbose=0,
                             save_best_only=True, save_weights_only = True)

checkpoint2 = ModelCheckpoint(model_path, monitor='val_loss',
                              verbose=0, save_best_only=False,
                              save_weights_only=False, mode='min',
                              save_freq=1)


In [11]:
#paramters for data augmentation
seed = 10
image_datagen1 = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.075,
    height_shift_range=0.075,
    shear_range=0.25,
    zoom_range=0.25,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest')

image_datagen2 = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.075,
    height_shift_range=0.075,
    shear_range=0.25,
    zoom_range=0.25,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest')

image_datagen1.fit(train, augment=True, seed=seed)
image_datagen2.fit(train_ref, augment=True, seed=seed)

image_generator1 = image_datagen1.flow(train,batch_size = batch_size,seed = seed)
image_generator2 = image_datagen1.flow(train_ref,batch_size = batch_size,seed = seed)


  ' channels).')
  str(self.x.shape[channels_axis]) + ' channels).')


In [12]:
def combine_generator(gen1,gen2):
    while True:
        batch_train = gen1.next()
        batch_train_ref = gen2.next()
        yield([batch_train,batch_train[:,:,:,1,np.newaxis]],
              batch_train_ref)

combined = combine_generator(image_generator1,image_generator2)


In [13]:

#initialize model
if os.path.isfile(model_path):
    model = load_model(model_path)
else:
    model = eu.enhancement_unet(H=H, W=W,Hpad = Hpad, Wpad = Wpad)
    opt = Adam(learning_rate = lr,decay = decay)
    model.compile(loss = 'mse',optimizer=opt)

print(model.summary())
hist = model.fit(combined,#[train,train[:,:,:,1,np.newaxis]], train_ref,
                epochs=epochs,
                steps_per_epoch=train.shape[0]//batch_size,
                verbose=1,
                validation_data=([val,val[:,:,:,1,np.newaxis]],val_ref),
                callbacks=[checkpoint2, checkpoint, earlyStopping])


Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D)  (None, 520, 520, 2)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 520, 520, 48) 912         zero_padding2d[0][0]             
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 520, 520, 48) 20784       conv2d[0][0]                     
______________________________________________________________________________________________



Epoch 2/150
Epoch 3/150
Epoch 4/150
Epoch 5/150
Epoch 6/150
Epoch 7/150
Epoch 8/150
Epoch 9/150
Epoch 10/150
Epoch 11/150
Epoch 12/150
Epoch 13/150
Epoch 14/150
Epoch 15/150
Epoch 16/150
Epoch 17/150
Epoch 18/150
Epoch 19/150
Epoch 20/150
Epoch 21/150
Epoch 22/150
Epoch 23/150
Epoch 24/150
Epoch 25/150
Epoch 26/150
Epoch 27/150
Epoch 28/150
Epoch 29/150
Epoch 30/150
Epoch 31/150
Epoch 32/150
Epoch 33/150
Epoch 34/150
Epoch 35/150
Epoch 36/150
Epoch 37/150
Epoch 38/150
Epoch 39/150
Epoch 40/150
Epoch 41/150
Epoch 42/150
Epoch 43/150
Epoch 44/150
Epoch 45/150
