In [None]:
# ── Standard Library ───────────────────────────────────────────────────────────
import os
import csv
import time
import glob
from pathlib import Path
from itertools import cycle

# ── Data Handling & Numerics ───────────────────────────────────────────────────
import numpy as np
from numpy import asarray
import pandas as pd

# ── PyTorch Ecosystem ──────────────────────────────────────────────────────────
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

import timm
from tqdm import tqdm

# ── Computer Vision & Augmentations ────────────────────────────────────────────
import torchvision
import torchvision.transforms as transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

import cv2
from PIL import Image
from pydicom import dcmread

# ── Scikit-Learn & Statistics ─────────────────────────────────────────────────
import sklearn
from sklearn.metrics import (
    roc_auc_score, auc, roc_curve, average_precision_score
)
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.utils.multiclass import type_of_target
from sklearn.utils import shuffle
from sklearn.preprocessing import label_binarize
import sklearn.metrics as metrics

from scipy import interp
from scipy.stats import ttest_rel, ttest_ind

# ── Visualization ─────────────────────────────────────────────────────────────
import seaborn as sns

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import gridspec

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ── Grad-CAM & Explainability ─────────────────────────────────────────────────
from pytorch_grad_cam import (
    GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
    AblationCAM, XGradCAM, EigenCAM, FullGrad
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.feature_factorization.deep_feature_factorization import (
    DeepFeatureFactorization
)
import pytorch_grad_cam.utils.reshape_transforms
import pytorch_grad_cam.metrics.cam_mult_image
import pytorch_grad_cam.metrics.road

In [None]:
os.chdir("Dataset Directory")
visualization_dir = "visualization_dir"
model_save_dir = "Directory_Model"
df = pd.read_csv("PLOSONE_DF-Time.csv") 
df["image_id"] = np.arange(len(df))
len(df)

# MODEL

In [None]:
class_names = [ 'class_0','class_1']
class_labels = { name: i for i, name in enumerate(class_names)}

N_EPOCHS = 20
BATCH_SIZE = 6
num_classes = class_labels 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Device setup
if torch.cuda.is_available():
    print("CUDA is available. Using GPU.")
    device = torch.device("cuda:0")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")

class OctModel3D(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.model = r3d_18(pretrained=True)
        self.model.stem[0] = nn.Conv3d(
            in_channels=1,
            out_channels=self.model.stem[0].out_channels,
            kernel_size=self.model.stem[0].kernel_size,
            stride=self.model.stem[0].stride,
            padding=self.model.stem[0].padding,
            bias=False,
        )

        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        # Expects shape: (Batch, 1, Depth, Height, Width)
        return self.model(x)
        
model = OctModel3D(num_classes=2)
model = model.to(device)
model.eval() 

In [None]:
def load_dicom_volume(dicom_dir, target_shape=(64, 256, 256), cache_dir=None, use_gpu=True):
    """
    Load a series of DICOM files into a resized 3D tensor with GPU-based resizing.
    """
    slices = []

    if os.path.isfile(dicom_dir):
        ds = pydicom.dcmread(dicom_dir)
        slices.append(ds.pixel_array)
    else:
        for filename in sorted(os.listdir(dicom_dir)):
            if filename.endswith('.dcm'):
                dicom_path = os.path.join(dicom_dir, filename)
                ds = pydicom.dcmread(dicom_path)
                slices.append(ds.pixel_array)

    if not slices:
        raise ValueError(f"No valid DICOM slices found in: {dicom_dir}")

    # Stack slices => shape [Depth, Height, Width]
    volume = np.stack(slices, axis=0).astype(np.float32)

    # Normalize the volume to [0, 1]
    volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume) + 1e-8)

    # Convert to PyTorch tensor
    volume = torch.tensor(volume)

    # Add batch=1, channel=1 => shape [1, 1, D, H, W]
    if volume.ndim == 3:
        volume = volume.unsqueeze(0).unsqueeze(0)
    elif volume.ndim == 4:
        volume = volume.unsqueeze(0)
    else:
        raise ValueError(f"Unexpected volume shape {volume.shape}")

    # Resize to target_shape => [1,1,*target_shape]
    if use_gpu and torch.cuda.is_available():
        volume = volume.cuda()
        volume = F.interpolate(volume, size=target_shape, mode="trilinear", align_corners=False)
        volume = volume.cpu()
    else:
        volume = F.interpolate(volume, size=target_shape, mode="trilinear", align_corners=False)

    # Remove batch dimension => shape [1, D,H,W]
    volume = volume.squeeze(0)

    # Validate final shape
    if volume.shape != (1, *target_shape):
        raise ValueError(f"Unexpected shape after resizing: {volume.shape}, expected: (1, {target_shape})")

    return volume

class OctDataset(torch.utils.data.Dataset):
    def __init__(self, df, class_names, target_shape=(64, 256, 256), cache_dir=None, use_gpu=True):
        self.df = df.reset_index(drop=True)
        self.class_names = class_names
        self.target_shape = target_shape
        self.cache_dir = cache_dir
        self.use_gpu = use_gpu

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        dicom_dir = self.df.iloc[idx]['path3D']
        try:
            volume = load_dicom_volume(dicom_dir, target_shape=self.target_shape, cache_dir=self.cache_dir, use_gpu=self.use_gpu)
        except ValueError as e:
            raise RuntimeError(f"Error loading DICOM at index {idx}: {e}")

        # volume => shape [1, D,H,W]
        if volume.ndim != 4:
            raise ValueError(f"Volume shape is {volume.shape}, expected 4D (1, D, H, W)")

        # Load labels
        labels = self.df.iloc[idx][self.class_names].values.astype(np.float32)
        labels = torch.from_numpy(labels)

        # Get sample ID
        sample_id = self.df.iloc[idx]['image_id']
        return volume, labels, sample_id


In [None]:
def reset_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

In [None]:
def get_current_time():
    return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

def train_one_fold(v, t, model, criterion, optimizer,
                   dataloader_train, dataloader_valid,
                   model_save_dir, patience=10):
    
    val_fold_results = []
    best_fold_avg_precision = 0
    epochs_no_improve = 0
    best_model_path = None
    epoch_times = []

    for epoch in range(N_EPOCHS):
        start_time = time.time()
        print(f"[{get_current_time()}] Epoch {epoch + 1}/{N_EPOCHS}")

        # ---- Training ----
        model.train()
        tr_loss = 0
        for _, batch in enumerate(dataloader_train):
            images, labels = batch[0].to(device).float(), batch[1].to(device).float()
            optimizer.zero_grad()
            outputs = model(images)  # forward => shape [batch_size, 2]
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            tr_loss += loss.item()

        # ---- Validation ----
        model.eval()
        val_loss = 0
        val_preds, val_labels = [], []

        with torch.no_grad():
            for _, batch in enumerate(dataloader_valid):
                images, labels = batch[0].to(device).float(), batch[1].to(device).float()
                outputs = model(images)  # shape [batch_size, 2]
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                val_preds.extend(outputs.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        # Convert to NumPy arrays
        val_preds_np = np.array(val_preds)   # shape [N, 2]
        val_labels_np = np.array(val_labels) # shape [N, 2]

        try:
            avg_precision_current = average_precision_score(val_labels_np[:, 1], val_preds_np[:, 1])
        except ValueError:
            avg_precision_current = 0.0

        print(f"[{get_current_time()}] Average Precision Current Validation: {avg_precision_current:.4f}")

        # ---- Model Saving ----
        if avg_precision_current > best_fold_avg_precision:
            best_fold_avg_precision = avg_precision_current
            epochs_no_improve = 0
            best_model_path = os.path.join(
                model_save_dir, f"Model{v}_T{t}Time.pth"
            )
            torch.save(model.state_dict(), best_model_path)
            print(f"[{get_current_time()}] Best Model Saved at: {best_model_path}")
        else:
            epochs_no_improve += 1

        # Early stopping
        if epochs_no_improve >= patience:
            print(f"[{get_current_time()}] Early stopping at epoch: {epoch + 1}")
            break

        epoch_time = time.time() - start_time
        epoch_times.append(epoch_time)
        print(f"[{get_current_time()}] Time for E {epoch + 1}: {epoch_time:.2f} s")

        # Record stats
        val_fold_results.append({
            'TEST': t,
            'VAL': v,
            'epoch': epoch,
            'train_loss': tr_loss / len(dataloader_train),
            'valid_loss': val_loss / len(dataloader_valid),
            'valid_avg_precision': avg_precision_current,
            'time': epoch_time
        })

    #Plot epoch times
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(epoch_times)+1), epoch_times, marker='o')
    plt.xlabel("Epoch")
    plt.ylabel("Time (seconds)")
    plt.title(f"Epoch Timing - Fold {v}-{t}")
    plt.grid(True)
    plt.show()

    return val_fold_results, best_fold_avg_precision, best_model_path

In [None]:
# Aggregated results list
aggregated_results = []

# Loop through test and validation folds
for test_fold in range(5):
    for val_fold in range(5):
        if test_fold != val_fold:
            # Arrays to store test results
            aggregated_test_targets = []
            aggregated_test_probabilities = []
            aggregated_test_ids = []

            # 1) Initialize the model and reset weights
            model = OctModel3D(num_classes=2).to(device)
            model.apply(reset_weights)

            # 2) Prepare dataset splits
            val_df = df[df['fold'] == val_fold].reset_index(drop=True)
            train_df = df[
                (df['fold'] != test_fold) & (df['fold'] != val_fold)
            ].reset_index(drop=True)
            test_df = df[df['fold'] == test_fold].reset_index(drop=True)

            print("!!!!!", len(train_df), len(val_df), len(test_df))

            # 3) Prepare dataloaders
            dataset_train = OctDataset(df=train_df, class_names=class_names)
            dataset_valid = OctDataset(df=val_df, class_names=class_names)
            dataset_test = OctDataset(df=test_df, class_names=class_names)

            num_workers = min(os.cpu_count(), 8)
            dataloader_train = DataLoader(
                dataset_train,
                batch_size=BATCH_SIZE,
                shuffle=True,
                num_workers=0,
                pin_memory=True
            )
            dataloader_valid = DataLoader(
                dataset_valid,
                batch_size=BATCH_SIZE,
                shuffle=False,
                num_workers=0,
                pin_memory=True
            )
            dataloader_test = DataLoader(
                dataset_test,
                batch_size=BATCH_SIZE,
                shuffle=False,
                num_workers=0
            )

            # 4) Training setup
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(model.parameters(), lr=5e-5)

            # 5) Train & validate (train_one_fold is adapted for 2 classes)
            best_model_path_local = os.path.join(
                model_save_dir, f"Model-T{test_fold}_V{val_fold}-Time.pth"
            )
            val_fold_results, best_fold_roc_auc, best_model_path = train_one_fold(
                test_fold,
                val_fold,
                model,
                criterion,
                optimizer,
                dataloader_train,
                dataloader_valid,
                model_save_dir,
                patience=10
            )

            # 6) Load the best model for testing
            model.load_state_dict(torch.load(best_model_path, map_location=device))
            model.eval()

            # 7) Test phase (collect predictions for ALL test samples)
            fold_test_targets = []
            fold_test_probabilities = []
            fold_test_ids = []
            predictions = []

            for batch in dataloader_test:
                images, labels, ids = batch[0].to(device).float(), batch[1].to(device).float(), batch[2]

                with torch.no_grad():
                    # model outputs shape [batch_size, 2]
                    outputs = model(images)
                    # probabilities shape [batch_size, 2]
                    probabilities = torch.softmax(outputs, dim=1).cpu().numpy()

                # Predicted class = argmax => 0 or 1
                batch_predictions = np.argmax(probabilities, axis=1)

                # Extend predictions/results
                predictions.extend(batch_predictions)
                fold_test_targets.extend(labels.cpu().numpy())       
                fold_test_probabilities.extend(probabilities.tolist())
                fold_test_ids.extend(ids.numpy())

            # Convert to numpy arrays
            test_targets_np = np.stack(fold_test_targets)
            test_probabilities_np = np.stack(fold_test_probabilities)

            # 7a) Compute ROC AUC for 2 classes:
            try:
                roc_auc = roc_auc_score(
                    test_targets_np[:, 1],  # ground-truth for positive class
                    test_probabilities_np[:, 1]  # predicted probability for positive class
                )
            except ValueError:
                roc_auc = 0.0

            print(f"ROC AUC Score [Test Fold={test_fold}, Val Fold={val_fold}]: {roc_auc:.4f}")

            # 8) Identify correct predictions
            actual_classes = np.argmax(test_targets_np, axis=1)  # shape (N,)
            correct_indices = [
                i for i in range(len(predictions))
                if predictions[i] == actual_classes[i]
            ]
            print(f"Found {len(correct_indices)} correctly predicted samples.")

            # 12) Store final results for this fold combination 
            sub = pd.DataFrame(
                fold_test_probabilities, 
                columns=["pred0", "pred1"]
            )
            sub["image_id"] = fold_test_ids
            sub["test_fold"] = test_fold
            sub["val_fold"] = val_fold
            sub["predicted_class"] = predictions
            aggregated_results.append(sub)

# Concatenate final results
final_results = pd.concat(aggregated_results, ignore_index=True)
print("Final Results (head):")
print(final_results.head())