In [None]:
!pwd

import config_vacbag_512_2D as config
import paths_2D as paths
import explore_2D as explore
import generator_2D as generator
import loss_2D as loss
import model_ronneberger_512 as model

# PREPARE DATA

In [None]:
patient_paths, input_paths, label_paths = paths.get_paths(config.DATA_PATH)

train_paths, valid_paths, test_paths = paths.split_paths(input_paths, config.RATIO)

print("Patient:", len(patient_paths))
print([len(x) for x in input_paths])
print("Total:", len(train_paths) + len(valid_paths) + len(test_paths))
print("---------------")
print("Train:", len(train_paths))
print("Valid:", len(valid_paths))
print("Test:", len(test_paths))

In [None]:
import numpy as np

images = explore.get_images(train_paths)
train_mean = np.mean(images)
train_std = np.std(images)
del(images)

print("Mean:", train_mean)
print("Std:", train_std)

In [None]:
train_gen = generator.make_gen(train_paths,
                             label_paths,
                             train_mean = train_mean,
                             train_std = train_std,
                             batch_size=config.BATCH_SIZE,
                             grid_size=config.GRID_SIZE, 
                             structure_names=config.STRUCTURE_NAMES, 
                             augment=False)

valid_gen = generator.make_gen(valid_paths,
                             label_paths,
                             train_mean = train_mean,
                             train_std = train_std,
                             batch_size=config.BATCH_SIZE, 
                             grid_size=config.GRID_SIZE, 
                             structure_names=config.STRUCTURE_NAMES, 
                             augment=False)

test_gen = generator.make_gen(test_paths,
                             label_paths,
                             train_mean = train_mean,
                             train_std = train_std,
                             batch_size=config.BATCH_SIZE, 
                             grid_size=config.GRID_SIZE, 
                             structure_names=config.STRUCTURE_NAMES, 
                             augment=False)

# PRE-TRAINING

In [None]:
import tensorflow as tf
import model_ronneberger_512 as model

MODEL_SAVE = "./weights/bce_vacbag_512.{epoch:02d}.hdf5"

LOSS = tf.keras.losses.BinaryCrossentropy()
INITIAL_LR = 1e-4
STOPPING_PATIENCE = 50
LR_PATIENCE = 4

METRICS = [loss.dsc_loss]
OPTIMIZER = tf.keras.optimizers.Adam(lr = INITIAL_LR)
LR_SCALE = 0.5


early_stopping = tf.keras.callbacks.EarlyStopping(patience=STOPPING_PATIENCE, verbose=1, restore_best_weights=True)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(factor=LR_SCALE, patience=LR_PATIENCE, verbose=1)
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(MODEL_SAVE, save_weights_only=True, verbose=1)

model = model.model(config.GRID_SIZE, len(config.STRUCTURE_NAMES))

model.compile(optimizer = OPTIMIZER, loss = LOSS ,metrics = METRICS)

In [None]:
from matplotlib import pyplot as plt

y_input, y_true = train_gen.__getitem__(0)
y_pred = model.predict(y_input)
#y_pred = np.round(y_pred)

print("Input shape:", y_input.shape)
print("Truth shape: ", y_true.shape)
print("Predict shape:", y_pred.shape)

for batch_index in range(y_input.shape[0]):
    fig, axs = plt.subplots(1, len(config.STRUCTURE_NAMES)+3, figsize=(10,20))
    axs[0].imshow(y_input[batch_index,...,0])
    axs[1].imshow(y_true[batch_index,...,0])
    axs[2].imshow(y_pred[batch_index,...,0])
    
    axs[3].imshow(y_true[batch_index,...,0], alpha=0.5)
    axs[3].imshow(y_pred[batch_index,...,0], alpha=0.5)

In [None]:
#model.save("./weights/bce_vacbag_512.initial_2.hdf5")

In [None]:
EPOCHS = 1

print("\n Training...")
train_history = model.fit(train_gen,
                          epochs=EPOCHS,
                          steps_per_epoch=train_gen.__len__(),
                          validation_steps=valid_gen.__len__(),
                          validation_data=valid_gen,
                          callbacks=[early_stopping, model_checkpoint, reduce_lr],
                          verbose=1)

In [None]:
from matplotlib import pyplot as plt

y_input, y_true = test_gen.__getitem__(0)
y_pred = model.predict(y_true)
y_pred = np.round(y_pred)

print("Input shape:", y_input.shape)
print("Truth shape: ", y_true.shape)
print("Predict shape:", y_pred.shape)

for batch_index in range(y_input.shape[0]):
    fig, axs = plt.subplots(1, len(config.STRUCTURE_NAMES)+3, figsize=(10,20))
    axs[0].imshow(y_input[batch_index,...,0])
    axs[1].imshow(y_true[batch_index,...,0])
    axs[2].imshow(y_pred[batch_index,...,0])
    
    axs[3].imshow(y_true[batch_index,...,0], alpha=0.5)
    axs[3].imshow(y_pred[batch_index,...,0], alpha=0.5)

In [None]:
model.load_weights("./weights/bce_vacbag_512.initial.hdf5")

from matplotlib import pyplot as plt

y_input, y_true = test_gen.__getitem__(0)
y_pred = model.predict(y_true)
#y_pred = np.round(y_pred)

print("Input shape:", y_input.shape)
print("Truth shape: ", y_true.shape)
print("Predict shape:", y_pred.shape)

for batch_index in range(y_input.shape[0]):
    fig, axs = plt.subplots(1, len(config.STRUCTURE_NAMES)+3, figsize=(10,20))
    axs[0].imshow(y_input[batch_index,...,0])
    axs[1].imshow(y_true[batch_index,...,0])
    axs[2].imshow(y_pred[batch_index,...,0])
    
    diff = y_true-y_pred
    
#     axs[3].imshow(y_true[batch_index,...,0], alpha=0.5)
#     axs[3].imshow(y_pred[batch_index,...,0], alpha=0.5)
    axs[3].imshow(diff[batch_index,...,0], alpha=0.5


In [None]:
import tensorflow as tf
import model_ronneberger_512 as model


MODEL_SAVE = "./weights/dice_vacbag_512.{epoch:02d}.hdf5"

LOSS = loss.dsc_loss
INITIAL_LR = 1e-4
STOPPING_PATIENCE = 50
LR_PATIENCE = 4

METRICS = [
    loss.dice_metric,
    tf.keras.metrics.Recall()
]
OPTIMIZER = tf.keras.optimizers.Adam(lr = INITIAL_LR)
LR_SCALE = 0.5


early_stopping = tf.keras.callbacks.EarlyStopping(patience=STOPPING_PATIENCE, verbose=1, restore_best_weights=True)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(factor=LR_SCALE, patience=LR_PATIENCE, verbose=1)
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(MODEL_SAVE, save_weights_only=True, verbose=1)

model = model.model(config.GRID_SIZE, len(config.STRUCTURE_NAMES))

model.compile(optimizer = OPTIMIZER, loss = LOSS ,metrics = METRICS)

model.load_weights("./weights/bce_vacbag_512.initial.hdf5")

In [None]:
EPOCHS = 10

print("\n Training...")
train_history = model.fit(train_gen,
                          epochs=EPOCHS,
                          steps_per_epoch=train_gen.__len__(),
                          validation_steps=valid_gen.__len__(),
                          validation_data=valid_gen,
                          callbacks=[early_stopping, model_checkpoint, reduce_lr],
                          verbose=1)

In [None]:
model.load_weights("./weights/dice_vacbag_512.01.hdf5")

from matplotlib import pyplot as plt

y_input, y_true = train_gen.__getitem__(0)
y_pred = model.predict(y_true)
#y_pred = np.round(y_pred)

print("Input shape:", y_input.shape)
print("Truth shape: ", y_true.shape)
print("Predict shape:", y_pred.shape)

for batch_index in range(y_input.shape[0]):
    fig, axs = plt.subplots(1, len(config.STRUCTURE_NAMES)+3, figsize=(10,20))
    axs[0].imshow(y_input[batch_index,...,0])
    axs[1].imshow(y_true[batch_index,...,0])
    axs[2].imshow(y_pred[batch_index,...,0])
    
    axs[3].imshow(y_true[batch_index,...,0], alpha=0.5)
    axs[3].imshow(y_pred[batch_index,...,0], alpha=0.5)

In [None]:
model.load_weights("./weights/dice_vacbag_512.02.hdf5")

from matplotlib import pyplot as plt

y_input, y_true = train_gen.__getitem__(0)
y_pred = model.predict(y_true)
y_pred = np.round(y_pred)

print("Input shape:", y_input.shape)
print("Truth shape: ", y_true.shape)
print("Predict shape:", y_pred.shape)

for batch_index in range(y_input.shape[0]):
    fig, axs = plt.subplots(1, len(config.STRUCTURE_NAMES)+3, figsize=(10,20))
    axs[0].imshow(y_input[batch_index,...,0])
    axs[1].imshow(y_true[batch_index,...,0])
    axs[2].imshow(y_pred[batch_index,...,0])
    
    axs[3].imshow(y_true[batch_index,...,0], alpha=0.5)
    axs[3].imshow(y_pred[batch_index,...,0], alpha=0.5)