In [23]:
!pip install -q monai nibabel scikit-learn


In [25]:
import os                    
import pandas as pd           
import torch                  
import torch.nn as nn         
import torch.optim as optim   
from torch.utils.data import Dataset, DataLoader, random_split 


In [27]:
import monai
print("MONAI version:", monai.__version__)

MONAI version: 1.5.0


In [30]:
#MONAI
from monai.transforms import (
    Compose,           
    LoadImaged,                 
    ScaleIntensityd,    
    ResizeWithPadOrCropd,
    ToTensord,
    Lambda           
)

from monai.data import Dataset

# MONAI model 
from monai.networks.nets import resnet          
from monai.networks.layers import Norm          

# For accuracy, precision, recall, F1
from sklearn.metrics import classification_report  



In [22]:
csv_path = "/kaggle/input/labelled/mri_diagnosis_mapping.csv" 
df = pd.read_csv(csv_path)

label_map = {"AD": 0, "MCI": 1, "NC": 2, "CN": 2}

df.columns = df.columns.str.strip() 
df["Diagnosis"] = df["Diagnosis"].str.strip()  

# Build list of dicts for MONAI
data_dicts = []

for _, row in df.iterrows():
    img_path = row["MRI Path"]
    label = row["Diagnosis"]
    if label not in label_map:
        continue
    data_dicts.append({
        "image": img_path,
        "label": label_map[label]
    })

print(f"Loaded {len(data_dicts)} samples")
print("Example sample:\n", data_dicts[0])


Loaded 460 samples
Example sample:
 {'image': '/kaggle/input/preprocessed-output-zipped/preprocess_output_f/133_S_0913/MPR__GradWarp__B1_Correction__N3__Scaled_2/2007-07-18_15_09_17.0/I119636/ADNI_133_S_0913_MR_MPR__GradWarp__B1_Correction__N3__Scaled_2_Br_20081008095615528_S35319_I119636_final.nii', 'label': 1}


In [32]:
#MONAI Pipeline Transform
transforms = Compose([
    LoadImaged(keys=["image"]),  # Load NIfTI image
    Lambda(func=lambda x: {"image": x["image"][None, ...], "label": x["label"]}),
    ScaleIntensityd(keys=["image"]),
    ResizeWithPadOrCropd(keys=["image"], spatial_size=(96, 96, 96)),
    ToTensord(keys=["image", "label"])
])


In [33]:
from torch.utils.data import random_split, DataLoader

full_dataset = Dataset(data=data_dicts, transform=transforms)

torch.manual_seed(42)

#80% train, 10% val, 10% test
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size  # Remainder to make total exact


train_dataset, val_dataset, test_dataset = random_split(
    full_dataset,
    [train_size, val_size, test_size]
)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)


In [35]:
from monai.networks.nets.resnet import resnet18  # Import ResNet18

# AD, MCI, CN
num_classes = 3

# 3D ResNet18 model
model = resnet18(
    spatial_dims=3,         # 3D convolutions
    n_input_channels=1,     # MRI volumes are grayscale → 1 channel
    num_classes=num_classes # 3-way classification (AD/MCI/NC)
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [39]:
# loss function
criterion = nn.CrossEntropyLoss()

# optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [40]:
num_epochs = 10  
best_val_accuracy = 0.0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")

    # Training phase
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for batch in train_loader:
        images = batch["image"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_train_loss = train_loss / total
    train_acc = correct / total

    print(f"Train Loss: {avg_train_loss:.4f} | Train Accuracy: {train_acc:.4f}")

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for batch in val_loader:
            images = batch["image"].to(device)
            labels = batch["label"].to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            preds = torch.argmax(outputs, dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    avg_val_loss = val_loss / val_total
    val_acc = val_correct / val_total

    print(f"Val Loss: {avg_val_loss:.4f} | Val Accuracy: {val_acc:.4f}")


    # Save best model (optional)
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        print("Saved best model so far.")



Epoch 1/10
Train Loss: 0.8668 | Train Accuracy: 0.5951
Val Loss: 1.6259 | Val Accuracy: 0.4130
Saved best model so far.

Epoch 2/10
Train Loss: 0.8330 | Train Accuracy: 0.6196
Val Loss: 1.4117 | Val Accuracy: 0.3913

Epoch 3/10
Train Loss: 0.7107 | Train Accuracy: 0.6929
Val Loss: 1.3272 | Val Accuracy: 0.3696

Epoch 4/10
Train Loss: 0.6280 | Train Accuracy: 0.7201
Val Loss: 1.2502 | Val Accuracy: 0.4348
Saved best model so far.

Epoch 5/10
Train Loss: 0.4348 | Train Accuracy: 0.8370
Val Loss: 1.1064 | Val Accuracy: 0.5870
Saved best model so far.

Epoch 6/10
Train Loss: 0.2926 | Train Accuracy: 0.8967
Val Loss: 1.5468 | Val Accuracy: 0.3913

Epoch 7/10
Train Loss: 0.1778 | Train Accuracy: 0.9457
Val Loss: 1.1845 | Val Accuracy: 0.5870

Epoch 8/10
Train Loss: 0.1523 | Train Accuracy: 0.9620
Val Loss: 1.0990 | Val Accuracy: 0.6304
Saved best model so far.

Epoch 9/10
Train Loss: 0.0792 | Train Accuracy: 0.9755
Val Loss: 1.1289 | Val Accuracy: 0.6304

Epoch 10/10
Train Loss: 0.0309 | Tr

In [42]:
# Load best saved model
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

test_correct = 0
test_total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        images = batch["image"].to(device)
        labels = batch["label"].to(device)

        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)

        test_correct += (preds == labels).sum().item()
        test_total += labels.size(0)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_acc = test_correct / test_total
print(f"\n Final Test Accuracy: {test_acc:.4f}")



 Final Test Accuracy: 0.6087
