# Visualize the bounding boxes from Model

# REMOTE COMPUTER

In [1]:
import sys
sys.path.append('..')
from modules.mischelpers import get_data_splits_bbox, open_hdf5, createDense
from modules.configfile import config
import importlib
import optparse
import os
import logging
from mayavi import mlab
import numpy as np

logging.basicConfig(level=logging.INFO)
try:
    logger = logging.getLogger(__file__.split('/')[-1])
except:
    logger = logging.getLogger(__name__)

parser = optparse.OptionParser()
parser.add_option('--dm', '--defmodelfile',
                  dest="defmodelfile",
                  default=None,
                  type='str'
                  )
parser.add_option('--i', '--in-name',
                  dest="input_name",
                  default=None,
                  type='str'
                  )
parser.add_option('--t', '--type',
                  dest="type",
                  default=None,
                  type='str'
                  ) # either checkpoint or snapshot

<Option at 0x7f849cb2d248: --t/--type>

### Set arguments

In [2]:
options, remainder = parser.parse_args(['--dm', 'cnn', '--i', 'original_tuning_dice.h501--0.00.h5', '--t', 'checkpoints'])

In [3]:
if options.input_name == None:
    logger.info('No input name defined, using default values')
    if options.type == 'snapshots':
        options.input_name = '/local-scratch/cedar-rm/scratch/asa224/model-snapshots/model_default.h5'
    else:
        options.input_name = '/local-scratch/cedar-rm/scratch/asa224/model-checkpoints/model_default.h5'
    logger.info('Name of input file: {}'.format(options.input_name))
else:
    if options.type == 'snapshots':
        options.input_name = os.path.join('/local-scratch/cedar-rm/scratch/asa224/model-snapshots/', options.input_name)
    else:
        options.input_name = os.path.join('/local-scratch/cedar-rm/scratch/asa224/model-checkpoints/', options.input_name)
    logger.info('Name of input file: {}'.format(options.input_name))

if options.defmodelfile == None:
    logger.info('No defmodel file name defined, using default model (cnn)')
    options.defmodelfile = 'cnn'

modeldefmodule = importlib.import_module('defmodel.' + options.defmodelfile, package=None)

INFO:__main__:Name of input file: /local-scratch/cedar-rm/scratch/asa224/model-checkpoints/original_tuning_dice.h501--0.00.h5
Using TensorFlow backend.
INFO:cnn.pyc:Setting keras backend data format to "channels_first"


In [4]:
x_train, y_train, x_test, y_test = get_data_splits_bbox(config['hdf5_filepath_prefix'],
                                                        train_start=0, train_end=190, test_start=190, test_end=None)

Load the y_pred variable on local computer to start visualization

In [5]:
y_pred = np.load(open('/local-scratch/cedar-rm/scratch/asa224/model-checkpoints/y_pred.npy', 'r'))

In [6]:
y_pred

array([[ -1.51396794e+11,   1.25387186e+11,  -4.67916841e+10,
          8.79261696e+10,  -2.14184919e+10,   9.34011863e+10],
       [ -1.36416715e+11,   1.12954319e+11,  -4.21589320e+10,
          7.92078828e+10,  -1.92917176e+10,   8.41234842e+10],
       [ -2.61060428e+11,   2.16307302e+11,  -8.08439316e+10,
          1.51682253e+11,  -3.69366057e+10,   1.61233076e+11],
       [ -1.78053169e+11,   1.47583894e+11,  -5.52360018e+10,
          1.03437279e+11,  -2.52157952e+10,   1.10026047e+11],
       [ -1.42638088e+11,   1.18143066e+11,  -4.41745736e+10,
          8.28660367e+10,  -2.01822638e+10,   8.80596337e+10],
       [ -2.15181328e+11,   1.78343166e+11,  -6.66850181e+10,
          1.25010305e+11,  -3.04671293e+10,   1.32928954e+11],
       [ -9.16650312e+10,   7.59635395e+10,  -2.84251361e+10,
          5.32652442e+10,  -1.29778647e+10,   5.66572237e+10],
       [ -8.73351414e+11,   7.24095795e+11,  -2.71122596e+11,
          5.07653652e+11,  -1.23664515e+11,   5.40160786e+11],


In [7]:
np.shape(y_pred)

(20, 6)

In [8]:
((y_pred - y_test) ** 2).mean()

577.67652179459094

# Visualize interactively in individual Mayavi windows

In [8]:
# start 3D visualization
# open the dataset
hdf5_file = open_hdf5()
t = 0
for i in range(190, 210):
    # lets get a segmentation masks for the test data
    seg = hdf5_file["training_data_segmasks_hgg"][i]

    # let's get its bounding box
    bbox = hdf5_file["bounding_box_hgg"][i]

    dense_bbox_orig = createDense(bbox, seg)
    dense_bbox_pred = createDense(y_pred[t].astype('int32'), seg)
    
    mlab.contour3d(seg, contours=[1])
    mlab.contour3d(dense_bbox_orig, opacity=0.1, color=(1, 0, 0), transparent=True)
    mlab.contour3d(dense_bbox_pred, opacity=0.1, color=(0, 1, 0), transparent=True)
    mlab.show()
    t += 1

ValueError: Index (210) out of range (0-209)

# Visualize in Matplotlib using subplots and screenshots of 3D scenes

In [9]:
# start 3D visualization
import matplotlib.pyplot as plt
%matplotlib
# open the dataset
hdf5_file = open_hdf5()
t = 0
start_test = 190
end_test= 210

num_subplots = end_test - start_test - 1

fig, ax = plt.subplots(nrows=5, ncols=4, sharex=False, sharey=False, squeeze=False, figsize=(50, 30))

ax = [i for ls in ax for i in ls]

for i in range(start_test, end_test):
    # lets get a segmentation masks for the test data
    seg = hdf5_file["training_data_segmasks_hgg"][i]
    pat_name = hdf5_file["training_data_hgg_pat_name"][i]
    
    # let's get its bounding box
    bbox = hdf5_file["bounding_box_hgg"][i]

    dense_bbox_orig = createDense(bbox, seg)
    dense_bbox_pred = createDense(y_pred[t].astype('int32'), seg)
    
    # start the visualization process
    mfig = mlab.figure() # size=(1024, 1024)
    
    mlab.contour3d(seg, contours=[1])
    mlab.contour3d(dense_bbox_orig, opacity=0.1, color=(1, 0, 0), transparent=True)
    mlab.contour3d(dense_bbox_pred, opacity=0.1, color=(0, 1, 0), transparent=True)
    
    comparison = mlab.screenshot(figure=mfig, mode='rgba', antialiased=True)
    mlab.clf(mfig)
    mlab.close()

    # Then later in a matplotlib fig:
    ax[t].imshow(comparison)
    ax[t].axis('off')
    ax[t].set_title(pat_name)
    plt.suptitle('Test Subjects')
#     plt.tight_layout
    logger.info('.')
    t += 1
plt.show()

Using matplotlib backend: Qt4Agg


INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
INFO:__main__:.
