In [1]:
import os
import json
import time
import copy
import random
from collections import Counter

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.cuda.amp import autocast, GradScaler

import torchvision.models as models
import torchvision.transforms as T
from PIL import Image

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, roc_auc_score
)

from tqdm import tqdm

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Hyperparameters / Config
IMG_SIZE = 256
BATCH_SIZE = 32
NUM_EPOCHS = 20
LR_HEAD = 3e-4
LR_BACKBONE = 1e-4
WEIGHT_DECAY = 1e-4
LABEL_SMOOTH = 0.02
HEAD_WARMUP_EPOCHS = 3
USE_MIXUP = False
MIXUP_ALPHA = 0.2
USE_TTA = True



Using device: cuda


In [2]:
# Dataset discovery from KaggleHub-style layout or explicit path
EXPECTED_CLASSES = ['Non Demented', 'Very mild Dementia', 'Mild Dementia', 'Moderate Dementia']

def resolve_dataset_path(root_candidates):
    for root in root_candidates:
        if not root:
            continue
        candidates = [os.path.join(root, 'Data'), root]
        for cand in candidates:
            if os.path.isdir(cand):
                subdirs = [d for d in os.listdir(cand) if os.path.isdir(os.path.join(cand, d))]
                found = [d for d in subdirs if d in EXPECTED_CLASSES]
                if len(found) == len(EXPECTED_CLASSES):
                    return cand, found
                if len(found) > 0:
                    return cand, found
    return None, []


DATASET_ROOT = os.environ.get('ALZ_DATASET_ROOT', None)


common_roots = [
    DATASET_ROOT,
    os.path.expanduser(r"~/.cache/kagglehub/datasets/ninadaithal/imagesoasis/versions/1"),
    os.path.expanduser(r"~/kaggle/input/imagesoasis"),
    r"C:\\Users\\sajib\\.cache\\kagglehub\\datasets\\ninadaithal\\imagesoasis\\versions\\1",
]

dataset_path, class_names = resolve_dataset_path(common_roots)
if not dataset_path:
    raise FileNotFoundError("Could not resolve dataset path. Set ALZ_DATASET_ROOT to the dataset root or adjust common_roots.")

print(f"Dataset path: {dataset_path}")
print(f"Classes: {class_names}")



Dataset path: C:\Users\sajib/.cache/kagglehub/datasets/ninadaithal/imagesoasis/versions/1\Data
Classes: ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']


In [3]:

class AlzheimerDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.class_to_idx = {
            'Non Demented': 0,
            'Very mild Dementia': 1,
            'Mild Dementia': 2,
            'Moderate Dementia': 3
        }
        if len(self.labels) > 0 and isinstance(self.labels[0], str):
            self.labels = [self.class_to_idx[lbl] for lbl in self.labels]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('L')
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label


def get_transforms(img_size=224, is_training=False):
    if is_training:
        return T.Compose([
            T.RandomResizedCrop(img_size, scale=(0.8, 1.0), ratio=(0.95, 1.05)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomRotation(degrees=10),
            T.Grayscale(num_output_channels=3),
            T.ColorJitter(brightness=0.15, contrast=0.15),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        return T.Compose([
            T.Resize(int(img_size*1.14)),
            T.CenterCrop(img_size),
            T.Grayscale(num_output_channels=3),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])



In [4]:


def load_and_split(dataset_path, class_names, test_size=0.15, val_size=0.15, random_state=SEED):
    image_paths, labels = [], []
    for cls in class_names:
        cls_dir = os.path.join(dataset_path, cls)
        files = [f for f in os.listdir(cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        for f in files:
            image_paths.append(os.path.join(cls_dir, f))
            labels.append(cls)
    print(f"Total images: {len(image_paths)}")

    X_temp, X_test, y_temp, y_test = train_test_split(
        image_paths, labels, test_size=test_size, stratify=labels, random_state=random_state
    )
    val_size_adjusted = val_size / (1 - test_size)
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size_adjusted, stratify=y_temp, random_state=random_state
    )
    return (X_train, y_train), (X_val, y_val), (X_test, y_test)


def make_loaders(dataset_path, class_names, batch_size=BATCH_SIZE, img_size=IMG_SIZE, num_workers=0):
    (X_train, y_train), (X_val, y_val), (X_test, y_test) = load_and_split(dataset_path, class_names)

    train_tf = get_transforms(img_size, True)
    eval_tf = get_transforms(img_size, False)

    train_ds = AlzheimerDataset(X_train, y_train, train_tf)
    val_ds = AlzheimerDataset(X_val, y_val, eval_tf)
    test_ds = AlzheimerDataset(X_test, y_test, eval_tf)

   
    num_classes = 4
    train_label_tensor = torch.tensor(train_ds.labels, dtype=torch.long)
    binc = torch.bincount(train_label_tensor, minlength=num_classes).float()
    binc[binc == 0] = 1.0
    inv_sqrt = (1.0 / torch.sqrt(binc))
    class_weights = (inv_sqrt / inv_sqrt.sum()) * num_classes  # moderate
    print(f"Class weights: {class_weights.numpy()}")

   
    per_sample_w = (1.0 / torch.sqrt(binc))[train_label_tensor]
    sampler = WeightedRandomSampler(weights=per_sample_w, num_samples=len(per_sample_w), replacement=True)

    loaders = {
        'train': DataLoader(train_ds, batch_size=batch_size, sampler=sampler, shuffle=False, num_workers=num_workers, pin_memory=False),
        'val': DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False),
        'test': DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False),
        'class_weights': class_weights
    }
    return loaders, {'class_names': class_names}

loaders, meta = make_loaders(dataset_path, class_names)



Total images: 86437
Class weights: [0.21495749 0.47572607 0.7879393  2.521377  ]


In [5]:
# Models

def create_model(name: str, num_classes: int = 4, pretrained: bool = True) -> nn.Module:
    if name == 'vgg16':
        m = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None)
        m.classifier[6] = nn.Linear(4096, num_classes)
        return m
    if name == 'vgg19':
        m = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1 if pretrained else None)
        m.classifier[6] = nn.Linear(4096, num_classes)
        return m
    if name == 'resnet50':
        m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    if name == 'resnet101':
        m = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V2 if pretrained else None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    if name == 'resnet152':
        m = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V2 if pretrained else None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    if name == 'densenet121':
        m = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None)
        m.classifier = nn.Linear(m.classifier.in_features, num_classes)
        return m
    if name == 'densenet201':
        m = models.densenet201(weights=models.DenseNet201_Weights.IMAGENET1K_V1 if pretrained else None)
        m.classifier = nn.Linear(m.classifier.in_features, num_classes)
        return m
    if name == 'mobilenetv3_large':
        m = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2 if pretrained else None)
        m.classifier[3] = nn.Linear(m.classifier[3].in_features, num_classes)
        return m
    if name == 'shufflenet_v2_x1_0':
        m = models.shufflenet_v2_x1_0(weights=models.ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1 if pretrained else None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    raise ValueError(f"Unknown model: {name}")

ALL_MODELS = [
    'vgg16', 'vgg19',
    'resnet50', 'resnet101', 'resnet152',
    'densenet121', 'densenet201',
    'mobilenetv3_large', 'shufflenet_v2_x1_0'
]



In [6]:


def one_hot_targets(labels, num_classes=4, smoothing=LABEL_SMOOTH):
    with torch.no_grad():
        true_dist = torch.zeros((labels.size(0), num_classes), device=labels.device)
        true_dist.fill_(smoothing / (num_classes - 1))
        true_dist.scatter_(1, labels.unsqueeze(1), 1.0 - smoothing)
    return true_dist


def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, (y_a, y_b), lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train_epoch(model, loader, criterion, optimizer, epoch):
    
    use_mixup = (epoch < max(HEAD_WARMUP_EPOCHS + 2, 4)) and USE_MIXUP
    model.train()
    scaler = GradScaler('cuda', enabled=torch.cuda.is_available())
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc=f"Train {epoch+1}")
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            if use_mixup:
                images, (ya, yb), lam = mixup_data(images, labels, MIXUP_ALPHA)
                outputs = model(images)
                loss = mixup_criterion(criterion, outputs, ya, yb, lam)
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(1)
        correct += (preds == labels).sum().item()
        total += images.size(0)
        pbar.set_postfix(loss=running_loss/total, acc=correct/total)
    return running_loss/total, correct/total


def eval_epoch(model, loader, criterion, epoch, phase="Val"):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc=f"{phase} {epoch+1}")
    all_probs, all_labels, all_preds = [], [], []
    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            with autocast('cuda', enabled=torch.cuda.is_available()):
                outputs = model(images)
                loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            probs = F.softmax(outputs, dim=1)
            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += images.size(0)
            all_probs.append(probs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_preds.append(preds.detach().cpu())
            pbar.set_postfix(loss=running_loss/total, acc=correct/total)
    all_probs = torch.cat(all_probs).numpy()
    all_labels = torch.cat(all_labels).numpy()
    all_preds = torch.cat(all_preds).numpy()
    return running_loss/total, correct/total, all_probs, all_labels, all_preds



In [7]:


def train_model(name, loaders, num_epochs=NUM_EPOCHS, lr=LR_BACKBONE, weight_decay=WEIGHT_DECAY):
    print(f"\n==== Training {name} ====")
    model = create_model(name, num_classes=4, pretrained=True).to(device)
    class_weights = loaders['class_weights'].to(device)
    base_criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)


    head_params, backbone_params = [], []
    for n, p in model.named_parameters():
        if any(k in n for k in ['fc', 'classifier']):
            head_params.append(p)
        else:
            backbone_params.append(p)
    optimizer = optim.AdamW([
        {'params': backbone_params, 'lr': lr},
        {'params': head_params, 'lr': LR_HEAD}
    ], weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_acc = 0.0
    best_state = None

    for epoch in range(num_epochs):
        if epoch < HEAD_WARMUP_EPOCHS:
            for p in backbone_params:
                p.requires_grad = False
        else:
            for p in backbone_params:
                p.requires_grad = True

        tr_loss, tr_acc = train_epoch(model, loaders['train'], base_criterion, optimizer, epoch)
        va_loss, va_acc, va_probs, va_labels, va_preds = eval_epoch(model, loaders['val'], base_criterion, epoch, phase='Val')
        history['train_loss'].append(tr_loss)
        history['train_acc'].append(tr_acc)
        history['val_loss'].append(va_loss)
        history['val_acc'].append(va_acc)
        scheduler.step()
        # extra metrics
        va_precision = precision_score(va_labels, va_preds, average='macro', zero_division=0)
        va_recall = recall_score(va_labels, va_preds, average='macro', zero_division=0)
        va_f1 = f1_score(va_labels, va_preds, average='macro', zero_division=0)
        print(f"Epoch {epoch+1}: train_acc={tr_acc:.4f} val_acc={va_acc:.4f} val_f1={va_f1:.4f} val_prec={va_precision:.4f} val_rec={va_recall:.4f}")
        if va_acc > best_acc:
            best_acc = va_acc
            best_state = copy.deepcopy(model.state_dict())

    if best_state is not None:
        model.load_state_dict(best_state)
    te_loss, te_acc, te_probs, te_labels, te_preds = eval_epoch(model, loaders['test'], base_criterion, -1, phase='Test')

    if USE_TTA:
        model.eval()
        with torch.no_grad():
            tta_probs = []
            for images, _ in tqdm(loaders['test'], desc='TTA'):
                images = images.to(device)
                p1 = F.softmax(model(images), dim=1)
                p2 = F.softmax(model(torch.flip(images, dims=[3])), dim=1)
                tta_probs.append(((p1 + p2) / 2).cpu())
            te_probs = torch.cat(tta_probs).numpy()

    try:
        y_true_bin = F.one_hot(torch.tensor(te_labels), num_classes=4).numpy()
        auc_macro = roc_auc_score(y_true_bin, te_probs, average='macro', multi_class='ovr')
    except Exception:
        auc_macro = 0.0

    # detailed test metrics
    te_precision = precision_score(te_labels, te_preds, average='macro', zero_division=0)
    te_recall = recall_score(te_labels, te_preds, average='macro', zero_division=0)
    te_f1 = f1_score(te_labels, te_preds, average='macro', zero_division=0)

    print(f"Test: acc={te_acc:.4f} f1={te_f1:.4f} prec={te_precision:.4f} rec={te_recall:.4f} auc_macro={auc_macro:.4f}")
    print(f"Model: {name} | ImgSize: {IMG_SIZE} | Batch: {BATCH_SIZE} | Epochs: {NUM_EPOCHS} | LR(head/backbone): {LR_HEAD}/{LR_BACKBONE} | Warmup: {HEAD_WARMUP_EPOCHS}")

    return {
        'model_name': name,
        'best_val_acc': best_acc,
        'test_acc': te_acc,
        'test_f1_macro': te_f1,
        'test_precision_macro': te_precision,
        'test_recall_macro': te_recall,
        'test_auc_macro': auc_macro,
        'config': {
            'img_size': IMG_SIZE,
            'batch_size': BATCH_SIZE,
            'num_epochs': NUM_EPOCHS,
            'lr_head': LR_HEAD,
            'lr_backbone': LR_BACKBONE,
            'weight_decay': WEIGHT_DECAY,
            'label_smooth': LABEL_SMOOTH,
            'warmup_epochs': HEAD_WARMUP_EPOCHS,
            'mixup': USE_MIXUP
        },
        'history': history,
        'state_dict': best_state
    }



In [None]:

results = []
os.makedirs('models', exist_ok=True)


priority_models = [
    'resnet50', 'resnet101', 'resnet152'
]

for name in priority_models:
    try:
        print(f"Training {name}")
        res = train_model(name, loaders)
        results.append(res)
        torch.save({
            'model_name': name,
            'state_dict': res['state_dict'],
            'history': res['history'],
            'meta': meta,
            'results': {k: res[k] for k in ['best_val_acc', 'test_acc', 'test_auc_macro']}
        }, os.path.join('models', f'{name}_finetuned.pth'))
    except Exception as e:
        print(f"Error training {name}: {e}")


if results:
    df = pd.DataFrame([
        {
            'model': r['model_name'],
            'val_acc': r['best_val_acc'],
            'test_acc': r['test_acc'],
            'test_f1_macro': r.get('test_f1_macro', None),
            'test_precision_macro': r.get('test_precision_macro', None),
            'test_recall_macro': r.get('test_recall_macro', None),
            'test_auc_macro': r['test_auc_macro'],
            'img_size': r['config']['img_size'],
            'batch_size': r['config']['batch_size'],
            'epochs': r['config']['num_epochs'],
            'lr_head': r['config']['lr_head'],
            'lr_backbone': r['config']['lr_backbone'],
            'weight_decay': r['config']['weight_decay'],
            'label_smooth': r['config']['label_smooth'],
            'warmup_epochs': r['config']['warmup_epochs'],
            'mixup': r['config']['mixup']
        }
        for r in results
    ])
    print(df.sort_values('test_acc', ascending=False))
    df.to_csv('models/summary.csv', index=False)
else:
    print("No models trained.")


Training resnet50

==== Training resnet50 ====


  scaler = GradScaler('cuda', enabled=torch.cuda.is_available())
  with autocast(enabled=torch.cuda.is_available()):
Train 1:   0%|          | 0/1891 [00:00<?, ?it/s]


Error training resnet50: full() received an invalid combination of arguments - got (tuple, str, device=torch.device, dtype=torch.dtype), but expected one of:
 * (tuple of ints size, Number fill_value, *, tuple of names names, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, Number fill_value, *, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)

Training resnet101

==== Training resnet101 ====


Train 1:   0%|          | 0/1891 [00:00<?, ?it/s]


Error training resnet101: full() received an invalid combination of arguments - got (tuple, str, device=torch.device, dtype=torch.dtype), but expected one of:
 * (tuple of ints size, Number fill_value, *, tuple of names names, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, Number fill_value, *, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)

Training resnet152

==== Training resnet152 ====
Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to C:\Users\sajib/.cache\torch\hub\checkpoints\resnet152-f82ba261.pth


 27%|██▋       | 62.5M/230M [00:40<01:02, 2.81MB/s]