<a href="https://colab.research.google.com/github/ssudhanshu488/SwinOnAlziehmer/blob/main/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import timm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np

In [3]:
# Load dataset
def load_dataset(image_folder):
    image_paths = []
    labels = []
    for image_name in os.listdir(image_folder):
        image_path = os.path.join(image_folder, image_name)
        label = image_name.split('_')[0]
        image_paths.append(image_path)
        labels.append(label)
    return image_paths, labels

In [4]:
# Preprocess dataset
def preprocess_dataset(image_paths, labels):
    label_encoder = LabelEncoder()
    labels = label_encoder.fit_transform(labels)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, random_state=42, stratify=labels)
    return train_paths, val_paths, train_labels, val_labels, label_encoder


In [5]:
# Define Dataset class
class AlzheimerDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [6]:
# Training function with additional metrics
def train_model(train_loader, val_loader, model, criterion, optimizer, scheduler, num_epochs, device):
    best_accuracy = 0.0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        scheduler.step()

        # Evaluation phase
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Compute metrics
        accuracy = accuracy_score(all_labels, all_preds)
        report = classification_report(all_labels, all_preds, target_names=['AD', 'CN', 'MCI'], digits=4, output_dict=True)
        precision = report['weighted avg']['precision']
        recall = report['weighted avg']['recall']  # Sensitivity
        f1 = report['weighted avg']['f1-score']

        # Compute specificity
        cm = confusion_matrix(all_labels, all_preds)
        specificity = []
        for i in range(len(cm)):
            tn = cm.sum() - (cm[i, :].sum() + cm[:, i].sum() - cm[i, i])
            fp = cm[:, i].sum() - cm[i, i]
            specificity.append(tn / (tn + fp) if (tn + fp) > 0 else 0)
        avg_specificity = sum(specificity) / len(specificity)

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}, "
              f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, "
              f"Sensitivity (Recall): {recall:.4f}, F1-Score: {f1:.4f}, Specificity: {avg_specificity:.4f}")

        # Save the best model
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_state = model.state_dict()

    return best_model_state, best_accuracy, precision, recall, f1, avg_specificity


In [7]:
# Define image folders
folder_1 = '/content/Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI'
folder_2 = '/content/Slices_Separate_Folders_T1_weighted/cr_AD_CN_MCI'
folder_3 = '/content/Slices_Separate_Folders_T1_weighted/sg_AD_CN_MCI'


In [8]:
# Define transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [9]:
# Run experiment function
def run_experiment(combination_name, folder_a, folder_b):
    print(f"Training on: {combination_name}")
    image_paths_a, labels_a = load_dataset(folder_a)
    image_paths_b, labels_b = load_dataset(folder_b)
    image_paths = image_paths_a + image_paths_b
    labels = labels_a + labels_b

    train_paths, val_paths, train_labels, val_labels, _ = preprocess_dataset(image_paths, labels)
    train_dataset = AlzheimerDataset(train_paths, train_labels, transform=train_transform)
    val_dataset = AlzheimerDataset(val_paths, val_labels, transform=val_transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

    model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=3)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

    best_model_state, best_accuracy, precision, recall, f1, specificity = train_model(
        train_loader, val_loader, model, criterion, optimizer, scheduler, 10, device)

    print(f"Best Accuracy for {combination_name}: {best_accuracy:.4f}\n")
    torch.save(best_model_state, f"best_model_{combination_name}.pth")

    return best_accuracy

In [10]:
#Run and store accuracy values
acc_ax_cr = run_experiment("ax_cr", folder_1, folder_2)
acc_ax_sg = run_experiment("ax_sg", folder_1, folder_3)
acc_cr_sg = run_experiment("cr_sg", folder_2, folder_3)

Training on: ax_cr


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

Epoch 1/10, Loss: 1.0209, Accuracy: 0.6642, Precision: 0.6636, Sensitivity (Recall): 0.6642, F1-Score: 0.6519, Specificity: 0.8321




Epoch 2/10, Loss: 0.8492, Accuracy: 0.6623, Precision: 0.7363, Sensitivity (Recall): 0.6623, F1-Score: 0.6186, Specificity: 0.8313




Epoch 3/10, Loss: 0.7221, Accuracy: 0.7444, Precision: 0.7446, Sensitivity (Recall): 0.7444, F1-Score: 0.7420, Specificity: 0.8722




Epoch 4/10, Loss: 0.6583, Accuracy: 0.7966, Precision: 0.7986, Sensitivity (Recall): 0.7966, F1-Score: 0.7932, Specificity: 0.8984




Epoch 5/10, Loss: 0.5406, Accuracy: 0.8153, Precision: 0.8365, Sensitivity (Recall): 0.8153, F1-Score: 0.8156, Specificity: 0.9076




Epoch 6/10, Loss: 0.4530, Accuracy: 0.8787, Precision: 0.8799, Sensitivity (Recall): 0.8787, F1-Score: 0.8781, Specificity: 0.9394




Epoch 7/10, Loss: 0.3728, Accuracy: 0.9160, Precision: 0.9171, Sensitivity (Recall): 0.9160, F1-Score: 0.9156, Specificity: 0.9580




Epoch 8/10, Loss: 0.3499, Accuracy: 0.9440, Precision: 0.9439, Sensitivity (Recall): 0.9440, F1-Score: 0.9440, Specificity: 0.9720




Epoch 9/10, Loss: 0.3273, Accuracy: 0.9459, Precision: 0.9464, Sensitivity (Recall): 0.9459, F1-Score: 0.9460, Specificity: 0.9729




Epoch 10/10, Loss: 0.3202, Accuracy: 0.9459, Precision: 0.9464, Sensitivity (Recall): 0.9459, F1-Score: 0.9459, Specificity: 0.9729
Best Accuracy for ax_cr: 0.9459

Training on: ax_sg




Epoch 1/10, Loss: 1.0218, Accuracy: 0.5187, Precision: 0.5530, Sensitivity (Recall): 0.5187, F1-Score: 0.5187, Specificity: 0.7594




Epoch 2/10, Loss: 0.8999, Accuracy: 0.5989, Precision: 0.6652, Sensitivity (Recall): 0.5989, F1-Score: 0.5240, Specificity: 0.7989




Epoch 3/10, Loss: 0.7982, Accuracy: 0.7183, Precision: 0.7250, Sensitivity (Recall): 0.7183, F1-Score: 0.7144, Specificity: 0.8592




Epoch 4/10, Loss: 0.6991, Accuracy: 0.6884, Precision: 0.7253, Sensitivity (Recall): 0.6884, F1-Score: 0.6830, Specificity: 0.8445




Epoch 5/10, Loss: 0.5953, Accuracy: 0.8470, Precision: 0.8496, Sensitivity (Recall): 0.8470, F1-Score: 0.8455, Specificity: 0.9235




Epoch 6/10, Loss: 0.4886, Accuracy: 0.8806, Precision: 0.8861, Sensitivity (Recall): 0.8806, F1-Score: 0.8816, Specificity: 0.9404




Epoch 7/10, Loss: 0.3951, Accuracy: 0.9142, Precision: 0.9159, Sensitivity (Recall): 0.9142, F1-Score: 0.9139, Specificity: 0.9571




Epoch 8/10, Loss: 0.3640, Accuracy: 0.9366, Precision: 0.9375, Sensitivity (Recall): 0.9366, F1-Score: 0.9366, Specificity: 0.9683




Epoch 9/10, Loss: 0.3407, Accuracy: 0.9328, Precision: 0.9342, Sensitivity (Recall): 0.9328, F1-Score: 0.9328, Specificity: 0.9665




Epoch 10/10, Loss: 0.3355, Accuracy: 0.9440, Precision: 0.9463, Sensitivity (Recall): 0.9440, F1-Score: 0.9442, Specificity: 0.9721
Best Accuracy for ax_sg: 0.9440

Training on: cr_sg




Epoch 1/10, Loss: 0.9947, Accuracy: 0.6236, Precision: 0.6778, Sensitivity (Recall): 0.6236, F1-Score: 0.6286, Specificity: 0.8115




Epoch 2/10, Loss: 0.8838, Accuracy: 0.5768, Precision: 0.7212, Sensitivity (Recall): 0.5768, F1-Score: 0.4680, Specificity: 0.7888




Epoch 3/10, Loss: 0.8197, Accuracy: 0.7509, Precision: 0.7502, Sensitivity (Recall): 0.7509, F1-Score: 0.7505, Specificity: 0.8754




Epoch 4/10, Loss: 0.6810, Accuracy: 0.7528, Precision: 0.8084, Sensitivity (Recall): 0.7528, F1-Score: 0.7540, Specificity: 0.8762




Epoch 5/10, Loss: 0.5831, Accuracy: 0.8708, Precision: 0.8761, Sensitivity (Recall): 0.8708, F1-Score: 0.8703, Specificity: 0.9354




Epoch 6/10, Loss: 0.4801, Accuracy: 0.8914, Precision: 0.8944, Sensitivity (Recall): 0.8914, F1-Score: 0.8909, Specificity: 0.9457




Epoch 7/10, Loss: 0.4365, Accuracy: 0.9139, Precision: 0.9153, Sensitivity (Recall): 0.9139, F1-Score: 0.9127, Specificity: 0.9569




Epoch 8/10, Loss: 0.3714, Accuracy: 0.9438, Precision: 0.9441, Sensitivity (Recall): 0.9438, F1-Score: 0.9436, Specificity: 0.9719




Epoch 9/10, Loss: 0.3514, Accuracy: 0.9476, Precision: 0.9480, Sensitivity (Recall): 0.9476, F1-Score: 0.9474, Specificity: 0.9738




Epoch 10/10, Loss: 0.3416, Accuracy: 0.9607, Precision: 0.9612, Sensitivity (Recall): 0.9607, F1-Score: 0.9605, Specificity: 0.9804
Best Accuracy for cr_sg: 0.9607



In [12]:
# Print ordered results
print("\nFinal Ordered Accuracies:")
print(f"ax_cr: {acc_ax_cr:.4f}")
print(f"ax_sg: {acc_ax_sg:.4f}")
print(f"cr_sg: {acc_cr_sg:.4f}")


Final Ordered Accuracies:
ax_cr: 0.9459
ax_sg: 0.9440
cr_sg: 0.9340


In [15]:
# Load dataset paths again
folder_1 = '/content/Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI'
folder_2 = '/content/Slices_Separate_Folders_T1_weighted/cr_AD_CN_MCI'
folder_3 = '/content/Slices_Separate_Folders_T1_weighted/sg_AD_CN_MCI'

# Reload dataset
image_paths_1, labels_1 = load_dataset(folder_1)
image_paths_2, labels_2 = load_dataset(folder_2)
image_paths_3, labels_3 = load_dataset(folder_3)

# Combine paths for validation dataset
image_paths = image_paths_1 + image_paths_2 + image_paths_3
labels = labels_1 + labels_2 + labels_3

# Re-split dataset into train & validation
_, val_paths, _, val_labels, _ = preprocess_dataset(image_paths, labels)

# Define transformations
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create validation dataset & dataloader
val_dataset = AlzheimerDataset(val_paths, val_labels, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

print("✅ Validation dataset reloaded successfully!")


✅ Validation dataset reloaded successfully!




In [16]:
# Load trained models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_ax_cr = load_model("best_model_ax_cr.pth", device)
model_ax_sg = load_model("best_model_ax_sg.pth", device)
model_cr_sg = load_model("best_model_cr_sg.pth", device)

# Evaluate each model
metrics_ax_cr = evaluate_model(model_ax_cr, val_loader, device)
metrics_ax_sg = evaluate_model(model_ax_sg, val_loader, device)
metrics_cr_sg = evaluate_model(model_cr_sg, val_loader, device)

# Unpack metrics
acc_ax_cr, prec_ax_cr, sens_ax_cr, f1_ax_cr, spec_ax_cr = metrics_ax_cr
acc_ax_sg, prec_ax_sg, sens_ax_sg, f1_ax_sg, spec_ax_sg = metrics_ax_sg
acc_cr_sg, prec_cr_sg, sens_cr_sg, f1_cr_sg, spec_cr_sg = metrics_cr_sg

# Print results
print("\nFinal Ordered Metrics:")
print(f"{'Model':<10} {'Accuracy':<10} {'Precision':<10} {'Sensitivity':<12} {'F1-Score':<10} {'Specificity':<10}")
print(f"{'ax_cr':<10} {acc_ax_cr:.4f}    {prec_ax_cr:.4f}    {sens_ax_cr:.4f}       {f1_ax_cr:.4f}    {spec_ax_cr:.4f}")
print(f"{'ax_sg':<10} {acc_ax_sg:.4f}    {prec_ax_sg:.4f}    {sens_ax_sg:.4f}       {f1_ax_sg:.4f}    {spec_ax_sg:.4f}")
print(f"{'cr_sg':<10} {acc_cr_sg:.4f}    {prec_cr_sg:.4f}    {sens_cr_sg:.4f}       {f1_cr_sg:.4f}    {spec_cr_sg:.4f}")


  model.load_state_dict(torch.load(model_path, map_location=device))



Final Ordered Metrics:
Model      Accuracy   Precision  Sensitivity  F1-Score   Specificity
ax_cr      0.8257    0.8377    0.8257       0.8247    0.9128
ax_sg      0.8655    0.8659    0.8655       0.8652    0.9327
cr_sg      0.8207    0.8268    0.8207       0.8216    0.9103
