# Data Generation

Code to create data generators using preprocessed nifti data from UVMMC.


## Imports and Constants, etc.

In [None]:
import datetime
import importlib
import keras
from keras.layers import (Dense, SimpleRNN, Input, Conv1D, 
                          LSTM, GRU, AveragePooling3D, Conv3D, 
                          UpSampling3D, BatchNormalization)
from keras.models import Model
import nibabel as nib
import numpy as np
import pandas as pd
from pathlib import Path
import pickle
import projd
import random
import re
import scipy
import shutil
import sys
from sklearn.model_selection import train_test_split
import uuid

import matplotlib.pyplot as plt # data viz
import seaborn as sns # data viz

import imageio # display animated volumes
from IPython.display import Image # display animated volumes

from IPython.display import SVG # visualize model
from keras.utils.vis_utils import model_to_dot # visualize model

# for importing local code
src_dir = str(Path(projd.cwd_token_dir('notebooks')) / 'src') # $PROJECT_ROOT/src
if src_dir not in sys.path:
    sys.path.append(src_dir)

import util
importlib.reload(util)
import preprocessing
importlib.reload(preprocessing)

SEED = 0
EPOCHS = 10
BATCH_SIZE = 1
PATCH_SHAPE = (32, 32, 32)

MODEL_NAME = 'model_01'

DATA_DIR = Path('/data2').expanduser()
NORMAL_SCANS_DIR = DATA_DIR / 'uvmmc/nifti_normals'
PROJECT_DATA_DIR = DATA_DIR / 'uvm_deep_learning_project'
PP_IMG_DIR = PROJECT_DATA_DIR / 'uvmmc' / 'preprocessed' # preprocessed scans dir
PP_MD_PATH = PROJECT_DATA_DIR / 'uvmmc' / 'preprocessed_metadata.pkl'

MODELS_DIR = PROJECT_DATA_DIR / 'models'
LOG_DIR = PROJECT_DATA_DIR / 'log'
TENSORBOARD_LOG_DIR = PROJECT_DATA_DIR / 'tensorboard'
TMP_DIR = DATA_DIR / 'tmp'

for d in [DATA_DIR, NORMAL_SCANS_DIR, PROJECT_DATA_DIR, PP_IMG_DIR, MODELS_DIR, LOG_DIR, 
          TENSORBOARD_LOG_DIR, TMP_DIR, PP_MD_PATH.parent]:
    if not d.exists():
        d.mkdir(parents=True)
        
%matplotlib inline
sns.set()


## Data Generators

Data generators yield batch-sized random samples of training and validation data.  We used the keras analogue, a keras.utils.Sequence.

In [None]:
def random_crop(img, shape):
    '''
    Randomly crop an image to a shape.  Location is chosen at random from
    all possible crops of the given shape.
    
    img: a volume to crop
    shape: size of cropped volume.  e.g. (32, 32, 32)
    '''
    assert all(img.shape[i] >= shape[i] for i in range(len(shape)))
    
    # if img.shape[i] == 32 and shape[i] == 32, i_max == 0.
    maxes = [img.shape[i] - shape[i] for i in range(len(shape))]
    # the starting corner of the crop
    starts = [random.randint(0, m) for m in maxes]
    # Will this indexing work?
    cropped_img = img[[slice(starts[i], starts[i] + shape[i]) for i in range(len(shape))]]
    return cropped_img
        

def augment_image(img, crop_shape):
    return random_crop(img, crop_shape)


class ScanSequence(keras.utils.Sequence):

    def __init__(self, x_infos, batch_size, crop_shape, shuffle=True):
        '''
        x_paths: list of paths to preprocessed images
        '''
        self.x = x_infos.reset_index()
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.crop_shape = crop_shape
        # assert len(self.x) == len(self.y)

    def __len__(self):
        '''
        Return number of batches, based on batch_size
        '''
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        '''
        idx: batch index
        '''
        # loc indexing uses inclusive name-based indexing, I know I know don't ask, hence the -1.
        batch_x_paths = list(self.x.loc[idx * self.batch_size:(idx + 1) * self.batch_size - 1, 'pp_path'])
        # add channel dimension to each augmented (randomly cropped) image.
        batch_x = [np.expand_dims(augment_image(preprocessing.get_preprocessed_image(path), 
                                                crop_shape=self.crop_shape), axis=-1)
                   for path in batch_x_paths]

        # return x and y batches
        return (np.array(batch_x), np.array(batch_x))
    
    def on_epoch_end(self):
        if self.shuffle:
            self.x = self.x.sample(frac=1) # shuffle x
    

def get_datagens(preprocessed_metadata_path, batch_size, crop_shape, seed=0, validation_split=0.25):
    '''
    Return a tuple of training ScanSequence and validation ScanSequence
    '''
    # Data generator
    infos = preprocessing.read_preprocessed_metadata(preprocessed_metadata_path)
    print('Data set size:', len(infos))
    shuffled = infos.sample(frac=1, random_state=seed)
    nrow = len(shuffled)
    idx = int(nrow * validation_split)
    val = shuffled.iloc[:idx, :].reindex()
    train = shuffled.iloc[idx:, :].reindex()
    print('Validation set size:', len(val))
    print('Train set size:', len(train))
    train_gen = ScanSequence(train, batch_size, crop_shape)
    val_gen = ScanSequence(val, batch_size, crop_shape)
    return train_gen, val_gen



### Testing and Validating Functions

In [None]:
# Test that the random crop is producing what look like random crops.
img = preprocessing.get_preprocessed_image(preprocessing.read_preprocessed_metadata(PP_MD_PATH).loc[0, 'pp_path'])
display(animate_crop(img, step=1))
for i in range(5):
    display(animate_crop(random_crop(img, PATCH_SHAPE), step=1))

In [None]:
# test getting a batch of data from ScanSequence
seq, _ = get_datagens(preprocessed_metadata_path=PP_MD_PATH, batch_size=BATCH_SIZE, crop_shape=PATCH_SHAPE)
print(len(seq))

In [None]:
batch_x, batch_y = seq[0]

In [None]:
# test that a batch picture looks like a preprocessed image.
print(batch_x.shape, batch_y.shape)
display(animate_crop(batch_x[0, :, :, :, 0])) # drop the example and channel dimensions
display(animate_crop(batch_y[0, :, :, :, 0]))

### Examine preprocessed metadata for any weirdness

Found one scan, for id 082222_190, with a bogus shape (only one slice).

In [None]:
infos = preprocessing.read_preprocessed_metadata(PP_MD_PATH)

In [None]:
infos[['pp_dim0', 'pp_dim1', 'pp_dim2']].describe()

In [None]:
infos[infos['pp_dim0'] == 500]

In [None]:
infos = infos[infos['id'] != '082222_190']

In [None]:
infos[['pp_dim0', 'pp_dim1', 'pp_dim2']].describe()


In [None]:
infos[(infos['pp_dim0'] == 234) | (infos['pp_dim1'] == 280) | (infos['pp_dim2'] == 278)]
# infos[infos['pp_dim1'] == 280]
# infos[infos['pp_dim0'] == 278]
