<a href="https://colab.research.google.com/github/snekumar/AD_cnn/blob/main/Image_Preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



* `SimpleITK` for managing `.nii` files.
* `numpy` for matrix operations. It is neccesary for `SimpleITK` to work.
* `pandas` for loading tables with basic information about the images, like their label.
* `matplotlib` for image visualization.
* `dltk.io.preprocessing` for some useful functions, like whitening.
* `skimage.filters`, to try some filters on the images.
* `os` for file interaction.

In [None]:
import SimpleITK as sitk
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from dltk.io import preprocessing
from skimage import filters

import os

In [None]:
# path to .nii file
IMAGE = '/path/to/image/ADNI_XXXXX_XXXXX_XXXXX.nii'

In [None]:
# load in sitk format
sitk_image = sitk.ReadImage(IMAGE)
# transform into a numpy array
img = sitk.GetArrayFromImage(sitk_image)
# check the final shape
img.shape

In [None]:
plt.imshow(img[:, :, 70], cmap='gray')
plt.show()

In [None]:
otsu = filters.threshold_otsu(img)
otsu_img = img > otsu
plt.imshow(otsu_img[:, :, 70], cmap='gray')
plt.show()

In [None]:
def resample_img(itk_image, out_spacing=[2.0, 2.0, 2.0]):
    ''' This function resamples images to 2-mm isotropic voxels.

        Parameters:
            itk_image -- Image in simpleitk format, not a numpy array
            out_spacing -- Space representation of each voxel

        Returns:
            Resulting image in simpleitk format, not a numpy array
    '''

    # Resample images to 2mm spacing with SimpleITK
    original_spacing = itk_image.GetSpacing()
    original_size = itk_image.GetSize()

    out_size = [
        int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))),
        int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))),
        int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size)
    resample.SetOutputDirection(itk_image.GetDirection())
    resample.SetOutputOrigin(itk_image.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(itk_image.GetPixelIDValue())

    resample.SetInterpolator(sitk.sitkBSpline)

    return resample.Execute(itk_image)

In [None]:
res = resample_img(sitk_image)
res_img = sitk.GetArrayFromImage(sitk_image)
res_img = preprocessing.resize_image_with_crop_or_pad(res_img, img_size=(128, 192, 192), mode='symmetric')
res_img = preprocessing.whitening(res_img)
plt.imshow(res_img[:, 100, :], cmap='gray')
plt.show()

In [None]:
CN_FOLDER = 'CN'
MCI_FOLDER = 'MCI'
AD_FOLDER = 'AD'

Create a few constants with the images paths.

In [None]:
# gdrive path, one for shell commands and the other with python format
DRIVE_SHELL_PATH = '/path/to/gdrive/'
DRIVE_PATH = '/path/to/gdrive/'

# drive path for the folder with the skull stripped images
DRIVE_SS_PATH = '/path/to/gdrive/skull_stripped/'
# file path of the description file
# it contains the information about the images
DESCRIPTION_FILE = '/path/to/Description.csv'

# dataset path, which in this case was stored in a external drive
DATASET_PATH = '/path/to/organized/dataset'
# raw directories where the images where previously organized
DATASET_FOLDERS = ['1', '2', '3', '4', '5']

Open the `.csv` file with the main information about the dataset.

In [None]:
description = pd.read_csv('/path/mydrive/')
description.head()

In [None]:
def process_and_upload(filename, path, skull_stripping=True, random_printing=False):
    ''' Process the image name and copy the image to its
        corresponding Google Drive folder.

        Parameters:
            filename -- Name of the image file (.nii)
            path -- The path were the image is located
            skull_stripping -- Whether or not to practice skull stripping
                               (The skull stripping method is defined in the
                               next section)
            random_priting -- 10% possibilities of printing a horizontal cut
                              useful to see if the skull stripping is working
                              as expected
    '''

    # separte the name of the file by '_'
    splitted_name = filename.strip().split('_')
    # sometimes residual MacOS files appear; ignore them
    if splitted_name[0] == '.': return

    # save the image ID
    image_ID = splitted_name[-1][1:-4]

    # sometimes empty files appear, just ignore them
    if image_ID == '': return
    # transform the ID into a int64 numpy variable for indexing
    image_ID = np.int64(image_ID)

    # with the ID, index the information we need
    row_index = description.index[description['Image Data ID'] == image_ID].tolist()[0]
    # obtain the corresponding row in the dataframe
    row = description.iloc[row_index]
    # get the label
    label = row['Group']

    # prepare the origin path
    complete_file_path = os.path.join(path, filename)

    if skull_stripping:
        complete_new_path = os.path.join(DRIVE_SS_PATH,
                                         label,
                                         filename)
        skull_strip_nii(complete_file_path, complete_new_path)
    else:
        complete_new_path = os.path.join(DRIVE_SHELL_PATH, label)
        # copy the image to the drive folder
        ! cp $complete_file_path $complete_new_path

    # print the image 10% of the time
    if random_printing and np.random.randint(0, 101) > 90:
        sitk_image = sitk.ReadImage(complete_new_path)
        img = sitk.GetArrayFromImage(sitk_image)
        plt.figure(figsize=(10,10))
        plt.imshow(img[:, :, np.random.randint(70, 160)],
                   cmap='gray')
        plt.show()

In [None]:
exceptions = []
for subdir in DATASET_FOLDERS:
    for path, dirs, files in os.walk(DATASET_PATH + subdir):
        if files:
            for file in files:
                try:
                    process_and_upload(file, path,
                                       skull_stripping=False,
                                       random_printing=True)
                except RuntimeError:
                    exceptions.append(os.path.join(path, file))

---

## Image registration

### Testing

In this section, the `SimpleElastix` package from the `SimpleITK` library is used to perform image registration. The installation of `SimpleElastix` is quite long, but necessary for this code to work.

First, define the names and paths of several images.

In [None]:
ROOT = '/path/to/root/directory'
CN_IMGS = ['CN/someimage.nii',
           'CN/someimage.nii',
           'CN/someimage.nii']
MCI_IMGS = ['MCI/someimage.nii',
            'MCI/someimage.nii',
            'MCI/someimage.nii']
AD_IMGS = ['AD/someimage.nii',
           'AD/someimage.nii',
           'AD/someimage.nii']

In [None]:
sitk_moving = sitk.ReadImage(ROOT + CN_IMGS[0])
sitk_fixed = sitk.ReadImage(ROOT + CN_IMGS[1])

moving_img = sitk.GetArrayFromImage(sitk_moving)
fixed_img = sitk.GetArrayFromImage(sitk_fixed)

print('Fixed', fixed_img.shape)
print('Moving', moving_img.shape)

In [None]:
elastixImageFilter = sitk.ElastixImageFilter()
elastixImageFilter.SetFixedImage(sitk_fixed)
elastixImageFilter.SetMovingImage(sitk_moving)

parameterMapVector = sitk.VectorOfParameterMap()
parameterMapVector.append(sitk.GetDefaultParameterMap("affine"))
# the following line is used for non-rigid registration
# it is commented because it is very slow and not very useful
#parameterMapVector.append(sitk.GetDefaultParameterMap("bspline"))
elastixImageFilter.SetParameterMap(parameterMapVector)

elastixImageFilter.Execute()
result = elastixImageFilter.GetResultImage()

In [None]:
img = sitk.GetArrayFromImage(result)
img.shape

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(img[:, :, 100], cmap='gray')
plt.show()

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(moving_img[:, :, 136], cmap='gray')
plt.show()

In [None]:
def registrate(sitk_fixed, sitk_moving, bspline=False):
    ''' Perform image registration using SimpleElastix.
        By default, uses affine transformation.

        Parameters:
            sitk_fixed -- Reference atlas (sitk .nii)
            sitk_moving -- Image to be registrated
                           (sitk .nii)
            bspline -- Whether or not to perform non-rigid
                       registration. Note: it usually deforms
                       the images and takes a lot of time
    '''

    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(sitk_fixed)
    elastixImageFilter.SetMovingImage(sitk_moving)

    parameterMapVector = sitk.VectorOfParameterMap()
    parameterMapVector.append(sitk.GetDefaultParameterMap("affine"))
    if bspline:
        parameterMapVector.append(sitk.GetDefaultParameterMap("bspline"))
    elastixImageFilter.SetParameterMap(parameterMapVector)

    elastixImageFilter.Execute()
    return elastixImageFilter.GetResultImage()

Try the method.

In [None]:
results = [fixed_img]
for img_name in imgs:
    moving = sitk.ReadImage(ROOT + img_name)
    result = registrate(sitk_fixed, moving)
    results.append(sitk.GetArrayFromImage(result))

In [None]:
for result in results:
    plt.figure(figsize=(10, 10))
    plt.imshow(result[:, :, 100], cmap='gray')
    plt.show()

In [None]:
DATABASE = '/Volumes/0SC4R/ADNI/RAW/'
DB_SUBFOLDERS = ['1/', '2/', '3/', '4/', '5/']

Get a reference image.

In [None]:
FIXED_IMAGE = 'local/path/to/fixed.nii'
sitk_fixed = sitk.ReadImage(DATABASE + FIXED_IMAGE)

#### Image with shape [124, 256, 256]

In [None]:
path = 'image.nii'
sitk_moving = sitk.ReadImage(path)

result = registrate(atlas, sitk_moving)
original = sitk.GetArrayFromImage(sitk_fixed)
img = sitk.GetArrayFromImage(result)
print(img.shape)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(img[50, :, :], cmap='gray')
plt.show()

In [None]:
path = 'image.nii'
sitk_moving = sitk.ReadImage(path)

result = registrate(sitk_fixed, sitk_moving)
original = sitk.GetArrayFromImage(sitk_fixed)
img = sitk.GetArrayFromImage(result)
print(img.shape)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(img[:, :, 60], cmap='gray')
plt.show()

#### Image with shape [146, 256, 256]

In [None]:
path = 'image.nii'
sitk_moving = sitk.ReadImage(path)

result = registrate(sitk_fixed, sitk_moving)
original = sitk.GetArrayFromImage(sitk_fixed)
img = sitk.GetArrayFromImage(result)
print(img.shape)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(img[:, :, 70], cmap='gray')
plt.show()

#### Image with shape [170, 256, 256]

In [None]:
path = 'image.nii'
sitk_moving = sitk.ReadImage(path)

result = registrate(atlas, sitk_moving)
original = sitk.GetArrayFromImage(sitk_fixed)
img = sitk.GetArrayFromImage(result)
print(img.shape)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(img[50, :, :], cmap='gray')
plt.show()

#### Image with shape [166, 256, 256]

These are the most common, so it is important to make sure the registration process works well.

In [None]:
path = 'image.nii'
sitk_moving = sitk.ReadImage(path)

result = registrate(sitk_fixed, sitk_moving)
original = sitk.GetArrayFromImage(sitk_fixed)
img = sitk.GetArrayFromImage(result)
print(img.shape)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(img[:, :, 70], cmap='gray')
plt.show()

#### Trying with MNI 305 mean atlas

In [None]:
atlas_path = '/path/to/atlas.nii'
atlas = sitk.ReadImage(atlas_path)

In [None]:
path = 'image.nii'
sitk_moving = sitk.ReadImage(path)

result = registrate(atlas, sitk_moving)
img = sitk.GetArrayFromImage(result)
original = sitk.GetArrayFromImage(sitk_moving)
print(img.shape)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(img[80, :, :], cmap='gray')
plt.show()

In [None]:
filenames = np.array([])
for path, _, files in os.walk(OASIS_DB):
    files_paths = [os.path.join(path, name) for name in files]
    filenames = np.concatenate((filenames, files_paths), axis=None)

In [None]:
mri_files = np.array([name for name in filenames if 'pet' not in name])
description = pd.read_csv(os.path.join(OASIS_RAW, DESCRIPTION))

In [None]:
import re

for name in mri_files:
    split_path = name.split('/')
    subject = split_path[-1][4:12]
    diagnosis = description.loc[description['Subject'] == subject, ['dx1']].iloc[0][0]
    if diagnosis in LABELS.keys():
        dest = os.path.join(OASIS_RAW, LABELS[diagnosis])
        ! cp $name $dest

In [None]:
atlas = sitk.ReadImage('/path/to/atlas.nii')
# resample the atlas to the desired spatial resolution
atlas = resample_img(atlas)

In [None]:
DB = '/path/to/database/'
DEST = '/path/to/destination/'
exceptions = []

for file in os.listdir(DB):
    try:
        name = os.path.join(DB, file)
        new_name = os.path.join(DEST, file)
        sitk_moving = sitk.ReadImage(name)
        sitk_moving = resample_img(sitk_moving)
        registrated = registrate(atlas, sitk_moving)
        sitk.WriteImage(registrated, new_name)
    except:
        exceptions.append(name)

In [None]:
# save the exceptions
with open(os.path.join(DB, 'exceptions.txt'), 'w') as f:
    for item in exceptions:
        f.write("%s\n" % item)

---

## Building the ADNI database



Define the paths where the images are, and where to store them.

In [None]:
# original database
DATABASE = '/Volumes/0SC4R/TFM-Data/ADNI/MRI/RAW/'
DB_SUBFOLDERS = ['1/', '2/', '3/', '4/', '5/', '6/',
                 '7/', '8/', '9/', '10/', '11/', '12/',
                 '13/', '14/4', '15/', '16/', '17/',
                 '18/', '19/', '20/']

# registered and organized database
REG_DB = '/Volumes/0SC4R/TFM-Data/ADNI/MRI/REGISTERED/'
REG_DB_SUBFOLDERS = ['AD/', 'MCI/', 'CN/']

Load and resample the 305 MNI atlas, as well as the description file.

In [None]:
atlas = sitk.ReadImage('/path/to/atlas.nii')
atlas = resample_img(atlas)

description = pd.read_csv('/Volumes/0SC4R/TFM-Data/ADNI/MRI/Description.csv')
description.head()

In [None]:
def register_and_save(filename, path, atlas, random_printing=False):
    ''' Process the image name and copy the image to its
        corresponding Google Drive folder.

        Parameters:
            filename -- Name of the image file (.nii)
            path -- The path were the image is located
            atlas -- Reference sitk image for registration
            random_priting -- 10% possibilities of printing a horizontal cut
                              useful to see if the skull stripping is working
                              as expected
    '''

    # separte the name of the file by '_'
    splitted_name = filename.strip().split('_')
    # sometimes residual MacOS files appear; ignore them
    if splitted_name[0] == '.': return

    # save the image ID
    image_ID = splitted_name[-1][1:-4]

    # sometimes empty files appear, just ignore them
    if image_ID == '': return
    # transform the ID into a int64 numpy variable for indexing
    image_ID = np.int64(image_ID)

    # with the ID, index the information we need
    row_index = description.index[description['Image Data ID'] == image_ID].tolist()[0]
    # obtain the corresponding row in the dataframe
    row = description.iloc[row_index]
    # get the label
    label = row['Group']

    # prepare the origin path
    complete_file_path = os.path.join(path, filename)
    # load sitk image
    sitk_moving = sitk.ReadImage(complete_file_path)
    sitk_moving = resample_img(sitk_moving)
    registrated = registrate(atlas, sitk_moving)

    # prepare the destination path
    complete_new_path = os.path.join(REG_DB,
                                     label,
                                     filename)
    sitk.WriteImage(registrated, complete_new_path)

    if random_printing and np.random.randint(0, 101) > 90:
        sitk_image = sitk.ReadImage(complete_new_path)
        img = sitk.GetArrayFromImage(sitk_image)
        plt.figure(figsize=(10,10))
        plt.imshow(img[np.random.randint(30, 70), :, :],
                   cmap='gray')
        plt.show()

In [None]:
for subdir in DB_SUBFOLDERS:
    for path, dirs, files in os.walk(DATABASE + subdir):
        if files:
            for file in files:
                try:
                    register_and_save(file,
                                      path,
                                      atlas,
                                      random_printing=False)
                except RuntimeError:
                    print('Exception with', os.path.join(path, file))

In [None]:
from nipype.interfaces import fsl
import matplotlib.pyplot as plt

In [None]:
def skull_strip_nii(original_img, destination_img, frac=0.3):
    ''' Practice skull stripping on the given image, and save
        the result to a new .nii image.
        Uses FSL-BET
        (https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/BET/UserGuide#Main_bet2_options:)

        Parameters:
            original_img -- Original nii image
            destination_img -- The new skull-stripped image
            frac -- Fractional intensity threshold for BET
    '''

    btr = fsl.BET()
    btr.inputs.in_file = original_img
    btr.inputs.frac = frac
    btr.inputs.out_file = destination_img
    btr.cmdline
    res = btr.run()

In [None]:
REG_DB = '/Volumes/0SC4R/TFM-Data/ADNI/MRI/REGISTERED/'
SKULL_STRIPPED_DB = '/Volumes/0SC4R/TFM-Data/ADNI/MRI/SKULL_STRIPPED/'
CLASS_FOLDERS = ['AD', 'MCI', 'CN']

In [None]:
exceptions = []
for folder in CLASS_FOLDERS:
    origin_folder = os.path.join(REG_DB, folder)
    dest_folder = os.path.join(SKULL_STRIPPED_DB, folder)
    for path, _, files in os.walk(origin_folder):
        for file in files:
            try:
                img = os.path.join(path, file)
                dest = os.path.join(dest_folder, file)
                skull_strip_nii(img, dest, frac=0.2)
            except RuntimeError:
                exceptions.append(img)

# save the exceptions
with open(os.path.join(SKULL_STRIPPED_DB, 'exceptions.txt'), 'w') as f:
    for item in exceptions:
        f.write("%s\n" % item)

In [None]:
REG_DB = '/Volumes/0SC4R/TFM-Data/IXI-T1/REGISTERED/'
SKULL_STRIPPED_DB = '/Volumes/0SC4R/TFM-Data/IXI-T1/SKULL-STRIPPED/'

files = [os.path.join(REG_DB, name) for name in os.listdir(REG_DB)]

In [None]:
exceptions = []
for filename in files:
    try:
        dest = os.path.join(SKULL_STRIPPED_DB, filename.split('/')[-1])
        skull_strip_nii(filename, dest, frac=0.5)

        result = sitk.ReadImage(dest)
        img = sitk.GetArrayFromImage(result)
        plt.imshow(img[np.random.randint(20, 70), :, :], cmap='gray')
        plt.show()
    except:
        exceptions.append(filename)

# save the exceptions
with open(os.path.join(SKULL_STRIPPED_DB, 'exceptions.txt'), 'w') as f:
    for item in exceptions:
        f.write("%s\n" % item)

---

## Working with PET images (DEPRECATED)

Using MRI images originally resulted in very bad results, so a differente approach was tried. The attention switched to PET images obtained from the ADNI database. The dataset consists of 1251 images, post-processed with spatial normalization, baseline alignment and Tx Origin. In the end, a few tweaks allowed to obtain much better results with MRI images, so this code is not a part of the final work.

The localization of the data is in the following constants:

In [None]:
PET_DB_PATH = '/Volumes/0SC4R/TFM-Data/ADNI/PET/'
DESCRIPTION = 'PET.csv'
DATA_FOLDER = 'RAW/'

We load the `.csv` file with all the information about the images.

In [None]:
pet_data = os.path.join(PET_DB_PATH, DATA_FOLDER)
description = pd.read_csv(os.path.join(PET_DB_PATH, DESCRIPTION))
description.set_index('Subject ID', inplace=True)
description.head()

Save the complete path of every image in a numpy array, so that information becomes easily accesible.

In [None]:
images = np.array([])
for path, _, files in os.walk(pet_data):
    images = np.concatenate((images,
                             [os.path.join(path, name)
                                  for name in files]),
                            axis=None)

In [None]:
rand = np.random.choice(images)
sitk_image = sitk.ReadImage(rand)
img = sitk.GetArrayFromImage(sitk_image)

plt.figure(figsize=(5, 5))
plt.imshow(img[55, :, :])
plt.show()

In [None]:
sitk_img = resample_img(itk_image=sitk_image)
img = sitk.GetArrayFromImage(sitk_image)
img.shape

All 1251 images have the same shape of 69x95x79. They seem to have been already resampled to $2mm^3$ isotropic voxels, because using resampling does not change the resolution. They all have also the same scaling, meaning that we do not need to registrate them. The only thing we have to do is organize them into folders according to their class:

In [None]:
ORGANIZED_FOLDER = 'Organized/'
destination = os.path.join(PET_DB_PATH, ORGANIZED_FOLDER)

In [None]:
def get_label_for(filename):
    ''' Returns the label for a given image.
        In this case we do not have image IDs, for some
        reason, so we need to infer the label by using
        the subject ID.

        Parameters:
            filename -- Complete path and filename

        Returns:
            label (CN/MCI/AD)
    '''

    path_folders = filename.split('/')
    image_name = path_folders[-1]
    image_name = image_name.split('_')
    subject_ID = '_'.join([image_name[1],
                           image_name[2],
                           image_name[3]])
    label = description.loc[subject_ID, 'Research Group']
    if type(label) is str: return label
    else: return label.iloc[0]

In [None]:
for filename in images:
    label = get_label_for(filename)
    copy_to = os.path.join(destination, label)
    ! cp $filename $copy_to

Now, the information can be uploaded to Google Drive. The images are ready to be used by a CNN, although it may be necessary to create a TFRecord, because there is too much data. Images also need obvious extra processing if they are going to be fed to a 3D RGB CNN.

---