In [1]:
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import scipy.ndimage
from monai.networks.nets import resnet18
from torch.utils.data import Dataset, DataLoader

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)


In [2]:
def preprocess_nifti(nifti_path, target_shape=(128, 128, 128)):
    # Normalize intensity to [0,1]
    img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
    # Resize to target shape: 
    img_resized = scipy.ndimage.zoom(img, np.array(target_shape) / np.array(img.shape), order=1)
    return img_resized

In [3]:
import os

def find_files_with_substring(directory, substring):
    matching_files = [f for f in os.listdir(directory) if substring in f]
    return matching_files

def get_nib_image(adni_file_name):
    return nib.load(adni_file_name).get_fdata()

def visualize_image(nib_image):
    plt.imshow(nib_image[:,:,nib_image.shape[2]//2])
    plt.show()

In [4]:
# TODO: Implement a simple function which returns the subject's image files in nib format based on subject id and optional date.
# Use the dip_project/adni_subject_file_ma.json to search for the file(s) or, os paths.
def get_image_file_names_for_subject(subject_id, date=None):
    os.path.expanduser("~/adni_flat_dataset")
    dir_ = "/home/rittikar-s/adni_flat_dataset"
    files = find_files_with_substring(dir_, subject_id)
    if date:
        files = [file for file in files if date in file]
    file_paths = [f"{dir_}/{file}" for file in files]
    return file_paths
    # nib_images = []
    # for file in files:
    #     nib_image = get_nib_image(f"{dir_}/{file}")
    #     nib_images.append(nib_image)
    # return nib_images

In [5]:
import pandas as pd

df = pd.read_csv("ADNI1_Complete_1Yr_1.5T_1_26_2025.csv")

In [6]:
class NiftiDataset(Dataset):
    def __init__(self, image_paths, labels, target_shape=(128, 128, 128)):
        self.image_paths = image_paths
        self.labels = labels
        self.target_shape = target_shape

    def __len__(self):
        return len(self.image_paths)

    def preprocess_nifti(self, nifti_path):
        img = nib.load(nifti_path).get_fdata()
        
        # Normalize intensity to [0,1]
        img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
        
        # Resize to target shape
        img_resized = scipy.ndimage.zoom(img, np.array(self.target_shape) / np.array(img.shape), order=1)
        
        return torch.tensor(img_resized, dtype=torch.float32).unsqueeze(0)  # Add channel dim

    def __getitem__(self, idx):
        image = self.preprocess_nifti(self.image_paths[idx])
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return image, label

In [7]:
class_to_label = {
    "CN": 0,
    "MCI": 1,
    "AD": 2
}
image_paths = []
labels = []

for i in range(len(df)):
    row = df.iloc[i]
    subject = row["Subject"]
    date = row["Acq Date"]
    date = date.replace("/", "-")
    image_path = get_image_file_names_for_subject(subject, date)[0]
    image_paths.append(image_path)
    labels.append(class_to_label[row["Group"]])

In [8]:
len(image_paths)

2294

In [9]:
len(labels)

2294

In [10]:
from sklearn.model_selection import train_test_split
train_paths, test_paths, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42)

In [11]:
# Create train & test datasets
train_dataset = NiftiDataset(train_paths, train_labels)
test_dataset = NiftiDataset(test_paths, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=True)

print(f"Train Batches: {len(train_loader)}, Test Batches: {len(test_loader)}")

Train Batches: 459, Test Batches: 115


In [12]:
# Define the ResNet-based classifier
class ResNet3DClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ResNet3DClassifier, self).__init__()
        self.resnet = resnet18(spatial_dims=3, n_input_channels=1, num_classes=num_classes)

    def forward(self, x):
        return self.resnet(x)

# Instantiate model
num_classes = 3
model = ResNet3DClassifier(num_classes)

In [13]:
import torch.optim as optim

# Define loss function & optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet3DClassifier(
  (resnet): ResNet(
    (conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU(inplace=True)
    (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): ResNetBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): ResNetBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, e

In [14]:
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm  # Progress bar
import os

# Ensure model directory exists
os.makedirs("models", exist_ok=True)

def train_model(model, train_loader, test_loader, num_epochs=10, accumulation_steps=2):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    scaler = torch.cuda.amp.GradScaler()

    best_loss = float("inf")
    best_model_path = "models/best_model.pth"

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        optimizer.zero_grad()

        # Use tqdm for progress bar
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")

        for i, (images, labels) in progress_bar:
            images, labels = images.to(device), labels.to(device)

            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels) / accumulation_steps  # Divide loss for accumulation

            scaler.scale(loss).backward()

            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # Update progress bar
            progress_bar.set_postfix(loss=loss.item(), accuracy=100 * correct / total)

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total

        # Adjust LR based on loss
        scheduler.step(epoch_loss)

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

        # Save best model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"✅ Best Model Saved! (Loss: {best_loss:.4f})")

# Train the model with progress bars and best model saving
train_model(model, train_loader, test_loader, num_epochs=10)

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
Epoch 1/10: 100%|██████████████████████████████████████████| 459/459 [12:55<00:00,  1.69s/it, accuracy=45.4, loss=0.438]


Epoch [1/10], Loss: 0.5431, Accuracy: 45.40%
✅ Best Model Saved! (Loss: 0.5431)


Epoch 2/10: 100%|██████████████████████████████████████████| 459/459 [12:59<00:00,  1.70s/it, accuracy=46.4, loss=0.339]


Epoch [2/10], Loss: 0.5259, Accuracy: 46.38%
✅ Best Model Saved! (Loss: 0.5259)


Epoch 3/10: 100%|██████████████████████████████████████████| 459/459 [13:09<00:00,  1.72s/it, accuracy=46.6, loss=0.658]


Epoch [3/10], Loss: 0.5274, Accuracy: 46.65%


Epoch 4/10: 100%|████████████████████████████████████████████| 459/459 [12:41<00:00,  1.66s/it, accuracy=47, loss=0.348]


Epoch [4/10], Loss: 0.5219, Accuracy: 46.98%
✅ Best Model Saved! (Loss: 0.5219)


Epoch 5/10: 100%|██████████████████████████████████████████| 459/459 [13:14<00:00,  1.73s/it, accuracy=47.6, loss=0.786]


Epoch [5/10], Loss: 0.5165, Accuracy: 47.63%
✅ Best Model Saved! (Loss: 0.5165)


Epoch 6/10: 100%|██████████████████████████████████████████| 459/459 [12:59<00:00,  1.70s/it, accuracy=46.6, loss=0.413]


Epoch [6/10], Loss: 0.5223, Accuracy: 46.59%


Epoch 7/10: 100%|██████████████████████████████████████████| 459/459 [12:59<00:00,  1.70s/it, accuracy=49.3, loss=0.569]


Epoch [7/10], Loss: 0.5102, Accuracy: 49.32%
✅ Best Model Saved! (Loss: 0.5102)


Epoch 8/10: 100%|███████████████████████████████████████████| 459/459 [13:00<00:00,  1.70s/it, accuracy=50.2, loss=0.59]


Epoch [8/10], Loss: 0.5033, Accuracy: 50.19%
✅ Best Model Saved! (Loss: 0.5033)


Epoch 9/10: 100%|██████████████████████████████████████████| 459/459 [12:57<00:00,  1.69s/it, accuracy=49.7, loss=0.478]


Epoch [9/10], Loss: 0.5028, Accuracy: 49.70%
✅ Best Model Saved! (Loss: 0.5028)


Epoch 10/10: 100%|█████████████████████████████████████████| 459/459 [12:47<00:00,  1.67s/it, accuracy=51.2, loss=0.281]


Epoch [10/10], Loss: 0.4888, Accuracy: 51.17%
✅ Best Model Saved! (Loss: 0.4888)


In [15]:
from sklearn.metrics import classification_report

def evaluate_model(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)  # Get class with highest probability
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Generate classification report
    report = classification_report(all_labels, all_preds, digits=4)
    print("\n🔹 Classification Report:\n")
    print(report)

# Evaluate model
evaluate_model(model, test_loader)


🔹 Classification Report:

              precision    recall  f1-score   support

           0     0.5437    0.4058    0.4647       138
           1     0.5271    0.8259    0.6435       224
           2     0.4000    0.0206    0.0392        97

    accuracy                         0.5294       459
   macro avg     0.4903    0.4174    0.3825       459
weighted avg     0.5052    0.5294    0.4620       459

