In [None]:
# ── Standard Library ───────────────────────────────────────────────────────────
import os
import csv
import time
import glob
from pathlib import Path
from itertools import cycle

# ── Data Handling & Numerics ───────────────────────────────────────────────────
import numpy as np
from numpy import asarray
import pandas as pd

# ── PyTorch Ecosystem ──────────────────────────────────────────────────────────
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

import timm
from tqdm import tqdm

# ── Computer Vision & Augmentations ────────────────────────────────────────────
import torchvision
import torchvision.transforms as transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

import cv2
from PIL import Image
from pydicom import dcmread

# ── Scikit-Learn & Statistics ─────────────────────────────────────────────────
import sklearn
from sklearn.metrics import (
    roc_auc_score, auc, roc_curve, average_precision_score
)
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.utils.multiclass import type_of_target
from sklearn.utils import shuffle
from sklearn.preprocessing import label_binarize
import sklearn.metrics as metrics

from scipy import interp
from scipy.stats import ttest_rel, ttest_ind

# ── Visualization ─────────────────────────────────────────────────────────────
import seaborn as sns

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import gridspec

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ── Grad-CAM & Explainability ─────────────────────────────────────────────────
from pytorch_grad_cam import (
    GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
    AblationCAM, XGradCAM, EigenCAM, FullGrad
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.feature_factorization.deep_feature_factorization import \
    DeepFeatureFactorization
import pytorch_grad_cam.utils.model_targets
import pytorch_grad_cam.utils.reshape_transforms
import pytorch_grad_cam.metrics.cam_mult_image
import pytorch_grad_cam.metrics.road


In [None]:
os.chdir("Dataset Directory")
visualization_dir = "visualization_dir"
model_save_dir = "Directory_Model"
df = pd.read_csv("PLOSONE_DF-MacularStatus.csv") 
df["image_id"] = np.arange(len(df))
model_save_dir = "model_checkpoints"
os.makedirs(model_save_dir, exist_ok=True)
len(df)

# MODEL

In [None]:
class_names = [ 'class_0','class_1','class_2']
class_labels = { name: i for i, name in enumerate(class_names)}

N_EPOCHS = 20
BATCH_SIZE = 6
num_classes = class_labels 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Device setup
if torch.cuda.is_available():
    print("CUDA is available. Using GPU.")
    device = torch.device("cuda:0")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")

class OctModel3D(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        # Load the 3D ResNet-18 model
        self.model = r3d_18(pretrained=True)  # Pretrained 3D ResNet18

        self.model.stem[0] = nn.Conv3d(
            in_channels=1,
            out_channels=self.model.stem[0].out_channels,
            kernel_size=self.model.stem[0].kernel_size,
            stride=self.model.stem[0].stride,
            padding=self.model.stem[0].padding,
            bias=False,
        )

        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        # Expects shape: (Batch, 1, Depth, Height, Width)
        return self.model(x)

# Initialize and move model to device
model = OctModel3D(num_classes=3)
model = model.to(device)
model.eval()

In [None]:
def load_dicom_volume(dicom_dir, target_shape=(64, 256, 256), cache_dir=None, use_gpu=True):
    """
    Load a series of DICOM files into a resized 3D tensor with GPU-based resizing.
    """
    slices = []

    # Check if path is a single file
    if os.path.isfile(dicom_dir):
        ds = pydicom.dcmread(dicom_dir)
        slices.append(ds.pixel_array)
    else:
        for filename in sorted(os.listdir(dicom_dir)):
            if filename.endswith('.dcm'):
                dicom_path = os.path.join(dicom_dir, filename)
                ds = pydicom.dcmread(dicom_path)
                slices.append(ds.pixel_array)

    if not slices:
        raise ValueError(f"No valid DICOM slices found in: {dicom_dir}")

    # Stack slices => shape [Depth, Height, Width]
    volume = np.stack(slices, axis=0).astype(np.float32)

    # Normalize the volume to [0, 1]
    volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume) + 1e-8)

    # Convert to PyTorch tensor
    volume = torch.tensor(volume)

    # Add batch=1, channel=1 => shape [1, 1, D, H, W]
    if volume.ndim == 3:
        volume = volume.unsqueeze(0).unsqueeze(0)
    elif volume.ndim == 4:
        volume = volume.unsqueeze(0)
    else:
        raise ValueError(f"Unexpected volume shape {volume.shape}")

    # Resize to target_shape => [1,1,*target_shape]
    if use_gpu and torch.cuda.is_available():
        volume = volume.cuda()
        volume = F.interpolate(volume, size=target_shape, mode="trilinear", align_corners=False)
        volume = volume.cpu()
    else:
        volume = F.interpolate(volume, size=target_shape, mode="trilinear", align_corners=False)

    # Remove batch dimension => shape [1, D,H,W]
    volume = volume.squeeze(0)

    # Validate final shape
    if volume.shape != (1, *target_shape):
        raise ValueError(f"Unexpected shape after resizing: {volume.shape}, expected: (1, {target_shape})")

    return volume

class OctDataset(torch.utils.data.Dataset):
    def __init__(self, df, class_names, target_shape=(64, 256, 256), cache_dir=None, use_gpu=True):

        self.df = df.reset_index(drop=True)
        self.class_names = class_names
        self.target_shape = target_shape
        self.cache_dir = cache_dir
        self.use_gpu = use_gpu

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        dicom_dir = self.df.iloc[idx]['path3D']
        try:
            volume = load_dicom_volume(dicom_dir, target_shape=self.target_shape, cache_dir=self.cache_dir, use_gpu=self.use_gpu)
        except ValueError as e:
            raise RuntimeError(f"Error loading DICOM at index {idx}: {e}")

        # volume => shape [1, D,H,W]
        if volume.ndim != 4:
            raise ValueError(f"Volume shape is {volume.shape}, expected 4D (1, D, H, W)")

        # Load labels
        labels = self.df.iloc[idx][self.class_names].values.astype(np.float32)
        labels = torch.from_numpy(labels)

        # Get sample ID
        sample_id = self.df.iloc[idx]['image_id']
        return volume, labels, sample_id


In [None]:
def reset_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

In [None]:
def reset_weights(m):
    if hasattr(m, 'reset_parameters'):
        m.reset_parameters()

def get_current_time():
    return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

##################################################
# TRAIN FUNCTION
##################################################
def train_one_fold(v, t, model, criterion, optimizer,
                   dataloader_train, dataloader_valid,

    val_fold_results = []
    best_fold_avg_precision = 0.0
    epochs_no_improve = 0
    best_model_path = None
    epoch_times = []

    # Make sure N_EPOCHS is defined globally or pass it in
    for epoch in range(N_EPOCHS):
        start_time = time.time()
        print(f"[{get_current_time()}] Epoch {epoch + 1}/{N_EPOCHS}")

        #################
        # TRAINING PHASE
        #################
        model.train()
        tr_loss = 0.0
        for batch_idx, batch in enumerate(dataloader_train):
            images, labels = batch[0], batch[1]

            # Move images to GPU/CPU
            images = images.float().to(device)

            if labels.ndim == 2 and labels.shape[1] == 3:
                labels = torch.argmax(labels, dim=1)
            labels = labels.long().to(device)

            optimizer.zero_grad()
            outputs = model(images)  # [batch_size, 3]

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()

        ###################
        # VALIDATION PHASE
        ###################
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for batch_idx, batch in enumerate(dataloader_valid):
                images, labels = batch[0], batch[1]

                images = images.float().to(device)
                # --------------[ OPTION A FIX ]--------------
                if labels.ndim == 2 and labels.shape[1] == 3:
                    labels = torch.argmax(labels, dim=1)
                labels = labels.long().to(device)

                outputs = model(images)  # [batch_size, 3]
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Store outputs & labels for metric
                val_preds.extend(outputs.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        # Convert to NumPy
        val_preds_np = np.array(val_preds)    # shape [N, 3]
        val_labels_np = np.array(val_labels)  # shape [N], each in {0,1,2}

        val_probs = F.softmax(torch.from_numpy(val_preds_np), dim=1).numpy()  # [N, 3]

        # Convert labels -> one-hot for one-vs-rest macro AP
        num_classes = 3
        labels_one_hot = np.zeros((len(val_labels_np), num_classes), dtype=np.float32)
        for i, lbl in enumerate(val_labels_np):
            labels_one_hot[i, lbl] = 1.0

        # Compute macro AP
        try:
            avg_precision_current = average_precision_score(labels_one_hot, val_probs, average="macro")
        except ValueError:

            avg_precision_current = 0.0

        print(f"[{get_current_time()}] Val Avg Precision (macro): {avg_precision_current:.4f}")

        if avg_precision_current > best_fold_avg_precision:
            best_fold_avg_precision = avg_precision_current
            epochs_no_improve = 0
            best_model_path = os.path.join(
                model_save_dir,
                f"ModelV{v}_T{t}M-Status.pth"
            )
            torch.save(model.state_dict(), best_model_path)
            print(f"[{get_current_time()}] Best Model Saved at: {best_model_path}")
        else:
            epochs_no_improve += 1

        # Early stopping
        if epochs_no_improve >= patience:
            print(f"[{get_current_time()}] Early stopping at epoch: {epoch + 1}")
            break

        epoch_time = time.time() - start_time
        epoch_times.append(epoch_time)
        print(f"[{get_current_time()}] Time for Epoch {epoch + 1}: {epoch_time:.2f} seconds")

        # Record epoch stats
        val_fold_results.append({
            'TEST': t,
            'VAL': v,
            'epoch': epoch,
            'train_loss': tr_loss / len(dataloader_train),
            'valid_loss': val_loss / len(dataloader_valid),
            'valid_avg_precision': avg_precision_current,
            'time': epoch_time
        })

    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(epoch_times)+1), epoch_times, marker='o')
    plt.xlabel("Epoch")
    plt.ylabel("Time (seconds)")
    plt.title(f"Epoch Timing - Fold {v}-{t}")
    plt.grid(True)
    plt.show()

    return val_fold_results, best_fold_avg_precision, best_model_path

In [None]:
##################################################
#Test
##################################################


aggregated_results = []
N_FOLDS = 5

for test_fold in range(N_FOLDS):
    for val_fold in range(N_FOLDS):
        if test_fold != val_fold:
            # 1) Initialize model for 3 classes
            model = OctModel3D(num_classes=3).to(device)
            model.apply(reset_weights)

            # 2) Prepare splits
            val_df = df[df["fold"] == val_fold].reset_index(drop=True)
            train_df = df[(df["fold"] != test_fold) & (df["fold"] != val_fold)].reset_index(drop=True)
            test_df = df[df["fold"] == test_fold].reset_index(drop=True)

            print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

            # 3) Create Datasets / DataLoaders
            dataset_train = OctDataset(df=train_df, class_names=class_names)
            dataset_valid = OctDataset(df=val_df, class_names=class_names)
            dataset_test  = OctDataset(df=test_df, class_names=class_names)

            dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
            dataloader_valid = DataLoader(dataset_valid, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
            dataloader_test  = DataLoader(dataset_test,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

            # 4) Training setup
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(model.parameters(), lr=5e-5)

            # 5) Train & validate (3-class single-label)
            val_fold_results, best_fold_ap, best_model_path = train_one_fold(
                v=test_fold,
                t=val_fold,
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                dataloader_train=dataloader_train,
                dataloader_valid=dataloader_valid,
                model_save_dir=model_save_dir,
                patience=10
            )

            # 6) Load best model for testing
            model.load_state_dict(torch.load(best_model_path, map_location=device))
            model.eval()

            # 7) TEST PHASE
            fold_test_targets = []
            fold_test_probabilities = []
            fold_test_ids = []
            predictions = []

            with torch.no_grad():
                for batch in dataloader_test:
                    images, labels, ids = batch
                    images = images.float().to(device)

                    # Convert one-hot => int if needed
                    if labels.ndim == 2 and labels.shape[1] == 3:
                        labels = labels.argmax(dim=1)
                    labels = labels.long()  # CPU is OK for storing results

                    outputs = model(images)            # [batch_size, 3]
                    probs = torch.softmax(outputs, dim=1).cpu().numpy()  # [batch_size, 3]

                    batch_preds = np.argmax(probs, axis=1)
                    predictions.extend(batch_preds)
                    fold_test_probabilities.extend(probs)
                    fold_test_targets.extend(labels.numpy())
                    # sample_id might be a tensor or direct value
                    if hasattr(ids, "numpy"):
                        fold_test_ids.extend(ids.numpy())
                    else:
                        fold_test_ids.extend(ids)

            # Convert to NumPy
            test_targets_np = np.array(fold_test_targets)       # shape [N], each ∈ {0..2}
            test_probs_np   = np.array(fold_test_probabilities) # shape [N, 3]

            # 7a) Multi-class ROC AUC (macro, one-vs-rest)
            try:
                roc_auc = roc_auc_score(test_targets_np, test_probs_np,
                                        multi_class='ovr', average='macro')
            except ValueError:
                roc_auc = 0.0

            print(f"[TestFold={test_fold}, ValFold={val_fold}] ROC AUC (macro, ovr): {roc_auc:.4f}")

            # 8) Identify correct predictions
            correct_indices = [i for i in range(len(predictions))
                               if predictions[i] == test_targets_np[i]]
            num_correct = len(correct_indices)
            print(f"Correct Predictions: {num_correct}/{len(test_targets_np)}")

            # 9) Store final results
            sub = pd.DataFrame(test_probs_np, columns=["pred0","pred1","pred2"])
            sub["image_id"] = fold_test_ids
            sub["test_fold"] = test_fold
            sub["val_fold"] = val_fold
            sub["predicted_class"] = predictions
            sub["actual_class"] = test_targets_np
            sub["roc_auc"] = roc_auc

            aggregated_results.append(sub)

# Concatenate final results
final_results = pd.concat(aggregated_results, ignore_index=True)
print("Final Results (head):")
print(final_results.head())
