### Classification of AD vs MCI vs NC Using the ResNet Pre-Trained Model

Import the packages

In [None]:
! sudo apt-get update
! pip install -r requirements.txt

In [None]:
! pip install monai
! pip install torchio

In [4]:
import os
import sys
import numpy as np
import pandas as pd
import torch
from torch import nn
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
from torchinfo import summary
import shutil
import data_manager as DM
# import torchvision.models as models
import torchvision.models.video as models
from torchvision.models.video import R3D_18_Weights
import data_setup, engine
from helper_functions import plot_loss_curves
from data_setup import create_dataloaders
import engine
from monai.transforms import Resize
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import DataLoader
from helper_functions import plot_loss_curves
import torchio as tio
from torchvision.models.video import R3D_18_Weights
from torch.optim.lr_scheduler import ReduceLROnPlateau
from Update_Git import git_add, git_commit, git_push


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


Add the parrent path to current path because data is there

In [5]:
current_path = os.getcwd()
parrent_path = os.path.abspath(os.path.join(current_path, '..'))
sys.path.append(parrent_path)

## Git

In [None]:
file_path = os.path.join(current_path, 'main.ipynb')
git_add(file_path)
git_commit('Updated dahsbord test 1')
git_push('main')

Error executing command: ['git', 'commit', '-m', 'Updated dahsbord test 1']
Error message: 


''

Manage data:

✔ Read subject IDs from each sheet in Subject list.xlsx.

✔ Create Data/AD, Data/MCI, Data/NC folders.

✔ Find std_T1.nii for each subject inside ADNI/{subject_id}/.

✔ Copy & renames the file to Data/{category}/{subject_id}.nii.

In [None]:
# excel_file = "../Subject list.xlsx"
# source_root = "ADNI"
# destination_root = "Data"
# categories = ["AD", "MCI", "NC"]
# DM.copy_data(excel_file,source_root,destination_root,categories)

How many subjects do we have in each group?

In [None]:
data_root = "Data"
categories = ["AD", "MCI", "NC"]

for c in categories:
    path_train = os.path.join(data_root, 'train', c)
    path_test = os.path.join(data_root, 'test', c)

    num_train_files = len(os.listdir(path_train))
    print(f"{c} train: {num_train_files} files")

    num_test_files = len(os.listdir(path_test))
    print(f"{c} test: {num_test_files} files")


## Classification Model

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Load pretrained 3D ResNet (r3d_18)
resnet3d = models.r3d_18(weights=R3D_18_Weights.DEFAULT).to(device)

# Freeze all layers for transfer learning (do this first!)
for param in resnet3d.parameters():
    param.requires_grad = False

# Modify the first convolution layer to accept 1-channel 3D MRI input
resnet3d.stem[0] = nn.Conv3d(
    in_channels=1,  # Change to 1 channel for grayscale MRI images
    out_channels=64,  # Keeping the same output channels as the original model
    kernel_size=(7, 7, 7),  # The size of the 3D convolution filter
    stride=(2, 2, 2),  # reducing the spatial resolution by half at each step
    padding=(3, 3, 3),  # Adds zero-padding around the input before applying the convolution
    bias=True  # Whether the layer should learn a bias term (default = False)
).to(device)

# Modify the final fully connected layer (fc) for 3-class classification (AD, MCI, NC)
resnet3d.fc = nn.Sequential(
    nn.Dropout(p=0.2),
    nn.Linear(in_features=512, out_features=3)  # 3-class output
).to(device)

'''
Output size = ((input size + 2*padding size - kernel size)stride size) - 1
'''

# Print model summary
summary(model=resnet3d,
        input_size=(16, 1, 79, 95, 79), # (batch_size, channels, depth, height, width)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Data loader: Prepare the data for model training and testing

In [None]:
# Define the batch size
batch_size = 32

transform = transforms.Compose([
    # Resize(spatial_size=(60, 60, 60)),
    transforms.Lambda(lambda x: (x - x.min()) / (x.max() - x.min())),  # Normalize to [0,1]
    transforms.Normalize(mean=[0.485], std=[0.229]) 
])


# Augmentaed dataset to be added to the main dataset
augmentation_transform = transforms.Compose([
    transforms.Lambda(lambda x: (x - x.min()) / (x.max() - x.min())),  # Normalize to [0,1]
    transforms.Normalize(mean=[0.485], std=[0.229]),
    transforms.RandomHorizontalFlip(p=1),  # Random horizontal flip
    transforms.RandomVerticalFlip(p=0.3),  # Random vertical flip
    transforms.RandomRotation(degrees=15, interpolation=InterpolationMode.NEAREST ),  # Random rotation
])


# Data loader
train_data_path = os.path.join(data_root,"train")
test_data_path = os.path.join(data_root,"test")

train_dataloader_pretrained, test_dataloader_pretrained, class_names = data_setup.create_dataloaders(
    train_dir=train_data_path, 
    test_dir=test_data_path,
    transform=transform,
    batch_size=batch_size,
    augmentation_transform=augmentation_transform
)

print(' ')
print(f"Class names: {class_names}")
print(f"Number of classes: {len(class_names)}")

print(' ')
print("Number of training data: ", len(train_dataloader_pretrained) * batch_size)
print("Number of testing data: ", len(test_dataloader_pretrained) * batch_size)

image_batch, label_batch = next(iter(train_dataloader_pretrained))
print(image_batch.shape, label_batch.shape)

# Choose 9 random images from the batch
num_images = 9
batch_size = image_batch.shape[0]
random_indices = torch.randint(0, batch_size, (num_images,))  # Select 9 random indices

# Mid slice index along the third dimension (depth)
mid_slice_idx = image_batch.shape[3] // 2

# Create 3x3 plot
fig, axes = plt.subplots(3, 3, figsize=(12, 12))

for i, ax in enumerate(axes.flat):
    idx = random_indices[i]  # Get a random index
    img = image_batch[idx, 0, :, :, mid_slice_idx].detach().cpu().numpy()  # Extract middle slice
    
    ax.imshow(img, cmap='gray')  
    ax.set_title(f'Image {idx+1}')
    ax.axis('off')  

plt.tight_layout()
plt.show()


## TRAIN AND TEST

In [None]:
# Check the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Working on device: {device}")

# Create optimizer and loss function
optimizer = torch.optim.Adam(params=resnet3d.parameters(), lr=1e-6, weight_decay=1e-7)
loss_fn = torch.nn.CrossEntropyLoss()

# Train the classifier head of the pretrained ViT feature extractor model
torch.manual_seed(42)
torch.cuda.manual_seed(42)

pretrained_RN_18_results = engine.train(
    model=resnet3d,
    train_dataloader=train_dataloader_pretrained,
    test_dataloader=test_dataloader_pretrained,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=20,
    device=device
)


## PLOT

In [None]:
plot_loss_curves(pretrained_RN_18_results)