In [None]:
# ==============================
# CELL 1: ENVIRONMENT & GPU CHECK
# ==============================

import torch
import random
import numpy as np
import os

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Environment info
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
else:
    print("WARNING: Running on CPU (training will be slow)")

print("Environment check completed.")


PyTorch version: 2.9.0+cpu
CUDA available: False
Environment check completed.


In [None]:
# ============================
# CELL 2: MOUNT GOOGLE DRIVE
# ============================

from google.colab import drive

drive.mount('/content/drive')

print("Google Drive mounted successfully.")


Mounted at /content/drive
Google Drive mounted successfully.


In [None]:
# =========================================
# CELL 3: DATASET EXTRACTION & VERIFICATION
# =========================================

import os
from zipfile import ZipFile

zip_path = "/content/drive/MyDrive/alzheimersdataset.zip"
extract_path = "/content/alzheimersdataset"

# Extract only once
if not os.path.exists(extract_path):
    with ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print(" Dataset extracted successfully.")
else:
    print(" Dataset already extracted.")

# Inspect extracted contents
print("\n Contents inside extracted folder:")
items = os.listdir(extract_path)
print(items)

# List class folders only
print("\n Detected class folders:")
for item in items:
    if os.path.isdir(os.path.join(extract_path, item)):
        print(" -", item)

print("\n Dataset structure check completed.")


 Dataset extracted successfully.

 Contents inside extracted folder:
['Data']

 Detected class folders:
 - Data

 Dataset structure check completed.


In [None]:
# =========================================
# CELL 4: VERIFY CLASS FOLDERS INSIDE /Data
# =========================================

import os

# Correct base path (important!)
data_path = "/content/alzheimersdataset/Data"

print(" Contents inside Data folder:")
print(os.listdir(data_path))

print("\n Class folders detected:")
for item in os.listdir(data_path):
    if os.path.isdir(os.path.join(data_path, item)):
        print(" -", item)

print("\n Data folder structure verified.")


 Contents inside Data folder:
['Moderate Dementia', 'Non Demented', 'Very mild Dementia', 'Mild Dementia']

 Class folders detected:
 - Moderate Dementia
 - Non Demented
 - Very mild Dementia
 - Mild Dementia

 Data folder structure verified.


In [None]:
# =========================================
# CELL 5: STANDARDIZE CLASS FOLDER NAMES
# =========================================

import os
import shutil

# Original data location
src_base = "/content/alzheimersdataset/Data"

# Clean working directory
clean_base = "/content/oasis_clean"
os.makedirs(clean_base, exist_ok=True)

# Mapping original ‚Üí clean names
class_map = {
    "Non Demented": "NonDemented",
    "Very mild Dementia": "VeryMildDemented",
    "Mild Dementia": "MildDemented",
    "Moderate Dementia": "ModerateDemented"
}

for src_name, clean_name in class_map.items():
    src_path = os.path.join(src_base, src_name)
    dst_path = os.path.join(clean_base, clean_name)

    if not os.path.exists(dst_path):
        shutil.copytree(src_path, dst_path)
        print(f" Copied {src_name} ‚Üí {clean_name}")
    else:
        print(f" {clean_name} already exists")

# Verify
print("\n Clean dataset folders:")
print(os.listdir(clean_base))

print("\n Class name standardization complete.")


 Copied Non Demented ‚Üí NonDemented
 Copied Very mild Dementia ‚Üí VeryMildDemented
 Copied Mild Dementia ‚Üí MildDemented
 Copied Moderate Dementia ‚Üí ModerateDemented

 Clean dataset folders:
['NonDemented', 'MildDemented', 'ModerateDemented', 'VeryMildDemented']

 Class name standardization complete.


In [None]:
# =========================================
# CELL 6: CREATE TRAIN / VALIDATION SPLIT
# =========================================

import os
import shutil
from sklearn.model_selection import train_test_split

# Source (cleaned classes)
source_dir = "/content/oasis_clean"

# Output split directory
split_dir = "/content/oasis_split"
train_dir = os.path.join(split_dir, "train")
val_dir = os.path.join(split_dir, "val")

os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

# Split config
VAL_RATIO = 0.2
SEED = 42

classes = sorted(os.listdir(source_dir))

for cls in classes:
    cls_path = os.path.join(source_dir, cls)
    images = sorted(os.listdir(cls_path))

    train_imgs, val_imgs = train_test_split(
        images,
        test_size=VAL_RATIO,
        random_state=SEED,
        shuffle=True
    )

    os.makedirs(os.path.join(train_dir, cls), exist_ok=True)
    os.makedirs(os.path.join(val_dir, cls), exist_ok=True)

    for img in train_imgs:
        shutil.copy(
            os.path.join(cls_path, img),
            os.path.join(train_dir, cls, img)
        )

    for img in val_imgs:
        shutil.copy(
            os.path.join(cls_path, img),
            os.path.join(val_dir, cls, img)
        )

print(" Train/Validation split created.\n")

# Sanity check: class counts
for split, path in [("TRAIN", train_dir), ("VAL", val_dir)]:
    print(f"--> {split} distribution:")
    for cls in classes:
        count = len(os.listdir(os.path.join(path, cls)))
        print(f"  {cls}: {count}")
    print()


 Train/Validation split created.

--> TRAIN distribution:
  MildDemented: 4001
  ModerateDemented: 390
  NonDemented: 53777
  VeryMildDemented: 10980

--> VAL distribution:
  MildDemented: 1001
  ModerateDemented: 98
  NonDemented: 13445
  VeryMildDemented: 2745



In [None]:
# =========================================
# CELL 7: TRANSFORMS + BALANCED DATALOADERS
# =========================================

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler
import numpy as np
import os

# Paths
DATA_DIR = "/content/oasis_split"

# MRI-safe transforms (NO ColorJitter)
data_transforms = {
    "train": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomAffine(
            degrees=15,
            translate=(0.05, 0.05),
            scale=(0.95, 1.05)
        ),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485], std=[0.229])
    ]),
    "val": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485], std=[0.229])
    ])
}

# ImageFolder datasets
image_datasets = {
    x: datasets.ImageFolder(
        os.path.join(DATA_DIR, x),
        transform=data_transforms[x]
    )
    for x in ["train", "val"]
}

class_names = image_datasets["train"].classes
num_classes = len(class_names)

print(" Classes:", class_names)

# --------- CLASS BALANCING (CRITICAL) ---------
targets = image_datasets["train"].targets
class_counts = np.bincount(targets)

print("\n Raw class counts:")
for cls, cnt in zip(class_names, class_counts):
    print(f"{cls}: {cnt}")

class_weights = 1.0 / class_counts
sample_weights = [class_weights[t] for t in targets]

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

# DataLoaders
dataloaders = {
    "train": DataLoader(
        image_datasets["train"],
        batch_size=32,
        sampler=sampler,
        num_workers=2,
        pin_memory=True
    ),
    "val": DataLoader(
        image_datasets["val"],
        batch_size=32,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
}

print("\n DataLoaders created successfully.")


 Classes: ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']

 Raw class counts:
MildDemented: 4001
ModerateDemented: 390
NonDemented: 53777
VeryMildDemented: 10980

 DataLoaders created successfully.


In [None]:
# =========================================
# CELL 8: RESNET50 MODEL[RESIDUAL NETWORK]
# =========================================

import torch
import torch.nn as nn
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def build_resnet50(num_classes):
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

    # Freeze all layers first
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze deeper layers for disease-specific learning
    for layer in [model.layer3, model.layer4]:
        for param in layer.parameters():
            param.requires_grad = True

    # Replace classification head
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.BatchNorm1d(num_features),
        nn.Dropout(0.6),
        nn.Linear(num_features, num_classes)
    )

    return model.to(device)

model = build_resnet50(num_classes)

# Sanity check: trainable parameters
trainable = sum(p.requires_grad for p in model.parameters())
total = sum(1 for _ in model.parameters())

print(f"Trainable parameter tensors: {trainable}/{total}")
print("ResNet50 model ready.")


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:00<00:00, 168MB/s]


Trainable parameter tensors: 91/163
ResNet50 model ready.


In [None]:
# =========================================
# CELL 9: LOSS, OPTIMIZER, SCHEDULER
# =========================================

import torch.optim as optim
import torch.nn as nn
import numpy as np

# Compute class weights from training set
class_counts = np.bincount(image_datasets["train"].targets)
class_weights = 1.0 / class_counts
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

print("Class weights:")
for cls, w in zip(class_names, class_weights):
    print(f"{cls}: {w:.6f}")

# Loss: weighted CrossEntropy (critical)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer: AdamW (better generalization than Adam)
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=3e-4,
    weight_decay=1e-4
)

# Scheduler: cosine annealing (smooth convergence)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=25
)

print("\n Loss, optimizer, and scheduler configured.")


Class weights:
MildDemented: 0.000250
ModerateDemented: 0.002564
NonDemented: 0.000019
VeryMildDemented: 0.000091

 Loss, optimizer, and scheduler configured.


In [None]:
# =========================================
# CELL 10: TRAINING LOOP
# =========================================

import copy
import time

def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=15):
    since = time.time()

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": []
    }

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 30)

        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            total_samples = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels)
                total_samples += labels.size(0)

            epoch_loss = running_loss / total_samples
            epoch_acc = running_corrects.double() / total_samples

            history[f"{phase}_loss"].append(epoch_loss)
            history[f"{phase}_acc"].append(epoch_acc.item())

            print(f"{phase.capitalize()} Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}")

            # Save best model based on validation accuracy
            if phase == "val" and epoch_acc > best_val_acc:
                best_val_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        scheduler.step()

    time_elapsed = time.time() - since
    print(f"\n‚è±Ô∏è Training completed in {time_elapsed//60:.0f}m {time_elapsed%60:.0f}s")
    print(f"üèÜ Best Validation Accuracy: {best_val_acc:.4f}")

    model.load_state_dict(best_model_wts)
    return model, history


# üî• Start training
model, history = train_model(
    model,
    dataloaders,
    criterion,
    optimizer,
    scheduler,
    num_epochs=15
)



Epoch 1/15
------------------------------
Train Loss: 0.0935 | Acc: 0.7617
Val Loss: 0.7205 | Acc: 0.7264

Epoch 2/15
------------------------------
Train Loss: 0.0312 | Acc: 0.8727
Val Loss: 0.8002 | Acc: 0.7446

Epoch 3/15
------------------------------
Train Loss: 0.0322 | Acc: 0.8799
Val Loss: 0.7350 | Acc: 0.7665

Epoch 4/15
------------------------------
Train Loss: 0.0197 | Acc: 0.9058
Val Loss: 0.6896 | Acc: 0.7755

Epoch 5/15
------------------------------
Train Loss: 0.0188 | Acc: 0.9104
Val Loss: 0.5514 | Acc: 0.8219

Epoch 6/15
------------------------------
Train Loss: 0.0170 | Acc: 0.9199
Val Loss: 0.4691 | Acc: 0.8434

Epoch 7/15
------------------------------
Train Loss: 0.0146 | Acc: 0.9289
Val Loss: 0.5841 | Acc: 0.8048

Epoch 8/15
------------------------------
Train Loss: 0.0133 | Acc: 0.9340
Val Loss: 0.2951 | Acc: 0.8957

Epoch 9/15
------------------------------
Train Loss: 0.0091 | Acc: 0.9461
Val Loss: 0.3611 | Acc: 0.8782

Epoch 10/15
------------------------

In [None]:
# =========================================
# CELL 11: LOSS & ACCURACY CURVES
# =========================================

import matplotlib.pyplot as plt

epochs = range(1, len(history["train_loss"]) + 1)

plt.figure(figsize=(14, 5))

# -------- Loss Curve --------
plt.subplot(1, 2, 1)
plt.plot(epochs, history["train_loss"], label="Train Loss", linewidth=2)
plt.plot(epochs, history["val_loss"], label="Validation Loss", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(alpha=0.3)

# -------- Accuracy Curve --------
plt.subplot(1, 2, 2)
plt.plot(epochs, history["train_acc"], label="Train Accuracy", linewidth=2)
plt.plot(epochs, history["val_acc"], label="Validation Accuracy", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training vs Validation Accuracy")
plt.legend()
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(" Training curves plotted successfully.")


NameError: name 'history' is not defined

In [None]:
# =========================================
# CELL 12: CONFUSION MATRIX (VALIDATION)
# =========================================

import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in dataloaders["val"]:
        inputs = inputs.to(device)
        outputs = model(inputs)
        preds = torch.argmax(outputs, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=class_names,
    yticklabels=class_names
)

plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Alzheimer‚Äôs Stage Confusion Matrix (Validation)")
plt.tight_layout()
plt.show()

print(" Confusion matrix generated.")


NameError: name 'model' is not defined