# Preprocess

In [1]:
# Install compatible versions of JAX, jaxlib, and PennyLane dependencies
!pip install jax==0.4.28 jaxlib==0.4.28 pennylane pennylane-lightning --quiet

import os  # Import os for file system operations
import random  # Import random for generating random numbers
import torch  # Import PyTorch for tensor operations
import torch.nn as nn  # Import neural network modules from PyTorch
import torch.nn.functional as F  # Import functional operations from PyTorch
from torch.utils.data import Dataset, DataLoader  # Import Dataset and DataLoader for data handling
from torchvision import transforms  # Import transforms for image preprocessing
from PIL import Image  # Import PIL for image loading
import pennylane as qml  # Import PennyLane for quantum computing
import numpy as np  # Import NumPy for numerical operations
from tqdm import tqdm  # Import tqdm for progress bars
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix  # Import metrics for evaluation
import matplotlib.pyplot as plt  # Import matplotlib for plotting
import seaborn as sns  # Import seaborn for enhanced visualizations

DATA_ROOT = "/kaggle/input/imagesoasis/Data"  # Define path to OASIS dataset
WORKING_DIR = "/kaggle/working/"  # Define working directory

LABEL_MAP = {
    "Non Demented": 0,  # Map "Non Demented" to label 0
    "Very mild Dementia": 1,  # Map "Very mild Dementia" to label 1
    "Mild Dementia": 2,  # Map "Mild Dementia" to label 2
    "Moderate Dementia": 3,  # Map "Moderate Dementia" to label 3
}

class_names = list(LABEL_MAP.keys())  # Create list of class names from LABEL_MAP keys

if not os.path.exists(DATA_ROOT):  # Check if dataset directory exists
    raise RuntimeError(f"Dataset directory not found: {DATA_ROOT}")  # Raise error if directory not found
print("✔ Data root:", DATA_ROOT)  # Print the data root directory

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.1/56.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.5/77.5 MB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m61.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m72.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m930.8/930.8 kB[0m [31m50.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m81.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.6/8.6 MB[0m [31m86.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# Dataset

In [2]:
class OASIS2DDataset(Dataset):  # Define custom Dataset class for OASIS data
    def __init__(self, root_dir, class_names, transform=None, noise_std=0.05):  # Initialize with root directory, class names, transforms, and noise level
        self.data = []  # Initialize empty list for image paths
        self.labels = []  # Initialize empty list for labels
        self.transform = transform  # Store transform function
        self.noise_std = noise_std  # Store standard deviation for Gaussian noise

        for class_name in class_names:  # Iterate through each class
            class_dir = os.path.join(root_dir, class_name)  # Get path to class subdirectory
            label = LABEL_MAP[class_name]  # Get integer label for class
            for fname in os.listdir(class_dir):  # Iterate through files in class directory
                if fname.lower().endswith(".jpg"):  # Check if file is a JPEG
                    self.data.append(os.path.join(class_dir, fname))  # Add image path to data list
                    self.labels.append(label)  # Add corresponding label to labels list

    def __len__(self):  # Define length method for dataset
        return len(self.data)  # Return number of images

    def __getitem__(self, idx):  # Define method to get item by index
        img_path = self.data[idx]  # Get image path at index
        label = self.labels[idx]  # Get label at index

        image = Image.open(img_path).convert("L")  # Load image as grayscale
        if self.transform:  # Check if transform is provided
            image = self.transform(image)  # Apply transform to image

        noisy_image = image + self.noise_std * torch.randn_like(image)  # Add Gaussian noise to image

        patches = self.extract_patches(noisy_image)  # Extract random patches from noisy image
        return torch.stack(patches), torch.tensor(label, dtype=torch.long)  # Return stacked patches and label

    def extract_patches(self, image, patch_size=32, num_patches=5):  # Define method to extract patches
        _, H, W = image.shape  # Get image dimensions
        patches = []  # Initialize empty list for patches
        for _ in range(num_patches):  # Loop to extract specified number of patches
            top = random.randint(0, H - patch_size)  # Random top coordinate
            left = random.randint(0, W - patch_size)  # Random left coordinate
            patch = image[:, top : top + patch_size, left : left + patch_size]  # Crop patch from image
            patches.append(patch)  # Add patch to list
        return patches  # Return list of patches

# Model

In [3]:
def create_qnode(n_qubits=4):  # Define function to create quantum node
    dev = qml.device("default.qubit", wires=n_qubits)  # Initialize quantum device with specified qubits

    @qml.qnode(dev, interface="torch", diff_method="backprop")  # Define quantum node with PyTorch interface
    def qnode(inputs, weights):  # Define quantum circuit function
        qml.templates.AngleEmbedding(inputs, wires=range(n_qubits), rotation="Y")  # Embed inputs as Y rotations
        qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))  # Apply entangling layer
        return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]  # Return PauliZ expectations for each qubit

    return qnode  # Return quantum node function

class QNNMILClassifier(nn.Module):  # Define quantum neural network classifier
    def __init__(self, n_qubits=4, num_classes=4):  # Initialize with qubits and number of classes
        super().__init__()  # Call parent class initializer
        self.n_qubits = n_qubits  # Store number of qubits

        self.qnode = create_qnode(n_qubits)  # Create quantum node
        weight_shapes = {"weights": (1, n_qubits)}  # Define shape for quantum weights
        self.q_layer = qml.qnn.TorchLayer(self.qnode, weight_shapes)  # Create quantum layer with PyTorch interface

        self.attention = nn.Sequential(  # Define attention network
            nn.Linear(n_qubits, 8),  # Linear layer from qubits to 8 units
            nn.Tanh(),  # Apply Tanh activation
            nn.Linear(8, 1),  # Linear layer to output attention score
        )

        self.classifier = nn.Linear(n_qubits, num_classes)  # Define final classifier layer

    def forward(self, patches):  # Define forward pass
        B, N, C, H, W = patches.shape  # Get batch size, number of patches, channels, height, width
        outputs = []  # Initialize list for quantum outputs

        for i in range(N):  # Iterate through patches
            patch = patches[:, i].reshape(B, -1)[:, : self.n_qubits]  # Flatten patch and slice to n_qubits
            q_out = self.q_layer(patch)  # Pass patch through quantum layer
            outputs.append(q_out)  # Append quantum output

        patch_outputs = torch.stack(outputs, dim=1)  # Stack outputs to (B, N, n_qubits)

        attn_logits = self.attention(patch_outputs)  # Compute attention logits
        attn_weights = F.softmax(attn_logits.squeeze(-1), dim=1)  # Apply softmax to get attention weights

        bag_rep = torch.sum(attn_weights.unsqueeze(-1) * patch_outputs, dim=1)  # Compute weighted sum of patch outputs

        return self.classifier(bag_rep)  # Return classification logits

class ContrastiveLoss(nn.Module):  # Define contrastive loss class
    def __init__(self, temperature=0.5):  # Initialize with temperature parameter
        super().__init__()  # Call parent class initializer
        self.temperature = temperature  # Store temperature value

    def forward(self, z1, z2):  # Define forward pass for loss
        z1 = F.normalize(z1, dim=1)  # Normalize first input tensor
        z2 = F.normalize(z2, dim=1)  # Normalize second input tensor
        sim = torch.matmul(z1, z2.T) / self.temperature  # Compute similarity matrix
        labels = torch.arange(z1.size(0)).to(z1.device)  # Create labels for contrastive loss
        return F.cross_entropy(sim, labels)  # Return cross-entropy loss

# Train_eval

In [4]:
def train_joint(model, dataloader, optimizer, criterion_cls, criterion_cont, device):  # Define joint training function
    model.train()  # Set model to training mode
    total_loss = 0.0  # Initialize total loss
    correct = 0  # Initialize correct predictions count
    all_preds = []  # Initialize list for predictions
    all_labels = []  # Initialize list for true labels

    for patches, labels in tqdm(dataloader, desc="Training"):  # Iterate through dataloader with progress bar
        patches, labels = patches.to(device), labels.to(device)  # Move data to device
        noisy = patches + .3 * torch.randn_like(patches)  # Add Gaussian noise to patches

        logits_clean = model(patches)  # Forward pass on clean patches
        logits_noisy = model(noisy)  # Forward pass on noisy patches

        loss_cls = criterion_cls(logits_clean, labels)  # Compute classification loss
        loss_cont = criterion_cont(logits_clean, logits_noisy)  # Compute contrastive loss

        loss = loss_cls + 0.1 * loss_cont  # Combine losses with weighting
        optimizer.zero_grad()  # Clear gradients
        loss.backward()  # Compute gradients
        optimizer.step()  # Update model parameters

        total_loss += loss.item()  # Accumulate loss
        preds = torch.argmax(logits_clean, dim=1)  # Get predicted classes
        correct += (preds == labels).sum().item()  # Count correct predictions
        all_preds.extend(preds.cpu().numpy())  # Collect predictions
        all_labels.extend(labels.cpu().numpy())  # Collect true labels

    avg_loss = total_loss / len(dataloader)  # Compute average loss
    accuracy = accuracy_score(all_labels, all_preds)  # Compute accuracy
    f1 = f1_score(all_labels, all_preds, average='weighted')  # Compute weighted F1 score
    return avg_loss, accuracy, f1, all_preds, all_labels  # Return metrics and predictions

def validate(model, dataloader, criterion_cls, device):  # Define validation function
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0  # Initialize total loss
    all_preds = []  # Initialize list for predictions
    all_labels = []  # Initialize list for true labels

    with torch.no_grad():  # Disable gradient computation
        for patches, labels in tqdm(dataloader, desc="Validating"):  # Iterate through dataloader with progress bar
            patches, labels = patches.to(device), labels.to(device)  # Move data to device
            logits = model(patches)  # Forward pass
            loss = criterion_cls(logits, labels)  # Compute classification loss
            total_loss += loss.item()  # Accumulate loss
            preds = torch.argmax(logits, dim=1)  # Get predicted classes
            all_preds.extend(preds.cpu().numpy())  # Collect predictions
            all_labels.extend(labels.cpu().numpy())  # Collect true labels

    avg_loss = total_loss / len(dataloader)  # Compute average loss
    accuracy = accuracy_score(all_labels, all_preds)  # Compute accuracy
    f1 = f1_score(all_labels, all_preds, average='weighted')  # Compute weighted F1 score
    return avg_loss, accuracy, f1, all_preds, all_labels  # Return metrics and predictions

def compute_prototypes(model, loader, device):  # Define function to compute class prototypes
    model.eval()  # Set model to evaluation mode
    class_feats = {}  # Initialize dictionary for class features

    with torch.no_grad():  # Disable gradient computation
        for patches, labels in loader:  # Iterate through dataloader
            patches, labels = patches.to(device), labels.to(device)  # Move data to device
            feats = model(patches)  # Forward pass to get features (logits)

            for i, lbl in enumerate(labels):  # Iterate through batch
                c = lbl.item()  # Get class label
                if c not in class_feats:  # Check if class is in dictionary
                    class_feats[c] = []  # Initialize list for class features
                class_feats[c].append(feats[i].cpu())  # Append features to class

    prototypes = {c: torch.stack(class_feats[c]).mean(0) for c in class_feats}  # Compute mean feature for each class
    return prototypes  # Return prototypes

def plot_confusion_matrix(labels, preds, class_names, epoch):  # Define function to plot confusion matrix
    cm = confusion_matrix(labels, preds)  # Compute confusion matrix
    plt.figure(figsize=(8, 6))  # Create figure with size 8x6
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)  # Plot heatmap
    plt.title(f'Confusion Matrix (Epoch {epoch+1})')  # Set title with epoch number
    plt.xlabel('Predicted')  # Label x-axis
    plt.ylabel('True')  # Label y-axis
    plt.savefig(f'confusion_matrix_epoch_{epoch+1}.png')  # Save plot to file
    plt.close()  # Close figure

def plot_validation_curves(train_accs, val_accs, train_f1s, val_f1s, train_losses, val_losses):  # Define function to plot validation curves
    epochs = range(1, len(train_accs) + 1)  # Create range of epochs
    plt.figure(figsize=(12, 4))  # Create figure with size 12x4

    plt.subplot(1, 3, 1)  # Create first subplot for accuracy
    plt.plot(epochs, train_accs, 'b-', label='Train Accuracy')  # Plot training accuracy
    plt.plot(epochs, val_accs, 'r-', label='Validation Accuracy')  # Plot validation accuracy
    plt.title('Accuracy Curve')  # Set title
    plt.xlabel('Epoch')  # Label x-axis
    plt.ylabel('Accuracy')  # Label y-axis
    plt.legend()  # Show legend
    plt.grid(True)  # Enable grid

    plt.subplot(1, 3, 2)  # Create second subplot for F1 score
    plt.plot(epochs, train_f1s, 'b-', label='Train F1 Score')  # Plot training F1 score
    plt.plot(epochs, val_f1s, 'r-', label='Validation F1 Score')  # Plot validation F1 score
    plt.title('F1 Score Curve')  # Set title
    plt.xlabel('Epoch')  # Label x-axis
    plt.ylabel('F1 Score')  # Label y-axis
    plt.legend()  # Show legend
    plt.grid(True)  # Enable grid

    plt.subplot(1, 3, 3)  # Create third subplot for loss
    plt.plot(epochs, train_losses, 'b-', label='Train Loss')  # Plot training loss
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss')  # Plot validation loss
    plt.title('Loss Curve')  # Set title
    plt.xlabel('Epoch')  # Label x-axis
    plt.ylabel('Loss')  # Label y-axis
    plt.legend()  # Show legend
    plt.grid(True)  # Enable grid

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.savefig('validation_curves.png')  # Save plot to file
    plt.close()  # Close figure

if __name__ == "__main__":  # Main execution block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Select GPU if available, else CPU

    transform = transforms.Compose([  # Define image transformations
        transforms.Resize((128, 128)),  # Resize images to 128x128
        transforms.ToTensor(),  # Convert images to tensors
    ])

    dataset = OASIS2DDataset(DATA_ROOT, class_names, transform=transform)  # Create dataset instance
    train_size = int(0.8 * len(dataset))  # Calculate training set size (80%)
    test_size = len(dataset) - train_size  # Calculate test set size (20%)
    train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])  # Split dataset

    train_loader = DataLoader(train_set, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)  # Create training DataLoader
    test_loader = DataLoader(test_set, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)  # Create test DataLoader

    model = QNNMILClassifier(n_qubits=4, num_classes=len(class_names)).to(device)  # Initialize model and move to device

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Initialize Adam optimizer
    loss_fn = nn.CrossEntropyLoss()  # Initialize cross-entropy loss
    contrast_fn = ContrastiveLoss()  # Initialize contrastive loss

    print("\n▶️  Training with contrastive + classification loss...\n")  # Print training start message
    num_epochs = 10  # Set number of epochs
    train_accs, val_accs = [], []  # Initialize lists for accuracies
    train_f1s, val_f1s = [], []  # Initialize lists for F1 scores
    train_losses, val_losses = [], []  # Initialize lists for losses

    for epoch in range(num_epochs):  # Iterate through epochs
        train_loss, train_acc, train_f1, train_preds, train_labels = train_joint(
            model, train_loader, optimizer, loss_fn, contrast_fn, device
        )  # Train model for one epoch
        val_loss, val_acc, val_f1, val_preds, val_labels = validate(
            model, test_loader, loss_fn, device
        )  # Validate model

        train_accs.append(train_acc)  # Store training accuracy
        val_accs.append(val_acc)  # Store validation accuracy
        train_f1s.append(train_f1)  # Store training F1 score
        val_f1s.append(val_f1)  # Store validation F1 score
        train_losses.append(train_loss)  # Store training loss
        val_losses.append(val_loss)  # Store validation loss

        print(f"Epoch {epoch+1:02d} | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")  # Print epoch metrics

        plot_confusion_matrix(val_labels, val_preds, class_names, epoch)  # Plot confusion matrix for validation

    plot_validation_curves(train_accs, val_accs, train_f1s, val_f1s, train_losses, val_losses)  # Plot validation curves

    prototypes = compute_prototypes(model, train_loader, device)  # Compute prototypes on training set
    for cls_idx, proto in prototypes.items():  # Iterate through prototypes
        print(f"Prototype for class {cls_idx} ({class_names[cls_idx]}): {proto.numpy()[:5]} ...")  # Print prototype features


▶️  Training with contrastive + classification loss...



Training: 100%|██████████| 8644/8644 [21:34<00:00,  6.68it/s]
Validating: 100%|██████████| 2161/2161 [01:47<00:00, 20.12it/s]


Epoch 01 | Train Loss: 0.8918 | Train Acc: 0.7792 | Train F1: 0.6824 | Val Loss: 0.6907 | Val Acc: 0.7719 | Val F1: 0.6725


Training: 100%|██████████| 8644/8644 [21:26<00:00,  6.72it/s]
Validating: 100%|██████████| 2161/2161 [01:46<00:00, 20.28it/s]


Epoch 02 | Train Loss: 0.8892 | Train Acc: 0.7792 | Train F1: 0.6824 | Val Loss: 0.6903 | Val Acc: 0.7719 | Val F1: 0.6725


Training: 100%|██████████| 8644/8644 [21:27<00:00,  6.71it/s]
Validating: 100%|██████████| 2161/2161 [01:46<00:00, 20.25it/s]


Epoch 03 | Train Loss: 0.8885 | Train Acc: 0.7792 | Train F1: 0.6824 | Val Loss: 0.6908 | Val Acc: 0.7719 | Val F1: 0.6725


Training: 100%|██████████| 8644/8644 [21:15<00:00,  6.78it/s]
Validating: 100%|██████████| 2161/2161 [01:44<00:00, 20.69it/s]


Epoch 04 | Train Loss: 0.8887 | Train Acc: 0.7792 | Train F1: 0.6824 | Val Loss: 0.6903 | Val Acc: 0.7719 | Val F1: 0.6725


Training: 100%|██████████| 8644/8644 [21:18<00:00,  6.76it/s]
Validating: 100%|██████████| 2161/2161 [01:44<00:00, 20.71it/s]


Epoch 05 | Train Loss: 0.8886 | Train Acc: 0.7792 | Train F1: 0.6824 | Val Loss: 0.6909 | Val Acc: 0.7719 | Val F1: 0.6725


Training: 100%|██████████| 8644/8644 [21:18<00:00,  6.76it/s]
Validating: 100%|██████████| 2161/2161 [01:43<00:00, 20.91it/s]


Epoch 06 | Train Loss: 0.8888 | Train Acc: 0.7792 | Train F1: 0.6824 | Val Loss: 0.6905 | Val Acc: 0.7719 | Val F1: 0.6725


Training: 100%|██████████| 8644/8644 [21:06<00:00,  6.83it/s]
Validating: 100%|██████████| 2161/2161 [01:44<00:00, 20.65it/s]


Epoch 07 | Train Loss: 0.8886 | Train Acc: 0.7792 | Train F1: 0.6824 | Val Loss: 0.6905 | Val Acc: 0.7719 | Val F1: 0.6725


Training: 100%|██████████| 8644/8644 [21:30<00:00,  6.70it/s]
Validating: 100%|██████████| 2161/2161 [01:47<00:00, 20.18it/s]


Epoch 08 | Train Loss: 0.8885 | Train Acc: 0.7792 | Train F1: 0.6824 | Val Loss: 0.6956 | Val Acc: 0.7719 | Val F1: 0.6725


Training: 100%|██████████| 8644/8644 [21:41<00:00,  6.64it/s]
Validating: 100%|██████████| 2161/2161 [01:45<00:00, 20.44it/s]


Epoch 09 | Train Loss: 0.8886 | Train Acc: 0.7792 | Train F1: 0.6824 | Val Loss: 0.6906 | Val Acc: 0.7719 | Val F1: 0.6725


Training: 100%|██████████| 8644/8644 [21:39<00:00,  6.65it/s]
Validating: 100%|██████████| 2161/2161 [01:45<00:00, 20.47it/s]


Epoch 10 | Train Loss: 0.8887 | Train Acc: 0.7792 | Train F1: 0.6824 | Val Loss: 0.6903 | Val Acc: 0.7719 | Val F1: 0.6725
Prototype for class 0 (Non Demented): [ 1.5503255   0.05553889 -1.0122899  -3.3829305 ] ...
Prototype for class 1 (Very mild Dementia): [ 1.5503255   0.05553889 -1.0122899  -3.3829305 ] ...
Prototype for class 2 (Mild Dementia): [ 1.5503255   0.05553889 -1.01229    -3.3829305 ] ...
Prototype for class 3 (Moderate Dementia): [ 1.5503258   0.05553888 -1.0122899  -3.38293   ] ...
