In [None]:
# End-End training and testing of the new pytorch-based VGG16-deeplabv3 architecture

In [None]:
# https://stackoverflow.com/questions/37893755/tensorflow-set-cuda-visible-devices-within-jupyter
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

# https://stackoverflow.com/questions/56008683/could-not-create-cudnn-handle-cudnn-status-internal-error
import tensorflow as tf
gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.95)
config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)

In [None]:
import importlib

import wasserstein_utils
import data_utils
import evaluate
import losses
import networks
import deeplabv3 as dlv3
import utils
import io_utils

In [None]:
import time
import numpy as np

from IPython import display

import umap

In [None]:
backbone = 'vgg16'
dataset = "abdomen"

# H x W x C
img_shape = (256,256,3)

# 4 classes + void
num_classes = 5

batch_size=16

epochs=30000
epoch_step=500

num_projections=100

data_dir = "./data/abdomen/processed-data/"
source_train_list = io_utils.read_list_file("./data/abdomen/processed-data/mri_chaos_train_list")
source_val_list = io_utils.read_list_file("./data/abdomen/processed-data/mri_chaos_val_list")
target_list = io_utils.read_list_file("./data/abdomen/processed-data/ct_atlas_train_list")
target_test_list = io_utils.read_list_file("./data/abdomen/processed-data/ct_atlas_test_list")

suffix = "run1"
fn_w_dlv3 = "weights/" + dataset + "/" + backbone + "_deeplabv3_debug_" + suffix + ".h5"
fn_w_cls = "weights/" + dataset + "/" + backbone + "_deeplabv3_classifier_debug_" + suffix + ".h5"

fn_w_adapted_dlv3 = "weights/" + dataset + "/" + backbone +"_deeplabv3_adapted_debug_" + suffix + ".h5"
fn_w_adapted_cls = "weights/" + dataset + "/" + backbone + "_deeplabv3_classifier_adapted_debug_" + suffix + ".h5"

In [None]:
importlib.reload(losses)
importlib.reload(dlv3)

deeplabv3 = dlv3.deeplabv3(activation=None, \
                           backbone=backbone, \
                           num_classes=num_classes, \
                           regularizer=tf.keras.regularizers.l2(1), \
                           dropout=.5)

X = deeplabv3.input
Y = tf.keras.layers.Input(shape=(img_shape[0], img_shape[1], num_classes,), dtype='float32', name='label_input')

C_in = tf.keras.layers.Input(shape=deeplabv3.layers[-1].output_shape[1:], dtype='float32', name='classifier_input')
classifier = tf.keras.Model(C_in, networks.classifier_layers(C_in, num_classes = num_classes, activation='softmax'))

# A combined model, giving the output of classifier(deeplabv3(X))
combined = tf.keras.Model(X, classifier(deeplabv3(X)))
combined.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False))

# A model outputting hxwx1 labels for each image. Also useful to verify the
# mIoU with Keras' built-in function. Will however also consider the 'ignore' class. 
combined_ = tf.keras.Model(X, tf.cast(tf.keras.backend.argmax(combined(X), axis=-1), 'float32'))
combined_.compile(metrics=[tf.keras.metrics.MeanIoU(num_classes=num_classes)])

In [None]:
deeplabv3.load_weights(fn_w_dlv3)
classifier.load_weights(fn_w_cls)

In [None]:
try:
    print("Loading pre-learnt gaussians")
    
    means = np.load("./extras/means_" + backbone + "_deeplabv3_" + dataset + "_" + suffix + ".npy")
    covs = np.load("./extras/covs_" + backbone + "_deeplabv3_" + dataset + "_" + suffix + ".npy")
except:
    print("Learning means and covariances")
    
    importlib.reload(utils)

    # means = np.load("./extras/submission/means_" + backbone + "_deeplabv3_" + dataset + "_" + suffix + ".npy")
    # covs = np.load("./extras/submission/covs_" + backbone + "_deeplabv3_" + dataset + "_" + suffix + ".npy")

    start_time = time.time()
    means, _, ct = utils.learn_gaussians(data_dir, source_train_list, deeplabv3, combined, batch_size, data_utils.label_ids_abdomen, \
                                    rho=.97)
    print("computed means in", time.time() - start_time)

    start_time = time.time()
    means, covs, ct = utils.learn_gaussians(data_dir, source_train_list, deeplabv3, combined, batch_size, data_utils.label_ids_abdomen, \
                                  rho=.97, initial_means=means)
    print("finished training gaussians in", time.time() - start_time)

    np.save("./extras/means_" + backbone + "_deeplabv3_" + dataset + "_" + suffix + ".npy", means)
    np.save("./extras/covs_" + backbone + "_deeplabv3_" + dataset + "_" + suffix + ".npy", covs)

In [None]:
# Generate data from the gmm model and plot it
importlib.reload(umap)

importlib.reload(utils)

start_time = time.time()

n_samples = np.ones(num_classes, dtype=int)
n_samples *= 2000

xx, yy = utils.sample_from_gaussians(means, covs, n_samples=n_samples)

NUM_COLORS = num_classes

reducer = umap.UMAP()

umap_embedding = reducer.fit_transform(xx)

plt.figure(figsize=(16,14))
cm = plt.get_cmap('gist_rainbow')

shift = 1 / len(data_utils.label_ids_abdomen.keys())
idx = 0
for label in data_utils.label_ids_abdomen:
    ind = yy == data_utils.label_ids_abdomen[label]
    
    plt.scatter(umap_embedding[:,0][ind], umap_embedding[:,1][ind], label=label, \
                color=cm(1.*idx/NUM_COLORS))
    idx += 1

plt.title("Embedding scatter-plot")
plt.legend()
    
plt.show()

print(time.time() - start_time)

In [None]:
importlib.reload(wasserstein_utils)

Z_s = tf.keras.layers.Input(shape=(img_shape[0], img_shape[1], num_classes,) )
Y_s = tf.keras.backend.placeholder(shape=(None, img_shape[0], img_shape[1], num_classes), dtype='float32') #labels of input images oneHot
lambda2 = .5

loss_function = losses.masked_ce_loss(num_classes, None)
wce_loss = loss_function(Y_s, classifier(Z_s), from_logits=False)

# Wasserstein matcing loss
theta = tf.keras.backend.placeholder(shape = (num_projections, num_classes), dtype='float32')
matching_loss = wasserstein_utils.sWasserstein_hd(deeplabv3(X), Z_s, theta, nclass=num_classes, Cp=None, Cq=None,)

# Overall loss is a weighted combination of the two losses
total_loss = wce_loss + lambda2*matching_loss

params = deeplabv3.trainable_weights + classifier.trainable_weights

# Optimizer and training setup
opt = tf.keras.optimizers.Adam(lr=5e-5, epsilon=1e-1, decay=1e-6)

updates = opt.get_updates(total_loss, params)
train = tf.keras.backend.function(inputs=[X,Z_s,Y_s,theta], outputs=[total_loss, wce_loss, matching_loss], updates=updates)

In [None]:
importlib.reload(wasserstein_utils)

Z_s = tf.keras.layers.Input(shape=(img_shape[0], img_shape[1], num_classes,) )
Y_s = tf.keras.backend.placeholder(shape=(None, img_shape[0], img_shape[1], num_classes), dtype='float32') #labels of input images oneHot
lambda2 = .5

loss_function = losses.masked_ce_loss(num_classes, None)
wce_loss = loss_function(Y_s, classifier(Z_s), from_logits=False)

# Wasserstein matching loss
theta = tf.keras.backend.placeholder(shape = (num_projections, num_classes), dtype='float32')
matching_loss = wasserstein_utils.sWasserstein_hd(deeplabv3(X), Z_s, theta, nclass=num_classes, Cp=None, Cq=None,)

# Overall loss is a weighted combination of the two losses
total_loss = wce_loss + lambda2*matching_loss

params = deeplabv3.trainable_weights + classifier.trainable_weights

# Optimizer and training setup
opt = tf.keras.optimizers.Adam(lr=5e-5, epsilon=1e-1, decay=1e-6)

updates = opt.get_updates(total_loss, params)
train = tf.keras.backend.function(inputs=[X,Z_s,Y_s,theta], outputs=[total_loss, wce_loss, matching_loss], updates=updates)

In [None]:
loss = []
target_miou = []

deeplabv3.load_weights(fn_w_dlv3)
classifier.load_weights(fn_w_cls)

In [None]:
fig,ax=plt.subplots(2,figsize=(15,10))

for itr in range(epochs):
    target_train_data, target_train_labels = io_utils.sample_batch(data_dir, target_list, \
                                                                   batch_size=batch_size, seed=itr)
    
    # make sure the #samples from gaussians match the distribution of the labels
    n_samples = np.zeros(num_classes, dtype=int)
    cls, ns = np.unique(target_train_labels, return_counts=True)
    for i in range(len(cls)):
        n_samples[cls[i]] = ns[i]

    if np.sum(n_samples) % np.prod(img_shape) != 0:
        remaining = np.prod(img_shape[:-1]) - np.sum(n_samples) % np.prod(img_shape[:-1])

        aux = np.copy(n_samples) / np.sum(n_samples)
        aux *= remaining
        aux = np.floor(aux).astype('int')
        
        n_samples += aux
        
        # in case there are extra samples left, dump them on the highest represented class
        n_samples[np.argmax(n_samples)] += remaining - np.sum(aux)

    Yembed,Yembedlabels = utils.sample_from_gaussians(means, covs, n_samples = n_samples)
    Yembed = Yembed.reshape(-1, img_shape[0], img_shape[1], num_classes)
    Yembedlabels = Yembedlabels.reshape(-1, img_shape[0], img_shape[1])

    Yembedlabels = tf.keras.utils.to_categorical(Yembedlabels, num_classes=num_classes)

    theta_instance = tf.keras.backend.variable(wasserstein_utils.generateTheta(num_projections,num_classes))
    loss.append(train(inputs=[target_train_data, Yembed, Yembedlabels, theta_instance]))
    
    target_miou.append(combined_.evaluate(target_train_data, target_train_labels, verbose=False)[-1])
    
    if itr%epoch_step==0 or itr < 1000:
        deeplabv3.save_weights(fn_w_adapted_dlv3)
        classifier.save_weights(fn_w_adapted_cls)
        
        # Debug info. First, the mIoU. Second, the categorical CE loss (ignoring class weights and containing) 
        # the ignore class
        if itr != 0:
            ax[0].clear()
            
            ll = np.asarray(loss)
            ax[0].plot(np.log(ll[:,0]), label='log total loss')
            ax[0].plot(np.log(ll[:,1]), label='log ce loss')
            ax[0].plot(np.log(ll[:,2] * lambda2), label='log wasserstein loss')
            ax[0].legend()
            
        ax[0].set_title("Log Loss")
        ax[0].set_xlabel("Epochs")
        ax[0].set_ylabel("Log Loss")
        
        if itr != 0:
            ax[1].clear()
            ax[1].plot(np.asarray(target_miou))
        
        ax[1].set_title("MIOU on target domain")
        ax[1].set_xlabel("Epochs")
        ax[1].set_ylabel("Mean IOU")
        
        display.display(plt.gcf())
        display.clear_output(wait=True)
        
        time.sleep(1e-3) 

In [None]:
deeplabv3.save_weights(fn_w_adapted_dlv3)
classifier.save_weights(fn_w_adapted_cls)

In [None]:
start_time = time.time()

target_cat_dice,target_dice = utils.compute_dice(data_dir, target_test_list, combined_, data_utils.label_ids_abdomen, \
                                                 id_to_ignore=0)

for k in target_cat_dice:
    print(k, target_cat_dice[k])
print(target_dice)

print('Computed ' + dataset + ' target DICE in', time.time() - start_time)

# liver 0.8745730727705845
# right_kidney 0.7234397414832942
# left_kidney 0.8030892885012225
# spleen 0.8031550976644446
# 0.8010643001048865
# Computed abdomen target DICE in 23.268261671066284

In [None]:
start_time = time.time()

target_cat_dice,target_dice = utils.compute_dice(data_dir, target_list, combined_, data_utils.label_ids_abdomen, \
                                                 id_to_ignore=0)

for k in target_cat_dice:
    print(k, target_cat_dice[k])
print(target_dice)

print('Computed ' + dataset + ' target DICE in', time.time() - start_time)

# liver 0.8783129334054984
# right_kidney 0.7251318294154707
# left_kidney 0.782384974705287
# spleen 0.8428680590576928
# 0.8071744491459872
# Computed abdomen target DICE in 991.6162948608398