#### This notebook is an attempt to unfreeze last layer of DSSE model and fine tune it as CRD based on user scribble
#### This follows the paper "Interactive Medical Image Segmentation using Deep Learning with Image-specific Fine-tuning"
#### by Guotai Wang https://arxiv.org/pdf/1710.04043.pdf

In [1]:
import sys
import os
import json
import glob
from datetime import datetime

import numpy as np
import nibabel as nib
from scipy import ndimage
from scipy.ndimage import morphology
import SimpleITK

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, UpSampling3D, Conv3DTranspose, Activation, Add, Concatenate, BatchNormalization, ELU, SpatialDropout3D, GlobalAveragePooling3D, Reshape, Dense, Multiply,  Permute
from tensorflow.keras import regularizers, metrics
from tensorflow.keras.utils import Sequence
#import tensorflow_addons as tfa

import matplotlib.pyplot as plt
import time

#codeRootPath = '/home/wd974888/notebooks/MMGTVSeg'
#dataRootpath = '/home/wd974888/notebooks/MMGTVSeg/data'
codeRootPath = '/home/user/DMML/CodeAndRepositories/MMGTVSeg'
dataRootpath = '/home/user/DMML/CodeAndRepositories/MMGTVSeg/data'
sys.path.append(codeRootPath)
import src
from src.DSSENet import DSSE_VNet

verbose=False




In [2]:
# load model

#Open configuration
trainConfigFilePath = os.path.join(codeRootPath, 'input/trainInput_DSSENet.json')
with open(trainConfigFilePath) as f:
    trainInputParams = json.load(f)
    f.close()
trainInputParams = DSSE_VNet.sanityCheckTrainParams(trainInputParams)

#Other paths
saveModelDirectory = os.path.join(codeRootPath, 'output/DSSEModels')
#Make sure out_dir exists else create it.
out_dir = os.path.join(codeRootPath, 'output/interactiveFineTuneExperiment')
DSSE_VNet.checkFolderExistenceAndCreate(out_dir)
groundTruthComparisonFilePath_out = os.path.join(codeRootPath, 'output/iftExperimentGTComparison.json')

#Get location of model file based on the CVFold index and test patients  
cvFoldIndex = 2
numCVFolds = 5,
thisModelFileName = "{:>02d}FinalDSSENetModel.h5".format(cvFoldIndex)
thisModelPath = os.path.join(saveModelDirectory,thisModelFileName)
#print(thisModelPath)
#Make sure model file exists
if  os.path.exists(thisModelPath) and os.path.isfile(thisModelPath):
    pass
else:
    sys.exit('No file exists at ', thisModelPath)

  
#We are testing with only one model
model = tf.keras.models.load_model(thisModelPath,compile=False)
optimizer = tf.keras.optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-8, decay=0.0)        
if trainInputParams['AMP']:
    optimizer = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
model.compile(optimizer = optimizer, loss = trainInputParams['loss_func'](data_format=trainInputParams['data_format']), metrics = [trainInputParams['acc_func']])
print('Loaded model: ' + thisModelPath)


Loaded model: /home/user/DMML/CodeAndRepositories/MMGTVSeg/output/DSSEModels/02FinalDSSENetModel.h5


In [3]:
model.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 144, 144, 14 0                                            
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, 144, 144, 144 4000        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 144, 144, 144 64          conv3d[0][0]                     
__________________________________________________________________________________________________
elu (ELU)                       (None, 144, 144, 144 0           batch_normalization[0][0]        
_______________________________________________________________________________________

In [12]:
#predict on test data
#Test patients and location
listOfTestPatientNames = ["CHUS097", "CHUM021", "CHGJ036", "CHUS026", "CHUM019", \
                          "CHUS015", "CHUM036", "CHUM022", "CHUM038", "CHUS013"]
resampledTestDataLocation = os.path.join(dataRootpath, 'hecktor_train/resampled')
groundTruthPresent = True

if groundTruthPresent:
    groundTruthTestComparison = []
    diceSum = 0.0

#The cubes are going to be arranged as depth-index, rowIndex, col-Index    
cube_size = [trainInputParams["patientVol_Depth"], trainInputParams["patientVol_Height"], \
             trainInputParams["patientVol_width"]]
DepthRange, RowRange, ColRange, numLabels,  X_size, y_size, y_gt_size = \
    DSSE_VNet.getDimensions(cube_size, trainInputParams["labels_to_train"], groundTruthPresent, \
                  trainInputParams['data_format'])
#print(DepthRange, RowRange, ColRange, numLabels, X_size, y_size, y_gt_size)

#verbose=True

#Read each input and then predict using each model from the list of models
if groundTruthPresent:
    groundTruthTestComparison = []
    diceSum = 0.0
for patientName in listOfTestPatientNames:
    ctFileName = patientName + trainInputParams["suffixList"][0]
    ctData = DSSE_VNet.readAndScaleImageData(fileName=ctFileName,
        folderName=resampledTestDataLocation, clipFlag = True,
        clipLow=trainInputParams["ct_low"], clipHigh =trainInputParams["ct_high"],
        scaleLag=True, scaleFactor=1000, meanSDNormalizeFlag = False, finalDataType = np.float32,
        isLabelData=False, labels_to_train_list=None, verbose=verbose)
    ptData = DSSE_VNet.readAndScaleImageData(fileName=patientName + trainInputParams["suffixList"][1],
        folderName=resampledTestDataLocation, clipFlag = True,
        clipLow=trainInputParams["pt_low"], clipHigh =trainInputParams["pt_high"],
        scaleLag=False, scaleFactor=1, meanSDNormalizeFlag = True, finalDataType = np.float32, 
        isLabelData=False, labels_to_train_list=None, verbose=verbose)
    gtvData = None
    if groundTruthPresent:
        gtvData = DSSE_VNet.readAndScaleImageData(fileName=patientName + trainInputParams["suffixList"][2],
            folderName=resampledTestDataLocation, clipFlag = False,
            clipLow=0, clipHigh = 0,
            scaleLag=False, scaleFactor=1, meanSDNormalizeFlag = False, finalDataType = np.int16, 
            isLabelData=True, labels_to_train_list=trainInputParams["labels_to_train"], verbose=verbose)
    
    #Create batch input
    #createSize1Batch(DepthRange, RowRange, ColRange, X_size, ctData, ptData, groundTruthPresent, y_gt_size, gtvData, data_format)
    batch_X, y_gt = DSSE_VNet.createSize1Batch(DepthRange, RowRange, ColRange,
                                     X_size, ctData, ptData, 
                                     groundTruthPresent, y_gt_size, gtvData, 
                                     trainInputParams['data_format'])
    #predict 
    t = time.time()
    batch_y_pred_softmax = model.predict(batch_X, batch_size=1)
    y_pred_softmax = batch_y_pred_softmax[0,...]
    #Now do the voting : convert softmax output for  N-classes (here N= 2, 0 and 1) 
    # into the class with highest probability
    if 'channels_last' == trainInputParams['data_format']:
        y_pred = np.argmax(y_pred_softmax, axis=-1).astype('int16')            
    else:
        y_pred = np.argmax(y_pred_softmax, axis=0).astype('int16')
    print('\nInference time for 1 sample ',  time.time() - t)
    if groundTruthPresent:
        if 'channels_last' == trainInputParams['data_format']:
            y_gt = np.squeeze(y_gt,axis=-1)
        else:
            y_gt = np.squeeze(y_gt,axis=0)
    #compute metrics if ground truth available, save prediction
    key =  patientName if groundTruthPresent else None
    ctFilePath = os.path.join(resampledTestDataLocation, ctFileName)
    orgSpacing = SimpleITK.ReadImage(ctFilePath).GetSpacing()
    tarnsposedSpacing = (orgSpacing[2], orgSpacing[1], orgSpacing[0])
    srcImage_nii =  nib.load(ctFilePath)
    destinationFilePath = os.path.join(out_dir, 'pred_' + patientName + trainInputParams["suffixList"][2])

    dice, msd, lbls, rms, hd = DSSE_VNet.evaluateAndSavePredictionHelper(y_pred=y_pred,
        evaluateFlag=groundTruthPresent, key=key, y_gt=y_gt, spacingForEvaluation=tarnsposedSpacing,
        saveFlag=True, destinationFilePath=destinationFilePath, modelImage_nii=srcImage_nii)
    if groundTruthPresent:
        groundTruthTestComparison.append({ 'patientName':patientName, 'SSMSD':np.transpose(msd)[0],  'dice':np.transpose(dice)[0] })
        diceSum += np.sum(np.transpose(dice))
    
if groundTruthPresent:
    print('Average Dice: ', diceSum/len(groundTruthTestComparison))
    for result in groundTruthTestComparison :
        print(result)
    with open(groundTruthComparisonFilePath_out, 'w') as fp:
        json.dump(groundTruthTestComparison, fp) #, indent='' #, indent=4
        fp.close()  


Inference time for 1 sample  2.212066411972046
CHUS097  Surface to surface MSD [mm]:  [1.6507491]  dice:  [0.81649426]
***********************************************
predY_shape  (144, 144, 144)  transpose_GTV shape:  (145, 144, 144)

Inference time for 1 sample  1.1099872589111328
CHUM021  Surface to surface MSD [mm]:  [1.29499484]  dice:  [0.84044349]

Inference time for 1 sample  1.1186728477478027
CHGJ036  Surface to surface MSD [mm]:  [1.97313758]  dice:  [0.81697656]

Inference time for 1 sample  1.1167559623718262
CHUS026  Surface to surface MSD [mm]:  [146.99496396]  dice:  [0.]

Inference time for 1 sample  1.1312181949615479
CHUM019  Surface to surface MSD [mm]:  [0.73600702]  dice:  [0.8976378]

Inference time for 1 sample  1.1159629821777344
CHUS015  Surface to surface MSD [mm]:  [0.73922884]  dice:  [0.8501176]

Inference time for 1 sample  1.1171424388885498
CHUM036  Surface to surface MSD [mm]:  [138.33571783]  dice:  [0.]
**********************************************