In [1]:
# preprocess.py
# Install compatible versions of JAX, jaxlib, and PennyLane dependencies
!pip install jax==0.4.28 jaxlib==0.4.28 pennylane pennylane-lightning --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.1/56.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m28.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.5/77.5 MB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m64.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m66.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m930.8/930.8 kB[0m [31m42.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m58.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.6/8.6 MB[0m [31m99.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# Preprocess

In [2]:
# preprocess.py
# Install compatible versions of JAX, jaxlib, and PennyLane dependencies
!pip install jax==0.4.28 jaxlib==0.4.28 pennylane pennylane-lightning --quiet

import os  # Import os for file system operations
import random  # Import random for generating random numbers
import torch  # Import PyTorch for tensor operations
import torch.nn as nn  # Import neural network modules from PyTorch
import torch.nn.functional as F  # Import functional operations from PyTorch
from torch.utils.data import Dataset, DataLoader  # Import Dataset and DataLoader for data handling
from torchvision import transforms  # Import transforms for image preprocessing
from PIL import Image  # Import PIL for image loading
import pennylane as qml  # Import PennyLane for quantum computing
import numpy as np  # Import NumPy for numerical operations
from tqdm import tqdm  # Import tqdm for progress bars
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix  # Import metrics for evaluation
import matplotlib.pyplot as plt  # Import matplotlib for plotting
import seaborn as sns  # Import seaborn for enhanced visualizations

DATA_ROOT = "/kaggle/input/imagesoasis/Data"  # Define path to OASIS dataset
WORKING_DIR = "/kaggle/working/"  # Define working directory

LABEL_MAP = {
    "Non Demented": 0,  # Map "Non Demented" to label 0
    "Very mild Dementia": 1,  # Map "Very mild Dementia" to label 1
    "Mild Dementia": 2,  # Map "Mild Dementia" to label 2
    "Moderate Dementia": 3,  # Map "Moderate Dementia" to label 3
}

class_names = list(LABEL_MAP.keys())  # Create list of class names from LABEL_MAP keys

if not os.path.exists(DATA_ROOT):  # Check if dataset directory exists
    raise RuntimeError(f"Dataset directory not found: {DATA_ROOT}")  # Raise error if directory not found
print("✔ Data root:", DATA_ROOT)  # Print the data root directory

✔ Data root: /kaggle/input/imagesoasis/Data


# Dataset

In [3]:
# =====================
# 1. Dataset: MIL Patch Extraction with Object Coverage Strategy
# =====================
class OASIS2DDataset(Dataset):
    def __init__(self, root_dir, class_names, transform=None, noise_std=0.05, patch_size=32, stride=16):  # Initialize with root directory, class names, transforms, noise level, patch size and stride
        self.data = []  # Initialize empty list for image paths
        self.labels = []  # Initialize empty list for labels
        self.transform = transform  # Store transform function
        self.noise_std = noise_std  # Store standard deviation for Gaussian noise
        self.patch_size = patch_size # Store patch size
        self.stride = stride # Store stride

        for class_name in class_names:  # Iterate through each class
            class_dir = os.path.join(root_dir, class_name)  # Get path to class subdirectory
            label = LABEL_MAP[class_name]  # Get integer label for class
            class_images = [os.path.join(class_dir, fname) for fname in os.listdir(class_dir) if fname.lower().endswith(".jpg")]  # Collect all JPEG image paths
            if class_name == "Moderate Dementia":  # For Moderate Dementia, use all available images (488)
                selected_images = class_images  # Take all 488 images
            elif len(class_images) >= 2000:  # For other classes with 2000 or more images
                selected_images = random.sample(class_images, 2000)  # Randomly sample exactly 2000 images
            else:
                raise ValueError(f"Class {class_name} has fewer than 2000 images: {len(class_images)}")  # Error if a class has < 2000 excluding Moderate Dementia
            self.data.extend(selected_images)  # Add selected image paths to data list
            self.labels.extend([label] * len(selected_images))  # Add corresponding labels to labels list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("L")
        if self.transform:
            image = self.transform(image)

        # Add noise to simulate clinical imaging variability
        image = image + self.noise_std * torch.randn_like(image)

        # Extract patches using sliding window to cover full object area
        patches = self.extract_patches(image)
        return torch.stack(patches), torch.tensor(label)

    def extract_patches(self, image):
        # Cover entire image using sliding window with overlap
        _, H, W = image.shape
        patches = []
        for y in range(0, H - self.patch_size + 1, self.stride):
            for x in range(0, W - self.patch_size + 1, self.stride):
                patch = image[:, y:y+self.patch_size, x:x+self.patch_size]
                patches.append(patch)
        return patches

# Model

In [4]:
def create_qnode(n_qubits=4):  # Define function to create quantum node
    dev = qml.device("default.qubit", wires=n_qubits)  # Initialize quantum device with specified qubits

    @qml.qnode(dev, interface="torch", diff_method="backprop")  # Define quantum node with PyTorch interface
    def qnode(inputs, weights):  # Define quantum circuit function
        qml.templates.AngleEmbedding(inputs, wires=range(n_qubits), rotation="Y")  # Embed inputs as Y rotations
        qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))  # Apply entangling layer
        return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]  # Return PauliZ expectations for each qubit

    return qnode  # Return quantum node function

class QNNMILClassifier(nn.Module):
    def __init__(self, n_qubits=4, num_classes=4):
        super().__init__()
        self.n_qubits = n_qubits
        self.qnode = create_qnode(n_qubits)
        weight_shapes = {"weights": (1, n_qubits)} # Changed from (n_qubits,) to (1, n_qubits)
        self.q_layer = qml.qnn.TorchLayer(self.qnode, weight_shapes)

        # Attention scores to learn which patches are most relevant (MIL core)
        self.attention = nn.Sequential(
            nn.Linear(n_qubits, 8), nn.Tanh(), nn.Linear(8, 1)
        )
        self.classifier = nn.Linear(n_qubits, num_classes)  # final classification layer

    def forward(self, patches):
        B, N, C, H, W = patches.shape
        outputs = []
        for i in range(N):
            patch = patches[:, i].reshape(B, -1)[:, :self.n_qubits]
            q_out = self.q_layer(patch)
            outputs.append(q_out)
        patch_outputs = torch.stack(outputs, dim=1)
        attn_weights = F.softmax(self.attention(patch_outputs).squeeze(-1), dim=1)
        bag_rep = torch.sum(attn_weights.unsqueeze(-1) * patch_outputs, dim=1)
        return self.classifier(bag_rep), bag_rep  # also return embedding for contrastive use

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature

    def forward(self, bag_reps, labels, prototypes):
        bag_reps = F.normalize(bag_reps, dim=1)
        protos = torch.stack([prototypes[label.item()] for label in labels]).to(bag_reps.device)
        protos = F.normalize(protos, dim=1)
        sim = torch.sum(bag_reps * protos, dim=1) / self.temperature
        return -torch.mean(sim)

In [5]:
def compute_prototypes(model, loader, device):
    model.eval()
    features_by_class = {}
    with torch.no_grad():
        for patches, labels in tqdm(loader):
            patches, labels = patches.to(device), labels.to(device)
            _, bag_rep = model(patches)
            for i in range(len(labels)):
                c = labels[i].item()
                if c not in features_by_class:
                    features_by_class[c] = []
                features_by_class[c].append(bag_rep[i].detach().cpu())
    return {c: torch.stack(features).mean(0) for c, features in features_by_class.items()}


# Train_eval

In [6]:
def train_joint(model, dataloader, optimizer, criterion_cls, criterion_cont, device):  # Define joint training function
    model.train()  # Set model to training mode
    total_loss = 0.0  # Initialize total loss
    correct = 0  # Initialize correct predictions count
    all_preds = []  # Initialize list for predictions
    all_labels = []  # Initialize list for true labels

    for patches, labels in tqdm(dataloader, desc="Training"):  # Iterate through dataloader with progress bar
        patches, labels = patches.to(device), labels.to(device)  # Move data to device
        noisy = patches + .3 * torch.randn_like(patches)  # Add Gaussian noise to patches

        logits_clean, bag_rep_clean = model(patches)  # Forward pass on clean patches
        logits_noisy, bag_rep_noisy = model(noisy)  # Forward pass on noisy patches

        loss_cls = criterion_cls(logits_clean, labels)  # Compute classification loss
        loss_cont = criterion_cont(bag_rep_clean, labels, prototypes)  # Compute contrastive loss

        loss = loss_cls + 0.1 * loss_cont  # Combine losses with weighting
        optimizer.zero_grad()  # Clear gradients
        loss.backward()  # Compute gradients
        optimizer.step()  # Update model parameters

        total_loss += loss.item()  # Accumulate loss
        preds = torch.argmax(logits_clean, dim=1)  # Get predicted classes
        correct += (preds == labels).sum().item()  # Count correct predictions
        all_preds.extend(preds.cpu().numpy())  # Collect predictions
        all_labels.extend(labels.cpu().numpy())  # Collect true labels

    avg_loss = total_loss / len(dataloader)  # Compute average loss
    accuracy = accuracy_score(all_labels, all_preds)  # Compute accuracy
    f1 = f1_score(all_labels, all_preds, average='weighted')  # Compute weighted F1 score
    return avg_loss, accuracy, f1, all_preds, all_labels  # Return metrics and predictions

def validate(model, dataloader, criterion_cls, device):  # Define validation function
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0  # Initialize total loss
    all_preds = []  # Initialize list for predictions
    all_labels = []  # Initialize list for true labels

    with torch.no_grad():  # Disable gradient computation
        for patches, labels in tqdm(dataloader, desc="Validating"):  # Iterate through dataloader with progress bar
            patches, labels = patches.to(device), labels.to(device)  # Move data to device
            logits, _ = model(patches)  # Forward pass
            loss = criterion_cls(logits, labels)  # Compute classification loss
            total_loss += loss.item()  # Accumulate loss
            preds = torch.argmax(logits, dim=1)  # Get predicted classes
            all_preds.extend(preds.cpu().numpy())  # Collect predictions
            all_labels.extend(labels.cpu().numpy())  # Collect true labels

    avg_loss = total_loss / len(dataloader)  # Compute average loss
    accuracy = accuracy_score(all_labels, all_preds)  # Compute accuracy
    f1 = f1_score(all_labels, all_preds, average='weighted')  # Compute weighted F1 score
    return avg_loss, accuracy, f1, all_preds, all_labels  # Return metrics and predictions

def plot_confusion_matrix(labels, preds, class_names, epoch):  # Define function to plot confusion matrix
    cm = confusion_matrix(labels, preds)  # Compute confusion matrix
    plt.figure(figsize=(8, 6))  # Create figure with size 8x6
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)  # Plot heatmap
    plt.title(f'Confusion Matrix (Epoch {epoch+1})')  # Set title with epoch number
    plt.xlabel('Predicted')  # Label x-axis
    plt.ylabel('True')  # Label y-axis
    plt.savefig(f'confusion_matrix_epoch_{epoch+1}.png')  # Save plot to file
    plt.close()  # Close figure

def plot_validation_curves(train_accs, val_accs, train_f1s, val_f1s, train_losses, val_losses):  # Define function to plot validation curves
    epochs = range(1, len(train_accs) + 1)  # Create range of epochs
    plt.figure(figsize=(12, 4))  # Create figure with size 12x4

    plt.subplot(1, 3, 1)  # Create first subplot for accuracy
    plt.plot(epochs, train_accs, 'b-', label='Train Accuracy')  # Plot training accuracy
    plt.plot(epochs, val_accs, 'r-', label='Validation Accuracy')  # Plot validation accuracy
    plt.title('Accuracy Curve')  # Set title
    plt.xlabel('Epoch')  # Label x-axis
    plt.ylabel('Accuracy')  # Label y-axis
    plt.legend()  # Show legend
    plt.grid(True)  # Enable grid

    plt.subplot(1, 3, 2)  # Create second subplot for F1 score
    plt.plot(epochs, train_f1s, 'b-', label='Train F1 Score')  # Plot training F1 score
    plt.plot(epochs, val_f1s, 'r-', label='Validation F1 Score')  # Plot validation F1 score
    plt.title('F1 Score Curve')  # Set title
    plt.xlabel('Epoch')  # Label x-axis
    plt.ylabel('F1 Score')  # Label y-axis
    plt.legend()  # Show legend
    plt.grid(True)  # Enable grid

    plt.subplot(1, 3, 3)  # Create third subplot for loss
    plt.plot(epochs, train_losses, 'b-', label='Train Loss')  # Plot training loss
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss')  # Plot validation loss
    plt.title('Loss Curve')  # Set title
    plt.xlabel('Epoch')  # Label x-axis
    plt.ylabel('Loss')  # Label y-axis
    plt.legend()  # Show legend
    plt.grid(True)  # Enable grid

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.savefig('validation_curves.png')  # Save plot to file
    plt.close()  # Close figure

if __name__ == "__main__":  # Main execution block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Select GPU if available, else CPU

    transform = transforms.Compose([  # Define image transformations
        transforms.Resize((128, 128)),  # Resize images to 128x128
        transforms.ToTensor(),  # Convert images to tensors
    ])

    dataset = OASIS2DDataset(DATA_ROOT, class_names, transform=transform)  # Create dataset instance
    train_size = int(0.8 * len(dataset))  # Calculate training set size (80%)
    test_size = len(dataset) - train_size  # Calculate test set size (20%)
    train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])  # Split dataset

    # Compute class counts in training set
    from collections import Counter
    train_indices = train_set.indices
    train_labels = [dataset.labels[i] for i in train_indices]
    class_counts = Counter(train_labels)
    num_classes = len(class_names)
    counts = [class_counts.get(i, 0) for i in range(num_classes)]
    max_count = max(counts)
    weights = [max_count / count if count > 0 else 0 for count in counts]
    weights_tensor = torch.tensor(weights, dtype=torch.float).to(device)

    train_loader = DataLoader(train_set, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)  # Create training DataLoader
    test_loader = DataLoader(test_set, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)  # Create test DataLoader

    model = QNNMILClassifier(n_qubits=4, num_classes=num_classes).to(device)  # Initialize model and move to device

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Initialize Adam optimizer
    loss_fn = nn.CrossEntropyLoss(weight=weights_tensor)  # Initialize cross-entropy loss with class weights
    contrast_fn = ContrastiveLoss()  # Initialize contrastive loss

    print("\n▶️  Training with contrastive + classification loss...\n")  # Print training start message
    num_epochs = 10  # Set number of epochs
    train_accs, val_accs = [], []  # Initialize lists for accuracies
    train_f1s, val_f1s = [], []  # Initialize lists for F1 scores
    train_losses, val_losses = [], []  # Initialize lists for losses

    for epoch in range(num_epochs):  # Iterate through epochs
        # Compute prototypes at the beginning of each epoch
        prototypes = compute_prototypes(model, train_loader, device)

        train_loss, train_acc, train_f1, train_preds, train_labels = train_joint(
            model, train_loader, optimizer, loss_fn, contrast_fn, device
        )  # Train model for one epoch
        val_loss, val_acc, val_f1, val_preds, val_labels = validate(
            model, test_loader, loss_fn, device
        )  # Validate model

        train_accs.append(train_acc)  # Store training accuracy
        val_accs.append(val_acc)  # Store validation accuracy
        train_f1s.append(train_f1)  # Store training F1 score
        val_f1s.append(val_f1)  # Store validation F1 score
        train_losses.append(train_loss)  # Store training loss
        val_losses.append(val_loss)  # Store validation loss

        print(f"Epoch {epoch+1:02d} | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")  # Print epoch metrics

        plot_confusion_matrix(val_labels, val_preds, class_names, epoch)  # Plot confusion matrix for validation

    plot_validation_curves(train_accs, val_accs, train_f1s, val_f1s, train_losses, val_losses)  # Plot validation curves

    prototypes = compute_prototypes(model, train_loader, device)  # Compute prototypes on training set
    for cls_idx, proto in prototypes.items():  # Iterate through prototypes
        print(f"Prototype for class {cls_idx} ({class_names[cls_idx]}): {proto.numpy()[:5]} ...")  # Print prototype features


▶️  Training with contrastive + classification loss...



100%|██████████| 649/649 [05:06<00:00,  2.12it/s]
Training: 100%|██████████| 649/649 [13:46<00:00,  1.27s/it]
Validating: 100%|██████████| 163/163 [01:17<00:00,  2.11it/s]


Epoch 01 | Train Loss: 1.1862 | Train Acc: 0.2934 | Train F1: 0.2828 | Val Loss: 1.3843 | Val Acc: 0.2928 | Val F1: 0.1326


100%|██████████| 649/649 [05:08<00:00,  2.10it/s]
Training: 100%|██████████| 649/649 [13:53<00:00,  1.28s/it]
Validating: 100%|██████████| 163/163 [01:16<00:00,  2.13it/s]


Epoch 02 | Train Loss: 1.1803 | Train Acc: 0.2961 | Train F1: 0.2892 | Val Loss: 1.3871 | Val Acc: 0.2928 | Val F1: 0.1406


100%|██████████| 649/649 [05:05<00:00,  2.12it/s]
Training: 100%|██████████| 649/649 [13:49<00:00,  1.28s/it]
Validating: 100%|██████████| 163/163 [01:17<00:00,  2.11it/s]


Epoch 03 | Train Loss: 1.1807 | Train Acc: 0.3079 | Train F1: 0.2923 | Val Loss: 1.3814 | Val Acc: 0.3213 | Val F1: 0.1562


100%|██████████| 649/649 [05:05<00:00,  2.13it/s]
Training: 100%|██████████| 649/649 [13:45<00:00,  1.27s/it]
Validating: 100%|██████████| 163/163 [01:15<00:00,  2.15it/s]


Epoch 04 | Train Loss: 1.1820 | Train Acc: 0.3102 | Train F1: 0.2957 | Val Loss: 1.3827 | Val Acc: 0.2928 | Val F1: 0.1326


100%|██████████| 649/649 [05:02<00:00,  2.15it/s]
Training: 100%|██████████| 649/649 [13:46<00:00,  1.27s/it]
Validating: 100%|██████████| 163/163 [01:16<00:00,  2.13it/s]


Epoch 05 | Train Loss: 1.1808 | Train Acc: 0.3091 | Train F1: 0.2948 | Val Loss: 1.3808 | Val Acc: 0.2897 | Val F1: 0.1355


100%|██████████| 649/649 [05:05<00:00,  2.12it/s]
Training: 100%|██████████| 649/649 [13:46<00:00,  1.27s/it]
Validating: 100%|██████████| 163/163 [01:16<00:00,  2.13it/s]


Epoch 06 | Train Loss: 1.1817 | Train Acc: 0.3048 | Train F1: 0.2918 | Val Loss: 1.3802 | Val Acc: 0.3213 | Val F1: 0.1562


100%|██████████| 649/649 [05:02<00:00,  2.14it/s]
Training: 100%|██████████| 649/649 [13:42<00:00,  1.27s/it]
Validating: 100%|██████████| 163/163 [01:16<00:00,  2.13it/s]


Epoch 07 | Train Loss: 1.1811 | Train Acc: 0.3035 | Train F1: 0.2868 | Val Loss: 1.3805 | Val Acc: 0.3213 | Val F1: 0.1562


100%|██████████| 649/649 [05:00<00:00,  2.16it/s]
Training: 100%|██████████| 649/649 [13:41<00:00,  1.27s/it]
Validating: 100%|██████████| 163/163 [01:16<00:00,  2.13it/s]


Epoch 08 | Train Loss: 1.1815 | Train Acc: 0.3054 | Train F1: 0.2919 | Val Loss: 1.3804 | Val Acc: 0.3213 | Val F1: 0.1562


100%|██████████| 649/649 [04:58<00:00,  2.17it/s]
Training: 100%|██████████| 649/649 [13:32<00:00,  1.25s/it]
Validating: 100%|██████████| 163/163 [01:17<00:00,  2.11it/s]


Epoch 09 | Train Loss: 1.1810 | Train Acc: 0.3015 | Train F1: 0.2857 | Val Loss: 1.3818 | Val Acc: 0.2928 | Val F1: 0.1326


100%|██████████| 649/649 [04:56<00:00,  2.19it/s]
Training: 100%|██████████| 649/649 [13:22<00:00,  1.24s/it]
Validating: 100%|██████████| 163/163 [01:14<00:00,  2.19it/s]


Epoch 10 | Train Loss: 1.1820 | Train Acc: 0.3060 | Train F1: 0.2770 | Val Loss: 1.3808 | Val Acc: 0.3074 | Val F1: 0.1446


100%|██████████| 649/649 [04:54<00:00,  2.20it/s]

Prototype for class 1 (Very mild Dementia): [ 1.0799805e-05  2.4797477e-02 -3.5426719e-04 -1.0487450e-06] ...
Prototype for class 0 (Non Demented): [ 1.0692320e-05  2.4659477e-02 -3.5113384e-04 -1.0357679e-06] ...
Prototype for class 2 (Mild Dementia): [ 1.0807049e-05  2.4786530e-02 -3.5426201e-04 -1.0491033e-06] ...
Prototype for class 3 (Moderate Dementia): [ 1.07539445e-05  2.48049926e-02 -3.53882031e-04 -1.04562844e-06] ...



