In [None]:
# --- Standard Library ---
import os
import re
import pickle
from collections import defaultdict

# --- Third-Party Libraries ---
import numpy as np
import pandas as pd
import requests
import matplotlib.pyplot as plt
from scipy.ndimage import (
    label,
    binary_dilation,
    binary_erosion,
)

from sklearn.utils import shuffle

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input,
    Conv2D,
    UpSampling2D,
    Dense,
    Dropout,
    BatchNormalization,
    Concatenate,
    GlobalAveragePooling2D,
    Reshape,
    SpatialDropout2D,
    Multiply,
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

from skimage.filters import threshold_otsu


In [None]:
# ===========================================
# Clean Imports (deduplicated and organized)
# ===========================================

# --------------------
# Paths & source
# --------------------
BASE_PATH       = "data"
FILENAME        = "image_dicts_256_wgrayscale_andcutoffs.pkl"
FILE_PATH       = os.path.join(BASE_PATH, FILENAME)
EXCEL_FILE_PATH = os.path.join(BASE_PATH, "sample_groups.xlsx")
URL             = "https://github.com/tylervasse/DOCI-Prediction/releases/download/v1.0/image_dicts_256_wgrayscale_andcutoffs.pkl"

# --------------------
# IO helpers
# --------------------
def download_file(url, output_path):
    """
    Downloads a file from a given URL if it does not already exist locally.

    Args:
        url (str): The URL pointing to the file to download.
        output_path (str): Local path where the downloaded file should be saved.

    Returns:
        None: Prints status messages indicating whether the file was downloaded
            or already existed at the target path.
    """
    if os.path.exists(output_path):
        print(f"File already exists at {output_path}")
        return
    print(f"Downloading to {output_path}...")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(output_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    print("Download complete.")

def load_image_dicts(file_path):
    """
    Loads a list of image dictionaries from a pickle file.

    Args:
        file_path (str): Path to the pickle file containing serialized image data.

    Returns:
        list: A list of dictionaries, each containing image metadata and pixel data.
              Returns an empty list if the file does not exist or loading fails.
    """
    try:
        with open(file_path, "rb") as f:
            return pickle.load(f)
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return []
    except Exception as e:
        print(f"Error loading file: {e}")
        return []

def load_sample_groups(excel_file_path):
    """
    Loads training, validation, and test sample group identifiers from an Excel file.

    Args:
        excel_file_path (str): Path to the Excel file specifying sample groups.
                               Must contain the columns:
                               'Train Samples', 'Validation Samples', 'Test Samples'.

    Returns:
        tuple: A tuple containing three lists:
            - list of str: Training sample base names.
            - list of str: Validation sample base names.
            - list of str: Test sample base names.
        Each list will be empty if the file is missing or cannot be parsed.
    """
    try:
        df = pd.read_excel(excel_file_path)
        norm = lambda col: [s.strip().strip("'") for s in df[col].dropna().tolist()]
        return norm('Train Samples'), norm('Validation Samples'), norm('Test Samples')
    except FileNotFoundError:
        print(f"Error: Sample groups file not found at {excel_file_path}")
        return [], [], []
    except Exception as e:
        print(f"Error reading Excel: {e}")
        return [], [], []

# --------------------
# Basic parsing utils
# --------------------
def get_base_name(name):
    """
    Extracts the base sample name from a DOCI image filename.

    Args:
        name (str): Full image filename containing a suffix like '_DOCI_n'.

    Returns:
        str: The base name preceding '_DOCI_n', used to associate slices
             belonging to the same specimen.
    """
    return name.split('_DOCI')[0]

def get_doci_number(name):
    """
    Extracts the DOCI index number from an image filename.

    Args:
        name (str): Filename containing the pattern '_DOCI_<number>'.

    Returns:
        int: The extracted DOCI slice index. Returns -1 if no index is found.
    """
    m = re.search(r'_DOCI_(\d+)', name)
    return int(m.group(1)) if m else -1

# --------------------
# Split image dicts into splits by sample base name
# --------------------
def categorize_images(image_data, train_samples, val_samples, test_samples):
    """
    Categorizes image dictionaries into training, validation, and test sets
    based on their base sample names.

    Args:
        image_data (list of dict): List of image dictionaries containing at least
                                   the key 'name'.
        train_samples (list of str): Base names assigned to the training split.
        val_samples (list of str): Base names assigned to the validation split.
        test_samples (list of str): Base names assigned to the test split.

    Returns:
        tuple: A tuple containing three lists:
            - list of dict: Training image dictionaries.
            - list of dict: Validation image dictionaries.
            - list of dict: Test image dictionaries.
    """
    train_set, val_set, test_set = [], [], []
    for d in image_data:
        base = "_".join(d['name'].split('_')[:2])  # e.g., 'SSW-23-12345_A1'
        if base in train_samples:
            train_set.append(d)
        elif base in val_samples:
            val_set.append(d)
        elif base in test_samples:
            test_set.append(d)
    return train_set, val_set, test_set

# --------------------
# Voxelize per-sample (group by base name, sort by DOCI)
# --------------------
def samples_to_voxels(dataset):
    """
    Groups individual DOCI images by specimen, sorts them by DOCI index,
    and constructs voxel stacks across the depth dimension.

    Args:
        dataset (list of dict): List of image dictionaries containing keys:
            - 'name' (str): Filename used to infer sample grouping.
            - 'grayscale' (numpy.ndarray): 2D grayscale DOCI image.
            - 'image_grayscale_cutoff' (numpy.ndarray): Cutoff-processed grayscale image.
            - 'mask' (numpy.ndarray or None): Tumor mask, if available.
            - 'tissue_type' (str): Annotated tissue label for the sample.

    Returns:
        list: A list of voxelized sample dictionaries, each containing:
            - 'name' (str): Base sample name.
            - 'grayscale_voxel' (numpy.ndarray): Stacked grayscale images [H, W, D].
            - 'grayscale_image_cutoff_voxel' (numpy.ndarray): Stacked cutoff images [H, W, D].
            - 'tissue_type' (str): Tissue class for the sample.
            - 'mask' (numpy.ndarray or None): First available mask across slices.
    """
    grouped = defaultdict(lambda: {
        'names': [], 'grayscale': [], 'image_grayscale_cutoff': [], 'mask': None, 'tissue_type': None
    })

    for d in dataset:
        base = get_base_name(d['name'])
        grouped[base]['names'].append(d['name'])
        grouped[base]['grayscale'].append(d['grayscale'])
        grouped[base]['image_grayscale_cutoff'].append(d['image_grayscale_cutoff'])
        grouped[base]['tissue_type'] = d['tissue_type']
        if grouped[base]['mask'] is None and d.get('mask') is not None:
            grouped[base]['mask'] = d['mask']

    voxelized = []
    for base, g in grouped.items():
        order = sorted(range(len(g['names'])), key=lambda i: get_doci_number(g['names'][i]))
        gray     = [g['grayscale'][i] for i in order]
        gray_cut = [g['image_grayscale_cutoff'][i] for i in order]
        grayscale_voxel                 = np.stack(gray, axis=-1).astype(np.float32)     # [H,W,D]
        grayscale_image_cutoff_voxel    = np.stack(gray_cut, axis=-1).astype(np.uint8)   # [H,W,D]

        voxelized.append({
            'name': base,
            'grayscale_voxel': grayscale_voxel,
            'grayscale_image_cutoff_voxel': grayscale_image_cutoff_voxel,
            'tissue_type': g['tissue_type'],
            'mask': g['mask']
        })
    return voxelized

# ====================
# Main flow
# ====================
# 1) Ensure data file
download_file(URL, FILE_PATH)

# 2) Load raw dicts
image_dicts = load_image_dicts(FILE_PATH)

# 3) Exclude specific samples by substring match in 'name'
EXCLUDE_LIST = ["SSW-23-14395_C2", "SSW-23-05363_A7"]
image_dicts = [d for d in image_dicts if not any(excl in d['name'] for excl in EXCLUDE_LIST)]

# 4) Load sample groups from Excel
train_samples, val_samples, test_samples = load_sample_groups(EXCEL_FILE_PATH)

# 5) Assign to splits and shuffle at image level
train_set, val_set, test_set = categorize_images(image_dicts, train_samples, val_samples, test_samples)
train_set = shuffle(train_set, random_state=42)
val_set   = shuffle(val_set,   random_state=42)
test_set  = shuffle(test_set,  random_state=42)

# 6) Voxelize per sample
train_combined = samples_to_voxels(train_set)
val_combined   = samples_to_voxels(val_set)
test_combined  = samples_to_voxels(test_set)

print(f"Samples -> train: {len(train_combined)} | val: {len(val_combined)} | test: {len(test_combined)}")

In [None]:
# --------------------
# Load regional categorization results from Excel
# --------------------
excel_path = "regional_categorization_results.xlsx"
df_cat = pd.read_excel(excel_path)

# Build lookup: (split, name) -> predicted label
cat_lookup = {
    (str(row["split"]), str(row["name"])): str(row["predicted"])
    for _, row in df_cat.iterrows()
}
print(f"[LOAD] Loaded {len(cat_lookup)} categorization entries from {excel_path}")

In [None]:
# --- constants ---
TISSUES3 = ['Normal', 'Follicular', 'Papillary']
CLASS_TO_ID3 = {c: i for i, c in enumerate(TISSUES3)}
TARGET_TUMOR = "Papillary"   # still used later when you want the TARGET-only map
TARGET_ID = CLASS_TO_ID3[TARGET_TUMOR]

# ---- Channel selection (0-based indices) ----
# Define channels to REMOVE by index (1-based here for readability), then convert to 0-based indices.
REMOVE_VOXEL_CHANNELS = [1, 2, 4, 7, 9, 11, 12, 14, 16, 17, 19]
REMOVE_VOXEL_CHANNELS = [i - 1 for i in REMOVE_VOXEL_CHANNELS]

# Optionally, explicitly define channels to KEEP (overrides REMOVE_* if not None)
KEEP_VOXEL_CHANNELS = None 


# ===========================================
# Excel-based regional filtering + voxel-only augmentation
# ===========================================
# ---- CONFIG: which tumor class to filter out for this TARGET_TUMOR ----
if TARGET_TUMOR == "Papillary":
    FILTER_TUMOR = "Follicular"
elif TARGET_TUMOR == "Follicular":
    FILTER_TUMOR = "Papillary"
else:
    FILTER_TUMOR = ""

FILTER_VALTEST = True
assert FILTER_TUMOR in ("Follicular", "Papillary")


# ===========================================
# 1) Filtering using Excel-based predictions
# ===========================================
def filter_out_by_category_from_excel(samples, split_name, remove_class, lookup):
    """
    Filters samples based on precomputed regional predictions stored in an
    external Excel-derived lookup table.

    For each sample, this function checks the predicted label from `lookup`
    using the (split_name, sample_name) key. If the predicted label matches
    `remove_class`, the sample is dropped; otherwise, it is retained. Samples
    without an entry in the lookup are kept by default.

    Args:
        samples (list of dict): List of sample dictionaries (e.g., train_combined,
            val_combined, test_combined), each containing at least a 'name' field.
        split_name (str): Name of the dataset split, typically one of
            {'train', 'val', 'test'}, used as part of the lookup key.
        remove_class (str): Tumor class label to filter out, such as
            'Follicular' or 'Papillary'.
        lookup (dict): Dictionary mapping (split_name, sample_name) tuples to
            predicted labels, e.g., {('train', 'SSW-23-12345_A1'): 'Papillary'}.

    Returns:
        tuple:
            - list of dict: Samples retained after filtering (kept_samples).
            - list of dict: Samples removed because their predicted label
              matched `remove_class` (dropped_samples).
    """
    kept, dropped = [], []
    for d in samples:
        name = str(d.get("name", "unknown_sample"))
        pred = lookup.get((split_name, name), None)

        # If no prediction is found, keep by default
        if pred is None:
            kept.append(d)
            continue

        if pred == remove_class:
            dropped.append(d)
        else:
            kept.append(d)
    return kept, dropped


# Use Excel lookup to build filtered splits
train_px_kept, train_px_dropped = filter_out_by_category_from_excel(
    train_combined, "train", FILTER_TUMOR, cat_lookup
)

if FILTER_VALTEST:
    val_px_kept,  val_px_dropped  = filter_out_by_category_from_excel(
        val_combined,  "val",  FILTER_TUMOR, cat_lookup
    )
    test_px_kept, test_px_dropped = filter_out_by_category_from_excel(
        test_combined, "test", FILTER_TUMOR, cat_lookup
    )
    val_px  = val_px_kept
    test_px = test_px_kept
else:
    val_px  = val_combined
    test_px = test_combined

print(
    f"[FILTER/EXCEL] kept: "
    f"train={len(train_px_kept)} val={len(val_px)} test={len(test_px)}"
)
print(
    f"[FILTER/EXCEL] dropped({FILTER_TUMOR}-pred): "
    f"train={len(train_px_dropped)} "
    f"val={len(val_px_dropped) if FILTER_VALTEST else 0} "
    f"test={len(test_px_dropped) if FILTER_VALTEST else 0}"
)

# Augmentation for model

In [None]:
# Using the same regional params picked earlier
REGIONAL_PARAMS = dict(
    black_tol=5,
    win=45,
    stride=20,
    frac_thresh=0.80,
    min_tissue_px_per_win=30,
)

# ===========================================
# 1) Channel selection helpers
# ===========================================
def _sanitize_indices(n_channels, keep, remove):
    """
    Computes the final set of channel indices to retain, based on keep/remove
    specifications and the total number of available channels.

    Args:
        n_channels (int): Total number of channels in the input array.
        keep (list of int or None): Explicit list of channel indices to keep
            (0-based). If not None, this overrides `remove`.
        remove (list of int or None): List of channel indices to exclude
            (0-based). Used only if `keep` is None.

    Returns:
        list of int: Sorted list of unique channel indices that should be
        retained, restricted to the range [0, n_channels).
    """
    if keep is not None:
        keep = sorted(int(i) for i in keep if 0 <= int(i) < n_channels)
        return keep
    # build keep from remove
    remove = set(int(i) for i in (remove or []) if 0 <= int(i) < n_channels)
    return [i for i in range(n_channels) if i not in remove]


def _apply_channel_filter(arr, keep_idx):
    """
    Applies a channel-selection filter to a voxel or image array.

    Args:
        arr (numpy.ndarray): Input array of shape [H, W, C] or [H, W]. If 2D,
            the array is returned unchanged (no channel dimension to filter).
        keep_idx (list of int): List of channel indices (0-based) to retain
            along the last axis.

    Returns:
        numpy.ndarray: Array with the same height and width as `arr`, but with
        only the channels in `keep_idx` preserved when `arr` is 3D.

    Raises:
        ValueError: If `arr` has a dimensionality other than 2 or 3.
    """
    arr = np.asarray(arr)
    if arr.ndim == 2:
        return arr  # nothing to filter if no channel dim
    if arr.ndim != 3:
        raise ValueError(f"Expected [H,W,C], got shape {arr.shape}")
    if len(keep_idx) == arr.shape[-1]:
        return arr
    return arr[..., keep_idx]


def filter_out_by_category(samples, px_model_all, remove_class: str):
    """
    Filters samples based on PCA-based regional categorization, keeping only
    those that are not predicted as a specific tumor class.

    For each sample, a regional category is computed using
    `categorize_sample_regional` and the PCA-based model `px_model_all`.
    Samples whose predicted regional label equals `remove_class` are removed.

    Args:
        samples (list of dict): List of sample dictionaries to be filtered.
        px_model_all (PixelPCAContextClassifierAll): Trained pixel-wise model
            used inside the regional categorization function.
        remove_class (str): Tumor class label to exclude, e.g., 'Papillary'
            or 'Follicular'.

    Returns:
        list of dict: Filtered list of samples for which the regional
        prediction is not equal to `remove_class`.
    """
    kept = []
    for d in samples:
        pred = categorize_sample_regional(px_model_all, d, **REGIONAL_PARAMS)
        if pred != remove_class:
            kept.append(d)
    return kept

# ===========================================
# 2) Geometric aug (centered rotate + mild zoom + optional flips/noise)
# ===========================================
def _build_centered_rotation_transform(height, width, angle_rad):
    """
    Builds a projective transform vector for a rotation around the center of
    an image, suitable for use with TensorFlow's ImageProjectiveTransform ops.

    Args:
        height (int or tf.Tensor): Image height in pixels.
        width (int or tf.Tensor): Image width in pixels.
        angle_rad (float or tf.Tensor): Rotation angle in radians. Positive
            values correspond to counterclockwise rotation.

    Returns:
        tf.Tensor: A 1D tensor of length 8 representing the transformation
        parameters [a0, a1, a2, a3, a4, a5, a6, a7], encoding an affine
        transform centered on the image.
    """
    angle = tf.cast(angle_rad, tf.float32)
    c, s  = tf.math.cos(angle), tf.math.sin(angle)
    cx = (tf.cast(width,  tf.float32) - 1.0) / 2.0
    cy = (tf.cast(height, tf.float32) - 1.0) / 2.0
    a0, a1, a3, a4 = c, -s, s, c
    a2 = -a0 * cx - a1 * cy + cx
    a5 = -a3 * cx - a4 * cy + cy
    return tf.stack([a0, a1, a2, a3, a4, a5, 0.0, 0.0], axis=0)


def _rotate_any(img, angle_rad, interpolation="bilinear", fill_mode="REFLECT"):
    """
    Rotates an image by an arbitrary angle about its center using a projective
    transform, with configurable interpolation and fill mode.

    Args:
        img (numpy.ndarray or tf.Tensor): Input image of shape [H, W, C] or [H, W].
        angle_rad (float or tf.Tensor): Rotation angle in radians.
        interpolation (str): Interpolation mode, either 'bilinear' or 'nearest'.
        fill_mode (str): Fill mode passed to ImageProjectiveTransformV3, e.g.,
            'REFLECT', 'CONSTANT', 'NEAREST'.

    Returns:
        tf.Tensor: Rotated image tensor with the same spatial dimensions as
        the input.
    """
    img = tf.convert_to_tensor(img, dtype=tf.float32)
    H = tf.shape(img)[0]
    W = tf.shape(img)[1]
    transform = tf.reshape(_build_centered_rotation_transform(H, W, angle_rad), [1, 8])
    interp = "BILINEAR" if interpolation.lower().startswith("bilinear") else "NEAREST"
    out = tf.raw_ops.ImageProjectiveTransformV3(
        images=tf.expand_dims(img, axis=0),
        transforms=transform,
        output_shape=tf.stack([H, W]),
        interpolation=interp,
        fill_mode=fill_mode,
        fill_value=0.0,
    )
    return tf.squeeze(out, axis=0)


def augment_triplet(image, image_cutoff, mask):
    """
    Applies a matched geometric and noise augmentation to an image triplet:
      - voxel image
      - cutoff image
      - binary mask

    The same random flips, rotation, zoom, and optional Gaussian noise are
    applied consistently to all three inputs. Outputs are then cropped/padded
    to a fixed size of 256×256.

    Args:
        image (numpy.ndarray or tf.Tensor): Voxel image of shape [H, W, C],
            typically float32.
        image_cutoff (numpy.ndarray or tf.Tensor): Cutoff voxel image of shape
            [H, W, C], typically float32.
        mask (numpy.ndarray or tf.Tensor): Binary or label mask of shape
            [H, W] or [H, W, 1], integer or float.

    Returns:
        tuple:
            - tf.Tensor: Augmented voxel image of shape [256, 256, C'].
            - tf.Tensor: Augmented cutoff image of shape [256, 256, C'].
            - tf.Tensor: Augmented mask of shape [256, 256], with nearest-
              neighbor interpolation preserving label integrity.
    """
    image        = tf.cast(image, tf.float32)
    image_cutoff = tf.cast(image_cutoff, tf.float32)
    mask         = tf.cast(mask, tf.float32)

    if tf.rank(mask) == 2:
        mask = tf.expand_dims(mask, axis=-1)

    # random params
    flip_lr   = tf.random.uniform([], 0.0, 1.0)
    flip_ud   = tf.random.uniform([], 0.0, 1.0)
    angle_rad = tf.random.uniform([], -15.0 * np.pi / 180.0, 15.0 * np.pi / 180.0)
    zoom      = tf.random.uniform([], 0.90, 1.10)
    add_noise = tf.random.uniform([], 0.0, 1.0)

    # flips
    if flip_lr > 0.5:
        image        = tf.image.flip_left_right(image)
        image_cutoff = tf.image.flip_left_right(image_cutoff)
        mask         = tf.image.flip_left_right(mask)
    if flip_ud > 0.5:
        image        = tf.image.flip_up_down(image)
        image_cutoff = tf.image.flip_up_down(image_cutoff)
        mask         = tf.image.flip_up_down(mask)

    # rotation
    image        = _rotate_any(image,        angle_rad, "bilinear", "REFLECT")
    image_cutoff = _rotate_any(image_cutoff, angle_rad, "bilinear", "REFLECT")
    mask         = _rotate_any(mask,         angle_rad, "nearest",  "REFLECT")

    # zoom
    H = tf.shape(image)[0]
    W = tf.shape(image)[1]
    new_h = tf.cast(tf.cast(H, tf.float32) * zoom, tf.int32)
    new_w = tf.cast(tf.cast(W, tf.float32) * zoom, tf.int32)
    image        = tf.image.resize(image,        [new_h, new_w])
    image_cutoff = tf.image.resize(image_cutoff, [new_h, new_w])
    mask         = tf.image.resize(mask, [new_h, new_w],
                                   method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # center crop/pad to 256x256
    image        = tf.image.resize_with_crop_or_pad(image,        256, 256)
    image_cutoff = tf.image.resize_with_crop_or_pad(image_cutoff, 256, 256)
    mask         = tf.image.resize_with_crop_or_pad(mask,         256, 256)

    # noise: separate tensors to avoid channel-shape mismatch
    if add_noise > 0.5:
        noise_img = tf.random.normal(
            tf.shape(image), mean=0.0, stddev=0.001, dtype=tf.float32
        )
        noise_cut = tf.random.normal(
            tf.shape(image_cutoff), mean=0.0, stddev=0.001, dtype=tf.float32
        )
        image        = image + noise_img
        image_cutoff = image_cutoff + noise_cut

    return image, image_cutoff, tf.squeeze(mask, axis=-1)

# --- simple control over aug counts ---
POS_AUGS = 3
NEG_AUGS = 1


def _has_positive(msk):
    """
    Checks whether a mask contains any non-zero (positive) pixels.

    Args:
        msk (numpy.ndarray): Input mask of shape [H, W] or [H, W, 1].

    Returns:
        bool: True if the mask contains at least one positive pixel,
        False otherwise.
    """
    m = np.asarray(msk)
    if m.ndim == 3 and m.shape[-1] == 1:
        m = m[..., 0]
    return np.sum(m) > 0


def make_augmented_copies(samples, pos_augs=POS_AUGS, neg_augs=NEG_AUGS):
    """
    Creates resized base copies and multiple augmented versions of each sample,
    applying voxel-only geometric augmentation and channel selection.

    Workflow per sample:
        1) Apply channel selection to grayscale voxel and cutoff voxel using
           KEEP_VOXEL_CHANNELS / REMOVE_VOXEL_CHANNELS.
        2) Resize voxel, cutoff, and mask to 256×256.
        3) Store a base (non-augmented) copy with updated fields.
        4) If the resized mask contains positive pixels, generate `pos_augs`
           augmented copies via `augment_triplet`; otherwise generate `neg_augs`.
        5) Remove any stale pixel-level fields ('px_probs', 'px_map', 'px_features').

    Args:
        samples (list of dict): Original sample dictionaries, each containing:
            - 'grayscale_voxel'
            - 'grayscale_image_cutoff_voxel'
            - optional 'mask'
        pos_augs (int): Number of augmentations for positive (tumor-containing)
            masks.
        neg_augs (int): Number of augmentations for negative (no tumor) masks.

    Returns:
        list of dict: Expanded list of sample dictionaries including the base
        resized samples and their augmented variants. Each dict contains:
            - 'grayscale_voxel'
            - 'grayscale_image_cutoff_voxel'
            - 'mask'
            - 'voxel_channels_kept'
            - 'cutoff_channels_kept'
    """
    out = []
    for d in samples:
        img = np.asarray(d["grayscale_voxel"], np.float32)              # [H,W,Cv]
        cut = np.asarray(d["grayscale_image_cutoff_voxel"], np.float32) # [H,W,Cc]
        msk = np.asarray(
            d.get("mask", np.zeros(img.shape[:2], np.uint8)),
            np.float32,
        )

        # ---- channel filtering (pre-augmentation) ----
        Cv = img.shape[-1] if img.ndim == 3 else 1
        Cc = cut.shape[-1] if cut.ndim == 3 else 1

        keep_vox = _sanitize_indices(Cv, KEEP_VOXEL_CHANNELS,   REMOVE_VOXEL_CHANNELS)
        keep_cut = _sanitize_indices(Cc, KEEP_VOXEL_CHANNELS,  REMOVE_VOXEL_CHANNELS)

        # If you also want to filter voxel channels, uncomment:
        img = _apply_channel_filter(img, keep_vox)
        cut = _apply_channel_filter(cut, keep_cut)   # [H,W,Cc']

        # resize to 256x256 once (voxel-only augmentation target)
        img0 = tf.image.resize(img, [256, 256]).numpy()
        cut0 = tf.image.resize(cut, [256, 256]).numpy()
        msk0 = tf.image.resize(
            msk[..., None], [256, 256],
            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
        ).numpy()[..., 0]

        # base (unaltered except resize/pad)
        base = {
            **d,
            "grayscale_voxel": img0,
            "grayscale_image_cutoff_voxel": cut0,
            "mask": msk0,
            # helpful metadata for traceability
            "voxel_channels_kept": keep_vox,
            "cutoff_channels_kept": keep_cut,
        }
        # ensure no stale PX fields if present
        for k in ("px_probs", "px_map", "px_features"):
            base.pop(k, None)
        out.append(base)

        # augmented copies (voxel-only)
        n_aug = pos_augs if _has_positive(msk0) else neg_augs
        for _ in range(n_aug):
            ai, ac, am = augment_triplet(img0, cut0, msk0)
            aug = {
                **d,
                "grayscale_voxel": ai.numpy(),
                "grayscale_image_cutoff_voxel": ac.numpy(),
                "mask": am.numpy(),
                "voxel_channels_kept": keep_vox,
                "cutoff_channels_kept": keep_cut,
            }
            for k in ("px_probs", "px_map", "px_features"):
                aug.pop(k, None)
            out.append(aug)

    return out


# ---- Apply categorize-first filtering (PCA-based) + voxel-only augmentation ----
train_px_aug  = make_augmented_copies(train_px_kept, pos_augs=POS_AUGS, neg_augs=NEG_AUGS)
val_px_aug  = make_augmented_copies(val_px,  pos_augs=0, neg_augs=0)
test_px_aug = make_augmented_copies(test_px, pos_augs=0, neg_augs=0)


# Model

In [None]:
# ===========================================
# Voxel-only SE-UNet pipeline
# (channel selection & regional gating applied upstream)
# ===========================================
# ------------------------------
# CONFIG
# ------------------------------
KEEP_ANAPLASTIC_WHEN_TARGET = True  # used for reporting only (gating done upstream)
assert TARGET_TUMOR in ("Papillary", "Follicular")

# ------------------------------
# Masks & input tensors (expects 256×256 from augmentation)
# ------------------------------
def ensure_mask(d, H, W):
    """
    Constructs a tumor mask for a given sample, ensuring a standardized
    shape of [H, W, 1] with float32 values in {0, 1}.

    For Normal samples, this returns an all-zero mask. For tumor samples,
    it uses the stored 'mask' field from the sample dictionary, resizing
    as needed with nearest-neighbor interpolation.

    Args:
        d (dict): Sample dictionary containing at least 'tissue_type' and
            optionally 'mask'.
        H (int): Target mask height in pixels.
        W (int): Target mask width in pixels.

    Returns:
        numpy.ndarray: Binary tumor mask of shape (H, W, 1), dtype float32,
        where 1 indicates tumor and 0 indicates background.
    """
    gt = d.get("tissue_type", "")
    if gt == "Normal":
        m = np.zeros((H, W), np.uint8)
    else:
        m = np.asarray(d.get("mask", np.zeros((H, W), np.uint8)), np.uint8)
        if m.ndim == 3 and m.shape[-1] == 1:
            m = m[..., 0]
        if m.shape != (H, W):
            m = tf.image.resize(
                m[..., None], (H, W),
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
            ).numpy().squeeze()
    return (m > 0).astype(np.float32)[..., None]


def build_vox_inputs(samples):
    """
    Builds voxel input tensors and corresponding tumor masks for training or
    evaluation of the voxel-only SE-UNet model.

    Args:
        samples (list of dict): List of sample dictionaries, each containing:
            - 'grayscale_voxel': Voxel image of shape [H, W, C].
            - 'tissue_type': Class label (e.g., 'Normal', 'Follicular', 'Papillary').
            - optional 'mask': Tumor mask used to derive the ground-truth label.

    Returns:
        tuple:
            - numpy.ndarray: X_vox of shape (N, H, W, C), voxel intensities scaled
              to the range [0, 1].
            - numpy.ndarray: Y of shape (N, H, W, 1), binary tumor masks derived
              via `ensure_mask`.
    """
    Xv, Y = [], []
    for d in samples:
        voxel = np.asarray(d["grayscale_voxel"], np.float32)
        H, W, _ = voxel.shape
        mask = ensure_mask(d, H, W)
        Xv.append(voxel / 255.0)
        Y.append(mask)
    return np.stack(Xv, 0), np.stack(Y, 0)

# ------------------------------
# Metrics & Losses
# ------------------------------
_bce_none = tf.keras.losses.BinaryCrossentropy(from_logits=False, reduction="none")


def dice_coef(y_true, y_pred, smooth=1e-6, empty_score=1.0):
    """
    Computes the mean Dice coefficient over a batch, treating completely empty
    masks (both prediction and ground truth) as a special case.

    Args:
        y_true (tf.Tensor): Ground-truth masks of shape (N, H, W, 1).
        y_pred (tf.Tensor): Predicted masks (probabilities) of shape (N, H, W, 1).
        smooth (float): Small constant to avoid division by zero.
        empty_score (float): Dice score assigned to samples where both
            prediction and ground truth are empty.

    Returns:
        tf.Tensor: Scalar mean Dice coefficient over the batch.
    """
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(tf.cast(y_pred, tf.float32), 1e-6, 1 - 1e-6)
    t = tf.reshape(y_true, (tf.shape(y_true)[0], -1))
    p = tf.reshape(y_pred, (tf.shape(y_pred)[0], -1))
    inter = tf.reduce_sum(t * p, axis=1)
    den = tf.reduce_sum(t + p, axis=1)
    dice = (2.0 * inter + smooth) / (den + smooth)

    both_empty = tf.logical_and(
        tf.equal(tf.reduce_sum(t, 1), 0.0),
        tf.equal(tf.reduce_sum(p, 1), 0.0)
    )
    dice = tf.where(both_empty, tf.fill(tf.shape(dice), empty_score), dice)
    return tf.reduce_mean(dice)


def dice_nonempty(y_true, y_pred, smooth=1e-6):
    """
    Computes the mean Dice coefficient over only those samples that contain
    at least one positive ground-truth pixel.

    Args:
        y_true (tf.Tensor): Ground-truth masks of shape (N, H, W, 1).
        y_pred (tf.Tensor): Predicted masks (probabilities) of shape (N, H, W, 1).
        smooth (float): Small constant to avoid division by zero.

    Returns:
        tf.Tensor: Scalar mean Dice coefficient over the subset of non-empty
        ground-truth masks. Returns 1.0 if there are no non-empty masks.
    """
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(tf.cast(y_pred, tf.float32), 1e-6, 1 - 1e-6)
    t = tf.reshape(y_true, (tf.shape(y_true)[0], -1))
    p = tf.reshape(y_pred, (tf.shape(y_pred)[0], -1))
    inter = tf.reduce_sum(t * p, axis=1)
    den = tf.reduce_sum(t + p, axis=1)
    dice = (2.0 * inter + smooth) / (den + smooth)
    nonempty = tf.greater(tf.reduce_sum(t, axis=1), 0.0)
    dice = tf.boolean_mask(dice, nonempty)
    return tf.cond(
        tf.size(dice) > 0,
        lambda: tf.reduce_mean(dice),
        lambda: tf.constant(1.0),
    )


def iou_coef(y_true, y_pred, smooth=1e-6):
    """
    Computes the mean Intersection-over-Union (IoU) over a batch of masks.

    Args:
        y_true (tf.Tensor): Ground-truth masks of shape (N, H, W, 1).
        y_pred (tf.Tensor): Predicted masks (probabilities) of shape (N, H, W, 1).
        smooth (float): Small constant to avoid division by zero.

    Returns:
        tf.Tensor: Scalar mean IoU across all samples in the batch.
    """
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(tf.cast(y_pred, tf.float32), 1e-6, 1 - 1e-6)
    t = tf.reshape(y_true, (tf.shape(y_true)[0], -1))
    p = tf.reshape(y_pred, (tf.shape(y_pred)[0], -1))
    inter = tf.reduce_sum(t * p, axis=1)
    union = tf.reduce_sum(t + p, axis=1) - inter
    return tf.reduce_mean((inter + smooth) / (union + smooth))


def make_weighted_bce(pos_weight_scalar):
    """
    Creates a weighted binary cross-entropy loss function that upweights
    positive pixels relative to negative pixels.

    Args:
        pos_weight_scalar (float): Positive class weight. A value > 1 increases
            the contribution of tumor pixels to the loss.

    Returns:
        callable: A loss function `weighted_bce(y_true, y_pred)` suitable for
        use in model.compile(), which computes a per-image, per-pixel weighted
        binary cross-entropy and averages over the batch.
    """
    pw = tf.constant(float(pos_weight_scalar), dtype=tf.float32)

    def weighted_bce(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.clip_by_value(tf.cast(y_pred, tf.float32), 1e-6, 1 - 1e-6)
        per_px = _bce_none(y_true, y_pred)
        weights = 1.0 + (pw - 1.0) * tf.squeeze(y_true, -1)
        num = tf.reduce_sum(weights * per_px, axis=[1, 2])
        den = tf.reduce_sum(weights, axis=[1, 2]) + 1e-6
        return tf.reduce_mean(num / den)

    return weighted_bce


def tversky_index(y_true, y_pred, alpha=0.7, beta=0.3, smooth=1e-6):
    """
    Computes the Tversky index for a batch of predictions, a generalization
    of Dice/IoU that allows asymmetric weighting of false positives and
    false negatives.

    Args:
        y_true (tf.Tensor): Ground-truth masks of shape (N, H, W, 1).
        y_pred (tf.Tensor): Predicted masks (probabilities) of shape (N, H, W, 1).
        alpha (float): Weight for false positives.
        beta (float): Weight for false negatives.
        smooth (float): Small constant to avoid division by zero.

    Returns:
        tf.Tensor: Scalar mean Tversky index across the batch.
    """
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(tf.cast(y_pred, tf.float32), 1e-6, 1 - 1e-6)
    tp = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    fp = tf.reduce_sum((1.0 - y_true) * y_pred, axis=[1, 2, 3])
    fn = tf.reduce_sum(y_true * (1.0 - y_pred), axis=[1, 2, 3])
    return tf.reduce_mean((tp + smooth) / (tp + alpha * fp + beta * fn + smooth))


def focal_tversky_loss(alpha=0.7, beta=0.3, gamma=0.75):
    """
    Builds a Focal Tversky loss function, which focuses learning on hard
    examples by applying a non-linear penalty on the Tversky index.

    Args:
        alpha (float): Weight for false positives in the Tversky index.
        beta (float): Weight for false negatives in the Tversky index.
        gamma (float): Focusing parameter; larger values place more emphasis
            on misclassified or difficult examples.

    Returns:
        callable: A loss function `_loss(y_true, y_pred)` that computes
        (1 - TverskyIndex)^gamma averaged over the batch.
    """
    def _loss(y_true, y_pred):
        t = tversky_index(y_true, y_pred, alpha=alpha, beta=beta)
        return tf.pow(1.0 - t, gamma)

    return _loss


def make_combined_loss(pos_weight_scalar):
    """
    Constructs a composite loss function combining weighted binary
    cross-entropy and Focal Tversky loss in equal proportion.

    Args:
        pos_weight_scalar (float): Positive class weight passed to the
            weighted BCE component.

    Returns:
        callable: A loss function `_loss(y_true, y_pred)` that computes
        the mean of:
            0.5 * weighted_bce(y_true, y_pred) +
            0.5 * focal_tversky_loss(y_true, y_pred).
    """
    wbce = make_weighted_bce(pos_weight_scalar)
    ftv = focal_tversky_loss(alpha=0.7, beta=0.3, gamma=0.75)

    def _loss(y_true, y_pred):
        return 0.5 * wbce(y_true, y_pred) + 0.5 * ftv(y_true, y_pred)

    return _loss

# ------------------------------
# SE-UNet (voxel-only)
# ------------------------------
def squeeze_excite_block(x, reduction=16):
    """
    Applies a squeeze-and-excitation (SE) attention block to a feature map,
    adaptively reweighting channels based on global context.

    Args:
        x (tf.Tensor): Input feature map of shape (N, H, W, C).
        reduction (int): Channel reduction factor in the bottleneck layer
            of the SE block. Larger values reduce parameter count.

    Returns:
        tf.Tensor: Output feature map of the same shape as `x`, with channels
        rescaled by learned attention weights in [0, 1].
    """
    c = x.shape[-1]
    s = GlobalAveragePooling2D()(x)
    s = Reshape((1, 1, c))(s)
    s = Dense(max(c // reduction, 4), activation="relu")(s)
    s = Dense(c, activation="sigmoid")(s)
    return Multiply()([x, s])


def enc_block_SE(x, filters, p_drop=0.10):
    """
    Encoder block with optional squeeze-and-excitation, used in the
    downsampling path of the SE-UNet.

    The block performs:
        - Conv2D + BatchNorm (+ SE for higher channel counts)
        - SpatialDropout2D
        - Strided Conv2D for downsampling

    Args:
        x (tf.Tensor): Input feature map.
        filters (int): Number of convolution filters in this block.
        p_drop (float): Dropout probability for SpatialDropout2D.

    Returns:
        tuple:
            - tf.Tensor: Feature map before downsampling (skip connection).
            - tf.Tensor: Downsampled feature map passed to the next encoder level.
    """
    x = Conv2D(filters, 3, padding="same", activation="relu")(x)
    x = BatchNormalization()(x)
    if filters >= 128:
        x = squeeze_excite_block(x)
    x = SpatialDropout2D(p_drop)(x)
    p = Conv2D(filters, 3, strides=2, padding="same", activation="relu")(x)
    return x, p


def dec_block_SE(x, skip, filters):
    """
    Decoder block with skip connections, used in the upsampling path of the
    SE-UNet.

    The block performs:
        - Upsampling by a factor of 2
        - Concatenation with the corresponding encoder skip feature map
        - Two Conv2D + ReLU layers with BatchNorm on the first

    Args:
        x (tf.Tensor): Input feature map from the deeper layer.
        skip (tf.Tensor): Skip-connection feature map from the encoder path.
        filters (int): Number of convolution filters in this block.

    Returns:
        tf.Tensor: Output feature map after upsampling and refinement.
    """
    x = UpSampling2D((2, 2))(x)
    x = Concatenate()([x, skip])
    x = Conv2D(filters, 3, padding="same", activation="relu")(x)
    x = BatchNormalization()(x)
    x = Conv2D(filters, 3, padding="same", activation="relu")(x)
    return x


def build_unet_vox_only(vox_ch=23):
    """
    Builds a 2D SE-UNet model that operates on voxel-only DOCI inputs for
    tumor segmentation.

    The architecture includes:
        - An initial Conv2D stem.
        - Five encoder blocks (with SE in deeper layers).
        - A bottleneck with Conv2D, BatchNorm, and Dropout.
        - Five decoder blocks with skip connections.
        - A final 1×1 convolution with sigmoid activation for output masks.

    Args:
        vox_ch (int): Number of input voxel channels (C) in the grayscale
            voxel image.

    Returns:
        tf.keras.Model: Compiled Keras model with input shape
        (256, 256, vox_ch) and a single-channel sigmoid output named
        'seg_output'.
    """
    vox_in = Input((256, 256, vox_ch), name="vox_in")
    x0 = Conv2D(32, 3, padding="same", activation="relu")(vox_in)
    e1, p1 = enc_block_SE(x0,   32, 0.10)
    e2, p2 = enc_block_SE(p1,   64, 0.10)
    e3, p3 = enc_block_SE(p2,  128, 0.10)
    e4, p4 = enc_block_SE(p3,  256, 0.10)
    e5, p5 = enc_block_SE(p4,  512, 0.15)

    b = Conv2D(1024, 3, padding="same", activation="relu")(p5)
    b = BatchNormalization()(b)
    b = Dropout(0.2)(b)

    d5 = dec_block_SE(b,  e5, 512)
    d4 = dec_block_SE(d5, e4, 256)
    d3 = dec_block_SE(d4, e3, 128)
    d2 = dec_block_SE(d3, e2,  64)
    d1 = dec_block_SE(d2, e1,  32)

    out = Conv2D(1, 1, activation="sigmoid", name="seg_output")(d1)
    return Model(vox_in, out, name="SE_UNet_VoxelOnly")

# ------------------------------
# Build filtered datasets (already gated & augmented upstream)
# ------------------------------
train_f = train_px_aug
val_f   = val_px_aug
test_f  = test_px_aug

print(
    "Example channels:",
    f"train C={np.asarray(train_f[0]['grayscale_voxel']).shape[-1]}",
    f"val C={np.asarray(val_f[0]['grayscale_voxel']).shape[-1]}",
    f"test C={np.asarray(test_f[0]['grayscale_voxel']).shape[-1]}",
)

Xtr_vox, Ytr = build_vox_inputs(train_f)
Xva_vox, Yva = build_vox_inputs(val_f)
Xte_vox, Yte = build_vox_inputs(test_f)

print(
    f"Kept after regional gate (Normal + {TARGET_TUMOR}"
    f"{' + Anaplastic' if KEEP_ANAPLASTIC_WHEN_TARGET else ''}): "
    f"train={len(train_f)}, val={len(val_f)}, test={len(test_f)}"
)
print(
    "Shapes:",
    "\n  train:", Xtr_vox.shape, Ytr.shape,
    "\n  val:  ", Xva_vox.shape, Yva.shape,
    "\n  test: ", Xte_vox.shape, Yte.shape,
)

# ------------------------------
# Compile & train
# ------------------------------
pos_frac = float(np.mean(Ytr > 0.5)) if Ytr.size else 0.5
neg_frac = 1.0 - pos_frac
pos_weight_val = min(5.0, float(neg_frac / (pos_frac + 1e-6)))
print(f"pos_frac ~ {pos_frac:.4f} | pos_weight (capped) ~ {pos_weight_val:.3f}")

model = build_unet_vox_only(vox_ch=Xtr_vox.shape[-1])
model.compile(
    optimizer=Adam(learning_rate=3e-4),
    loss=make_combined_loss(pos_weight_val),
    metrics=[dice_coef, dice_nonempty, iou_coef],
)

ckpt_path = (
    f"test1.weights.h5")

early = EarlyStopping(
    monitor="val_dice_nonempty", mode="max", patience=20, restore_best_weights=True, verbose=1, )

ckpt = ModelCheckpoint(
    ckpt_path, monitor="val_dice_nonempty", mode="max", save_best_only=True, save_weights_only=True, verbose=1, )

rlrop = ReduceLROnPlateau(
    monitor="val_dice_nonempty", mode="max", factor=0.5, patience=6, min_lr=1e-6, verbose=1, )

# ------------------------------
# Evaluation helper
# ------------------------------
def evaluate(y_true, y_prob, t):
    """
    Computes thresholded Dice and IoU statistics over a batch of predictions.

    For each sample, predictions are binarized at threshold `t`, and Dice and
    IoU are computed at the image level. Empty cases (both prediction and
    ground truth empty) receive perfect scores by definition.

    Args:
        y_true (numpy.ndarray): Ground-truth masks of shape (N, H, W, 1),
            with values in {0, 1}.
        y_prob (numpy.ndarray): Predicted probabilities of shape (N, H, W, 1),
            with values in [0, 1].
        t (float): Threshold for converting probabilities to binary masks.

    Returns:
        tuple:
            - tuple: (dice_mean, dice_std, dice_median)
            - tuple: (iou_mean,  iou_std,  iou_median)
    """
    y_true = (y_true > 0.5).astype(np.uint8)
    y_pred = (y_prob >= t).astype(np.uint8)

    dices, ious = [], []
    for i in range(len(y_true)):
        gt = y_true[i, ..., 0]
        pr = y_pred[i, ..., 0]
        inter = np.logical_and(gt, pr).sum()

        den_d = gt.sum() + pr.sum()
        dice = (2 * inter) / (den_d + 1e-6) if den_d > 0 else 1.0

        union = np.logical_or(gt, pr).sum()
        iou = inter / (union + 1e-6) if union > 0 else 1.0

        dices.append(dice)
        ious.append(iou)

    return (
        (np.mean(dices), np.std(dices), np.median(dices)),
        (np.mean(ious),  np.std(ious),  np.median(ious)),
    )


# Training

In [None]:
# ------------------------------
# Channel policy (already applied upstream)
# ------------------------------
print(f"[ChannelFilter] Using {Xtr_vox.shape[-1]} channels.")

# ------------------------------
# Fit model
# ------------------------------
history = model.fit(
    Xtr_vox, Ytr,
    validation_data=(Xva_vox, Yva),
    batch_size=2,
    epochs=100,
    shuffle=True,
    callbacks=[ckpt, rlrop, early],
)

# ------------------------------
# Threshold calibration + evaluation
# ------------------------------
def _mean_dice_for_threshold(y_true, y_prob, t):
    """
    Computes the mean Dice coefficient over a batch for a given
    binarization threshold.

    Ground truth masks are binarized at 0.5, while predicted probabilities
    are binarized at threshold `t`. Dice is computed per image and then
    averaged across the batch. Empty cases (no positive pixels in either
    prediction or ground truth) receive a perfect score of 1.0.

    Args:
        y_true (numpy.ndarray): Ground-truth masks of shape (N, H, W, 1),
            with values in {0, 1}.
        y_prob (numpy.ndarray): Predicted probabilities of shape (N, H, W, 1),
            with values in [0, 1].
        t (float): Threshold used to convert predicted probabilities into
            binary masks.

    Returns:
        float: Mean Dice coefficient across all samples in the batch at
        threshold `t`.
    """
    y_true_bin = (y_true > 0.5).astype(np.uint8)
    y_pred_bin = (y_prob >= t).astype(np.uint8)

    dices = []
    for i in range(len(y_true_bin)):
        gt = y_true_bin[i, ..., 0]
        pr = y_pred_bin[i, ..., 0]
        inter = np.logical_and(gt, pr).sum()
        den = gt.sum() + pr.sum()
        d = (2 * inter) / (den + 1e-6) if den > 0 else 1.0
        dices.append(d)
    return float(np.mean(dices)) if dices else 1.0


def best_dice_threshold(y_true, y_prob, grid=np.linspace(0.2, 0.9, 29)):
    """
    Searches over a grid of thresholds to find the value that maximizes
    the mean Dice coefficient on a given dataset.

    For each candidate threshold in `grid`, the function computes the mean
    Dice using `_mean_dice_for_threshold` and keeps track of the best-performing
    threshold.

    Args:
        y_true (numpy.ndarray): Ground-truth masks of shape (N, H, W, 1),
            with values in {0, 1}.
        y_prob (numpy.ndarray): Predicted probabilities of shape (N, H, W, 1),
            with values in [0, 1].
        grid (numpy.ndarray): 1D array of threshold values to scan, typically
            in the interval [0, 1].

    Returns:
        tuple:
            - float: Best threshold (t*) that yields the highest mean Dice.
            - float: Corresponding mean Dice value at t*.
    """
    best_t, best_d = 0.5, -1.0
    for t in grid:
        md = _mean_dice_for_threshold(y_true, y_prob, t)
        if md > best_d:
            best_d, best_t = md, float(t)
    return best_t, best_d


# --- Calibrate on validation set ---
val_probs = model.predict(Xva_vox, batch_size=2)
t_star, dice_val = best_dice_threshold(Yva, val_probs)
print(f"[CAL] t*={t_star:.3f} (VAL mean Dice={dice_val:.3f})")

# --- Evaluate on test set ---
test_probs = model.predict(Xte_vox, batch_size=2)
(d_mean, d_std, d_med), (i_mean, i_std, i_med) = evaluate(Yte, test_probs, t_star)

print(f"[TEST] Dice mean={d_mean:.3f} ± {d_std:.3f} | median={d_med:.3f}")
print(f"[TEST]  IoU  mean={i_mean:.3f} ± {i_std:.3f} | median={i_med:.3f}")


# Visualization

In [None]:
# Load the weights saved for the voxel-only model
model.load_weights(
    "test1.weights.h5"
)

# ---------------------------------------------------------------------
# UID helper
# ---------------------------------------------------------------------
def sample_uid(d):
    """
    Generates a stable, human-readable unique identifier string for a sample
    dictionary based on preferred metadata fields or, if unavailable, on
    voxel tensor shapes.

    Priority is given to common identifying keys such as:
    'uid', 'id', 'sample_id', 'image_name', 'name', 'file', 'path'.

    If none of these fields exist, a fallback UID is constructed from:
        - The shape of the grayscale voxel tensor
        - The shape of the grayscale cutoff voxel
        - A modulo hash of the voxel size for collision reduction

    Args:
        d (dict): Sample dictionary containing voxel and cutoff voxel arrays,
                  and optionally metadata fields for identification.

    Returns:
        str: Stable identifier string for the sample, suitable for filenames,
             figure labels, or logging.
    """
    for k in ("uid", "id", "sample_id", "image_name", "name", "file", "path"):
        if k in d and d[k] is not None:
            return str(d[k])
    v   = np.asarray(d["grayscale_voxel"])
    cut = np.asarray(d["grayscale_image_cutoff_voxel"])
    return f"auto:{v.shape}-{cut.shape}-{int(v.size % 997)}"


U_gt = [sample_uid(d) for d in test_f]
print("TEST shapes:", Xte_vox.shape, Yte.shape, f"uids={len(U_gt)}")

# --- Predict on test set (same tensors used for metrics) ---
test_probs = model.predict(Xte_vox, batch_size=2)

# Sanity checks
assert test_probs.shape[0] == Yte.shape[0] == len(U_gt)
assert Xte_vox.shape[:3] == Yte.shape[:3] == test_probs.shape[:3]

GT_mask = Yte          # [N,H,W,1]
PROBS   = test_probs   # [N,H,W,1]

# Use previously calibrated t_star if present; else compute per-image Otsu fallback
t_star_value = globals().get("t_star", None)
use_global_thresh = t_star_value is not None

# ---------------------------------------------------------------------
# Per-sample visualization: CNN P(Tumor) / GT / Thresholded Pred
# ---------------------------------------------------------------------
for i in range(PROBS.shape[0]):
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))

    # --- Extract data ---
    voxel_ch0 = Xte_vox[i, ..., 0]      # <-- FIRST CHANNEL OF VOXEL STACK
    pt_cnn    = PROBS[i, ..., 0]        # CNN probability map
    gt        = GT_mask[i, ..., 0]      # GT mask

    # --- Threshold selection ---
    if use_global_thresh:
        t_use = float(t_star_value)
    else:
        try:
            t_use = float(threshold_otsu(pt_cnn))
        except Exception:
            t_use = 0.5

    pred = (pt_cnn >= t_use).astype(np.uint8)

    # ---------------------------------------------------------------------
    # PANEL 1: First voxel channel (instead of PCA or CNN probs)
    # ---------------------------------------------------------------------
    ax[0].imshow(voxel_ch0, cmap="gray")
    ax[0].set_title("Voxel: Channel 0")
    ax[0].axis("off")

    # PANEL 2: Ground-truth mask
    ax[1].imshow(gt, cmap="gray", vmin=0, vmax=1)
    ax[1].set_title("GT")
    ax[1].axis("off")

    # PANEL 3: Thresholded prediction
    ax[2].imshow(pred, cmap="gray", vmin=0, vmax=1)
    ax[2].set_title(f"Pred (t={t_use:.2f})")
    ax[2].axis("off")

    plt.suptitle(f"sample: {U_gt[i]}")
    plt.tight_layout()
    plt.show()



# Final Metrics

In [None]:
# ----------------------------
# Basic helpers
# ----------------------------
def _ensure_2d(mask):
    """
    Convert a mask array to a 2D uint8 array.

    Accepts inputs of shape [H,W] or [H,W,1]. Any trailing singleton channel
    dimension is removed. Values are cast to uint8.

    Args:
        mask (array-like): Input mask tensor of shape [H,W] or [H,W,1].

    Returns:
        np.ndarray: 2D uint8 mask of shape [H,W].
    """
    arr = np.asarray(mask)
    if arr.ndim == 3:
        arr = arr[..., 0]
    return arr.astype(np.uint8)


def _threshold_probs(y_prob, t):
    """
    Threshold probability maps at a fixed scalar threshold.

    Handles tensors of shape [N,H,W,1] or [N,H,W] and returns a binary mask.

    Args:
        y_prob (np.ndarray): Probability maps.
        t (float): Threshold in [0,1].

    Returns:
        np.ndarray: Binary masks (uint8) with same spatial dimensions.
    """
    arr = np.asarray(y_prob)
    if arr.ndim == 4:
        arr = arr[..., 0]
    return (arr >= float(t)).astype(np.uint8)


def dice_per_image(gt, pr, eps=1e-6):
    """
    Compute Dice coefficient for a single 2D image.

    Args:
        gt (array-like): Ground-truth mask [H,W] or [H,W,1].
        pr (array-like): Predicted mask [H,W] or [H,W,1].
        eps (float): Numerical stability constant.

    Returns:
        float: Dice score for one image.
    """
    gt = _ensure_2d(gt)
    pr = _ensure_2d(pr)
    inter = np.logical_and(gt, pr).sum()
    den   = gt.sum() + pr.sum()
    return (2 * inter) / (den + eps) if den > 0 else 1.0


def iou_per_image(gt, pr, eps=1e-6):
    """
    Compute IoU (Jaccard index) for a single 2D image.

    Args:
        gt (array-like): Ground-truth mask.
        pr (array-like): Predicted mask.
        eps (float): Numerical stability constant.

    Returns:
        float: IoU score.
    """
    gt = _ensure_2d(gt)
    pr = _ensure_2d(pr)
    inter = np.logical_and(gt, pr).sum()
    union = np.logical_or(gt, pr).sum()
    return inter / (union + eps) if union > 0 else 1.0


# ----------------------------
# Threshold calibration + simple Dice/IoU
# ----------------------------
def best_dice_threshold(y_true, y_prob, grid=np.linspace(0.2, 0.9, 29)):
    """
    Brute-force search for the optimal scalar threshold that maximizes
    mean per-image Dice on a validation set.

    Args:
        y_true (np.ndarray): Ground-truth masks [N,H,W,1].
        y_prob (np.ndarray): Predicted probabilities [N,H,W,1].
        grid (iterable): Threshold values to evaluate.

    Returns:
        (float, float):
            best_threshold : threshold achieving maximum mean Dice
            best_mean_dice : corresponding mean Dice value
    """
    y_true_bin = (y_true > 0.5).astype(np.uint8)
    best_t, best_d = 0.5, -1.0

    for t in grid:
        y_pred_bin = _threshold_probs(y_prob, t)
        dices = []
        for i in range(len(y_true_bin)):
            gt = y_true_bin[i, ..., 0]
            pr = y_pred_bin[i, ...]
            dices.append(dice_per_image(gt, pr))
        md = float(np.mean(dices)) if dices else 1.0
        if md > best_d:
            best_d, best_t = md, float(t)

    return best_t, best_d


def evaluate(y_true, y_prob, t):
    """
    Compute mean/std/median Dice and IoU for a fixed threshold.

    Args:
        y_true (np.ndarray): Ground-truth masks [N,H,W,1].
        y_prob (np.ndarray): CNN probability maps [N,H,W,1].
        t (float): Threshold value.

    Returns:
        tuple:
          dice_stats = (mean, std, median)
          iou_stats  = (mean, std, median)
    """
    y_true_bin = (y_true > 0.5).astype(np.uint8)
    y_pred_bin = _threshold_probs(y_prob, t)

    dices, ious = [], []
    for i in range(len(y_true_bin)):
        gt = y_true_bin[i, ..., 0]
        pr = y_pred_bin[i, ...]
        dices.append(dice_per_image(gt, pr))
        ious.append(iou_per_image(gt, pr))

    dices = np.asarray(dices)
    ious  = np.asarray(ious)

    return (
        (float(dices.mean()), float(dices.std()), float(np.median(dices))),
        (float(ious.mean()),  float(ious.std()),  float(np.median(ious))),
    )


# --- Predict once on all splits ---
train_probs = model.predict(Xtr_vox, batch_size=2)
val_probs   = model.predict(Xva_vox, batch_size=2)
test_probs  = model.predict(Xte_vox, batch_size=2)

# Optional sanity checks
assert train_probs.shape[0] == Ytr.shape[0]
assert val_probs.shape[0]   == Yva.shape[0]
assert test_probs.shape[0]  == Yte.shape[0]

# --- Calibrate threshold on validation set ---
t_star, dice_val = best_dice_threshold(Yva, val_probs)
print(f"[CAL] t*={t_star:.3f} (VAL mean Dice={dice_val:.3f})")

# --- Basic test-set evaluation ---
(d_mean, d_std, d_med), (i_mean, i_std, i_med) = evaluate(Yte, test_probs, t_star)
print(f"[TEST] Dice mean={d_mean:.3f} ± {d_std:.3f} | median={d_med:.3f}")
print(f"[TEST]  IoU  mean={i_mean:.3f} ± {i_std:.3f} | median={i_med:.3f}")


# ============================
# Rich metrics for imbalanced data
# ============================
def _boundary_band(mask, radius=2):
    """
    Compute a thin boundary band around mask edges using morphological
    dilation/erosion.

    Used for boundary-based F1 evaluation (precision/recall on edges).

    Args:
        mask (array-like): Binary 2D mask.
        radius (int): Number of iterations for dilation/erosion.

    Returns:
        np.ndarray: Binary [H,W] mask representing boundary pixels.
    """
    m = _ensure_2d(mask) > 0
    if m.size == 0:
        return np.zeros_like(m, dtype=np.uint8)
    dil = binary_dilation(m, iterations=radius)
    ero = binary_erosion(m,  iterations=radius)
    return np.logical_xor(dil, ero).astype(np.uint8)


def dice_nonempty_mean(y_true, y_pred):
    """
    Compute average Dice over ONLY images with non-empty ground-truth masks.

    Args:
        y_true (np.ndarray): [N,H,W] or [N,H,W,1] GT masks.
        y_pred (np.ndarray): [N,H,W] predicted masks.

    Returns:
        (float, np.ndarray):
            mean Dice over nonempty images,
            per-image Dice array.
    """
    scores = []
    for i in range(len(y_true)):
        gt = _ensure_2d(y_true[i])
        pr = _ensure_2d(y_pred[i])
        if gt.sum() > 0:
            scores.append(dice_per_image(gt, pr))
    return (np.mean(scores) if scores else 1.0), np.array(scores)


def empty_fp_penalty_mean(y_true, y_pred, tissue_masks=None):
    """
    Evaluate false-positive performance only on images with empty GT.

    For empty images:
        score = 1 - FP_fraction,
    where FP_fraction is measured either within the tissue region (if provided)
    or over all pixels.

    Args:
        y_true (np.ndarray): Ground-truth masks.
        y_pred (np.ndarray): Predicted masks.
        tissue_masks (np.ndarray or None):
            Optional binary mask indicating tissue pixels.

    Returns:
        (float, np.ndarray):
            mean penalty across empty-GT images,
            per-image penalty values.
    """
    scores = []
    for i in range(len(y_true)):
        gt = _ensure_2d(y_true[i])
        if gt.sum() == 0:
            pr = _ensure_2d(y_pred[i])
            if tissue_masks is not None:
                T = (tissue_masks[i] > 0).astype(np.uint8)
                denom = max(1, T.sum())
                fp = (pr * (1 - gt) * T).sum()
            else:
                denom = pr.size
                fp = (pr * (1 - gt)).sum()
            scores.append(1.0 - (fp / denom))
    return (np.mean(scores) if scores else 1.0), np.array(scores)


def boundary_f1_mean(y_true, y_pred, radius=2, eps=1e-6):
    """
    Compute boundary F1 score by comparing boundary pixels of GT and prediction.

    Args:
        y_true (np.ndarray): Ground truth masks.
        y_pred (np.ndarray): Predicted masks.
        radius (int): Boundary thickness.
        eps (float): Stability constant.

    Returns:
        (float, np.ndarray):
            mean boundary F1, per-image F1 scores.
    """
    vals = []
    for i in range(len(y_true)):
        gt = _ensure_2d(y_true[i])
        pr = _ensure_2d(y_pred[i])
        Bgt = _boundary_band(gt, radius=radius)
        Bpr = _boundary_band(pr, radius=radius)
        tp = np.logical_and(Bgt, Bpr).sum()
        fp = np.logical_and((1 - Bgt), Bpr).sum()
        fn = np.logical_and(Bgt, (1 - Bpr)).sum()
        prec = tp / (tp + fp + eps)
        rec  = tp / (tp + fn + eps)
        f1   = (2 * prec * rec) / (prec + rec + eps)
        vals.append(f1)
    return float(np.mean(vals)), np.array(vals)


def lesion_f1_mean(y_true, y_pred, iou_thresh=0.5):
    """
    Lesion-level F1 score: computes detection F1 over connected components,
    using IoU >= threshold as the matching criterion.

    Args:
        y_true (np.ndarray): GT masks.
        y_pred (np.ndarray): Pred masks.
        iou_thresh (float): IoU threshold for declaring a match.

    Returns:
        (float, np.ndarray):
            mean lesion-level F1, per-image F1 values.
    """
    def _components(mask):
        lab, n = label(mask > 0)
        return [(lab == k).astype(np.uint8) for k in range(1, n + 1)]

    f1s = []
    for i in range(len(y_true)):
        gt = _ensure_2d(y_true[i])
        pr = _ensure_2d(y_pred[i])
        gt_cs = _components(gt)
        pr_cs = _components(pr)

        if len(gt_cs) == 0 and len(pr_cs) == 0:
            f1s.append(1.0)
            continue
        if len(gt_cs) == 0 or len(pr_cs) == 0:
            f1s.append(0.0)
            continue

        IoU = np.zeros((len(gt_cs), len(pr_cs)), dtype=np.float32)
        gt_areas = np.array([c.sum() for c in gt_cs], dtype=np.float32)
        pr_areas = np.array([c.sum() for c in pr_cs], dtype=np.float32)

        for g_idx, g in enumerate(gt_cs):
            for p_idx, p in enumerate(pr_cs):
                inter = np.logical_and(g, p).sum()
                union = gt_areas[g_idx] + pr_areas[p_idx] - inter
                IoU[g_idx, p_idx] = inter / (union + 1e-6)

        matched_gt, matched_pr = set(), set()
        pairs = []
        all_pairs = [
            (g, p)
            for g in range(len(gt_cs))
            for p in range(len(pr_cs))
        ]
        for g_idx, p_idx in sorted(all_pairs,
                                   key=lambda x: IoU[x[0], x[1]],
                                   reverse=True):
            if (
                IoU[g_idx, p_idx] >= iou_thresh
                and g_idx not in matched_gt
                and p_idx not in matched_pr
            ):
                matched_gt.add(g_idx)
                matched_pr.add(p_idx)
                pairs.append((g_idx, p_idx))

        tp = len(pairs)
        fp = len(pr_cs) - tp
        fn = len(gt_cs) - tp
        prec = tp / (tp + fp + 1e-6)
        rec  = tp / (tp + fn + 1e-6)
        f1   = (2 * prec * rec) / (prec + rec + 1e-6)
        f1s.append(f1)

    return float(np.mean(f1s)), np.array(f1s)


def evaluate_imbalanced(y_true, y_prob, t, tissue_masks=None, boundary_radius=2):
    """
    Full evaluation suite for segmentation under severe class imbalance.

    Computes:
        - Dice (non-empty images only)
        - Empty-image FP penalty
        - Balanced Dice (average of the above two)
        - Boundary F1 (edge detection)
        - Lesion-level F1 (connected-component detection)

    Args:
        y_true (np.ndarray): Ground-truth masks [N,H,W,1].
        y_prob (np.ndarray): Probability maps [N,H,W,1].
        t (float): Threshold for binarization.
        tissue_masks (np.ndarray or None):
            Optional tissue-region masks for more accurate FP penalty.
        boundary_radius (int): Thickness of boundary band.

    Returns:
        dict: Summary metrics + per-image arrays.
    """
    y_pred = _threshold_probs(y_prob, t)

    dice_pos_mean_val, dice_pos_all = dice_nonempty_mean(y_true, y_pred)
    empty_pen_mean_val, empty_pen_all = empty_fp_penalty_mean(
        y_true, y_pred, tissue_masks=tissue_masks
    )
    balanced_dice = 0.5 * (dice_pos_mean_val + empty_pen_mean_val)

    b_f1_mean_val, b_f1_all = boundary_f1_mean(
        y_true, y_pred, radius=boundary_radius
    )
    l_f1_mean_val, l_f1_all = lesion_f1_mean(
        y_true, y_pred, iou_thresh=0.5
    )

    return {
        "dice_pos_mean": float(dice_pos_mean_val),
        "empty_penalty_mean": float(empty_pen_mean_val),
        "balanced_dice": float(balanced_dice),
        "boundary_f1_mean": float(b_f1_mean_val),
        "lesion_f1_mean": float(l_f1_mean_val),
        "per_image": {
            "dice_pos": dice_pos_all,
            "empty_penalty": empty_pen_all,
            "boundary_f1": b_f1_all,
            "lesion_f1": l_f1_all,
        },
    }

train_tissue_masks = None
val_tissue_masks   = None
test_tissue_masks  = None

train_metrics = evaluate_imbalanced(Ytr, train_probs, t_star, tissue_masks=train_tissue_masks)
val_metrics   = evaluate_imbalanced(Yva, val_probs,   t_star, tissue_masks=val_tissue_masks)
test_metrics  = evaluate_imbalanced(Yte, test_probs,  t_star, tissue_masks=test_tissue_masks)

print(
    "[TRAIN] dice_pos_mean={:.3f} | empty_penalty_mean={:.3f} | "
    "balanced_dice={:.3f} | boundary_f1_mean={:.3f} | lesion_f1_mean={:.3f}"
    .format(
        train_metrics["dice_pos_mean"],
        train_metrics["empty_penalty_mean"],
        train_metrics["balanced_dice"],
        train_metrics["boundary_f1_mean"],
        train_metrics["lesion_f1_mean"],
    )
)
print(
    "\n[VAL]  dice_pos_mean={:.3f} | empty_penalty_mean={:.3f} | "
    "balanced_dice={:.3f} | boundary_f1_mean={:.3f} | lesion_f1_mean={:.3f}"
    .format(
        val_metrics["dice_pos_mean"],
        val_metrics["empty_penalty_mean"],
        val_metrics["balanced_dice"],
        val_metrics["boundary_f1_mean"],
        val_metrics["lesion_f1_mean"],
    )
)
print(
    "[TEST] dice_pos_mean={:.3f} | empty_penalty_mean={:.3f} | "
    "balanced_dice={:.3f} | boundary_f1_mean={:.3f} | lesion_f1_mean={:.3f}"
    .format(
        test_metrics["dice_pos_mean"],
        test_metrics["empty_penalty_mean"],
        test_metrics["balanced_dice"],
        test_metrics["boundary_f1_mean"],
        test_metrics["lesion_f1_mean"],
    )
)

# Filter Ablation

In [None]:
def _resolve_threshold(y_true, probs):
    """
    Determines an appropriate segmentation threshold for probability maps,
    preferring pre-calibrated global thresholds if available.

    The function checks for globally defined variables in the following order:
      1) t_star_calibrated
      2) t_star
    If neither is available or valid, it falls back to calibrating a threshold
    on the provided (y_true, probs) pair using `best_dice_threshold`, if
    defined. As a last resort, it returns 0.5.

    Args:
        y_true (numpy.ndarray): Ground-truth masks of shape (N, H, W, 1) or
            (N, H, W), used if calibration is needed.
        probs (numpy.ndarray): Predicted probability maps of shape
            (N, H, W, 1) or (N, H, W).

    Returns:
        float: Selected threshold value in [0, 1] to be used for binarizing
        `probs` into segmentation masks.
    """
    # Prefer calibrated global thresholds if available
    if "t_star_calibrated" in globals():
        try:
            return float(t_star_calibrated)
        except Exception:
            pass
    if "t_star" in globals():
        try:
            return float(t_star)
        except Exception:
            pass
    # Otherwise calibrate on provided set (val/test) using  best_dice_threshold()
    if "best_dice_threshold" in globals():
        t_, _ = best_dice_threshold(y_true, probs)
        return float(t_)
    # Last resort
    return 0.5

# -- Helper: evaluate a probability map tensor at threshold t using evaluate() --
def _eval_with_threshold(y_true, probs, t):
    """
    Wrapper around `evaluate` to assess segmentation quality at a fixed
    threshold.

    This helper calls the global `evaluate(y_true, y_prob, t)` function and
    simply forwards its output, allowing it to be used within ablation
    experiments.

    Args:
        y_true (numpy.ndarray): Ground-truth masks of shape (N, H, W, 1).
        probs (numpy.ndarray): Predicted probability maps of shape
            (N, H, W, 1).
        t (float): Threshold in [0, 1] used to binarize `probs`.

    Returns:
        tuple: A nested tuple containing:
            - (float, float, float): Dice mean, standard deviation, and median.
            - (float, float, float): IoU mean, standard deviation, and median.
    """
    return evaluate(y_true, probs, t)

# -- Optional: permutation within image (spatial shuffle) for a channel --
def _permute_channel_inplace(x, chan, rng):
    """
    Applies an in-place spatial permutation (shuffle) of one channel across
    all images in a batch.

    Each image's selected channel is flattened to 1D, shuffled independently
    using the provided random number generator, and then reshaped back to
    its original [H, W] structure. All other channels remain unchanged.

    Args:
        x (numpy.ndarray): Input tensor of shape (N, H, W, C) to be modified
            in-place.
        chan (int): Zero-based channel index to be permuted.
        rng (numpy.random.RandomState): Random state used to perform the
            shuffling, ensuring reproducibility.

    Returns:
        None: The function modifies `x` in place.
    """
    N, H, W, C = x.shape
    flat = x[..., chan].reshape(N, -1)
    for i in range(N):
        rng.shuffle(flat[i])
    x[..., chan] = flat.reshape(N, H, W)

def run_filter_ablation(model, X, Y, batch_size=2, mode="zero", repeats=1, seed=0, channel_map_1based=None):
    """
    Performs per-channel ablation on the input voxel stack to estimate
    filter/channel importance for a trained segmentation model.

    Two ablation modes are supported:
      - "zero":    Set the selected channel to zero for all pixels.
      - "permute": Spatially shuffle the selected channel within each image,
                   optionally repeating the experiment and averaging metrics.

    For each channel, the function recomputes predictions, evaluates the
    model at a fixed threshold (resolved by `_resolve_threshold`), and
    records Dice and IoU drops relative to the unperturbed baseline.

    If `KEEP_IDX` is defined globally, the function attempts to map channels
    back to their original 1-based filter indices; otherwise, labels default
    to [1..C].

    Args:
        model (tf.keras.Model or compatible): Trained segmentation model
            with a `.predict()` method accepting inputs of shape (N, H, W, C).
        X (numpy.ndarray): Input voxel tensor of shape (N, H, W, C) used
            for evaluation (e.g., validation or test set).
        Y (numpy.ndarray): Ground-truth masks of shape (N, H, W, 1).
        batch_size (int, optional): Batch size for `model.predict`. Defaults to 2.
        mode (str, optional): Ablation mode, either "zero" or "permute".
            Defaults to "zero".
        repeats (int, optional): Number of repeated permutations per channel
            when `mode="permute"`. Metrics are averaged across repeats.
            Defaults to 1.
        seed (int, optional): Random seed for reproducible permutations.
            Defaults to 0.
        channel_map_1based (list of int or None, optional): Optional mapping
            from channel index (0-based) to original 1-based filter index.
            If None, labels are set to [1..C], or derived from global
            `KEEP_IDX` if present.

    Returns:
        tuple:
            - pandas.DataFrame: Per-channel ablation summary with columns:
                * 'chan_idx_0b'      : 0-based channel index
                * 'chan_label_1b'    : 1-based channel label (original index)
                * 'dice_mean'        : mean Dice after ablation
                * 'dice_drop'        : Dice drop vs baseline
                * 'iou_mean'         : mean IoU after ablation
                * 'iou_drop'         : IoU drop vs baseline
                * 'baseline_dice_mean': baseline Dice (no ablation)
                * 'baseline_iou_mean' : baseline IoU (no ablation)
                * 'threshold_used'   : threshold used for evaluation
                * 'mode'             : ablation mode ("zero" or "permute")
                * 'repeats'          : number of repeats (for permute mode)
            - dict: Metadata dictionary containing:
                * 'baseline': {'dice_mean', 'iou_mean', 't'}
                * 'mode': ablation mode
                * 'repeats': number of permutation repeats

    Side Effects:
        - Prints top and bottom 10 filters by Dice drop.
        - Plots a bar chart of ΔDice for most and least impactful filters.
        - Saves full results to 'filter_ablation_results.csv'.
    """
    """
    Returns a pandas.DataFrame with per-channel metrics and drops vs baseline.
    X: [N,H,W,C] input tensor used for evaluation (e.g., val or test stack)
    Y: [N,H,W,1] ground-truth mask (float or uint8)
    mode="zero": set channel to 0
    mode="permute": spatially shuffle pixels of that channel within each image
    """
    assert X.ndim == 4 and Y.ndim == 4 and X.shape[:3] == Y.shape[:3]
    C = X.shape[-1]
    rng = np.random.RandomState(seed)

    # Build nice channel labels
    if channel_map_1based is None:
        # Try to back-map to original selection if used KEEP_IDX; else fall back to 1..C
        if "KEEP_IDX" in globals():
            # KEEP_IDX are zero-based original indices; convert to 1-based labels
            channel_map_1based = [int(i + 1) for i in list(KEEP_IDX)]
        else:
            channel_map_1based = list(range(1, C + 1))
    assert len(channel_map_1based) == C, "channel_map_1based length must match X[...,C]"

    # ---- Baseline ----
    probs_base = model.predict(X, batch_size=batch_size)
    t_use = _resolve_threshold(Y, probs_base)
    (d_mean_b, d_std_b, d_med_b), (i_mean_b, i_std_b, i_med_b) = _eval_with_threshold(Y, probs_base, t_use)

    rows = []
    for c in range(C):
        if mode == "zero":
            X_mod = X.copy()
            X_mod[..., c] = 0.0
            probs_c = model.predict(X_mod, batch_size=batch_size)
            (d_mean, d_std, d_med), (i_mean, i_std, i_med) = _eval_with_threshold(Y, probs_c, t_use)
        elif mode == "permute":
            d_means, i_means = [], []
            for r in range(repeats):
                X_mod = X.copy()
                _permute_channel_inplace(X_mod, c, rng)
                probs_c = model.predict(X_mod, batch_size=batch_size)
                (d_mean, _, _), (i_mean, _, _) = _eval_with_threshold(Y, probs_c, t_use)
                d_means.append(d_mean); i_means.append(i_mean)
            # aggregate over repeats
            d_mean = float(np.mean(d_means)); i_mean = float(np.mean(i_means))
            # std over repeats for quick uncertainty (not pixel-level)
            d_std, i_std = float(np.std(d_means)), float(np.std(i_means))
            d_med = np.nan; i_med = np.nan  # (optional)
        else:
            raise ValueError("mode must be 'zero' or 'permute'")

        rows.append({
            "chan_idx_0b": c,
            "chan_label_1b": channel_map_1based[c],
            "dice_mean": d_mean,
            "dice_drop": d_mean_b - d_mean,
            "iou_mean": i_mean,
            "iou_drop": i_mean_b - i_mean,
            "baseline_dice_mean": d_mean_b,
            "baseline_iou_mean": i_mean_b,
            "threshold_used": t_use,
            "mode": mode,
            "repeats": repeats
        })

    # ---- Aggregate results ----
    df = (
        pd.DataFrame(rows)
        .sort_values(by=["dice_drop", "iou_drop"], ascending=[False, False])
        .reset_index(drop=True)
    )

    # Nicely print both top and bottom 10 filters
    with pd.option_context('display.max_rows', None, 'display.width', 120):
        print("=== 10 Most Impactful Filters (largest Dice drop) ===")
        print(
            df[[
                "chan_label_1b", "chan_idx_0b",
                "dice_mean", "dice_drop",
                "iou_mean", "iou_drop",
                "threshold_used", "mode"
            ]].head(10)
        )

        print("\n=== 10 Least Impactful Filters (smallest Dice drop) ===")
        print(
            df[[
                "chan_label_1b", "chan_idx_0b",
                "dice_mean", "dice_drop",
                "iou_mean", "iou_drop",
                "threshold_used", "mode"
            ]].tail(10)
        )

    # --- summarize ---
    most10 = df.sort_values("dice_drop", ascending=False).head(10)
    least10 = df.sort_values("dice_drop", ascending=True).head(10)
    
    print("\n=== Summary ===")
    print(f"Baseline Dice: {d_mean_b:.3f}")
    print(f"Baseline IoU : {i_mean_b:.3f}")
    print(f"Top filter impact range: {most10['dice_drop'].min():.3f}–{most10['dice_drop'].max():.3f}")
    print(f"Least filter impact range: {least10['dice_drop'].min():.3f}–{least10['dice_drop'].max():.3f}")
    
    # --- visualization ---
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(8,4))
    plt.bar(most10["chan_label_1b"].astype(str), most10["dice_drop"], color="salmon", label="Most Impactful")
    plt.bar(least10["chan_label_1b"].astype(str), least10["dice_drop"], color="skyblue", label="Least Impactful")
    plt.axhline(0, color="gray", lw=1)
    plt.xlabel("Channel Label (1-based)")
    plt.ylabel("ΔDice vs Baseline")
    plt.title("Filter Importance via Ablation")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # --- optional: export results to CSV ---
    df.to_csv("filter_ablation_results.csv", index=False)
    print("\nSaved full results to 'filter_ablation_results.csv'")

    return df, {
        "baseline": {"dice_mean": d_mean_b, "iou_mean": i_mean_b, "t": t_use},
        "mode": mode,
        "repeats": repeats
    }

In [None]:
df_perm, meta_perm = run_filter_ablation(model=model, X=Xva_vox, =Yva.astype(np.float32), batch_size=2, mode="permute", repeats=3, seed=123)