In [None]:
# ── Standard Library ───────────────────────────────────────────────────────────
import os
import csv
import time
import glob
from pathlib import Path
from itertools import cycle

# ── Scientific & Data Handling ────────────────────────────────────────────────
import numpy as np
from numpy import asarray
import pandas as pd

# ── PyTorch & Training Utilities ──────────────────────────────────────────────
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

import timm
from tqdm import tqdm

# ── Computer Vision & Augmentations ───────────────────────────────────────────
import torchvision
import torchvision.transforms as transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

import cv2
from PIL import Image
from pydicom import dcmread

# ── Scikit-Learn & Statistics ────────────────────────────────────────────────
from sklearn.metrics import (
    roc_auc_score, auc, roc_curve, average_precision_score
)
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.utils.multiclass import type_of_target
from sklearn.utils import shuffle
from sklearn.preprocessing import label_binarize
import sklearn.metrics as metrics

from scipy import interp
from scipy.stats import ttest_rel, ttest_ind

# ── Visualization ─────────────────────────────────────────────────────────────
import seaborn as sns

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import gridspec

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ── Grad-CAM & Explainability ────────────────────────────────────────────────
from pytorch_grad_cam import (
    GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
    AblationCAM, XGradCAM, EigenCAM, FullGrad
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.feature_factorization.deep_feature_factorization import \
    DeepFeatureFactorization
import pytorch_grad_cam.utils.model_targets
import pytorch_grad_cam.utils.reshape_transforms
import pytorch_grad_cam.metrics.cam_mult_image
import pytorch_grad_cam.metrics.road

In [None]:
os.chdir("Dataset Directory")
visualization_dir = "visualization_dir"
model_save_dir = "Directory_Model"
df = pd.read_csv("PLOSONE_DF-Stages") 
df["image_id"] = np.arange(len(df))
len(df)

# MODEL

In [None]:
class_names = ['class_0', 'class_1', 'class_2', 'class_3', 'class_4', 'class_5']
class_labels = { name: i for i, name in enumerate(class_names)}

N_EPOCHS = 30
BATCH_SIZE = 20
SIZE =   496
num_classes = class_labels

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
transforms_train = A.Compose([
    A.ShiftScaleRotate(rotate_limit=1.0, p=0.8, border_mode = 0),
    A.Resize(height=SIZE, width=SIZE, p=1.0),
    ToTensorV2(p=1.0),
])

transforms_valid = A.Compose([
    A.Resize(height=SIZE, width=SIZE, p=1.0),
    ToTensorV2(p=1.0),
])

transforms_test = A.Compose([
    A.Resize(height=SIZE, width=SIZE, p=1.0),
    ToTensorV2(p=1.0),
])

In [None]:
class OctModel(nn.Module): 
    def __init__(self, num_classes=6):
        super().__init__()
        self.model = timm.create_model('resnet18', pretrained=True, in_chans=3) 
        self.logit = nn.Linear(1000, num_classes)

    def forward(self, x):
        x = self.model(x)
        x = self.logit(x)
        return x

if torch.cuda.is_available():
    print("CUDA is available. Using GPU.")
    device = torch.device("cuda:0")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")

model = OctModel(num_classes=6)
model = model.to(device)

In [None]:
class OctDataset(torch.utils.data.Dataset):
    def __init__(self, df, transforms=None):
        self.df = df.reset_index(drop=True)
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['Path']
        image = Image.open(img_path).convert('RGB')

        labels = self.df.iloc[idx][class_names].values.astype(np.float32)
        labels = torch.from_numpy(labels)

        sample_id = self.df.iloc[idx]['image_id']

        if self.transforms:
            transformed = self.transforms(image=np.array(image))
            image = transformed['image']
        else:
            image = np.array(image, dtype=np.float32) / 255.0  # Normalize the image to [0, 1]
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)  # Move the channel dimension

        return image, labels, sample_id

In [None]:
def reset_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

In [None]:
def apply_gradcam_and_overlay(image, model_output, cam):
    # Ensure model_output is 2D: [batch_size, num_classes]
    if model_output.dim() == 1:
        model_output = model_output.unsqueeze(0)
    
    # Determine the predicted class
    _, predicted_class = model_output.max(dim=1)
    
    heatmap = cam.generate_heatmap(image, predicted_class)
    
    # Normalize the heatmap for visualization
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    # Convert torch tensor image to numpy for visualization
    image_np = image.cpu().numpy().transpose(1, 2, 0)
    image_np = (image_np - np.min(image_np)) / (np.max(image_np) - np.min(image_np)) * 255
    image_np = image_np.astype(np.uint8)

    # Overlay the heatmap on the original image
    overlayed_img = cv2.addWeighted(image_np, 0.5, heatmap, 0.5, 0)
    
    return overlayed_img

In [None]:
def train_one_fold(v, t, model, criterion, optimizer, dataloader_train, dataloader_valid, model_save_dir, patience=10):
    val_fold_results = []
    best_fold_ap = 0 
    epochs_no_improve = 0

    for epoch in range(N_EPOCHS):
        print(f'  Epoch {epoch + 1}/{N_EPOCHS}')

        # ---- TRAINING PHASE ----
        model.train()
        tr_loss = 0

        for _, batch in enumerate(dataloader_train):
            images, labels = batch[0].to(device).float(), batch[1].to(device).float()

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels.squeeze(-1))
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()

        # ---- VALIDATION PHASE ----
        model.eval()
        val_loss = 0
        val_preds = []
        val_labels = []

        for _, batch in enumerate(dataloader_valid):
            images, labels = batch[0].to(device).float(), batch[1].to(device).float()

            with torch.no_grad():
                outputs = model(images)
                loss = criterion(outputs, labels.squeeze(-1))
                val_loss += loss.item()
                val_preds.extend(outputs.detach().cpu().numpy())
                val_labels.extend(labels.detach().cpu().numpy())

        # Compute Average Precision Score (Macro)
        val_preds = np.array(val_preds)
        val_labels = np.array(val_labels)
        ap_score_current = average_precision_score(val_labels, val_preds, average="macro")

        print(f"AP Score Current Validation (Macro): {ap_score_current}")

        # Save the best model based on AP Score
        if ap_score_current > best_fold_ap:
            best_fold_ap = ap_score_current
            epochs_no_improve = 0
            best_model_path = os.path.join(model_save_dir, f"Model{v}_T{t}Stage.pth")
            torch.save(model.state_dict(), best_model_path)
            print(f"Updated Best Model Saved at: {best_model_path}")
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            print(f"Early stopping triggered at epoch: {epoch + 1}")
            break

        val_fold_results.append({
            'TEST': t,
            'VAL': v,
            'epoch': epoch,
            'train_loss': tr_loss / len(dataloader_train),
            'valid_loss': val_loss / len(dataloader_valid),
            'valid_ap_score': ap_score_current
        })

    return val_fold_results, best_fold_ap, best_model_path

In [None]:
os.makedirs(visualization_dir, exist_ok=True)
model = OctModel(num_classes=6).to(device)

cam = GradCAM(model=model, target_layers=[model.model.layer4[-1].conv2])

aggregated_results = []
for test_fold in range(5):
    for val_fold in range(5):
        if test_fold != val_fold:
            aggregated_test_targets = []
            aggregated_test_probabilities = []
            aggregated_test_ids = []  # To store the IDs of the test images

            model.apply(reset_weights)  # Reset model weights

            # Prepare dataset splits
            val_df = df[df['fold'] == val_fold].reset_index(drop=True)
            train_df = df[(df['fold'] != test_fold) & (df['fold'] != val_fold)].reset_index(drop=True)
            test_df = df[df['fold'] == test_fold].reset_index(drop=True)

            # Prepare dataloaders
            dataset_train = OctDataset(df=train_df, transforms=transforms_train)
            dataset_valid = OctDataset(df=val_df, transforms=transforms_valid)
            dataset_test = OctDataset(df=test_df, transforms=transforms_test)

            dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
            dataloader_valid = DataLoader(dataset_valid, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
            dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

            # Training setup
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(model.parameters(), lr=5e-5)

            # Train and validate
            best_model_path = os.path.join(model_save_dir, f"Model{test_fold}_T{val_fold}Stages.pth")
            val_fold_results, best_fold_roc_auc, best_model_path = train_one_fold(
                test_fold, val_fold, model, criterion, optimizer, dataloader_train, dataloader_valid, model_save_dir, patience=10
            )

            fold_test_targets = []
            fold_test_probabilities = []
            fold_test_predictions = []  # Initialize predictions list
            fold_test_ids = []  # To store test IDs for the fold

            for batch_idx, batch in enumerate(dataloader_test):
                images, labels, ids = batch[0].to(device).float(), batch[1].to(device).float(), batch[2]

                with torch.no_grad():
                    images = images.type(torch.cuda.FloatTensor)
                    outputs = model(images)
                    probabilities = torch.softmax(outputs, dim=1).cpu().numpy()

                fold_test_targets.extend(labels.cpu().numpy())
                fold_test_probabilities.extend(probabilities.tolist())
                fold_test_ids.extend(ids.numpy())

                for i, (img, output, label) in enumerate(zip(images, outputs, labels)):
                    # Define true and predicted labels
                    true_label_idx = label.argmax().item()
                    pred_class_idx = output.argmax().item()

                    # Prepare the input image and class index for Grad-CAM
                    input_tensor = img.unsqueeze(0)
                    target = ClassifierOutputTarget(pred_class_idx)

                    # Generate the CAM mask
                    cam_mask = cam(input_tensor=input_tensor, targets=[target])
                    cam_mask = cam_mask[0, :]

                    # Prepare the original image for visualization
                    img_np = img.cpu().numpy().transpose(1, 2, 0)
                    img_normalized = (img_np - img_np.min()) / (img_np.max() - img_np.min())

                    # Apply the CAM mask to the image
                    visualization = show_cam_on_image(img_normalized, cam_mask, use_rgb=True)

                    # Plot the original and Grad-CAM visualizations side by side
                    fig, axs = plt.subplots(1, 2, figsize=(10, 5))

                    # Original image
                    axs[0].imshow(img_normalized, cmap='gray')
                    axs[0].set_title(f"Original Image\nTrue: {true_label_idx}")
                    axs[0].axis('off')

                    # Grad-CAM visualization
                    axs[1].imshow(visualization)
                    axs[1].set_title(f"Predicted: {pred_class_idx}, True: {true_label_idx}")
                    axs[1].axis('off')

                    image_filename = f"image_{batch_idx * len(images) + i + 1}_true_{true_label_idx}_pred_{pred_class_idx}.png"
                    plt.savefig(os.path.join(visualization_dir, image_filename))
                    plt.close(fig)

            print("Fold completed with ROC-AUC:", roc_auc_score(np.stack(fold_test_targets), np.stack(fold_test_probabilities)))

            aggregated_test_targets.extend(fold_test_targets)
            aggregated_test_probabilities.extend(fold_test_probabilities)
            aggregated_test_ids.extend(fold_test_ids)

            sub = pd.DataFrame(aggregated_test_probabilities, columns=["pred0", "pred1", "pred2","pred3","pred4","pred5"])
            sub["image_id"] = fold_test_ids
            sub["test_fold"] = test_fold
            sub["val_fold"] = val_fold
            aggregated_results.append(sub)

AP Score Current Validation (Macro): 0.46544990044752277
Early stopping triggered at epoch: 28
Fold completed with ROC-AUC: 0.9442127786265396
  Epoch 1/30
AP Score Current Validation (Macro): 0.40194245081023056
Updated Best Model Saved at: D:/weights/SKEVAS/REG/SKEVAS_REG_GRAD_OCT-GRAD_fold-V4_T0Stages6-UE100e-PAPERFINALEE.pth
  Epoch 2/30
AP Score Current Validation (Macro): 0.41446235073693516
Updated Best Model Saved at: D:/weights/SKEVAS/REG/SKEVAS_REG_GRAD_OCT-GRAD_fold-V4_T0Stages6-UE100e-PAPERFINALEE.pth
  Epoch 3/30
AP Score Current Validation (Macro): 0.41502886817967033
Updated Best Model Saved at: D:/weights/SKEVAS/REG/SKEVAS_REG_GRAD_OCT-GRAD_fold-V4_T0Stages6-UE100e-PAPERFINALEE.pth
  Epoch 4/30
AP Score Current Validation (Macro): 0.442364623680959
Updated Best Model Saved at: D:/weights/SKEVAS/REG/SKEVAS_REG_GRAD_OCT-GRAD_fold-V4_T0Stages6-UE100e-PAPERFINALEE.pth
  Epoch 5/30
AP Score Current Validation (Macro): 0.4493133397196188
Updated Best Model Saved at: D:/weight

AP Score Current Validation (Macro): 0.44227589526618444
Updated Best Model Saved at: D:/weights/SKEVAS/REG/SKEVAS_REG_GRAD_OCT-GRAD_fold-V4_T3Stages6-UE100e-PAPERFINALEE.pth
  Epoch 4/30
AP Score Current Validation (Macro): 0.44971257459231634
Updated Best Model Saved at: D:/weights/SKEVAS/REG/SKEVAS_REG_GRAD_OCT-GRAD_fold-V4_T3Stages6-UE100e-PAPERFINALEE.pth
  Epoch 5/30
AP Score Current Validation (Macro): 0.46537312874450404
Updated Best Model Saved at: D:/weights/SKEVAS/REG/SKEVAS_REG_GRAD_OCT-GRAD_fold-V4_T3Stages6-UE100e-PAPERFINALEE.pth
  Epoch 6/30
AP Score Current Validation (Macro): 0.4255749354209137
  Epoch 7/30
AP Score Current Validation (Macro): 0.4469133951894683
  Epoch 8/30
AP Score Current Validation (Macro): 0.4396246028477239
  Epoch 9/30
AP Score Current Validation (Macro): 0.4451012712945613
  Epoch 10/30
AP Score Current Validation (Macro): 0.47167884483108796
Updated Best Model Saved at: D:/weights/SKEVAS/REG/SKEVAS_REG_GRAD_OCT-GRAD_fold-V4_T3Stages6-UE100e-P