# Classification of Alzheimer's Diseases using Quantification of HippoCampal Volume

### Exploratory Data Analysis - Section 2 - HippoCampusDataLoader

This task involves building recursive UNet model, training, logging and testing. 

In [29]:
! pip3 install medpy

Collecting medpy
[?25l  Downloading https://files.pythonhosted.org/packages/3b/70/c1fd5dd60242eee81774696ea7ba4caafac2bad8f028bba94b1af83777d7/MedPy-0.4.0.tar.gz (151kB)
[K     |██▏                             | 10kB 22.1MB/s eta 0:00:01[K     |████▎                           | 20kB 1.6MB/s eta 0:00:01[K     |██████▌                         | 30kB 2.1MB/s eta 0:00:01[K     |████████▋                       | 40kB 2.3MB/s eta 0:00:01[K     |██████████▉                     | 51kB 1.9MB/s eta 0:00:01[K     |█████████████                   | 61kB 2.2MB/s eta 0:00:01[K     |███████████████                 | 71kB 2.4MB/s eta 0:00:01[K     |█████████████████▎              | 81kB 2.6MB/s eta 0:00:01[K     |███████████████████▍            | 92kB 2.8MB/s eta 0:00:01[K     |█████████████████████▋          | 102kB 2.7MB/s eta 0:00:01[K     |███████████████████████▊        | 112kB 2.7MB/s eta 0:00:01[K     |██████████████████████████      | 122kB 2.7MB/s eta 0:00:01[K     

In [0]:
import os
from os import listdir
from os.path import isfile, join

import numpy as np
from medpy.io import load

#from utils.utils import med_reshape

This module loads the hippocampus dataset into RAM. Note that the data is small enough to fit into the RAM. If not, we need to enable caching techniques. You could enable help using the pytorch community here: https://discuss.pytorch.org/t/best-practice-to-cache-the-entire-dataset-during-first-epoch/19608

In [0]:
def LoadHippocampusData(root_dir, y_shape, z_shape):
    '''
    This function loads our dataset form disk into memory,
    reshaping output to common size

    Arguments:
        volume {Numpy array} -- 3D array representing the volume

    Returns:
        Array of dictionaries with data stored in seg and image fields as 
        Numpy arrays of shape [AXIAL_WIDTH, Y_SHAPE, Z_SHAPE]
    '''

    image_dir = os.path.join(root_dir + "/output/", 'images')
    label_dir = os.path.join(root_dir + "/output/", 'labels')

    images = [f for f in listdir(image_dir) if (
        isfile(join(image_dir, f)) and f[0] != ".")]

    out = []
    for f in images:

        # We would benefit from mmap load method here if dataset doesn't fit into memory
        # Images are loaded here using MedPy's load method. We will ignore header 
        # since we will not use it
        image, _ = load(os.path.join(image_dir, f))
        label, _ = load(os.path.join(label_dir, f))

        # Normalize all images (but not labels) so that values are in [0..1] range
        # Pixel Normalisation - There are multiple ways of performing the same
        """ Reference: https://machinelearningmastery.com/how-to-manually-scale-image-pixel-data-for-deep-learning/"""
        pixels = np.asarray(image)
        pixels = pixels.astype('float32')
        pixels /= 255.0

        # We need to reshape data since CNN tensors that represent minibatches
        # in our case will be stacks of slices and stacks need to be of the same size.
        # In the inference pathway we will need to crop the output to that
        # of the input image.
        # Note that since we feed individual slices to the CNN, we only need to 
        # extend 2 dimensions out of 3. We choose to extend coronal and sagittal here

        image = med_reshape(image, new_shape=(image.shape[0], y_shape, z_shape))
        label = med_reshape(label, new_shape=(label.shape[0], y_shape, z_shape)).astype(int)

        # Why do we need to cast label to int?
        # ANSWER: 
        """Casting helps reduce memory consumption."""

        out.append({"image": image, "seg": label, "filename": f})

    # Hippocampus dataset only takes about 300 Mb RAM, so we can afford to keep it all in RAM
    print(f"Processed {len(out)} files, total {sum([x['image'].shape[0] for x in out])} slices")
    return np.array(out)