In [2]:
# =============== Imports & Hyperparameters ===============
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
import matplotlib.pyplot as plt

import medmnist
from medmnist import INFO
from sklearn.metrics import accuracy_score

# ----- Training configuration -----
DATASET_NAME = "pathmnist"      # we focus on PathMNIST
BATCH_SIZE = 256                # larger batch for better GPU utilization
EPOCHS_CNN = 8                  # epochs for SimpleCNN
EPOCHS_RESNET = 12              # epochs for ResNet18 (larger net)
LR_CNN = 1e-3                   # learning rate for SimpleCNN
LR_RESNET = 1e-4                # smaller LR for fine-tuning pretrained ResNet18
USE_PRETRAINED_RESNET = True    # use ImageNet weights

# ImageNet normalization (standard for ResNet-like models)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]


# ====================== Utils ======================

def get_device():
    """Return GPU device if available, otherwise CPU."""
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_seed(seed=42):
    """Set random seed for reproducibility."""
    import random
    random.seed(seed)
    np.random.seed(seed) 
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def denormalize_image(img_tensor):
    """
    Undo ImageNet normalization for visualization.
    img_tensor: tensor [3,H,W], normalized by IMAGENET_MEAN/STD.
    Returns a tensor in [0,1].
    """
    mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1).to(img_tensor.device)
    std = torch.tensor(IMAGENET_STD).view(3, 1, 1).to(img_tensor.device)
    img = img_tensor * std + mean
    img = torch.clamp(img, 0.0, 1.0)
    return img


# ====================== Data ======================

def get_medmnist_dataloaders(dataset_name=DATASET_NAME, batch_size=BATCH_SIZE):
    """
    Load MedMNIST dataset (train/val/test) with preprocessing.

    Preprocessing:
      - resize to 64x64 (slightly larger than original 28x28)
      - convert to tensor in [0,1]
      - convert 1-channel images to 3-channel
      - apply ImageNet normalization

    Train loader uses data augmentation:
      - random horizontal flip
      - small random rotation
    """
    info = INFO[dataset_name]
    DataClass = getattr(medmnist, info["python_class"])
    num_classes = len(info["label"])

    # Data augmentation for training
    transform_train = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])

    # Evaluation transform (no augmentation)
    transform_eval = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])

    train_dataset = DataClass(split="train", transform=transform_train, download=True)
    val_dataset = DataClass(split="val", transform=transform_eval, download=True)
    test_dataset = DataClass(split="test", transform=transform_eval, download=True)

    # num_workers=0 to avoid multiprocessing issues in Jupyter on Windows
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    return train_loader, val_loader, test_loader, num_classes


# ====================== Models ======================

class SimpleCNN(nn.Module):
    """
    A convolutional neural network used as a baseline.
    """

    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        # Block 1
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        # Block 2
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        # Block 3
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        # Block 4
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)

        # Input 64x64 -> 32 -> 16 -> 8 -> 4 after 4 poolings
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))  # [B,64,32,32]
        x = self.pool(F.relu(self.bn2(self.conv2(x))))  # [B,128,16,16]
        x = self.pool(F.relu(self.bn3(self.conv3(x))))  # [B,256,8,8]
        x = self.pool(F.relu(self.bn4(self.conv4(x))))  # [B,256,4,4]
        x = x.view(x.size(0), -1)                       # [B,256*4*4]
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x


def get_resnet18_model(num_classes, use_pretrained=True):
    """
    Create a ResNet18 model with modified final layer.
    Used as a stronger model compared to SimpleCNN.
    """
    if use_pretrained:
        weights = ResNet18_Weights.DEFAULT
        model = resnet18(weights=weights)
    else:
        model = resnet18(weights=None)

    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


# ====================== Training & Evaluation ======================

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    """Train model for one epoch and return average loss."""
    model.train()
    running_loss = 0.0

    for images, targets in dataloader:
        images = images.to(device)

        # Targets from MedMNIST can be tensors or numpy arrays
        if not isinstance(targets, torch.Tensor):
            targets = torch.tensor(targets)
        targets = targets.squeeze()
        if targets.ndim > 1:
            # Multi-hot -> single label via argmax
            targets = targets.argmax(dim=1)
        targets = targets.long().to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    return running_loss / len(dataloader.dataset)


def evaluate(model, dataloader, device):
    """Evaluate model accuracy."""
    model.eval()
    preds_all = []
    labels_all = []

    with torch.no_grad():
        for images, targets in dataloader:
            images = images.to(device)

            if not isinstance(targets, torch.Tensor):
                targets = torch.tensor(targets)
            targets = targets.squeeze()
            if targets.ndim > 1:
                targets = targets.argmax(dim=1)
            targets = targets.long()

            outputs = model(images)
            preds = outputs.argmax(dim=1).cpu()

            preds_all.extend(preds.tolist())
            labels_all.extend(targets.tolist())

    acc = accuracy_score(labels_all, preds_all)
    return acc


def train_model(model, train_loader, val_loader, device, num_epochs=5, lr=1e-3):
    """
    Train model for several epochs and keep the best one on validation set.
    """
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_acc = 0.0
    best_state = None

    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_acc = evaluate(model, val_loader, device)
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            best_state = model.state_dict()

    if best_state is not None:
        model.load_state_dict(best_state)

    print("Best validation accuracy:", best_acc)
    return model


# ====================== Grad-CAM & Grad-CAM++ ======================

class GradCAM:
    """Implementation of Grad-CAM."""

    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        self.forward_hook = target_layer.register_forward_hook(self.save_activation)
        self.backward_hook = target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, inp, out):
        self.activations = out.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def generate(self, image, target_class, device):
        """
        Generate Grad-CAM heatmap for a single image.
        image: tensor [3,H,W], target_class: int
        """
        self.model.eval()
        self.model.zero_grad()

        image = image.unsqueeze(0).to(device)
        output = self.model(image)
        score = output[0, target_class]
        score.backward()

        gradients = self.gradients[0]     # [C,H,W]
        activations = self.activations[0] # [C,H,W]
        weights = gradients.mean(dim=(1, 2))  # [C]

        cam = torch.zeros(activations.shape[1:], device=device)
        for i, w in enumerate(weights):
            cam += w * activations[i]

        cam = F.relu(cam)
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-6)
        return cam.cpu().numpy()

    def close(self):
        self.forward_hook.remove()
        self.backward_hook.remove()


class GradCAMPlusPlus(GradCAM):
    """Implementation of Grad-CAM++ using higher-order gradients."""

    def generate(self, image, target_class, device):
        self.model.eval()
        self.model.zero_grad()

        image = image.unsqueeze(0).to(device)
        output = self.model(image)
        score = output[0, target_class]
        score.backward(retain_graph=True)

        gradients = self.gradients[0]     # [C,H,W]
        activations = self.activations[0] # [C,H,W]

        grad2 = gradients ** 2
        grad3 = gradients ** 3
        sum_act = (activations * grad3).sum(dim=(1, 2), keepdim=True)

        alpha = grad2 / (2.0 * grad2 + sum_act + 1e-6)
        weights = (alpha * F.relu(gradients)).sum(dim=(1, 2))  # [C]

        cam = torch.zeros(activations.shape[1:], device=device)
        for i, w in enumerate(weights):
            cam += w * activations[i]

        cam = F.relu(cam)
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-6)
        return cam.cpu().numpy()


# ====================== Visualization ======================

def overlay_heatmap(image, heatmap, alpha=0.4):
    """
    Overlay heatmap over the de-normalized image.
    image: tensor [3,H,W] (normalized)
    heatmap: numpy array [h,w] in [0,1] (may be smaller than image)
    """
    # de-normalize image for display
    img = denormalize_image(image).permute(1, 2, 0).cpu().numpy()  # [H,W,3]
    H, W = img.shape[:2]

    # upsample heatmap to match image size using bilinear interpolation
    heatmap_t = torch.tensor(heatmap, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # [1,1,h,w]
    heatmap_up = F.interpolate(heatmap_t, size=(H, W), mode="bilinear", align_corners=False)
    heatmap_up = heatmap_up.squeeze().cpu().numpy()  # [H,W]

    # convert heatmap to RGB
    heatmap_rgb = plt.cm.jet(heatmap_up)[..., :3]  # [H,W,3]

    # blend image and heatmap
    blended = (1 - alpha) * img + alpha * heatmap_rgb
    blended = np.clip(blended, 0.0, 1.0)
    return blended


def save_examples(model, gradcam, gradcampp, test_loader, device, outdir="xai_outputs"):
    """
    Save a few example images with Grad-CAM and Grad-CAM++ overlays.
    """
    os.makedirs(outdir, exist_ok=True)

    images, labels = next(iter(test_loader))
    images = images[:5]
    labels = labels.squeeze()
    if labels.ndim > 1:
        labels = labels.argmax(dim=1)
    labels = labels.long()

    for i in range(len(images)):
        img = images[i]
        cls = int(labels[i].item())

        cam = gradcam.generate(img, cls, device)
        campp = gradcampp.generate(img, cls, device)

        fig, axes = plt.subplots(1, 3, figsize=(9, 3))

        # Original (de-normalized)
        axes[0].imshow(denormalize_image(img).permute(1, 2, 0).cpu().numpy())
        axes[0].set_title(f"Original (class {cls})")
        axes[0].axis("off")

        axes[1].imshow(overlay_heatmap(img, cam))
        axes[1].set_title("Grad-CAM")
        axes[1].axis("off")

        axes[2].imshow(overlay_heatmap(img, campp))
        axes[2].set_title("Grad-CAM++")
        axes[2].axis("off")

        plt.tight_layout()
        fig.savefig(os.path.join(outdir, f"example_{i}.png"))
        plt.close(fig)

    print(f"Saved XAI examples to: {outdir}")


In [3]:
# ================== Training & XAI Pipeline ==================

set_seed(42)
device = get_device()
print("Using device:", device)

# ----- Load data -----
train_loader, val_loader, test_loader, num_classes = \
    get_medmnist_dataloaders(DATASET_NAME, BATCH_SIZE)

print(f"Dataset: {DATASET_NAME}, num_classes: {num_classes}")

# ----- Train SimpleCNN (baseline) -----
print("\nTraining SimpleCNN (baseline)...")
simple_cnn = SimpleCNN(num_classes)
simple_cnn = train_model(simple_cnn, train_loader, val_loader, device,
                         num_epochs=EPOCHS_CNN, lr=LR_CNN)
test_acc_cnn = evaluate(simple_cnn, test_loader, device)
print("SimpleCNN Test Accuracy:", test_acc_cnn)

os.makedirs("checkpoints", exist_ok=True)
torch.save(simple_cnn.state_dict(), os.path.join("checkpoints", "simple_cnn.pth"))

# ----- Train ResNet18 (improved, larger model) -----
print("\nTraining ResNet18 (improved model)...")
resnet18_model = get_resnet18_model(num_classes, use_pretrained=USE_PRETRAINED_RESNET)
resnet18_model = train_model(resnet18_model, train_loader, val_loader, device,
                             num_epochs=EPOCHS_RESNET, lr=LR_RESNET)
test_acc_resnet18 = evaluate(resnet18_model, test_loader, device)
print("ResNet18 Test Accuracy:", test_acc_resnet18)

torch.save(resnet18_model.state_dict(), os.path.join("checkpoints", "resnet18.pth"))

# ----- Use ResNet18 for Grad-CAM / Grad-CAM++ -----
final_model = resnet18_model
target_layer = final_model.layer4[-1].conv2  # last conv layer of ResNet18

gradcam = GradCAM(final_model, target_layer)
gradcampp = GradCAMPlusPlus(final_model, target_layer)

save_examples(final_model, gradcam, gradcampp, test_loader, device,
              outdir="xai_outputs")

gradcam.close()
gradcampp.close()

print("\nDone. Models saved in 'checkpoints', XAI images in 'xai_outputs'.")


Using device: cuda


100%|██████████| 206M/206M [00:09<00:00, 21.6MB/s] 


Dataset: pathmnist, num_classes: 9

Training SimpleCNN (baseline)...
Epoch 1/8 | Train Loss: 0.7970 | Val Acc: 0.7649
Epoch 2/8 | Train Loss: 0.4169 | Val Acc: 0.8780
Epoch 3/8 | Train Loss: 0.2951 | Val Acc: 0.9252
Epoch 4/8 | Train Loss: 0.2376 | Val Acc: 0.8973
Epoch 5/8 | Train Loss: 0.2071 | Val Acc: 0.9318
Epoch 6/8 | Train Loss: 0.1811 | Val Acc: 0.9441
Epoch 7/8 | Train Loss: 0.1629 | Val Acc: 0.9241
Epoch 8/8 | Train Loss: 0.1488 | Val Acc: 0.9298
Best validation accuracy: 0.9441223510595762


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\dashabi/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth


SimpleCNN Test Accuracy: 0.8328690807799443

Training ResNet18 (improved model)...


100%|██████████| 44.7M/44.7M [00:00<00:00, 58.2MB/s]


Epoch 1/12 | Train Loss: 0.3361 | Val Acc: 0.9507
Epoch 2/12 | Train Loss: 0.1459 | Val Acc: 0.9646
Epoch 3/12 | Train Loss: 0.1026 | Val Acc: 0.9706
Epoch 4/12 | Train Loss: 0.0795 | Val Acc: 0.9757
Epoch 5/12 | Train Loss: 0.0601 | Val Acc: 0.9722
Epoch 6/12 | Train Loss: 0.0499 | Val Acc: 0.9778
Epoch 7/12 | Train Loss: 0.0403 | Val Acc: 0.9748
Epoch 8/12 | Train Loss: 0.0357 | Val Acc: 0.9690
Epoch 9/12 | Train Loss: 0.0301 | Val Acc: 0.9837
Epoch 10/12 | Train Loss: 0.0257 | Val Acc: 0.9806
Epoch 11/12 | Train Loss: 0.0241 | Val Acc: 0.9822
Epoch 12/12 | Train Loss: 0.0232 | Val Acc: 0.9810
Best validation accuracy: 0.9837065173930428
ResNet18 Test Accuracy: 0.9018105849582173


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved XAI examples to: xai_outputs

Done. Models saved in 'checkpoints', XAI images in 'xai_outputs'.
