In [None]:
import datetime
import imageio # creating animated gif
import importlib # reloading module.  maybe should just use ipython magic
from IPython.display import Image # display animated gif
from IPython.display import SVG # visualize model
import keras
from keras.layers import Dense, SimpleRNN, Input, Conv1D
from keras.models import Model
from keras.utils.vis_utils import model_to_dot # visualize model
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection # what is this for?
import nibabel as nib # work with nifti files
from nibabel.testing import data_path
import numpy as np
import os
import pandas as pd
from pathlib import Path
import projd # finding rootdir of project for setting python path
import pydicom # for reading dicom files
import random
import re
import seaborn as sns
import scipy.ndimage # image resizing
from scipy.ndimage.interpolation import rotate
from skimage import morphology
from skimage import measure
from skimage.transform import resize
from sklearn.model_selection import train_test_split
import sys
import uuid

# for importing local code
src_dir = str(Path(projd.cwd_token_dir('notebooks')) / 'src') # $PROJECT_ROOT/src
if src_dir not in sys.path:
    sys.path.append(src_dir)

# import datagen
# importlib.reload(datagen)

np.set_printoptions(precision=2, suppress=True)
%matplotlib inline
sns.set()

## Constants

In [None]:
data_dir = Path('~/data/2018').expanduser()
normal_scans_dir = data_dir / 'uvmmc/nifti_normals'
gif_path = str(Path('~/Downloads/test.gif').expanduser())


In [None]:
def temp_gif_path():
    return str(Path('~/Downloads').expanduser() / ('tmp_' + uuid.uuid4().hex + '.gif'))

    
def get_nifti_files(path):
    '''
    path: directory root containing nifti files (possibly in subdirs)
    Return a list of Path objs for every .nii file within path.
    '''
    return list(Path(path).glob('**/*.nii'))


def sample_stack(stack, rows=3, cols=3, start_with=0, show_every=3, r=0):
    '''
    stack: 3-d voxel array.
    '''
    fig, ax = plt.subplots(rows, cols, figsize=[20, 20])
    for i in range(rows * cols):
        ind = start_with + i * show_every
        ax[i // cols, i % cols].set_title('slice %d' % ind)
        
        if r == 0:
            ax[i // cols, i % cols].imshow(stack[:, :, ind], cmap='gray')
        else:
            ax[i // cols, i % cols].imshow(rotate(stack[:, :, ind], r), cmap='gray')
        
        
        ax[i // cols, i % cols].axis('off')
    plt.show()


def make_animated_gif(path, img, start=0, stop=None, step=1):
    '''
    Create animated gif of 3d image, where each frame is a 2-d image taken from 
    iterating across the 3rd dimension.  E.g. the ith 2d image is img[:, :, i]
    path: where to save the animated gif
    img: a 3-d volume
    start: index of 3rd dimension to start iterating at.  default = 0.
    stop: index of 3rd dimension to stop at, not inclusive.  Default is None, meaning stop at img.shape[2].
    step: number of slices to skip    
    '''
    # convert to uint8 to suppress warnings from imageio
    imax = img.max()
    imin = img.min()
    img = 255 * ((img - imin) / (imax - imin)) # scale to 0..255
    img = np.uint8(img)
    
    with imageio.get_writer(path, mode='I') as writer:
        for i in range(start, img.shape[2], step):
            writer.append_data(img[:, :, i])

    
def animate_crop(img, crop=(0, 1, 0, 1, 0, 1), axis=2, step=5):
    '''
    img: a 3d volume to be cropped and animated.
    axis: 0, 1, 2: the axis to animate along.  img will be transposed s.t. this axis is the 3rd axis.
    crop: 6 element list: axis 0 start position, axis 0 end position, axis 1 start position, etc.  Each position 
      is a number in [0.0, 1.0] representing the position as a proportion of that axis.  0.0 is the beginning,
      1.0 the end, and 0.5 the middle.
    step: only include every nth frame in the animation, where each frame is a 2d slice of img.
    return: ipython Image, for display in a notebook.
    '''
    # as a proportion of the total range, range of axis 0, 1, and 2 that should be included in the volume
    prop0 = crop[0:2]
    prop1 = crop[2:4]
    prop2 = crop[4:6]
    # as specific voxel coordinates, range of axis 0, 1, and 2 that should be included in the volume
    pix0 = [int(p * img.shape[0]) for p in prop0]
    pix1 = [int(p * img.shape[1]) for p in prop1]
    pix2 = [int(p * img.shape[2]) for p in prop2]

    cropped_img = img[pix0[0]:pix0[1], pix1[0]:pix1[1], pix2[0]:pix2[1]]
    # rotate axes for animation
    cropped_img = cropped_img.transpose([0,1,2][-(2-axis):] + [0,1,2][:-(2-axis)])
    
    tmp_path = temp_gif_path()
    print('temp gif path:', tmp_path)
    make_animated_gif(tmp_path, cropped_img, step=step)
    return Image(filename=tmp_path)


def animate_scan_info_crop(scan_info, i, crop=(0, 1, 0, 1, 0, 1), axis=0, step=3):
    path = scan_info.loc[i, 'path']
    print('scan path:', path)
    img = nib.load(path).get_data()
    print('scan img shape:', img.shape)
    return animate_crop(img, crop, axis=axis, step=step)
    

def get_scan_info(normal_dirs, fracture_dirs):
    # flatten a list of lists if files
    normal_files = [f for fs in [get_nifti_files(d) for d in normal_dirs] for f in fs]
    fracture_files = [f for fs in [get_nifti_files(d) for d in fracture_dirs] for f in fs]
    files = normal_files + fracture_files
    classes = ['normal'] * len(normal_files) + ['fracture'] * len(fracture_files)
    scan_info = pd.DataFrame({'id': [re.sub('\.nii$', '', p.name) for p in files], 
                              'path': [str(p) for p in files],
                              'class': classes})
    scan_info['nft'] = scan_info.path.apply(lambda p: nib.load(p))
    scan_info['header'] = scan_info.nft.apply(lambda nft: nft.header)
    scan_info['affine'] = scan_info.nft.apply(lambda nft: nft.affine)
    scan_info['pixdim'] = scan_info.header.apply(lambda h: h['pixdim'][1:4])
    scan_info['dim'] = scan_info.header.apply(lambda h: h['dim'][1:4])
    scan_info['qform_code'] = scan_info.header.apply(lambda h: h['qform_code'])
    scan_info['sform_code'] = scan_info.header.apply(lambda h: h['sform_code'])
    scan_info['sizeof_hdr'] = scan_info.header.apply(lambda h: h['sizeof_hdr'])
    scan_info['pixdim0'] = scan_info.pixdim.apply(lambda x: x[0])
    scan_info['pixdim1'] = scan_info.pixdim.apply(lambda x: x[1])
    scan_info['pixdim2'] = scan_info.pixdim.apply(lambda x: x[2])
    scan_info['dim0'] = scan_info.dim.apply(lambda x: x[0])
    scan_info['dim1'] = scan_info.dim.apply(lambda x: x[1])
    scan_info['dim2'] = scan_info.dim.apply(lambda x: x[2])
    scan_info['desc'] = scan_info.header.apply(lambda h: h['descrip'])
    return scan_info


In [None]:
scan_info = get_scan_info([normal_scans_dir], [])

In [None]:
scan_info.head()

In [None]:
scan_paths = get_nifti_files(normal_scans_dir)

In [None]:
# scan info will contain metadata about each of the scans in the small dataset we are examining
scan_info = pd.DataFrame({'id': [re.sub('\.nii$', '', p.name) for p in scan_paths], 'path': [str(p) for p in scan_paths]})
scan_info.head()

- ressample to uniform voxel size
- crop
- save to train-val-test under class directories, like keras likes?