In [None]:
# Obtain best possible source domain training results (above .82 dice)

In [None]:
# https://stackoverflow.com/questions/37893755/tensorflow-set-cuda-visible-devices-within-jupyter
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# https://stackoverflow.com/questions/56008683/could-not-create-cudnn-handle-cudnn-status-internal-error
import tensorflow as tf
gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.95)
config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import importlib

import wasserstein_utils
import data_utils
import losses
import networks
import deeplabv3 as dlv3
import utils
import io_utils

In [None]:
import time
import numpy as np

from IPython import display

import umap

In [None]:
backbone = 'vgg16'
dataset = "abdomen"

# H x W x C
img_shape = (256,256,3)

# 4 classes + void
num_classes = 5

batch_size=16

do_training = True

epochs=90000
epoch_step=500

num_projections=100

data_dir = "./data/abdomen/processed-data/"
source_train_list = io_utils.read_list_file("./data/abdomen/processed-data/mri_chaos_train_list")
source_val_list = io_utils.read_list_file("./data/abdomen/processed-data/mri_chaos_val_list")
target_list = io_utils.read_list_file("./data/abdomen/processed-data/ct_atlas_train_list")
target_test_list = io_utils.read_list_file("./data/abdomen/processed-data/ct_atlas_test_list")

suffix = "run1"
fn_w_dlv3 = "weights/" + dataset + "/" + backbone + "_deeplabv3_" + suffix + ".h5"
fn_w_cls = "weights/" + dataset + "/" + backbone + "_deeplabv3_classifier_" + suffix + ".h5"

In [None]:
importlib.reload(losses)
importlib.reload(dlv3)

deeplabv3 = dlv3.deeplabv3(activation=None, \
                           backbone=backbone, \
                           num_classes=num_classes, \
                           regularizer=tf.keras.regularizers.l2(1), \
                           dropout=.5)

X = deeplabv3.input
Y = tf.keras.layers.Input(shape=(img_shape[0], img_shape[1], num_classes,), dtype='float32', name='label_input')

C_in = tf.keras.layers.Input(shape=deeplabv3.layers[-1].output_shape[1:], dtype='float32', name='classifier_input')
classifier = tf.keras.Model(C_in, networks.classifier_layers(C_in, num_classes = num_classes, activation='softmax'))

# A combined model, giving the output of classifier(deeplabv3(X))
combined = tf.keras.Model(X, classifier(deeplabv3(X)))
combined.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False))

# A model outputting hxwx1 labels for each image. Also useful to verify the
# mIoU with Keras' built-in function. Will however also consider the 'ignore' class. 
combined_ = tf.keras.Model(X, tf.cast(tf.keras.backend.argmax(combined(X), axis=-1), 'float32'))
combined_.compile(metrics=[tf.keras.metrics.MeanIoU(num_classes=num_classes)])

# Set up the loss functions
loss_function = losses.weighted_ce_loss(num_classes, None)
wce_loss = loss_function(Y, classifier(deeplabv3(X)), from_logits=False)

# Set up training
opt = [tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-6, decay=1e-6), \
      tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-6, decay=1e-6), \
      tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-6, decay=1e-6)]

# https://stackoverflow.com/questions/55434653/batch-normalization-doesnt-have-gradient-in-tensorflow-2-0
params = deeplabv3.trainable_weights + classifier.trainable_weights

updates = [o.get_updates(wce_loss, params) for o in opt]
train = [tf.keras.backend.function(inputs=[X,Y], outputs=[wce_loss], updates=u) for u in updates]

In [None]:
importlib.reload(io_utils)

# Training on source domain
if do_training == True:
    try:
#         deeplabv3.load_weights(fn_w_dlv3)
#         classifier.load_weights(fn_w_cls)

#         print("Successfully loaded model. Continuing training.")
        print("Training from scratch")
    except:
        print("Could not load previous model weights. Is a new model present?")
        
    start_time = time.time()
    fig,ax = plt.subplots(1,figsize=(10,7))
    loss_history = []
    
    for itr in range(epochs):
        opt_idx = itr // (epochs // len(train) + 1)
        
        source_train_data, source_train_labels = io_utils.sample_batch(data_dir, source_train_list, \
                                                                       batch_size=batch_size, seed=itr)
        source_train_labels = tf.keras.utils.to_categorical(source_train_labels, num_classes=num_classes)
        
        loss_history.append(train[opt_idx](inputs=[ source_train_data, source_train_labels ]))

        if itr%epoch_step == 0 or itr < 1000:
            if itr != 0:
                ax.clear()
                ax.plot(np.log(np.asarray(loss_history)))

            ax.set_title("Training loss on source domain")
            ax.set_xlabel("Epoch")
            ax.set_ylabel("Log Loss")

            display.clear_output(wait=True)
            display.display(plt.gcf())
            time.sleep(1e-3)
    
        if itr % (epochs // 100) == 0 or itr == epochs - 1:
            deeplabv3.save_weights(fn_w_dlv3)
            classifier.save_weights(fn_w_cls)

    training_time = time.time() - start_time
else:
    deeplabv3.load_weights(fn_w_dlv3)
    classifier.load_weights(fn_w_cls)
    print("Loaded model weights")

In [None]:
deeplabv3.load_weights(fn_w_dlv3)
classifier.load_weights(fn_w_cls)

In [None]:
source_train_data, source_train_labels = io_utils.sample_batch(data_dir, source_train_list, \
                                                               batch_size=batch_size, seed=42)

print(source_train_data.shape, source_train_labels.shape)
idx = 5

plt.imshow(source_train_data[idx][:,:,1])
plt.show()

plt.imshow(source_train_labels[idx].reshape(256,256), vmin=0, vmax=4)
plt.show()

myans = combined_.predict(source_train_data[idx].reshape(1,256,256,3))
plt.imshow(myans[0], vmin=0, vmax=4)
plt.show()

In [None]:
importlib.reload(utils)
importlib.reload(data_utils)

start_time = time.time()

source_cat_dice,source_dice = utils.compute_dice(data_dir, source_train_list, combined_, data_utils.label_ids_abdomen, \
                                                 id_to_ignore=0)

for k in source_cat_dice:
    print(k, source_cat_dice[k])
print(source_dice)

print('Computed ' + dataset + ' source training DICE in', time.time() - start_time)

# liver 0.9787260747631922
# right_kidney 0.9483143525935519
# left_kidney 0.9440179099644721
# spleen 0.9490219109321446
# 0.9550200620633401
# Computed abdomen source training DICE in 838.5641431808472

In [None]:
importlib.reload(utils)
importlib.reload(data_utils)

start_time = time.time()

source_cat_dice,source_dice = utils.compute_dice(data_dir, source_val_list, combined_, data_utils.label_ids_abdomen, \
                                                 id_to_ignore=0)

for k in source_cat_dice:
    print(k, source_cat_dice[k])
print(source_dice)

print('Computed ' + dataset + ' source validation DICE in', time.time() - start_time)

# liver 0.8757004829320048
# right_kidney 0.8712244593434707
# left_kidney 0.8717021572035188
# spleen 0.8433080796990774
# 0.8654837947945179
# Computed abdomen source validation DICE in 205.81820273399353

In [None]:
deeplabv3.summary()

In [None]:
importlib.reload(utils)
importlib.reload(data_utils)

start_time = time.time()

target_cat_dice,target_dice = utils.compute_dice(data_dir, target_list, combined_, data_utils.label_ids_abdomen, \
                                                 id_to_ignore=0)

for k in target_cat_dice:
    print(k, target_cat_dice[k])
print(target_dice)

print('Computed ' + dataset + ' target DICE in', time.time() - start_time)

# liver 0.6530598731222946
# right_kidney 0.5666970426966222
# left_kidney 0.6155851272696546
# spleen 0.4065931413558
# 0.5604837961110929
# Computed abdomen target DICE in 1262.0187606811523

In [None]:
importlib.reload(utils)
importlib.reload(data_utils)

start_time = time.time()

target_cat_dice,target_dice = utils.compute_dice(data_dir, target_test_list, combined_, data_utils.label_ids_abdomen, \
                                                 id_to_ignore=0)

for k in target_cat_dice:
    print(k, target_cat_dice[k])
print(target_dice)

print('Computed ' + dataset + ' target DICE in', time.time() - start_time)

# liver 0.670681529231044
# right_kidney 0.41137495310217714
# left_kidney 0.6314386452919504
# spleen 0.3881504453307543
# 0.5254113932389815
# Computed abdomen target DICE in 28.56124997138977