In [None]:
from google.colab import drive
drive.mount('/content/drive')

# **Package installation and imports**

In [None]:
!pip install pandas numpy pydicom dicom2nifti tqdm matplotlib scikit-learn torch torchvision torchcam nibabel albumentations --quiet

**Restart session if got an error here**

In [None]:
# Data handling and utilities
import pandas as pd
import numpy as np
from tqdm import tqdm
import glob
import tempfile
import heapq
import os
import pickle
import ast
import random

# Plotting and visualization
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors
from scipy.ndimage import map_coordinates

# PyTorch and model tools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Vision and image processing
import pydicom                             # For reading DICOM CT slices
import torchvision.models as models        # Pretrained CNN models
import albumentations as A                 # Image augmentation
import cv2                                 # Image resizing and processing
import dicom2nifti                        # DICOM to NIfTI conversion

# Explainable AI (Grad-CAM)
from torchcam.methods import GradCAM

# Medical imaging tools
import nibabel as nib                      # For reading .nii segmentation masks

# Evaluation and metrics
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import (
    roc_auc_score, confusion_matrix, accuracy_score,
    roc_curve, auc, precision_recall_curve, average_precision_score
)

# Suppress unnecessary warnings
import warnings
warnings.filterwarnings("ignore")

**Copy and Extract Segmentation Masks to Colab Local Storage**

In [None]:
# Copy the ZIP from Drive to Colab local storage
!cp "/content/drive/MyDrive/MSc_project/segmentations.zip" "/content/segmentations.zip"

# Unzip in Colab local storage
!unzip -q "/content/segmentations.zip" -d "/content/segmentations/"

# For checking: confirm a few files exist
print(os.listdir('/content/segmentations/')[:5])

**DICOM Indexing and Label Matching**

This section indexes all DICOM slices from the filtered patient/series folders and matches each slice to an "Active Extravasation" (bleeding) label using the provided CSV.
It builds a Python list of dicts with each DICOM's patient, series, slice index, label, and path.

**RUN THIS ONLY ONCE**

In [None]:
# import re

# Paths
data_dir = '/content/drive/MyDrive/MSc_project/bleed_subset_images'
labels_csv = '/content/drive/MyDrive/MSc_project/image_level_labels_2024.csv'
labels_df = pd.read_csv(labels_csv)

# Build the label dictionary for extravasation (bleeding only)
label_dict = {}
for _, row in labels_df.iterrows():
    label = 1 if row['injury_name'] == 'Active_Extravasation' else 0
    key = (str(row['patient_id']), str(row['series_id']), int(row['instance_number']))
    if label == 1:
        label_dict[key] = label  # Only mark extravasation-positive slices

def extract_int(filename):
    """
    Extract the integer slice number from DICOM filename.
    Handles formats like '796 (1)', '0050', '420', etc.
    """
    num = re.findall(r'\d+', filename)
    return int(num[0]) if num else None

# Index all available DICOM files and match to labels
available = []
for patient_id in tqdm(os.listdir(data_dir), desc="Patients"):
    patient_path = os.path.join(data_dir, patient_id)
    for series_id in os.listdir(patient_path):
        series_path = os.path.join(patient_path, series_id)
        for dcm_file in os.listdir(series_path):
            fname_no_ext = os.path.splitext(dcm_file)[0]
            slice_id = extract_int(fname_no_ext)
            if slice_id is None:
                print(f"Warning: Could not extract integer from {dcm_file}")
                continue
            key = (patient_id, series_id, slice_id)
            dcm_path = os.path.join(series_path, dcm_file)
            label = label_dict.get(key, 0)  # 1 for extravasation, else 0
            available.append({
                'patient_id': patient_id,
                'series_id': series_id,
                'slice_id': slice_id,
                'label': label,
                'dcm_path': dcm_path
            })

print(f"Total DICOMs indexed: {len(available)}")

**Slice Selection for Model Input**

This block selects which DICOM slices to process:
- **For positive (bleeding) cases:**  
  All labeled slices **plus their neighbors** (within `NEIGHBOR_RANGE` slices) are included for better context.
- **For negative cases:**  
  A fixed number of evenly spaced slices (`NEG_SLICES_PER_SERIES`) per series are sampled to balance the dataset.
- **Duplicates are removed** to avoid repeated slices.

This is for balancing positive/negative data and ensuring enough context for 2.5D/stack-based modeling

**RUN THIS ONLY ONCE**

In [None]:
# from collections import defaultdict

NEIGHBOR_RANGE = 3             # Number of neighbor slices to include on each side of a positive
NEG_SLICES_PER_SERIES = 20     # Number of negative slices to sample per negative series

# Organize slices by (patient, series)
slices_by_patient_series = defaultdict(list)
for item in available:
    key = (item['patient_id'], item['series_id'])
    slices_by_patient_series[key].append(item)

# Select slices: all positive + neighbors, or downsample negatives
selected_records = []

for (patient_id, series_id), slices in slices_by_patient_series.items():
    # Sort slices by their slice index
    slices = sorted(slices, key=lambda x: x['slice_id'])
    labels = [s['label'] for s in slices]
    if any(l == 1 for l in labels):  # Positive series
        # Add all positive slices plus NEIGHBOR_RANGE neighbors on each side
        for idx, l in enumerate(labels):
            if l == 1:
                for offset in range(-NEIGHBOR_RANGE, NEIGHBOR_RANGE+1):
                    nb_idx = idx + offset
                    if 0 <= nb_idx < len(slices):
                        selected_records.append(slices[nb_idx])
    else:  # Negative series
        # Sample evenly spaced negatives
        if len(slices) <= NEG_SLICES_PER_SERIES:
            selected_records.extend(slices)
        else:
            indices = np.linspace(0, len(slices)-1, NEG_SLICES_PER_SERIES, dtype=int)
            for idx in indices:
                selected_records.append(slices[idx])

# Remove duplicate (patient, series, slice)
unique_keys = set()
final_records = []
for rec in selected_records:
    key = (rec['patient_id'], rec['series_id'], rec['slice_id'])
    if key not in unique_keys:
        final_records.append(rec)
        unique_keys.add(key)

print(f"Selected slices for processing: {len(final_records)}")

**Save Slice/Label Index for Fast Reload**

This block saves the **slice selection list** (`available` or `final_records`) and the optional `label_dict` as `.pkl` files for easy, fast access in future runs and no need to re-index DICOM or regenerate labels each time.


**UNCOMMENT WHEN RUNNING**

In [None]:
# import pickle

# Save the slice index
# with open('/content/drive/MyDrive/MSc_project/available_slices.pkl', 'wb') as f:
#     pickle.dump(available, f)   # Or: pickle.dump(final_records, f)

# Save the label dictionary
# with open('/content/drive/MyDrive/MSc_project/label_dict.pkl', 'wb') as f:
#    pickle.dump(label_dict, f)

**Load Precomputed Slice and Label Index**

This block **loads your precomputed list of selected DICOM slices** and the `label_dict` from pickle files.  
It is done to avoid reprocessing when resuming work to simply reload and continue.


In [None]:
# import pickle

with open('/content/drive/MyDrive/MSc_project/available_slices.pkl', 'rb') as f:
    available = pickle.load(f)

with open('/content/drive/MyDrive/MSc_project/label_dict.pkl', 'rb') as f:
    label_dict = pickle.load(f)

**DICOM to 3-Channel NPY Conversion and Indexing**

This block **preprocesses DICOM slices into NPY files** for fast model training:
- Loads each DICOM, applies HU scaling, resizes, and creates a 3-channel image using three standard CT windows (soft tissue, liver, blood).
- Saves each preprocessed array as `.npy` for efficient future access.
- Updates the dataframe with the generated NPY filenames for model indexing.

In [None]:
csv_path = '/content/drive/MyDrive/MSc_project/selected_slice_index.csv'
output_dir = '/content/drive/MyDrive/MSc_project/preproc_npy_3ch'
os.makedirs(output_dir, exist_ok=True)

RESIZE_SHAPE = (256, 256)  # (width, height) for model input

def apply_window(img, center, width):
    """Apply CT windowing to HU values and scale to [0,1]."""
    lower = center - width // 2
    upper = center + width // 2
    img = np.clip(img, lower, upper)
    img = (img - lower) / (upper - lower)
    return img

df = pd.read_csv(csv_path)

for idx, row in tqdm(df.iterrows(), total=len(df)):
    dcm_path = row['dcm_path']
    label = row['label']
    patient_id = row['patient_id']
    series_id = row['series_id']
   slice_id = row['slice_id']

    try:
       ds = pydicom.dcmread(dcm_path)
       arr = ds.pixel_array.astype(np.float32)

        # Convert to Hounsfield Units (HU)
        if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
            arr = arr * float(ds.RescaleSlope) + float(ds.RescaleIntercept)

        # Resize before windowing for efficiency
        if arr.shape != RESIZE_SHAPE[::-1]:
            arr = cv2.resize(arr, RESIZE_SHAPE, interpolation=cv2.INTER_LINEAR)

        # Apply three windowings for multi-channel input
        arr_soft  = apply_window(arr, center=50, width=400)    # General abdomen
        arr_liver = apply_window(arr, center=60, width=150)    # Liver
        arr_blood = apply_window(arr, center=40, width=80)     # Blood/hemorrhage

        arr_3ch = np.stack([arr_soft, arr_liver, arr_blood], axis=0)  # [3, H, W]
        arr_3ch = arr_3ch.astype(np.float32)

        npy_name = f"{patient_id}_{series_id}_{slice_id}_label{label}.npy"
        npy_path = os.path.join(output_dir, npy_name)
        if os.path.exists(npy_path):
            continue
        np.save(npy_path, arr_3ch)
    except Exception as e:
        print(f"Error processing {dcm_path}: {e}")

print("Preprocessing to 3-channel NPY complete.")

df['npy_file'] = df.apply(
    lambda r: f"{r['patient_id']}_{r['series_id']}_{r['slice_id']}_label{r['label']}.npy",
    axis=1
)
df.to_csv('/content/drive/MyDrive/MSc_project/preproc_npy_3ch_index.csv', index=False)

**NPY File checking for 2.5D Input (3-Channel Per Slice)**

**Just check for any corrupted files.**

This block is **now only for quality check:** confirms all `.npy` have shape `[3, H, W]` and lists any corrupted ones.

**RUN THIS ONLY ONCE FOR CHECKING**

In [None]:
npy_dir = '/content/drive/MyDrive/MSc_project/preproc_npy_3ch'
npy_files = sorted([f for f in os.listdir(npy_dir) if f.endswith('.npy')])

records = []
for fname in tqdm(npy_files, desc="Checking NPY shapes"):
    fpath = os.path.join(npy_dir, fname)
    arr = np.load(fpath, mmap_mode='r')  # Use mmap_mode to save memory
    records.append({
       'npy_file': fname,
        'shape_0': arr.shape[0],
        'shape_1': arr.shape[1],
        'shape_2': arr.shape[2]
    })

shapes_df = pd.DataFrame(records)
shapes_df.to_csv('/content/drive/MyDrive/MSc_project/npy_shapes.csv', index=False)

print(shapes_df['shape_0'].value_counts())
print(shapes_df.head())

# Find any files that are not [3, H, W]
bad_files = shapes_df[shapes_df['shape_0'] != 3]
if not bad_files.empty:
    print("Problem files:", bad_files)
else:
    print("All NPY files are valid [3, H, W].")

**2.5D Stack Construction: Sliding Window NPY Index**

This **creates sliding stacks of NPY slices** for each patient/series, ready for 2.5D model training.  
- Each stack contains `stack_size` slices (`3` by default) centered on each slice.
- Edges are padded by repeating the nearest slice.
- Output is a DataFrame with patient/series, center slice, stack file list, and label (center slice label).
- Saved for efficient, balanced batch loading in PyTorch.


In [None]:
# Settings
csv_path = '/content/drive/MyDrive/MSc_project/preproc_npy_3ch_index.csv'
stack_size = 3  # Odd number, e.g. 3 for [prev, center, next]

# Load NPY slice index
df = pd.read_csv(csv_path)
df['slice_id'] = df['slice_id'].astype(int)

records = []
# Group by patient/series: build stacks (sliding window)
for (patient_id, series_id), group in df.groupby(['patient_id', 'series_id']):
    group = group.sort_values('slice_id').reset_index(drop=True)
    slice_ids = group['slice_id'].tolist()
    npy_files = group['npy_file'].tolist()
    labels = group['label'].tolist()
    n = len(slice_ids)
    for i in range(n):
        stack_idxs = []
        stack_npy_files = []
        # Symmetric context around center
        for offset in range(-(stack_size // 2), stack_size // 2 + 1):
            idx = i + offset
            idx = min(max(idx, 0), n - 1)  # Clamp to edge if out of bounds
            stack_idxs.append(slice_ids[idx])
            stack_npy_files.append(npy_files[idx])
        # Center slice's label for the stack
        records.append({
            'patient_id': patient_id,
            'series_id': series_id,
            'center_slice': slice_ids[i],
            'stack_slice_ids': stack_idxs,
            'stack_npy_files': stack_npy_files,
            'label': labels[i]
        })

# Save DataFrame with stack info
stack_df = pd.DataFrame(records)
stack_df.to_csv('/content/drive/MyDrive/MSc_project/stack3_index.csv', index=False)

print(stack_df.head())
print(f"Total stacks created: {len(stack_df)}")

**Match Segmentation Masks to Series**

This checks **which patient/series in the stack index have available segmentation masks**.  
It adds a `has_mask` boolean column to the stack DataFrame for downstream use (segmentation/evaluation).


In [None]:
# List available mask files
mask_series_ids = set(
    f.split('.')[0] for f in os.listdir('/content/segmentations/') if f.endswith('.nii')
)

# Load your sliding window stack index
df = pd.read_csv('/content/drive/MyDrive/MSc_project/stack3_index.csv')

# Check if this stack's series have a segmentation mask
df['has_mask'] = df['series_id'].astype(str).isin(mask_series_ids)

# Summary of mask coverage
print(df['has_mask'].sum(), "slices/stacks with masks found.")
print(df[df['has_mask']].head())
print("Unique series_ids with masks in your data:", df[df['has_mask']]['series_id'].unique())

# Save the DataFrame with has_mask for future use
df.to_csv('/content/drive/MyDrive/MSc_project/stack3_index_with_mask.csv', index=False)

**Train/Validation/Test Split by Patient**

This section splits the dataset into train, validation, and test sets at the patient level, making sure each patient appears in only one set. It uses `GroupShuffleSplit` to avoid patient overlap and ensure unbiased evaluation.


In [None]:
# Copy master DataFrame for splitting
df_to_split = df.copy()

# Assign patient-level label: 1 if any slice has bleeding, else 0
patient_labels = df_to_split.groupby('patient_id')['label'].max()
all_patient_ids = patient_labels.index

# Split patients into training and temporary sets (val + test)
splitter = GroupShuffleSplit(test_size=0.3, n_splits=1, random_state=42)
train_inds, temp_inds = next(splitter.split(all_patient_ids, groups=all_patient_ids, y=patient_labels))
train_patient_ids = all_patient_ids[train_inds]
temp_patient_ids = all_patient_ids[temp_inds]

# Further split temporary set equally into validation and test sets
val_test_splitter = GroupShuffleSplit(test_size=0.5, n_splits=1, random_state=42)
val_inds, test_inds = next(val_test_splitter.split(temp_patient_ids, groups=temp_patient_ids, y=patient_labels.loc[temp_patient_ids]))
val_patient_ids = temp_patient_ids[val_inds]
test_patient_ids = temp_patient_ids[test_inds]

# Create DataFrames for each split
train_df = df_to_split[df_to_split['patient_id'].isin(train_patient_ids)]
val_df = df_to_split[df_to_split['patient_id'].isin(val_patient_ids)]
test_df = df_to_split[df_to_split['patient_id'].isin(test_patient_ids)]

# Check that patients do not overlap between splits
assert len(set(train_df['patient_id']) & set(val_df['patient_id'])) == 0
assert len(set(train_df['patient_id']) & set(test_df['patient_id'])) == 0
assert len(set(val_df['patient_id']) & set(test_df['patient_id'])) == 0

print(f"Train Patients: {len(train_patient_ids)} | Validation Patients: {len(val_patient_ids)} | Test Patients: {len(test_patient_ids)}")

**Image Augmentation for CT Slices**

This block defines a **data augmentation pipeline** using Albumentations, designed for 2.5D/3D medical images.

[Albumentation Documentation](https://albumentations.ai/docs/3-basic-usage/keypoint-augmentations/)


In [None]:
# Compose image augmentations for training (applied per CT stack slice)
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5), # Randomly flip images horizontally (50% chance)
    A.VerticalFlip(p=0.5), # Randomly flip images vertically (50% chance)
    A.RandomRotate90(p=0.5), # Randomly rotate by 90°, 180°, or 270° (50% chance)
    A.Affine(
        rotate=(-20, 20), # Random small rotation in range -20 to 20 degrees
        scale=(0.95, 1.05), # Random scale (zoom in/out by up to 5%)
        translate_percent={ # Random translation up to 5% of image size in x/y
            "x": (-0.05, 0.05),
            "y": (-0.05, 0.05)
        },
        p=0.5
    ),
    A.RandomBrightnessContrast(p=0.2),  # Randomly adjust brightness and contrast (20% chance)
    A.RandomGamma(gamma_limit=(85, 115), p=0.2),  # Small random gamma changes (20% chance)
    A.Blur(blur_limit=3, p=0.1),           # Slight blur (3x3 kernel, 10% chance)
    A.ElasticTransform(alpha=10, sigma=1, p=0.05),  # Elastic deformation for more realistic variation (5% chance)
])