<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/MedMNIST_QNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
# Snippet 1 (User Corrected): Data Loading and Preprocessing

# --- Install ---
!pip install medmnist --quiet

import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import medmnist
import numpy as np

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print("--- Library Versions ---")
print(f"Torch: {torch.__version__}")
print(f"MedMNIST: {medmnist.__version__}\n")

DATASET_NAMES = ['PathMNIST', 'DermaMNIST', 'BloodMNIST']
datasets, dataloaders = {}, {}

print("--- Loading Datasets (train/val/test) ---")
for name in DATASET_NAMES:
    info = medmnist.INFO[name.lower()]
    DataClass = getattr(medmnist, info['python_class'])
    n_channels = info.get('n_channels', 3)

    # Channel-aware normalization
    if n_channels == 1:
        mean, std = [0.5], [0.5]
    else:
        mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]

    tfm = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize(mean=mean, std=std)])

    # Use TRUE splits; do not train on test
    train_ds = DataClass(split='train', transform=tfm, download=True)
    val_ds   = DataClass(split='val',   transform=tfm, download=True)
    test_ds  = DataClass(split='test',  transform=tfm, download=True)

    datasets[name] = {'train': train_ds, 'val': val_ds, 'test': test_ds}

    # Reasonable batch sizes
    dataloaders[name] = {
        'train': DataLoader(train_ds, batch_size=64, shuffle=True,  num_workers=2, pin_memory=True),
        'val':   DataLoader(val_ds,   batch_size=128, shuffle=False, num_workers=2, pin_memory=True),
        'test':  DataLoader(test_ds,  batch_size=128, shuffle=False, num_workers=2, pin_memory=True),
    }

    print(f"{name}: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}, "
          f"classes={len(train_ds.info['label'])}, channels={n_channels}")

print("\nAll datasets ready.\n")

Using device: cpu
--- Library Versions ---
Torch: 2.8.0+cu126
MedMNIST: 3.0.2

--- Loading Datasets (train/val/test) ---
PathMNIST: train=89996, val=10004, test=7180, classes=9, channels=3
DermaMNIST: train=7007, val=1003, test=2005, classes=7, channels=3
BloodMNIST: train=11959, val=1712, test=3421, classes=8, channels=3

All datasets ready.



In [12]:
# Snippet 2 (Second User Correction): The Patched Model Architecture

import torch
import torch.nn as nn
import pennylane as qml

# No changes to the qnode, it remains the same
n_qubits = 8
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface='torch', diff_method='backprop')
def quantum_feature_map(inputs):
    qml.templates.AngleEmbedding(inputs, wires=range(n_qubits), rotation='Y')
    for i in range(n_qubits - 1):
        qml.CNOT(wires=[i, i + 1])
    qml.CNOT(wires=[n_qubits - 1, 0])
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

# --- Corrected QTFClassifier ---
# Incorporating the user-provided patch for PennyLane version compatibility

class QTFClassifier(nn.Module):
    def __init__(self, q_feature_map, n_classes=9, in_channels=3):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, n_qubits)
        )
        # PATCH: Removed the unsupported 'dtype' argument
        self.q_feature_map = qml.qnn.TorchLayer(q_feature_map, weight_shapes={})
        self.classifier = nn.Linear(n_qubits, n_classes)

    def forward(self, x):
        feats = self.feature_extractor(x)
        qfeats = self.q_feature_map(feats)
        # PATCH: Manually cast dtype to ensure compatibility
        qfeats = qfeats.to(feats.dtype)
        logits = self.classifier(qfeats)
        return logits

print("Patched Quantum Transfer Learning architecture defined successfully.")

Patched Quantum Transfer Learning architecture defined successfully.


In [13]:
# Snippet 3 (Second User Correction): The Training and Validation Loop

from tqdm import tqdm

# --- Setup ---
# Choose the dataset to train on
dataset_name = 'PathMNIST'
info = medmnist.INFO[dataset_name.lower()]
in_channels = info.get('n_channels', 3)
n_classes = len(info['label'])

# --- Model, Optimizer, and Loss Function ---
# This will now use the patched QTFClassifier from the previous step
model = QTFClassifier(q_feature_map=quantum_feature_map,
                      n_classes=n_classes,
                      in_channels=in_channels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

# --- DataLoaders ---
train_loader = dataloaders[dataset_name]['train']
val_loader   = dataloaders[dataset_name]['val']

def labels_to_indices(y):
    """
    MedMNIST often returns one-hot labels.
    Convert to class indices for CrossEntropyLoss.
    """
    if y.ndim == 2 and y.size(1) > 1:
        return y.argmax(dim=1)
    return y.view(-1).long()

# --- Training Loop ---
num_epochs = 10
best_val = 0.0
print(f"Training on {dataset_name} for {num_epochs} epochs...")

for epoch in range(1, num_epochs + 1):
    # ---- Train ----
    model.train()
    train_loss, train_correct, train_total = 0.0, 0, 0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [train]"):
        x = x.to(device, non_blocking=True)
        y = labels_to_indices(y).to(device, non_blocking=True)

        optimizer.zero_grad()
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping
        optimizer.step()

        train_loss += loss.item() * x.size(0)
        train_correct += (logits.argmax(1) == y).sum().item()
        train_total += x.size(0)

    tr_loss = train_loss / train_total
    tr_acc  = train_correct / train_total

    # ---- Validate ----
    model.eval()
    val_correct, val_total, val_loss = 0, 0, 0.0
    with torch.no_grad():
        for x, y in tqdm(val_loader, desc=f"Epoch {epoch}/{num_epochs} [val]"):
            x = x.to(device, non_blocking=True)
            y = labels_to_indices(y).to(device, non_blocking=True)
            logits = model(x)
            loss = loss_fn(logits, y)
            val_loss += loss.item() * x.size(0)
            val_correct += (logits.argmax(1) == y).sum().item()
            val_total += x.size(0)

    v_loss = val_loss / val_total
    v_acc  = val_correct / val_total
    print(f"Epoch {epoch}: train_loss={tr_loss:.4f}  train_acc={tr_acc:.4f}  "
          f"val_loss={v_loss:.4f}  val_acc={v_acc:.4f}")

    # Save the best model based on validation accuracy
    if v_acc > best_val:
        best_val = v_acc
        best_state = {k: v.cpu() for k, v in model.state_dict().items()}

print(f"\nTraining complete. Best validation accuracy: {best_val:.4f}")

# Load the best performing model state
if 'best_state' in locals():
    model.load_state_dict(best_state)
    print("Loaded best model weights for explainability analysis.")

Training on PathMNIST for 10 epochs...


Epoch 1/10 [train]: 100%|██████████| 1407/1407 [02:20<00:00,  9.98it/s]
Epoch 1/10 [val]: 100%|██████████| 79/79 [00:07<00:00, 11.04it/s]


Epoch 1: train_loss=1.5969  train_acc=0.4014  val_loss=1.3190  val_acc=0.4790


Epoch 2/10 [train]: 100%|██████████| 1407/1407 [03:11<00:00,  7.34it/s]
Epoch 2/10 [val]: 100%|██████████| 79/79 [00:14<00:00,  5.39it/s]


Epoch 2: train_loss=1.2045  train_acc=0.5556  val_loss=1.0715  val_acc=0.6021


Epoch 3/10 [train]: 100%|██████████| 1407/1407 [03:40<00:00,  6.39it/s]
Epoch 3/10 [val]: 100%|██████████| 79/79 [00:19<00:00,  4.01it/s]


Epoch 3: train_loss=0.9879  train_acc=0.6245  val_loss=0.9315  val_acc=0.6409


Epoch 4/10 [train]: 100%|██████████| 1407/1407 [04:01<00:00,  5.84it/s]
Epoch 4/10 [val]: 100%|██████████| 79/79 [00:16<00:00,  4.92it/s]


Epoch 4: train_loss=0.8799  train_acc=0.6656  val_loss=0.8170  val_acc=0.6942


Epoch 5/10 [train]: 100%|██████████| 1407/1407 [03:44<00:00,  6.27it/s]
Epoch 5/10 [val]: 100%|██████████| 79/79 [00:15<00:00,  5.04it/s]


Epoch 5: train_loss=0.8021  train_acc=0.7027  val_loss=0.7725  val_acc=0.7169


Epoch 6/10 [train]: 100%|██████████| 1407/1407 [03:49<00:00,  6.13it/s]
Epoch 6/10 [val]: 100%|██████████| 79/79 [00:17<00:00,  4.60it/s]


Epoch 6: train_loss=0.7375  train_acc=0.7308  val_loss=0.8024  val_acc=0.7186


Epoch 7/10 [train]: 100%|██████████| 1407/1407 [03:56<00:00,  5.94it/s]
Epoch 7/10 [val]: 100%|██████████| 79/79 [00:19<00:00,  4.14it/s]


Epoch 7: train_loss=0.6888  train_acc=0.7485  val_loss=0.6705  val_acc=0.7684


Epoch 8/10 [train]: 100%|██████████| 1407/1407 [03:56<00:00,  5.95it/s]
Epoch 8/10 [val]: 100%|██████████| 79/79 [00:17<00:00,  4.44it/s]


Epoch 8: train_loss=0.6503  train_acc=0.7640  val_loss=0.6090  val_acc=0.7791


Epoch 9/10 [train]: 100%|██████████| 1407/1407 [03:57<00:00,  5.92it/s]
Epoch 9/10 [val]: 100%|██████████| 79/79 [00:18<00:00,  4.37it/s]


Epoch 9: train_loss=0.6236  train_acc=0.7767  val_loss=0.6554  val_acc=0.7693


Epoch 10/10 [train]: 100%|██████████| 1407/1407 [04:04<00:00,  5.76it/s]
Epoch 10/10 [val]: 100%|██████████| 79/79 [00:18<00:00,  4.36it/s]

Epoch 10: train_loss=0.5907  train_acc=0.7914  val_loss=0.6434  val_acc=0.7791

Training complete. Best validation accuracy: 0.7791
Loaded best model weights for explainability analysis.



