In [None]:
#https://github.com/ellisdg/3DUnetCNN
#https://www.biorxiv.org/content/10.1101/2020.02.25.964015v1.full
import os, sys
import glob
sys.path.append('/lustre_scratch/sandbox/salam/cnn/3DUnetCNN/')
os.environ["CUDA_VISIBLE_DEVICES"]="3,7,2,1"
from unet3d.data import write_data_to_file, open_data_file
from unet3d.generator_sequence import get_validation_split
from unet3d.generator_sequence import DataGenerator
from unet3d.model import isensee2017
from unet3d.training import load_old_model, train_model
from keras.utils.vis_utils import plot_model
import tensorflow as tf
import threading
import faulthandler
import unet3d.generator
faulthandler.enable()

file_path="/lustre_scratch/sandbox/salam/cnn/3DUnetCNN/brats/data/training/inex_train_uniform/inex_train_iso_v1/inex_whole_brain/"
data_path="/lustre_scratch/sandbox/salam/models_tmp/"
output_model_path=data_path
#BraTS_2020_subject_ID
config = dict()
config["gpu"]=4
config["image_shape"] = (320, 256, 320)  # This determines what shape the images will be cropped/resampled to.
config["patch_shape"] =(320,256,80)#None  # switch to None to train on the whole image
config["labels"] = (1)  # the label numbers on the input image
config["n_base_filters"] = 16  # these are doubled after each downsampling
config["n_labels"] =1# len(config["labels"])
config["all_modalities"] = ["brain"]  # set for the brats data
config["training_modalities"] = config["all_modalities"]  # change this if you want to only use some of the modalities
config["nb_channels"] = len(config["training_modalities"])
if "patch_shape" in config and config["patch_shape"] is not None:
    config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"]))
else:
    config["input_shape"] = tuple([config["nb_channels"]] + list(config["image_shape"]))
config["truth_channel"] = config["nb_channels"]
config["deconvolution"] = False  # if False, will use upsampling instead of deconvolution

config["batch_size"] =4
config["validation_batch_size"] =4
config["n_epochs"] = 16  # cutoff the training after this many epochs
config["patience"] = 10  # learning rate will be reduced after this many epochs if the validation loss is not improving
config["early_stop"] = 100  # training will be stopped after this many epochs without the validation loss improving
config["initial_learning_rate"] = 5e-5
config["learning_rate_drop"] = 0.7  # factor by which the learning rate will be reduced
config["validation_split"] = 0.8  # portion of the data that will be used for training
config["flip"] = False  # augments the data by randomly flipping an axis during
config["permute"] = False  # data shape must be a cube. Augments the data by permuting in various directions
config["distort"] = False  # switch to None if you want no distortion
config["augment"] = config["flip"] or config["distort"]
config["patch_overlap"] = 0  # if > 0, during training, validation patches will be overlapping
config["training_patch_start_offset"] = (0, 0, 0)  # randomly offset the first patch index by up to this offset
config["skip_blank"] = True  # if True, then patches without any target will be skipped

config["data_file"] = os.path.abspath(data_path+"brain_data.h5")
config["model_file"] = os.path.abspath(output_model_path+"brain_unet_model-{epoch:02d}.h5")
config["training_file"] = os.path.abspath(data_path+"brain_training_ids.pkl")
config["validation_file"] = os.path.abspath(data_path+"brain_validation_ids.pkl")
config["overwrite"] = False  # If True, will previous files. If False, will use previously written files.

#write_data_to_file-->Utils.read_image_files

def fetch_brats_2020_files(modalities, group="Training", include_truth=True, return_subject_ids=False):
    training_data_files = list()
    subject_ids = list()
    modalities = list(modalities)
    if include_truth:
        modalities = modalities + ["seg"]
    #print(os.path.join(os.path.dirname(files_dir), "data", "*{0}*", "*{0}*").format(group))
    for subject_dir in glob.glob(file_path+"/*"):
        subject_id = os.path.basename(subject_dir)
        
        subject_ids.append(subject_id)
        subject_files = list()
        for modality in modalities:
            subject_files.append(os.path.join(subject_dir, subject_id + "_" + modality + ".nii.gz"))
            #print(os.path.join(subject_dir, subject_id + "_" + modality + ".nii.gz"))
        training_data_files.append(tuple(subject_files))
    if return_subject_ids:
        return training_data_files, subject_ids
    else:
        return training_data_files


def fetch_training_data_files(return_subject_ids=False):
    return fetch_brats_2020_files(modalities=config["training_modalities"],include_truth=True, return_subject_ids=return_subject_ids)


def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    #if not overwrite and os.path.exists(config["model_file"]):
        #model = load_old_model(config["model_file"])
    #else:
        # instantiate new model
    model = isensee2017.isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"],
                              n_base_filters=config["n_base_filters"],depth=5,dropout_rate=0.08,gpu=config["gpu"])
    #print(model.summary())
    #tf.keras.utils.plot_model(
    #model, to_file='model.png', show_shapes=False, show_layer_names=True,
    #rankdir='LR', expand_nested=False, dpi=96)

    # get training and testing generators
    lock = threading.Lock()
    
    training_list, validation_list = get_validation_split(data_file_opened,
                                                          data_split=onfig["validation_split"]
                                                          overwrite=overwrite,
                                                          training_file=config["training_file"],
                                                          validation_file=config["validation_file"],)
    
    train_generator=DataGenerator(data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        patch_overlap=config["patch_overlap"], ndex_list=training_list, locks)
    validation_generator=DataGenerator(data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        patch_overlap=config["patch_overlap"], ndex_list=validation_list, locks)
    
    n_train_steps=train_generator.num_training_steps
     
    n_validation_steps=validation_generator.num_validation_steps
    # run training
    
    history=train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
    return history


history=main(overwrite=config["overwrite"])
