In [None]:
import os
import glob

from unet3d.data import write_data_to_file, open_data_file
from unet3d.generator import get_training_and_validation_generators
from unet3d.model import isensee2017_model
from unet3d.training import load_old_model, train_model

In [None]:
config = dict()
config["image_shape"] = (256, 128, 256)  # This determines what shape the images will be cropped/resampled to.
config["patch_shape"] = (64,64,64)  # switch to None to train on the whole image
config["labels"] = (0, 1, 2, 3)  # the label numbers on the input image
config["n_base_filters"] = 16
config["n_labels"] = len(config["labels"])
config["all_modalities"] = ["t1"]
config["training_modalities"] = config["all_modalities"]  # change this if you want to only use some of the modalities
config["nb_channels"] = len(config["training_modalities"])
if "patch_shape" in config and config["patch_shape"] is not None:
    config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"]))
else:
    config["input_shape"] = tuple([config["nb_channels"]] + list(config["image_shape"]))
config["truth_channel"] = config["nb_channels"]
config["deconvolution"] = True  # if False, will use upsampling instead of deconvolution

In [None]:
config["batch_size"] = 2
config["validation_batch_size"] = 4
config["n_epochs"] = 500  # cutoff the training after this many epochs
config["patience"] = 10  # learning rate will be reduced after this many epochs if the validation loss is not improving
config["early_stop"] = 50  # training will be stopped after this many epochs without the validation loss improving
config["initial_learning_rate"] = 5e-4
config["learning_rate_drop"] = 0.5  # factor by which the learning rate will be reduced
config["validation_split"] = 0.8  # portion of the data that will be used for training
config["flip"] = False  # augments the data by randomly flipping an axis during
config["permute"] = True  # data shape must be a cube. Augments the data by permuting in various directions
config["distort"] = False  # switch to None if you want no distortion
config["augment"] = config["flip"] or config["distort"]
config["validation_patch_overlap"] = 0  # if > 0, during training, validation patches will be overlapping
config["training_patch_start_offset"] = (16, 16, 16)  # randomly offset the first patch index by up to this offset
config["skip_blank"] = True  # if True, then patches without any target will be skipped

In [None]:
config["data_file"] = os.path.abspath("data.h5")
config["model_file"] = os.path.abspath("model.h5")
config["training_file"] = os.path.abspath("train.pkl")
config["validation_file"] = os.path.abspath("val.pkl")
config["overwrite"] = True  # If True, will replace previous files. If False, will use previously written files.

In [None]:
def fetch_training_data_files(return_subject_ids=False):
    training_data_files = list()
    subject_ids = list()
    
    train_img = glob.glob(os.path.join('.', "Data", 'Train','img','*'))
    train_seg = glob.glob(os.path.join('.', "Data", 'Train','seg','*'))
    
    for i in range(len(train_img)):
        subject_ids.append(os.path.basename(train_img[i]))
        subject_files = list()
        subject_files.append(os.path.join(train_img[i]))
        subject_files.append(os.path.join(train_seg[i]))
        training_data_files.append(tuple(subject_files))
        
        
    val_img = glob.glob(os.path.join('.', "Data", 'Validation','img','*'))
    val_seg = glob.glob(os.path.join('.', "Data", 'Validation','seg','*'))
    
    for i in range(len(val_img)):
        subject_ids.append(os.path.basename(val_img[i]))
        subject_files = list()
        subject_files.append(os.path.join(val_img[i]))
        subject_files.append(os.path.join(val_seg[i]))
        training_data_files.append(tuple(subject_files))
        
        
    if return_subject_ids:
        return training_data_files, subject_ids
    else:
        return training_data_files

In [None]:
overwrite=True
# convert input images into an hdf5 file
if overwrite or not os.path.exists(config["data_file"]):
    training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)
    write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                       subject_ids=subject_ids)
    
    
data_file_opened = open_data_file(config["data_file"])
if not overwrite and os.path.exists(config["model_file"]):
    model = load_old_model(config["model_file"])
else:
# instantiate new model
    model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                          initial_learning_rate=config["initial_learning_rate"],
                          n_base_filters=config["n_base_filters"])

# Loading "twice" in str output only, triple checked, actually loaded once.

In [None]:
train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
    data_file_opened,
    batch_size=config["batch_size"],
    data_split=config["validation_split"],
    overwrite=overwrite,
    validation_keys_file=config["validation_file"],
    training_keys_file=config["training_file"],
    n_labels=config["n_labels"],
    labels=config["labels"],
    patch_shape=config["patch_shape"],
    validation_batch_size=config["validation_batch_size"],
    validation_patch_overlap=config["validation_patch_overlap"],
    training_patch_start_offset=config["training_patch_start_offset"],
    permute=config["permute"],
    augment=config["augment"],
    skip_blank=config["skip_blank"],
    augment_flip=config["flip"],
    augment_distortion_factor=config["distort"])


In [None]:
train_model(model=model,
            model_file=config["model_file"],
            training_generator=train_generator,
            validation_generator=validation_generator,
            steps_per_epoch=n_train_steps,
            validation_steps=n_validation_steps,
            initial_learning_rate=config["initial_learning_rate"],
            learning_rate_drop=config["learning_rate_drop"],
            learning_rate_patience=config["patience"],
            early_stopping_patience=config["early_stop"],
            n_epochs=config["n_epochs"])

In [None]:
data_file_opened.close()