# Classification of AD vs MCI vs NC Using the ResNet Pre-Trained Model

## Import the packages

In [None]:
! sudo apt-get update
! pip install -r requirements.txt
! sudo apt install libgl1 -y

#! pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118



In [None]:
import os
import sys
import numpy as np
import pandas as pd
import torch
from torch import nn
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
from torchinfo import summary
import shutil
import data_manager as DM
# import torchvision.models as models
import torchvision.models.video as models
from torchvision.models.video import R3D_18_Weights
import data_setup, engine
from helper_functions import plot_loss_curves
from data_setup import create_dataloaders
import engine
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import DataLoader
from helper_functions import plot_loss_curves, plot_roc_auc
import torchio as tio
from torchvision.models.video import R3D_18_Weights
from torch.optim.lr_scheduler import ReduceLROnPlateau
from monai.networks.nets import DenseNet121
from torch.optim.lr_scheduler import StepLR
from monai.transforms import Compose, Resize, ScaleIntensity, NormalizeIntensity, RandFlip
import random
from monai.bundle import download, load
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

: 

## Add the parrent path to current path because data is there

In [None]:
current_path = os.getcwd()
parrent_path = os.path.abspath(os.path.join(current_path, '..'))
sys.path.append(parrent_path)

## Manage data:

✔ Read subject IDs from each sheet in Subject list.xlsx.

✔ Create Data/AD, Data/MCI, Data/NC folders.

✔ Find std_T1.nii for each subject inside ADNI/{subject_id}/.

✔ Copy & renames the file to Data/{category}/{subject_id}.nii.

In [None]:
categories = ["AD", "MCI", "NC"]
excel_file = "../Subject_list.xlsx"
source_root = "ADNI"
destination_root = "Data"
if os.path.join(destination_root):
   os.system('rm -r Data')
source_dir = "../"
image_type = "std_T1"
DM.copy_data(image_type, excel_file,source_root,source_dir, destination_root,categories)
image_type = "SUV"
DM.copy_data(image_type, excel_file,source_root,source_dir, destination_root,categories)


How many subjects do we have in each group?

In [None]:
data_root = "Data"

for c in categories:
    path_train = os.path.join(data_root, 'train', c)
    path_test = os.path.join(data_root, 'test', c)

    num_train_files = len(os.listdir(path_train))
    print(f"{c} train: {num_train_files} files")

    num_test_files = len(os.listdir(path_test))
    print(f"{c} test: {num_test_files} files")


## Classification Model

### MONAI DenseNET121

In [None]:
img_size = 64

# MRI subnetwork
mri_model = DenseNet121(
    spatial_dims=3,
    in_channels=1,           # Only MRI
    out_channels=1024        # Feature vector
)

# PET subnetwork
pet_model = DenseNet121(
    spatial_dims=3,
    in_channels=1,           # Only PET
    out_channels=1024        # Feature vector
)

# Late Fusion Classifier
class LateFusionModel(nn.Module):
    def __init__(self, mri_model, pet_model, num_classes):
        super().__init__()
        self.mri_model = mri_model
        self.pet_model = pet_model
        self.classifier = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.6),  # Increase dropout rate
            nn.Linear(in_features=2048, out_features=128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.6),
            nn.Linear(128, num_classes)
        )


    def forward(self, x):
        mri, pet = x[:, 0:1], x[:, 1:2]  # Split channels: [B, 1, D, H, W]
        mri_feat = self.mri_model(mri)
        pet_feat = self.pet_model(pet)
        combined = torch.cat([mri_feat, pet_feat], dim=1)
        return self.classifier(combined)

# Instantiate late fusion model
Monai3d = LateFusionModel(mri_model, pet_model, num_classes=3)

# Print model summary
summary(model=Monai3d,
        input_size=(1, 2, img_size, img_size, img_size),  # (batch_size, channels, D, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

## Data loader: Prepare the data for model training and testing

In [None]:
# Define global params
batch_size = 64
random.seed(20)

# Base transform
transform = Compose([
    Resize(spatial_size=(img_size, img_size, img_size), mode='trilinear'),
    lambda x: (x - x.min()) / (x.max() - x.min() + 1e-5),
])

# Your custom intensity adjustment function
def random_intensity_adjust(img):
    factor = random.uniform(0.8, 1.2)
    return img * factor

augmentation_transform = Compose([
    Resize(spatial_size=(img_size, img_size, img_size), mode='trilinear'),
    # transforms.RandomHorizontalFlip(p=1),
    transforms.RandomVerticalFlip(p=1),
    transforms.RandomRotation(degrees=10),  # Add random rotation
    # transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Random translation
    lambda x: (x - x.min()) / (x.max() - x.min() + 1e-5),  # Normalize
    lambda x: random_intensity_adjust(x)
])


# Prepare paths
train_data_path = os.path.join(data_root, "train")
test_data_path = os.path.join(data_root, "test")

# Load DataLoaders
train_dataloader, test_dataloader, class_names = create_dataloaders(
    train_dir=train_data_path,
    test_dir=test_data_path,
    transform=transform,
    batch_size=batch_size,
    augmentation_transform=augmentation_transform
)

# Print info
print(' ')
print(f"Class names: {class_names}")
print(f"Number of classes: {len(class_names)}")
print(' ')
print("Number of training samples:", len(train_dataloader.dataset))
print("Number of testing samples:", len(test_dataloader.dataset))


# Visualize samples
image_batch, label_batch = next(iter(train_dataloader))
print(image_batch.shape, label_batch.shape)

# Randomly sample n PET middle slices
num_images = 16
batch_size = image_batch.shape[0]
random_indices = torch.randint(0, batch_size, (num_images,))
mid_slice_idx = image_batch.shape[4] // 2  # Depth index

fig, axes = plt.subplots(4, 4, figsize=(10, 6))
for i, ax in enumerate(axes.flat):
    idx = random_indices[i]
    img = image_batch[idx, 1, :, :, mid_slice_idx].detach().cpu().numpy()  # PET slice
    ax.imshow(img, cmap='gray')
    ax.set_title(f"Label: {class_names[label_batch[idx].item()]}")
    ax.axis('off')

plt.tight_layout()
plt.show()


## TRAIN AND TEST

In [None]:
# Check the device
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Working on device: {device}")

# Model
model = Monai3d

# Create optimizer and loss function
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
# loss_fn = torch.nn.NLLLoss()
loss_fn = torch.nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=5, gamma=0.6)

# Train the classifier head of the pretrained ViT feature extractor model
torch.manual_seed(21)
torch.cuda.manual_seed(21)

model_results = engine.train(
    model=model,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=30,
    device=device,
    scheduler = scheduler
)

## PLOT

In [None]:
plot_loss_curves(model_results)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
from itertools import cycle

# Step 1: Collect predictions and ground-truth labels
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in train_dataloader:
        inputs, labels = batch  # Unpack if batch is a tuple
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model(inputs)  # <--- You were missing this line
        probs = torch.softmax(outputs, dim=1)  # Get class probabilities
        
        all_preds.append(probs.cpu().numpy())
        all_labels.append(labels.cpu().numpy())


# Convert to numpy arrays
y_score = np.concatenate(all_preds, axis=0)      # shape: (num_samples, num_classes)
y_true = np.concatenate(all_labels, axis=0)      # shape: (num_samples,)

# Step 2: Binarize the labels for ROC computation
n_classes = y_score.shape[1]
y_true_bin = label_binarize(y_true, classes=np.arange(n_classes))

# Step 3: Compute ROC curve and AUC for each class
fpr = dict()
tpr = dict()
roc_auc = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Step 4 (Optional): Compute micro-average and macro-average
# Micro-average
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Macro-average
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
mean_tpr = np.zeros_like(all_fpr)

for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])

mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# Step 5: Plot all ROC curves
plt.figure(figsize=(10, 8))

# Plot per-class ROC
colors = cycle(['blue', 'green', 'red', 'orange', 'purple'])
for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=2,
             label=f'Class {i} (AUC = {roc_auc[i]:.2f})')

# Plot micro/macro average ROC
# plt.plot(fpr["micro"], tpr["micro"], linestyle='--', color='deeppink',
#          label=f'Micro-average (AUC = {roc_auc["micro"]:.2f})', linewidth=2)

# plt.plot(fpr["macro"], tpr["macro"], linestyle='--', color='navy',
#          label=f'Macro-average (AUC = {roc_auc["macro"]:.2f})', linewidth=2)

# Plot diagonal line for random guess
plt.plot([0, 1], [0, 1], 'k--', lw=2)

# Final touches
plt.xlabel('False Positive Rate', fontsize=14, fontweight='bold')
plt.ylabel('True Positive Rate', fontsize=14, fontweight='bold')
plt.title('Multi-class ROC Curve', fontsize=16, fontweight='bold')
plt.legend(loc='lower right', fontsize=12)
plt.xticks(fontsize=12, fontweight='bold')
plt.yticks(fontsize=12, fontweight='bold')



## Save the model

In [None]:
def save_model(model, target_dir, model_name):
    os.makedirs(target_dir, exist_ok=True)
    save_path = os.path.join(target_dir, model_name)
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

save_model(model=model,
            target_dir="models",
            model_name="MONAI3D_AD_MCI_NC.pth")

## Load the model

In [None]:
import torch
from torch.utils.data import DataLoader
from data_setup import MultiModalSeparateNiiDataset

torch.manual_seed(21)
torch.cuda.manual_seed_all(21)
model.eval()

valid_dataset = MultiModalSeparateNiiDataset("Data/valid", categories, transform)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=True)

test_loss, test_acc, metrics, classification_summary = engine.test_step(
    model=model,
    dataloader=valid_loader,
    loss_fn=loss_fn,
    device=device
)

print("Test Accuracy:", test_acc)
for class_index, sensitivity in metrics.items():
    print(f"Class {class_index}: Sensitivity = {sensitivity:.4f}")
