In [36]:
%matplotlib inline
import os
import sys
import re
import glob

import pandas as pd
import numpy as np
import torch
import torch.utils.data
import torch.nn

from random import randrange
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
import argparse
""" Training and hyperparameter search configurations """
curr_dir = os.getcwd()

parser = argparse.ArgumentParser(description='Final')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed (default: 1)')
args = parser.parse_args("")

# Set random seed to reproduce results
np.random.seed(args.seed)

### Download Files

In [42]:
import requests
import os
import tarfile

def download_file(url, local_filename):
    """
    Downloads a file from a given URL and saves it to a local path.
    """
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(local_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    return local_filename

def download_oasis1(base_dir="/Users/valenetjong/Downloads/"):
    base_url = "https://download.nrg.wustl.edu/data/oasis_cross-sectional_disc"
    total_disks = 12

    for i in range(1, total_disks + 1):
        url = f"{base_url}{i}.tar.gz"
        local_filename = f"oasis_cross-sectional_disc{i}.tar.gz"
        full_file_path = os.path.join(base_dir, local_filename)

        # Check if the file already exists
        if os.path.exists(full_file_path):
            print(f"File {local_filename} already exists. Skipping download.")
            continue

        print(f"Downloading: {url}")
        
        try:
            download_file(url, full_file_path)
            print(f"Downloaded {local_filename}")
        except Exception as e:
            print(f"Failed to download {local_filename}: {e}")

def extract_tar_gz(tar_path, extract_to_path):
    """
    Extracts a .tar.gz file to a specified directory.
    """
    with tarfile.open(tar_path, 'r:gz') as tar:
        tar.extractall(path=extract_to_path)
        print(f"Extracted {tar_path} to {extract_to_path}")

def extract_all_discs(base_disc_path="/Users/valenetjong/Downloads/", 
                    extract_to_path="/Users/valenetjong/Downloads/"):
    total_disks = 12

    for i in range(1, total_disks + 1):
        if os.path.exists(extract_to_path + f"/disc{i}") and os.path.isdir(extract_to_path + f"/disc{i}"):
            print(f"Folder for disc{i} already exists. Skipping extraction.")
            continue
        tar_path = os.path.join(base_disc_path, f"oasis_cross-sectional_disc{i}.tar.gz")
        os.makedirs(extract_to_path, exist_ok=True)
        extract_tar_gz(tar_path, extract_to_path)

        # Remove the tar.gz file after extraction
        # os.remove(tar_path)
        # print(f"Removed the archive: {tar_path}")

In [38]:
download_oasis1()

Downloading: https://download.nrg.wustl.edu/data/oasis_cross-sectional_disc1.tar.gz
Downloaded oasis_cross-sectional_disc1.tar.gz
Downloading: https://download.nrg.wustl.edu/data/oasis_cross-sectional_disc2.tar.gz
Downloaded oasis_cross-sectional_disc2.tar.gz
Downloading: https://download.nrg.wustl.edu/data/oasis_cross-sectional_disc3.tar.gz
Downloaded oasis_cross-sectional_disc3.tar.gz
Downloading: https://download.nrg.wustl.edu/data/oasis_cross-sectional_disc4.tar.gz
Downloaded oasis_cross-sectional_disc4.tar.gz
File oasis_cross-sectional_disc5.tar.gz already exists. Skipping download.
File oasis_cross-sectional_disc6.tar.gz already exists. Skipping download.
File oasis_cross-sectional_disc7.tar.gz already exists. Skipping download.
File oasis_cross-sectional_disc8.tar.gz already exists. Skipping download.
File oasis_cross-sectional_disc9.tar.gz already exists. Skipping download.
File oasis_cross-sectional_disc10.tar.gz already exists. Skipping download.
File oasis_cross-sectional_di

In [45]:
extract_all_discs()

Folder for disc1 already exists. Skipping extraction.
Extracted /Users/valenetjong/Downloads/oasis_cross-sectional_disc2.tar.gz to /Users/valenetjong/Downloads/
Extracted /Users/valenetjong/Downloads/oasis_cross-sectional_disc3.tar.gz to /Users/valenetjong/Downloads/
Extracted /Users/valenetjong/Downloads/oasis_cross-sectional_disc4.tar.gz to /Users/valenetjong/Downloads/
Extracted /Users/valenetjong/Downloads/oasis_cross-sectional_disc5.tar.gz to /Users/valenetjong/Downloads/
Extracted /Users/valenetjong/Downloads/oasis_cross-sectional_disc6.tar.gz to /Users/valenetjong/Downloads/
Extracted /Users/valenetjong/Downloads/oasis_cross-sectional_disc7.tar.gz to /Users/valenetjong/Downloads/
Extracted /Users/valenetjong/Downloads/oasis_cross-sectional_disc8.tar.gz to /Users/valenetjong/Downloads/
Extracted /Users/valenetjong/Downloads/oasis_cross-sectional_disc9.tar.gz to /Users/valenetjong/Downloads/
Extracted /Users/valenetjong/Downloads/oasis_cross-sectional_disc10.tar.gz to /Users/valen

### Pre-processing

In [112]:
import skimage.filters
import skimage.morphology
import cv2 as cv
import tempfile
import shutil

""" Pre-processing Functions """

DEMENTIA_MAP = {
    '0.0': "nondemented",
    '0.5': "mildly demented",
    '1.0': 'moderately demented',
    '2.0': 'severely demented'
}

# Pre-determined max dimensions of cropped images
CONV_WIDTH = 137
CONV_HEIGHT = 167

def normalize_intensity(img):
    """
    Normalizes the intensity of an image to the range [0, 255].

    Parameters:
    img: The image to be normalized.

    Returns:
    Normalized image.
    """
    img_min = img.min()
    img_max = img.max()
    normalized_img = (img - img_min) / (img_max - img_min) * 255
    return normalized_img.astype(np.uint8)

def pad_image_to_size(img, width, height):
    """
    Pads an image with zeros to the specified width and height.

    Parameters:
    img: The image to be padded.
    width: The desired width.
    height: The desired height.

    Returns:
    Padded image.
    """
    padded_img = np.zeros((height, width), dtype=img.dtype)
    y_offset = (height - img.shape[0]) // 2
    x_offset = (width - img.shape[1]) // 2
    padded_img[y_offset:y_offset+img.shape[0], x_offset:x_offset+img.shape[1]] = img
    return padded_img

def crop_black_boundary(mri_image):
    """
    Crops the black boundary from an MRI image.

    Parameters:
    mri_image: Input MRI image.

    Returns:
    Cropped MRI image with black boundaries removed.
    """
    _, thresh = cv.threshold(mri_image, 1, 255, cv.THRESH_BINARY)
    contours, _ = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    largest_contour = max(contours, key=cv.contourArea)
    x, y, w, h = cv.boundingRect(largest_contour)
    cropped_image = mri_image[y:y+h, x:x+w]
    return cropped_image

def extract_files(base_dir, target_dir, oasis_csv_path):
    """
    Extracts and processes MRI files from a given directory.

    Parameters:
    base_dir: Directory containing MRI files.
    target_dir: Directory where processed files will be saved.
    oasis_csv_path: Path to the CSV file containing metadata.
    """
    oasis_df = pd.read_csv(oasis_csv_path)

    for subdir in filter(lambda d: d != '.DS_Store', os.listdir(base_dir)):
        source_dir = os.path.join(base_dir, subdir, "FSL_SEG")
        print("source_dir", source_dir)
        num = subdir.split('_')[1]
        id = f'OAS1_{num}_MR1'
        num = int(num)
        row = oasis_df.loc[oasis_df['ID'] == id]
        dementia_type = row['CDR'].item()
        
        if pd.isna(dementia_type):
            continue

        for n_suffix in ['n3', 'n4']:
            fn = os.path.join(source_dir, f"{subdir}_mpr_{n_suffix}_anon_"
                                  f"111_t88_masked_gfc_fseg_tra_90.gif")
            if os.path.exists(fn):
                process_image(fn, target_dir, dementia_type, id)

def process_image(fn, target_dir, dementia_type, id):
    """
    Processes a single MRI image file and saves it to the target directory.

    Parameters:
    fn: Path of the file to be processed.
    target_dir: Directory where the processed file will be saved.
    dementia_type: Type of dementia associated with the image.
    id: Patient identifier associated with the image.
    """
    with Image.open(fn) as img:
        img = np.array(img.convert('RGB'))
        img = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
    img = crop_black_boundary(img)
    img = normalize_intensity(img)
    img = pad_image_to_size(img, CONV_WIDTH, CONV_HEIGHT)

    target_subdir = os.path.join(target_dir, DEMENTIA_MAP[str(dementia_type)])
    os.makedirs(target_subdir, exist_ok=True)
    target_path = os.path.join(target_subdir, f"{id}.png")
    cv.imwrite(target_path, img)

def process_all_discs(base_disc_path, base_extraction_path, oasis_csv_path):
    """
    Processes all discs found in the base directory.

    Parameters:
    base_disc_path: Base path where the discs are located.
    base_extraction_path: Base path where processed data will be saved.
    oasis_csv_path: Path to the OASIS CSV file.
    """
    total_disks = 12

    for i in range(1, total_disks + 1):
        disc_path = f'{base_disc_path}/disc{i}'
        if not os.path.exists(disc_path):
            print(f"Disc {i} does not exist at path {disc_path}. Skipping.")
            continue
        extract_files(disc_path, base_extraction_path, oasis_csv_path)
        print(f"Processed Disc {i}")

        # Cleanup: delete the folder after processing
        # cleanup_directory(disc_path)

def cleanup_directory(path):
    """
    Deletes a directory and all of its contents.

    Parameters:
    path: Path of the directory to be deleted.
    """
    try:
        shutil.rmtree(path)
        print(f"Cleaned up and deleted the directory: {path}")
    except OSError as e:
        print(f"Error: {e.filename} - {e.strerror}")

In [113]:
base_disc_path = '/Users/valenetjong/Downloads'
base_extraction_path = '/Users/valenetjong/alzheimer-classification/data'
oasis_csv_path = '/Users/valenetjong/alzheimer-classification/datacsv/oasis_cross-sectional.csv'

process_all_discs(base_disc_path, base_extraction_path, oasis_csv_path)

source_dir /Users/valenetjong/Downloads/disc1/OAS1_0016_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0002_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0003_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0017_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0001_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0015_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0029_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0028_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0014_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0038_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0004_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0010_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0011_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OAS1_0005_MR1/FSL_SEG
source_dir /Users/valenetjong/Downloads/disc1/OA

In [114]:
import os
import torch
from torchvision import transforms
from PIL import Image
from collections import Counter

LABEL_MAP = {
    "nondemented": 0,
    "mildly demented": 1,
    'moderately demented': 2,
    'severely demented' : 3
}

def load_dataset(base_dir):
    transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.RandomRotation(degrees=20),  # increase model robustness
    ])
    
    all_images = []
    all_labels = []
    class_counts = Counter()

    # Automatically find all subdirectories in base_dir
    for folder_name in os.listdir(base_dir):
        folder_path = os.path.join(base_dir, folder_name)
        if os.path.isdir(folder_path):  # Check if it's a directory
            class_label = LABEL_MAP[folder_name]
            for image_file in os.listdir(folder_path):
                image_path = os.path.join(folder_path, image_file)
                if os.path.isfile(image_path):
                    with Image.open(image_path) as img:
                        img_tensor = transform(img)
                        all_images.append(img_tensor)
                        all_labels.append(class_label)
                        class_counts[folder_name] += 1

    X = torch.stack(all_images)
    y = torch.tensor(all_labels, dtype=torch.long)  # Changed to long for integer labels
    return X, y, class_counts

base_dir = '/Users/valenetjong/alzheimer-classification/data'
X, y, class_counts = load_dataset(base_dir)

print(f"Combined Tensor Size: {X.size()}")
print(f"Labels Tensor Size: {y.size()}")
print(f"Class Counts: {class_counts}")

Combined Tensor Size: torch.Size([233, 1, 167, 137])
Labels Tensor Size: torch.Size([233])
Class Counts: Counter({'nondemented': 133, 'mildly demented': 70, 'moderately demented': 28, 'severely demented': 2})


In [124]:
import torch
from sklearn.model_selection import train_test_split

def train_val_split(X, y, test_size=0.2, random_state=42, stratified=False):
    # Convert X and y to numpy arrays if they are torch tensors
    X_np = X.numpy() if isinstance(X, torch.Tensor) else X
    y_np = y.numpy() if isinstance(y, torch.Tensor) else y

    # Stratified split
    if stratified:
        X_train, X_val, y_train, y_val = train_test_split(
            X_np, y_np, test_size=test_size, random_state=random_state, stratify=y_np
        )
    # Random split
    else:
        X_train, X_val, y_train, y_val = train_test_split(
            X_np, y_np, test_size=test_size, random_state=random_state
        )

    # Convert numpy arrays back to torch tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long)

    return X_train_tensor, X_val_tensor, y_train_tensor, y_val_tensor

X_train, X_val, y_train, y_val = train_val_split(X, y, test_size=0.2)

print(f'Training set size: {X_train.shape[0]}')
print(f'Validation set size: {X_val.shape[0]}')

Training set size: 186
Validation set size: 47


In [119]:
print(f"Number of nondemented in train dataset as percentage: {((y_train == 0).sum() / (X_train.shape[0])) * 100:0.2f}%")
print(f"Number of mildly demented in train dataset as percentage: {((y_train == 1).sum() / (X_train.shape[0])) * 100:0.2f}%")
print(f"Number of moderately demented in train dataset as percentage: {((y_train == 2).sum() / (X_train.shape[0])) * 100:0.2f}%")
print(f"Number of severely demented in train dataset as percentage: {((y_train == 3).sum() / (X_train.shape[0])) * 100:0.2f}%")

Number of nondemented in train dataset as percentage: 55.38%
Number of mildly demented in train dataset as percentage: 32.80%
Number of moderately demented in train dataset as percentage: 10.75%
Number of severely demented in train dataset as percentage: 1.08%


In [121]:
print(f"Number of nondemented in train dataset as percentage: {((y_val == 0).sum() / (X_val.shape[0])) * 100:0.2f}%")
print(f"Number of mildly demented in train dataset as percentage: {((y_val == 1).sum() / (X_val.shape[0])) * 100:0.2f}%")
print(f"Number of moderately demented in train dataset as percentage: {((y_val == 2).sum() / (X_val.shape[0])) * 100:0.2f}%")
print(f"Number of severely demented in train dataset as percentage: {((y_val == 3).sum() / (X_val.shape[0])) * 100:0.2f}%")

Number of nondemented in train dataset as percentage: 63.83%
Number of mildly demented in train dataset as percentage: 19.15%
Number of moderately demented in train dataset as percentage: 17.02%
Number of severely demented in train dataset as percentage: 0.00%


In [140]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset

def apply_transforms(X, transform):
    transformed_data = []
    for x in X:
        x = transform(x)  # Apply the transformation
        transformed_data.append(x)
    return torch.stack(transformed_data)

# Define transformations
train_transform = transforms.Compose([
    transforms.RandomRotation(degrees=20),
    transforms.RandomResizedCrop(size=(CONV_HEIGHT, CONV_WIDTH), scale=(0.9, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=None, shear=10),
])

# Apply transformations
X_train_transformed = apply_transforms(X_train, train_transform)

### Define CNN Model

In [125]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        # Convolutional Block 1
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)  
        self.bn1 = nn.BatchNorm2d(8)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        
        # Convolutional Block 2
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(16)
        self.pool2 = nn.MaxPool2d(kernel_size=3)
        
        # Convolutional Block 3
        self.conv3 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(32)
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        
        # Convolutional Block 4
        self.conv4 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.pool4 = nn.MaxPool2d(kernel_size=3)

        # Compute the flattened size for the fully connected layer
        self._to_linear = None
        self._forward_conv(torch.randn(1, 1, 137, 167))

        # Fully connected layers
        self.fc1 = nn.Linear(self._to_linear, 128)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(128, 4)
        self.dropout2 = nn.Dropout(p=0.2)

    def _forward_conv(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = self.pool4(F.relu(self.bn4(self.conv4(x))))
        if self._to_linear is None:
            self._to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
        return x

    def forward(self, x):
        x = self._forward_conv(x)
        x = x.view(-1, self._to_linear)  # Flatten the output for the fully connected layers
        x = self.dropout1(F.relu(self.fc1(x)))
        x = self.dropout2(self.fc2(x))
        return F.log_softmax(x, dim=1)

### Handle Disproportionate Classes

In [129]:
import torch
import torch.nn as nn
from collections import Counter

def calculate_class_weights(y_train):
    # Count the frequency of each class
    class_counts = Counter(y_train.numpy())
    total_samples = sum(class_counts.values())

    # Calculate weights: Inverse of frequency
    weights = {class_id: total_samples/class_counts[class_id] for class_id in class_counts}

    # Convert to a list in the order of class ids
    weights_list = [weights[i] for i in sorted(weights)]
    return torch.tensor(weights_list, dtype=torch.float32)

### Training and Validation

In [142]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import torch.nn.functional as F

model = CNNModel()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
class_weights = calculate_class_weights(y_train)
loss_function = nn.CrossEntropyLoss(weight=class_weights)

# DataLoader
batch_size = 64
train_data = TensorDataset(X_train_transformed, y_train)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data = TensorDataset(X_val, y_val)
val_loader = DataLoader(val_data, batch_size=batch_size)

# Training Loop
def train_model(num_epochs):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            output = model(X_batch)
            loss = loss_function(output, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}')

        # Validation
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for X_batch, y_batch in val_loader:
                output = model(X_batch)
                _, predicted = torch.max(output.data, 1)
                total += y_batch.size(0)
                correct += (predicted == y_batch).sum().item()
                
            print(f'Validation Accuracy: {100 * correct / total}%')

# Run training
num_epochs = 100  # Set the number of epochs
train_model(num_epochs)

Epoch 1/100, Loss: 5.127073049545288
Validation Accuracy: 17.02127659574468%
Epoch 2/100, Loss: 4.706859588623047
Validation Accuracy: 19.148936170212767%
Epoch 3/100, Loss: 3.936876734097799
Validation Accuracy: 63.829787234042556%
Epoch 4/100, Loss: 3.402120272318522
Validation Accuracy: 17.02127659574468%
Epoch 5/100, Loss: 1.7150770823160808
Validation Accuracy: 17.02127659574468%
Epoch 6/100, Loss: 1.3937952915827434
Validation Accuracy: 17.02127659574468%
Epoch 7/100, Loss: 1.4394262234369914
Validation Accuracy: 17.02127659574468%
Epoch 8/100, Loss: 1.3359441757202148
Validation Accuracy: 17.02127659574468%
Epoch 9/100, Loss: 1.3166981140772502
Validation Accuracy: 17.02127659574468%
Epoch 10/100, Loss: 1.3289816776911418
Validation Accuracy: 17.02127659574468%
Epoch 11/100, Loss: 1.2394487460454304
Validation Accuracy: 25.53191489361702%
Epoch 12/100, Loss: 1.240601658821106
Validation Accuracy: 44.680851063829785%
Epoch 13/100, Loss: 1.1370036602020264
Validation Accuracy: 46.