In [25]:
%matplotlib inline
import os
import sys
import re
import glob

import pandas as pd
import numpy as np
import torch
import torch.utils.data
import torch.nn

from random import randrange
from PIL import Image
import matplotlib.pyplot as plt

### Pre-processing

In [45]:
import skimage.filters
import skimage.morphology
import cv2 as cv
import tempfile
import shutil
import os
import fnmatch


""" Pre-processing Functions """

DEMENTIA_MAP = {
    '0.0': "nondemented",
    '0.5': "mildly demented",
    '1.0': 'moderately demented',
    '2.0': 'severely demented'
}

# Pre-determined max dimensions of cropped images
CONV_WIDTH = 137
CONV_HEIGHT = 167

def normalize_intensity(img):
    """
    Normalizes the intensity of an image to the range [0, 255].

    Parameters:
    img: The image to be normalized.

    Returns:
    Normalized image.
    """
    img_min = img.min()
    img_max = img.max()
    normalized_img = (img - img_min) / (img_max - img_min) * 255
    return normalized_img.astype(np.uint8)

def pad_image_to_size(img, width, height):
    """
    Pads an image with zeros to the specified width and height.

    Parameters:
    img: The image to be padded.
    width: The desired width.
    height: The desired height.

    Returns:
    Padded image.
    """
    if img.shape[0] > height or img.shape[1] > width:
        scaling_factor = min(width / img.shape[1], height / img.shape[0])
        img = cv.resize(img, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv.INTER_AREA)

    padded_img = np.zeros((height, width), dtype=img.dtype)
    y_offset = (height - img.shape[0]) // 2
    x_offset = (width - img.shape[1]) // 2
    padded_img[y_offset:y_offset+img.shape[0], x_offset:x_offset+img.shape[1]] = img
    return padded_img

def crop_black_boundary(mri_image):
    """
    Crops the black boundary from an MRI image.

    Parameters:
    mri_image: Input MRI image.

    Returns:
    Cropped MRI image with black boundaries removed.
    """
    _, thresh = cv.threshold(mri_image, 1, 255, cv.THRESH_BINARY)
    contours, _ = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    largest_contour = max(contours, key=cv.contourArea)
    x, y, w, h = cv.boundingRect(largest_contour)
    cropped_image = mri_image[y:y+h, x:x+w]
    return cropped_image

def extract_files(base_dir, target_dir, oasis_csv_path):
    oasis_df = pd.read_csv(oasis_csv_path)

    for subdir in filter(lambda d: d != '.DS_Store', os.listdir(base_dir)):
        source_dir = os.path.join(base_dir, subdir, "PROCESSED", "MPRAGE", "T88_111")
        identifier = subdir
        if identifier not in oasis_df['ID'].values:
            continue
        row = oasis_df[oasis_df['ID'] == identifier]
        dementia_type = row['CDR'].iloc[0]
        if pd.isna(dementia_type):
            continue

        for n_suffix in ['n3', 'n4', 'n5']:
            for img_type in ['tra', 'cor', 'sag']:
                pattern = f"{subdir}_mpr_{n_suffix}_anon_111_t88_gfc_{img_type}_*.gif"
                for file in os.listdir(source_dir):
                    if fnmatch.fnmatch(file, pattern):
                        fn = os.path.join(source_dir, file)
                        process_image(fn, target_dir, dementia_type, identifier, img_type)
                        break


def process_image(fn, target_dir, dementia_type, id):
    """
    Processes a single MRI image file and saves it to the target directory.

    Parameters:
    fn: Path of the file to be processed.
    target_dir: Directory where the processed file will be saved.
    dementia_type: Type of dementia associated with the image.
    id: Patient identifier associated with the image.
    """
    with Image.open(fn) as img:
        img = np.array(img.convert('RGB'))
        img = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
    img = crop_black_boundary(img)
    img = normalize_intensity(img)
    img = pad_image_to_size(img, CONV_WIDTH, CONV_HEIGHT)

    target_subdir = os.path.join(target_dir, DEMENTIA_MAP[str(dementia_type)])
    os.makedirs(target_subdir, exist_ok=True)
    target_path = os.path.join(target_subdir, f"{id}.png")
    cv.imwrite(target_path, img)

def process_all_discs(base_disc_path, base_extraction_path, oasis_csv_path):
    """
    Processes all discs found in the base directory.

    Parameters:
    base_disc_path: Base path where the discs are located.
    base_extraction_path: Base path where processed data will be saved.
    oasis_csv_path: Path to the OASIS CSV file.
    """
    total_disks = 12

    for i in range(1, total_disks + 1):
        disc_path = os.path.join(base_disc_path, f'disc{i}')
        print(f"Processing Disc {i} at path: {disc_path}")

        if not os.path.exists(disc_path):
            print(f"Disc {i} does not exist at path {disc_path}. Skipping.")
            continue

        extract_files(disc_path, base_extraction_path, oasis_csv_path)
        print(f"Finished processing Disc {i}")

        # Cleanup: delete the folder after processing
        # cleanup_directory(disc_path)

def cleanup_directory(path):
    """
    Deletes a directory and all of its contents.

    Parameters:
    path: Path of the directory to be deleted.
    """
    try:
        shutil.rmtree(path)
        print(f"Cleaned up and deleted the directory: {path}")
    except OSError as e:
        print(f"Error: {e.filename} - {e.strerror}")

In [46]:
# Example usage:
base_path = '/Users/msturman00/Documents/GitHub/alzheimer-classification/'
base_disc_path = base_path+'data2'
base_extraction_path = './data'
oasis_csv_path = base_path + './datacsv/oasis_cross-sectional.csv'
process_all_discs(base_disc_path, base_extraction_path, oasis_csv_path)

Processing Disc 1 at path: /Users/msturman00/Documents/GitHub/alzheimer-classification/data2/disc1
Finished processing Disc 1
Processing Disc 2 at path: /Users/msturman00/Documents/GitHub/alzheimer-classification/data2/disc2
Finished processing Disc 2
Processing Disc 3 at path: /Users/msturman00/Documents/GitHub/alzheimer-classification/data2/disc3
Finished processing Disc 3
Processing Disc 4 at path: /Users/msturman00/Documents/GitHub/alzheimer-classification/data2/disc4
Finished processing Disc 4
Processing Disc 5 at path: /Users/msturman00/Documents/GitHub/alzheimer-classification/data2/disc5
Finished processing Disc 5
Processing Disc 6 at path: /Users/msturman00/Documents/GitHub/alzheimer-classification/data2/disc6
Finished processing Disc 6
Processing Disc 7 at path: /Users/msturman00/Documents/GitHub/alzheimer-classification/data2/disc7
Finished processing Disc 7
Processing Disc 8 at path: /Users/msturman00/Documents/GitHub/alzheimer-classification/data2/disc8
Finished processing

In [49]:
# Uniqueify the identifiers. 

import os

def count_images(data_dir):
    class_counts = {}
    for class_name in os.listdir(data_dir):
        class_dir = os.path.join(data_dir, class_name)
        if os.path.isdir(class_dir):
            count = len([name for name in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, name))])
            class_counts[class_name] = count
    return class_counts

# Example usage
data_dir = './data/'  # Replace with your actual data directory path
counts = count_images(data_dir)
for class_name, count in counts.items():
    print(f"{class_name}: {count} images")

# This will output the count of images in each subdirectory of 'notebooks/data'.


moderately demented: 28 images
mildly demented: 70 images
severely demented: 2 images
nondemented: 134 images
