In [1]:

import nibabel
import nibabel.affines
from PIL import Image
import os
import fnmatch
import numpy as np
import ants
import concurrent.futures
import pandas as pd
import glob
import xml.etree.ElementTree as ET
import xmltodict
import shutil
import pprint
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

np.set_printoptions(precision=2, suppress=True)

## File handling and loading functions

In [2]:
# Load an image into nibabel
def load_image(data_path, filename):
    
    return nibabel.load(f"{data_path}/{filename}")

# Display the middle slice of a nibabel image
def display_image(image):

    # Get image data array from image object
    image_array = np.asarray(image.dataobj)
    
    display_array(image_array)
    
    return 

# Display the middle slice of a 3d array
def display_array(array):
    
    # Get middle slice
    slice = array[array.shape[0] // 2, :, :]
    
    # Scale the image such that the maximum pixel value is 255
    # Display the scaled image
    display(Image.fromarray(((slice / np.max(slice)) * 255).astype(np.uint8)))
    
    return 

def display_image_3d(image):
    
    # Get image data array from image object
    image_array = np.asarray(image.dataobj)
    
    display_array_3d(image_array)
    
    return

def display_array_3d(array):
    
    array = (array / np.max(array)) * 255

    x, y, z = np.indices(array.shape)

    
    # Set up the plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    ax.voxels(x, y, z, array > 0, facecolors=plt.cm.viridis(array / 255.0), edgecolors='k')

    plt.show()
    
    return 

# Display the middle slice of a 3d array
def display_array(array):
    
    # Get middle slice
    slice = array[array.shape[0] // 2, :, :]
    
    # Scale the image such that the maximum pixel value is 255
    # Display the scaled image
    display(Image.fromarray(((slice / np.max(slice)) * 255).astype(np.uint8)))
    
    return 

# Find files with a specific filename and return a list. Non-recursive 
def list_files_ext(data_path, extensions):

    files = [f for f in os.listdir(data_path) if f.endswith(extensions)]
        
    return files       

# Return the absolute path to all files matching a filename in a directory. Recursive
def list_files_fname(data_path, filename):
    
    matched_files = []
    
    for root, dirs, files in os.walk(data_path):
        
        for file in fnmatch.filter(files, filename):
            
            matched_files.append(os.path.join(root, file))
    
    return matched_files

# Delete the files returned by list_files_fname
def delete_file_matching(data_path, filename):
    
    for file in list_files_fname(data_path, filename):
        
        os.remove(file)
    
    return

def delete_aux_files(subject):
    
    for file in subject.aux_file_list[:]:
        
            subject.aux_file_list.remove(file)
        
            os.remove(os.path.normpath(file))


## Subject class

In [22]:

class Subject:
    
    # Constructor assumes that the directory has already been processed in the specific format using fastsurfer
    # See preprocess.py
    def __init__(self, path):
        
        # Existing before object creation
        self.path = path
        
        self.orig_nu = os.path.join(path, "mri/orig_nu.mgz")
        
        self.mask = os.path.join(path, "mri/mask.mgz")
        
        self.aparc = os.path.join(path, "mri/aparc.DKTatlas+aseg.deep.mgz")
        
        xml_files = glob.glob(os.path.join(path, "*.xml"))
        
        self.xml_path = xml_files[0] if xml_files else None
        
        with open(self.xml_path, 'r') as file:
            
                self.xml_df = xmltodict.parse(file.read())

        # Manually assign the column headers
        header = ['ColHeaders', 'Index', 'SegId', 'NVoxels', 'Volume_mm3', 'StructName', 'normMean', 'normStdDev', 'normMin', 'normMax', 'normRange']
        
        self.aseg_stats = pd.read_csv(os.path.join(path, 'stats/aseg+DKT.stats'), delimiter='\s+', comment='#', header=None, names=header)
        
        # Existing after object creation
        
        # Affine aligned brain
        brain_aligned = os.path.join(path, "brain_aligned.nii")
        
        self.brain_aligned = brain_aligned if os.path.isfile(brain_aligned) else None
        
        # Affine alignment matrix from ANTsPy
        affine_alignment = os.path.join(path, 'affine_alignment.mat')
        
        self.affine_alignment = affine_alignment if os.path.isfile(affine_alignment) else None
        
        # Aparc file aligned with matrix
        aparc_aligned = os.path.join(path, "aparc.DKTatlas+aseg.deep_aligned.nii")
        
        self.aparc_aligned = aparc_aligned if os.path.isfile(aparc_aligned) else None
        
        # Aligned and cropped brain
        brain_aligned_cropped = os.path.join(path, "brain_aligned_cropped.nii")
        
        self.brain_aligned_cropped = brain_aligned_cropped if os.path.isfile(brain_aligned_cropped) else None
        
        # NB specific regions e.g hippocampus are not stored in the object. Access them using their path from the aux file list
        
        # Set of all files for convenience
        self.aux_file_list = {f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))}
        
        
# Searches data_path for subject directories and creates an object for each of them
def find_subjects(data_path):
    
    subject_list = []

    for item in os.listdir(data_path):
        
        subject_path = os.path.join(data_path, item)
        
        if os.path.isdir(subject_path):
            
            # MRI directory of subject path (checking validity)
            mri_path = os.path.join(subject_path, 'mri')
            
            # Check for MRI directory
            if os.path.isdir(mri_path):
                
                orig_file = os.path.join(mri_path, 'orig_nu.mgz')
                
                mask_file = os.path.join(mri_path, 'mask.mgz')

                # If both orig.mgz and mask.mgz exist, create object
                if os.path.isfile(orig_file) and os.path.isfile(mask_file):
                    
                    subject_list.append(Subject(subject_path))

    return subject_list

  self.aseg_stats = pd.read_csv(os.path.join(path, 'stats/aseg+DKT.stats'), delimiter='\s+', comment='#', header=None, names=header)


## Image manipulation functions

### Brain extraction

In [4]:
# Performs brain extraction using the orig_nu.mgz and mask.mgz of the subject by multiplying the mask with the image
def extract_brain(orig_file, mask_file):
    
    # Load the image and the brain mask
    image = nibabel.load(orig_file)
    mask = nibabel.load(mask_file)
    
    # Get their image arrays
    image_array = np.asarray(image.dataobj)
    mask_array = np.asarray(mask.dataobj)
    
    # Apply the mask, the mask entries are 1 or 0
    brain_array = image_array * mask_array
    
    return brain_array

### Reference brain (global)

In [None]:
# NB seems that fastsrufer brain is better
reference_brain_array_mni = extract_brain("/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/mni_icbm152_lin_nifti/icbm_avg_152_t1_tal_lin.nii", "/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/mni_icbm152_lin_nifti/icbm_avg_152_t1_tal_lin_mask.nii")

display_array(reference_brain_array_mni)

reference_brain_array_fastsurfer = extract_brain("/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/mni_icbm152_lin_nifti/fastsurfer-processed/mri/orig_nu.mgz", "/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/mni_icbm152_lin_nifti/fastsurfer-processed/mri/mask.mgz")

display_array(reference_brain_array_fastsurfer)


### Affine alignment

In [6]:
# Affine align a single subject
def alignment(subject):
        
    # Extract brain of subject and convert it to an ANTsPy image
    # The subject's brain is the moving image
    brain_array = extract_brain(subject.orig_nu, subject.mask)
    
    moving_image = ants.from_numpy(brain_array)
    
    # Convert the reference brain to an ANTsPy image
    # The reference brain is already extracted
    fixed_image = ants.from_numpy(reference_brain_array_fastsurfer)
    
    # Perform registration using ANTsPy
    registration = ants.registration(fixed=fixed_image, moving=moving_image, type_of_transform='AffineFast')
    
    aligned_brain_array = registration['warpedmovout'].numpy()
    
    # Copy the temp mat transformation file to the subject directory
    shutil.copy(registration['fwdtransforms'][0], os.path.join(subject.path, 'affine_alignment.mat'))
    
    subject.affine_alignment = os.path.join(subject.path, 'affine_alignment.mat')
    
    # Make nibabel image from array
    # Identity matrix as affine transform
    aligned_image = nibabel.Nifti1Image(aligned_brain_array, np.eye(4))
    
    # Save the NiBabel image as a .nii file
    aligned_image_path = os.path.join(subject.path, 'brain_aligned.nii')
    
    nibabel.save(aligned_image, aligned_image_path)
    
    subject.brain_aligned = aligned_image_path
    
    return aligned_image

# Affine align a list of subjects in parallel 
def alignment_parallel(subject_list):
        
    # Use ProcessPoolExecutor to run affine alignment in parallel
    with concurrent.futures.ProcessPoolExecutor() as executor:
        
        futures = []
        
        for subject in subject_list:

            futures.append(executor.submit(alignment, subject))
        
        for future in concurrent.futures.as_completed(futures):
            
            display_image(future.result())
            
    return


# Uses affine_alignment.mat of a subject to align another file (for example the aseg file)
# Must have already been aligned using 
def aux_alignment(subject, file, is_aparc):
    
    # Open both images
    fixed_image = ants.from_numpy(np.asarray(nibabel.load(subject.brain_aligned).get_fdata()))
    
    moving_image = ants.from_numpy(np.asarray(nibabel.load(file).get_fdata()))
    
    # Must use nearest neighbours for interpolation to preserve discrete labels (colours), prevents blurring
    transformed_image = ants.apply_transforms(fixed_image, moving_image, subject.affine_alignment, interpolator='nearestNeighbor')
    
    ants.plot(fixed_image)
    
    ants.plot(transformed_image)
    
    path = os.path.join(subject.path,(os.path.splitext(os.path.basename(file))[0] + '_aligned.nii'))
    
    
    if is_aparc:
        
        # Convert to nibabel image
        # Make sure parcellation files are stored as int as they contain discrete values
        transformed_image = nibabel.Nifti1Image(transformed_image.numpy(), np.eye(4), dtype=np.int32)

        nibabel.save(transformed_image, path)
        
        subject.aparc_aligned = path
        
    else:
        
        # Convert to nibabel image
        transformed_image = nibabel.Nifti1Image(transformed_image.numpy(), np.eye(4))

        nibabel.save(transformed_image, path)
        
        subject.aux_file_list.add(path)
        
    
    return transformed_image

# Align aparc files in parallel
def aux_alignment_parallel(subject_list):
    
    # Use ProcessPoolExecutor to run affine alignment in parallel
    with concurrent.futures.ProcessPoolExecutor() as executor:
        
        futures = []
        
        for subject in subject_list:

            futures.append(executor.submit(aux_alignment, subject, subject.aparc, True))
            
        for future in concurrent.futures.as_completed(futures):
            
            display_image(future.result())
            
    return


### Cropping

In [7]:
# Crop images to the minimum size whilst retaining whole dataset
# Can only be done on the whole dataset as the dataset has to be checked before
def crop_subjects(subject_list, relative_path, is_full_brain):
    
    max_bbox = (np.inf, np.inf, np.inf, -np.inf, -np.inf, -np.inf)
    
    def bounding_box(image_array):
    
        non_zero_indices = np.nonzero(image_array)
        
        min_x, min_y, min_z = np.min(non_zero_indices[0]), np.min(non_zero_indices[1]), np.min(non_zero_indices[2])
        max_x, max_y, max_z = np.max(non_zero_indices[0]), np.max(non_zero_indices[1]), np.max(non_zero_indices[2])
        
        return (min_x, min_y, min_z, max_x, max_y, max_z)
    
    # Find maximum bbox
    for subject in subject_list:
    
        image = nibabel.load(os.path.join(subject.path, relative_path))
        
        min_x, min_y, min_z, max_x, max_y, max_z = bounding_box(image.get_fdata())
        
        # Update max_bbox
        max_bbox = (
            min(max_bbox[0], min_x),
            min(max_bbox[1], min_y),
            min(max_bbox[2], min_z),
            max(max_bbox[3], max_x),
            max(max_bbox[4], max_y),
            max(max_bbox[5], max_z)
        )
          
          
    for subject in subject_list:
        
        image = nibabel.load(os.path.join(subject.path, relative_path))
        
        image_array = image.get_fdata()
        
        global_min_x, global_min_y, global_min_z, global_max_x, global_max_y, global_max_z = max_bbox
        
        # Crop the image array using the global bounding box
        cropped_array = image_array[
            int(global_min_x):int(global_max_x),
            int(global_min_y):int(global_max_y),
            int(global_min_z):int(global_max_z)
        ]
        
        # Create a new NiBabel image from the cropped array
        cropped_image = nibabel.Nifti1Image(cropped_array, image.affine)
        
        display_image(cropped_image)
        
        fname = os.path.splitext(os.path.basename(relative_path))[0]
                
        cropped_path = os.path.join(subject.path,(fname + '_cropped.nii'))
        
        nibabel.save(cropped_image, cropped_path)
        
        if is_full_brain:
            
            subject.brain_aligned_cropped = cropped_path
            
        else:
            
            subject.aux_file_list.add(cropped_path)

            
    return


### Region extraction

In [None]:
# Extracts brain regions using their number label found from freesurfer LUT
# Takes regions as a name
def extract_region(subject, values_list, brain, aparc, is_aligned):
    
    aparc_array = nibabel.load(aparc).get_fdata()
    
    image_array = nibabel.load(brain).get_fdata()
    
    # Create a mask from regions in list
    filtered_array = np.where(np.isin(aparc_array, values_list), 1, 0)
    
    # Check for empty array
    if np.all(filtered_array == 0):
        
        print("Error: region empty")
        
        return filtered_array

    
    # Extract region using mask
    extracted_region = image_array * filtered_array
    
    # Look up the name of the region for the filename
    lut_path = "/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/preprocessing/FreeSurferColorLUT.txt"
    
    lut = pd.read_csv(lut_path, delimiter='\s+', comment='#', header=None)
    
    region_names = lut[lut[0].isin(values_list)][1]

    # Save the regions as a nii file
    region_image = nibabel.Nifti1Image(extracted_region, np.eye(4))
    
    if is_aligned:
            
        region_image_path = os.path.join(subject.path, ('_'.join(region_names) + '_aligned.nii'))
        
    else:
            
        region_image_path = os.path.join(subject.path, ('_'.join(region_names) + '.nii'))
        
    nibabel.save(region_image, region_image_path)
        
    subject.aux_file_list.append(region_image_path)
    
    return extracted_region



## Main code

In [9]:
#print(subject_list[0].aux_file_list)

In [None]:
'''print(list_files_fname(data_path, 'aligned_brain.nii'))
    
delete_file_matching(data_path, 'aligned_brain.nii')'''

In [None]:

    
'''pprint.pprint(list_files_fname(data_path, 'affine_alignment.pkl'))
    
delete_file_matching(data_path, 'affine_alignment.pkl')'''

In [None]:
'''print(list_files_fname(data_path, 'aligned_brain_cropped.nii'))
    
delete_file_matching(data_path, 'aligned_brain_cropped.nii')'''

In [24]:
data_path = "/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/scratch disk/full-datasets/adni1-complete-1yr-3t"

subject_list = find_subjects(data_path)


In [None]:
alignment_parallel(subject_list)

In [None]:
crop_subjects(subject_list, 'brain_aligned.nii', True)

In [None]:
# NB affine alinged aparcs may be inaccurate
aux_alignment_parallel(subject_list)

In [None]:
# Extract hippocampus from non-aligned brains (serial)
for subject in subject_list:
    
    extract_region(subject, [17, 53], subject.orig_nu, subject.aparc, False)
    
    print('.')

In [None]:

# Crop hippocampi
crop_subjects(subject_list, 'Left-Hippocampus_Right-Hippocampus.nii', False)



In [23]:

# Extract hippocampus from aligned brains (serial)
for subject in subject_list:
    
    extract_region(subject, [17, 53], subject.brain_aligned, subject.aparc_aligned, True)
    
    print('.')


AttributeError: 'Subject' object has no attribute 'aparc_aligned'

In [None]:

# Crop hippocampi
crop_subjects(subject_list, 'Left-Hippocampus_Right-Hippocampus_aligned.nii')


In [None]:
'''

for subject in subject_list:
    
    research_group = subject.xml_df['idaxs']['project']['subject']['researchGroup']
    
    print(subject.xml_df['idaxs']['project']['subject'])
    
    #print(research_group)
    
    display_image(nibabel.load(subject.brain_aligned_cropped))
    
    break '''
