# Visualize the bounding boxes from Model

# REMOTE COMPUTER

In [None]:
import sys
sys.path.append('..')
from modules.mischelpers import get_data_splits_bbox, open_hdf5, createDense
from modules.configfile import config
import importlib
import optparse
import os
import logging
from mayavi import mlab
import numpy as np

logging.basicConfig(level=logging.INFO)
try:
    logger = logging.getLogger(__file__.split('/')[-1])
except:
    logger = logging.getLogger(__name__)

parser = optparse.OptionParser()
parser.add_option('--dm', '--defmodelfile',
                  dest="defmodelfile",
                  default=None,
                  type='str'
                  )
parser.add_option('--i', '--in-name',
                  dest="input_name",
                  default=None,
                  type='str'
                  )
parser.add_option('--t', '--type',
                  dest="type",
                  default=None,
                  type='str'
                  ) # either checkpoint or snapshot

### Set arguments

In [None]:
options, remainder = parser.parse_args(['--dm', 'cnn', '--i', 'original_tuning_dice.h501--0.00.h5', '--t', 'checkpoints'])

In [None]:
if options.input_name == None:
    logger.info('No input name defined, using default values')
    if options.type == 'snapshots':
        options.input_name = '/local-scratch/cedar-rm/scratch/asa224/model-snapshots/model_default.h5'
    else:
        options.input_name = '/local-scratch/cedar-rm/scratch/asa224/model-checkpoints/model_default.h5'
    logger.info('Name of input file: {}'.format(options.input_name))
else:
    if options.type == 'snapshots':
        options.input_name = os.path.join('/local-scratch/cedar-rm/scratch/asa224/model-snapshots/', options.input_name)
    else:
        options.input_name = os.path.join('/local-scratch/cedar-rm/scratch/asa224/model-checkpoints/', options.input_name)
    logger.info('Name of input file: {}'.format(options.input_name))

if options.defmodelfile == None:
    logger.info('No defmodel file name defined, using default model (cnn)')
    options.defmodelfile = 'cnn'

modeldefmodule = importlib.import_module('defmodel.' + options.defmodelfile, package=None)

In [None]:
# config['hdf5_filepath_prefix'] = '/local-scratch/cedar-rm/scratch/asa224/Datasets/BRATS2017/BRATS_vanilla.h5'

In [None]:
x_train, y_train, x_test, y_test = get_data_splits_bbox(config['hdf5_filepath_prefix'],
                                                        train_start=0, train_end=190, test_start=190, test_end=None)

In [None]:
if options.type == 'snapshots':
    model, hyparams, history = modeldefmodule.open_model_with_hyper_and_history(name=options.input_name)
else:
    from keras.models import load_model
    import keras.backend as K
    import tensorflow as tf
    def bbox_overlap_iou(y_true, y_pred):
        # split the bounding boxes into individual tensors
        xmin_true, xmax_true, ymin_true, ymax_true, zmin_true, zmax_true = tf.split(y_true, 6, axis=1)
        xmin_pred, xmax_pred, ymin_pred, ymax_pred, zmin_pred, zmax_pred = tf.split(y_pred, 6, axis=1)

        dx = K.minimum(xmax_true, xmax_pred) - K.maximum(xmin_true, xmin_pred)
        dy = K.minimum(ymax_true, ymax_pred) - K.maximum(ymin_true, ymin_pred)
        dz = K.minimum(zmax_true, zmax_pred) - K.maximum(zmin_true, zmin_pred)

        intersection = dx * dy * dz

        # find the total volume and then find the union
        vol_true = (xmax_true - xmin_true) * (ymax_true - ymin_true) * (ymax_true - ymin_true)
        vol_pred = (xmax_pred - xmin_pred) * (ymax_pred - ymin_pred) * (ymax_pred - ymin_pred)

        # find the union now
        union = vol_true + vol_pred - intersection

        iou = intersection / union

        return -iou
    model = load_model(options.input_name, custom_objects={'bbox_overlap_iou': bbox_overlap_iou})

In [None]:
model.summary()

In [None]:
model.evaluate(x_test, y_test, batch_size=1)

In [None]:
import numpy as np
y_pred = np.empty((20,6))
for i in range(x_test.shape[0]):
    y_pred[i] = model.predict(x_test[i].reshape(1,4,240,240,155))
    logger.info('.')

In [None]:
np.save(open('/local-scratch/cedar-rm/scratch/asa224/model-checkpoints/y_pred.npy', 'w'), y_pred)