In [None]:
# --- Standard Library ---
import os
import re
import pickle
from collections import defaultdict


# --- Third-Party Libraries ---
import numpy as np
import pandas as pd
import requests
import matplotlib.pyplot as plt
import seaborn as sns


from scipy.ndimage import (
    gaussian_filter,
    sobel,
)

from sklearn.decomposition import PCA
from sklearn.metrics import (
    confusion_matrix,
)
from sklearn.utils import shuffle
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.calibration import CalibratedClassifierCV

import tensorflow as tf

In [None]:
# ===========================================
# Clean Imports (deduplicated and organized)
# ===========================================

# --------------------
# Paths & source
# --------------------
BASE_PATH       = "data"
FILENAME        = "image_dicts_256_wgrayscale_andcutoffs.pkl"
FILE_PATH       = os.path.join(BASE_PATH, FILENAME)
EXCEL_FILE_PATH = os.path.join(BASE_PATH, "sample_groups.xlsx")
URL             = "https://github.com/tylervasse/DOCI-Prediction/releases/download/v1.0/image_dicts_256_wgrayscale_andcutoffs.pkl"

# --------------------
# IO helpers
# --------------------
def download_file(url, output_path):
    """Download file if missing."""
    if os.path.exists(output_path):
        print(f"File already exists at {output_path}")
        return
    print(f"Downloading to {output_path}...")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(output_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    print("Download complete.")

def load_image_dicts(file_path):
    """Load list of image dictionaries from pickle."""
    try:
        with open(file_path, "rb") as f:
            return pickle.load(f)
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return []
    except Exception as e:
        print(f"Error loading file: {e}")
        return []

def load_sample_groups(excel_file_path):
    """
    Excel must have columns: 'Train Samples', 'Validation Samples', 'Test Samples'.
    Returns three lists of sample base names (strings).
    """
    try:
        df = pd.read_excel(excel_file_path)
        norm = lambda col: [s.strip().strip("'") for s in df[col].dropna().tolist()]
        return norm('Train Samples'), norm('Validation Samples'), norm('Test Samples')
    except FileNotFoundError:
        print(f"Error: Sample groups file not found at {excel_file_path}")
        return [], [], []
    except Exception as e:
        print(f"Error reading Excel: {e}")
        return [], [], []

# --------------------
# Basic parsing utils
# --------------------
def get_base_name(name):
    """Sample base name = everything before '_DOCI_n'."""
    return name.split('_DOCI')[0]

def get_doci_number(name):
    """Extract integer n from '_DOCI_n' (or -1 if absent)."""
    m = re.search(r'_DOCI_(\d+)', name)
    return int(m.group(1)) if m else -1

# --------------------
# Split image dicts into splits by sample base name
# --------------------
def categorize_images(image_data, train_samples, val_samples, test_samples):
    """
    Split image dicts into train/val/test by base sample name.
    """
    train_set, val_set, test_set = [], [], []
    for d in image_data:
        base = "_".join(d['name'].split('_')[:2])  # e.g., 'SSW-23-12345_A1'
        if base in train_samples:
            train_set.append(d)
        elif base in val_samples:
            val_set.append(d)
        elif base in test_samples:
            test_set.append(d)
    return train_set, val_set, test_set

# --------------------
# Voxelize per-sample (group by base name, sort by DOCI)
# --------------------
def samples_to_voxels(dataset):
    """
    Group images by base sample, sort by DOCI index, and build:
      - grayscale_voxel:  [H, W, D]
      - grayscale_image_cutoff_voxel: [H, W, D]
      - mask: first available mask kept as-is
    Expected keys in each item: 'name', 'grayscale', 'image_grayscale_cutoff', 'mask', 'tissue_type'
    """
    grouped = defaultdict(lambda: {
        'names': [], 'grayscale': [], 'image_grayscale_cutoff': [], 'mask': None, 'tissue_type': None
    })

    for d in dataset:
        base = get_base_name(d['name'])
        grouped[base]['names'].append(d['name'])
        grouped[base]['grayscale'].append(d['grayscale'])
        grouped[base]['image_grayscale_cutoff'].append(d['image_grayscale_cutoff'])
        grouped[base]['tissue_type'] = d['tissue_type']
        if grouped[base]['mask'] is None and d.get('mask') is not None:
            grouped[base]['mask'] = d['mask']

    voxelized = []
    for base, g in grouped.items():
        order = sorted(range(len(g['names'])), key=lambda i: get_doci_number(g['names'][i]))
        gray     = [g['grayscale'][i] for i in order]
        gray_cut = [g['image_grayscale_cutoff'][i] for i in order]
        grayscale_voxel                 = np.stack(gray, axis=-1).astype(np.float32)     # [H,W,D]
        grayscale_image_cutoff_voxel    = np.stack(gray_cut, axis=-1).astype(np.uint8)   # [H,W,D]

        voxelized.append({
            'name': base,
            'grayscale_voxel': grayscale_voxel,
            'grayscale_image_cutoff_voxel': grayscale_image_cutoff_voxel,
            'tissue_type': g['tissue_type'],
            'mask': g['mask']
        })
    return voxelized

# ====================
# Main flow
# ====================
# 1) Ensure data file
download_file(URL, FILE_PATH)

# 2) Load raw dicts
image_dicts = load_image_dicts(FILE_PATH)

# 3) Exclude specific samples by substring match in 'name'
EXCLUDE_LIST = ["SSW-23-14395_C2", "SSW-23-05363_A7"]
image_dicts = [d for d in image_dicts if not any(excl in d['name'] for excl in EXCLUDE_LIST)]

# 4) Load sample groups from Excel
train_samples, val_samples, test_samples = load_sample_groups(EXCEL_FILE_PATH)

# 5) Assign to splits and shuffle at image level
train_set, val_set, test_set = categorize_images(image_dicts, train_samples, val_samples, test_samples)
train_set = shuffle(train_set, random_state=42)
val_set   = shuffle(val_set,   random_state=42)
test_set  = shuffle(test_set,  random_state=42)

# 6) Voxelize per sample
train_combined = samples_to_voxels(train_set)
val_combined   = samples_to_voxels(val_set)
test_combined  = samples_to_voxels(test_set)

print(f"Samples -> train: {len(train_combined)} | val: {len(val_combined)} | test: {len(test_combined)}")

# Regional Categorization from the PCA

In [None]:
# ===========================================
# Train PCA/Classifier on ALL tumor types, then
# filter out Follicular-dominant samples from
# TRAIN / VAL / TEST (based on PCA-feature model)
# ===========================================

# --- constants ---
TISSUES3 = ['Normal', 'Follicular', 'Papillary']
CLASS_TO_ID3 = {c: i for i, c in enumerate(TISSUES3)}
TARGET_TUMOR = "Follicular"
NONTARGET_TUMOR = {"Papillary": "Follicular",
                   "Follicular": "Papillary"}.get(TARGET_TUMOR)
TARGET_ID = CLASS_TO_ID3[TARGET_TUMOR]

# ---- Channel selection (0-based indices) ----
# Define channels to REMOVE by index (1-based here for readability), then convert to 0-based indices.
REMOVE_VOXEL_CHANNELS = [1, 2, 4, 7, 9, 11, 12, 14, 16, 17, 19]
REMOVE_VOXEL_CHANNELS = [i - 1 for i in REMOVE_VOXEL_CHANNELS]

# Optionally, explicitly define channels to KEEP (overrides REMOVE_* if not None)
KEEP_VOXEL_CHANNELS = None        # e.g., [2,3,4,5,6]

# ------------------------------
# Channel policy for PCA pipeline
# ------------------------------
def _sanitize_indices_pca(C, keep=None, remove=None):
    """
    Compute channel indices to keep given KEEP_VOXEL_CHANNELS / REMOVE_VOXEL_CHANNELS.
    Use either keep OR remove (or neither), but not both.
    """
    if keep not in (None, []) and remove not in (None, []):
        raise ValueError("Use either KEEP_VOXEL_CHANNELS or REMOVE_VOXEL_CHANNELS, not both.")
    if keep not in (None, []):
        idx = sorted({int(i) for i in keep if 0 <= int(i) < C})
    elif remove not in (None, []):
        bad = {int(i) for i in remove if 0 <= int(i) < C}
        idx = [i for i in range(C) if i not in bad]
    else:
        idx = list(range(C))
    if not idx:
        raise ValueError("No channels left after applying channel policy.")
    return idx

def _apply_channel_filter_to_samples_pca(sample_list, keep_idx):
    """
    In-place: slice d['grayscale_voxel'] to keep only channels in keep_idx
    for every sample in sample_list.
    """
    if sample_list is None:
        return None
    for d in sample_list:
        if 'grayscale_voxel' not in d:
            continue
        v = np.asarray(d['grayscale_voxel'], np.float32)
        if v.ndim != 3:
            raise ValueError(f"Expected voxel shape [H,W,C], got {v.shape}")
        if max(keep_idx) >= v.shape[-1]:
            raise ValueError(
                f"keep_idx {keep_idx} incompatible with voxel shape {v.shape}"
            )
        d['grayscale_voxel'] = v[..., keep_idx]
    return sample_list

# Apply channel policy to each set *before* PCA/scaler training
try:
    # Infer original channel count from train_combined
    C0 = None
    for d0 in train_combined:
        if 'grayscale_voxel' in d0:
            C0 = np.asarray(d0['grayscale_voxel']).shape[-1]
            break
    if C0 is None:
        raise RuntimeError("Could not infer voxel channel count from train_combined.")

    keep_idx_pca = _sanitize_indices_pca(C0, KEEP_VOXEL_CHANNELS, REMOVE_VOXEL_CHANNELS)
    print(f"[PCA] Applying voxel channel policy: {C0} -> {len(keep_idx_pca)} using {[i+1 for i in keep_idx_pca]}")

    # Apply to all relevant sample sets in-place
    train_combined = _apply_channel_filter_to_samples_pca(train_combined, keep_idx_pca)
    try:
        val_combined = _apply_channel_filter_to_samples_pca(val_combined, keep_idx_pca)
    except NameError:
        pass
    try:
        test_combined = _apply_channel_filter_to_samples_pca(test_combined, keep_idx_pca)
    except NameError:
        pass

except NameError:
    # If KEEP_VOXEL_CHANNELS / REMOVE_VOXEL_CHANNELS or train_combined not defined yet
    print("[PCA] Channel policy not found or sample sets undefined here; using all voxel channels for PCA.")

# ------------------------------
# Tissue mask from cutoff
# ------------------------------
def _tissue_mask_from_cutoff(cutvoxel, black_tolerance=5):
    """
    v1-style tissue mask: use a single grayscale cutoff plane.
    If cutvoxel is HxW, use it directly.
    If cutvoxel is HxWxC, use the FIRST channel.
    """
    arr = np.asarray(cutvoxel)
    if arr.ndim == 3:  # HxWxC
        arr = arr[..., 0]
    arr = arr.astype(np.uint8)
    return (arr > black_tolerance).astype(np.uint8)

def _resize_mask_to(img_mask, H, W):
    m = np.asarray(img_mask, np.uint8)
    if m.ndim == 3 and m.shape[-1] == 1:
        m = m[..., 0]
    if m.shape != (H, W):
        m = tf.image.resize(
            m[..., None], (H, W),
            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
        ).numpy().squeeze().astype(np.uint8)
    return m

# ------------------------------
# Pixels for PCA ONLY
#   - Normal slides: pixels from tissue region
#   - Follicular/Papillary/Anaplastic slides: pixels from tumor region ∩ tissue
# ------------------------------
def collect_pixels_for_pca_regions(
    samples,
    max_pixels_per_image=200000,
    rng=42,
    black_tolerance=5,
    include_anaplastic=False,
):
    """
    Return X_raw (N_pix, C) for PCA fitting or projection:
      - Normal: pixels where tissue==1
      - Follicular/Papillary: pixels where tumor==1 AND tissue==1
      - Anaplastic: ONLY included if include_anaplastic=True
    """
    rs = np.random.RandomState(rng)
    chunks = []
    for d in samples:
        voxel = np.asarray(d['grayscale_voxel'], np.float32)  # [H,W,C] (already filtered if policy applied)
        cut   = np.asarray(d['grayscale_image_cutoff_voxel'], np.uint8)
        H, W, _ = voxel.shape
        tissue = _tissue_mask_from_cutoff(cut, black_tolerance)
        tt = d.get('tissue_type', '')

        if tt == 'Normal':
            ys, xs = np.where(tissue > 0)

        elif tt in ('Follicular', 'Papillary') or (include_anaplastic and tt == 'Anaplastic'):
            tumor = _resize_mask_to(d.get('mask', np.zeros((H, W), np.uint8)), H, W)
            ys, xs = np.where((tumor > 0) & (tissue > 0))

        else:
            continue  # ignore anything else

        if ys.size == 0:
            continue
        cap = min(max_pixels_per_image, ys.size)
        idx = rs.choice(ys.size, size=cap, replace=False)
        chunks.append(voxel[ys[idx], xs[idx], :])

    if not chunks:
        return np.empty((0, 0), np.float32)
    return np.vstack(chunks).astype(np.float32)

# ------------------------------
# PCA + multi-scale features
# ------------------------------
def build_pca_and_scaler(train_X, n_components=2):
    scaler = StandardScaler(with_mean=True, with_std=True)
    Xs = scaler.fit_transform(train_X)
    pca = PCA(n_components=n_components, random_state=0)
    pca.fit(Xs)
    return scaler, pca

def project_image_to_pcs(voxel, scaler, pca):
    H, W, C = voxel.shape
    Xs = scaler.transform(voxel.reshape(-1, C))
    Z  = pca.transform(Xs).reshape(H, W, -1)   # [H,W,D]
    return Z

def _grad_mag(img):
    gx = sobel(img, axis=0, mode='nearest')
    gy = sobel(img, axis=1, mode='nearest')
    return np.hypot(gx, gy)

def image_features_from_pcs(pc_maps, sigmas=(0.0, 1.0, 2.0)):
    H, W, D = pc_maps.shape
    blurs = [pc_maps] + [
        np.stack([gaussian_filter(pc_maps[..., i], s) for i in range(D)], axis=-1)
        for s in sigmas if s > 0
    ]
    feats = blurs[:]
    for arr in blurs:
        gm = np.stack([_grad_mag(arr[..., i]) for i in range(arr.shape[-1])], axis=-1)
        feats.append(gm)
    return np.concatenate(feats, axis=-1).astype(np.float32)  # [H,W,F]

# ------------------------------
# Classifier trained on ALL tumors (modified fit: PCA source only)
# ------------------------------
class PixelPCAContextClassifierAll:
    def __init__(self, n_pcs=8, sigmas=(0.0,1.0,2.0),
                 use_linear_svc=False, calibration='sigmoid',
                 C=1.0, max_iter=1000):
        self.n_pcs = n_pcs
        self.sigmas = tuple(sigmas)
        self.use_linear_svc = use_linear_svc
        self.calibration = calibration
        self.C = C
        self.max_iter = max_iter
        self.scaler_ = None
        self.pca_ = None
        self.clf_ = None
        self.full_order_ids_ = np.array([CLASS_TO_ID3[c] for c in TISSUES3])
        self.present_ids_ = None

    def fit(self, train_samples, max_pixels_per_image=200000, class_weight='balanced'):
        # (A) PCA/scaler: train WITHOUT anaplastic
        X_raw_pca = collect_pixels_for_pca_regions(
            train_samples,
            max_pixels_per_image=max_pixels_per_image,
            rng=123,
            black_tolerance=5,
            include_anaplastic=False
        )
        if X_raw_pca.size == 0:
            raise RuntimeError("No pixels collected for PCA.")
        self.scaler_, self.pca_ = build_pca_and_scaler(X_raw_pca, n_components=self.n_pcs)

        # (B) Build contextual TRAIN set for classifier
        feats_list, y_list = [], []
        rs = np.random.RandomState(123)
        for d in train_samples:
            voxel  = np.asarray(d['grayscale_voxel'], np.float32)
            cutvox = np.asarray(d['grayscale_image_cutoff_voxel'], np.uint8)
            H, W, _ = voxel.shape
            tissue = _tissue_mask_from_cutoff(cutvox)

            tt = d.get('tissue_type', '')
            if tt == 'Normal':
                labels = np.full((H, W), CLASS_TO_ID3['Normal'], np.uint8)
                labels[tissue == 0] = 255
            elif tt in ('Follicular', 'Papillary'):
                tumor = _resize_mask_to(d.get('mask', np.zeros((H, W), np.uint8)), H, W)
                labels = np.full((H, W), CLASS_TO_ID3['Normal'], np.uint8)
                labels[(tumor > 0) & (tissue > 0)] = CLASS_TO_ID3[tt]
                labels[tissue == 0] = 255
            else:
                continue  # Anaplastic excluded from classifier labels

            pcs = project_image_to_pcs(voxel, self.scaler_, self.pca_)      # [H,W,D]
            F   = image_features_from_pcs(pcs, sigmas=self.sigmas)          # [H,W,F]
            ys, xs = np.where(labels != 255)
            if ys.size == 0:
                continue
            cap = min(30000, ys.size)
            idx = rs.choice(ys.size, size=cap, replace=False)
            feats_list.append(F[ys[idx], xs[idx], :])
            y_list.append(labels[ys[idx], xs[idx]])

        X_feat = np.vstack(feats_list)
        y_feat = np.concatenate(y_list)

        # (C) multinomial LR (or LinearSVC + calibration)
        if not self.use_linear_svc:
            base = LogisticRegression(
                multi_class='multinomial',
                solver='lbfgs',
                C=self.C,
                max_iter=self.max_iter,
                class_weight=class_weight
            )
            self.clf_ = base.fit(X_feat, y_feat)
            self.present_ids_ = self.clf_.classes_.astype(int)
        else:
            base = LinearSVC(C=self.C, max_iter=self.max_iter, class_weight=class_weight)
            self.clf_ = CalibratedClassifierCV(base_estimator=base, cv=4, method=self.calibration)
            self.clf_.fit(X_feat, y_feat)
            self.present_ids_ = self.clf_.classes_.astype(int)
        return self

    def _expand_to_full(self, proba_small):
        N = proba_small.shape[0]
        proba_full = np.zeros((N, len(self.full_order_ids_)), dtype=np.float32)
        for j, cls_id in enumerate(self.present_ids_):
            proba_full[:, cls_id] = proba_small[:, j]
        return proba_full

    def predict_proba_map(self, sample):
        voxel = np.asarray(sample['grayscale_voxel'], np.float32)  # already filtered channels
        pcs   = project_image_to_pcs(voxel, self.scaler_, self.pca_)
        F     = image_features_from_pcs(pcs, sigmas=self.sigmas)
        H, W, Fdim = F.shape
        proba_small = self.clf_.predict_proba(F.reshape(-1, Fdim))     # [H*W,K']
        P = self._expand_to_full(proba_small).reshape(H, W, -1)        # [H,W,3] (N,F,P)
        return P

# ------------------------------
# Fit PCA + classifier on TRAIN (all tumors)
# ------------------------------
px_model_all = PixelPCAContextClassifierAll(
    n_pcs=8, sigmas=(0.0,1.0,2.0),
    use_linear_svc=False, calibration='sigmoid',
    C=1.0, max_iter=1000
)
print("[PX] Fitting PCA+Classifier on TRAIN with Normal + Follicular + Papillary ...")
px_model_all.fit(train_combined, max_pixels_per_image=200000, class_weight='balanced')
print("[PX] ... done.")

# (optional) save
with open("pixel_pca_context_classifier_all.pkl", "wb") as f:
    pickle.dump(px_model_all, f)


In [None]:
# =========================
# Tunable regional categorization (categorize-first, then filter)
# =========================

VALID_CLASSES3 = ['Normal', 'Follicular', 'Papillary']

def _pred_label_map_over_tissue(px_model_all, d, black_tol=5):
    """
    Returns:
      labels_map: [H,W] in {'Normal','Follicular','Papillary'} (string per pixel)
      tissue:     [H,W] uint8 in {0,1}
      P:          [H,W,3] probs in order TISSUES3
    """
    P = px_model_all.predict_proba_map(d).astype(np.float32)  # [H,W,3]
    cut = np.asarray(d['grayscale_image_cutoff_voxel'], np.uint8)
    tissue = _tissue_mask_from_cutoff(cut, black_tolerance=black_tol)
    if tissue.sum() == 0:
        tissue = np.ones_like(tissue, np.uint8)

    lbl_idx = np.argmax(P, axis=-1)  # [H,W] ints 0..2
    labels_map = np.empty(lbl_idx.shape, dtype=object)
    for k, name in enumerate(TISSUES3):
        labels_map[lbl_idx == k] = name
    return labels_map, tissue, P

def _sliding_window_indices(H, W, win=64, stride=32):
    ys = list(range(0, max(1, H - win + 1), stride))
    xs = list(range(0, max(1, W - win + 1), stride))
    if len(ys) == 0:
        ys = [0]
    if len(xs) == 0:
        xs = [0]
    if ys[-1] != max(0, H - win):
        ys.append(max(0, H - win))
    if xs[-1] != max(0, W - win):
        xs.append(max(0, W - win))
    return ys, xs

def categorize_sample_regional(
    px_model_all, d,
    black_tol=5,
    win=64, stride=32,
    frac_thresh=0.60,
    min_tissue_px_per_win=50
):
    """
    Regional decision rule:

      1) Argmax pixel-wise class labels within tissue.
      2) Slide a window; in each window, if a cancer class
         occupies ≥ frac_thresh of tissue pixels, mark it as present.
      3) If multiple cancers are present, choose the one with
         larger overall tissue presence; if none, classify as Normal.
    """
    labels_map, tissue, _ = _pred_label_map_over_tissue(px_model_all, d, black_tol=black_tol)
    H, W = tissue.shape

    tissue_idx = tissue.astype(bool)
    if tissue_idx.sum() == 0:
        return 'Normal'

    overall_counts = {
        'Follicular': np.sum((labels_map == 'Follicular') & tissue_idx),
        'Papillary':  np.sum((labels_map == 'Papillary')  & tissue_idx)
    }

    present_cancers = set()
    ys, xs = _sliding_window_indices(H, W, win=win, stride=stride)

    for y0 in ys:
        for x0 in xs:
            y1, x1 = y0 + win, x0 + win
            sub_tissue = tissue[y0:y1, x0:x1].astype(bool)
            tp = int(sub_tissue.sum())
            if tp < min_tissue_px_per_win:
                continue

            sub_lbl = labels_map[y0:y1, x0:x1]
            f_cnt = int(np.sum((sub_lbl == 'Follicular') & sub_tissue))
            p_cnt = int(np.sum((sub_lbl == 'Papillary')  & sub_tissue))
            f_frac = f_cnt / tp
            p_frac = p_cnt / tp

            if f_frac >= frac_thresh:
                present_cancers.add('Follicular')
            if p_frac >= frac_thresh:
                present_cancers.add('Papillary')

    if len(present_cancers) == 0:
        return 'Normal'
    if len(present_cancers) == 1:
        return next(iter(present_cancers))

    f_total = overall_counts['Follicular']
    p_total = overall_counts['Papillary']
    return 'Follicular' if f_total >= p_total else 'Papillary'

# ------------------------------
# Categorize FIRST, then filter out Follicular-dominant
# ------------------------------
def categorize_split(samples, px_model_all,
                     black_tol=5, win=64, stride=32,
                     frac_thresh=0.60, min_tissue_px_per_win=50):
    """
    Returns list of dicts: [{ 'name':..., 'gt':..., 'pred':..., 'sample': d }, ...]
    """
    results = []
    for d in samples:
        pred = categorize_sample_regional(
            px_model_all, d,
            black_tol=black_tol,
            win=win, stride=stride,
            frac_thresh=frac_thresh,
            min_tissue_px_per_win=min_tissue_px_per_win
        )
        results.append({
            'name': d.get('name', 'unknown_sample'),
            'gt': d.get('tissue_type', None),
            'pred': pred,
            'sample': d
        })
    return results

def filter_after_categorization(cat_results):
    """
    From categorized results, drop those predicted as the nontarget tumor.
    Returns kept_samples, dropped_samples (original sample dicts).
    """
    kept, dropped = [], []
    for r in cat_results:
        if r['pred'] == NONTARGET_TUMOR:
            dropped.append(r['sample'])
        else:
            kept.append(r['sample'])
    return kept, dropped

# ---- RUN: categorize first, then filter ----
REGIONAL_PARAMS = dict(black_tol=5, win=80, stride=10,
                       frac_thresh=0.80, min_tissue_px_per_win=35)

train_cat = categorize_split(train_combined, px_model_all, **REGIONAL_PARAMS)
val_cat   = categorize_split(val_combined,   px_model_all, **REGIONAL_PARAMS)
test_cat  = categorize_split(test_combined,  px_model_all, **REGIONAL_PARAMS)

train_filtered, train_dropped = filter_after_categorization(train_cat)
val_filtered,   val_dropped   = filter_after_categorization(val_cat)
test_filtered,  test_dropped  = filter_after_categorization(test_cat)

print(f"[CATEGORIZE] counts (pred): "
      f"train N/F/P = "
      f"{sum(r['pred']=='Normal' for r in train_cat)}/"
      f"{sum(r['pred']=='Follicular' for r in train_cat)}/"
      f"{sum(r['pred']=='Papillary' for r in train_cat)}")
print(f"[FILTER] kept: train={len(train_filtered)} val={len(val_filtered)} test={len(test_filtered)}")
print(f"[FILTER] drop(Follicular-pred): train={len(train_dropped)} val={len(val_dropped)} test={len(test_dropped)}")

# ------------------------------
# Confusion matrices & verbose misclassifications
# ------------------------------
def confusion_from_cat(cat_results, split_name):
    y_true = [r['gt'] for r in cat_results if r['gt'] in VALID_CLASSES3]
    y_pred = [r['pred'] for r in cat_results if r['gt'] in VALID_CLASSES3]
    names  = [r['name'] for r in cat_results if r['gt'] in VALID_CLASSES3]

    cm = confusion_matrix(y_true, y_pred, labels=VALID_CLASSES3)

    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=VALID_CLASSES3, yticklabels=VALID_CLASSES3)
    plt.xlabel("Predicted")
    plt.ylabel("Ground Truth")
    plt.title(f"{split_name} Confusion Matrix (Regional categorization)")
    plt.tight_layout()
    plt.show()

    mis = [(n, gt, pr) for n, gt, pr in zip(names, y_true, y_pred) if gt != pr]
    anaplastic_as_normal = [r['name'] for r in cat_results
                            if r['gt'] == 'Anaplastic' and r['pred'] == 'Normal']

    print(f"\n[DETAILS] {split_name}")
    if mis:
        print(f"Misclassified samples ({len(mis)}):")
        for n, gt, pr in mis:
            print(f"  - {n}: GT={gt}, Pred={pr}")
    else:
        print("No misclassifications in this split.")

    if anaplastic_as_normal:
        print(f"\nAnaplastic → Normal cases ({len(anaplastic_as_normal)}):")
        for n in anaplastic_as_normal:
            print(f"  - {n}")
    print()

# Pre-filter confusion matrices (to inspect gate performance)
confusion_from_cat(train_cat, "TRAIN (pre-filter)")
confusion_from_cat(val_cat,   "VAL (pre-filter)")
confusion_from_cat(test_cat,  "TEST (pre-filter)")

# Bootstrapping

In [None]:
# ------------------------------
# Patient-level bootstrapped CIs for regional categorization
# ------------------------------
def _patient_id_from_name(name):
    """
    Map a sample name to a patient ID.
    Example naming: 'SSW-23-12345_A1' -> patient 'SSW-23-12345'.
    """
    s = str(name)
    parts = s.split("_")
    if len(parts) >= 2:
        return "_".join(parts[:2])
    return s


def _extract_arrays_for_bootstrap(cat_results):
    """
    From cat_results, extract:
      y_true : np.array of GT labels (strings)
      y_pred : np.array of predicted labels (strings)
      pids   : np.array of patient IDs (strings)
    Restricted to VALID_CLASSES3.
    """
    y_true, y_pred, pids = [], [], []
    for r in cat_results:
        gt = r.get("gt", None)
        if gt not in VALID_CLASSES3:
            continue
        y_true.append(gt)
        y_pred.append(r.get("pred", None))
        pids.append(_patient_id_from_name(r.get("name", "unknown_sample")))
    return np.array(y_true, dtype=object), np.array(y_pred, dtype=object), np.array(pids, dtype=object)


def _compute_metrics(y_true, y_pred, classes=VALID_CLASSES3):
    """
    Compute overall accuracy and per-class recall.
    Returns:
      acc, per_class_recall_dict
    """
    y_true = np.asarray(y_true, dtype=object)
    y_pred = np.asarray(y_pred, dtype=object)

    if y_true.size == 0:
        return np.nan, {c: np.nan for c in classes}

    acc = np.mean(y_true == y_pred)

    recalls = {}
    for c in classes:
        mask = (y_true == c)
        if mask.sum() == 0:
            recalls[c] = np.nan  # avoid division by zero
        else:
            tp = np.sum((y_pred == c) & mask)
            fn = np.sum((y_pred != c) & mask)
            recalls[c] = tp / (tp + fn + 1e-6)
    return float(acc), recalls


def bootstrap_regional_metrics(cat_results, n_boot=1000, random_state=0):
    """
    Patient-level bootstrap for regional categorization.

    - Resamples patients (by ID) with replacement.
    - Aggregates all samples for each selected patient.
    - Computes:
        * overall accuracy
        * per-class recall for Normal, Follicular, Papillary
    - Returns a dict with point estimates and 95% CIs.
    """
    y_true, y_pred, pids = _extract_arrays_for_bootstrap(cat_results)
    classes = VALID_CLASSES3

    # Point estimates on the original data
    acc_point, rec_point = _compute_metrics(y_true, y_pred, classes=classes)

    # Prepare bootstrap
    rng = np.random.RandomState(random_state)
    unique_pids = np.unique(pids)
    n_patients = len(unique_pids)
    if n_patients == 0:
        print("[BOOT] No patients found for bootstrap.")
        return {}

    acc_boot = []
    rec_boot = {c: [] for c in classes}

    # Pre-index by patient to speed sampling
    pid_to_idx = {pid: np.where(pids == pid)[0] for pid in unique_pids}

    for _ in range(n_boot):
        # sample patients with replacement
        sampled_pids = rng.choice(unique_pids, size=n_patients, replace=True)

        # collect indices for all tiles from these patients
        idx_list = []
        for pid in sampled_pids:
            idx_list.append(pid_to_idx[pid])
        idx_all = np.concatenate(idx_list, axis=0)

        y_true_b = y_true[idx_all]
        y_pred_b = y_pred[idx_all]

        acc_b, rec_b = _compute_metrics(y_true_b, y_pred_b, classes=classes)
        acc_boot.append(acc_b)
        for c in classes:
            if not np.isnan(rec_b[c]):
                rec_boot[c].append(rec_b[c])

    acc_boot = np.array(acc_boot, dtype=float)
    rec_boot = {c: np.array(vals, dtype=float) for c, vals in rec_boot.items()}

    def _ci(arr):
        arr = arr[~np.isnan(arr)]
        if arr.size == 0:
            return (np.nan, np.nan)
        return (float(np.percentile(arr, 2.5)),
                float(np.percentile(arr, 97.5)))

    results = {
        "accuracy": {
            "point": float(acc_point),
            "ci_95": _ci(acc_boot),
        },
        "recall": {},
        "n_patients": int(n_patients),
        "n_boot": int(n_boot),
    }
    for c in classes:
        results["recall"][c] = {
            "point": float(rec_point[c]) if not np.isnan(rec_point[c]) else np.nan,
            "ci_95": _ci(rec_boot[c]),
        }

    # Pretty print
    print(f"[BOOT] Patient-level bootstrap (n_patients={n_patients}, n_boot={n_boot})")
    acc_lo, acc_hi = results["accuracy"]["ci_95"]
    print(f"  Accuracy: {results['accuracy']['point']:.3f} "
          f"(95% CI {acc_lo:.3f}–{acc_hi:.3f})")

    for c in classes:
        r_pt = results["recall"][c]["point"]
        r_lo, r_hi = results["recall"][c]["ci_95"]
        print(f"  Recall {c:10s}: {r_pt:.3f} "
              f"(95% CI {r_lo:.3f}–{r_hi:.3f})")

    return results


print("\n[BOOTSTRAP] TRAIN split (regional categorization)")
train_boot = bootstrap_regional_metrics(train_cat, n_boot=2000, random_state=0)

print("\n[BOOTSTRAP] VAL split (regional categorization)")
val_boot = bootstrap_regional_metrics(val_cat, n_boot=2000, random_state=1)

print("\n[BOOTSTRAP] TEST split (regional categorization)")
test_boot = bootstrap_regional_metrics(test_cat, n_boot=2000, random_state=2)


# Visualizing Pixel Classification

In [None]:
# ------------ visualize per-image probability maps ------------
def visualize_px_maps_combined(samples, model, max_imgs=20,
                               save_path="px_maps_combined2.png"):
    """
    Creates ONE combined figure showing probability maps for all samples.
    Columns = Normal, Follicular, Papillary.
    Rows    = 1 row per sample.

    samples: list of sample dicts
    model:   trained PixelPCAContextClassifierAll
    """
    k = min(max_imgs, len(samples))

    # 3 columns: Normal, Follicular, Papillary
    fig, axes = plt.subplots(k, 3, figsize=(12, 4 * k))

    # Ensure axes is always 2D
    if k == 1:
        axes = np.expand_dims(axes, axis=0)

    for row in range(k):
        ex = samples[row]
        P  = model.predict_proba_map(ex)  # [H,W,3]

        # Extract individual probability maps
        pN = P[..., CLASS_TO_ID3['Normal']]
        pF = P[..., CLASS_TO_ID3['Follicular']]
        pP = P[..., CLASS_TO_ID3['Papillary']]

        # Row plots
        maps = [pN, pF, pP]
        titles = ["P(Normal)", "P(Follicular)", "P(Papillary)"]

        for col in range(3):
            ax = axes[row, col]
            ax.imshow(maps[col], vmin=0, vmax=1)
            ax.set_title(titles[col])
            ax.axis('off')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"[PX] Saved combined figure: {save_path}")

In [None]:
visualize_px_maps_combined(train_filtered, model=px_model_all, max_imgs=20, save_prefix="pixel_prob_maps_val")

In [None]:
visualize_px_maps_combined(test_filtered, model=px_model_all, max_imgs=20, save_prefix="pixel_prob_maps_val")

# Visualizing PCA

In [None]:
# ===========================================
# PCA diagnostics (PC1–PC2 scatter, variance, loadings)
# Reuses the PCA/scaler already trained in px_model_all
# ===========================================

def _boundary_mask(bin_mask):
    from scipy.ndimage import binary_dilation, binary_erosion
    m = np.asarray(bin_mask)
    if m.ndim == 3 and m.shape[-1] == 1:
        m = m[..., 0]
    m = m.astype(bool)
    if m.size == 0:
        return np.zeros_like(m, dtype=np.uint8)
    dil = binary_dilation(m, iterations=1)
    ero = binary_erosion(m, iterations=1)
    return np.logical_xor(dil, ero).astype(np.uint8)

# ---------------------------------------------------------------------
# Collect pixels for Normal + Follicular + Papillary
# (balanced; boundary-emphasized for tumors)
# ---------------------------------------------------------------------
def collect_pixel_samples_all_tumors_balanced_boundary(
    samples,
    max_pixels_per_image=20000,
    per_class_cap=150000,
    rng=42,
    black_tolerance=5,
    boundary_boost=0.5
):
    rs = np.random.RandomState(rng)
    per_class = {c: [] for c in TISSUES3}

    # infer channel count
    C_channels = None
    for d0 in samples:
        if 'grayscale_voxel' in d0:
            C_channels = int(np.asarray(d0['grayscale_voxel']).shape[-1])
            break
    if C_channels is None:
        return np.empty((0, 0), np.float32), np.empty((0,), np.int32)

    for d in samples:
        voxel = np.asarray(d['grayscale_voxel'], np.float32)  # [H,W,C]
        cut   = np.asarray(d['grayscale_image_cutoff_voxel'], np.uint8)
        H, W, _ = voxel.shape
        tissue = _tissue_mask_from_cutoff(cut, black_tolerance)
        tt = d.get('tissue_type', '')

        # Normal slides: sample tissue pixels
        if tt == 'Normal':
            ys, xs = np.where(tissue > 0)
            if ys.size == 0:
                continue
            cap = min(max_pixels_per_image, ys.size)
            idx = rs.choice(ys.size, size=cap, replace=False)
            per_class['Normal'].append(voxel[ys[idx], xs[idx], :])
            continue

        # Tumor slides: Follicular / Papillary
        if tt in ('Follicular', 'Papillary'):
            tumor = np.asarray(d.get('mask', np.zeros((H, W), np.uint8)), np.uint8)
            if tumor.ndim == 3 and tumor.shape[-1] == 1:
                tumor = tumor[..., 0]
            if tumor.shape != (H, W):
                tumor = tf.image.resize(
                    tumor[..., None], (H, W),
                    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
                ).numpy().squeeze().astype(np.uint8)

            y_t, x_t = np.where((tumor > 0) & (tissue > 0))   # tumor pixels
            y_n, x_n = np.where((tumor == 0) & (tissue > 0))  # same-slide normal pixels

            bnd = _boundary_mask(tumor)
            half_cap = max_pixels_per_image // 2
            cap_b  = int(max(0, min(1.0, boundary_boost)) * half_cap)
            cap_nt = max(0, half_cap - cap_b)

            def sub(ys, xs, k):
                if ys.size <= 0 or k <= 0:
                    return np.empty((0,), int), np.empty((0,), int)
                k = min(k, ys.size)
                idx = rs.choice(ys.size, size=k, replace=False)
                return ys[idx], xs[idx]

            if y_t.size:
                tumor_is_bnd = (bnd[y_t, x_t] > 0)
                y_tb,  x_tb  = y_t[tumor_is_bnd],  x_t[tumor_is_bnd]
                y_tnb, x_tnb = y_t[~tumor_is_bnd], x_t[~tumor_is_bnd]
            else:
                y_tb = x_tb = y_tnb = x_tnb = np.empty((0,), int)

            y_tb,  x_tb  = sub(y_tb,  x_tb,  cap_b)
            y_tnb, x_tnb = sub(y_tnb, x_tnb, cap_nt)
            y_n,   x_n   = sub(y_n,   x_n,   half_cap)

            if y_tb.size + y_tnb.size:
                y_all = np.concatenate([y_tb, y_tnb])
                x_all = np.concatenate([x_tb, x_tnb])
                per_class[tt].append(voxel[y_all, x_all, :])
            if y_n.size:
                per_class['Normal'].append(voxel[y_n, x_n, :])

    # stack & balance classes
    for k in per_class:
        per_class[k] = (
            np.vstack(per_class[k]).astype(np.float32)
            if len(per_class[k]) else
            np.empty((0, C_channels), np.float32)
        )

    sizes = {k: per_class[k].shape[0] for k in TISSUES3}
    nonzero = [v for v in sizes.values() if v > 0]
    if not nonzero:
        return np.empty((0, C_channels), np.float32), np.empty((0,), np.int32)
    target_count = min(min(nonzero), per_class_cap)

    X_list, y_list = [], []
    for cls_name in TISSUES3:
        Xc = per_class[cls_name]
        if Xc.shape[0] == 0:
            continue
        if Xc.shape[0] > target_count:
            idx = np.random.RandomState(rng).choice(Xc.shape[0], size=target_count, replace=False)
            Xc = Xc[idx]
        X_list.append(Xc)
        y_list.append(np.full((Xc.shape[0],), CLASS_TO_ID3[cls_name], np.int32))

    X_raw = np.vstack(X_list)
    y     = np.concatenate(y_list)
    return X_raw, y

# ---------------------------------------------------------------------
# Build pixel dataset for PCA diagnostics
# ---------------------------------------------------------------------
def get_pixel_dataset_for_pca(train_samples,
                              max_pixels_per_image=12000,
                              per_class_cap=150000,
                              rng=123):
    """
    Returns X_raw (N,C) and y (N,) using the same balanced,
    boundary-aware sampler as the main pipeline.
    """
    X_raw, y = collect_pixel_samples_all_tumors_balanced_boundary(
        train_samples,
        max_pixels_per_image=max_pixels_per_image,
        per_class_cap=per_class_cap,
        rng=rng
    )
    if X_raw.size == 0:
        raise RuntimeError("No pixel samples collected for PCA diagnostics.")
    return X_raw, y

# ---------------------------------------------------------------------
# Reuse trained PCA / scaler from px_model_all
# ---------------------------------------------------------------------
if getattr(px_model_all, "scaler_", None) is None or getattr(px_model_all, "pca_", None) is None:
    raise RuntimeError("px_model_all must be trained first (scaler_/pca_ not found).")

scaler_used = px_model_all.scaler_
pca_used    = px_model_all.pca_

# ---------------------------------------------------------------------
# 2D PCA scatter (PC1 vs PC2)
# ---------------------------------------------------------------------
def plot_pca_scatter_2d(X_raw, y, scaler, pca, sample_cap=200000, alpha=0.25):
    """
    PC1 vs PC2 scatter colored by class y (0: Normal, 1: Follicular, 2: Papillary).
    """
    if sample_cap is not None and X_raw.shape[0] > sample_cap:
        rs = np.random.RandomState(0)
        idx = rs.choice(X_raw.shape[0], size=sample_cap, replace=False)
        X_plot = X_raw[idx]
        y_plot = y[idx]
    else:
        X_plot = X_raw
        y_plot = y

    Xs = scaler.transform(X_plot)
    Z  = pca.transform(Xs)
    pc1, pc2 = Z[:, 0], Z[:, 1]

    label_names = ['Normal', 'Follicular', 'Papillary']
    colors = {0: 'tab:blue', 1: 'tab:orange', 2: 'tab:green'}

    plt.figure(figsize=(7, 6))
    for lab in range(3):
        m = (y_plot == lab)
        if np.any(m):
            plt.scatter(pc1[m], pc2[m], s=5, alpha=alpha,
                        label=label_names[lab], c=colors[lab])
    plt.axhline(0, lw=0.5, c='k', alpha=0.2)
    plt.axvline(0, lw=0.5, c='k', alpha=0.2)
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.title('PCA scatter (PC1 vs PC2)')
    plt.legend(markerscale=3, frameon=True)
    plt.tight_layout()
    plt.show()

# ---------------------------------------------------------------------
# Explained variance (bar + cumulative line)
# ---------------------------------------------------------------------
def plot_pca_explained_variance(pca, max_components=8):
    r = np.asarray(pca.explained_variance_ratio_)
    k = min(max_components, r.shape[0])
    x = np.arange(1, k + 1)

    plt.figure(figsize=(7, 4))
    plt.bar(x, r[:k], label='Explained variance ratio')
    cum = np.cumsum(r[:k])
    plt.plot(x, cum, marker='o', label='Cumulative')
    plt.xticks(x)
    plt.ylim(0, 1.05)
    plt.xlabel('Principal Component')
    plt.ylabel('Variance Ratio')
    plt.title(f'Explained Variance (first {k} PCs)')
    plt.legend()
    plt.tight_layout()
    plt.show()

# ---------------------------------------------------------------------
# Loadings plots (stacked per channel; abs values; prints top-K)
# ---------------------------------------------------------------------
def plot_pca_loadings_stacked(pca, feature_names=None, max_components=8,
                              weight_by_variance=True, top_k=10, start_index=1):
    comps = np.asarray(pca.components_)               # (n_components, n_features)
    evr   = np.asarray(pca.explained_variance_ratio_) # (n_components,)
    n_comp, n_feat = comps.shape
    k = min(max_components, n_comp)

    if feature_names is None:
        feature_names = [f"ch{start_index + j}" for j in range(n_feat)]

    vals = np.abs(comps[:k, :])
    if weight_by_variance:
        vals = (vals.T * evr[:k]).T

    importance = vals.sum(axis=0)

    # Full stacked bar plot
    x = np.arange(n_feat)
    plt.figure(figsize=(max(10, 0.5 * n_feat), 5))
    bottom = np.zeros(n_feat, dtype=float)
    for pc_idx in range(k):
        plt.bar(x, vals[pc_idx, :], bottom=bottom, label=f"PC{pc_idx + 1}")
        bottom += vals[pc_idx, :]
    plt.xticks(x, feature_names, rotation=90)
    plt.ylabel("Stacked |loading|" + (" × EVR" if weight_by_variance else ""))
    plt.title(f"PCA Loadings (first {k} PCs) — stacked per channel")
    plt.legend(ncol=min(4, k), fontsize=8)
    plt.tight_layout()
    plt.show()

    # Top-K channels by importance
    order = np.argsort(importance)[::-1]
    top_idx = order[:top_k]
    print(f"\nTop {top_k} channels by importance "
          f"({'|loading|×EVR' if weight_by_variance else '|loading|'} over first {k} PCs):")
    for rnk, i in enumerate(top_idx, 1):
        print(f"{rnk:2d}. {feature_names[i]}  —  score={importance[i]:.6f}")

    plt.figure(figsize=(max(8, 0.7 * top_k), 4))
    plt.bar(np.arange(top_k), importance[top_idx])
    plt.xticks(np.arange(top_k), [feature_names[i] for i in top_idx],
               rotation=45, ha='right')
    plt.ylabel("Importance (sum of |loading|" +
               (" × EVR" if weight_by_variance else "") + ")")
    plt.title(f"Top {top_k} Channels by PCA Loading Importance")
    plt.tight_layout()
    plt.show()

    # Stacked breakdown for top-K channels
    plt.figure(figsize=(max(8, 0.7 * top_k), 4.5))
    bottom = np.zeros(top_k, dtype=float)
    for pc_idx in range(k):
        plt.bar(np.arange(top_k), vals[pc_idx, top_idx],
                bottom=bottom, label=f"PC{pc_idx + 1}")
        bottom += vals[pc_idx, top_idx]
    plt.xticks(np.arange(top_k), [feature_names[i] for i in top_idx],
               rotation=45, ha='right')
    plt.ylabel("Stacked |loading|" + (" × EVR" if weight_by_variance else ""))
    plt.title(f"Top {top_k}: Stacked Contribution by PC")
    plt.legend(ncol=min(4, k), fontsize=8)
    plt.tight_layout()
    plt.show()

    return {
        "importance": importance,
        "order_desc": order,
        "top_idx": top_idx,
        "stack_vals": vals,
        "evr": evr,
    }

# ---------------------------------------------------------------------
# Per-sample PCA averages and plotting (one point per slide)
# ---------------------------------------------------------------------
def compute_sample_pca_averages(samples, scaler, pca, black_tolerance=5, split_tags=None):
    """
    Returns:
      Z_avg:   (N, D) mean PC per sample over tissue pixels
      names:   list of sample names
      labels:  list of sample GT tissue_type
      splits:  list of split tags ('train'|'val'|'test'|'unknown')
    """
    if split_tags is None:
        split_tags = ['unknown'] * len(samples)

    Z_list, names, labels, splits = [], [], [], []
    for d, sp in zip(samples, split_tags):
        voxel  = np.asarray(d['grayscale_voxel'], np.float32)            # [H,W,C]
        cutvox = np.asarray(d['grayscale_image_cutoff_voxel'], np.uint8) # [H,W,C]
        pcs    = project_image_to_pcs(voxel, scaler, pca)                # [H,W,D]
        tissue = (cutvox > black_tolerance).any(axis=-1).astype(np.uint8)

        m = tissue > 0
        if np.any(m):
            Z_mean = pcs[m].mean(axis=0)
        else:
            Z_mean = pcs.reshape(-1, pcs.shape[-1]).mean(axis=0)

        Z_list.append(Z_mean.astype(np.float32))
        names.append(d.get('name', 'unknown'))
        labels.append(d.get('tissue_type', 'unknown'))
        splits.append(sp)

    return np.vstack(Z_list), names, labels, splits

def plot_sample_pca_averages(Z_avg, sample_labels, split_tags, alpha=0.9):
    """
    PC1 vs PC2; one point per sample. Color = tissue label; Marker = split.
    """
    color_map  = {'Normal':'tab:blue', 'Follicular':'tab:orange',
                  'Papillary':'tab:green', 'Anaplastic':'tab:red'}
    marker_map = {'train':'o', 'val':'s', 'test':'^', 'unknown':'x'}
    order_lbl  = ['Normal', 'Follicular', 'Papillary', 'Anaplastic']

    plt.figure(figsize=(7, 6))
    for lbl in sorted(set(sample_labels),
                      key=lambda x: order_lbl.index(x) if x in order_lbl else 999):
        for sp in ['train', 'val', 'test', 'unknown']:
            idx = [i for i, (L, S) in enumerate(zip(sample_labels, split_tags))
                   if L == lbl and S == sp]
            if not idx:
                continue
            pts = Z_avg[idx]
            plt.scatter(pts[:, 0], pts[:, 1],
                        s=40, alpha=alpha,
                        c=color_map.get(lbl, 'gray'),
                        marker=marker_map.get(sp, 'x'),
                        label=f"{lbl} — {sp}")

    plt.axhline(0, lw=0.5, c='k', alpha=0.2)
    plt.axvline(0, lw=0.5, c='k', alpha=0.2)
    plt.xlabel('PC1 (sample mean)')
    plt.ylabel('PC2 (sample mean)')
    plt.title('Per-sample PCA (mean over tissue pixels) — split-marked')

    handles, labels = plt.gca().get_legend_handles_labels()
    uniq = dict(zip(labels, handles))
    plt.legend(uniq.values(), uniq.keys(), frameon=True, fontsize=9, ncol=2)
    plt.tight_layout()
    plt.show()

# ---------------------------------------------------------------------
# Example usage — BEFORE any filtering
# ---------------------------------------------------------------------
# 1) Pixel dataset from TRAIN
X_raw, y = get_pixel_dataset_for_pca(
    train_combined,
    max_pixels_per_image=12000,
    per_class_cap=150000,
    rng=123
)

# 2) Pixel-level PCA diagnostics
plot_pca_scatter_2d(X_raw, y, scaler_used, pca_used,
                    sample_cap=200000, alpha=0.25)
plot_pca_explained_variance(pca_used, max_components=8)

feature_names = None  # auto-label as ch1..chN
_ = plot_pca_loadings_stacked(
    pca_used,
    feature_names=feature_names,
    max_components=8,
    weight_by_variance=True,
    top_k=10,
    start_index=1
)

# 3) Per-sample PCA means across train/val/test
samples_for_plot = []
split_tags = []
for d in train_combined:
    samples_for_plot.append(d); split_tags.append('train')
for d in val_combined:
    samples_for_plot.append(d); split_tags.append('val')
for d in test_combined:
    samples_for_plot.append(d); split_tags.append('test')

Z_avg, sample_names, sample_labels, splits = compute_sample_pca_averages(
    samples_for_plot, scaler_used, pca_used, black_tolerance=5, split_tags=split_tags
)

plot_sample_pca_averages(Z_avg, sample_labels, splits, alpha=0.9)


In [None]:
# -----------------------------------------
# Save regional categorization results to Excel
# -----------------------------------------
rows = []

for split_name, cat_list in [
    ("train", train_cat),
    ("val",   val_cat),
    ("test",  test_cat),
]:
    for r in cat_list:
        rows.append({
            "split":        split_name,
            "name":         r.get("name", "unknown_sample"),
            "ground_truth": r.get("gt", None),
            "predicted":    r.get("pred", None),
        })

df_cat = pd.DataFrame(rows)
excel_path = "regional_categorization_results.xlsx"
df_cat.to_excel(excel_path, index=False)
print(f"[SAVE] Wrote regional categorization results to {excel_path}")