In [None]:
# Obtain best possible source domain training results (above .82 dice)

In [None]:
# https://stackoverflow.com/questions/37893755/tensorflow-set-cuda-visible-devices-within-jupyter
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# https://stackoverflow.com/questions/56008683/could-not-create-cudnn-handle-cudnn-status-internal-error
import tensorflow as tf
gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.95)
config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import importlib

import wasserstein_utils
import data_utils
import losses
import networks
import deeplabv3 as dlv3
import utils
import io_utils

In [None]:
import time
import numpy as np

from IPython import display

import umap

In [None]:
backbone = 'vgg16'
dataset = "mmwhs"

# H x W x C
img_shape = (256,256,3)

# 4 classes + void
num_classes = 5

batch_size=16

do_training = True

epochs=90000
epoch_step=250

num_projections=100

data_dir = "./data/mmwhs/PnpAda_release_data/train&val/"
source_list = io_utils.read_list_file("./data/mmwhs/PnpAda_release_data/train&val/mr_train_list")
source_val_list = io_utils.read_list_file("./data/mmwhs/PnpAda_release_data/train&val/mr_val_list")
target_list = io_utils.read_list_file("./data/mmwhs/PnpAda_release_data/train&val/ct_train_list")
target_val_list = io_utils.read_list_file("./data/mmwhs/PnpAda_release_data/train&val/ct_val_list")
target_test_list = io_utils.read_list_file("./data/mmwhs/PnpAda_release_data/train&val/ct_test_list")

fn_w_dlv3 = "weights/" + dataset + "/" + backbone + "_deeplabv3_debug.h5"
fn_w_cls = "weights/" + dataset + "/" + backbone + "_deeplabv3_classifier_debug.h5"

In [None]:
importlib.reload(losses)

deeplabv3 = dlv3.deeplabv3(activation=None, \
                           backbone=backbone, \
                           num_classes=num_classes, \
                           regularizer=tf.keras.regularizers.l2(1e-4))

X = deeplabv3.input
Y = tf.keras.layers.Input(shape=(img_shape[0], img_shape[1], num_classes,), dtype='float32', name='label_input')

C_in = tf.keras.layers.Input(shape=deeplabv3.layers[-1].output_shape[1:], dtype='float32', name='classifier_input')
classifier = tf.keras.Model(C_in, networks.classifier_layers(C_in, num_classes = num_classes, activation='softmax'))

# A combined model, giving the output of classifier(deeplabv3(X))
combined = tf.keras.Model(X, classifier(deeplabv3(X)))
combined.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False))

# A model outputting hxwx1 labels for each image. Also useful to verify the
# mIoU with Keras' built-in function. Will however also consider the 'ignore' class. 
combined_ = tf.keras.Model(X, tf.cast(tf.keras.backend.argmax(combined(X), axis=-1), 'float32'))
combined_.compile(metrics=[tf.keras.metrics.MeanIoU(num_classes=num_classes)])

# Set up the loss functions
loss_function = losses.masked_ce_loss(num_classes, None)
wce_loss = loss_function(Y, classifier(deeplabv3(X)), from_logits=False)

# Set up training
opt = [tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-6, decay=1e-6), \
      tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-6, decay=1e-6), \
      tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-6, decay=1e-6)]

# https://stackoverflow.com/questions/55434653/batch-normalization-doesnt-have-gradient-in-tensorflow-2-0
params = deeplabv3.trainable_weights + classifier.trainable_weights

updates = [o.get_updates(wce_loss, params) for o in opt]
train = [tf.keras.backend.function(inputs=[X,Y], outputs=[wce_loss], updates=u) for u in updates]

In [None]:
importlib.reload(io_utils)

# Training on source domain
if do_training == True:
    try:
#         deeplabv3.load_weights(fn_w_dlv3)
#         classifier.load_weights(fn_w_cls)

#         print("Successfully loaded model. Continuing training.")
        print("Training from scratch")
    except:
        print("Could not load previous model weights. Is a new model present?")
        
    start_time = time.time()
    fig,ax = plt.subplots(1,figsize=(10,7))
    loss_history = []
    
    for itr in range(epochs):
        opt_idx = itr // (epochs // len(train) + 1)
        
        source_train_data, source_train_labels = io_utils.sample_batch(data_dir, source_list, \
                                                                       batch_size=batch_size, seed=itr)
        source_train_labels = tf.keras.utils.to_categorical(source_train_labels, num_classes=num_classes)
        
        loss_history.append(train[opt_idx](inputs=[ source_train_data, source_train_labels ]))

        if itr%epoch_step == 0 or itr < 1000:
            if itr != 0:
                ax.clear()
                ax.plot(np.log(np.asarray(loss_history)))

            ax.set_title("Training loss on source domain")
            ax.set_xlabel("Epoch")
            ax.set_ylabel("Log Loss")

            display.clear_output(wait=True)
            display.display(plt.gcf())
            time.sleep(1e-3)
            
        if itr % (epochs // 10) == 0 or itr == epochs - 1:
            deeplabv3.save_weights(fn_w_dlv3)
            classifier.save_weights(fn_w_cls)

    training_time = time.time() - start_time
else:
    deeplabv3.load_weights(fn_w_dlv3)
    classifier.load_weights(fn_w_cls)
    print("Loaded model weights")

In [None]:
source_train_data, source_train_labels = io_utils.sample_batch(data_dir, source_val_list, \
                                                               batch_size=batch_size, seed=42)

print(source_train_data.shape, source_train_labels.shape)
idx = 10

plt.imshow(source_train_data[idx][:,:,1])
plt.show()

plt.imshow(source_train_labels[idx].reshape(256,256), vmin=0, vmax=4)
plt.show()

myans = combined_.predict(source_train_data[idx].reshape(1,256,256,3))
plt.imshow(myans[0], vmin=0, vmax=4)
plt.show()

In [None]:
importlib.reload(utils)
importlib.reload(data_utils)

start_time = time.time()

source_cat_iou,source_miou = utils.compute_miou(data_dir, source_val_list, combined_, data_utils.label_ids_mmwhs, \
                                                id_to_ignore=0)

for k in source_cat_iou:
    print(k, source_cat_iou[k])
print(source_miou)

print('Computed ' + dataset + ' mIoU in', time.time() - start_time)

# lv_myo 0.6391655877566917
# la_blood 0.7761180002868887
# lv_blood 0.8552707982717727
# aa 0.6573924014969161
# 0.7319866969530673
# Computed mmwhs mIoU in 198.71120285987854

In [None]:
importlib.reload(utils)
importlib.reload(data_utils)

start_time = time.time()

source_cat_dice,source_dice = utils.compute_dice(data_dir, source_val_list, combined_, data_utils.label_ids_mmwhs, \
                                                 id_to_ignore=0)

for k in source_cat_dice:
    print(k, source_cat_dice[k])
print(source_dice)

print('Computed ' + dataset + ' DICE in', time.time() - start_time)

# lv_myo 0.7798670159143992
# la_blood 0.8739486905279106
# lv_blood 0.9219902550813358
# aa 0.7932851639758641
# 0.8422727813748774
# Computed mmwhs DICE in 199.5630497932434

In [None]:
importlib.reload(utils)
importlib.reload(data_utils)

start_time = time.time()

target_cat_dice,target_dice = utils.compute_dice(data_dir, target_val_list, combined_, data_utils.label_ids_mmwhs, \
                                                 id_to_ignore=0)

for k in target_cat_dice:
    print(k, target_cat_dice[k])
print(target_dice)

print('Computed ' + dataset + ' target DICE in', time.time() - start_time)

# lv_myo 0.1805723386505549
# la_blood 0.732765818219422
# lv_blood 0.274775071361346
# aa 0.6930902059047827
# 0.47030085853402637
# Computed mmwhs target DICE in 98.94450497627258

In [None]:
importlib.reload(utils)
importlib.reload(data_utils)
importlib.reload(io_utils)

start_time = time.time()

target_cat_dice,target_dice = utils.compute_dice(data_dir, target_test_list, combined_, data_utils.label_ids_mmwhs, \
                                                 id_to_ignore=0)

for k in target_cat_dice:
    print(k, target_cat_dice[k])
print(target_dice)

print('Computed ' + dataset + ' target DICE in', time.time() - start_time)

# lv_myo 0.30680336428773775
# la_blood 0.8255339521423518
# lv_blood 0.47421624761212994
# aa 0.7822472557192597
# 0.5972002049403697
# Computed mmwhs target DICE in 82.1158561706543