In [None]:
# Bunch of imports
# System path should be changed to where the git repo is.
import sys
sys.path.append('..')
import tensorflow as tf
import os
import src.models as models

In [None]:
# Set GPU and batch size
# Can set a different batch size, or skip if CPU
strategy = models.set_GPU()
BATCH_SIZE_PER_REPLICA = 1
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync


In [None]:
# Set parameters
"""
MODEL_SCALE - The number of nodes in the initial layer. Every time the feature scale is decreased
the number of nodes double.
base_path - Path of the model if we are loading one.
model_path - Output path of the trained model.
model_type - The type of model you want to train. Can be either 'vnet', 'evnet' or 'evcnet'
DROP_R - Dropout rate of the model.
EPOCHS - Number of iterations. 9 takes 48 hours.
l_r - Learning rate.
train_files - List of absolute file paths for the training images
train_labels - List of absolute file paths for the training labels
val_files - List of absolute file paths for the validation images
val_labels - List of absolute file paths for the validation labels
"""
MODEL_SCALE = 16
model_type = 'evnet'
base_path = ''
model_path = 'trained_models/evnet/'
DROP_R = 0.5
EPOCHS = 9
l_r = 0.01

# Note that this is just an example of how the input directories would look like
train_dir = 'inputs/train/images/'
train_files = [os.path.join(train_dir, f) for f in sorted(os.listdir(train_dir))]
train_dir = 'inputs/train/labels/'
train_labels = [os.path.join(train_dir, f) for f in sorted(os.listdir(train_dir))]

val_dir = 'inputs/val/images/'
val_files = [os.path.join(val_dir, f) for f in sorted(os.listdir(val_dir))]
val_dir = 'inputs/val/labels/'
val_labels = [os.path.join(val_dir, f) for f in sorted(os.listdir(val_dir))]

In [None]:
# Create a dataset object for tensorflow
train_ds = models.create_dataset(train_files, train_labels, training=True, model_type=model_type, batch_size=BATCH_SIZE)
val_ds = models.create_dataset(val_files, val_labels, training=False, model_type=model_type, batch_size=BATCH_SIZE)

In [None]:
# Load the pre-trained model or create a new model
# model = tf.keras.models.load_model(base_path, custom_objects={'dice_coef': models.dice_coef})

# Note that the compiling step should not be done if loading a model
model = models.load_model(16, 0.5, model_type)
optimizer = tf.keras.optimizers.Adam(l_r)
model.compile(loss=models.dice_coef(reduction=tf.keras.losses.Reduction.AUTO), optimizer=optimizer)

In [None]:
# checkpoint for saving the best weight
# usually good for final run of the training session
checkpoint_path = 'tmp/' + model_type
if os.path.exists('tmp') == False:
    os.makedirs('tmp')
checkpoint_dir = os.path.dirname(checkpoint_path)
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                       save_weights_only=True,
                                       save_best_only=True)
]

In [None]:
# training the model
model.fit(x=train_ds, validation_data = val_ds, epochs=EPOCHS, callbacks=callbacks)

# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_path)

model.save(model_path, save_format='tf')