## Training the U-Net


@authors M. Schultheiss, T. Dorosti


In [None]:
import numpy as np
import pandas as pd
import h5py
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import datetime
from scipy import ndimage
from skimage.transform import rescale, resize
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras.layers import Conv2D
from keras.preprocessing.image import ImageDataGenerator
import keras.initializers
from keras.models import load_model

from UNet import UNet
from pyDeleClasses import Slice2D, Slice2DSet
from configuration_paper import cparam, CURRENT_CONFIG, PATH_SERVER_CACHE, POSITION
from functions import getSyntheticData

In [None]:
# setup GPU
CURRENT_GPU = '?'
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]=str(CURRENT_GPU)

print(tf.test.is_built_with_cuda())
print(tf.config.list_physical_devices('GPU'))
print(tf.test.gpu_device_name())
print(tf.test.is_gpu_available('GPU:{}'.format(CURRENT_GPU)))

### 1. Get Input Synthetic Radiographs Data and Lung thickness Ground Truth Lables

#### 1.A) Luna16 Data:

In [None]:
metadataPath = '???' # path to metadata .csv file
dataset_name = 'luna16'
radiograph_range = range(0,10) # get all 10 projections for luna
slices_luna = getSyntheticData(metadataPath, PATH_SERVER_CACHE, POSITION, dataset_name, CURRENT_CONFIG, radiograph_range)
tx_luna, ty_luna, _, _, valx_luna, valy_luna = slices_luna.split_data_by_csv("splits/lungvolume") # dont save the test data in the training script

print("Train Len", len(tx_luna))
print("Val Len", len(valx_luna))

#### 1.B) PE Data:

In [None]:
metadataPath = '???' # path to metadata .csv file
dataset_name = 'PE'
radiograph_range = range(4,5) # get the central projections for PE
slices_PE = getSyntheticData(metadataPath, PATH_SERVER_CACHE, POSITION, dataset_name, CURRENT_CONFIG, radiograph_range)
tx_PE, ty_PE, _, _, valx_PE, valy_PE = slices_PE.split_data_by_csv("splits/lungvolumePE") # dont save the test data in the training script

print("Train Len", len(tx_PE))
print("Val Len", len(valx_PE))

In [None]:
# combine the datasets for the final train and validation data:      
tx, ty, valx, valy= [],[],[],[]
tx = tx_luna + tx_PE
ty = ty_luna + ty_PE
valx = valx_luna + valx_PE 
valy = valy_luna + valy_PE

print("Train Len", len(tx))
print("Val Len", len(valx))

### 2. Set up model:

In [None]:
# define data generator
def generator(x_train, y_train, batch_size, data_gen_args, start_seed):
    datagen = ImageDataGenerator(
        **data_gen_args).flow(x_train, x_train, batch_size, seed=start_seed)
    maskgen = ImageDataGenerator(
        **data_gen_args).flow(y_train, y_train, batch_size, seed=start_seed)
    while True:
        batchx, _ = datagen.next()
        batchy, _ = maskgen.next()
        yield batchx, batchy

In [None]:
#looking at some augmentations:
data_gen_args = dict(featurewise_center=False,
                     featurewise_std_normalization=False,
                     rotation_range=10.0,
                     height_shift_range=0.20,
                     width_shift_range=0.20,

                     fill_mode='constant',
                     zoom_range=0,
                     cval=0.,
                     horizontal_flip=False)

batch_size = 8
train_gen = generator(np.array(tx)[:,:,:,np.newaxis], np.array(ty)[:,:,:,np.newaxis], batch_size, data_gen_args, 42)
for i in range(0,3):
    a, b = next(train_gen)
    #print(a.shape, b.shape)
i=3
plt.subplot(121)
plt.imshow(a[i], cmap='gray')
plt.colorbar()
plt.subplot(122)
plt.imshow(b[i], cmap='gray')
plt.colorbar()

In [None]:
model = UNet(slice_shape=[cparam("IMAGEWIDTH"), cparam("IMAGEWIDTH"), 1],  layer_depth=6, filter_count=32, kernel_size_down=(3,3), kernel_size_pool=(2, 2), dilation_rates=[1,2,1,2,1,1,], dropout=False,activation="relu").get_keras_model() 
model_name = '???'
checkpoint_dir='???'
loss_function = "mean_squared_error" 
monitor ="mean_squared_error"
opt = keras.optimizers.Adam(learning_rate=0.0001)
epochs = 120 
logdir = os.path.join("logs_unet", model_name+datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

model.compile(loss=loss_function,
              optimizer=opt,
              metrics=[monitor])
            
checkpoint = ModelCheckpoint(checkpoint_dir+'/{}_model.h5'.format(model_name), monitor='val_loss', verbose=2, save_best_only=True, mode='min')
data_gen_args = dict(featurewise_center=False,
                     featurewise_std_normalization=False,
                     rotation_range=10.0,
                     height_shift_range=0.20,
                     width_shift_range=0.20,
                     fill_mode='constant',
                     zoom_range=0,
                     cval=0.,
                     horizontal_flip=False)
    
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
#model.summary()

### 3. Train:

In [None]:
history = model.fit(generator(np.array(tx)[:,:,:,np.newaxis], np.array(ty)[:,:,:,np.newaxis], batch_size, data_gen_args, 42),
                              validation_data=generator(np.array(valx)[:,:,:,np.newaxis], np.array(valy)[:,:,:,np.newaxis], batch_size, data_gen_args, 42),
                              steps_per_epoch=len(tx)//batch_size, epochs=epochs,
                              validation_steps=len(valx)//batch_size, callbacks=[checkpoint, tensorboard_callback]) 


In [None]:
#save the history to plot loss later
# convert the history.history dict to a pandas DataFrame:     
hist_df = pd.DataFrame(history.history) 
# save to csv: 
run='?'
hist_csv_file = os.path.join(PATH_SERVER_CACHE,CURRENT_CONFIG)+'{}_modelHistory_{}.csv'.format(model_name, run)
with open(hist_csv_file, mode='w') as f:
    hist_df.to_csv(f)

In [None]:
history.history.keys()

In [None]:
#plot the loss
loss = 'val_loss'
plt.plot(hist_df[loss], label='val')
plt.plot(hist_df['loss'], label='train')
plt.xlabel('epochs')
plt.legend()
plt.ylabel(loss)