In [1]:
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import scipy.ndimage
from monai.networks.nets import resnet18
from torch.utils.data import Dataset, DataLoader
from sklearn.utils.class_weight import compute_class_weight

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)


In [2]:
def preprocess_nifti(nifti_path, target_shape=(128, 128, 128)):
    # Normalize intensity to [0,1]
    img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
    # Resize to target shape: 
    img_resized = scipy.ndimage.zoom(img, np.array(target_shape) / np.array(img.shape), order=1)
    return img_resized

In [3]:
import os

def find_files_with_substring(directory, substring):
    matching_files = [f for f in os.listdir(directory) if substring in f]
    return matching_files

def get_nib_image(adni_file_name):
    return nib.load(adni_file_name).get_fdata()

def visualize_image(nib_image):
    plt.imshow(nib_image[:,:,nib_image.shape[2]//2])
    plt.show()

In [4]:
# TODO: Implement a simple function which returns the subject's image files in nib format based on subject id and optional date.
# Use the dip_project/adni_subject_file_ma.json to search for the file(s) or, os paths.
def get_image_file_names_for_subject(subject_id, date=None):
    os.path.expanduser("~/adni_flat_dataset")
    dir_ = "/home/rittikar-s/adni_flat_dataset"
    files = find_files_with_substring(dir_, subject_id)
    if date:
        files = [file for file in files if date in file]
    file_paths = [f"{dir_}/{file}" for file in files]
    return file_paths
    # nib_images = []
    # for file in files:
    #     nib_image = get_nib_image(f"{dir_}/{file}")
    #     nib_images.append(nib_image)
    # return nib_images

In [5]:
import pandas as pd

df = pd.read_csv("ADNI1_Complete_1Yr_1.5T_1_26_2025.csv")

In [6]:
class NiftiDataset(Dataset):
    def __init__(self, image_paths, labels, target_shape=(128, 128, 128)):
        self.image_paths = image_paths
        self.labels = labels
        self.target_shape = target_shape

    def __len__(self):
        return len(self.image_paths)

    def preprocess_nifti(self, nifti_path):
        img = nib.load(nifti_path).get_fdata()
        
        # Normalize intensity to [0,1]
        img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
        
        # Resize to target shape
        img_resized = scipy.ndimage.zoom(img, np.array(self.target_shape) / np.array(img.shape), order=1)
        
        return torch.tensor(img_resized, dtype=torch.float32).unsqueeze(0)  # Add channel dim

    def __getitem__(self, idx):
        image = self.preprocess_nifti(self.image_paths[idx])
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return image, label

In [7]:
class_to_label = {
    "CN": 0,
    "MCI": 1,
    "AD": 2
}
image_paths = []
labels = []

for i in range(len(df)):
    row = df.iloc[i]
    subject = row["Subject"]
    date = row["Acq Date"]
    date = date.replace("/", "-")
    image_path = get_image_file_names_for_subject(subject, date)[0]
    image_paths.append(image_path)
    labels.append(class_to_label[row["Group"]])

In [8]:
len(image_paths)

2294

In [9]:
len(labels)

2294

In [10]:
from sklearn.model_selection import train_test_split
train_paths, test_paths, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.3, random_state=42, stratify=labels)
val_paths, test_paths, val_labels, test_labels = train_test_split(test_paths, test_labels, test_size=0.5, random_state=42, stratify=test_labels)

In [11]:
# Create train & test datasets
train_dataset = NiftiDataset(train_paths, train_labels)
val_dataset = NiftiDataset(val_paths, val_labels)
test_dataset = NiftiDataset(test_paths, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=True)

print(f"Train Batches: {len(train_loader)}, Val Batches: {len(val_loader)}, Test Batches: {len(test_loader)}")

Train Batches: 402, Val Batches: 86, Test Batches: 87


In [12]:
# Define the ResNet-based classifier
class ResNet3DClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ResNet3DClassifier, self).__init__()
        self.resnet = resnet18(spatial_dims=3, n_input_channels=1, num_classes=num_classes)

    def forward(self, x):
        return self.resnet(x)

# Instantiate model
num_classes = 3
model = ResNet3DClassifier(num_classes)

In [13]:
import torch.optim as optim

# Compute class weights
classes = np.unique(train_labels)
class_weights = compute_class_weight(class_weight="balanced", classes=classes, y=train_labels)

# Define loss function & optimizer
criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights))
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet3DClassifier(
  (resnet): ResNet(
    (conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU(inplace=True)
    (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): ResNetBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): ResNetBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, e

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm  # Progress bar
import os

# Ensure model directory exists
os.makedirs("models", exist_ok=True)

def train_model(model, train_loader, val_loader, num_epochs=10, accumulation_steps=2, device="cuda"):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    scaler = torch.cuda.amp.GradScaler()

    best_val_loss = float("inf")
    best_model_path = "models/best_model.pth"

    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        running_loss = 0.0
        correct, total = 0, 0

        # Use tqdm for progress bar
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")

        optimizer.zero_grad()

        for i, (images, labels) in progress_bar:
            images, labels = images.to(device), labels.to(device)

            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels) / accumulation_steps  # Divide loss for accumulation

            scaler.scale(loss).backward()

            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item() * accumulation_steps  # Undo division for correct loss tracking
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # Update progress bar
            progress_bar.set_postfix(loss=loss.item(), accuracy=100 * correct / total)

        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total

        # **Validation Phase**
        model.eval()  # Set model to evaluation mode
        val_loss, val_correct, val_total = 0.0, 0, 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)

        val_loss /= len(val_loader)
        val_acc = 100 * val_correct / val_total

        # Adjust LR based on **validation loss**
        scheduler.step(val_loss)

        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  🔹 Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
        print(f"  🔹 Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%")

        # Save the best model based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"✅ Best Model Saved! (Val Loss: {best_val_loss:.4f})")

# Train the model with validation and best model saving
train_model(model, train_loader, val_loader, num_epochs=20, device="cuda:1")

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
Epoch 1/20: 100%|██████████████████████████████████████████████████████████████████████████████| 402/402 [23:52<00:00,  3.56s/it, accuracy=45.2, loss=0.368]


Epoch [1/20]
  🔹 Train Loss: 1.0828, Train Accuracy: 45.23%
  🔹 Val Loss: 1.3045, Val Accuracy: 48.26%
✅ Best Model Saved! (Val Loss: 1.3045)


Epoch 2/20: 100%|██████████████████████████████████████████████████████████████████████████████| 402/402 [24:02<00:00,  3.59s/it, accuracy=47.2, loss=0.683]


Epoch [2/20]
  🔹 Train Loss: 1.0500, Train Accuracy: 47.23%
  🔹 Val Loss: 1.1458, Val Accuracy: 40.41%
✅ Best Model Saved! (Val Loss: 1.1458)


Epoch 3/20: 100%|██████████████████████████████████████████████████████████████████████████████| 402/402 [18:05<00:00,  2.70s/it, accuracy=48.5, loss=0.558]


Epoch [3/20]
  🔹 Train Loss: 1.0374, Train Accuracy: 48.54%
  🔹 Val Loss: 1.0688, Val Accuracy: 48.55%
✅ Best Model Saved! (Val Loss: 1.0688)


Epoch 4/20: 100%|██████████████████████████████████████████████████████████████████████████████| 402/402 [17:20<00:00,  2.59s/it, accuracy=48.8, loss=0.233]


Epoch [4/20]
  🔹 Train Loss: 1.0155, Train Accuracy: 48.79%
  🔹 Val Loss: 1.1436, Val Accuracy: 48.26%


Epoch 5/20: 100%|██████████████████████████████████████████████████████████████████████████████| 402/402 [16:04<00:00,  2.40s/it, accuracy=49.2, loss=0.813]


Epoch [5/20]
  🔹 Train Loss: 1.0082, Train Accuracy: 49.22%
  🔹 Val Loss: 1.1440, Val Accuracy: 38.37%


Epoch 6/20: 100%|██████████████████████████████████████████████████████████████████████████████| 402/402 [18:59<00:00,  2.83s/it, accuracy=51.7, loss=0.623]


Epoch [6/20]
  🔹 Train Loss: 0.9735, Train Accuracy: 51.71%
  🔹 Val Loss: 1.1430, Val Accuracy: 45.64%


Epoch 7/20: 100%|████████████████████████████████████████████████████████████████████████████████| 402/402 [23:07<00:00,  3.45s/it, accuracy=54, loss=0.558]


Epoch [7/20]
  🔹 Train Loss: 0.9224, Train Accuracy: 54.02%
  🔹 Val Loss: 0.9170, Val Accuracy: 52.62%
✅ Best Model Saved! (Val Loss: 0.9170)


Epoch 8/20: 100%|██████████████████████████████████████████████████████████████████████████████| 402/402 [14:50<00:00,  2.21s/it, accuracy=57.4, loss=0.711]


Epoch [8/20]
  🔹 Train Loss: 0.8809, Train Accuracy: 57.38%
  🔹 Val Loss: 1.0591, Val Accuracy: 52.03%


Epoch 9/20: 100%|████████████████████████████████████████████████████████████████████████████████| 402/402 [13:38<00:00,  2.04s/it, accuracy=60.2, loss=0.4]


Epoch [9/20]
  🔹 Train Loss: 0.8375, Train Accuracy: 60.19%
  🔹 Val Loss: 0.8038, Val Accuracy: 58.14%
✅ Best Model Saved! (Val Loss: 0.8038)


Epoch 10/20: 100%|███████████████████████████████████████████████████████████████████████████████| 402/402 [17:44<00:00,  2.65s/it, accuracy=66, loss=0.825]


Epoch [10/20]
  🔹 Train Loss: 0.7611, Train Accuracy: 65.98%
  🔹 Val Loss: 0.7650, Val Accuracy: 63.37%
✅ Best Model Saved! (Val Loss: 0.7650)


Epoch 11/20: 100%|█████████████████████████████████████████████████████████████████████████████| 402/402 [14:40<00:00,  2.19s/it, accuracy=70.7, loss=0.718]


Epoch [11/20]
  🔹 Train Loss: 0.6658, Train Accuracy: 70.65%
  🔹 Val Loss: 0.8225, Val Accuracy: 65.12%


Epoch 12/20: 100%|█████████████████████████████████████████████████████████████████████████████| 402/402 [14:33<00:00,  2.17s/it, accuracy=77.6, loss=0.579]


Epoch [12/20]
  🔹 Train Loss: 0.5138, Train Accuracy: 77.63%
  🔹 Val Loss: 0.8522, Val Accuracy: 65.41%


Epoch 13/20: 100%|█████████████████████████████████████████████████████████████████████████████| 402/402 [16:01<00:00,  2.39s/it, accuracy=84.8, loss=0.453]


Epoch [13/20]
  🔹 Train Loss: 0.3976, Train Accuracy: 84.80%
  🔹 Val Loss: 0.8889, Val Accuracy: 64.83%


Epoch 14/20: 100%|████████████████████████████████████████████████████████████████████████████████| 402/402 [18:56<00:00,  2.83s/it, accuracy=91, loss=0.14]


Epoch [14/20]
  🔹 Train Loss: 0.2677, Train Accuracy: 90.97%
  🔹 Val Loss: 0.7400, Val Accuracy: 72.67%
✅ Best Model Saved! (Val Loss: 0.7400)


Epoch 15/20: 100%|███████████████████████████████████████████████████████████████████████████████| 402/402 [14:16<00:00,  2.13s/it, accuracy=94.9, loss=1.3]


Epoch [15/20]
  🔹 Train Loss: 0.1560, Train Accuracy: 94.89%
  🔹 Val Loss: 0.7943, Val Accuracy: 72.67%


Epoch 16/20: 100%|█████████████████████████████████████████████████████████████████████████████| 402/402 [16:02<00:00,  2.39s/it, accuracy=96.7, loss=0.413]


Epoch [16/20]
  🔹 Train Loss: 0.1225, Train Accuracy: 96.70%
  🔹 Val Loss: 0.6733, Val Accuracy: 76.74%
✅ Best Model Saved! (Val Loss: 0.6733)


Epoch 17/20: 100%|███████████████████████████████████████████████████████████████████████████████| 402/402 [15:10<00:00,  2.27s/it, accuracy=97, loss=0.569]


Epoch [17/20]
  🔹 Train Loss: 0.1068, Train Accuracy: 97.01%
  🔹 Val Loss: 0.6760, Val Accuracy: 75.29%


Epoch 18/20: 100%|█████████████████████████████████████████████████████████████████████████████| 402/402 [14:25<00:00,  2.15s/it, accuracy=98.4, loss=0.128]


Epoch [18/20]
  🔹 Train Loss: 0.0754, Train Accuracy: 98.38%
  🔹 Val Loss: 1.1639, Val Accuracy: 63.95%


Epoch 19/20: 100%|████████████████████████████████████████████████████████████████████████████| 402/402 [15:16<00:00,  2.28s/it, accuracy=98.6, loss=0.0569]


Epoch [19/20]
  🔹 Train Loss: 0.0704, Train Accuracy: 98.63%
  🔹 Val Loss: 0.7744, Val Accuracy: 75.29%


Epoch 20/20: 100%|█████████████████████████████████████████████████████████████████████████████| 402/402 [15:13<00:00,  2.27s/it, accuracy=98.1, loss=0.368]


Epoch [20/20]
  🔹 Train Loss: 0.0812, Train Accuracy: 98.13%
  🔹 Val Loss: 0.7253, Val Accuracy: 76.45%


In [16]:
from sklearn.metrics import classification_report

def evaluate_model(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to("cuda:1"), labels.to("cuda:1")
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)  # Get class with highest probability
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Generate classification report
    report = classification_report(all_labels, all_preds, digits=4)
    print("\n🔹 Classification Report:\n")
    print(report)

# Evaluate model
evaluate_model(model, test_loader)


🔹 Classification Report:

              precision    recall  f1-score   support

           0     0.8857    0.5849    0.7045       106
           1     0.6697    0.8743    0.7584       167
           2     0.6842    0.5417    0.6047        72

    accuracy                         0.7159       345
   macro avg     0.7465    0.6669    0.6892       345
weighted avg     0.7391    0.7159    0.7098       345

