In [35]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import os
import json

from subtle.dnn.generators import GeneratorUNet2D, GeneratorBranchUNet2D, GeneratorFBoostUNet2D
import subtle.utils.io as suio
from subtle.data_loaders.slice_loader import SliceLoader
import subtle.subtle_loss as suloss
from subtle.dnn.helpers import clear_keras_memory, set_keras_memory
from keras.optimizers import Adam

from keract import get_activations, display_activations, display_heatmaps
from keras.utils.vis_utils import plot_model

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

plt.rcParams['figure.figsize'] = (10, 8)
plt.set_cmap('gray')

def transfer_weights(src_model, dest_model, branch_num):
    kw = 'b{}_'.format(branch_num)
    op_layer = 'branch{}_output'.format(branch_num)
    
    print('Transferring weights -> {}'.format(kw))
    src_layers = [l.name for l in src_model.model.layers]
    for idx, layer in enumerate(dest_model.model.layers):
        if kw not in layer.name:
            continue
        unet_name = layer.name.replace(kw, '')

        if unet_name in src_layers:
            lname = unet_name
            lname_orig = layer.name
            src_weights = src_model.model.layers[src_layers.index(unet_name)].get_weights()
            dest_model.model.layers[idx].set_weights(src_weights)
            dest_model.model.layers[idx].trainable = False
    dest_model.model.get_layer(op_layer).set_weights(src_model.model.get_layer('model_output').get_weights())
    dest_model.model.get_layer(op_layer).trainable = False
    return dest_model

<Figure size 720x576 with 0 Axes>

In [None]:
clear_keras_memory()
set_keras_memory(1.0)

loss_function = suloss.mixed_loss(l1_lambda=1.0, ssim_lambda=0.0)
metrics_monitor = [suloss.l1_loss]

model_main = GeneratorFBoostUNet2D(
                num_channel_input=35, num_channel_output=1,
                img_rows=240, img_cols=240,
                loss_function=loss_function,
                metrics_monitor=metrics_monitor,
                lr_init=0.001,
                verbose=1,
                compile_model=False,
                model_config='base',
                fpaths_pre=[
                    '/home/srivathsa/projects/studies/gad/tiantan/train/checkpoints/enh_vgg.checkpoint',
                    '/home/srivathsa/projects/studies/gad/tiantan/train/checkpoints/fboost_t1_t2.checkpoint',
                    '/home/srivathsa/projects/studies/gad/tiantan/train/checkpoints/fboost_t1_fl.checkpoint',
                    '/home/srivathsa/projects/studies/gad/tiantan/train/checkpoints/fboost_t1_uad.checkpoint'
                ]
            )

In [32]:
clear_keras_memory()
set_keras_memory(1.0)

loss_function = suloss.mixed_loss(l1_lambda=1.0, ssim_lambda=0.0)
metrics_monitor = [suloss.l1_loss]


model_main = GeneratorBranchUNet2D(
            num_channel_input=35, num_channel_output=1,
            img_rows=240, img_cols=240,
            loss_function=loss_function,
            metrics_monitor=metrics_monitor,
            lr_init=0.001,
            verbose=1,
            compile_model=False,
            model_config='fusion_single_ch')

model_t1 = GeneratorUNet2D(
            num_channel_input=14, num_channel_output=1,
            img_rows=240, img_cols=240,
            loss_function=loss_function,
            metrics_monitor=metrics_monitor,
            lr_init=0.001,
            verbose=1,
            compile_model=False,
            model_config='base',
            checkpoint_file='/home/srivathsa/projects/studies/gad/tiantan/train/checkpoints/enh_vgg.checkpoint')

model_t1.load_weights()

model_t2 = GeneratorUNet2D(num_channel_input=14, num_channel_output=1,
            img_rows=240, img_cols=240,
            loss_function=loss_function,
            metrics_monitor=metrics_monitor,
            lr_init=0.001,
            verbose=1,
            compile_model=False,
            model_config='base',
            checkpoint_file='/home/srivathsa/projects/studies/gad/tiantan/train/checkpoints/fboost_t1_t2.checkpoint')
model_t2.load_weights()

model_fl = GeneratorUNet2D(num_channel_input=14, num_channel_output=1,
            img_rows=240, img_cols=240,
            loss_function=loss_function,
            metrics_monitor=metrics_monitor,
            lr_init=0.001,
            verbose=1,
            compile_model=False,
            model_config='base',
            checkpoint_file='/home/srivathsa/projects/studies/gad/tiantan/train/checkpoints/fboost_t1_fl.checkpoint')
model_fl.load_weights()

model_uad = GeneratorUNet2D(num_channel_input=14, num_channel_output=1,
            img_rows=240, img_cols=240,
            loss_function=loss_function,
            metrics_monitor=metrics_monitor,
            lr_init=0.001,
            verbose=1,
            compile_model=False,
            model_config='base',
            checkpoint_file='/home/srivathsa/projects/studies/gad/tiantan/train/checkpoints/fboost_t1_uad.checkpoint')
model_uad.load_weights()

Building branch_unet2d-fusion_single_ch model...
inputs Tensor("model_input:0", shape=(?, 240, 240, 35), dtype=float32)
t1 pre Tensor("t1_pre/strided_slice:0", shape=(?, 240, 240, 7), dtype=float32)
t1 low Tensor("t1_low/strided_slice:0", shape=(?, 240, 240, 7), dtype=float32)
t2 Tensor("t2/strided_slice:0", shape=(?, 240, 240, 7), dtype=float32)
fl Tensor("fl/strided_slice:0", shape=(?, 240, 240, 7), dtype=float32)
uad Tensor("uad/strided_slice:0", shape=(?, 240, 240, 7), dtype=float32)
Tensor("relu_b1_conv_enc_1_2/Relu:0", shape=(?, 240, 240, 32), dtype=float32) Tensor("b1_maxpool_1/MaxPool:0", shape=(?, 120, 120, 32), dtype=float32)
Tensor("relu_b1_conv_enc_2_2/Relu:0", shape=(?, 120, 120, 64), dtype=float32) Tensor("b1_maxpool_2/MaxPool:0", shape=(?, 60, 60, 64), dtype=float32)
Tensor("relu_b1_conv_enc_3_2/Relu:0", shape=(?, 60, 60, 128), dtype=float32) Tensor("b1_maxpool_3/MaxPool:0", shape=(?, 30, 30, 128), dtype=float32)
conv center before add Tensor("relu_b1_conv_center/Relu:0"

relu_b1_conv_enc_1_0 (Activatio (None, 240, 240, 32) 0           b1_conv_enc_1_0[0][0]            
__________________________________________________________________________________________________
relu_b2_conv_enc_1_0 (Activatio (None, 240, 240, 32) 0           b2_conv_enc_1_0[0][0]            
__________________________________________________________________________________________________
relu_b3_conv_enc_1_0 (Activatio (None, 240, 240, 32) 0           b3_conv_enc_1_0[0][0]            
__________________________________________________________________________________________________
relu_b4_conv_enc_1_0 (Activatio (None, 240, 240, 32) 0           b4_conv_enc_1_0[0][0]            
__________________________________________________________________________________________________
b1_conv_enc_1_1 (Conv2D)        (None, 240, 240, 32) 9248        relu_b1_conv_enc_1_0[0][0]       
__________________________________________________________________________________________________
b2_conv_en

Tensor("relu_conv_enc_3_2/Relu:0", shape=(?, 60, 60, 128), dtype=float32) Tensor("maxpool_3/MaxPool:0", shape=(?, 30, 30, 128), dtype=float32)
conv center before add Tensor("relu_conv_center/Relu:0", shape=(?, 30, 30, 128), dtype=float32)
conv center... Tensor("add_center/add:0", shape=(?, 30, 30, 128), dtype=float32)
Tensor("relu_conv_dec_2_2/Relu:0", shape=(?, 60, 60, 128), dtype=float32) Tensor("cat_1/concat:0", shape=(?, 60, 60, 256), dtype=float32)
Tensor("relu_conv_dec_3_2/Relu:0", shape=(?, 120, 120, 64), dtype=float32) Tensor("cat_2/concat:0", shape=(?, 120, 120, 192), dtype=float32)
Tensor("relu_conv_dec_4_2/Relu:0", shape=(?, 240, 240, 32), dtype=float32) Tensor("cat_3/concat:0", shape=(?, 240, 240, 96), dtype=float32)
Tensor("linear_model_output_1/Identity:0", shape=(?, 240, 240, 1), dtype=float32)
Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #   

Building unet2d model...
Tensor("model_input_2:0", shape=(?, 240, 240, 14), dtype=float32)
Tensor("relu_conv_enc_1_2_1/Relu:0", shape=(?, 240, 240, 32), dtype=float32) Tensor("maxpool_1_1/MaxPool:0", shape=(?, 120, 120, 32), dtype=float32)
Tensor("relu_conv_enc_2_2_1/Relu:0", shape=(?, 120, 120, 64), dtype=float32) Tensor("maxpool_2_1/MaxPool:0", shape=(?, 60, 60, 64), dtype=float32)
Tensor("relu_conv_enc_3_2_1/Relu:0", shape=(?, 60, 60, 128), dtype=float32) Tensor("maxpool_3_1/MaxPool:0", shape=(?, 30, 30, 128), dtype=float32)
conv center before add Tensor("relu_conv_center_1/Relu:0", shape=(?, 30, 30, 128), dtype=float32)
conv center... Tensor("add_center_1/add:0", shape=(?, 30, 30, 128), dtype=float32)
Tensor("relu_conv_dec_2_2_1/Relu:0", shape=(?, 60, 60, 128), dtype=float32) Tensor("cat_1_1/concat:0", shape=(?, 60, 60, 256), dtype=float32)
Tensor("relu_conv_dec_3_2_1/Relu:0", shape=(?, 120, 120, 64), dtype=float32) Tensor("cat_2_1/concat:0", shape=(?, 120, 120, 192), dtype=float32

Building unet2d model...
Tensor("model_input_3:0", shape=(?, 240, 240, 14), dtype=float32)
Tensor("relu_conv_enc_1_2_2/Relu:0", shape=(?, 240, 240, 32), dtype=float32) Tensor("maxpool_1_2/MaxPool:0", shape=(?, 120, 120, 32), dtype=float32)
Tensor("relu_conv_enc_2_2_2/Relu:0", shape=(?, 120, 120, 64), dtype=float32) Tensor("maxpool_2_2/MaxPool:0", shape=(?, 60, 60, 64), dtype=float32)
Tensor("relu_conv_enc_3_2_2/Relu:0", shape=(?, 60, 60, 128), dtype=float32) Tensor("maxpool_3_2/MaxPool:0", shape=(?, 30, 30, 128), dtype=float32)
conv center before add Tensor("relu_conv_center_2/Relu:0", shape=(?, 30, 30, 128), dtype=float32)
conv center... Tensor("add_center_2/add:0", shape=(?, 30, 30, 128), dtype=float32)
Tensor("relu_conv_dec_2_2_2/Relu:0", shape=(?, 60, 60, 128), dtype=float32) Tensor("cat_1_2/concat:0", shape=(?, 60, 60, 256), dtype=float32)
Tensor("relu_conv_dec_3_2_2/Relu:0", shape=(?, 120, 120, 64), dtype=float32) Tensor("cat_2_2/concat:0", shape=(?, 120, 120, 192), dtype=float32

Building unet2d model...
Tensor("model_input_4:0", shape=(?, 240, 240, 14), dtype=float32)
Tensor("relu_conv_enc_1_2_3/Relu:0", shape=(?, 240, 240, 32), dtype=float32) Tensor("maxpool_1_3/MaxPool:0", shape=(?, 120, 120, 32), dtype=float32)
Tensor("relu_conv_enc_2_2_3/Relu:0", shape=(?, 120, 120, 64), dtype=float32) Tensor("maxpool_2_3/MaxPool:0", shape=(?, 60, 60, 64), dtype=float32)
Tensor("relu_conv_enc_3_2_3/Relu:0", shape=(?, 60, 60, 128), dtype=float32) Tensor("maxpool_3_3/MaxPool:0", shape=(?, 30, 30, 128), dtype=float32)
conv center before add Tensor("relu_conv_center_3/Relu:0", shape=(?, 30, 30, 128), dtype=float32)
conv center... Tensor("add_center_3/add:0", shape=(?, 30, 30, 128), dtype=float32)
Tensor("relu_conv_dec_2_2_3/Relu:0", shape=(?, 60, 60, 128), dtype=float32) Tensor("cat_1_3/concat:0", shape=(?, 60, 60, 256), dtype=float32)
Tensor("relu_conv_dec_3_2_3/Relu:0", shape=(?, 120, 120, 64), dtype=float32) Tensor("cat_2_3/concat:0", shape=(?, 120, 120, 192), dtype=float32

In [36]:
model_main = transfer_weights(model_t1, model_main, '1')
model_main = transfer_weights(model_t2, model_main, '2')
model_main = transfer_weights(model_fl, model_main, '3')
model_main = transfer_weights(model_uad, model_main, '4')

Transferring weights -> b1_
Transferring weights -> b2_
Transferring weights -> b3_
Transferring weights -> b4_


In [37]:
model_main.model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
model_input (InputLayer)        (None, 240, 240, 35) 0                                            
__________________________________________________________________________________________________
t1_pre (Lambda)                 (None, 240, 240, 7)  0           model_input[0][0]                
__________________________________________________________________________________________________
t1_low (Lambda)                 (None, 240, 240, 7)  0           model_input[0][0]                
__________________________________________________________________________________________________
t2 (Lambda)                     (None, 240, 240, 7)  0           model_input[0][0]                
____________________________________________________________________________________________

In [None]:
# model_main = transfer_weights(model_t1, model_main, 'fbst_')

In [None]:
prediction_generator = SliceLoader(
    data_list=['/home/srivathsa/projects/studies/gad/tiantan/preprocess/data_t2_fl/NO6.h5'],
    file_ext='h5',
    input_idx=[0, 1, 3, 4],
    output_idx=[2],
    batch_size=1,
    shuffle=False,
    verbose=0,
    residual_mode=False,
    slices_per_input=7,
    num_channel_output=1,
    resize=240,
    slice_axis=[0],
    brain_only=False,
    predict=False,
    use_enh_uad=True,
    use_uad_ch_input=True,
    uad_ip_channels=7,
    brain_only_mode=None,
    enh_mask_t2=False,
    enh_pfactor=1,
    fpath_uad_masks=['/home/srivathsa/projects/studies/gad/tiantan/preprocess/uad_masks_fl/NO6.npy'],
    uad_file_ext='npy'
)

In [None]:
y_pred = model_main.model.predict_generator(
    generator=prediction_generator,
    max_queue_size=16,
    workers=1,
    use_multiprocessing=False,
    verbose=0
)
print(y_pred.shape)

In [None]:
net_ip, op = prediction_generator.__getitem__(131)
print(net_ip.shape)
print(op.shape)
lname = 'fusion_boost_avg'
activations = get_activations(model_main.model, net_ip, layer_name=lname)[lname]
print(activations.shape)

In [None]:
plt.imshow(np.hstack([net_ip[0, ..., 15], activations[0, ..., 0], op[0, ..., 0]]))

In [None]:
plt.imshow(activations[0, ..., 0])

In [31]:
clear_keras_memory()
set_keras_memory(1.0)

loss_function = suloss.mixed_loss(l1_lambda=1.0, ssim_lambda=0.0)
metrics_monitor = [suloss.l1_loss]

model_br = GeneratorBranchUNet2D(
            num_channel_input=35, num_channel_output=1,
            img_rows=240, img_cols=240,
            loss_function=loss_function,
            metrics_monitor=metrics_monitor,
            lr_init=0.001,
            verbose=1,
            compile_model=False,
            model_config='simple_weighted')

Building branch_unet2d-simple_weighted model...
inputs Tensor("model_input:0", shape=(?, 240, 240, 35), dtype=float32)
t1 pre Tensor("t1_pre/strided_slice:0", shape=(?, 240, 240, 7), dtype=float32)
t1 low Tensor("t1_low/strided_slice:0", shape=(?, 240, 240, 7), dtype=float32)
t2 Tensor("t2/strided_slice:0", shape=(?, 240, 240, 7), dtype=float32)
fl Tensor("fl/strided_slice:0", shape=(?, 240, 240, 7), dtype=float32)
uad Tensor("uad/strided_slice:0", shape=(?, 240, 240, 7), dtype=float32)
Tensor("relu_b1_conv_enc_1_2/Relu:0", shape=(?, 240, 240, 32), dtype=float32) Tensor("b1_maxpool_1/MaxPool:0", shape=(?, 120, 120, 32), dtype=float32)
Tensor("relu_b1_conv_enc_2_2/Relu:0", shape=(?, 120, 120, 64), dtype=float32) Tensor("b1_maxpool_2/MaxPool:0", shape=(?, 60, 60, 64), dtype=float32)
Tensor("relu_b1_conv_enc_3_2/Relu:0", shape=(?, 60, 60, 128), dtype=float32) Tensor("b1_maxpool_3/MaxPool:0", shape=(?, 30, 30, 128), dtype=float32)
conv center before add Tensor("relu_b1_conv_center/Relu:0",

In [21]:
model_br.model._updated_config()

{'class_name': 'Model',
 'config': {'name': 'model_1',
  'layers': [{'name': 'model_input',
    'class_name': 'InputLayer',
    'config': {'batch_input_shape': (None, 240, 240, 35),
     'dtype': 'float32',
     'sparse': False,
     'name': 'model_input'},
    'inbound_nodes': []},
   {'name': 't1_pre',
    'class_name': 'Lambda',
    'config': {'name': 't1_pre',
     'trainable': True,
     'dtype': 'float32',
     'function': ('4wEAAAAAAAAAAQAAAAUAAAATAAAAcxIAAAB8AGQBiABkAIgBhQNmAhkAUwApAk4uqQApAdoCaXAp\nAtoDaWR42ghudW1fbW9kc3IBAAAA+k8vaG9tZS9zcml2YXRoc2EvcHJvamVjdHMvU3VidGxlR2Fk\nL3RyYWluL3N1YnRsZS9kbm4vZ2VuZXJhdG9ycy9icmFuY2hfdW5ldDJkLnB52gg8bGFtYmRhPqwA\nAABzAAAAAA==\n',
      None,
      (4, 5)),
     'function_type': 'lambda',
     'output_shape': None,
     'output_shape_type': 'raw',
     'arguments': {}},
    'inbound_nodes': [[['model_input', 0, 0, {}]]]},
   {'name': 't1_low',
    'class_name': 'Lambda',
    'config': {'name': 't1_low',
     'trainable': True,
     'dtype'

In [22]:
model_br.model.to_json()

'{"class_name": "Model", "config": {"name": "model_1", "layers": [{"name": "model_input", "class_name": "InputLayer", "config": {"batch_input_shape": [null, 240, 240, 35], "dtype": "float32", "sparse": false, "name": "model_input"}, "inbound_nodes": []}, {"name": "t1_pre", "class_name": "Lambda", "config": {"name": "t1_pre", "trainable": true, "dtype": "float32", "function": ["4wEAAAAAAAAAAQAAAAUAAAATAAAAcxIAAAB8AGQBiABkAIgBhQNmAhkAUwApAk4uqQApAdoCaXAp\\nAtoDaWR42ghudW1fbW9kc3IBAAAA+k8vaG9tZS9zcml2YXRoc2EvcHJvamVjdHMvU3VidGxlR2Fk\\nL3RyYWluL3N1YnRsZS9kbm4vZ2VuZXJhdG9ycy9icmFuY2hfdW5ldDJkLnB52gg8bGFtYmRhPqwA\\nAABzAAAAAA==\\n", null, [4, 5]], "function_type": "lambda", "output_shape": null, "output_shape_type": "raw", "arguments": {}}, "inbound_nodes": [[["model_input", 0, 0, {}]]]}, {"name": "t1_low", "class_name": "Lambda", "config": {"name": "t1_low", "trainable": true, "dtype": "float32", "function": ["4wEAAAAAAAAAAQAAAAUAAAATAAAAcxIAAAB8AGQBiABkAIgBhQNmAhkAUwApAk4uqQApAdoCaXAp\\nAt