<a href="https://colab.research.google.com/github/ssudhanshu488/SwinOnAlziehmer/blob/main/letsseeifitworks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import timm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

In [1]:
!unzip /content/Slices_Separate_Folders_T1_weighted.zip

Archive:  /content/Slices_Separate_Folders_T1_weighted.zip
   creating: Slices_Separate_Folders_T1_weighted/
   creating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/
  inflating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/AD_ax_1.jpg  
  inflating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/AD_ax_10.jpg  
  inflating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/AD_ax_100.jpg  
  inflating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/AD_ax_101.jpg  
  inflating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/AD_ax_102.jpg  
  inflating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/AD_ax_103.jpg  
  inflating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/AD_ax_104.jpg  
  inflating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/AD_ax_105.jpg  
  inflating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/AD_ax_106.jpg  
  inflating: Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI/AD_ax_107.jpg  
  inflating: Slices_Separate_Folders_T1_weighted/ax

In [3]:
def load_dataset(image_folder):
    image_paths = []
    labels = []
    for image_name in os.listdir(image_folder):
        image_path = os.path.join(image_folder, image_name)
        label = image_name.split('_')[0]
        image_paths.append(image_path)
        labels.append(label)
    return image_paths, labels

In [4]:
def preprocess_dataset(image_paths, labels):
    label_encoder = LabelEncoder()
    labels = label_encoder.fit_transform(labels)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, random_state=42, stratify=labels)
    return train_paths, val_paths, train_labels, val_labels, label_encoder

In [5]:
class AlzheimerDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [6]:
from sklearn.metrics import classification_report, confusion_matrix

def train_model(train_loader, val_loader, model, criterion, optimizer, scheduler, num_epochs, device):
    best_accuracy = 0.0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation Phase
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_preds)
        report = classification_report(all_labels, all_preds, target_names=['AD', 'CN', 'MCI'], digits=4, output_dict=True)
        precision = report['weighted avg']['precision']
        recall = report['weighted avg']['recall']  # Sensitivity
        f1 = report['weighted avg']['f1-score']

        # Compute Specificity
        cm = confusion_matrix(all_labels, all_preds)
        specificity = []
        for i in range(len(cm)):
            tn = cm.sum() - (cm[i, :].sum() + cm[:, i].sum() - cm[i, i])
            fp = cm[:, i].sum() - cm[i, i]
            specificity.append(tn / (tn + fp) if (tn + fp) > 0 else 0)
        avg_specificity = sum(specificity) / len(specificity)

        # Print Metrics
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}, "
              f"Validation Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, "
              f"Recall (Sensitivity): {recall:.4f}, F1-Score: {f1:.4f}, Specificity: {avg_specificity:.4f}")

        # Save best model
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_state = model.state_dict()

    return best_model_state, best_accuracy


In [7]:
# Folders
folder_1 = '/content/Slices_Separate_Folders_T1_weighted/ax_AD_CN_MCI'
folder_2 = '/content/Slices_Separate_Folders_T1_weighted/cr_AD_CN_MCI'
folder_3 = '/content/Slices_Separate_Folders_T1_weighted/sg_AD_CN_MCI'

In [8]:
# Define transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [9]:
def run_experiment(combination_name, folder_a, folder_b):
    print(f"Training on: {combination_name}")
    image_paths_a, labels_a = load_dataset(folder_a)
    image_paths_b, labels_b = load_dataset(folder_b)
    image_paths = image_paths_a + image_paths_b
    labels = labels_a + labels_b
    train_paths, val_paths, train_labels, val_labels, _ = preprocess_dataset(image_paths, labels)
    train_dataset = AlzheimerDataset(train_paths, train_labels, transform=train_transform)
    val_dataset = AlzheimerDataset(val_paths, val_labels, transform=val_transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

    model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=3)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    best_model_state, best_accuracy = train_model(train_loader, val_loader, model, criterion, optimizer, scheduler, 10, device)
    print(f"Best Accuracy for {combination_name}: {best_accuracy:.4f}\n")
    torch.save(best_model_state, f"best_model_{combination_name}.pth")

In [10]:
# Run experiments
run_experiment("Folder1+Folder2", folder_1, folder_2)
run_experiment("Folder2+Folder3", folder_2, folder_3)
run_experiment("Folder1+Folder3", folder_1, folder_3)

Training on: Folder1+Folder2


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

Epoch 1/10, Loss: 1.0051, Validation Accuracy: 0.5131, Precision: 0.6947, Recall (Sensitivity): 0.5131, F1-Score: 0.4741, Specificity: 0.7563




Epoch 2/10, Loss: 0.8938, Validation Accuracy: 0.6399, Precision: 0.6810, Recall (Sensitivity): 0.6399, F1-Score: 0.6133, Specificity: 0.8201




Epoch 3/10, Loss: 0.7506, Validation Accuracy: 0.7369, Precision: 0.7479, Recall (Sensitivity): 0.7369, F1-Score: 0.7341, Specificity: 0.8684




Epoch 4/10, Loss: 0.6551, Validation Accuracy: 0.7780, Precision: 0.7764, Recall (Sensitivity): 0.7780, F1-Score: 0.7768, Specificity: 0.8890




Epoch 5/10, Loss: 0.5420, Validation Accuracy: 0.8619, Precision: 0.8678, Recall (Sensitivity): 0.8619, F1-Score: 0.8617, Specificity: 0.9309




Epoch 6/10, Loss: 0.4743, Validation Accuracy: 0.8675, Precision: 0.8769, Recall (Sensitivity): 0.8675, F1-Score: 0.8672, Specificity: 0.9338




Epoch 7/10, Loss: 0.4047, Validation Accuracy: 0.9384, Precision: 0.9412, Recall (Sensitivity): 0.9384, F1-Score: 0.9385, Specificity: 0.9692




Epoch 8/10, Loss: 0.3583, Validation Accuracy: 0.9086, Precision: 0.9151, Recall (Sensitivity): 0.9086, F1-Score: 0.9089, Specificity: 0.9543




Epoch 9/10, Loss: 0.3369, Validation Accuracy: 0.9459, Precision: 0.9474, Recall (Sensitivity): 0.9459, F1-Score: 0.9462, Specificity: 0.9729




Epoch 10/10, Loss: 0.3267, Validation Accuracy: 0.9422, Precision: 0.9427, Recall (Sensitivity): 0.9422, F1-Score: 0.9423, Specificity: 0.9711
Best Accuracy for Folder1+Folder2: 0.9459

Training on: Folder2+Folder3




Epoch 1/10, Loss: 1.0381, Validation Accuracy: 0.6030, Precision: 0.6215, Recall (Sensitivity): 0.6030, F1-Score: 0.6006, Specificity: 0.8013




Epoch 2/10, Loss: 0.8673, Validation Accuracy: 0.7228, Precision: 0.7448, Recall (Sensitivity): 0.7228, F1-Score: 0.7152, Specificity: 0.8614




Epoch 3/10, Loss: 0.7578, Validation Accuracy: 0.8296, Precision: 0.8330, Recall (Sensitivity): 0.8296, F1-Score: 0.8282, Specificity: 0.9148




Epoch 4/10, Loss: 0.6362, Validation Accuracy: 0.8315, Precision: 0.8367, Recall (Sensitivity): 0.8315, F1-Score: 0.8268, Specificity: 0.9158




Epoch 5/10, Loss: 0.5273, Validation Accuracy: 0.9082, Precision: 0.9125, Recall (Sensitivity): 0.9082, F1-Score: 0.9064, Specificity: 0.9542




Epoch 6/10, Loss: 0.4529, Validation Accuracy: 0.9120, Precision: 0.9215, Recall (Sensitivity): 0.9120, F1-Score: 0.9130, Specificity: 0.9559




Epoch 7/10, Loss: 0.3996, Validation Accuracy: 0.9345, Precision: 0.9346, Recall (Sensitivity): 0.9345, F1-Score: 0.9341, Specificity: 0.9672




Epoch 8/10, Loss: 0.3550, Validation Accuracy: 0.9382, Precision: 0.9418, Recall (Sensitivity): 0.9382, F1-Score: 0.9380, Specificity: 0.9692




Epoch 9/10, Loss: 0.3343, Validation Accuracy: 0.9513, Precision: 0.9519, Recall (Sensitivity): 0.9513, F1-Score: 0.9511, Specificity: 0.9757




Epoch 10/10, Loss: 0.3266, Validation Accuracy: 0.9588, Precision: 0.9589, Recall (Sensitivity): 0.9588, F1-Score: 0.9587, Specificity: 0.9794
Best Accuracy for Folder2+Folder3: 0.9588

Training on: Folder1+Folder3




Epoch 1/10, Loss: 1.0995, Validation Accuracy: 0.5261, Precision: 0.3616, Recall (Sensitivity): 0.5261, F1-Score: 0.4250, Specificity: 0.7624


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 2/10, Loss: 0.9545, Validation Accuracy: 0.5765, Precision: 0.6102, Recall (Sensitivity): 0.5765, F1-Score: 0.5786, Specificity: 0.7885




Epoch 3/10, Loss: 0.8623, Validation Accuracy: 0.6287, Precision: 0.7123, Recall (Sensitivity): 0.6287, F1-Score: 0.6102, Specificity: 0.8144




Epoch 4/10, Loss: 0.7476, Validation Accuracy: 0.7724, Precision: 0.7704, Recall (Sensitivity): 0.7724, F1-Score: 0.7705, Specificity: 0.8862




Epoch 5/10, Loss: 0.6604, Validation Accuracy: 0.8022, Precision: 0.8148, Recall (Sensitivity): 0.8022, F1-Score: 0.7936, Specificity: 0.9009




Epoch 6/10, Loss: 0.5638, Validation Accuracy: 0.8246, Precision: 0.8290, Recall (Sensitivity): 0.8246, F1-Score: 0.8260, Specificity: 0.9124




Epoch 7/10, Loss: 0.4428, Validation Accuracy: 0.8974, Precision: 0.8998, Recall (Sensitivity): 0.8974, F1-Score: 0.8964, Specificity: 0.9487




Epoch 8/10, Loss: 0.4024, Validation Accuracy: 0.9235, Precision: 0.9257, Recall (Sensitivity): 0.9235, F1-Score: 0.9237, Specificity: 0.9618




Epoch 9/10, Loss: 0.3588, Validation Accuracy: 0.9384, Precision: 0.9407, Recall (Sensitivity): 0.9384, F1-Score: 0.9387, Specificity: 0.9693




Epoch 10/10, Loss: 0.3448, Validation Accuracy: 0.9459, Precision: 0.9467, Recall (Sensitivity): 0.9459, F1-Score: 0.9460, Specificity: 0.9730
Best Accuracy for Folder1+Folder3: 0.9459

