In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
import torchvision.models as models
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
import numpy as np
import random
import pandas as pd
import csv

from sklearn.metrics import (
    roc_auc_score, f1_score, precision_score, recall_score,
    hamming_loss, accuracy_score
)

import matplotlib.pyplot as plt
from torchvision.models import ResNet18_Weights, ResNet50_Weights, DenseNet121_Weights, efficientnet_b0, EfficientNet_B0_Weights, efficientnet_b3, EfficientNet_B3_Weights, ResNet34_Weights, convnext_base, ConvNeXt_Base_Weights

In [None]:
class ImageDataset(Dataset):
    def __init__(self, csv_dir, image_dir, transforms=None):
        self.image_dir = image_dir
        self.transforms = transforms

        df = pd.read_csv(csv_dir)

        # Extract labels and image names
        self.labels = df.iloc[:, -14:].values.astype("float32")
        self.image_names = df['Path'].tolist()

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_names[idx])

        try:
            # image = Image.open(img_path).convert('L')  # Grayscale
            image = Image.open(img_path).convert('RGB')  # RGB
        except Exception as e:
            print(f"Error loading image: {img_path} - {e}")
            raise e

        if self.transforms:
            image = self.transforms(image)
        else:
            image = transforms.ToTensor()(image)

        image = np.array(image)

        label = np.array(self.labels[idx])
        
        return {'images': image, 'labels': label}

def compute_sample_weights(labels):
    """
    Computes sample weights for multi-label oversampling.
    labels: numpy array of shape (N, num_classes)
    """
    label_counts = np.sum(labels, axis=0)  # count per class
    class_weights = 1.0 / (label_counts + 1e-6)  # avoid division by zero

    # Sample weights per image: mean of the weights of the labels present
    sample_weights = np.dot(labels, class_weights)
    return sample_weights


In [None]:
dim = 224
size = (dim, dim)

trans_train = transforms.Compose([
    transforms.Resize(size),                      # Resize first
    transforms.RandomHorizontalFlip(p=0.5),      # 50% chance flip
    transforms.RandomRotation(15),                # ±15 degrees rotation
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # brightness/contrast jitter
    transforms.RandomResizedCrop(size, scale=(0.8, 1.0)),  # random zoom/crop
    transforms.RandomPerspective(distortion_scale=0.1, p=0.5), # optional perspective distortion
    transforms.Grayscale(num_output_channels=3), # convert to 3 channel grayscale
    transforms.ToTensor(),                         # convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet normalisation standard
                        std=[0.229, 0.224, 0.225]),
])

trans_val = transforms.Compose([
    transforms.Resize(size),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
])

ds_size = "Balanced"

train_data = ImageDataset(fr"NIH_X_14_{ds_size}\train_labels_encoded.csv", fr"NIH_X_14_{ds_size}\train", transforms=trans_train)
test_data = ImageDataset(fr"NIH_X_14_{ds_size}\test_labels_encoded.csv", fr"NIH_X_14_{ds_size}\test", transforms=trans_val)
val_data = ImageDataset(fr"NIH_X_14_{ds_size}\val_labels_encoded.csv", fr"NIH_X_14_{ds_size}\val", transforms=trans_val)


class_names = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
    'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration',
    'Mass', 'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax'
]

In [None]:
# --- Compute sample weights from label frequency ---
label_array = np.array(train_data.labels)  # Shape: (N_samples, N_classes)
class_counts = np.sum(label_array, axis=0)  # total positives per class
class_weights = 1.0 / (class_counts + 1e-6)  # inverse frequency
class_weights = class_weights / np.sum(class_weights) * len(class_weights)  # normalize

print("Class counts:", class_counts)
print("Class weights (normalized):", np.round(class_weights, 4))

sample_weights = np.dot(label_array, class_weights)  # (N_samples,)
sample_weights = sample_weights / np.sum(sample_weights) * len(sample_weights)
sample_weights = np.clip(sample_weights, a_min=1e-6, a_max=None)

print(f"Sample weights: mean={sample_weights.mean():.4e}, std={sample_weights.std():.4e}, max={sample_weights.max():.4e}")

imbalance_ratio = class_counts.max() / class_counts.min()
if imbalance_ratio > 10:
    print(f"Warning: Significant class imbalance detected (max/min = {imbalance_ratio:.2f}). Consider data augmentation or other sampling strategies.")

sample_weights_tensor = torch.DoubleTensor(sample_weights)
sampler = WeightedRandomSampler(sample_weights_tensor, len(sample_weights_tensor), replacement=True)

# Sanity check
unique_indices = set([i for i in sampler])
if len(unique_indices) < 0.5 * len(train_data):
    print(f"Sampler is using only {len(unique_indices)} unique samples out of {len(train_data)}. Check weights.")

# --- DataLoaders ---
batch_size = 256
train_dl = DataLoader(train_data, batch_size=batch_size, sampler=sampler, drop_last=True, pin_memory=True)
test_dl = DataLoader(test_data, shuffle=True, batch_size=batch_size, drop_last=True, pin_memory=True)
val_dl = DataLoader(val_data, shuffle=True, batch_size=batch_size, drop_last=True, pin_memory=True)

In [None]:
def _resnet_fine_tune(self, n_layers):
    for param in self.net.parameters():
        param.requires_grad = False

    res_layers = [self.net.layer1, self.net.layer2, self.net.layer3, self.net.layer4]
    for block in res_layers[-n_layers:]:
        for param in block.parameters():
            param.requires_grad = True

    for param in self.net.fc.parameters():
        param.requires_grad = True

def _efficientnet_fine_tune(self, n_layers):
    for param in self.model.features.parameters():
        param.requires_grad = False
    for block in self.model.features[-n_layers:]:
        for param in block.parameters():
            param.requires_grad = True
    for param in self.model.classifier.parameters():
        param.requires_grad = True

In [None]:
# ResNet50
class ResNet50(nn.Module):
    def __init__(self, num_classes, in_chans=3):
        super(ResNet50, self).__init__()
        self.net = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        if in_chans != 3:
            self.net.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)

        self.in_features = self.net.fc.in_features
        self.net.fc = nn.Linear(self.in_features, num_classes)

        for param in self.net.parameters():
            param.requires_grad = False
        for param in self.net.fc.parameters():
            param.requires_grad = True

        self.model_name = "ResNet50"
        self.epochs_trained = 0

    def forward(self, x):
        return self.net(x)

    def fine_tune(self, n_layers=2):
        _resnet_fine_tune(self, n_layers)

# ResNet34
class ResNet34(nn.Module):
    def __init__(self, num_classes, in_chans=3):
        super(ResNet34, self).__init__()
        self.net = models.resnet34(weights=ResNet34_Weights.DEFAULT)
        if in_chans != 3:
            self.net.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)

        self.in_features = self.net.fc.in_features
        self.net.fc = nn.Linear(self.in_features, num_classes)

        for param in self.net.parameters():
            param.requires_grad = False
        for param in self.net.fc.parameters():
            param.requires_grad = True

        self.model_name = "ResNet34"
        self.epochs_trained = 0

    def forward(self, x):
        return self.net(x)
    
    def fine_tune(self, n_layers=2):
        _resnet_fine_tune(self, n_layers)

# ResNet18
class ResNet18(nn.Module):
    def __init__(self, num_classes, in_chans=3):
        super(ResNet18, self).__init__()
        self.net = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        if in_chans != 3:
            self.net.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)

        self.in_features = self.net.fc.in_features
        self.net.fc = nn.Linear(self.in_features, num_classes)

        for param in self.net.parameters():
            param.requires_grad = False
        for param in self.net.fc.parameters():
            param.requires_grad = True

        self.model_name = "ResNet18"
        self.epochs_trained = 0

    def forward(self, x):
        return self.net(x)
    
    def fine_tune(self, n_layers=2):
        _resnet_fine_tune(self, n_layers)

# DenseNet121
class DenseNet121(nn.Module):
    def __init__(self, num_classes, in_chans=3):
        super(DenseNet121, self).__init__()
        base_model = models.densenet121(weights=DenseNet121_Weights.DEFAULT)
        if in_chans != 3:
            base_model.features.conv0 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)

        self.features = base_model.features
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        in_features = base_model.classifier.in_features
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

        # Freeze everything
        for param in self.features.parameters():
            param.requires_grad = False
        for param in self.classifier.parameters():
            param.requires_grad = True

        self.model_name = "DenseNet121"
        self.epochs_trained = 0

    def forward(self, x):
        x = self.features(x)
        x = self.avg_pool(x)
        return self.classifier(x)

    def fine_tune(self, n_layers=2):
        for param in self.features.parameters():
            param.requires_grad = False
        blocks = [self.features.denseblock1, self.features.denseblock2, self.features.denseblock3, self.features.denseblock4]
        for block in blocks[-n_layers:]:
            for param in block.parameters():
                param.requires_grad = True
        for param in self.classifier.parameters():
            param.requires_grad = True

# EfficientNetB0
class EfficientNetB0(nn.Module):
    def __init__(self, num_classes, in_chans=3, freeze=True):
        super(EfficientNetB0, self).__init__()
        self.model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        if in_chans != 3:
            old_conv = self.model.features[0][0]
            self.model.features[0][0] = nn.Conv2d(
                in_channels=in_chans,
                out_channels=old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=old_conv.bias is not None
            )
        self.model_name = "EfficientNetB0"
        self.epochs_trained = 0

        if freeze:
            for param in self.model.features.parameters():
                param.requires_grad = False

        in_features = self.model.classifier[1].in_features
        self.model.classifier[1] = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)
    
    def fine_tune(self, n_layers=2):
        _efficientnet_fine_tune(self, n_layers)

# EfficientNetB3
class EfficientNetB3(nn.Module):
    def __init__(self, num_classes, in_chans=3, freeze=True):
        super(EfficientNetB3, self).__init__()
        self.model = efficientnet_b3(weights=EfficientNet_B3_Weights.DEFAULT)
        if in_chans != 3:
            old_conv = self.model.features[0][0]
            self.model.features[0][0] = nn.Conv2d(
                in_channels=in_chans,
                out_channels=old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=old_conv.bias is not None
            )
        self.model_name = "EfficientNetB3"
        self.epochs_trained = 0

        if freeze:
            for param in self.model.features.parameters():
                param.requires_grad = False

        in_features = self.model.classifier[1].in_features
        self.model.classifier[1] = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)
    
    def fine_tune(self, n_layers=2):
        _efficientnet_fine_tune(self, n_layers)

# ConvNeXt
class ConvNeXtBase(nn.Module):
    def __init__(self, num_classes, in_chans=3, freeze=True):
        super(ConvNeXtBase, self).__init__()
        self.model = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT)

        if in_chans != 3:
            old_conv = self.model.features[0][0]
            self.model.features[0][0] = nn.Conv2d(
                in_chans,
                old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )

        if freeze:
            for param in self.model.features.parameters():
                param.requires_grad = False

        in_features = self.model.classifier[2].in_features
        self.model.classifier[2] = nn.Linear(in_features, num_classes)

        self.model_name = "ConvNeXtBase"
        self.epochs_trained = 0

    def forward(self, x):
        return self.model(x)
    
    def fine_tune(self, n_layers=2):
        for param in self.model.features.parameters():
            param.requires_grad = False
        for block in self.model.features[-n_layers:]:
            for param in block.parameters():
                param.requires_grad = True
        for param in self.model.classifier.parameters():
            param.requires_grad = True

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        BCE_loss = self.bce(inputs, targets)
        pt = torch.exp(-BCE_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss
        
def optimise_thresholds(all_targets, all_probs):
    best_thresholds = []

    for i in range(all_targets.shape[1]):
        scores = []
        for t in tqdm(np.arange(0, 1, 0.01), desc=f"Optimising threshold of class {i+1}"):
            probs = (all_probs[:, i] > t).astype(int)
            f1 = f1_score(all_targets[:, i], probs, zero_division=0)
            scores.append((t, f1))
        best_threshold = max(scores, key=lambda x: x[1])[0]
        best_thresholds.append(best_threshold)
    
    return best_thresholds



In [None]:
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs available: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = DenseNet121(num_classes=14)
model = model.to(device)

eta = 1e-4
opti = torch.optim.AdamW(model.parameters(), lr=eta, weight_decay=0.01)
criterion = nn.BCEWithLogitsLoss()


In [None]:
epochs = 250

epoch_train_loss, epoch_val_loss = [], []
val_roc_aucs_by_epoch, val_roc_auc_means = [], []
hammings, precisions, recalls, f1_macros, f1_micros = [], [], [], [], []
class_thresholds = np.full(14, 0.35)
best_auc = 0.0
early_stop_counter = 0
patience = 5

In [None]:
metrics_log = []
metrics_csv_path = f"Metrics/{model.model_name}_training_metrics.csv"
os.makedirs(os.path.dirname(metrics_csv_path), exist_ok=True)
os.makedirs("Checkpoints", exist_ok=True)  # Ensure checkpoint folder exists

roc_auc_columns = [f"roc_auc_{name}" for name in class_names]

# Write header if file doesn't exist
if not os.path.exists(metrics_csv_path):
    with open(metrics_csv_path, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            "epoch", "train_loss", "val_loss", "mean_auc", "mean_entropy",
            "f1_micro", "f1_macro", "precision", "recall", "hamming_loss"
        ] + roc_auc_columns)

# --- Training Loop ---
for epoch in range(epochs):
    train_losses = []
    model.train()

    for d in tqdm(train_dl, desc=f"Training {epoch+1}/{epochs}"):
        opti.zero_grad()
        imgs = d['images'].to(device, dtype=torch.float)
        labels = d['labels'].to(device, dtype=torch.float)

        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opti.step()

        train_losses.append(loss.item())

    epoch_train_loss.append(np.mean(train_losses))
    model.epochs_trained += 1

    # --- Validation ---
    model.eval()
    val_losses, val_preds, val_targets = [], [], []
    with torch.no_grad():
        for d in tqdm(val_dl, desc="Validating"):
            imgs = d['images'].to(device, dtype=torch.float)
            labels = d['labels'].to(device, dtype=torch.float)

            logits = model(imgs)
            loss = criterion(logits, labels)
            val_losses.append(loss.item())

            preds = torch.sigmoid(logits)
            val_preds.append(preds.cpu())
            val_targets.append(labels.cpu())

    epoch_val_loss.append(np.mean(val_losses))
    all_val_preds = torch.cat(val_preds).numpy()
    all_val_targets = torch.cat(val_targets).numpy()

    entropy = - (all_val_preds * np.log(all_val_preds + 1e-7) + (1 - all_val_preds) * np.log(1 - all_val_preds + 1e-7))
    mean_entropy = np.mean(entropy)

    val_epoch_roc_aucs = []
    for i in range(all_val_targets.shape[1]):
        try:
            auc = roc_auc_score(all_val_targets[:, i], all_val_preds[:, i])
        except ValueError:
            auc = float('nan')
        val_epoch_roc_aucs.append(auc)

    val_roc_aucs_by_epoch.append(val_epoch_roc_aucs)
    mean_auc = np.nanmean(val_epoch_roc_aucs)
    val_roc_auc_means.append(mean_auc)

    # --- Thresholding ---
    binary_preds = np.zeros_like(all_val_preds)
    for i, thresh in enumerate(class_thresholds):
        binary_preds[:, i] = (all_val_preds[:, i] > thresh).astype(int)

    f1_micro = f1_score(all_val_targets, binary_preds, average='micro', zero_division=0)
    f1_macro = f1_score(all_val_targets, binary_preds, average='macro', zero_division=0)
    precision = precision_score(all_val_targets, binary_preds, average='micro', zero_division=0)
    recall = recall_score(all_val_targets, binary_preds, average='micro', zero_division=0)
    hamming = hamming_loss(all_val_targets, binary_preds)
    label_accuracy = (binary_preds == all_val_targets).mean(axis=0)

    f1_micros.append(f1_micro)
    f1_macros.append(f1_macro)
    precisions.append(precision)
    recalls.append(recall)
    hammings.append(hamming)

    # --- Epoch Summary ---
    print(f"Train Loss: {epoch_train_loss[-1]:.4f} | Val Loss: {epoch_val_loss[-1]:.4f}")
    print(f"Mean ROC-AUC: {mean_auc:.4f} | Mean Entropy: {mean_entropy:.4f}")
    print(f"F1-score (micro): {f1_micro:.4f} | F1-score (macro): {f1_macro:.4f}")
    print(f"Precision: {precision:.4f} | Recall: {recall:.4f}")
    print(f"Hamming Loss: {hamming:.4f}")
    print(f"Label-wise Accuracy: {label_accuracy.round(3)}\n")

    # Save metrics to CSV using model.epochs_trained
    with open(metrics_csv_path, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            model.epochs_trained,
            epoch_train_loss[-1],
            epoch_val_loss[-1],
            mean_auc,
            mean_entropy,
            f1_micro,
            f1_macro,
            precision,
            recall,
            hamming
        ] + val_epoch_roc_aucs)
    
    # --- Checkpoint & Early Stopping ---
    if mean_auc > best_auc:
        best_auc = mean_auc
        early_stop_counter = 0
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimiser_state_dict': opti.state_dict(),
            'epochs_trained': model.epochs_trained,
        }, f"Checkpoints/{model.model_name}_best.pth")
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

In [None]:
# Put model in evaluation mode
model.eval()

all_probs = []
all_logits = []
all_targets = []
all_images = []

with torch.no_grad():
    for d in tqdm(test_dl, desc="Generating predictions"):
        imgs = d['images'].to(device, dtype=torch.float)
        labels = d['labels'].to(device, dtype=torch.float)

        logits = model(imgs)
        prob = torch.sigmoid(logits)

        all_logits.append(logits.cpu())
        all_probs.append(prob.cpu())
        all_targets.append(labels.cpu())
        all_images.append(imgs.cpu())

# Stack into single arrays
all_probs = torch.cat(all_probs).numpy()
all_targets = torch.cat(all_targets).numpy()
all_logits = torch.cat(all_logits).numpy()
all_images = torch.cat(all_images)

In [None]:
# Secondary Training on all layers

model.fine_tune()

best_thresholds = optimise_thresholds(all_targets, all_probs)

new_lr = eta / 10
for param_group in opti.param_groups:
    param_group['lr'] = new_lr


# --- Training Loop ---
for epoch in range(epochs):
    train_losses = []
    model.train()

    for d in tqdm(train_dl, desc=f"Training {epoch+1}/{epochs}"):
        opti.zero_grad()
        imgs = d['images'].to(device, dtype=torch.float)
        labels = d['labels'].to(device, dtype=torch.float)

        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opti.step()

        train_losses.append(loss.item())

    epoch_train_loss.append(np.mean(train_losses))
    model.epochs_trained += 1

    # --- Validation ---
    model.eval()
    val_losses, val_preds, val_targets = [], [], []
    with torch.no_grad():
        for d in tqdm(val_dl, desc="Validating"):
            imgs = d['images'].to(device, dtype=torch.float)
            labels = d['labels'].to(device, dtype=torch.float)

            logits = model(imgs)
            loss = criterion(logits, labels)
            val_losses.append(loss.item())

            preds = torch.sigmoid(logits)
            val_preds.append(preds.cpu())
            val_targets.append(labels.cpu())

    epoch_val_loss.append(np.mean(val_losses))
    all_val_preds = torch.cat(val_preds).numpy()
    all_val_targets = torch.cat(val_targets).numpy()

    entropy = - (all_val_preds * np.log(all_val_preds + 1e-7) + (1 - all_val_preds) * np.log(1 - all_val_preds + 1e-7))
    mean_entropy = np.mean(entropy)

    val_epoch_roc_aucs = []
    for i in range(all_val_targets.shape[1]):
        try:
            auc = roc_auc_score(all_val_targets[:, i], all_val_preds[:, i])
        except ValueError:
            auc = float('nan')
        val_epoch_roc_aucs.append(auc)

    val_roc_aucs_by_epoch.append(val_epoch_roc_aucs)
    mean_auc = np.nanmean(val_epoch_roc_aucs)
    val_roc_auc_means.append(mean_auc)

    # --- Thresholding ---
    binary_preds = np.zeros_like(all_val_preds)
    for i, thresh in enumerate(best_thresholds):
        binary_preds[:, i] = (all_val_preds[:, i] > thresh).astype(int)

    f1_micro = f1_score(all_val_targets, binary_preds, average='micro', zero_division=0)
    f1_macro = f1_score(all_val_targets, binary_preds, average='macro', zero_division=0)
    precision = precision_score(all_val_targets, binary_preds, average='micro', zero_division=0)
    recall = recall_score(all_val_targets, binary_preds, average='micro', zero_division=0)
    hamming = hamming_loss(all_val_targets, binary_preds)
    label_accuracy = (binary_preds == all_val_targets).mean(axis=0)

    f1_micros.append(f1_micro)
    f1_macros.append(f1_macro)
    precisions.append(precision)
    recalls.append(recall)
    hammings.append(hamming)

    # --- Epoch Summary ---
    print(f"Train Loss: {epoch_train_loss[-1]:.4f} | Val Loss: {epoch_val_loss[-1]:.4f}")
    print(f"Mean ROC-AUC: {mean_auc:.4f} | Mean Entropy: {mean_entropy:.4f}")
    print(f"F1-score (micro): {f1_micro:.4f} | F1-score (macro): {f1_macro:.4f}")
    print(f"Precision: {precision:.4f} | Recall: {recall:.4f}")
    print(f"Hamming Loss: {hamming:.4f}")
    print(f"Label-wise Accuracy: {label_accuracy.round(3)}\n")

    # Save metrics to CSV using model.epochs_trained
    with open(metrics_csv_path, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            model.epochs_trained,
            epoch_train_loss[-1],
            epoch_val_loss[-1],
            mean_auc,
            mean_entropy,
            f1_micro,
            f1_macro,
            precision,
            recall,
            hamming
        ] + val_epoch_roc_aucs)
    
    # --- Checkpoint & Early Stopping ---
    if mean_auc > best_auc:
        best_auc = mean_auc
        early_stop_counter = 0
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimiser_state_dict': opti.state_dict(),
            'epochs_trained': model.epochs_trained,
        }, f"Checkpoints/{model.model_name}_best.pth")
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

In [None]:
plt.plot(epoch_train_loss, label='Train Loss')
plt.plot(epoch_val_loss, label='Validation Loss')
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()

plt.plot(f1_micros, label='F1 Micros')
plt.plot(f1_macros, label='F1 Macros')
plt.title("F1 Scores")
plt.xlabel("Epoch")
plt.ylabel("Score")
plt.legend()
plt.grid()
plt.show()

plt.plot(precisions, label='Precision')
plt.plot(recalls, label='Recall')
plt.title("P&R Rates")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()

plt.plot(hammings, label='Hamming Loss')
plt.title("Hamming Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()

val_roc_aucs_by_epoch_np = np.array(val_roc_aucs_by_epoch)
epochs_range = range(0, len(val_roc_aucs_by_epoch_np))

# Per-class curves
try:
    for i in range(val_roc_aucs_by_epoch_np.shape[1]):
        plt.plot(epochs_range, val_roc_aucs_by_epoch_np[:, i], label=class_names[i], alpha=0.5)

    # Mean ROC-AUC
    plt.plot(epochs_range, val_roc_auc_means, label='Mean ROC-AUC', color='black', linewidth=2.5)

    plt.title("Validation ROC-AUC per Class and Mean")
    plt.xlabel("Epoch")
    plt.ylabel("ROC-AUC")
    plt.grid(True)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()
except:
    print("Empty variables")

In [None]:
# Put model in evaluation mode
model.eval()

all_probs = []
all_logits = []
all_targets = []
all_images = []

with torch.no_grad():
    for d in tqdm(test_dl, desc="Generating predictions"):
        imgs = d['images'].to(device, dtype=torch.float)
        labels = d['labels'].to(device, dtype=torch.float)

        logits = model(imgs)
        prob = torch.sigmoid(logits)

        all_logits.append(logits.cpu())
        all_probs.append(prob.cpu())
        all_targets.append(labels.cpu())
        all_images.append(imgs.cpu())

# Stack into single arrays
all_probs = torch.cat(all_probs).numpy()
all_targets = torch.cat(all_targets).numpy()
all_logits = torch.cat(all_logits).numpy()
all_images = torch.cat(all_images)


In [None]:
best_thresholds = []

for i in range(all_targets.shape[1]):
    scores = []
    for t in tqdm(np.arange(0, 1, 0.01), desc=f"Optimising threshold of class {i+1}"):
        probs = (all_probs[:, i] > t).astype(int)
        f1 = f1_score(all_targets[:, i], probs, zero_division=0)
        scores.append((t, f1))
    best_threshold = max(scores, key=lambda x: x[1])[0]
    best_thresholds.append(best_threshold)

print("Best thresholds:", best_thresholds)

In [None]:
all_preds = np.array([
    (all_probs[:, i] > best_thresholds[i]).astype(int)
    for i in range(all_probs.shape[1])
]).T

In [None]:
roc_aucs = []

for i in range(all_targets.shape[1]):
    try:
        auc = roc_auc_score(all_targets[:, i], all_preds[:, i])
    except ValueError:
        auc = float('nan')
    roc_aucs.append(auc)

print("\nPer-Class ROC-AUC:")
for i, auc in enumerate(roc_aucs):
    print(f"Class {i}: {auc:.4f}")

print(f"\nMacro ROC-AUC: {np.nanmean(roc_aucs):.4f}")

In [None]:
opti_preds = np.zeros_like(all_preds)

for i in range(all_preds.shape[1]):
    opti_preds[:, i] = (all_preds[:, i] > best_thresholds[i]).astype(int)

In [None]:
def show_prediction(img_tensor, true_labels, pred_labels, opti_labels, class_names):
    # Convert image to numpy array if it's a tensor
    if isinstance(img_tensor, torch.Tensor):
        img = img_tensor.cpu().numpy()
    else:
        img = np.array(img_tensor)

    # Transpose if image has 3 channels (C, H, W) -> (H, W, C)
    if img.ndim == 3 and img.shape[0] == 3:
        img = np.transpose(img, (1, 2, 0))  # (C, H, W) -> (H, W, C)

    # Clip values to [0, 1] in case they exceed due to normalisation
    img = np.clip(img, 0, 1)

    # Convert labels to numpy if they're tensors
    if isinstance(true_labels, torch.Tensor):
        true_labels = true_labels.cpu().numpy()
    if isinstance(pred_labels, torch.Tensor):
        pred_labels = pred_labels.cpu().numpy()
    if isinstance(opti_labels, torch.Tensor):
        opti_labels = opti_labels.cpu().numpy()

    # Get active labels
    true_txt = ", ".join([class_names[i] for i, v in enumerate(true_labels) if v == 1])
    pred_txt = ", ".join([class_names[i] for i, v in enumerate(pred_labels) if v == 1])
    opti_pred_txt = ", ".join([class_names[i] for i, v in enumerate(opti_labels) if v == 1])

    # Plotting
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"True: {true_txt}\nPred(std): {pred_txt}\nPred(Opti): {opti_pred_txt}", fontsize=10)
    plt.show()


idx = random.randint(0, len(all_images)-1)
show_prediction(all_images[idx], all_targets[idx], all_preds[idx], opti_preds[idx], class_names)

In [None]:
subset_acc = accuracy_score(all_targets, all_preds)
print(f"Exact match accuracy: {subset_acc * 100:.4f}%")

In [None]:
torch.save({
            'model_state_dict': model.state_dict(),
            'optimiser_state_dict': opti.state_dict(),
            'epochs_trained': model.epochs_trained,
        },
          f"Models/{model.model_name}_manual.pth")

In [None]:
model = DenseNet121(num_classes=14)
checkpoint = torch.load(fr'Checkpoints/{model.model_name}_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])
opti.load_state_dict(checkpoint['optimiser_state_dict'])
model.epochs_trained = checkpoint.get('epochs_trained', 0)
model.to(device)
model.eval()
model.fine_tune()

In [None]:
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()

        for name, module in self.model.named_modules():
            if name == self.target_layer:
                self.hook_handles.append(module.register_forward_hook(forward_hook))
                self.hook_handles.append(module.register_backward_hook(backward_hook))
                print(f"[GradCAM] Hooks registered on layer: {name}")
                break
        else:
            raise ValueError(f"Layer {self.target_layer} not found in model.")

    def generate_cam(self, input_tensor, target_class=None):
        self.model.eval()
        self.gradients = None
        self.activations = None

        output = self.model(input_tensor)

        if target_class is None:
            target_class = output.argmax(dim=1).item()

        self.model.zero_grad()
        loss = output[0, target_class]
        loss.backward(retain_graph=True)

        if self.gradients is None or self.activations is None:
            raise RuntimeError("Gradients or activations not captured. Check hook registration.")

        weights = self.gradients.mean(dim=[2, 3], keepdim=True)
        weighted_activations = weights * self.activations
        cam = weighted_activations.sum(dim=1).squeeze()
        cam = F.relu(cam)

        cam -= cam.min()
        if cam.max() != 0:
            cam /= cam.max()

        return cam.cpu().numpy()

    def clear_hooks(self):
        for handle in self.hook_handles:
            handle.remove()

def unnormalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

def run_gradcam_and_save(model, val_dl, target_layer, mean, std, pathology_names=None, save_path="gradcam_overlay.png"):

    grad_cam = GradCAM(model, target_layer)

    sample = next(iter(val_dl))
    input_tensor = sample['images'][0].unsqueeze(0).to(next(model.parameters()).device)

    # Print pathology labels (supports multi-label or single-label)
    labels = sample['labels'][0]  # Assume shape [batch, num_classes] or [batch]

    if pathology_names:
        if labels.ndim == 1 and labels.numel() > 1:
            # Multi-label binary vector
            present = [pathology_names[i] for i, val in enumerate(labels) if val > 0]
            print("Pathologies in image:", ", ".join(present) if present else "None")
        elif labels.numel() == 1:
            # Single-label integer index
            print("Pathology label:", pathology_names[labels.item()])
        else:
            print("Labels shape or type unexpected; raw labels:", labels)
    else:
        print("Labels:", labels)


    heatmap = grad_cam.generate_cam(input_tensor)

    original_img = sample['images'][0].clone()
    original_img = unnormalize(original_img, mean, std)
    original_img = torch.clamp(original_img, 0, 1)
    original_img_np = original_img.permute(1, 2, 0).cpu().numpy()
    original_img_np = np.uint8(255 * original_img_np)

    heatmap_resized = cv2.resize(heatmap, (original_img_np.shape[1], original_img_np.shape[0]))

    heatmap_uint8 = np.uint8(255 * heatmap_resized)
    heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)

    overlayed_img = cv2.addWeighted(original_img_np, 0.6, heatmap_color, 0.4, 0)

    cv2.imwrite(save_path, overlayed_img)
    print(f"Grad-CAM overlay saved to {save_path}")

    plt.imshow(cv2.cvtColor(overlayed_img, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.show()

    grad_cam.clear_hooks()


In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
run_gradcam_and_save(model, val_dl, "features.denseblock4.denselayer16.conv2", mean, std, class_names)