# Preprocessing xVertSeg



## Imports and Constants, etc.

In [None]:
import datetime
import importlib
import keras
from keras.layers import (Dense, SimpleRNN, Input, Conv1D, 
                          LSTM, GRU, AveragePooling3D, MaxPooling3D, GlobalMaxPooling3D,
                          Conv3D, UpSampling3D, BatchNormalization, Concatenate, Add)
from keras.models import Model
import nibabel as nib
import numpy as np
import pandas as pd
from pathlib import Path
import pickle
import projd
import random
import re
import scipy
import shutil
import SimpleITK
import sys
from sklearn.model_selection import train_test_split
import uuid

import matplotlib.pyplot as plt # data viz
import seaborn as sns # data viz

import imageio # display animated volumes
from IPython.display import Image # display animated volumes

from IPython.display import SVG # visualize model
from keras.utils.vis_utils import model_to_dot # visualize model

# for importing local code
src_dir = str(Path(projd.cwd_token_dir('notebooks')) / 'src') # $PROJECT_ROOT/src
if src_dir not in sys.path:
    sys.path.append(src_dir)

import util
importlib.reload(util)
import preprocessing
importlib.reload(preprocessing)
import datagen
importlib.reload(datagen)
import modelutil
importlib.reload(modelutil)

SEED = 0
EPOCHS = 100
BATCH_SIZE = 1
PATCH_SHAPE = (32, 32, 32)

MODEL_NAME = 'model_09'

DATA_DIR = Path('/data2').expanduser()
NORMAL_SCANS_DIR = DATA_DIR / 'uvmmc/nifti_normals'
PROJECT_DATA_DIR = DATA_DIR / 'uvm_deep_learning_project'
PP_IMG_DIR = PROJECT_DATA_DIR / 'uvmmc' / 'preprocessed' # preprocessed scans dir
PP_MD_PATH = PROJECT_DATA_DIR / 'uvmmc' / 'preprocessed_metadata.pkl'

PP_XVERTSEG_IMG_DIR = PROJECT_DATA_DIR / 'xVertSeg.v1' / 'preprocessed' # preprocessed scans dir
PP_XVERTSEG_PATH = PROJECT_DATA_DIR / 'xVertSeg.v1' / 'preprocessed_metadata.pkl'


MODELS_DIR = PROJECT_DATA_DIR / 'models'
LOG_DIR = PROJECT_DATA_DIR / 'log'
TENSORBOARD_LOG_DIR = PROJECT_DATA_DIR / 'tensorboard' / MODEL_NAME
TMP_DIR = DATA_DIR / 'tmp'

for d in [DATA_DIR, NORMAL_SCANS_DIR, PROJECT_DATA_DIR, PP_IMG_DIR, MODELS_DIR, LOG_DIR, 
          TENSORBOARD_LOG_DIR, TMP_DIR, PP_MD_PATH.parent, PP_XVERTSEG_IMG_DIR, PP_XVERTSEG_PATH.parent]:
    if not d.exists():
        d.mkdir(parents=True)
        
%matplotlib inline
sns.set()

%load_ext autoreload
%autoreload 2

## Read Data

In [None]:
XVERTSEG_DIR = DATA_DIR / 'xVertSeg.v1'

def get_mhd_raw_id(d):
    '''
    d: a Path, a directory containing paired MetaImage format files.
    returns a triple of a list of mhd files, of raw files, and of xvertseg scan ids.
    '''
        
    mhds = [str(p) for p in list(d.glob('*.mhd'))]
    ids = [int(re.search(r'.*?(\d\d\d)\.mhd$', p).group(1)) for p in mhds]
    raws = [d / re.sub(r'\.mhd$', '.raw', p) for p in mhds]
    return mhds, raws, ids


def get_xvertseg_infos(xvertseg_dir):
    '''
    Build a dataframe with columns: id, dataset, image_mhd, image_raw, mask_mhd, mask_raw, and labeled.
    id is the number embedded in the xvertseg filenames.  xvertseg is split into 2 datasets, data1 and data2.
    data1 is labeled, meaning it has segmentation masks.  data2 only has images.
    
    There are 15 labeled images and 10 unlabeled images.
    
    data_dir: the xVertSeg1.v1/Data1 dir, as a Path.
    return: dataframe. 
    '''
    # filename examples
    # image016.mhd
    # image016.raw
    # mask001.mhd
    # mask001.raw
    
    # Data1 has 15 images and masks (labeled data)
    # Data2 has 10 test images with no mask.  Unlabeled data.
    data1_dir = xvertseg_dir / 'Data1'
    idir1 = data1_dir / 'images'
    mdir1 = data1_dir / 'masks'
    data2_dir = xvertseg_dir / 'Data2'
    idir2 = data2_dir / 'images'
    
    img1_mhds, img1_raws, img1_ids = get_mhd_raw_id(idir1)
    img1_df = pd.DataFrame({'id': img1_ids, 'image_mhd': img1_mhds, 'image_raw': img1_raws})
    mask1_mhds, mask1_raws, mask1_ids = get_mhd_raw_id(mdir1)
    mask1_df = pd.DataFrame({'id': mask1_ids, 'mask_mhd': mask1_mhds, 'mask_raw': mask1_raws})
    img2_mhds, img2_raws, img2_ids = get_mhd_raw_id(idir2)
    img2_df = pd.DataFrame({'id': img2_ids, 'image_mhd': img2_mhds, 'image_raw': img2_raws})
    img2_df['dataset'] = ['data2'] * len(img2_df)
    
    df = img1_df.merge(mask1_df, on='id')
    df['dataset'] = ['data1'] * len(df)
    df = pd.concat([df, img2_df]).sort_values('id').reset_index(drop=True)
    return df




In [None]:
df = get_xvertseg_infos(XVERTSEG_DIR)

## Visualize Data


In [None]:
def load_xvertseg_img(path):
    # https://github.com/juliandewit/kaggle_ndsb2017/blob/master/step1_preprocess_luna16.py
    itk = SimpleITK.ReadImage(path)
    img = SimpleITK.GetArrayFromImage(itk)
    return img, itk




### Look at an image and mask

In [None]:
img, itk = load_xvertseg_img(df.loc[0, 'image_mhd'])

In [None]:
util.animate_crop(img, crop=(0.0, 1, 0.5, 0.8, 0.3, 0.6), step=20)

In [None]:
mask, mitk = load_xvertseg_img(df.loc[0, 'mask_mhd'])

In [None]:
util.animate_crop(mask, crop=(0.0, 1, 0.5, 0.8, 0.3, 0.6), step=20)

### Look at mask

The mask has 6 unique values: 0, 200, 210, 220, 230, 240.  These correspond to background and the vertebrae l1, l2, ..., l5, I think.


In [None]:
np.unique(mask.ravel())

In [None]:
plt.hist(mask.ravel())
plt.show()
# looks like a typical ct scan in hounsfield units...or does it?  No -1000 values?  Looks like the units are hounsfield + 1000.
plt.hist(img.ravel(), bins=50)
plt.show()

## Resample/Resize Data

In [None]:
# https://github.com/juliandewit/kaggle_ndsb2017/blob/master/step1_preprocess_luna16.py


def normalize_xvertseg_image(img):
    '''
    img: an xvertseg xyz oriented image.
    
    Before normalization, 
    Normalize voxel units by clipping them to lie between -1000 and 1000 hounsfield units 
    and then scale number to between 0 and 1.
    '''
    MIN_BOUND = 0000.0 # Air: -1000, Water: 0 hounsfield units.
    MAX_BOUND = 2000.0 # Bone: 200, 700, 3000.  https://en.wikipedia.org/wiki/Hounsfield_scale
    image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
    image[image > 1] = 1.
    image[image < 0] = 0.
    return image




def plot_image_historgrams():
    infos = get_xvertseg_infos(XVERTSEG_DIR)
    for i in range(len(infos)):
        img_zyx, itk = load_xvertseg_img(infos.loc[i, 'image_mhd'])
        img = np.swapaxes(img_zyx, 0, 2) # swap z and x.
        plt.hist(img.ravel(), 256)
        plt.title('image histogram for id ' + str(infos.loc[i, 'id']))
        plt.show()
        

def get_preprocessed_xvertseg_image_path(id, preprocessed_dir):
    return str(Path(preprocessed_dir, f'image{id:03}.npy'))


def get_preprocessed_xvertseg_mask_path(id, preprocessed_dir):
    return str(Path(preprocessed_dir, f'mask{id:03}.npy'))


def resample_xvertseg_test(data_dir, out_dir, metadata_only=False):
    infos = get_xvertseg_infos(data_dir)
    for i in range(1): # range(len(infos)):
        
        img_zyx, itk = load_xvertseg_img(infos.loc[i, 'image_mhd'])
        img = np.swapaxes(img_zyx, 0, 2) # swap z and x.

        origin = np.array(itk.GetOrigin())      # x,y,z  Origin in world coordinates (mm)
        spacing = np.array(itk.GetSpacing())    # spacing of voxels in world coor. (mm)
        direction = np.array(itk.GetDirection())  
        print('img shape:', img.shape)
        print('img origin:', origin)
        print('img spacing:', spacing)
        print('img direction:', direction)

        # resample image
        target_spacing = (1.0, 1.0, 1.0)
        print('image spacing:', spacing)
        print('new spacing:', target_spacing)
        resampled_img, resampled_spacing = resample_image(img, spacing, target_spacing, metadata_only=metadata_only)
        print('resampled image spacing:', resampled_spacing)
        print('resampled image shape:', resampled_img.shape)
        infos.loc[i, 'pp_image_pixdim0'] = resampled_spacing[0]
        infos.loc[i, 'pp_image_pixdim1'] = resampled_spacing[1]
        infos.loc[i, 'pp_image_pixdim2'] = resampled_spacing[2]
        infos.loc[i, 'pp_image_dim0'] = resampled_img.shape[0]
        infos.loc[i, 'pp_image_dim1'] = resampled_img.shape[1]
        infos.loc[i, 'pp_image_dim2'] = resampled_img.shape[2]

        # Normalize voxel intensities
        if not metadata_only:
            normalized_img = normalize_xvertseg_image(resampled_img)
            print('Normalized image shape:', normalized_img.shape)
        
        # save processed image
        path = get_preprocessed_image_path(scan_id, dest_dir)
        print(f'Saving preprocessed image to {path}.')
        if not metadata_only:
            np.save(path, normalized_img)

        # resample mask
        if infos.loc[i, 'image_mhd'].notna():
            mimg_zyx, mitk = load_xvertseg_img(infos.loc[0, 'mask_mhd'])
            mimg = np.swapaxes(mimg_zyx, 0, 2)
            mask_spacing = np.array(mitk.GetSpacing()) # xyz spacing
            print('unique mask values:', np.unique(mimg.ravel()))
            resampled_mask, resampled_spacing = resample_image(mimg, mask_spacing, target_spacing,
                                                               metadata_only=metadata_only)
            print('resampled image spacing:', resampled_spacing)
            print('resampled image shape:', resampled_mask.shape)
            infos.loc[i, 'pp_mask_pixdim0'] = resampled_spacing[0]
            infos.loc[i, 'pp_mask_pixdim1'] = resampled_spacing[1]
            infos.loc[i, 'pp_mask_pixdim2'] = resampled_spacing[2]
            infos.loc[i, 'pp_mask_dim0'] = resampled_mask.shape[0]
            infos.loc[i, 'pp_mask_dim1'] = resampled_mask.shape[1]
            infos.loc[i, 'pp_mask_dim2'] = resampled_mask.shape[2]
        
        
def resample_xvertseg_image_test(data_dir, dest_dir, num=None, metadata_only=False):
    '''
    data_dir: Path to an xVertSeg.v1 dir.
    out_dir: where to save preprocessed images and masks.
    '''
    infos = get_xvertseg_infos(data_dir)
    if num is None:
        num = len(infos)
        
    for i in range(num):
        print('i:', i)
        
        img_zyx, itk = load_xvertseg_img(infos.loc[i, 'image_mhd'])
        img = np.swapaxes(img_zyx, 0, 2) # swap z and x.

        origin = np.array(itk.GetOrigin())      # x,y,z  Origin in world coordinates (mm)
        spacing = np.array(itk.GetSpacing())    # spacing of voxels in world coor. (mm)
        direction = np.array(itk.GetDirection())  
        print('img shape:', img.shape)
        print('img origin:', origin)
        print('img spacing:', spacing)
        print('img direction:', direction)

        # resample image
        target_spacing = (1.0, 1.0, 1.0)
        print('image spacing:', spacing)
        print('new spacing:', target_spacing)
        resampled_img, resampled_spacing = preprocessing.resample_image(img, spacing, target_spacing, metadata_only=metadata_only)
        print('resampled image spacing:', resampled_spacing)
        print('resampled image shape:', resampled_img.shape)
        infos.loc[i, 'pp_image_pixdim0'] = resampled_spacing[0]
        infos.loc[i, 'pp_image_pixdim1'] = resampled_spacing[1]
        infos.loc[i, 'pp_image_pixdim2'] = resampled_spacing[2]
        infos.loc[i, 'pp_image_dim0'] = resampled_img.shape[0]
        infos.loc[i, 'pp_image_dim1'] = resampled_img.shape[1]
        infos.loc[i, 'pp_image_dim2'] = resampled_img.shape[2]

        # Normalize voxel intensities
        if not metadata_only:
            normalized_img = normalize_xvertseg_image(resampled_img)
            print('Normalized image shape:', normalized_img.shape)
        
        # save processed image
        print(infos.head())
        path = get_preprocessed_xvertseg_image_path(infos.loc[i, 'id'], dest_dir)
        print(f'Saving preprocessed image to {path}.')
        if not metadata_only:
            np.save(path, normalized_img)

        return infos
    
    


In [None]:
infos = resample_xvertseg_image_test(XVERTSEG_DIR, PP_XVERTSEG_IMG_DIR, num=2, metadata_only=True)

In [None]:
infos = get_xvertseg_infos(XVERTSEG_DIR)
plt.hist(img.ravel())
plt.title('infos' + str(infos.loc[0, 'id']))
plt.show()


In [None]:
# examine original spacing and resampled spacing of mask using different spline orders to see if that will help
# the blurring that causes when resizing the mask.
spacing = np.array(mitk.GetSpacing())
target_spacing = (1.0, 1.0, 1.0)
nmask, nspacing = preprocessing.resample_mask(mask, spacing, target_spacing, order=3)
nmask5, nspacing5 = preprocessing.resample_mask(mask, spacing, target_spacing, order=5)
nmask0, nspacing0 = preprocessing.resample_mask(mask, spacing, target_spacing, order=0)
print(spacing, '->', target_spacing, '->', nspacing)

In [None]:
print(mask.shape)
print(nmask.shape)
display(util.animate_crop(nmask, step=30))

In [None]:
print(nmask5.shape)
display(util.animate_crop(nmask5, step=30))

In [None]:
print(nmask0.shape)
display(util.animate_crop(nmask0, step=30))

In [None]:
rav = nmask.ravel()
prav = rav[rav > 0]

plt.hist(rav, bins=256) # looks the same
plt.show()
plt.hist(prav, bins=256) # only > 0.  spikes at 200, 210, 220, 230, 240.  Also a smaller spike at 255.
plt.show()
print(pd.unique(rav)) # except every number from 0 to 256 is present
plt.hist(pd.unique(rav), bins=256) # every number is present.
plt.show()

In [None]:
rav0 = nmask0.ravel()
prav0 = rav0[rav0 > 0]
rav5 = nmask5.ravel()
prav5 = rav5[rav5 > 0]
plt.hist(rav0, bins=256) # looks the same
plt.show()
plt.hist(prav0, bins=256) # only > 0.  spikes at 200, 210, 220, 230, 240.  Also a smaller spike at 255.
plt.show()
plt.hist(rav5, bins=256) # looks the same
plt.show()
plt.hist(prav5, bins=256) # only > 0.  spikes at 200, 210, 220, 230, 240.  Also a smaller spike at 255.
plt.show()


In [None]:
plot_image_historgrams()


## Normalization

## Metadata

## Preprocess xVertSeg

In [None]:
def build_residual_encoder_decoder_block(x, n_a, n_d=1, use_bn=True):

    x = batchnorm_conv_block(x, n_a, use_bn=use_bn)
    
    if n_d > 0:
        x_e = x # shape: (32, 32, 32, 16)
        x_e = MaxPooling3D(padding='same')(x_e) # shape: (16, 16, 16, 16)
        x_e = build_residual_encoder_decoder_block(x_e, n_a, n_d - 1, use_bn=use_bn) # recursive call
        x_d = UpSampling3D()(x_e) # shape (32, 32, 32, 16)
        x = Concatenate()([x, x_d]) # residual join.  shape (32, 32, 32, 32)
        x = batchnorm_conv_block(x, n_a, use_bn=use_bn)
    
    return x


def batchnorm_conv_block(x, n_a, use_bn=True):
    if use_bn:
        x = BatchNormalization()(x)
        
    x = Conv3D(n_a, kernel_size=(3, 3, 3), padding='same', activation='relu')(x) # shape: (32, 32, 32, 1) = 32768
    return x


def build_residual_block(x, n_a, n_l=1, use_bn=True):
    '''
    n_l: number of layers/convolutions in the residual path.
    '''
    x_r = x
    for i in range(n_l):
        x_r = batchnorm_conv_block(x_r, n_a, use_bn=use_bn)
        
    x = Add()([x, x_r])  
    return x


def build_downsampling_conv_block(x, n_a, use_bn=True):
    if use_bn:
        x = BatchNormalization()(x)
        
    x = Conv3D(n_a, kernel_size=(3, 3, 3), strides=(2, 2, 2), padding='same', activation='relu')(x) 
    return x
    
    
def build_model(input_shape, n_a=16, n_r=2, n_d=4, use_bn=True):
    '''
    3D convolutional autoencoder that treats u-net architecture as a residual block.
    
    1 poolings reduce input from shape to shape/2, which in 3d is 1/8th the size of the original shape,
    a very respectable compression factor.
    '''

    x_input = Input(shape=input_shape)
    x = x_input

    x = build_downsampling_conv_block(x, n_a=n_a*2, use_bn=use_bn)
    x = build_residual_block(x, n_a=n_a*2, n_l=1, use_bn=use_bn) 

    # u-net
#    x = build_residual_encoder_decoder_block(x, n_a=(n_a//2), n_d=n_d)
    
    # upsample for autoencoder
#     x_ae = UpSampling3D()(x)
#     x_ae = batchnorm_conv_block(x_ae, n_a=n_a)
#     y_ae = Conv3D(1, kernel_size=(3, 3, 3), padding='same', activation='sigmoid')(x)

    x = build_downsampling_conv_block(x, n_a=n_a*4, use_bn=use_bn)
    x = build_residual_block(x, n_a=n_a*4, n_l=1, use_bn=use_bn) 

    x = build_downsampling_conv_block(x, n_a=n_a*8, use_bn=use_bn)
    x = build_residual_block(x, n_a=n_a*8, n_l=1, use_bn=use_bn) 

    x = build_downsampling_conv_block(x, n_a=n_a*16, use_bn=use_bn)
    x = build_residual_block(x, n_a=n_a*16, n_l=1, use_bn=use_bn) 

    x = build_downsampling_conv_block(x, n_a=n_a*32, use_bn=use_bn)
    x = build_residual_block(x, n_a=n_a*32, n_l=1, use_bn=use_bn) 
    x = build_residual_block(x, n_a=n_a*32, n_l=1, use_bn=use_bn)
     
    # pool and predict
    x = GlobalMaxPooling3D()(x)
    if use_bn:
        x = BatchNormalization()(x)
        
    x = Dense(n_a*16, activation='relu')(x)
    y_frac = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=x_input, outputs=y_frac)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model
   
   

In [None]:
model = build_model(input_shape=(None, None, None, 1,), n_a=4, n_r=4, n_d=4, use_bn=False)

In [None]:
print(model.summary())


In [None]:
SVG(model_to_dot(model).create(prog='dot', format='svg'))


## Train and Evaluate Model

- Add callbacks to save model every 20 epochs and to log performance stats every epoch, so we have the results saved somewhere for charting.


In [None]:
# history, log_path = modelutil.train_model(
#     model, train_gen, val_gen, epochs=40, batch_size=BATCH_SIZE, models_dir=MODELS_DIR, model_name=MODEL_NAME, 
#     log_dir=LOG_DIR, tensorboard_log_dir=TENSORBOARD_LOG_DIR, max_queue_size=20, use_multiprocessing=True, 
#     class_weight={0: 1, 1: 5})
history, log_path = modelutil.train_model_epoch(train_gen, val_gen, epoch=40, epochs=200, batch_size=BATCH_SIZE, models_dir=MODELS_DIR, model_name=MODEL_NAME, 
    log_dir=LOG_DIR, tensorboard_log_dir=TENSORBOARD_LOG_DIR, max_queue_size=20, use_multiprocessing=True, 
    class_weight={0: 1, 1: 5})

## Visualize Training Progress

In [None]:
# read metrics from the log file
# log_path = LOG_DIR / (model_name + '_2018-04-26T17:29:02.902740_log.csv')
log_path = Path('/data2/uvm_deep_learning_project/log/model_09_2018-04-28T02:02:18.169239_log.csv')
metrics = pd.read_csv(log_path)

In [None]:
print(pd.concat([metrics[::10], metrics[-1:]])) # every 10th metric and the last one

In [None]:
# Plot Training and Validation Accuracy 
axes = plt.gca()
axes.set_ylim([0.0,1.0]) # Show results on 0..1 range
plt.plot(metrics["acc"])
plt.plot(metrics["val_acc"])
plt.legend(['Training Accuracy', "Validation Accuracy"])
plt.show()

# Plot Training and Validation Loss
plt.plot(metrics["loss"])
plt.plot(metrics["val_loss"])
plt.legend(['Training Loss', "Validation Loss"])
plt.show()



### Confusion Matrix Results Over Time

Visualize how the results of the model improve over time.


In [None]:
# confusion_matrix_by_epochs()
modelutil.confusion_matrix_by_epochs(MODELS_DIR, MODEL_NAME, [1, 10, 200], val_gen)
    