In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
import numpy as np
import random
import pandas as pd

from sklearn.metrics import (
    roc_auc_score, f1_score, precision_score, recall_score,
    hamming_loss, accuracy_score
)

import matplotlib.pyplot as plt

import timm
from transformers import get_cosine_schedule_with_warmup
import csv

os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

In [None]:
class ImageDataset(Dataset):
    def __init__(self, csv_dir, image_dir, transforms=None):
        self.image_dir = image_dir
        self.transforms = transforms

        df = pd.read_csv(csv_dir)

        # Extract labels and image names
        self.labels = df.iloc[:, -14:].values.astype("float32")
        self.image_names = df['Path'].tolist()

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_names[idx])

        try:
            image = Image.open(img_path).convert('RGB')  # RGB
        except Exception as e:
            print(f"Error loading image: {img_path} - {e}")
            raise e

        if self.transforms:
            image = self.transforms(image)
        else:
            image = transforms.ToTensor()(image)

        image = np.array(image)

        label = np.array(self.labels[idx])
        
        return {'images': image, 'labels': label}

def compute_sample_weights(labels):
    """
    Computes sample weights for multi-label oversampling.
    labels: numpy array of shape (N, num_classes)
    """
    label_counts = np.sum(labels, axis=0)  # count per class
    class_weights = 1.0 / (label_counts + 1e-6)  # avoid division by zero

    # Sample weights per image: mean of the weights of the labels present
    sample_weights = np.dot(labels, class_weights)
    return sample_weights


In [None]:
dim = 224
size = (dim, dim)

trans_train = transforms.Compose([
    transforms.Resize(size),                      # Resize first
    transforms.RandomHorizontalFlip(p=0.5),      # 50% chance flip
    transforms.RandomRotation(15),                # ±15 degrees rotation
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # brightness/contrast jitter
    transforms.RandomResizedCrop(size, scale=(0.8, 1.0)),  # random zoom/crop
    transforms.RandomPerspective(distortion_scale=0.1, p=0.5), # optional perspective distortion
    transforms.Grayscale(num_output_channels=3), # convert to 3 channel grayscale
    transforms.ToTensor(),                         # convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet normalisation standard
                        std=[0.229, 0.224, 0.225]),
])

trans_val = transforms.Compose([
    transforms.Resize(size),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
])


ds_size = "Balanced"

train_data = ImageDataset(fr"NIH_X_14_{ds_size}\train_labels_encoded.csv", fr"NIH_X_14_{ds_size}\train", transforms=trans_train)
test_data = ImageDataset(fr"NIH_X_14_{ds_size}\test_labels_encoded.csv", fr"NIH_X_14_{ds_size}\test", transforms=trans_val)
val_data = ImageDataset(fr"NIH_X_14_{ds_size}\val_labels_encoded.csv", fr"NIH_X_14_{ds_size}\val", transforms=trans_val)


class_names = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
    'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration',
    'Mass', 'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax'
]

In [None]:
# --- Compute sample weights from label frequency ---
label_array = np.array(train_data.labels)  # Shape: (N_samples, N_classes)
class_counts = np.sum(label_array, axis=0)  # total positives per class
class_weights = 1.0 / (class_counts + 1e-6)  # inverse frequency
class_weights = class_weights / np.sum(class_weights) * len(class_weights)  # normalize

print("Class counts:", class_counts)
print("Class weights (normalized):", np.round(class_weights, 4))

sample_weights = np.dot(label_array, class_weights)  # (N_samples,)
sample_weights = sample_weights / np.sum(sample_weights) * len(sample_weights)
sample_weights = np.clip(sample_weights, a_min=1e-6, a_max=None)

print(f"Sample weights: mean={sample_weights.mean():.4e}, std={sample_weights.std():.4e}, max={sample_weights.max():.4e}")

imbalance_ratio = class_counts.max() / class_counts.min()
if imbalance_ratio > 10:
    print(f"Warning: Significant class imbalance detected (max/min = {imbalance_ratio:.2f}). Consider data augmentation or other sampling strategies.")

sample_weights_tensor = torch.DoubleTensor(sample_weights)
sampler = WeightedRandomSampler(sample_weights_tensor, len(sample_weights_tensor), replacement=True)

# Sanity check
unique_indices = set([i for i in sampler])
if len(unique_indices) < 0.5 * len(train_data):
    print(f"Sampler is using only {len(unique_indices)} unique samples out of {len(train_data)}. Check weights.")

# --- DataLoaders ---
batch_size = 32
train_dl = DataLoader(train_data, batch_size=batch_size, sampler=sampler, drop_last=True, pin_memory=True)
test_dl = DataLoader(test_data, shuffle=True, batch_size=batch_size, drop_last=True, pin_memory=True)
val_dl = DataLoader(val_data, shuffle=True, batch_size=batch_size, drop_last=True, pin_memory=True)

In [None]:
# --- ViT Model ---
class ViT(nn.Module):
    def __init__(self, model_name, num_classes=14, pretrained=True, in_chans=3):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
        self.model_name = model_name
        self.epochs_trained = 0

        # Adjust classifier if necessary
        self.model.reset_classifier(num_classes)

        # Freeze all parameters first
        for param in self.model.parameters():
            param.requires_grad = False

        # Unfreeze classifier head only
        if hasattr(self.model, 'head'):
            for param in self.model.head.parameters():
                param.requires_grad = True

        # Unfreeze layer norm if present (often after blocks)
        if hasattr(self.model, 'norm'):
            for param in self.model.norm.parameters():
                param.requires_grad = True

    def forward(self, x):
        return self.model(x)

    def fine_tune(self):
        # Unfreeze all model parameters
        for param in self.model.parameters():
            param.requires_grad = True


In [None]:
# --- Focal Loss ---
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        probs = torch.sigmoid(inputs)
        targets = targets.type(inputs.dtype)

        ce_loss = -(
            self.alpha * targets * torch.log(probs + 1e-8) +
            (1 - self.alpha) * (1 - targets) * torch.log(1 - probs + 1e-8)
        )
        focal_loss = (1 - probs) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss
        
# --- Asymmetric Loss ---
class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        preds = preds.clamp(self.eps, 1. - self.eps)

        if self.clip is not None and self.clip > 0:
            preds = (preds - self.clip).clamp(min=0)

        pos_loss = targets * torch.log(preds) * ((1 - preds) ** self.gamma_pos)
        neg_loss = (1 - targets) * torch.log(1 - preds) * (preds ** self.gamma_neg)

        return - (pos_loss + neg_loss).mean()

In [None]:
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs available: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
vit_models = [
    "vit_tiny_patch16_224",        # Lightweight, fast
    "vit_small_patch16_224",       # Small and efficient
    "vit_base_patch16_224",        # Standard ViT baseline
    "vit_large_patch16_224",       # More capacity, slower

    "deit_tiny_patch16_224",       # Data-efficient Image Transformer (DEiT) - tiny
    "deit_small_patch16_224",      # DEiT - small
    "deit_base_patch16_224",       # DEiT - standard

    "deit3_small_patch16_224",     # DEiT v3 - improved training, stronger baseline
    "deit3_medium_patch16_224",    # Balanced performance
    "deit3_base_patch16_224",      # DEiT v3 - base
    "deit3_large_patch16_224",     # DEiT v3 - large (resource heavy)

    "vit_relpos_base_patch16_224", # Relative positional encoding
    "vit_relpos_small_patch16_224",# Small version with relative pos

    "swin_tiny_patch4_window7_224",# Swin Transformer (shifted windows)
    "swin_small_patch4_window7_224",
    "swin_base_patch4_window7_224",

    "beit_base_patch16_224",       # BEiT (BERT-style masked image pretraining)
    "beit_large_patch16_224"       # Larger BEiT
]


# --- Model Setup ---
model = ViT(model_name='swin_base_patch4_window7_224', num_classes=14)
model.fine_tune()
model = model.to(device)

eta = 1e-4
opti = torch.optim.AdamW(model.parameters(), lr=eta, weight_decay=0.01)

criterion = nn.BCEWithLogitsLoss()

In [None]:
# --- Learning Rate Scheduler ---
epochs = 150
warmup_steps = len(train_dl) * 2
total_steps = len(train_dl) * epochs
scheduler = get_cosine_schedule_with_warmup(opti, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

# --- Metrics Tracking ---
epoch_train_loss, epoch_val_loss = [], []
val_roc_aucs_by_epoch, val_roc_auc_means = [], []
hammings, precisions, recalls, f1_macros, f1_micros = [], [], [], [], []
class_thresholds = np.full(14, 0.35)
best_auc = 0.0
early_stop_counter = 0
patience = 5

In [None]:
metrics_log = []
metrics_csv_path = f"Metrics/{model.model_name}_training_metrics.csv"
os.makedirs(os.path.dirname(metrics_csv_path), exist_ok=True)
os.makedirs("Checkpoints", exist_ok=True)  # Ensure checkpoint folder exists

roc_auc_columns = [f"roc_auc_{name}" for name in class_names]

# Write header if file doesn't exist
if not os.path.exists(metrics_csv_path):
    with open(metrics_csv_path, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            "epoch", "train_loss", "val_loss", "mean_auc", "mean_entropy",
            "f1_micro", "f1_macro", "precision", "recall", "hamming_loss"
        ] + roc_auc_columns)

# --- Training Loop ---
for epoch in range(epochs):
    train_losses = []
    model.train()

    for d in tqdm(train_dl, desc=f"Training {epoch+1}/{epochs}"):
        opti.zero_grad()
        imgs = d['images'].to(device, dtype=torch.float)
        labels = d['labels'].to(device, dtype=torch.float)

        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opti.step()
        scheduler.step()

        train_losses.append(loss.item())

    epoch_train_loss.append(np.mean(train_losses))
    model.epochs_trained += 1

    # --- Validation ---
    model.eval()
    val_losses, val_preds, val_targets = [], [], []
    with torch.no_grad():
        for d in tqdm(val_dl, desc="Validating"):
            imgs = d['images'].to(device, dtype=torch.float)
            labels = d['labels'].to(device, dtype=torch.float)

            logits = model(imgs)
            loss = criterion(logits, labels)
            val_losses.append(loss.item())

            preds = torch.sigmoid(logits)
            val_preds.append(preds.cpu())
            val_targets.append(labels.cpu())

    epoch_val_loss.append(np.mean(val_losses))
    all_val_preds = torch.cat(val_preds).numpy()
    all_val_targets = torch.cat(val_targets).numpy()

    entropy = - (all_val_preds * np.log(all_val_preds + 1e-7) + (1 - all_val_preds) * np.log(1 - all_val_preds + 1e-7))
    mean_entropy = np.mean(entropy)

    val_epoch_roc_aucs = []
    for i in range(all_val_targets.shape[1]):
        try:
            auc = roc_auc_score(all_val_targets[:, i], all_val_preds[:, i])
        except ValueError:
            auc = float('nan')
        val_epoch_roc_aucs.append(auc)

    val_roc_aucs_by_epoch.append(val_epoch_roc_aucs)
    mean_auc = np.nanmean(val_epoch_roc_aucs)
    val_roc_auc_means.append(mean_auc)

    # --- Thresholding ---
    binary_preds = np.zeros_like(all_val_preds)
    for i, thresh in enumerate(class_thresholds):
        binary_preds[:, i] = (all_val_preds[:, i] > thresh).astype(int)

    f1_micro = f1_score(all_val_targets, binary_preds, average='micro', zero_division=0)
    f1_macro = f1_score(all_val_targets, binary_preds, average='macro', zero_division=0)
    precision = precision_score(all_val_targets, binary_preds, average='micro', zero_division=0)
    recall = recall_score(all_val_targets, binary_preds, average='micro', zero_division=0)
    hamming = hamming_loss(all_val_targets, binary_preds)
    label_accuracy = (binary_preds == all_val_targets).mean(axis=0)

    f1_micros.append(f1_micro)
    f1_macros.append(f1_macro)
    precisions.append(precision)
    recalls.append(recall)
    hammings.append(hamming)

    # --- Epoch Summary ---
    print(f"Train Loss: {epoch_train_loss[-1]:.4f} | Val Loss: {epoch_val_loss[-1]:.4f}")
    print(f"Mean ROC-AUC: {mean_auc:.4f} | Mean Entropy: {mean_entropy:.4f}")
    print(f"F1-score (micro): {f1_micro:.4f} | F1-score (macro): {f1_macro:.4f}")
    print(f"Precision: {precision:.4f} | Recall: {recall:.4f}")
    print(f"Hamming Loss: {hamming:.4f}")
    print(f"Label-wise Accuracy: {label_accuracy.round(3)}\n")

    # Save metrics to CSV using model.epochs_trained
    with open(metrics_csv_path, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            model.epochs_trained,
            epoch_train_loss[-1],
            epoch_val_loss[-1],
            mean_auc,
            mean_entropy,
            f1_micro,
            f1_macro,
            precision,
            recall,
            hamming
        ] + val_epoch_roc_aucs)

    # --- Checkpoint & Early Stopping ---
    if mean_auc > best_auc:
        best_auc = mean_auc
        early_stop_counter = 0
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimiser_state_dict': opti.state_dict(),
            'epochs_trained': model.epochs_trained,
        }, f"Checkpoints/{model.model_name}_best.pth")
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

In [None]:
plt.plot(epoch_train_loss, label='Train Loss')
plt.plot(epoch_val_loss, label='Validation Loss')
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()

plt.plot(f1_micros, label='F1 Micros')
plt.plot(f1_macros, label='F1 Macros')
plt.title("F1 Scores")
plt.xlabel("Epoch")
plt.ylabel("Score")
plt.legend()
plt.grid()
plt.show()

plt.plot(precisions, label='Precision')
plt.plot(recalls, label='Recall')
plt.title("P&R Rates")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()

plt.plot(hammings, label='Hamming Loss')
plt.title("Hamming Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()

val_roc_aucs_by_epoch_np = np.array(val_roc_aucs_by_epoch)
epochs_range = range(0, len(val_roc_aucs_by_epoch_np))

# Per-class curves
try:
    for i in range(val_roc_aucs_by_epoch_np.shape[1]):
        plt.plot(epochs_range, val_roc_aucs_by_epoch_np[:, i], label=class_names[i], alpha=0.5)

    # Mean ROC-AUC
    plt.plot(epochs_range, val_roc_auc_means, label='Mean ROC-AUC', color='black', linewidth=2.5)

    plt.title("Validation ROC-AUC per Class and Mean")
    plt.xlabel("Epoch")
    plt.ylabel("ROC-AUC")
    plt.grid(True)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()
except:
    print("Empty variables")

In [None]:
# Put model in evaluation mode
model.eval()

all_probs = []
all_logits = []
all_targets = []
all_images = []

with torch.no_grad():
    for d in tqdm(test_dl, desc="Generating predictions"):
        imgs = d['images'].to(device, dtype=torch.float)
        labels = d['labels'].to(device, dtype=torch.float)

        logits = model(imgs)
        prob = torch.sigmoid(logits)

        all_logits.append(logits.cpu())
        all_probs.append(prob.cpu())
        all_targets.append(labels.cpu())
        all_images.append(imgs.cpu())

# Stack into single arrays
all_probs = torch.cat(all_probs).numpy()
all_targets = torch.cat(all_targets).numpy()
all_logits = torch.cat(all_logits).numpy()
all_images = torch.cat(all_images)


In [None]:
best_thresholds = []

for i in range(all_targets.shape[1]):
    scores = []
    for t in tqdm(np.arange(0, 1, 0.01), desc=f"Optimising threshold of class {i+1}"):
        probs = (all_probs[:, i] > t).astype(int)
        f1 = f1_score(all_targets[:, i], probs, zero_division=0)
        scores.append((t, f1))
    best_threshold = max(scores, key=lambda x: x[1])[0]
    best_thresholds.append(best_threshold)

print("Best thresholds:", best_thresholds)

In [None]:
all_preds = np.array([
    (all_probs[:, i] > best_thresholds[i]).astype(int)
    for i in range(all_probs.shape[1])
]).T

In [None]:
roc_aucs = []

for i in range(all_targets.shape[1]):
    try:
        auc = roc_auc_score(all_targets[:, i], all_preds[:, i])
    except ValueError:
        auc = float('nan')
    roc_aucs.append(auc)

print("\nPer-Class ROC-AUC:")
for i, auc in enumerate(roc_aucs):
    print(f"Class {i}: {auc:.4f}")

print(f"\nMacro ROC-AUC: {np.nanmean(roc_aucs):.4f}")

In [None]:
opti_preds = np.zeros_like(all_preds)

for i in range(all_preds.shape[1]):
    opti_preds[:, i] = (all_preds[:, i] > best_thresholds[i]).astype(int)

In [None]:
def show_prediction(img_tensor, true_labels, pred_labels, opti_labels, class_names):
    # Convert image to numpy array if it's a tensor
    if isinstance(img_tensor, torch.Tensor):
        img = img_tensor.cpu().numpy()
    else:
        img = np.array(img_tensor)

    # Transpose if image has 3 channels (C, H, W) -> (H, W, C)
    if img.ndim == 3 and img.shape[0] == 3:
        img = np.transpose(img, (1, 2, 0))  # (C, H, W) -> (H, W, C)

    # Clip values to [0, 1] in case they exceed due to normalisation
    img = np.clip(img, 0, 1)

    # Convert labels to numpy if they're tensors
    if isinstance(true_labels, torch.Tensor):
        true_labels = true_labels.cpu().numpy()
    if isinstance(pred_labels, torch.Tensor):
        pred_labels = pred_labels.cpu().numpy()
    if isinstance(opti_labels, torch.Tensor):
        opti_labels = opti_labels.cpu().numpy()

    # Get active labels
    true_txt = ", ".join([class_names[i] for i, v in enumerate(true_labels) if v == 1])
    pred_txt = ", ".join([class_names[i] for i, v in enumerate(pred_labels) if v == 1])
    opti_pred_txt = ", ".join([class_names[i] for i, v in enumerate(opti_labels) if v == 1])

    # Plotting
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"True: {true_txt}\nPred(std): {pred_txt}\nPred(Opti): {opti_pred_txt}", fontsize=10)
    plt.show()


idx = random.randint(0, len(all_images)-1)
show_prediction(all_images[idx], all_targets[idx], all_preds[idx], opti_preds[idx], class_names)

In [None]:
subset_acc = accuracy_score(all_targets, all_preds)
print(f"Exact match accuracy: {subset_acc * 100:.4f}%")

In [None]:
torch.save(model.state_dict(), fr'Models/{model.model_name}_{model.epochs_trained}.pth')

In [None]:
model = model = ViT(num_classes=14, model_name='swin_base_patch4_window7_224')

checkpoint = torch.load(fr'Checkpoints/{model.model_name}_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])
opti.load_state_dict(checkpoint['optimiser_state_dict'])
model.epochs_trained = checkpoint.get('epochs_trained', 0)
model.eval()
model.fine_tune()
model.to(device)

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
import torch.nn.functional as F

# Get the Attention module in the last block
def get_last_block_attn(model):
    return model.model.layers[-1].blocks[-1].attn

# Manually extract attn weights inside the module
def get_attention_map(model, image_tensor):
    attention_maps = []

    def hook_fn(module, input, output):
        # Manually recompute attn weights from inputs inside the Attention module
        qkv = module.qkv(input[0])  # [B*nW, N, 3*C]
        B, N, _ = qkv.shape
        C = _ // 3
        q, k, v = qkv.chunk(3, dim=2)

        q = q.reshape(B, N, module.num_heads, C // module.num_heads).permute(0, 2, 1, 3)
        k = k.reshape(B, N, module.num_heads, C // module.num_heads).permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)) * module.scale
        attn = attn.softmax(dim=-1)  # [B*nW, num_heads, N, N]
        attention_maps.append(attn.detach())

    attn_module = get_last_block_attn(model)
    handle = attn_module.register_forward_hook(hook_fn)

    model.eval()
    with torch.no_grad():
        _ = model(image_tensor)

    handle.remove()
    return attention_maps[0]  # [B*nW, num_heads, N, N]

# Visualize the overlay of attention on the image
def visualize_attention_on_image(img_tensor, attn_weights, head=None, mean=None, std=None):
    img = img_tensor.squeeze(0).cpu().clone()

    if mean is not None and std is not None:
        mean = torch.tensor(mean).view(-1, 1, 1)
        std = torch.tensor(std).view(-1, 1, 1)
        img = img * std + mean

    pil_img = to_pil_image(img)

    # Extract attention
    nW_B, nH, N, _ = attn_weights.shape
    token_idx = N // 2  # use center token
    if head is not None:
        attn = attn_weights[:, head, token_idx, :]  # [nW*B, N]
    else:
        attn = attn_weights[:, :, token_idx, :]     # [nW*B, nH, N]
        attn = attn.mean(1)  # mean over heads

    attn = attn.mean(0)  # mean over windows
    attn = attn.view(int(N**0.5), int(N**0.5))  # reshape

    # Upsample
    attn = attn.unsqueeze(0).unsqueeze(0)
    attn = F.interpolate(attn, size=(img.shape[1], img.shape[2]), mode='bilinear', align_corners=False)
    attn = attn.squeeze().cpu().numpy()
    attn = (attn - attn.min()) / (attn.max() - attn.min() + 1e-6)

    plt.imshow(pil_img)
    plt.imshow(attn, cmap='jet', alpha=0.5)
    plt.axis('off')
    plt.show()

# Run on one batch
def run_attention_overlay(model, dataloader, mean, std, head=None):
    device = next(model.parameters()).device
    model.eval()

    for batch in dataloader:
        img_tensor = batch['images'][0].unsqueeze(0).to(device)
        print(batch['labels'][0])
        attn = get_attention_map(model, img_tensor)
        visualize_attention_on_image(img_tensor, attn, head=head, mean=mean, std=std)
        break


In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
run_attention_overlay(model, test_dl, mean, std, head=None)  # head=None for average